/[escript]/trunk/escript/py_src/pdetools.py
ViewVC logotype

Diff of /trunk/escript/py_src/pdetools.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 782 by bcumming, Tue Jul 18 00:47:47 2006 UTC revision 2517 by jfenwick, Fri Jul 3 05:27:45 2009 UTC
# Line 1  Line 1 
1  # $Id$  
2    ########################################################
3    #
4    # Copyright (c) 2003-2008 by University of Queensland
5    # Earth Systems Science Computational Center (ESSCC)
6    # http://www.uq.edu.au/esscc
7    #
8    # Primary Business: Queensland, Australia
9    # Licensed under the Open Software License version 3.0
10    # http://www.opensource.org/licenses/osl-3.0.php
11    #
12    ########################################################
13    
14    __copyright__="""Copyright (c) 2003-2008 by University of Queensland
15    Earth Systems Science Computational Center (ESSCC)
16    http://www.uq.edu.au/esscc
17    Primary Business: Queensland, Australia"""
18    __license__="""Licensed under the Open Software License version 3.0
19    http://www.opensource.org/licenses/osl-3.0.php"""
20    __url__="https://launchpad.net/escript-finley"
21    
22  """  """
23  Provides some tools related to PDEs.  Provides some tools related to PDEs.
24    
25  Currently includes:  Currently includes:
26      - Projector - to project a discontinuous      - Projector - to project a discontinuous function onto a continuous function
27      - Locator - to trace values in data objects at a certain location      - Locator - to trace values in data objects at a certain location
28      - TimeIntegrationManager - to handel extraplotion in time      - TimeIntegrationManager - to handle extrapolation in time
29        - SaddlePointProblem - solver for Saddle point problems using the inexact uszawa scheme
30    
31  @var __author__: name of author  @var __author__: name of author
32  @var __copyright__: copyrights  @var __copyright__: copyrights
# Line 17  Currently includes: Line 37  Currently includes:
37  """  """
38    
39  __author__="Lutz Gross, l.gross@uq.edu.au"  __author__="Lutz Gross, l.gross@uq.edu.au"
 __copyright__="""  Copyright (c) 2006 by ACcESS MNRF  
                     http://www.access.edu.au  
                 Primary Business: Queensland, Australia"""  
 __license__="""Licensed under the Open Software License version 3.0  
              http://www.opensource.org/licenses/osl-3.0.php"""  
 __url__="http://www.iservo.edu.au/esys"  
 __version__="$Revision$"  
 __date__="$Date$"  
40    
41    
42  import escript  import escript
43  import linearPDEs  import linearPDEs
44  import numarray  import numpy
45  import util  import util
46    import math
47    
48    ##### Added by Artak
49    # from Numeric import zeros,Int,Float64
50    ###################################
51    
52    
53  class TimeIntegrationManager:  class TimeIntegrationManager:
54    """    """
55    a simple mechanism to manage time dependend values.    A simple mechanism to manage time dependend values.
56    
57    typical usage is::    Typical usage is::
58    
59       dt=0.1 # time increment       dt=0.1 # time increment
60       tm=TimeIntegrationManager(inital_value,p=1)       tm=TimeIntegrationManager(inital_value,p=1)
# Line 50  class TimeIntegrationManager: Line 68  class TimeIntegrationManager:
68    """    """
69    def __init__(self,*inital_values,**kwargs):    def __init__(self,*inital_values,**kwargs):
70       """       """
71       sets up the value manager where inital_value is the initial value and p is order used for extrapolation       Sets up the value manager where C{inital_values} are the initial values
72         and p is the order used for extrapolation.
73       """       """
74       if kwargs.has_key("p"):       if kwargs.has_key("p"):
75              self.__p=kwargs["p"]              self.__p=kwargs["p"]
# Line 67  class TimeIntegrationManager: Line 86  class TimeIntegrationManager:
86    
87    def getTime(self):    def getTime(self):
88        return self.__t        return self.__t
89    
90    def getValue(self):    def getValue(self):
91        out=self.__v_mem[0]        out=self.__v_mem[0]
92        if len(out)==1:        if len(out)==1:
# Line 76  class TimeIntegrationManager: Line 96  class TimeIntegrationManager:
96    
97    def checkin(self,dt,*values):    def checkin(self,dt,*values):
98        """        """
99        adds new values to the manager. the p+1 last value get lost        Adds new values to the manager. The p+1 last values are lost.
100        """        """
101        o=min(self.__order+1,self.__p)        o=min(self.__order+1,self.__p)
102        self.__order=min(self.__order+1,self.__p)        self.__order=min(self.__order+1,self.__p)
# Line 93  class TimeIntegrationManager: Line 113  class TimeIntegrationManager:
113    
114    def extrapolate(self,dt):    def extrapolate(self,dt):
115        """        """
116        extrapolates to dt forward in time.        Extrapolates to C{dt} forward in time.
117        """        """
118        if self.__order==0:        if self.__order==0:
119           out=self.__v_mem[0]           out=self.__v_mem[0]
# Line 108  class TimeIntegrationManager: Line 128  class TimeIntegrationManager:
128           return out[0]           return out[0]
129        else:        else:
130           return out           return out
131    
132    
133  class Projector:  class Projector:
134    """    """
135    The Projector is a factory which projects a discontiuous function onto a    The Projector is a factory which projects a discontinuous function onto a
136    continuous function on the a given domain.    continuous function on a given domain.
137    """    """
138    def __init__(self, domain, reduce = True, fast=True):    def __init__(self, domain, reduce=True, fast=True):
139      """      """
140      Create a continuous function space projector for a domain.      Creates a continuous function space projector for a domain.
141    
142      @param domain: Domain of the projection.      @param domain: Domain of the projection.
143      @param reduce: Flag to reduce projection order (default is True)      @param reduce: Flag to reduce projection order
144      @param fast: Flag to use a fast method based on matrix lumping (default is true)      @param fast: Flag to use a fast method based on matrix lumping
145      """      """
146      self.__pde = linearPDEs.LinearPDE(domain)      self.__pde = linearPDEs.LinearPDE(domain)
147      if fast:      if fast:
148        self.__pde.setSolverMethod(linearPDEs.LinearPDE.LUMPING)          self.__pde.getSolverOptions().setSolverMethod(linearPDEs.SolverOptions.LUMPING)
149      self.__pde.setSymmetryOn()      self.__pde.setSymmetryOn()
150      self.__pde.setReducedOrderTo(reduce)      self.__pde.setReducedOrderTo(reduce)
151      self.__pde.setValue(D = 1.)      self.__pde.setValue(D = 1.)
152      return      return
153      def getSolverOptions(self):
154    def __del__(self):      """
155      return      Returns the solver options of the PDE solver.
156        
157        @rtype: L{linearPDEs.SolverOptions}
158        """
159    
160    def __call__(self, input_data):    def __call__(self, input_data):
161      """      """
162      Projects input_data onto a continuous function      Projects C{input_data} onto a continuous function.
163    
164      @param input_data: The input_data to be projected.      @param input_data: the data to be projected
165      """      """
166      out=escript.Data(0.,input_data.getShape(),self.__pde.getFunctionSpaceForSolution())      out=escript.Data(0.,input_data.getShape(),self.__pde.getFunctionSpaceForSolution())
167        self.__pde.setValue(Y = escript.Data(), Y_reduced = escript.Data())
168      if input_data.getRank()==0:      if input_data.getRank()==0:
169          self.__pde.setValue(Y = input_data)          self.__pde.setValue(Y = input_data)
170          out=self.__pde.getSolution()          out=self.__pde.getSolution()
# Line 170  class Projector: Line 194  class Projector:
194    
195  class NoPDE:  class NoPDE:
196       """       """
197       solves the following problem for u:       Solves the following problem for u:
198    
199       M{kronecker[i,j]*D[j]*u[j]=Y[i]}       M{kronecker[i,j]*D[j]*u[j]=Y[i]}
200    
201       with constraint       with constraint
202    
203       M{u[j]=r[j]}  where M{q[j]>0}       M{u[j]=r[j]}  where M{q[j]>0}
204    
205       where D, Y, r and q are given functions of rank 1.       where M{D}, M{Y}, M{r} and M{q} are given functions of rank 1.
206    
207       In the case of scalars this takes the form       In the case of scalars this takes the form
208    
209       M{D*u=Y}       M{D*u=Y}
210    
211       with constraint       with constraint
212    
213       M{u=r}  where M{q>0}       M{u=r} where M{q>0}
214    
215       where D, Y, r and q are given scalar functions.       where M{D}, M{Y}, M{r} and M{q} are given scalar functions.
216    
217       The constraint is overwriting any other condition.       The constraint overwrites any other condition.
218    
219       @note: This class is similar to the L{linearPDEs.LinearPDE} class with A=B=C=X=0 but has the intention       @note: This class is similar to the L{linearPDEs.LinearPDE} class with
220              that all input parameter are given in L{Solution} or L{ReducedSolution}. The whole              A=B=C=X=0 but has the intention that all input parameters are given
221              thing is a bit strange and I blame Robert.Woodcock@csiro.au for this.              in L{Solution} or L{ReducedSolution}.
222       """       """
223         # The whole thing is a bit strange and I blame Rob Woodcock (CSIRO) for
224         # this.
225       def __init__(self,domain,D=None,Y=None,q=None,r=None):       def __init__(self,domain,D=None,Y=None,q=None,r=None):
226           """           """
227           initialize the problem           Initializes the problem.
228    
229           @param domain: domain of the PDE.           @param domain: domain of the PDE
230           @type domain: L{Domain}           @type domain: L{Domain}
231           @param D: coefficient of the solution.           @param D: coefficient of the solution
232           @type D: C{float}, C{int}, L{numarray.NumArray}, L{Data}           @type D: C{float}, C{int}, C{numpy.ndarray}, L{Data}
233           @param Y: right hand side           @param Y: right hand side
234           @type Y: C{float}, C{int}, L{numarray.NumArray}, L{Data}           @type Y: C{float}, C{int}, C{numpy.ndarray}, L{Data}
235           @param q: location of constraints           @param q: location of constraints
236           @type q: C{float}, C{int}, L{numarray.NumArray}, L{Data}           @type q: C{float}, C{int}, C{numpy.ndarray}, L{Data}
237           @param r: value of solution at locations of constraints           @param r: value of solution at locations of constraints
238           @type r: C{float}, C{int}, L{numarray.NumArray}, L{Data}           @type r: C{float}, C{int}, C{numpy.ndarray}, L{Data}
239           """           """
240           self.__domain=domain           self.__domain=domain
241           self.__D=D           self.__D=D
# Line 218  class NoPDE: Line 244  class NoPDE:
244           self.__r=r           self.__r=r
245           self.__u=None           self.__u=None
246           self.__function_space=escript.Solution(self.__domain)           self.__function_space=escript.Solution(self.__domain)
247    
248       def setReducedOn(self):       def setReducedOn(self):
249           """           """
250           sets the L{FunctionSpace} of the solution to L{ReducedSolution}           Sets the L{FunctionSpace} of the solution to L{ReducedSolution}.
251           """           """
252           self.__function_space=escript.ReducedSolution(self.__domain)           self.__function_space=escript.ReducedSolution(self.__domain)
253           self.__u=None           self.__u=None
254    
255       def setReducedOff(self):       def setReducedOff(self):
256           """           """
257           sets the L{FunctionSpace} of the solution to L{Solution}           Sets the L{FunctionSpace} of the solution to L{Solution}.
258           """           """
259           self.__function_space=escript.Solution(self.__domain)           self.__function_space=escript.Solution(self.__domain)
260           self.__u=None           self.__u=None
261            
262       def setValue(self,D=None,Y=None,q=None,r=None):       def setValue(self,D=None,Y=None,q=None,r=None):
263           """           """
264           assigns values to the parameters.           Assigns values to the parameters.
265    
266           @param D: coefficient of the solution.           @param D: coefficient of the solution
267           @type D: C{float}, C{int}, L{numarray.NumArray}, L{Data}           @type D: C{float}, C{int}, C{numpy.ndarray}, L{Data}
268           @param Y: right hand side           @param Y: right hand side
269           @type Y: C{float}, C{int}, L{numarray.NumArray}, L{Data}           @type Y: C{float}, C{int}, C{numpy.ndarray}, L{Data}
270           @param q: location of constraints           @param q: location of constraints
271           @type q: C{float}, C{int}, L{numarray.NumArray}, L{Data}           @type q: C{float}, C{int}, C{numpy.ndarray}, L{Data}
272           @param r: value of solution at locations of constraints           @param r: value of solution at locations of constraints
273           @type r: C{float}, C{int}, L{numarray.NumArray}, L{Data}           @type r: C{float}, C{int}, C{numpy.ndarray}, L{Data}
274           """           """
275           if not D==None:           if not D==None:
276              self.__D=D              self.__D=D
# Line 260  class NoPDE: Line 287  class NoPDE:
287    
288       def getSolution(self):       def getSolution(self):
289           """           """
290           returns the solution           Returns the solution.
291            
292           @return: the solution of the problem           @return: the solution of the problem
293           @rtype: L{Data} object in the L{FunctionSpace} L{Solution} or L{ReducedSolution}.           @rtype: L{Data} object in the L{FunctionSpace} L{Solution} or
294                     L{ReducedSolution}
295           """           """
296           if self.__u==None:           if self.__u==None:
297              if self.__D==None:              if self.__D==None:
# Line 280  class NoPDE: Line 308  class NoPDE:
308                  self.__u*=(1.-q)                  self.__u*=(1.-q)
309                  if not self.__r==None: self.__u+=q*self.__r                  if not self.__r==None: self.__u+=q*self.__r
310           return self.__u           return self.__u
311                
312  class Locator:  class Locator:
313       """       """
314       Locator provides access to the values of data objects at a given       Locator provides access to the values of data objects at a given spatial
315       spatial coordinate x.         coordinate x.
316        
317       In fact, a Locator object finds the sample in the set of samples of a       In fact, a Locator object finds the sample in the set of samples of a
318       given function space or domain where which is closest to the given       given function space or domain which is closest to the given point x.
      point x.  
319       """       """
320    
321       def __init__(self,where,x=numarray.zeros((3,))):       def __init__(self,where,x=numpy.zeros((3,))):
322         """         """
323         Initializes a Locator to access values in Data objects on the Doamin         Initializes a Locator to access values in Data objects on the Doamin
324         or FunctionSpace where for the sample point which         or FunctionSpace for the sample point which is closest to the given
325         closest to the given point x.         point x.
326    
327           @param where: function space
328           @type where: L{escript.FunctionSpace}
329           @param x: location(s) of the Locator
330           @type x: C{numpy.ndarray} or C{list} of C{numpy.ndarray}
331         """         """
332         if isinstance(where,escript.FunctionSpace):         if isinstance(where,escript.FunctionSpace):
333            self.__function_space=where            self.__function_space=where
334         else:         else:
335            self.__function_space=escript.ContinuousFunction(where)            self.__function_space=escript.ContinuousFunction(where)
336         self.__id=util.length(self.__function_space.getX()-x[:self.__function_space.getDim()]).mindp()         iterative=False
337           if isinstance(x, list):
338               if len(x)==0:
339                  raise "ValueError", "At least one point must be given."
340               try:
341                 iter(x[0])
342                 iterative=True
343               except TypeError:
344                 iterative=False
345           if iterative:
346               self.__id=[]
347               for p in x:
348                  self.__id.append(util.length(self.__function_space.getX()-p[:self.__function_space.getDim()]).minGlobalDataPoint())
349           else:
350               self.__id=util.length(self.__function_space.getX()-x[:self.__function_space.getDim()]).minGlobalDataPoint()
351    
352       def __str__(self):       def __str__(self):
353         """         """
354         Returns the coordinates of the Locator as a string.         Returns the coordinates of the Locator as a string.
355         """         """
356         return "<Locator %s>"%str(self.getX())         x=self.getX()
357           if instance(x,list):
358              out="["
359              first=True
360              for xx in x:
361                if not first:
362                    out+=","
363                else:
364                    first=False
365                out+=str(xx)
366              out+="]>"
367           else:
368              out=str(x)
369           return out
370    
371         def getX(self):
372            """
373            Returns the exact coordinates of the Locator.
374            """
375            return self(self.getFunctionSpace().getX())
376    
377       def getFunctionSpace(self):       def getFunctionSpace(self):
378          """          """
379      Returns the function space of the Locator.          Returns the function space of the Locator.
380      """          """
381          return self.__function_space          return self.__function_space
382    
383       def getId(self):       def getId(self,item=None):
384          """          """
385      Returns the identifier of the location.          Returns the identifier of the location.
     """  
         return self.__id  
   
      def getX(self):  
386          """          """
387      Returns the exact coordinates of the Locator.          if item == None:
388      """             return self.__id
389          return self(self.getFunctionSpace().getX())          else:
390               if isinstance(self.__id,list):
391                  return self.__id[item]
392               else:
393                  return self.__id
394    
395    
396       def __call__(self,data):       def __call__(self,data):
397          """          """
398      Returns the value of data at the Locator of a Data object otherwise          Returns the value of data at the Locator of a Data object.
399      the object is returned.          """
     """  
400          return self.getValue(data)          return self.getValue(data)
401    
402       def getValue(self,data):       def getValue(self,data):
403          """          """
404      Returns the value of data at the Locator if data is a Data object          Returns the value of C{data} at the Locator if C{data} is a L{Data}
405      otherwise the object is returned.          object otherwise the object is returned.
406      """          """
407          if isinstance(data,escript.Data):          if isinstance(data,escript.Data):
408             if data.getFunctionSpace()==self.getFunctionSpace():             dat=util.interpolate(data,self.getFunctionSpace())
409               out=data.convertToNumArrayFromDPNo(self.getId()[0],self.getId()[1])             id=self.getId()
410               #out=data.convertToNumArrayFromDPNo(self.getId()[0],self.getId()[1],self.getId()[2])             r=data.getRank()
411               if isinstance(id,list):
412                   out=[]
413                   for i in id:
414                      o=numpy.array(dat.getTupleForGlobalDataPoint(*i))
415                      if data.getRank()==0:
416                         out.append(o[0])
417                      else:
418                         out.append(o)
419                   return out
420             else:             else:
421               out=data.interpolate(self.getFunctionSpace()).convertToNumArrayFromDPNo(self.getId()[0],self.getId()[1])               out=numpy.array(dat.getTupleForGlobalDataPoint(*id))
422               #out=data.interpolate(self.getFunctionSpace()).convertToNumArrayFromDPNo(self.getId()[0],self.getId()[1],self.getId()[2])               if data.getRank()==0:
423             if data.getRank()==0:                  return out[0]
424                return out[0]               else:
425             else:                  return out
               return out  
426          else:          else:
427             return data             return data
428    
429  # vim: expandtab shiftwidth=4:  class SolverSchemeException(Exception):
430       """
431       This is a generic exception thrown by solvers.
432       """
433       pass
434    
435    class IndefinitePreconditioner(SolverSchemeException):
436       """
437       Exception thrown if the preconditioner is not positive definite.
438       """
439       pass
440    
441    class MaxIterReached(SolverSchemeException):
442       """
443       Exception thrown if the maximum number of iteration steps is reached.
444       """
445       pass
446    
447    class CorrectionFailed(SolverSchemeException):
448       """
449       Exception thrown if no convergence has been achieved in the solution
450       correction scheme.
451       """
452       pass
453    
454    class IterationBreakDown(SolverSchemeException):
455       """
456       Exception thrown if the iteration scheme encountered an incurable breakdown.
457       """
458       pass
459    
460    class NegativeNorm(SolverSchemeException):
461       """
462       Exception thrown if a norm calculation returns a negative norm.
463       """
464       pass
465    
466    def PCG(r, Aprod, x, Msolve, bilinearform, atol=0, rtol=1.e-8, iter_max=100, initial_guess=True, verbose=False):
467       """
468       Solver for
469    
470       M{Ax=b}
471    
472       with a symmetric and positive definite operator A (more details required!).
473       It uses the conjugate gradient method with preconditioner M providing an
474       approximation of A.
475    
476       The iteration is terminated if
477    
478       M{|r| <= atol+rtol*|r0|}
479    
480       where M{r0} is the initial residual and M{|.|} is the energy norm. In fact
481    
482       M{|r| = sqrt( bilinearform(Msolve(r),r))}
483    
484       For details on the preconditioned conjugate gradient method see the book:
485    
486       I{Templates for the Solution of Linear Systems by R. Barrett, M. Berry,
487       T.F. Chan, J. Demmel, J. Donato, J. Dongarra, V. Eijkhout, R. Pozo,
488       C. Romine, and H. van der Vorst}.
489    
490       @param r: initial residual M{r=b-Ax}. C{r} is altered.
491       @type r: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
492       @param x: an initial guess for the solution
493       @type x: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
494       @param Aprod: returns the value Ax
495       @type Aprod: function C{Aprod(x)} where C{x} is of the same object like
496                    argument C{x}. The returned object needs to be of the same type
497                    like argument C{r}.
498       @param Msolve: solves Mx=r
499       @type Msolve: function C{Msolve(r)} where C{r} is of the same type like
500                     argument C{r}. The returned object needs to be of the same
501                     type like argument C{x}.
502       @param bilinearform: inner product C{<x,r>}
503       @type bilinearform: function C{bilinearform(x,r)} where C{x} is of the same
504                           type like argument C{x} and C{r} is. The returned value
505                           is a C{float}.
506       @param atol: absolute tolerance
507       @type atol: non-negative C{float}
508       @param rtol: relative tolerance
509       @type rtol: non-negative C{float}
510       @param iter_max: maximum number of iteration steps
511       @type iter_max: C{int}
512       @return: the solution approximation and the corresponding residual
513       @rtype: C{tuple}
514       @warning: C{r} and C{x} are altered.
515       """
516       iter=0
517       rhat=Msolve(r)
518       d = rhat
519       rhat_dot_r = bilinearform(rhat, r)
520       if rhat_dot_r<0: raise NegativeNorm,"negative norm."
521       norm_r0=math.sqrt(rhat_dot_r)
522       atol2=atol+rtol*norm_r0
523       if atol2<=0:
524          raise ValueError,"Non-positive tolarance."
525       atol2=max(atol2, 100. * util.EPSILON * norm_r0)
526    
527       if verbose: print "PCG: initial residual norm = %e (absolute tolerance = %e)"%(norm_r0, atol2)
528    
529    
530       while not math.sqrt(rhat_dot_r) <= atol2:
531           iter+=1
532           if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached."%iter_max
533    
534           q=Aprod(d)
535           alpha = rhat_dot_r / bilinearform(d, q)
536           x += alpha * d
537           if isinstance(q,ArithmeticTuple):
538           r += q * (-alpha)      # Doing it the other way calls the float64.__mul__ not AT.__rmul__
539           else:
540               r += (-alpha) * q
541           rhat=Msolve(r)
542           rhat_dot_r_new = bilinearform(rhat, r)
543           beta = rhat_dot_r_new / rhat_dot_r
544           rhat+=beta * d
545           d=rhat
546    
547           rhat_dot_r = rhat_dot_r_new
548           if rhat_dot_r<0: raise NegativeNorm,"negative norm."
549           if verbose: print "PCG: iteration step %s: residual norm = %e"%(iter, math.sqrt(rhat_dot_r))
550       if verbose: print "PCG: tolerance reached after %s steps."%iter
551       return x,r,math.sqrt(rhat_dot_r)
552    
553    class Defect(object):
554        """
555        Defines a non-linear defect F(x) of a variable x.
556        """
557        def __init__(self):
558            """
559            Initializes defect.
560            """
561            self.setDerivativeIncrementLength()
562    
563        def bilinearform(self, x0, x1):
564            """
565            Returns the inner product of x0 and x1
566    
567            @param x0: value for x0
568            @param x1: value for x1
569            @return: the inner product of x0 and x1
570            @rtype: C{float}
571            """
572            return 0
573    
574        def norm(self,x):
575            """
576            Returns the norm of argument C{x}.
577    
578            @param x: a value
579            @return: norm of argument x
580            @rtype: C{float}
581            @note: by default C{sqrt(self.bilinearform(x,x)} is returned.
582            """
583            s=self.bilinearform(x,x)
584            if s<0: raise NegativeNorm,"negative norm."
585            return math.sqrt(s)
586    
587        def eval(self,x):
588            """
589            Returns the value F of a given C{x}.
590    
591            @param x: value for which the defect C{F} is evaluated
592            @return: value of the defect at C{x}
593            """
594            return 0
595    
596        def __call__(self,x):
597            return self.eval(x)
598    
599        def setDerivativeIncrementLength(self,inc=math.sqrt(util.EPSILON)):
600            """
601            Sets the relative length of the increment used to approximate the
602            derivative of the defect. The increment is inc*norm(x)/norm(v)*v in the
603            direction of v with x as a starting point.
604    
605            @param inc: relative increment length
606            @type inc: positive C{float}
607            """
608            if inc<=0: raise ValueError,"positive increment required."
609            self.__inc=inc
610    
611        def getDerivativeIncrementLength(self):
612            """
613            Returns the relative increment length used to approximate the
614            derivative of the defect.
615            @return: value of the defect at C{x}
616            @rtype: positive C{float}
617            """
618            return self.__inc
619    
620        def derivative(self, F0, x0, v, v_is_normalised=True):
621            """
622            Returns the directional derivative at C{x0} in the direction of C{v}.
623    
624            @param F0: value of this defect at x0
625            @param x0: value at which derivative is calculated
626            @param v: direction
627            @param v_is_normalised: True to indicate that C{v} is nomalized
628                                    (self.norm(v)=0)
629            @return: derivative of this defect at x0 in the direction of C{v}
630            @note: by default numerical evaluation (self.eval(x0+eps*v)-F0)/eps is
631                   used but this method maybe overwritten to use exact evaluation.
632            """
633            normx=self.norm(x0)
634            if normx>0:
635                 epsnew = self.getDerivativeIncrementLength() * normx
636            else:
637                 epsnew = self.getDerivativeIncrementLength()
638            if not v_is_normalised:
639                normv=self.norm(v)
640                if normv<=0:
641                   return F0*0
642                else:
643                   epsnew /= normv
644            F1=self.eval(x0 + epsnew * v)
645            return (F1-F0)/epsnew
646    
647    ######################################
648    def NewtonGMRES(defect, x, iter_max=100, sub_iter_max=20, atol=0,rtol=1.e-4, sub_tol_max=0.5, gamma=0.9, verbose=False):
649       """
650       Solves a non-linear problem M{F(x)=0} for unknown M{x} using the stopping
651       criterion:
652    
653       M{norm(F(x) <= atol + rtol * norm(F(x0)}
654    
655       where M{x0} is the initial guess.
656    
657       @param defect: object defining the function M{F}. C{defect.norm} defines the
658                      M{norm} used in the stopping criterion.
659       @type defect: L{Defect}
660       @param x: initial guess for the solution, C{x} is altered.
661       @type x: any object type allowing basic operations such as
662                C{numpy.ndarray}, L{Data}
663       @param iter_max: maximum number of iteration steps
664       @type iter_max: positive C{int}
665       @param sub_iter_max: maximum number of inner iteration steps
666       @type sub_iter_max: positive C{int}
667       @param atol: absolute tolerance for the solution
668       @type atol: positive C{float}
669       @param rtol: relative tolerance for the solution
670       @type rtol: positive C{float}
671       @param gamma: tolerance safety factor for inner iteration
672       @type gamma: positive C{float}, less than 1
673       @param sub_tol_max: upper bound for inner tolerance
674       @type sub_tol_max: positive C{float}, less than 1
675       @return: an approximation of the solution with the desired accuracy
676       @rtype: same type as the initial guess
677       """
678       lmaxit=iter_max
679       if atol<0: raise ValueError,"atol needs to be non-negative."
680       if rtol<0: raise ValueError,"rtol needs to be non-negative."
681       if rtol+atol<=0: raise ValueError,"rtol or atol needs to be non-negative."
682       if gamma<=0 or gamma>=1: raise ValueError,"tolerance safety factor for inner iteration (gamma =%s) needs to be positive and less than 1."%gamma
683       if sub_tol_max<=0 or sub_tol_max>=1: raise ValueError,"upper bound for inner tolerance for inner iteration (sub_tol_max =%s) needs to be positive and less than 1."%sub_tol_max
684    
685       F=defect(x)
686       fnrm=defect.norm(F)
687       stop_tol=atol + rtol*fnrm
688       sub_tol=sub_tol_max
689       if verbose: print "NewtonGMRES: initial residual = %e."%fnrm
690       if verbose: print "             tolerance = %e."%sub_tol
691       iter=1
692       #
693       # main iteration loop
694       #
695       while not fnrm<=stop_tol:
696                if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached."%iter_max
697                #
698            #   adjust sub_tol_
699            #
700                if iter > 1:
701               rat=fnrm/fnrmo
702                   sub_tol_old=sub_tol
703               sub_tol=gamma*rat**2
704               if gamma*sub_tol_old**2 > .1: sub_tol=max(sub_tol,gamma*sub_tol_old**2)
705               sub_tol=max(min(sub_tol,sub_tol_max), .5*stop_tol/fnrm)
706            #
707            # calculate newton increment xc
708                #     if iter_max in __FDGMRES is reached MaxIterReached is thrown
709                #     if iter_restart -1 is returned as sub_iter
710                #     if  atol is reached sub_iter returns the numer of steps performed to get there
711                #
712                #
713                if verbose: print "             subiteration (GMRES) is called with relative tolerance %e."%sub_tol
714                try:
715                   xc, sub_iter=__FDGMRES(F, defect, x, sub_tol*fnrm, iter_max=iter_max-iter, iter_restart=sub_iter_max)
716                except MaxIterReached:
717                   raise MaxIterReached,"maximum number of %s steps reached."%iter_max
718                if sub_iter<0:
719                   iter+=sub_iter_max
720                else:
721                   iter+=sub_iter
722                # ====
723            x+=xc
724                F=defect(x)
725            iter+=1
726                fnrmo, fnrm=fnrm, defect.norm(F)
727                if verbose: print "             step %s: residual %e."%(iter,fnrm)
728       if verbose: print "NewtonGMRES: completed after %s steps."%iter
729       return x
730    
731    def __givapp(c,s,vin):
732        """
733        Applies a sequence of Givens rotations (c,s) recursively to the vector
734        C{vin}
735    
736        @warning: C{vin} is altered.
737        """
738        vrot=vin
739        if isinstance(c,float):
740            vrot=[c*vrot[0]-s*vrot[1],s*vrot[0]+c*vrot[1]]
741        else:
742            for i in range(len(c)):
743                w1=c[i]*vrot[i]-s[i]*vrot[i+1]
744            w2=s[i]*vrot[i]+c[i]*vrot[i+1]
745                vrot[i]=w1
746                vrot[i+1]=w2
747        return vrot
748    
749    def __FDGMRES(F0, defect, x0, atol, iter_max=100, iter_restart=20):
750       h=numpy.zeros((iter_restart,iter_restart),numpy.float64)
751       c=numpy.zeros(iter_restart,numpy.float64)
752       s=numpy.zeros(iter_restart,numpy.float64)
753       g=numpy.zeros(iter_restart,numpy.float64)
754       v=[]
755    
756       rho=defect.norm(F0)
757       if rho<=0.: return x0*0
758    
759       v.append(-F0/rho)
760       g[0]=rho
761       iter=0
762       while rho > atol and iter<iter_restart-1:
763            if iter  >= iter_max:
764                raise MaxIterReached,"maximum number of %s steps reached."%iter_max
765    
766            p=defect.derivative(F0,x0,v[iter], v_is_normalised=True)
767            v.append(p)
768    
769            v_norm1=defect.norm(v[iter+1])
770    
771            # Modified Gram-Schmidt
772            for j in range(iter+1):
773                h[j,iter]=defect.bilinearform(v[j],v[iter+1])
774                v[iter+1]-=h[j,iter]*v[j]
775    
776            h[iter+1,iter]=defect.norm(v[iter+1])
777            v_norm2=h[iter+1,iter]
778    
779            # Reorthogonalize if needed
780            if v_norm1 + 0.001*v_norm2 == v_norm1:   #Brown/Hindmarsh condition (default)
781                for j in range(iter+1):
782                    hr=defect.bilinearform(v[j],v[iter+1])
783                    h[j,iter]=h[j,iter]+hr
784                    v[iter+1] -= hr*v[j]
785    
786                v_norm2=defect.norm(v[iter+1])
787                h[iter+1,iter]=v_norm2
788            #   watch out for happy breakdown
789            if not v_norm2 == 0:
790                v[iter+1]=v[iter+1]/h[iter+1,iter]
791    
792            #   Form and store the information for the new Givens rotation
793            if iter > 0 :
794                hhat=numpy.zeros(iter+1,numpy.float64)
795                for i in range(iter+1) : hhat[i]=h[i,iter]
796                hhat=__givapp(c[0:iter],s[0:iter],hhat);
797                for i in range(iter+1) : h[i,iter]=hhat[i]
798    
799            mu=math.sqrt(h[iter,iter]*h[iter,iter]+h[iter+1,iter]*h[iter+1,iter])
800    
801            if mu!=0 :
802                c[iter]=h[iter,iter]/mu
803                s[iter]=-h[iter+1,iter]/mu
804                h[iter,iter]=c[iter]*h[iter,iter]-s[iter]*h[iter+1,iter]
805                h[iter+1,iter]=0.0
806                gg=__givapp(c[iter],s[iter],[g[iter],g[iter+1]])
807                g[iter]=gg[0]
808                g[iter+1]=gg[1]
809    
810            # Update the residual norm
811            rho=abs(g[iter+1])
812            iter+=1
813    
814       # At this point either iter > iter_max or rho < tol.
815       # It's time to compute x and leave.
816       if iter > 0 :
817         y=numpy.zeros(iter,numpy.float64)
818         y[iter-1] = g[iter-1] / h[iter-1,iter-1]
819         if iter > 1 :
820            i=iter-2
821            while i>=0 :
822              y[i] = ( g[i] - numpy.dot(h[i,i+1:iter], y[i+1:iter])) / h[i,i]
823              i=i-1
824         xhat=v[iter-1]*y[iter-1]
825         for i in range(iter-1):
826        xhat += v[i]*y[i]
827       else :
828          xhat=v[0] * 0
829    
830       if iter<iter_restart-1:
831          stopped=iter
832       else:
833          stopped=-1
834    
835       return xhat,stopped
836    
837    def GMRES(r, Aprod, x, bilinearform, atol=0, rtol=1.e-8, iter_max=100, iter_restart=20, verbose=False,P_R=None):
838       """
839       Solver for
840    
841       M{Ax=b}
842    
843       with a general operator A (more details required!).
844       It uses the generalized minimum residual method (GMRES).
845    
846       The iteration is terminated if
847    
848       M{|r| <= atol+rtol*|r0|}
849    
850       where M{r0} is the initial residual and M{|.|} is the energy norm. In fact
851    
852       M{|r| = sqrt( bilinearform(r,r))}
853    
854       @param r: initial residual M{r=b-Ax}. C{r} is altered.
855       @type r: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
856       @param x: an initial guess for the solution
857       @type x: same like C{r}
858       @param Aprod: returns the value Ax
859       @type Aprod: function C{Aprod(x)} where C{x} is of the same object like
860                    argument C{x}. The returned object needs to be of the same
861                    type like argument C{r}.
862       @param bilinearform: inner product C{<x,r>}
863       @type bilinearform: function C{bilinearform(x,r)} where C{x} is of the same
864                           type like argument C{x} and C{r}. The returned value is
865                           a C{float}.
866       @param atol: absolute tolerance
867       @type atol: non-negative C{float}
868       @param rtol: relative tolerance
869       @type rtol: non-negative C{float}
870       @param iter_max: maximum number of iteration steps
871       @type iter_max: C{int}
872       @param iter_restart: in order to save memory the orthogonalization process
873                            is terminated after C{iter_restart} steps and the
874                            iteration is restarted.
875       @type iter_restart: C{int}
876       @return: the solution approximation and the corresponding residual
877       @rtype: C{tuple}
878       @warning: C{r} and C{x} are altered.
879       """
880       m=iter_restart
881       restarted=False
882       iter=0
883       if rtol>0:
884          r_dot_r = bilinearform(r, r)
885          if r_dot_r<0: raise NegativeNorm,"negative norm."
886          atol2=atol+rtol*math.sqrt(r_dot_r)
887          if verbose: print "GMRES: norm of right hand side = %e (absolute tolerance = %e)"%(math.sqrt(r_dot_r), atol2)
888       else:
889          atol2=atol
890          if verbose: print "GMRES: absolute tolerance = %e"%atol2
891       if atol2<=0:
892          raise ValueError,"Non-positive tolarance."
893    
894       while True:
895          if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached"%iter_max
896          if restarted:
897             r2 = r-Aprod(x-x2)
898          else:
899             r2=1*r
900          x2=x*1.
901          x,stopped=_GMRESm(r2, Aprod, x, bilinearform, atol2, iter_max=iter_max-iter, iter_restart=m, verbose=verbose,P_R=P_R)
902          iter+=iter_restart
903          if stopped: break
904          if verbose: print "GMRES: restart."
905          restarted=True
906       if verbose: print "GMRES: tolerance has been reached."
907       return x
908    
909    def _GMRESm(r, Aprod, x, bilinearform, atol, iter_max=100, iter_restart=20, verbose=False, P_R=None):
910       iter=0
911    
912       h=numpy.zeros((iter_restart+1,iter_restart),numpy.float64)
913       c=numpy.zeros(iter_restart,numpy.float64)
914       s=numpy.zeros(iter_restart,numpy.float64)
915       g=numpy.zeros(iter_restart+1,numpy.float64)
916       v=[]
917    
918       r_dot_r = bilinearform(r, r)
919       if r_dot_r<0: raise NegativeNorm,"negative norm."
920       rho=math.sqrt(r_dot_r)
921    
922       v.append(r/rho)
923       g[0]=rho
924    
925       if verbose: print "GMRES: initial residual %e (absolute tolerance = %e)"%(rho,atol)
926       while not (rho<=atol or iter==iter_restart):
927    
928        if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached."%iter_max
929    
930            if P_R!=None:
931                p=Aprod(P_R(v[iter]))
932            else:
933            p=Aprod(v[iter])
934        v.append(p)
935    
936        v_norm1=math.sqrt(bilinearform(v[iter+1], v[iter+1]))
937    
938    # Modified Gram-Schmidt
939        for j in range(iter+1):
940          h[j,iter]=bilinearform(v[j],v[iter+1])
941          v[iter+1]-=h[j,iter]*v[j]
942    
943        h[iter+1,iter]=math.sqrt(bilinearform(v[iter+1],v[iter+1]))
944        v_norm2=h[iter+1,iter]
945    
946    # Reorthogonalize if needed
947        if v_norm1 + 0.001*v_norm2 == v_norm1:   #Brown/Hindmarsh condition (default)
948         for j in range(iter+1):
949            hr=bilinearform(v[j],v[iter+1])
950                h[j,iter]=h[j,iter]+hr
951                v[iter+1] -= hr*v[j]
952    
953         v_norm2=math.sqrt(bilinearform(v[iter+1], v[iter+1]))
954         h[iter+1,iter]=v_norm2
955    
956    #   watch out for happy breakdown
957            if not v_norm2 == 0:
958             v[iter+1]=v[iter+1]/h[iter+1,iter]
959    
960    #   Form and store the information for the new Givens rotation
961        if iter > 0: h[:iter+1,iter]=__givapp(c[:iter],s[:iter],h[:iter+1,iter])
962        mu=math.sqrt(h[iter,iter]*h[iter,iter]+h[iter+1,iter]*h[iter+1,iter])
963    
964        if mu!=0 :
965            c[iter]=h[iter,iter]/mu
966            s[iter]=-h[iter+1,iter]/mu
967            h[iter,iter]=c[iter]*h[iter,iter]-s[iter]*h[iter+1,iter]
968            h[iter+1,iter]=0.0
969                    gg=__givapp(c[iter],s[iter],[g[iter],g[iter+1]])
970                    g[iter]=gg[0]
971                    g[iter+1]=gg[1]
972    # Update the residual norm
973    
974            rho=abs(g[iter+1])
975            if verbose: print "GMRES: iteration step %s: residual %e"%(iter,rho)
976        iter+=1
977    
978    # At this point either iter > iter_max or rho < tol.
979    # It's time to compute x and leave.
980    
981       if verbose: print "GMRES: iteration stopped after %s step."%iter
982       if iter > 0 :
983         y=numpy.zeros(iter,numpy.float64)
984         y[iter-1] = g[iter-1] / h[iter-1,iter-1]
985         if iter > 1 :
986            i=iter-2
987            while i>=0 :
988              y[i] = ( g[i] - numpy.dot(h[i,i+1:iter], y[i+1:iter])) / h[i,i]
989              i=i-1
990         xhat=v[iter-1]*y[iter-1]
991         for i in range(iter-1):
992        xhat += v[i]*y[i]
993       else:
994         xhat=v[0] * 0
995       if P_R!=None:
996          x += P_R(xhat)
997       else:
998          x += xhat
999       if iter<iter_restart-1:
1000          stopped=True
1001       else:
1002          stopped=False
1003    
1004       return x,stopped
1005    
1006    def MINRES(r, Aprod, x, Msolve, bilinearform, atol=0, rtol=1.e-8, iter_max=100):
1007        """
1008        Solver for
1009    
1010        M{Ax=b}
1011    
1012        with a symmetric and positive definite operator A (more details required!).
1013        It uses the minimum residual method (MINRES) with preconditioner M
1014        providing an approximation of A.
1015    
1016        The iteration is terminated if
1017    
1018        M{|r| <= atol+rtol*|r0|}
1019    
1020        where M{r0} is the initial residual and M{|.|} is the energy norm. In fact
1021    
1022        M{|r| = sqrt( bilinearform(Msolve(r),r))}
1023    
1024        For details on the preconditioned conjugate gradient method see the book:
1025    
1026        I{Templates for the Solution of Linear Systems by R. Barrett, M. Berry,
1027        T.F. Chan, J. Demmel, J. Donato, J. Dongarra, V. Eijkhout, R. Pozo,
1028        C. Romine, and H. van der Vorst}.
1029    
1030        @param r: initial residual M{r=b-Ax}. C{r} is altered.
1031        @type r: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
1032        @param x: an initial guess for the solution
1033        @type x: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
1034        @param Aprod: returns the value Ax
1035        @type Aprod: function C{Aprod(x)} where C{x} is of the same object like
1036                     argument C{x}. The returned object needs to be of the same
1037                     type like argument C{r}.
1038        @param Msolve: solves Mx=r
1039        @type Msolve: function C{Msolve(r)} where C{r} is of the same type like
1040                      argument C{r}. The returned object needs to be of the same
1041                      type like argument C{x}.
1042        @param bilinearform: inner product C{<x,r>}
1043        @type bilinearform: function C{bilinearform(x,r)} where C{x} is of the same
1044                            type like argument C{x} and C{r} is. The returned value
1045                            is a C{float}.
1046        @param atol: absolute tolerance
1047        @type atol: non-negative C{float}
1048        @param rtol: relative tolerance
1049        @type rtol: non-negative C{float}
1050        @param iter_max: maximum number of iteration steps
1051        @type iter_max: C{int}
1052        @return: the solution approximation and the corresponding residual
1053        @rtype: C{tuple}
1054        @warning: C{r} and C{x} are altered.
1055        """
1056        #------------------------------------------------------------------
1057        # Set up y and v for the first Lanczos vector v1.
1058        # y  =  beta1 P' v1,  where  P = C**(-1).
1059        # v is really P' v1.
1060        #------------------------------------------------------------------
1061        r1    = r
1062        y = Msolve(r)
1063        beta1 = bilinearform(y,r)
1064    
1065        if beta1< 0: raise NegativeNorm,"negative norm."
1066    
1067        #  If r = 0 exactly, stop with x
1068        if beta1==0: return x
1069    
1070        if beta1> 0: beta1  = math.sqrt(beta1)
1071    
1072        #------------------------------------------------------------------
1073        # Initialize quantities.
1074        # ------------------------------------------------------------------
1075        iter   = 0
1076        Anorm = 0
1077        ynorm = 0
1078        oldb   = 0
1079        beta   = beta1
1080        dbar   = 0
1081        epsln  = 0
1082        phibar = beta1
1083        rhs1   = beta1
1084        rhs2   = 0
1085        rnorm  = phibar
1086        tnorm2 = 0
1087        ynorm2 = 0
1088        cs     = -1
1089        sn     = 0
1090        w      = r*0.
1091        w2     = r*0.
1092        r2     = r1
1093        eps    = 0.0001
1094    
1095        #---------------------------------------------------------------------
1096        # Main iteration loop.
1097        # --------------------------------------------------------------------
1098        while not rnorm<=atol+rtol*Anorm*ynorm:    #  checks ||r|| < (||A|| ||x||) * TOL
1099    
1100        if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached."%iter_max
1101            iter    = iter  +  1
1102    
1103            #-----------------------------------------------------------------
1104            # Obtain quantities for the next Lanczos vector vk+1, k = 1, 2,...
1105            # The general iteration is similar to the case k = 1 with v0 = 0:
1106            #
1107            #   p1      = Operator * v1  -  beta1 * v0,
1108            #   alpha1  = v1'p1,
1109            #   q2      = p2  -  alpha1 * v1,
1110            #   beta2^2 = q2'q2,
1111            #   v2      = (1/beta2) q2.
1112            #
1113            # Again, y = betak P vk,  where  P = C**(-1).
1114            #-----------------------------------------------------------------
1115            s = 1/beta                 # Normalize previous vector (in y).
1116            v = s*y                    # v = vk if P = I
1117    
1118            y      = Aprod(v)
1119    
1120            if iter >= 2:
1121              y = y - (beta/oldb)*r1
1122    
1123            alfa   = bilinearform(v,y)              # alphak
1124            y      += (- alfa/beta)*r2
1125            r1     = r2
1126            r2     = y
1127            y = Msolve(r2)
1128            oldb   = beta                           # oldb = betak
1129            beta   = bilinearform(y,r2)             # beta = betak+1^2
1130            if beta < 0: raise NegativeNorm,"negative norm."
1131    
1132            beta   = math.sqrt( beta )
1133            tnorm2 = tnorm2 + alfa*alfa + oldb*oldb + beta*beta
1134    
1135            if iter==1:                 # Initialize a few things.
1136              gmax   = abs( alfa )      # alpha1
1137              gmin   = gmax             # alpha1
1138    
1139            # Apply previous rotation Qk-1 to get
1140            #   [deltak epslnk+1] = [cs  sn][dbark    0   ]
1141            #   [gbar k dbar k+1]   [sn -cs][alfak betak+1].
1142    
1143            oldeps = epsln
1144            delta  = cs * dbar  +  sn * alfa  # delta1 = 0         deltak
1145            gbar   = sn * dbar  -  cs * alfa  # gbar 1 = alfa1     gbar k
1146            epsln  =               sn * beta  # epsln2 = 0         epslnk+1
1147            dbar   =            -  cs * beta  # dbar 2 = beta2     dbar k+1
1148    
1149            # Compute the next plane rotation Qk
1150    
1151            gamma  = math.sqrt(gbar*gbar+beta*beta)  # gammak
1152            gamma  = max(gamma,eps)
1153            cs     = gbar / gamma             # ck
1154            sn     = beta / gamma             # sk
1155            phi    = cs * phibar              # phik
1156            phibar = sn * phibar              # phibark+1
1157    
1158            # Update  x.
1159    
1160            denom = 1/gamma
1161            w1    = w2
1162            w2    = w
1163            w     = (v - oldeps*w1 - delta*w2) * denom
1164            x     +=  phi*w
1165    
1166            # Go round again.
1167    
1168            gmax   = max(gmax,gamma)
1169            gmin   = min(gmin,gamma)
1170            z      = rhs1 / gamma
1171            ynorm2 = z*z  +  ynorm2
1172            rhs1   = rhs2 -  delta*z
1173            rhs2   =      -  epsln*z
1174    
1175            # Estimate various norms and test for convergence.
1176    
1177            Anorm  = math.sqrt( tnorm2 )
1178            ynorm  = math.sqrt( ynorm2 )
1179    
1180            rnorm  = phibar
1181    
1182        return x
1183    
1184    def TFQMR(r, Aprod, x, bilinearform, atol=0, rtol=1.e-8, iter_max=100):
1185      """
1186      Solver for
1187    
1188      M{Ax=b}
1189    
1190      with a general operator A (more details required!).
1191      It uses the Transpose-Free Quasi-Minimal Residual method (TFQMR).
1192    
1193      The iteration is terminated if
1194    
1195      M{|r| <= atol+rtol*|r0|}
1196    
1197      where M{r0} is the initial residual and M{|.|} is the energy norm. In fact
1198    
1199      M{|r| = sqrt( bilinearform(r,r))}
1200    
1201      @param r: initial residual M{r=b-Ax}. C{r} is altered.
1202      @type r: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
1203      @param x: an initial guess for the solution
1204      @type x: same like C{r}
1205      @param Aprod: returns the value Ax
1206      @type Aprod: function C{Aprod(x)} where C{x} is of the same object like
1207                   argument C{x}. The returned object needs to be of the same type
1208                   like argument C{r}.
1209      @param bilinearform: inner product C{<x,r>}
1210      @type bilinearform: function C{bilinearform(x,r)} where C{x} is of the same
1211                          type like argument C{x} and C{r}. The returned value is
1212                          a C{float}.
1213      @param atol: absolute tolerance
1214      @type atol: non-negative C{float}
1215      @param rtol: relative tolerance
1216      @type rtol: non-negative C{float}
1217      @param iter_max: maximum number of iteration steps
1218      @type iter_max: C{int}
1219      @rtype: C{tuple}
1220      @warning: C{r} and C{x} are altered.
1221      """
1222      u1=0
1223      u2=0
1224      y1=0
1225      y2=0
1226    
1227      w = r
1228      y1 = r
1229      iter = 0
1230      d = 0
1231      v = Aprod(y1)
1232      u1 = v
1233    
1234      theta = 0.0;
1235      eta = 0.0;
1236      rho=bilinearform(r,r)
1237      if rho < 0: raise NegativeNorm,"negative norm."
1238      tau = math.sqrt(rho)
1239      norm_r0=tau
1240      while tau>atol+rtol*norm_r0:
1241        if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached."%iter_max
1242    
1243        sigma = bilinearform(r,v)
1244        if sigma == 0.0: raise IterationBreakDown,'TFQMR breakdown, sigma=0'
1245    
1246        alpha = rho / sigma
1247    
1248        for j in range(2):
1249    #
1250    #   Compute y2 and u2 only if you have to
1251    #
1252          if ( j == 1 ):
1253            y2 = y1 - alpha * v
1254            u2 = Aprod(y2)
1255    
1256          m = 2 * (iter+1) - 2 + (j+1)
1257          if j==0:
1258             w = w - alpha * u1
1259             d = y1 + ( theta * theta * eta / alpha ) * d
1260          if j==1:
1261             w = w - alpha * u2
1262             d = y2 + ( theta * theta * eta / alpha ) * d
1263    
1264          theta = math.sqrt(bilinearform(w,w))/ tau
1265          c = 1.0 / math.sqrt ( 1.0 + theta * theta )
1266          tau = tau * theta * c
1267          eta = c * c * alpha
1268          x = x + eta * d
1269    #
1270    #  Try to terminate the iteration at each pass through the loop
1271    #
1272        if rho == 0.0: raise IterationBreakDown,'TFQMR breakdown, rho=0'
1273    
1274        rhon = bilinearform(r,w)
1275        beta = rhon / rho;
1276        rho = rhon;
1277        y1 = w + beta * y2;
1278        u1 = Aprod(y1)
1279        v = u1 + beta * ( u2 + beta * v )
1280    
1281        iter += 1
1282    
1283      return x
1284    
1285    
1286    #############################################
1287    
1288    class ArithmeticTuple(object):
1289       """
1290       Tuple supporting inplace update x+=y and scaling x=a*y where C{x,y} is an
1291       ArithmeticTuple and C{a} is a float.
1292    
1293       Example of usage::
1294    
1295           from esys.escript import Data
1296           from numpy import array
1297           a=Data(...)
1298           b=array([1.,4.])
1299           x=ArithmeticTuple(a,b)
1300           y=5.*x
1301    
1302       """
1303       def __init__(self,*args):
1304           """
1305           Initializes object with elements C{args}.
1306    
1307           @param args: tuple of objects that support inplace add (x+=y) and
1308                        scaling (x=a*y)
1309           """
1310           self.__items=list(args)
1311    
1312       def __len__(self):
1313           """
1314           Returns the number of items.
1315    
1316           @return: number of items
1317           @rtype: C{int}
1318           """
1319           return len(self.__items)
1320    
1321       def __getitem__(self,index):
1322           """
1323           Returns item at specified position.
1324    
1325           @param index: index of item to be returned
1326           @type index: C{int}
1327           @return: item with index C{index}
1328           """
1329           return self.__items.__getitem__(index)
1330    
1331       def __mul__(self,other):
1332           """
1333           Scales by C{other} from the right.
1334    
1335           @param other: scaling factor
1336           @type other: C{float}
1337           @return: itemwise self*other
1338           @rtype: L{ArithmeticTuple}
1339           """
1340           out=[]
1341           try:
1342               l=len(other)
1343               if l!=len(self):
1344                   raise ValueError,"length of arguments don't match."
1345               for i in range(l): out.append(self[i]*other[i])
1346           except TypeError:
1347               for i in range(len(self)): out.append(self[i]*other)
1348           return ArithmeticTuple(*tuple(out))
1349    
1350       def __rmul__(self,other):
1351           """
1352           Scales by C{other} from the left.
1353    
1354           @param other: scaling factor
1355           @type other: C{float}
1356           @return: itemwise other*self
1357           @rtype: L{ArithmeticTuple}
1358           """
1359           out=[]
1360           try:
1361               l=len(other)
1362               if l!=len(self):
1363                   raise ValueError,"length of arguments don't match."
1364               for i in range(l): out.append(other[i]*self[i])
1365           except TypeError:
1366               for i in range(len(self)): out.append(other*self[i])
1367           return ArithmeticTuple(*tuple(out))
1368    
1369       def __div__(self,other):
1370           """
1371           Scales by (1/C{other}) from the right.
1372    
1373           @param other: scaling factor
1374           @type other: C{float}
1375           @return: itemwise self/other
1376           @rtype: L{ArithmeticTuple}
1377           """
1378           return self*(1/other)
1379    
1380       def __rdiv__(self,other):
1381           """
1382           Scales by (1/C{other}) from the left.
1383    
1384           @param other: scaling factor
1385           @type other: C{float}
1386           @return: itemwise other/self
1387           @rtype: L{ArithmeticTuple}
1388           """
1389           out=[]
1390           try:
1391               l=len(other)
1392               if l!=len(self):
1393                   raise ValueError,"length of arguments don't match."
1394               for i in range(l): out.append(other[i]/self[i])
1395           except TypeError:
1396               for i in range(len(self)): out.append(other/self[i])
1397           return ArithmeticTuple(*tuple(out))
1398    
1399       def __iadd__(self,other):
1400           """
1401           Inplace addition of C{other} to self.
1402    
1403           @param other: increment
1404           @type other: C{ArithmeticTuple}
1405           """
1406           if len(self) != len(other):
1407               raise ValueError,"tuple lengths must match."
1408           for i in range(len(self)):
1409               self.__items[i]+=other[i]
1410           return self
1411    
1412       def __add__(self,other):
1413           """
1414           Adds C{other} to self.
1415    
1416           @param other: increment
1417           @type other: C{ArithmeticTuple}
1418           """
1419           out=[]
1420           try:
1421               l=len(other)
1422               if l!=len(self):
1423                   raise ValueError,"length of arguments don't match."
1424               for i in range(l): out.append(self[i]+other[i])
1425           except TypeError:
1426               for i in range(len(self)): out.append(self[i]+other)
1427           return ArithmeticTuple(*tuple(out))
1428    
1429       def __sub__(self,other):
1430           """
1431           Subtracts C{other} from self.
1432    
1433           @param other: decrement
1434           @type other: C{ArithmeticTuple}
1435           """
1436           out=[]
1437           try:
1438               l=len(other)
1439               if l!=len(self):
1440                   raise ValueError,"length of arguments don't match."
1441               for i in range(l): out.append(self[i]-other[i])
1442           except TypeError:
1443               for i in range(len(self)): out.append(self[i]-other)
1444           return ArithmeticTuple(*tuple(out))
1445    
1446       def __isub__(self,other):
1447           """
1448           Inplace subtraction of C{other} from self.
1449    
1450           @param other: decrement
1451           @type other: C{ArithmeticTuple}
1452           """
1453           if len(self) != len(other):
1454               raise ValueError,"tuple length must match."
1455           for i in range(len(self)):
1456               self.__items[i]-=other[i]
1457           return self
1458    
1459       def __neg__(self):
1460           """
1461           Negates values.
1462           """
1463           out=[]
1464           for i in range(len(self)):
1465               out.append(-self[i])
1466           return ArithmeticTuple(*tuple(out))
1467    
1468    
1469    class HomogeneousSaddlePointProblem(object):
1470          """
1471          This class provides a framework for solving linear homogeneous saddle
1472          point problems of the form::
1473    
1474              M{Av+B^*p=f}
1475              M{Bv     =0}
1476    
1477          for the unknowns M{v} and M{p} and given operators M{A} and M{B} and
1478          given right hand side M{f}. M{B^*} is the adjoint operator of M{B}.
1479          """
1480          def __init__(self, adaptSubTolerance=True, **kwargs):
1481        """
1482        initializes the saddle point problem
1483        
1484        @param adaptSubTolerance: If True the tolerance for subproblem is set automatically.
1485        @type adaptSubTolerance: C{bool}
1486        """
1487            self.setTolerance()
1488            self.setAbsoluteTolerance()
1489        self.__adaptSubTolerance=adaptSubTolerance
1490          #=============================================================
1491          def initialize(self):
1492            """
1493            Initializes the problem (overwrite).
1494            """
1495            pass
1496    
1497          def inner_pBv(self,p,Bv):
1498             """
1499             Returns inner product of element p and Bv (overwrite).
1500    
1501             @param p: a pressure increment
1502             @param v: a residual
1503             @return: inner product of element p and Bv
1504             @rtype: C{float}
1505             @note: used if PCG is applied.
1506             """
1507             raise NotImplementedError,"no inner product for p and Bv implemented."
1508    
1509          def inner_p(self,p0,p1):
1510             """
1511             Returns inner product of p0 and p1 (overwrite).
1512    
1513             @param p0: a pressure
1514             @param p1: a pressure
1515             @return: inner product of p0 and p1
1516             @rtype: C{float}
1517             """
1518             raise NotImplementedError,"no inner product for p implemented."
1519      
1520          def norm_v(self,v):
1521             """
1522             Returns the norm of v (overwrite).
1523    
1524             @param v: a velovity
1525             @return: norm of v
1526             @rtype: non-negative C{float}
1527             """
1528             raise NotImplementedError,"no norm of v implemented."
1529          def getV(self, p, v0):
1530             """
1531             return the value for v for a given p (overwrite)
1532    
1533             @param p: a pressure
1534             @param v0: a initial guess for the value v to return.
1535             @return: v given as M{v= A^{-1} (f-B^*p)}
1536             """
1537             raise NotImplementedError,"no v calculation implemented."
1538    
1539            
1540          def Bv(self,v):
1541            """
1542            Returns Bv (overwrite).
1543    
1544            @rtype: equal to the type of p
1545            @note: boundary conditions on p should be zero!
1546            """
1547            raise NotImplementedError, "no operator B implemented."
1548    
1549          def norm_Bv(self,Bv):
1550            """
1551            Returns the norm of Bv (overwrite).
1552    
1553            @rtype: equal to the type of p
1554            @note: boundary conditions on p should be zero!
1555            """
1556            raise NotImplementedError, "no norm of Bv implemented."
1557    
1558          def solve_AinvBt(self,p):
1559             """
1560             Solves M{Av=B^*p} with accuracy L{self.getSubProblemTolerance()}
1561             (overwrite).
1562    
1563             @param p: a pressure increment
1564             @return: the solution of M{Av=B^*p}
1565             @note: boundary conditions on v should be zero!
1566             """
1567             raise NotImplementedError,"no operator A implemented."
1568    
1569          def solve_prec(self,Bv):
1570             """
1571             Provides a preconditioner for M{BA^{-1}B^*} applied to Bv with accuracy
1572             L{self.getSubProblemTolerance()} (overwrite).
1573    
1574             @rtype: equal to the type of p
1575             @note: boundary conditions on p should be zero!
1576             """
1577             raise NotImplementedError,"no preconditioner for Schur complement implemented."
1578          def setSubProblemTolerance(self):
1579             """
1580         Updates the tolerance for subproblems
1581         @note: method is typically the method is overwritten.
1582             """
1583             pass
1584          #=============================================================
1585          def __Aprod_PCG(self,p):
1586              dv=self.solve_AinvBt(p)
1587              return ArithmeticTuple(dv,self.Bv(dv))
1588    
1589          def __inner_PCG(self,p,r):
1590             return self.inner_pBv(p,r[1])
1591    
1592          def __Msolve_PCG(self,r):
1593              return self.solve_prec(r[1])
1594          #=============================================================
1595          def __Aprod_GMRES(self,p):
1596              return self.solve_prec(self.Bv(self.solve_AinvBt(p)))
1597          def __inner_GMRES(self,p0,p1):
1598             return self.inner_p(p0,p1)
1599    
1600          #=============================================================
1601          def norm_p(self,p):
1602              """
1603              calculates the norm of C{p}
1604              
1605              @param p: a pressure
1606              @return: the norm of C{p} using the inner product for pressure
1607              @rtype: C{float}
1608              """
1609              f=self.inner_p(p,p)
1610              if f<0: raise ValueError,"negative pressure norm."
1611              return math.sqrt(f)
1612          def adaptSubTolerance(self):
1613          """
1614          Returns True if tolerance adaption for subproblem is choosen.
1615          """
1616              self.__adaptSubTolerance
1617          
1618          def solve(self,v,p,max_iter=20, verbose=False, usePCG=True, iter_restart=20, max_correction_steps=10):
1619             """
1620             Solves the saddle point problem using initial guesses v and p.
1621    
1622             @param v: initial guess for velocity
1623             @param p: initial guess for pressure
1624             @type v: L{Data}
1625             @type p: L{Data}
1626             @param usePCG: indicates the usage of the PCG rather than GMRES scheme.
1627             @param max_iter: maximum number of iteration steps per correction
1628                              attempt
1629             @param verbose: if True, shows information on the progress of the
1630                             saddlepoint problem solver.
1631             @param iter_restart: restart the iteration after C{iter_restart} steps
1632                                  (only used if useUzaw=False)
1633             @type usePCG: C{bool}
1634             @type max_iter: C{int}
1635             @type verbose: C{bool}
1636             @type iter_restart: C{int}
1637             @rtype: C{tuple} of L{Data} objects
1638             """
1639             self.verbose=verbose
1640             rtol=self.getTolerance()
1641             atol=self.getAbsoluteTolerance()
1642         if self.adaptSubTolerance(): self.setSubProblemTolerance()
1643             correction_step=0
1644             converged=False
1645             while not converged:
1646                  # calculate velocity for current pressure:
1647                  v=self.getV(p,v)
1648                  Bv=self.Bv(v)
1649                  norm_v=self.norm_v(v)
1650                  norm_Bv=self.norm_Bv(Bv)
1651                  ATOL=norm_v*rtol+atol
1652                  if self.verbose: print "HomogeneousSaddlePointProblem: norm v= %e, norm_Bv= %e, tolerance = %e."%(norm_v, norm_Bv,ATOL)
1653                  if not ATOL>0: raise ValueError,"overall absolute tolerance needs to be positive."
1654                  if norm_Bv <= ATOL:
1655                     converged=True
1656                  else:
1657                     correction_step+=1
1658                     if correction_step>max_correction_steps:
1659                          raise CorrectionFailed,"Given up after %d correction steps."%correction_step
1660                     dp=self.solve_prec(Bv)
1661                     if usePCG:
1662                       norm2=self.inner_pBv(dp,Bv)
1663                       if norm2<0: raise ValueError,"negative PCG norm."
1664                       norm2=math.sqrt(norm2)
1665                     else:
1666                       norm2=self.norm_p(dp)
1667                     ATOL_ITER=ATOL/norm_Bv*norm2*0.5
1668                     if self.verbose: print "HomogeneousSaddlePointProblem: tolerance for solver: %e"%ATOL_ITER
1669                     if usePCG:
1670                           p,v0,a_norm=PCG(ArithmeticTuple(v,Bv),self.__Aprod_PCG,p,self.__Msolve_PCG,self.__inner_PCG,atol=ATOL_ITER, rtol=0.,iter_max=max_iter, verbose=self.verbose)
1671                     else:
1672                           p=GMRES(dp,self.__Aprod_GMRES, p, self.__inner_GMRES,atol=ATOL_ITER, rtol=0.,iter_max=max_iter, iter_restart=iter_restart, verbose=self.verbose)
1673             if self.verbose: print "HomogeneousSaddlePointProblem: tolerance reached."
1674         return v,p
1675    
1676          #========================================================================
1677          def setTolerance(self,tolerance=1.e-4):
1678             """
1679             Sets the relative tolerance for (v,p).
1680    
1681             @param tolerance: tolerance to be used
1682             @type tolerance: non-negative C{float}
1683             """
1684             if tolerance<0:
1685                 raise ValueError,"tolerance must be positive."
1686             self.__rtol=tolerance
1687    
1688          def getTolerance(self):
1689             """
1690             Returns the relative tolerance.
1691    
1692             @return: relative tolerance
1693             @rtype: C{float}
1694             """
1695             return self.__rtol
1696    
1697          def setAbsoluteTolerance(self,tolerance=0.):
1698             """
1699             Sets the absolute tolerance.
1700    
1701             @param tolerance: tolerance to be used
1702             @type tolerance: non-negative C{float}
1703             """
1704             if tolerance<0:
1705                 raise ValueError,"tolerance must be non-negative."
1706             self.__atol=tolerance
1707    
1708          def getAbsoluteTolerance(self):
1709             """
1710             Returns the absolute tolerance.
1711    
1712             @return: absolute tolerance
1713             @rtype: C{float}
1714             """
1715             return self.__atol
1716    
1717          def getSubProblemTolerance(self):
1718             """
1719             Sets the relative tolerance to solve the subproblem(s).
1720    
1721             @param rtol: relative tolerence
1722             @type rtol: positive C{float}
1723             """
1724             return max(200.*util.EPSILON,self.getTolerance()**2)
1725    
1726    def MaskFromBoundaryTag(domain,*tags):
1727       """
1728       Creates a mask on the Solution(domain) function space where the value is
1729       one for samples that touch the boundary tagged by tags.
1730    
1731       Usage: m=MaskFromBoundaryTag(domain, "left", "right")
1732    
1733       @param domain: domain to be used
1734       @type domain: L{escript.Domain}
1735       @param tags: boundary tags
1736       @type tags: C{str}
1737       @return: a mask which marks samples that are touching the boundary tagged
1738                by any of the given tags
1739       @rtype: L{escript.Data} of rank 0
1740       """
1741       pde=linearPDEs.LinearPDE(domain,numEquations=1, numSolutions=1)
1742       d=escript.Scalar(0.,escript.FunctionOnBoundary(domain))
1743       for t in tags: d.setTaggedValue(t,1.)
1744       pde.setValue(y=d)
1745       return util.whereNonZero(pde.getRightHandSide())
1746    
1747    def MaskFromTag(domain,*tags):
1748       """
1749       Creates a mask on the Solution(domain) function space where the value is
1750       one for samples that touch regions tagged by tags.
1751    
1752       Usage: m=MaskFromTag(domain, "ham")
1753    
1754       @param domain: domain to be used
1755       @type domain: L{escript.Domain}
1756       @param tags: boundary tags
1757       @type tags: C{str}
1758       @return: a mask which marks samples that are touching the boundary tagged
1759                by any of the given tags
1760       @rtype: L{escript.Data} of rank 0
1761       """
1762       pde=linearPDEs.LinearPDE(domain,numEquations=1, numSolutions=1)
1763       d=escript.Scalar(0.,escript.Function(domain))
1764       for t in tags: d.setTaggedValue(t,1.)
1765       pde.setValue(Y=d)
1766       return util.whereNonZero(pde.getRightHandSide())
1767    
1768    

Legend:
Removed from v.782  
changed lines
  Added in v.2517

  ViewVC Help
Powered by ViewVC 1.1.26