/[escript]/trunk/escript/py_src/pdetools.py
ViewVC logotype

Diff of /trunk/escript/py_src/pdetools.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 782 by bcumming, Tue Jul 18 00:47:47 2006 UTC revision 2683 by gross, Tue Sep 29 02:20:22 2009 UTC
# Line 1  Line 1 
1  # $Id$  
2    ########################################################
3    #
4    # Copyright (c) 2003-2009 by University of Queensland
5    # Earth Systems Science Computational Center (ESSCC)
6    # http://www.uq.edu.au/esscc
7    #
8    # Primary Business: Queensland, Australia
9    # Licensed under the Open Software License version 3.0
10    # http://www.opensource.org/licenses/osl-3.0.php
11    #
12    ########################################################
13    
14    __copyright__="""Copyright (c) 2003-2009 by University of Queensland
15    Earth Systems Science Computational Center (ESSCC)
16    http://www.uq.edu.au/esscc
17    Primary Business: Queensland, Australia"""
18    __license__="""Licensed under the Open Software License version 3.0
19    http://www.opensource.org/licenses/osl-3.0.php"""
20    __url__="https://launchpad.net/escript-finley"
21    
22  """  """
23  Provides some tools related to PDEs.  Provides some tools related to PDEs.
24    
25  Currently includes:  Currently includes:
26      - Projector - to project a discontinuous      - Projector - to project a discontinuous function onto a continuous function
27      - Locator - to trace values in data objects at a certain location      - Locator - to trace values in data objects at a certain location
28      - TimeIntegrationManager - to handel extraplotion in time      - TimeIntegrationManager - to handle extrapolation in time
29        - SaddlePointProblem - solver for Saddle point problems using the inexact uszawa scheme
30    
31  @var __author__: name of author  :var __author__: name of author
32  @var __copyright__: copyrights  :var __copyright__: copyrights
33  @var __license__: licence agreement  :var __license__: licence agreement
34  @var __url__: url entry point on documentation  :var __url__: url entry point on documentation
35  @var __version__: version  :var __version__: version
36  @var __date__: date of the version  :var __date__: date of the version
37  """  """
38    
39  __author__="Lutz Gross, l.gross@uq.edu.au"  __author__="Lutz Gross, l.gross@uq.edu.au"
 __copyright__="""  Copyright (c) 2006 by ACcESS MNRF  
                     http://www.access.edu.au  
                 Primary Business: Queensland, Australia"""  
 __license__="""Licensed under the Open Software License version 3.0  
              http://www.opensource.org/licenses/osl-3.0.php"""  
 __url__="http://www.iservo.edu.au/esys"  
 __version__="$Revision$"  
 __date__="$Date$"  
40    
41    
42  import escript  import escript
43  import linearPDEs  import linearPDEs
44  import numarray  import numpy
45  import util  import util
46    import math
47    
48  class TimeIntegrationManager:  class TimeIntegrationManager:
49    """    """
50    a simple mechanism to manage time dependend values.    A simple mechanism to manage time dependend values.
51    
52    typical usage is::    Typical usage is::
53    
54       dt=0.1 # time increment       dt=0.1 # time increment
55       tm=TimeIntegrationManager(inital_value,p=1)       tm=TimeIntegrationManager(inital_value,p=1)
# Line 46  class TimeIntegrationManager: Line 59  class TimeIntegrationManager:
59           tm.checkin(dt,v)           tm.checkin(dt,v)
60           t+=dt           t+=dt
61    
62    @note: currently only p=1 is supported.    :note: currently only p=1 is supported.
63    """    """
64    def __init__(self,*inital_values,**kwargs):    def __init__(self,*inital_values,**kwargs):
65       """       """
66       sets up the value manager where inital_value is the initial value and p is order used for extrapolation       Sets up the value manager where ``inital_values`` are the initial values
67         and p is the order used for extrapolation.
68       """       """
69       if kwargs.has_key("p"):       if kwargs.has_key("p"):
70              self.__p=kwargs["p"]              self.__p=kwargs["p"]
# Line 67  class TimeIntegrationManager: Line 81  class TimeIntegrationManager:
81    
82    def getTime(self):    def getTime(self):
83        return self.__t        return self.__t
84    
85    def getValue(self):    def getValue(self):
86        out=self.__v_mem[0]        out=self.__v_mem[0]
87        if len(out)==1:        if len(out)==1:
# Line 76  class TimeIntegrationManager: Line 91  class TimeIntegrationManager:
91    
92    def checkin(self,dt,*values):    def checkin(self,dt,*values):
93        """        """
94        adds new values to the manager. the p+1 last value get lost        Adds new values to the manager. The p+1 last values are lost.
95        """        """
96        o=min(self.__order+1,self.__p)        o=min(self.__order+1,self.__p)
97        self.__order=min(self.__order+1,self.__p)        self.__order=min(self.__order+1,self.__p)
# Line 93  class TimeIntegrationManager: Line 108  class TimeIntegrationManager:
108    
109    def extrapolate(self,dt):    def extrapolate(self,dt):
110        """        """
111        extrapolates to dt forward in time.        Extrapolates to ``dt`` forward in time.
112        """        """
113        if self.__order==0:        if self.__order==0:
114           out=self.__v_mem[0]           out=self.__v_mem[0]
# Line 108  class TimeIntegrationManager: Line 123  class TimeIntegrationManager:
123           return out[0]           return out[0]
124        else:        else:
125           return out           return out
126    
127    
128  class Projector:  class Projector:
129    """    """
130    The Projector is a factory which projects a discontiuous function onto a    The Projector is a factory which projects a discontinuous function onto a
131    continuous function on the a given domain.    continuous function on a given domain.
132    """    """
133    def __init__(self, domain, reduce = True, fast=True):    def __init__(self, domain, reduce=True, fast=True):
134      """      """
135      Create a continuous function space projector for a domain.      Creates a continuous function space projector for a domain.
136    
137      @param domain: Domain of the projection.      :param domain: Domain of the projection.
138      @param reduce: Flag to reduce projection order (default is True)      :param reduce: Flag to reduce projection order
139      @param fast: Flag to use a fast method based on matrix lumping (default is true)      :param fast: Flag to use a fast method based on matrix lumping
140      """      """
141      self.__pde = linearPDEs.LinearPDE(domain)      self.__pde = linearPDEs.LinearPDE(domain)
142      if fast:      if fast:
143        self.__pde.setSolverMethod(linearPDEs.LinearPDE.LUMPING)          self.__pde.getSolverOptions().setSolverMethod(linearPDEs.SolverOptions.LUMPING)
144      self.__pde.setSymmetryOn()      self.__pde.setSymmetryOn()
145      self.__pde.setReducedOrderTo(reduce)      self.__pde.setReducedOrderTo(reduce)
146      self.__pde.setValue(D = 1.)      self.__pde.setValue(D = 1.)
147      return      return
148      def getSolverOptions(self):
149    def __del__(self):      """
150      return      Returns the solver options of the PDE solver.
151        
152        :rtype: `linearPDEs.SolverOptions`
153        """
154    
155    def __call__(self, input_data):    def __call__(self, input_data):
156      """      """
157      Projects input_data onto a continuous function      Projects ``input_data`` onto a continuous function.
158    
159      @param input_data: The input_data to be projected.      :param input_data: the data to be projected
160      """      """
161      out=escript.Data(0.,input_data.getShape(),self.__pde.getFunctionSpaceForSolution())      out=escript.Data(0.,input_data.getShape(),self.__pde.getFunctionSpaceForSolution())
162        self.__pde.setValue(Y = escript.Data(), Y_reduced = escript.Data())
163      if input_data.getRank()==0:      if input_data.getRank()==0:
164          self.__pde.setValue(Y = input_data)          self.__pde.setValue(Y = input_data)
165          out=self.__pde.getSolution()          out=self.__pde.getSolution()
# Line 170  class Projector: Line 189  class Projector:
189    
190  class NoPDE:  class NoPDE:
191       """       """
192       solves the following problem for u:       Solves the following problem for u:
193    
194       M{kronecker[i,j]*D[j]*u[j]=Y[i]}       *kronecker[i,j]*D[j]*u[j]=Y[i]*
195    
196       with constraint       with constraint
197    
198       M{u[j]=r[j]}  where M{q[j]>0}       *u[j]=r[j]*  where *q[j]>0*
199    
200       where D, Y, r and q are given functions of rank 1.       where *D*, *Y*, *r* and *q* are given functions of rank 1.
201    
202       In the case of scalars this takes the form       In the case of scalars this takes the form
203    
204       M{D*u=Y}       *D*u=Y*
205    
206       with constraint       with constraint
207    
208       M{u=r}  where M{q>0}       *u=r* where *q>0*
209    
210       where D, Y, r and q are given scalar functions.       where *D*, *Y*, *r* and *q* are given scalar functions.
211    
212       The constraint is overwriting any other condition.       The constraint overwrites any other condition.
213    
214       @note: This class is similar to the L{linearPDEs.LinearPDE} class with A=B=C=X=0 but has the intention       :note: This class is similar to the `linearPDEs.LinearPDE` class with
215              that all input parameter are given in L{Solution} or L{ReducedSolution}. The whole              A=B=C=X=0 but has the intention that all input parameters are given
216              thing is a bit strange and I blame Robert.Woodcock@csiro.au for this.              in `Solution` or `ReducedSolution`.
217       """       """
218         # The whole thing is a bit strange and I blame Rob Woodcock (CSIRO) for
219         # this.
220       def __init__(self,domain,D=None,Y=None,q=None,r=None):       def __init__(self,domain,D=None,Y=None,q=None,r=None):
221           """           """
222           initialize the problem           Initializes the problem.
223    
224           @param domain: domain of the PDE.           :param domain: domain of the PDE
225           @type domain: L{Domain}           :type domain: `Domain`
226           @param D: coefficient of the solution.           :param D: coefficient of the solution
227           @type D: C{float}, C{int}, L{numarray.NumArray}, L{Data}           :type D: ``float``, ``int``, ``numpy.ndarray``, `Data`
228           @param Y: right hand side           :param Y: right hand side
229           @type Y: C{float}, C{int}, L{numarray.NumArray}, L{Data}           :type Y: ``float``, ``int``, ``numpy.ndarray``, `Data`
230           @param q: location of constraints           :param q: location of constraints
231           @type q: C{float}, C{int}, L{numarray.NumArray}, L{Data}           :type q: ``float``, ``int``, ``numpy.ndarray``, `Data`
232           @param r: value of solution at locations of constraints           :param r: value of solution at locations of constraints
233           @type r: C{float}, C{int}, L{numarray.NumArray}, L{Data}           :type r: ``float``, ``int``, ``numpy.ndarray``, `Data`
234           """           """
235           self.__domain=domain           self.__domain=domain
236           self.__D=D           self.__D=D
# Line 218  class NoPDE: Line 239  class NoPDE:
239           self.__r=r           self.__r=r
240           self.__u=None           self.__u=None
241           self.__function_space=escript.Solution(self.__domain)           self.__function_space=escript.Solution(self.__domain)
242    
243       def setReducedOn(self):       def setReducedOn(self):
244           """           """
245           sets the L{FunctionSpace} of the solution to L{ReducedSolution}           Sets the `FunctionSpace` of the solution to `ReducedSolution`.
246           """           """
247           self.__function_space=escript.ReducedSolution(self.__domain)           self.__function_space=escript.ReducedSolution(self.__domain)
248           self.__u=None           self.__u=None
249    
250       def setReducedOff(self):       def setReducedOff(self):
251           """           """
252           sets the L{FunctionSpace} of the solution to L{Solution}           Sets the `FunctionSpace` of the solution to `Solution`.
253           """           """
254           self.__function_space=escript.Solution(self.__domain)           self.__function_space=escript.Solution(self.__domain)
255           self.__u=None           self.__u=None
256            
257       def setValue(self,D=None,Y=None,q=None,r=None):       def setValue(self,D=None,Y=None,q=None,r=None):
258           """           """
259           assigns values to the parameters.           Assigns values to the parameters.
260    
261           @param D: coefficient of the solution.           :param D: coefficient of the solution
262           @type D: C{float}, C{int}, L{numarray.NumArray}, L{Data}           :type D: ``float``, ``int``, ``numpy.ndarray``, `Data`
263           @param Y: right hand side           :param Y: right hand side
264           @type Y: C{float}, C{int}, L{numarray.NumArray}, L{Data}           :type Y: ``float``, ``int``, ``numpy.ndarray``, `Data`
265           @param q: location of constraints           :param q: location of constraints
266           @type q: C{float}, C{int}, L{numarray.NumArray}, L{Data}           :type q: ``float``, ``int``, ``numpy.ndarray``, `Data`
267           @param r: value of solution at locations of constraints           :param r: value of solution at locations of constraints
268           @type r: C{float}, C{int}, L{numarray.NumArray}, L{Data}           :type r: ``float``, ``int``, ``numpy.ndarray``, `Data`
269           """           """
270           if not D==None:           if not D==None:
271              self.__D=D              self.__D=D
# Line 260  class NoPDE: Line 282  class NoPDE:
282    
283       def getSolution(self):       def getSolution(self):
284           """           """
285           returns the solution           Returns the solution.
286            
287           @return: the solution of the problem           :return: the solution of the problem
288           @rtype: L{Data} object in the L{FunctionSpace} L{Solution} or L{ReducedSolution}.           :rtype: `Data` object in the `FunctionSpace` `Solution` or
289                     `ReducedSolution`
290           """           """
291           if self.__u==None:           if self.__u==None:
292              if self.__D==None:              if self.__D==None:
# Line 280  class NoPDE: Line 303  class NoPDE:
303                  self.__u*=(1.-q)                  self.__u*=(1.-q)
304                  if not self.__r==None: self.__u+=q*self.__r                  if not self.__r==None: self.__u+=q*self.__r
305           return self.__u           return self.__u
306                
307  class Locator:  class Locator:
308       """       """
309       Locator provides access to the values of data objects at a given       Locator provides access to the values of data objects at a given spatial
310       spatial coordinate x.         coordinate x.
311        
312       In fact, a Locator object finds the sample in the set of samples of a       In fact, a Locator object finds the sample in the set of samples of a
313       given function space or domain where which is closest to the given       given function space or domain which is closest to the given point x.
      point x.  
314       """       """
315    
316       def __init__(self,where,x=numarray.zeros((3,))):       def __init__(self,where,x=numpy.zeros((3,))):
317         """         """
318         Initializes a Locator to access values in Data objects on the Doamin         Initializes a Locator to access values in Data objects on the Doamin
319         or FunctionSpace where for the sample point which         or FunctionSpace for the sample point which is closest to the given
320         closest to the given point x.         point x.
321    
322           :param where: function space
323           :type where: `escript.FunctionSpace`
324           :param x: location(s) of the Locator
325           :type x: ``numpy.ndarray`` or ``list`` of ``numpy.ndarray``
326         """         """
327         if isinstance(where,escript.FunctionSpace):         if isinstance(where,escript.FunctionSpace):
328            self.__function_space=where            self.__function_space=where
329         else:         else:
330            self.__function_space=escript.ContinuousFunction(where)            self.__function_space=escript.ContinuousFunction(where)
331         self.__id=util.length(self.__function_space.getX()-x[:self.__function_space.getDim()]).mindp()         iterative=False
332           if isinstance(x, list):
333               if len(x)==0:
334                  raise "ValueError", "At least one point must be given."
335               try:
336                 iter(x[0])
337                 iterative=True
338               except TypeError:
339                 iterative=False
340           if iterative:
341               self.__id=[]
342               for p in x:
343                  self.__id.append(util.length(self.__function_space.getX()-p[:self.__function_space.getDim()]).minGlobalDataPoint())
344           else:
345               self.__id=util.length(self.__function_space.getX()-x[:self.__function_space.getDim()]).minGlobalDataPoint()
346    
347       def __str__(self):       def __str__(self):
348         """         """
349         Returns the coordinates of the Locator as a string.         Returns the coordinates of the Locator as a string.
350         """         """
351         return "<Locator %s>"%str(self.getX())         x=self.getX()
352           if isinstance(x,list):
353              out="["
354              first=True
355              for xx in x:
356                if not first:
357                    out+=","
358                else:
359                    first=False
360                out+=str(xx)
361              out+="]>"
362           else:
363              out=str(x)
364           return out
365    
366         def getX(self):
367            """
368            Returns the exact coordinates of the Locator.
369            """
370            return self(self.getFunctionSpace().getX())
371    
372       def getFunctionSpace(self):       def getFunctionSpace(self):
373          """          """
374      Returns the function space of the Locator.          Returns the function space of the Locator.
375      """          """
376          return self.__function_space          return self.__function_space
377    
378       def getId(self):       def getId(self,item=None):
379          """          """
380      Returns the identifier of the location.          Returns the identifier of the location.
     """  
         return self.__id  
   
      def getX(self):  
381          """          """
382      Returns the exact coordinates of the Locator.          if item == None:
383      """             return self.__id
384          return self(self.getFunctionSpace().getX())          else:
385               if isinstance(self.__id,list):
386                  return self.__id[item]
387               else:
388                  return self.__id
389    
390    
391       def __call__(self,data):       def __call__(self,data):
392          """          """
393      Returns the value of data at the Locator of a Data object otherwise          Returns the value of data at the Locator of a Data object.
394      the object is returned.          """
     """  
395          return self.getValue(data)          return self.getValue(data)
396    
397       def getValue(self,data):       def getValue(self,data):
398          """          """
399      Returns the value of data at the Locator if data is a Data object          Returns the value of ``data`` at the Locator if ``data`` is a `Data`
400      otherwise the object is returned.          object otherwise the object is returned.
401      """          """
402          if isinstance(data,escript.Data):          if isinstance(data,escript.Data):
403             if data.getFunctionSpace()==self.getFunctionSpace():             dat=util.interpolate(data,self.getFunctionSpace())
404               out=data.convertToNumArrayFromDPNo(self.getId()[0],self.getId()[1])             id=self.getId()
405               #out=data.convertToNumArrayFromDPNo(self.getId()[0],self.getId()[1],self.getId()[2])             r=data.getRank()
406             else:             if isinstance(id,list):
407               out=data.interpolate(self.getFunctionSpace()).convertToNumArrayFromDPNo(self.getId()[0],self.getId()[1])                 out=[]
408               #out=data.interpolate(self.getFunctionSpace()).convertToNumArrayFromDPNo(self.getId()[0],self.getId()[1],self.getId()[2])                 for i in id:
409             if data.getRank()==0:                    o=numpy.array(dat.getTupleForGlobalDataPoint(*i))
410                return out[0]                    if data.getRank()==0:
411                         out.append(o[0])
412                      else:
413                         out.append(o)
414                   return out
415             else:             else:
416                return out               out=numpy.array(dat.getTupleForGlobalDataPoint(*id))
417                 if data.getRank()==0:
418                    return out[0]
419                 else:
420                    return out
421          else:          else:
422             return data             return data
423    
424  # vim: expandtab shiftwidth=4:  class SolverSchemeException(Exception):
425       """
426       This is a generic exception thrown by solvers.
427       """
428       pass
429    
430    class IndefinitePreconditioner(SolverSchemeException):
431       """
432       Exception thrown if the preconditioner is not positive definite.
433       """
434       pass
435    
436    class MaxIterReached(SolverSchemeException):
437       """
438       Exception thrown if the maximum number of iteration steps is reached.
439       """
440       pass
441    
442    class CorrectionFailed(SolverSchemeException):
443       """
444       Exception thrown if no convergence has been achieved in the solution
445       correction scheme.
446       """
447       pass
448    
449    class IterationBreakDown(SolverSchemeException):
450       """
451       Exception thrown if the iteration scheme encountered an incurable breakdown.
452       """
453       pass
454    
455    class NegativeNorm(SolverSchemeException):
456       """
457       Exception thrown if a norm calculation returns a negative norm.
458       """
459       pass
460    
461    def PCG(r, Aprod, x, Msolve, bilinearform, atol=0, rtol=1.e-8, iter_max=100, initial_guess=True, verbose=False):
462       """
463       Solver for
464    
465       *Ax=b*
466    
467       with a symmetric and positive definite operator A (more details required!).
468       It uses the conjugate gradient method with preconditioner M providing an
469       approximation of A.
470    
471       The iteration is terminated if
472    
473       *|r| <= atol+rtol*|r0|*
474    
475       where *r0* is the initial residual and *|.|* is the energy norm. In fact
476    
477       *|r| = sqrt( bilinearform(Msolve(r),r))*
478    
479       For details on the preconditioned conjugate gradient method see the book:
480    
481       I{Templates for the Solution of Linear Systems by R. Barrett, M. Berry,
482       T.F. Chan, J. Demmel, J. Donato, J. Dongarra, V. Eijkhout, R. Pozo,
483       C. Romine, and H. van der Vorst}.
484    
485       :param r: initial residual *r=b-Ax*. ``r`` is altered.
486       :type r: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
487       :param x: an initial guess for the solution
488       :type x: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
489       :param Aprod: returns the value Ax
490       :type Aprod: function ``Aprod(x)`` where ``x`` is of the same object like
491                    argument ``x``. The returned object needs to be of the same type
492                    like argument ``r``.
493       :param Msolve: solves Mx=r
494       :type Msolve: function ``Msolve(r)`` where ``r`` is of the same type like
495                     argument ``r``. The returned object needs to be of the same
496                     type like argument ``x``.
497       :param bilinearform: inner product ``<x,r>``
498       :type bilinearform: function ``bilinearform(x,r)`` where ``x`` is of the same
499                           type like argument ``x`` and ``r`` is. The returned value
500                           is a ``float``.
501       :param atol: absolute tolerance
502       :type atol: non-negative ``float``
503       :param rtol: relative tolerance
504       :type rtol: non-negative ``float``
505       :param iter_max: maximum number of iteration steps
506       :type iter_max: ``int``
507       :return: the solution approximation and the corresponding residual
508       :rtype: ``tuple``
509       :warning: ``r`` and ``x`` are altered.
510       """
511       iter=0
512       rhat=Msolve(r)
513       d = rhat
514       rhat_dot_r = bilinearform(rhat, r)
515       if rhat_dot_r<0: raise NegativeNorm,"negative norm."
516       norm_r0=math.sqrt(rhat_dot_r)
517       atol2=atol+rtol*norm_r0
518       if atol2<=0:
519          raise ValueError,"Non-positive tolarance."
520       atol2=max(atol2, 100. * util.EPSILON * norm_r0)
521    
522       if verbose: print "PCG: initial residual norm = %e (absolute tolerance = %e)"%(norm_r0, atol2)
523    
524    
525       while not math.sqrt(rhat_dot_r) <= atol2:
526           iter+=1
527           if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached."%iter_max
528    
529           q=Aprod(d)
530           alpha = rhat_dot_r / bilinearform(d, q)
531           x += alpha * d
532           if isinstance(q,ArithmeticTuple):
533           r += q * (-alpha)      # Doing it the other way calls the float64.__mul__ not AT.__rmul__
534           else:
535               r += (-alpha) * q
536           rhat=Msolve(r)
537           rhat_dot_r_new = bilinearform(rhat, r)
538           beta = rhat_dot_r_new / rhat_dot_r
539           rhat+=beta * d
540           d=rhat
541    
542           rhat_dot_r = rhat_dot_r_new
543           if rhat_dot_r<0: raise NegativeNorm,"negative norm."
544           if verbose: print "PCG: iteration step %s: residual norm = %e"%(iter, math.sqrt(rhat_dot_r))
545       if verbose: print "PCG: tolerance reached after %s steps."%iter
546       return x,r,math.sqrt(rhat_dot_r)
547    
548    class Defect(object):
549        """
550        Defines a non-linear defect F(x) of a variable x.
551        """
552        def __init__(self):
553            """
554            Initializes defect.
555            """
556            self.setDerivativeIncrementLength()
557    
558        def bilinearform(self, x0, x1):
559            """
560            Returns the inner product of x0 and x1
561    
562            :param x0: value for x0
563            :param x1: value for x1
564            :return: the inner product of x0 and x1
565            :rtype: ``float``
566            """
567            return 0
568    
569        def norm(self,x):
570            """
571            Returns the norm of argument ``x``.
572    
573            :param x: a value
574            :return: norm of argument x
575            :rtype: ``float``
576            :note: by default ``sqrt(self.bilinearform(x,x)`` is returned.
577            """
578            s=self.bilinearform(x,x)
579            if s<0: raise NegativeNorm,"negative norm."
580            return math.sqrt(s)
581    
582        def eval(self,x):
583            """
584            Returns the value F of a given ``x``.
585    
586            :param x: value for which the defect ``F`` is evaluated
587            :return: value of the defect at ``x``
588            """
589            return 0
590    
591        def __call__(self,x):
592            return self.eval(x)
593    
594        def setDerivativeIncrementLength(self,inc=1000.*math.sqrt(util.EPSILON)):
595            """
596            Sets the relative length of the increment used to approximate the
597            derivative of the defect. The increment is inc*norm(x)/norm(v)*v in the
598            direction of v with x as a starting point.
599    
600            :param inc: relative increment length
601            :type inc: positive ``float``
602            """
603            if inc<=0: raise ValueError,"positive increment required."
604            self.__inc=inc
605    
606        def getDerivativeIncrementLength(self):
607            """
608            Returns the relative increment length used to approximate the
609            derivative of the defect.
610            :return: value of the defect at ``x``
611            :rtype: positive ``float``
612            """
613            return self.__inc
614    
615        def derivative(self, F0, x0, v, v_is_normalised=True):
616            """
617            Returns the directional derivative at ``x0`` in the direction of ``v``.
618    
619            :param F0: value of this defect at x0
620            :param x0: value at which derivative is calculated
621            :param v: direction
622            :param v_is_normalised: True to indicate that ``v`` is nomalized
623                                    (self.norm(v)=0)
624            :return: derivative of this defect at x0 in the direction of ``v``
625            :note: by default numerical evaluation (self.eval(x0+eps*v)-F0)/eps is
626                   used but this method maybe overwritten to use exact evaluation.
627            """
628            normx=self.norm(x0)
629            if normx>0:
630                 epsnew = self.getDerivativeIncrementLength() * normx
631            else:
632                 epsnew = self.getDerivativeIncrementLength()
633            if not v_is_normalised:
634                normv=self.norm(v)
635                if normv<=0:
636                   return F0*0
637                else:
638                   epsnew /= normv
639            F1=self.eval(x0 + epsnew * v)
640            return (F1-F0)/epsnew
641    
642    ######################################
643    def NewtonGMRES(defect, x, iter_max=100, sub_iter_max=20, atol=0,rtol=1.e-4, sub_tol_max=0.5, gamma=0.9, verbose=False):
644       """
645       Solves a non-linear problem *F(x)=0* for unknown *x* using the stopping
646       criterion:
647    
648       *norm(F(x) <= atol + rtol * norm(F(x0)*
649    
650       where *x0* is the initial guess.
651    
652       :param defect: object defining the function *F*. ``defect.norm`` defines the
653                      *norm* used in the stopping criterion.
654       :type defect: `Defect`
655       :param x: initial guess for the solution, ``x`` is altered.
656       :type x: any object type allowing basic operations such as
657                ``numpy.ndarray``, `Data`
658       :param iter_max: maximum number of iteration steps
659       :type iter_max: positive ``int``
660       :param sub_iter_max: maximum number of inner iteration steps
661       :type sub_iter_max: positive ``int``
662       :param atol: absolute tolerance for the solution
663       :type atol: positive ``float``
664       :param rtol: relative tolerance for the solution
665       :type rtol: positive ``float``
666       :param gamma: tolerance safety factor for inner iteration
667       :type gamma: positive ``float``, less than 1
668       :param sub_tol_max: upper bound for inner tolerance
669       :type sub_tol_max: positive ``float``, less than 1
670       :return: an approximation of the solution with the desired accuracy
671       :rtype: same type as the initial guess
672       """
673       lmaxit=iter_max
674       if atol<0: raise ValueError,"atol needs to be non-negative."
675       if rtol<0: raise ValueError,"rtol needs to be non-negative."
676       if rtol+atol<=0: raise ValueError,"rtol or atol needs to be non-negative."
677       if gamma<=0 or gamma>=1: raise ValueError,"tolerance safety factor for inner iteration (gamma =%s) needs to be positive and less than 1."%gamma
678       if sub_tol_max<=0 or sub_tol_max>=1: raise ValueError,"upper bound for inner tolerance for inner iteration (sub_tol_max =%s) needs to be positive and less than 1."%sub_tol_max
679    
680       F=defect(x)
681       fnrm=defect.norm(F)
682       stop_tol=atol + rtol*fnrm
683       sub_tol=sub_tol_max
684       if verbose: print "NewtonGMRES: initial residual = %e."%fnrm
685       if verbose: print "             tolerance = %e."%sub_tol
686       iter=1
687       #
688       # main iteration loop
689       #
690       while not fnrm<=stop_tol:
691                if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached."%iter_max
692                #
693            #   adjust sub_tol_
694            #
695                if iter > 1:
696               rat=fnrm/fnrmo
697                   sub_tol_old=sub_tol
698               sub_tol=gamma*rat**2
699               if gamma*sub_tol_old**2 > .1: sub_tol=max(sub_tol,gamma*sub_tol_old**2)
700               sub_tol=max(min(sub_tol,sub_tol_max), .5*stop_tol/fnrm)
701            #
702            # calculate newton increment xc
703                #     if iter_max in __FDGMRES is reached MaxIterReached is thrown
704                #     if iter_restart -1 is returned as sub_iter
705                #     if  atol is reached sub_iter returns the numer of steps performed to get there
706                #
707                #
708                if verbose: print "             subiteration (GMRES) is called with relative tolerance %e."%sub_tol
709                try:
710                   xc, sub_iter=__FDGMRES(F, defect, x, sub_tol*fnrm, iter_max=iter_max-iter, iter_restart=sub_iter_max)
711                except MaxIterReached:
712                   raise MaxIterReached,"maximum number of %s steps reached."%iter_max
713                if sub_iter<0:
714                   iter+=sub_iter_max
715                else:
716                   iter+=sub_iter
717                # ====
718            x+=xc
719                F=defect(x)
720            iter+=1
721                fnrmo, fnrm=fnrm, defect.norm(F)
722                if verbose: print "             step %s: residual %e."%(iter,fnrm)
723       if verbose: print "NewtonGMRES: completed after %s steps."%iter
724       return x
725    
726    def __givapp(c,s,vin):
727        """
728        Applies a sequence of Givens rotations (c,s) recursively to the vector
729        ``vin``
730    
731        :warning: ``vin`` is altered.
732        """
733        vrot=vin
734        if isinstance(c,float):
735            vrot=[c*vrot[0]-s*vrot[1],s*vrot[0]+c*vrot[1]]
736        else:
737            for i in range(len(c)):
738                w1=c[i]*vrot[i]-s[i]*vrot[i+1]
739            w2=s[i]*vrot[i]+c[i]*vrot[i+1]
740                vrot[i]=w1
741                vrot[i+1]=w2
742        return vrot
743    
744    def __FDGMRES(F0, defect, x0, atol, iter_max=100, iter_restart=20):
745       h=numpy.zeros((iter_restart,iter_restart),numpy.float64)
746       c=numpy.zeros(iter_restart,numpy.float64)
747       s=numpy.zeros(iter_restart,numpy.float64)
748       g=numpy.zeros(iter_restart,numpy.float64)
749       v=[]
750    
751       rho=defect.norm(F0)
752       if rho<=0.: return x0*0
753    
754       v.append(-F0/rho)
755       g[0]=rho
756       iter=0
757       while rho > atol and iter<iter_restart-1:
758            if iter  >= iter_max:
759                raise MaxIterReached,"maximum number of %s steps reached."%iter_max
760    
761            p=defect.derivative(F0,x0,v[iter], v_is_normalised=True)
762            v.append(p)
763    
764            v_norm1=defect.norm(v[iter+1])
765    
766            # Modified Gram-Schmidt
767            for j in range(iter+1):
768                h[j,iter]=defect.bilinearform(v[j],v[iter+1])
769                v[iter+1]-=h[j,iter]*v[j]
770    
771            h[iter+1,iter]=defect.norm(v[iter+1])
772            v_norm2=h[iter+1,iter]
773    
774            # Reorthogonalize if needed
775            if v_norm1 + 0.001*v_norm2 == v_norm1:   #Brown/Hindmarsh condition (default)
776                for j in range(iter+1):
777                    hr=defect.bilinearform(v[j],v[iter+1])
778                    h[j,iter]=h[j,iter]+hr
779                    v[iter+1] -= hr*v[j]
780    
781                v_norm2=defect.norm(v[iter+1])
782                h[iter+1,iter]=v_norm2
783            #   watch out for happy breakdown
784            if not v_norm2 == 0:
785                v[iter+1]=v[iter+1]/h[iter+1,iter]
786    
787            #   Form and store the information for the new Givens rotation
788            if iter > 0 :
789                hhat=numpy.zeros(iter+1,numpy.float64)
790                for i in range(iter+1) : hhat[i]=h[i,iter]
791                hhat=__givapp(c[0:iter],s[0:iter],hhat);
792                for i in range(iter+1) : h[i,iter]=hhat[i]
793    
794            mu=math.sqrt(h[iter,iter]*h[iter,iter]+h[iter+1,iter]*h[iter+1,iter])
795    
796            if mu!=0 :
797                c[iter]=h[iter,iter]/mu
798                s[iter]=-h[iter+1,iter]/mu
799                h[iter,iter]=c[iter]*h[iter,iter]-s[iter]*h[iter+1,iter]
800                h[iter+1,iter]=0.0
801                gg=__givapp(c[iter],s[iter],[g[iter],g[iter+1]])
802                g[iter]=gg[0]
803                g[iter+1]=gg[1]
804    
805            # Update the residual norm
806            rho=abs(g[iter+1])
807            iter+=1
808    
809       # At this point either iter > iter_max or rho < tol.
810       # It's time to compute x and leave.
811       if iter > 0 :
812         y=numpy.zeros(iter,numpy.float64)
813         y[iter-1] = g[iter-1] / h[iter-1,iter-1]
814         if iter > 1 :
815            i=iter-2
816            while i>=0 :
817              y[i] = ( g[i] - numpy.dot(h[i,i+1:iter], y[i+1:iter])) / h[i,i]
818              i=i-1
819         xhat=v[iter-1]*y[iter-1]
820         for i in range(iter-1):
821        xhat += v[i]*y[i]
822       else :
823          xhat=v[0] * 0
824    
825       if iter<iter_restart-1:
826          stopped=iter
827       else:
828          stopped=-1
829    
830       return xhat,stopped
831    
832    def GMRES(r, Aprod, x, bilinearform, atol=0, rtol=1.e-8, iter_max=100, iter_restart=20, verbose=False,P_R=None):
833       """
834       Solver for
835    
836       *Ax=b*
837    
838       with a general operator A (more details required!).
839       It uses the generalized minimum residual method (GMRES).
840    
841       The iteration is terminated if
842    
843       *|r| <= atol+rtol*|r0|*
844    
845       where *r0* is the initial residual and *|.|* is the energy norm. In fact
846    
847       *|r| = sqrt( bilinearform(r,r))*
848    
849       :param r: initial residual *r=b-Ax*. ``r`` is altered.
850       :type r: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
851       :param x: an initial guess for the solution
852       :type x: same like ``r``
853       :param Aprod: returns the value Ax
854       :type Aprod: function ``Aprod(x)`` where ``x`` is of the same object like
855                    argument ``x``. The returned object needs to be of the same
856                    type like argument ``r``.
857       :param bilinearform: inner product ``<x,r>``
858       :type bilinearform: function ``bilinearform(x,r)`` where ``x`` is of the same
859                           type like argument ``x`` and ``r``. The returned value is
860                           a ``float``.
861       :param atol: absolute tolerance
862       :type atol: non-negative ``float``
863       :param rtol: relative tolerance
864       :type rtol: non-negative ``float``
865       :param iter_max: maximum number of iteration steps
866       :type iter_max: ``int``
867       :param iter_restart: in order to save memory the orthogonalization process
868                            is terminated after ``iter_restart`` steps and the
869                            iteration is restarted.
870       :type iter_restart: ``int``
871       :return: the solution approximation and the corresponding residual
872       :rtype: ``tuple``
873       :warning: ``r`` and ``x`` are altered.
874       """
875       m=iter_restart
876       restarted=False
877       iter=0
878       if rtol>0:
879          r_dot_r = bilinearform(r, r)
880          if r_dot_r<0: raise NegativeNorm,"negative norm."
881          atol2=atol+rtol*math.sqrt(r_dot_r)
882          if verbose: print "GMRES: norm of right hand side = %e (absolute tolerance = %e)"%(math.sqrt(r_dot_r), atol2)
883       else:
884          atol2=atol
885          if verbose: print "GMRES: absolute tolerance = %e"%atol2
886       if atol2<=0:
887          raise ValueError,"Non-positive tolarance."
888    
889       while True:
890          if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached"%iter_max
891          if restarted:
892             r2 = r-Aprod(x-x2)
893          else:
894             r2=1*r
895          x2=x*1.
896          x,stopped=_GMRESm(r2, Aprod, x, bilinearform, atol2, iter_max=iter_max-iter, iter_restart=m, verbose=verbose,P_R=P_R)
897          iter+=iter_restart
898          if stopped: break
899          if verbose: print "GMRES: restart."
900          restarted=True
901       if verbose: print "GMRES: tolerance has been reached."
902       return x
903    
904    def _GMRESm(r, Aprod, x, bilinearform, atol, iter_max=100, iter_restart=20, verbose=False, P_R=None):
905       iter=0
906    
907       h=numpy.zeros((iter_restart+1,iter_restart),numpy.float64)
908       c=numpy.zeros(iter_restart,numpy.float64)
909       s=numpy.zeros(iter_restart,numpy.float64)
910       g=numpy.zeros(iter_restart+1,numpy.float64)
911       v=[]
912    
913       r_dot_r = bilinearform(r, r)
914       if r_dot_r<0: raise NegativeNorm,"negative norm."
915       rho=math.sqrt(r_dot_r)
916    
917       v.append(r/rho)
918       g[0]=rho
919    
920       if verbose: print "GMRES: initial residual %e (absolute tolerance = %e)"%(rho,atol)
921       while not (rho<=atol or iter==iter_restart):
922    
923        if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached."%iter_max
924    
925            if P_R!=None:
926                p=Aprod(P_R(v[iter]))
927            else:
928            p=Aprod(v[iter])
929        v.append(p)
930    
931        v_norm1=math.sqrt(bilinearform(v[iter+1], v[iter+1]))
932    
933    # Modified Gram-Schmidt
934        for j in range(iter+1):
935          h[j,iter]=bilinearform(v[j],v[iter+1])
936          v[iter+1]-=h[j,iter]*v[j]
937    
938        h[iter+1,iter]=math.sqrt(bilinearform(v[iter+1],v[iter+1]))
939        v_norm2=h[iter+1,iter]
940    
941    # Reorthogonalize if needed
942        if v_norm1 + 0.001*v_norm2 == v_norm1:   #Brown/Hindmarsh condition (default)
943         for j in range(iter+1):
944            hr=bilinearform(v[j],v[iter+1])
945                h[j,iter]=h[j,iter]+hr
946                v[iter+1] -= hr*v[j]
947    
948         v_norm2=math.sqrt(bilinearform(v[iter+1], v[iter+1]))
949         h[iter+1,iter]=v_norm2
950    
951    #   watch out for happy breakdown
952            if not v_norm2 == 0:
953             v[iter+1]=v[iter+1]/h[iter+1,iter]
954    
955    #   Form and store the information for the new Givens rotation
956        if iter > 0: h[:iter+1,iter]=__givapp(c[:iter],s[:iter],h[:iter+1,iter])
957        mu=math.sqrt(h[iter,iter]*h[iter,iter]+h[iter+1,iter]*h[iter+1,iter])
958    
959        if mu!=0 :
960            c[iter]=h[iter,iter]/mu
961            s[iter]=-h[iter+1,iter]/mu
962            h[iter,iter]=c[iter]*h[iter,iter]-s[iter]*h[iter+1,iter]
963            h[iter+1,iter]=0.0
964                    gg=__givapp(c[iter],s[iter],[g[iter],g[iter+1]])
965                    g[iter]=gg[0]
966                    g[iter+1]=gg[1]
967    # Update the residual norm
968    
969            rho=abs(g[iter+1])
970            if verbose: print "GMRES: iteration step %s: residual %e"%(iter,rho)
971        iter+=1
972    
973    # At this point either iter > iter_max or rho < tol.
974    # It's time to compute x and leave.
975    
976       if verbose: print "GMRES: iteration stopped after %s step."%iter
977       if iter > 0 :
978         y=numpy.zeros(iter,numpy.float64)
979         y[iter-1] = g[iter-1] / h[iter-1,iter-1]
980         if iter > 1 :
981            i=iter-2
982            while i>=0 :
983              y[i] = ( g[i] - numpy.dot(h[i,i+1:iter], y[i+1:iter])) / h[i,i]
984              i=i-1
985         xhat=v[iter-1]*y[iter-1]
986         for i in range(iter-1):
987        xhat += v[i]*y[i]
988       else:
989         xhat=v[0] * 0
990       if P_R!=None:
991          x += P_R(xhat)
992       else:
993          x += xhat
994       if iter<iter_restart-1:
995          stopped=True
996       else:
997          stopped=False
998    
999       return x,stopped
1000    
1001    def MINRES(r, Aprod, x, Msolve, bilinearform, atol=0, rtol=1.e-8, iter_max=100):
1002        """
1003        Solver for
1004    
1005        *Ax=b*
1006    
1007        with a symmetric and positive definite operator A (more details required!).
1008        It uses the minimum residual method (MINRES) with preconditioner M
1009        providing an approximation of A.
1010    
1011        The iteration is terminated if
1012    
1013        *|r| <= atol+rtol*|r0|*
1014    
1015        where *r0* is the initial residual and *|.|* is the energy norm. In fact
1016    
1017        *|r| = sqrt( bilinearform(Msolve(r),r))*
1018    
1019        For details on the preconditioned conjugate gradient method see the book:
1020    
1021        I{Templates for the Solution of Linear Systems by R. Barrett, M. Berry,
1022        T.F. Chan, J. Demmel, J. Donato, J. Dongarra, V. Eijkhout, R. Pozo,
1023        C. Romine, and H. van der Vorst}.
1024    
1025        :param r: initial residual *r=b-Ax*. ``r`` is altered.
1026        :type r: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
1027        :param x: an initial guess for the solution
1028        :type x: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
1029        :param Aprod: returns the value Ax
1030        :type Aprod: function ``Aprod(x)`` where ``x`` is of the same object like
1031                     argument ``x``. The returned object needs to be of the same
1032                     type like argument ``r``.
1033        :param Msolve: solves Mx=r
1034        :type Msolve: function ``Msolve(r)`` where ``r`` is of the same type like
1035                      argument ``r``. The returned object needs to be of the same
1036                      type like argument ``x``.
1037        :param bilinearform: inner product ``<x,r>``
1038        :type bilinearform: function ``bilinearform(x,r)`` where ``x`` is of the same
1039                            type like argument ``x`` and ``r`` is. The returned value
1040                            is a ``float``.
1041        :param atol: absolute tolerance
1042        :type atol: non-negative ``float``
1043        :param rtol: relative tolerance
1044        :type rtol: non-negative ``float``
1045        :param iter_max: maximum number of iteration steps
1046        :type iter_max: ``int``
1047        :return: the solution approximation and the corresponding residual
1048        :rtype: ``tuple``
1049        :warning: ``r`` and ``x`` are altered.
1050        """
1051        #------------------------------------------------------------------
1052        # Set up y and v for the first Lanczos vector v1.
1053        # y  =  beta1 P' v1,  where  P = C**(-1).
1054        # v is really P' v1.
1055        #------------------------------------------------------------------
1056        r1    = r
1057        y = Msolve(r)
1058        beta1 = bilinearform(y,r)
1059    
1060        if beta1< 0: raise NegativeNorm,"negative norm."
1061    
1062        #  If r = 0 exactly, stop with x
1063        if beta1==0: return x
1064    
1065        if beta1> 0: beta1  = math.sqrt(beta1)
1066    
1067        #------------------------------------------------------------------
1068        # Initialize quantities.
1069        # ------------------------------------------------------------------
1070        iter   = 0
1071        Anorm = 0
1072        ynorm = 0
1073        oldb   = 0
1074        beta   = beta1
1075        dbar   = 0
1076        epsln  = 0
1077        phibar = beta1
1078        rhs1   = beta1
1079        rhs2   = 0
1080        rnorm  = phibar
1081        tnorm2 = 0
1082        ynorm2 = 0
1083        cs     = -1
1084        sn     = 0
1085        w      = r*0.
1086        w2     = r*0.
1087        r2     = r1
1088        eps    = 0.0001
1089    
1090        #---------------------------------------------------------------------
1091        # Main iteration loop.
1092        # --------------------------------------------------------------------
1093        while not rnorm<=atol+rtol*Anorm*ynorm:    #  checks ||r|| < (||A|| ||x||) * TOL
1094    
1095        if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached."%iter_max
1096            iter    = iter  +  1
1097    
1098            #-----------------------------------------------------------------
1099            # Obtain quantities for the next Lanczos vector vk+1, k = 1, 2,...
1100            # The general iteration is similar to the case k = 1 with v0 = 0:
1101            #
1102            #   p1      = Operator * v1  -  beta1 * v0,
1103            #   alpha1  = v1'p1,
1104            #   q2      = p2  -  alpha1 * v1,
1105            #   beta2^2 = q2'q2,
1106            #   v2      = (1/beta2) q2.
1107            #
1108            # Again, y = betak P vk,  where  P = C**(-1).
1109            #-----------------------------------------------------------------
1110            s = 1/beta                 # Normalize previous vector (in y).
1111            v = s*y                    # v = vk if P = I
1112    
1113            y      = Aprod(v)
1114    
1115            if iter >= 2:
1116              y = y - (beta/oldb)*r1
1117    
1118            alfa   = bilinearform(v,y)              # alphak
1119            y      += (- alfa/beta)*r2
1120            r1     = r2
1121            r2     = y
1122            y = Msolve(r2)
1123            oldb   = beta                           # oldb = betak
1124            beta   = bilinearform(y,r2)             # beta = betak+1^2
1125            if beta < 0: raise NegativeNorm,"negative norm."
1126    
1127            beta   = math.sqrt( beta )
1128            tnorm2 = tnorm2 + alfa*alfa + oldb*oldb + beta*beta
1129    
1130            if iter==1:                 # Initialize a few things.
1131              gmax   = abs( alfa )      # alpha1
1132              gmin   = gmax             # alpha1
1133    
1134            # Apply previous rotation Qk-1 to get
1135            #   [deltak epslnk+1] = [cs  sn][dbark    0   ]
1136            #   [gbar k dbar k+1]   [sn -cs][alfak betak+1].
1137    
1138            oldeps = epsln
1139            delta  = cs * dbar  +  sn * alfa  # delta1 = 0         deltak
1140            gbar   = sn * dbar  -  cs * alfa  # gbar 1 = alfa1     gbar k
1141            epsln  =               sn * beta  # epsln2 = 0         epslnk+1
1142            dbar   =            -  cs * beta  # dbar 2 = beta2     dbar k+1
1143    
1144            # Compute the next plane rotation Qk
1145    
1146            gamma  = math.sqrt(gbar*gbar+beta*beta)  # gammak
1147            gamma  = max(gamma,eps)
1148            cs     = gbar / gamma             # ck
1149            sn     = beta / gamma             # sk
1150            phi    = cs * phibar              # phik
1151            phibar = sn * phibar              # phibark+1
1152    
1153            # Update  x.
1154    
1155            denom = 1/gamma
1156            w1    = w2
1157            w2    = w
1158            w     = (v - oldeps*w1 - delta*w2) * denom
1159            x     +=  phi*w
1160    
1161            # Go round again.
1162    
1163            gmax   = max(gmax,gamma)
1164            gmin   = min(gmin,gamma)
1165            z      = rhs1 / gamma
1166            ynorm2 = z*z  +  ynorm2
1167            rhs1   = rhs2 -  delta*z
1168            rhs2   =      -  epsln*z
1169    
1170            # Estimate various norms and test for convergence.
1171    
1172            Anorm  = math.sqrt( tnorm2 )
1173            ynorm  = math.sqrt( ynorm2 )
1174    
1175            rnorm  = phibar
1176    
1177        return x
1178    
1179    def TFQMR(r, Aprod, x, bilinearform, atol=0, rtol=1.e-8, iter_max=100):
1180      """
1181      Solver for
1182    
1183      *Ax=b*
1184    
1185      with a general operator A (more details required!).
1186      It uses the Transpose-Free Quasi-Minimal Residual method (TFQMR).
1187    
1188      The iteration is terminated if
1189    
1190      *|r| <= atol+rtol*|r0|*
1191    
1192      where *r0* is the initial residual and *|.|* is the energy norm. In fact
1193    
1194      *|r| = sqrt( bilinearform(r,r))*
1195    
1196      :param r: initial residual *r=b-Ax*. ``r`` is altered.
1197      :type r: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
1198      :param x: an initial guess for the solution
1199      :type x: same like ``r``
1200      :param Aprod: returns the value Ax
1201      :type Aprod: function ``Aprod(x)`` where ``x`` is of the same object like
1202                   argument ``x``. The returned object needs to be of the same type
1203                   like argument ``r``.
1204      :param bilinearform: inner product ``<x,r>``
1205      :type bilinearform: function ``bilinearform(x,r)`` where ``x`` is of the same
1206                          type like argument ``x`` and ``r``. The returned value is
1207                          a ``float``.
1208      :param atol: absolute tolerance
1209      :type atol: non-negative ``float``
1210      :param rtol: relative tolerance
1211      :type rtol: non-negative ``float``
1212      :param iter_max: maximum number of iteration steps
1213      :type iter_max: ``int``
1214      :rtype: ``tuple``
1215      :warning: ``r`` and ``x`` are altered.
1216      """
1217      u1=0
1218      u2=0
1219      y1=0
1220      y2=0
1221    
1222      w = r
1223      y1 = r
1224      iter = 0
1225      d = 0
1226      v = Aprod(y1)
1227      u1 = v
1228    
1229      theta = 0.0;
1230      eta = 0.0;
1231      rho=bilinearform(r,r)
1232      if rho < 0: raise NegativeNorm,"negative norm."
1233      tau = math.sqrt(rho)
1234      norm_r0=tau
1235      while tau>atol+rtol*norm_r0:
1236        if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached."%iter_max
1237    
1238        sigma = bilinearform(r,v)
1239        if sigma == 0.0: raise IterationBreakDown,'TFQMR breakdown, sigma=0'
1240    
1241        alpha = rho / sigma
1242    
1243        for j in range(2):
1244    #
1245    #   Compute y2 and u2 only if you have to
1246    #
1247          if ( j == 1 ):
1248            y2 = y1 - alpha * v
1249            u2 = Aprod(y2)
1250    
1251          m = 2 * (iter+1) - 2 + (j+1)
1252          if j==0:
1253             w = w - alpha * u1
1254             d = y1 + ( theta * theta * eta / alpha ) * d
1255          if j==1:
1256             w = w - alpha * u2
1257             d = y2 + ( theta * theta * eta / alpha ) * d
1258    
1259          theta = math.sqrt(bilinearform(w,w))/ tau
1260          c = 1.0 / math.sqrt ( 1.0 + theta * theta )
1261          tau = tau * theta * c
1262          eta = c * c * alpha
1263          x = x + eta * d
1264    #
1265    #  Try to terminate the iteration at each pass through the loop
1266    #
1267        if rho == 0.0: raise IterationBreakDown,'TFQMR breakdown, rho=0'
1268    
1269        rhon = bilinearform(r,w)
1270        beta = rhon / rho;
1271        rho = rhon;
1272        y1 = w + beta * y2;
1273        u1 = Aprod(y1)
1274        v = u1 + beta * ( u2 + beta * v )
1275    
1276        iter += 1
1277    
1278      return x
1279    
1280    
1281    #############################################
1282    
1283    class ArithmeticTuple(object):
1284       """
1285       Tuple supporting inplace update x+=y and scaling x=a*y where ``x,y`` is an
1286       ArithmeticTuple and ``a`` is a float.
1287    
1288       Example of usage::
1289    
1290           from esys.escript import Data
1291           from numpy import array
1292           a=Data(...)
1293           b=array([1.,4.])
1294           x=ArithmeticTuple(a,b)
1295           y=5.*x
1296    
1297       """
1298       def __init__(self,*args):
1299           """
1300           Initializes object with elements ``args``.
1301    
1302           :param args: tuple of objects that support inplace add (x+=y) and
1303                        scaling (x=a*y)
1304           """
1305           self.__items=list(args)
1306    
1307       def __len__(self):
1308           """
1309           Returns the number of items.
1310    
1311           :return: number of items
1312           :rtype: ``int``
1313           """
1314           return len(self.__items)
1315    
1316       def __getitem__(self,index):
1317           """
1318           Returns item at specified position.
1319    
1320           :param index: index of item to be returned
1321           :type index: ``int``
1322           :return: item with index ``index``
1323           """
1324           return self.__items.__getitem__(index)
1325    
1326       def __mul__(self,other):
1327           """
1328           Scales by ``other`` from the right.
1329    
1330           :param other: scaling factor
1331           :type other: ``float``
1332           :return: itemwise self*other
1333           :rtype: `ArithmeticTuple`
1334           """
1335           out=[]
1336           try:
1337               l=len(other)
1338               if l!=len(self):
1339                   raise ValueError,"length of arguments don't match."
1340               for i in range(l): out.append(self[i]*other[i])
1341           except TypeError:
1342               for i in range(len(self)): out.append(self[i]*other)
1343           return ArithmeticTuple(*tuple(out))
1344    
1345       def __rmul__(self,other):
1346           """
1347           Scales by ``other`` from the left.
1348    
1349           :param other: scaling factor
1350           :type other: ``float``
1351           :return: itemwise other*self
1352           :rtype: `ArithmeticTuple`
1353           """
1354           out=[]
1355           try:
1356               l=len(other)
1357               if l!=len(self):
1358                   raise ValueError,"length of arguments don't match."
1359               for i in range(l): out.append(other[i]*self[i])
1360           except TypeError:
1361               for i in range(len(self)): out.append(other*self[i])
1362           return ArithmeticTuple(*tuple(out))
1363    
1364       def __div__(self,other):
1365           """
1366           Scales by (1/``other``) from the right.
1367    
1368           :param other: scaling factor
1369           :type other: ``float``
1370           :return: itemwise self/other
1371           :rtype: `ArithmeticTuple`
1372           """
1373           return self*(1/other)
1374    
1375       def __rdiv__(self,other):
1376           """
1377           Scales by (1/``other``) from the left.
1378    
1379           :param other: scaling factor
1380           :type other: ``float``
1381           :return: itemwise other/self
1382           :rtype: `ArithmeticTuple`
1383           """
1384           out=[]
1385           try:
1386               l=len(other)
1387               if l!=len(self):
1388                   raise ValueError,"length of arguments don't match."
1389               for i in range(l): out.append(other[i]/self[i])
1390           except TypeError:
1391               for i in range(len(self)): out.append(other/self[i])
1392           return ArithmeticTuple(*tuple(out))
1393    
1394       def __iadd__(self,other):
1395           """
1396           Inplace addition of ``other`` to self.
1397    
1398           :param other: increment
1399           :type other: ``ArithmeticTuple``
1400           """
1401           if len(self) != len(other):
1402               raise ValueError,"tuple lengths must match."
1403           for i in range(len(self)):
1404               self.__items[i]+=other[i]
1405           return self
1406    
1407       def __add__(self,other):
1408           """
1409           Adds ``other`` to self.
1410    
1411           :param other: increment
1412           :type other: ``ArithmeticTuple``
1413           """
1414           out=[]
1415           try:
1416               l=len(other)
1417               if l!=len(self):
1418                   raise ValueError,"length of arguments don't match."
1419               for i in range(l): out.append(self[i]+other[i])
1420           except TypeError:
1421               for i in range(len(self)): out.append(self[i]+other)
1422           return ArithmeticTuple(*tuple(out))
1423    
1424       def __sub__(self,other):
1425           """
1426           Subtracts ``other`` from self.
1427    
1428           :param other: decrement
1429           :type other: ``ArithmeticTuple``
1430           """
1431           out=[]
1432           try:
1433               l=len(other)
1434               if l!=len(self):
1435                   raise ValueError,"length of arguments don't match."
1436               for i in range(l): out.append(self[i]-other[i])
1437           except TypeError:
1438               for i in range(len(self)): out.append(self[i]-other)
1439           return ArithmeticTuple(*tuple(out))
1440    
1441       def __isub__(self,other):
1442           """
1443           Inplace subtraction of ``other`` from self.
1444    
1445           :param other: decrement
1446           :type other: ``ArithmeticTuple``
1447           """
1448           if len(self) != len(other):
1449               raise ValueError,"tuple length must match."
1450           for i in range(len(self)):
1451               self.__items[i]-=other[i]
1452           return self
1453    
1454       def __neg__(self):
1455           """
1456           Negates values.
1457           """
1458           out=[]
1459           for i in range(len(self)):
1460               out.append(-self[i])
1461           return ArithmeticTuple(*tuple(out))
1462    
1463    
1464    class HomogeneousSaddlePointProblem(object):
1465          """
1466          This class provides a framework for solving linear homogeneous saddle
1467          point problems of the form::
1468    
1469              *Av+B^*p=f*
1470              *Bv     =0*
1471    
1472          for the unknowns *v* and *p* and given operators *A* and *B* and
1473          given right hand side *f*. *B^** is the adjoint operator of *B*.
1474          """
1475          def __init__(self, adaptSubTolerance=True, **kwargs):
1476        """
1477        initializes the saddle point problem
1478        
1479        :param adaptSubTolerance: If True the tolerance for subproblem is set automatically.
1480        :type adaptSubTolerance: ``bool``
1481        """
1482            self.setTolerance()
1483            self.setAbsoluteTolerance()
1484        self.__adaptSubTolerance=adaptSubTolerance
1485          #=============================================================
1486          def initialize(self):
1487            """
1488            Initializes the problem (overwrite).
1489            """
1490            pass
1491    
1492          def inner_pBv(self,p,Bv):
1493             """
1494             Returns inner product of element p and Bv (overwrite).
1495    
1496             :param p: a pressure increment
1497             :param Bv: a residual
1498             :return: inner product of element p and Bv
1499             :rtype: ``float``
1500             :note: used if PCG is applied.
1501             """
1502             raise NotImplementedError,"no inner product for p and Bv implemented."
1503    
1504          def inner_p(self,p0,p1):
1505             """
1506             Returns inner product of p0 and p1 (overwrite).
1507    
1508             :param p0: a pressure
1509             :param p1: a pressure
1510             :return: inner product of p0 and p1
1511             :rtype: ``float``
1512             """
1513             raise NotImplementedError,"no inner product for p implemented."
1514      
1515          def norm_v(self,v):
1516             """
1517             Returns the norm of v (overwrite).
1518    
1519             :param v: a velovity
1520             :return: norm of v
1521             :rtype: non-negative ``float``
1522             """
1523             raise NotImplementedError,"no norm of v implemented."
1524          def getV(self, p, v0):
1525             """
1526             return the value for v for a given p (overwrite)
1527    
1528             :param p: a pressure
1529             :param v0: a initial guess for the value v to return.
1530             :return: v given as *v= A^{-1} (f-B^*p)*
1531             """
1532             raise NotImplementedError,"no v calculation implemented."
1533    
1534            
1535          def Bv(self,v):
1536            """
1537            Returns Bv (overwrite).
1538    
1539            :rtype: equal to the type of p
1540            :note: boundary conditions on p should be zero!
1541            """
1542            raise NotImplementedError, "no operator B implemented."
1543    
1544          def norm_Bv(self,Bv):
1545            """
1546            Returns the norm of Bv (overwrite).
1547    
1548            :rtype: equal to the type of p
1549            :note: boundary conditions on p should be zero!
1550            """
1551            raise NotImplementedError, "no norm of Bv implemented."
1552    
1553          def solve_AinvBt(self,p):
1554             """
1555             Solves *Av=B^*p* with accuracy `self.getSubProblemTolerance()`
1556             (overwrite).
1557    
1558             :param p: a pressure increment
1559             :return: the solution of *Av=B^*p*
1560             :note: boundary conditions on v should be zero!
1561             """
1562             raise NotImplementedError,"no operator A implemented."
1563    
1564          def solve_prec(self,Bv):
1565             """
1566             Provides a preconditioner for *(BA^{-1}B^ * )* applied to Bv with accuracy
1567             `self.getSubProblemTolerance()` (overwrite).
1568    
1569             :rtype: equal to the type of p
1570             :note: boundary conditions on p should be zero!
1571             """
1572             raise NotImplementedError,"no preconditioner for Schur complement implemented."
1573          def setSubProblemTolerance(self):
1574             """
1575         Updates the tolerance for subproblems
1576         :note: method is typically the method is overwritten.
1577             """
1578             pass
1579          #=============================================================
1580          def __Aprod_PCG(self,p):
1581              dv=self.solve_AinvBt(p)
1582              return ArithmeticTuple(dv,self.Bv(dv))
1583    
1584          def __inner_PCG(self,p,r):
1585             return self.inner_pBv(p,r[1])
1586    
1587          def __Msolve_PCG(self,r):
1588              return self.solve_prec(r[1])
1589          #=============================================================
1590          def __Aprod_GMRES(self,p):
1591              return self.solve_prec(self.Bv(self.solve_AinvBt(p)))
1592          def __inner_GMRES(self,p0,p1):
1593             return self.inner_p(p0,p1)
1594    
1595          #=============================================================
1596          def norm_p(self,p):
1597              """
1598              calculates the norm of ``p``
1599              
1600              :param p: a pressure
1601              :return: the norm of ``p`` using the inner product for pressure
1602              :rtype: ``float``
1603              """
1604              f=self.inner_p(p,p)
1605              if f<0: raise ValueError,"negative pressure norm."
1606              return math.sqrt(f)
1607          def adaptSubTolerance(self):
1608          """
1609          Returns True if tolerance adaption for subproblem is choosen.
1610          """
1611              return self.__adaptSubTolerance
1612          
1613          def solve(self,v,p,max_iter=20, verbose=False, usePCG=True, iter_restart=20, max_correction_steps=10):
1614             """
1615             Solves the saddle point problem using initial guesses v and p.
1616    
1617             :param v: initial guess for velocity
1618             :param p: initial guess for pressure
1619             :type v: `Data`
1620             :type p: `Data`
1621             :param usePCG: indicates the usage of the PCG rather than GMRES scheme.
1622             :param max_iter: maximum number of iteration steps per correction
1623                              attempt
1624             :param verbose: if True, shows information on the progress of the
1625                             saddlepoint problem solver.
1626             :param iter_restart: restart the iteration after ``iter_restart`` steps
1627                                  (only used if useUzaw=False)
1628             :type usePCG: ``bool``
1629             :type max_iter: ``int``
1630             :type verbose: ``bool``
1631             :type iter_restart: ``int``
1632             :rtype: ``tuple`` of `Data` objects
1633             """
1634             self.verbose=verbose
1635             rtol=self.getTolerance()
1636             atol=self.getAbsoluteTolerance()
1637         if self.adaptSubTolerance(): self.setSubProblemTolerance()
1638             correction_step=0
1639             converged=False
1640             while not converged:
1641                  # calculate velocity for current pressure:
1642                  v=self.getV(p,v)
1643                  Bv=self.Bv(v)
1644                  norm_v=self.norm_v(v)
1645                  norm_Bv=self.norm_Bv(Bv)
1646                  ATOL=norm_v*rtol+atol
1647                  if self.verbose: print "HomogeneousSaddlePointProblem: norm v= %e, norm_Bv= %e, tolerance = %e."%(norm_v, norm_Bv,ATOL)
1648                  if not ATOL>0: raise ValueError,"overall absolute tolerance needs to be positive."
1649                  if norm_Bv <= ATOL:
1650                     converged=True
1651                  else:
1652                     correction_step+=1
1653                     if correction_step>max_correction_steps:
1654                          raise CorrectionFailed,"Given up after %d correction steps."%correction_step
1655                     dp=self.solve_prec(Bv)
1656                     if usePCG:
1657                       norm2=self.inner_pBv(dp,Bv)
1658                       if norm2<0: raise ValueError,"negative PCG norm."
1659                       norm2=math.sqrt(norm2)
1660                     else:
1661                       norm2=self.norm_p(dp)
1662                     ATOL_ITER=ATOL/norm_Bv*norm2*0.5
1663                     if self.verbose: print "HomogeneousSaddlePointProblem: tolerance for solver: %e"%ATOL_ITER
1664                     if usePCG:
1665                           p,v0,a_norm=PCG(ArithmeticTuple(v,Bv),self.__Aprod_PCG,p,self.__Msolve_PCG,self.__inner_PCG,atol=ATOL_ITER, rtol=0.,iter_max=max_iter, verbose=self.verbose)
1666                     else:
1667                           p=GMRES(dp,self.__Aprod_GMRES, p, self.__inner_GMRES,atol=ATOL_ITER, rtol=0.,iter_max=max_iter, iter_restart=iter_restart, verbose=self.verbose)
1668             if self.verbose: print "HomogeneousSaddlePointProblem: tolerance reached."
1669         return v,p
1670    
1671          #========================================================================
1672          def setTolerance(self,tolerance=1.e-4):
1673             """
1674             Sets the relative tolerance for (v,p).
1675    
1676             :param tolerance: tolerance to be used
1677             :type tolerance: non-negative ``float``
1678             """
1679             if tolerance<0:
1680                 raise ValueError,"tolerance must be positive."
1681             self.__rtol=tolerance
1682    
1683          def getTolerance(self):
1684             """
1685             Returns the relative tolerance.
1686    
1687             :return: relative tolerance
1688             :rtype: ``float``
1689             """
1690             return self.__rtol
1691    
1692          def setAbsoluteTolerance(self,tolerance=0.):
1693             """
1694             Sets the absolute tolerance.
1695    
1696             :param tolerance: tolerance to be used
1697             :type tolerance: non-negative ``float``
1698             """
1699             if tolerance<0:
1700                 raise ValueError,"tolerance must be non-negative."
1701             self.__atol=tolerance
1702    
1703          def getAbsoluteTolerance(self):
1704             """
1705             Returns the absolute tolerance.
1706    
1707             :return: absolute tolerance
1708             :rtype: ``float``
1709             """
1710             return self.__atol
1711    
1712          def getSubProblemTolerance(self):
1713             """
1714             Sets the relative tolerance to solve the subproblem(s).
1715             """
1716             return max(200.*util.EPSILON,self.getTolerance()**2)
1717    
1718    def MaskFromBoundaryTag(domain,*tags):
1719       """
1720       Creates a mask on the Solution(domain) function space where the value is
1721       one for samples that touch the boundary tagged by tags.
1722    
1723       Usage: m=MaskFromBoundaryTag(domain, "left", "right")
1724    
1725       :param domain: domain to be used
1726       :type domain: `escript.Domain`
1727       :param tags: boundary tags
1728       :type tags: ``str``
1729       :return: a mask which marks samples that are touching the boundary tagged
1730                by any of the given tags
1731       :rtype: `escript.Data` of rank 0
1732       """
1733       pde=linearPDEs.LinearPDE(domain,numEquations=1, numSolutions=1)
1734       d=escript.Scalar(0.,escript.FunctionOnBoundary(domain))
1735       for t in tags: d.setTaggedValue(t,1.)
1736       pde.setValue(y=d)
1737       return util.whereNonZero(pde.getRightHandSide())
1738    
1739    def MaskFromTag(domain,*tags):
1740       """
1741       Creates a mask on the Solution(domain) function space where the value is
1742       one for samples that touch regions tagged by tags.
1743    
1744       Usage: m=MaskFromTag(domain, "ham")
1745    
1746       :param domain: domain to be used
1747       :type domain: `escript.Domain`
1748       :param tags: boundary tags
1749       :type tags: ``str``
1750       :return: a mask which marks samples that are touching the boundary tagged
1751                by any of the given tags
1752       :rtype: `escript.Data` of rank 0
1753       """
1754       pde=linearPDEs.LinearPDE(domain,numEquations=1, numSolutions=1)
1755       d=escript.Scalar(0.,escript.Function(domain))
1756       for t in tags: d.setTaggedValue(t,1.)
1757       pde.setValue(Y=d)
1758       return util.whereNonZero(pde.getRightHandSide())
1759    
1760    

Legend:
Removed from v.782  
changed lines
  Added in v.2683

  ViewVC Help
Powered by ViewVC 1.1.26