/[escript]/trunk/escript/py_src/pdetools.py
ViewVC logotype

Diff of /trunk/escript/py_src/pdetools.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 1956 by gross, Mon Nov 3 05:08:42 2008 UTC revision 2745 by jfenwick, Tue Nov 17 04:23:02 2009 UTC
# Line 1  Line 1 
1    
2  ########################################################  ########################################################
3  #  #
4  # Copyright (c) 2003-2008 by University of Queensland  # Copyright (c) 2003-2009 by University of Queensland
5  # Earth Systems Science Computational Center (ESSCC)  # Earth Systems Science Computational Center (ESSCC)
6  # http://www.uq.edu.au/esscc  # http://www.uq.edu.au/esscc
7  #  #
# Line 11  Line 11 
11  #  #
12  ########################################################  ########################################################
13    
14  __copyright__="""Copyright (c) 2003-2008 by University of Queensland  __copyright__="""Copyright (c) 2003-2009 by University of Queensland
15  Earth Systems Science Computational Center (ESSCC)  Earth Systems Science Computational Center (ESSCC)
16  http://www.uq.edu.au/esscc  http://www.uq.edu.au/esscc
17  Primary Business: Queensland, Australia"""  Primary Business: Queensland, Australia"""
18  __license__="""Licensed under the Open Software License version 3.0  __license__="""Licensed under the Open Software License version 3.0
19  http://www.opensource.org/licenses/osl-3.0.php"""  http://www.opensource.org/licenses/osl-3.0.php"""
20  __url__="http://www.uq.edu.au/esscc/escript-finley"  __url__="https://launchpad.net/escript-finley"
21    
22  """  """
23  Provides some tools related to PDEs.  Provides some tools related to PDEs.
24    
25  Currently includes:  Currently includes:
26      - Projector - to project a discontinuous      - Projector - to project a discontinuous function onto a continuous function
27      - Locator - to trace values in data objects at a certain location      - Locator - to trace values in data objects at a certain location
28      - TimeIntegrationManager - to handel extraplotion in time      - TimeIntegrationManager - to handle extrapolation in time
29      - SaddlePointProblem - solver for Saddle point problems using the inexact uszawa scheme      - SaddlePointProblem - solver for Saddle point problems using the inexact uszawa scheme
30    
31  @var __author__: name of author  :var __author__: name of author
32  @var __copyright__: copyrights  :var __copyright__: copyrights
33  @var __license__: licence agreement  :var __license__: licence agreement
34  @var __url__: url entry point on documentation  :var __url__: url entry point on documentation
35  @var __version__: version  :var __version__: version
36  @var __date__: date of the version  :var __date__: date of the version
37  """  """
38    
39  __author__="Lutz Gross, l.gross@uq.edu.au"  __author__="Lutz Gross, l.gross@uq.edu.au"
# Line 41  __author__="Lutz Gross, l.gross@uq.edu.a Line 41  __author__="Lutz Gross, l.gross@uq.edu.a
41    
42  import escript  import escript
43  import linearPDEs  import linearPDEs
44  import numarray  import numpy
45  import util  import util
46  import math  import math
47    
 ##### Added by Artak  
 # from Numeric import zeros,Int,Float64  
 ###################################  
   
   
48  class TimeIntegrationManager:  class TimeIntegrationManager:
49    """    """
50    a simple mechanism to manage time dependend values.    A simple mechanism to manage time dependend values.
51    
52    typical usage is::    Typical usage is::
53    
54       dt=0.1 # time increment       dt=0.1 # time increment
55       tm=TimeIntegrationManager(inital_value,p=1)       tm=TimeIntegrationManager(inital_value,p=1)
# Line 64  class TimeIntegrationManager: Line 59  class TimeIntegrationManager:
59           tm.checkin(dt,v)           tm.checkin(dt,v)
60           t+=dt           t+=dt
61    
62    @note: currently only p=1 is supported.    :note: currently only p=1 is supported.
63    """    """
64    def __init__(self,*inital_values,**kwargs):    def __init__(self,*inital_values,**kwargs):
65       """       """
66       sets up the value manager where inital_value is the initial value and p is order used for extrapolation       Sets up the value manager where ``inital_values`` are the initial values
67         and p is the order used for extrapolation.
68       """       """
69       if kwargs.has_key("p"):       if kwargs.has_key("p"):
70              self.__p=kwargs["p"]              self.__p=kwargs["p"]
# Line 85  class TimeIntegrationManager: Line 81  class TimeIntegrationManager:
81    
82    def getTime(self):    def getTime(self):
83        return self.__t        return self.__t
84    
85    def getValue(self):    def getValue(self):
86        out=self.__v_mem[0]        out=self.__v_mem[0]
87        if len(out)==1:        if len(out)==1:
# Line 94  class TimeIntegrationManager: Line 91  class TimeIntegrationManager:
91    
92    def checkin(self,dt,*values):    def checkin(self,dt,*values):
93        """        """
94        adds new values to the manager. the p+1 last value get lost        Adds new values to the manager. The p+1 last values are lost.
95        """        """
96        o=min(self.__order+1,self.__p)        o=min(self.__order+1,self.__p)
97        self.__order=min(self.__order+1,self.__p)        self.__order=min(self.__order+1,self.__p)
# Line 111  class TimeIntegrationManager: Line 108  class TimeIntegrationManager:
108    
109    def extrapolate(self,dt):    def extrapolate(self,dt):
110        """        """
111        extrapolates to dt forward in time.        Extrapolates to ``dt`` forward in time.
112        """        """
113        if self.__order==0:        if self.__order==0:
114           out=self.__v_mem[0]           out=self.__v_mem[0]
# Line 126  class TimeIntegrationManager: Line 123  class TimeIntegrationManager:
123           return out[0]           return out[0]
124        else:        else:
125           return out           return out
126    
127    
128  class Projector:  class Projector:
129    """    """
130    The Projector is a factory which projects a discontiuous function onto a    The Projector is a factory which projects a discontinuous function onto a
131    continuous function on the a given domain.    continuous function on a given domain.
132    """    """
133    def __init__(self, domain, reduce = True, fast=True):    def __init__(self, domain, reduce=True, fast=True):
134      """      """
135      Create a continuous function space projector for a domain.      Creates a continuous function space projector for a domain.
136    
137      @param domain: Domain of the projection.      :param domain: Domain of the projection.
138      @param reduce: Flag to reduce projection order (default is True)      :param reduce: Flag to reduce projection order
139      @param fast: Flag to use a fast method based on matrix lumping (default is true)      :param fast: Flag to use a fast method based on matrix lumping
140      """      """
141      self.__pde = linearPDEs.LinearPDE(domain)      self.__pde = linearPDEs.LinearPDE(domain)
142      if fast:      if fast:
143        self.__pde.setSolverMethod(linearPDEs.LinearPDE.LUMPING)          self.__pde.getSolverOptions().setSolverMethod(linearPDEs.SolverOptions.LUMPING)
144      self.__pde.setSymmetryOn()      self.__pde.setSymmetryOn()
145      self.__pde.setReducedOrderTo(reduce)      self.__pde.setReducedOrderTo(reduce)
146      self.__pde.setValue(D = 1.)      self.__pde.setValue(D = 1.)
147      return      return
148      def getSolverOptions(self):
149        """
150        Returns the solver options of the PDE solver.
151        
152        :rtype: `linearPDEs.SolverOptions`
153        """
154        return self.__pde.getSolverOptions()
155    
156    def __call__(self, input_data):    def __call__(self, input_data):
157      """      """
158      Projects input_data onto a continuous function      Projects ``input_data`` onto a continuous function.
159    
160      @param input_data: The input_data to be projected.      :param input_data: the data to be projected
161      """      """
162      out=escript.Data(0.,input_data.getShape(),self.__pde.getFunctionSpaceForSolution())      out=escript.Data(0.,input_data.getShape(),self.__pde.getFunctionSpaceForSolution())
163      self.__pde.setValue(Y = escript.Data(), Y_reduced = escript.Data())      self.__pde.setValue(Y = escript.Data(), Y_reduced = escript.Data())
# Line 186  class Projector: Line 190  class Projector:
190    
191  class NoPDE:  class NoPDE:
192       """       """
193       solves the following problem for u:       Solves the following problem for u:
194    
195       M{kronecker[i,j]*D[j]*u[j]=Y[i]}       *kronecker[i,j]*D[j]*u[j]=Y[i]*
196    
197       with constraint       with constraint
198    
199       M{u[j]=r[j]}  where M{q[j]>0}       *u[j]=r[j]*  where *q[j]>0*
200    
201       where D, Y, r and q are given functions of rank 1.       where *D*, *Y*, *r* and *q* are given functions of rank 1.
202    
203       In the case of scalars this takes the form       In the case of scalars this takes the form
204    
205       M{D*u=Y}       *D*u=Y*
206    
207       with constraint       with constraint
208    
209       M{u=r}  where M{q>0}       *u=r* where *q>0*
210    
211       where D, Y, r and q are given scalar functions.       where *D*, *Y*, *r* and *q* are given scalar functions.
212    
213       The constraint is overwriting any other condition.       The constraint overwrites any other condition.
214    
215       @note: This class is similar to the L{linearPDEs.LinearPDE} class with A=B=C=X=0 but has the intention       :note: This class is similar to the `linearPDEs.LinearPDE` class with
216              that all input parameter are given in L{Solution} or L{ReducedSolution}. The whole              A=B=C=X=0 but has the intention that all input parameters are given
217              thing is a bit strange and I blame Robert.Woodcock@csiro.au for this.              in `Solution` or `ReducedSolution`.
218       """       """
219         # The whole thing is a bit strange and I blame Rob Woodcock (CSIRO) for
220         # this.
221       def __init__(self,domain,D=None,Y=None,q=None,r=None):       def __init__(self,domain,D=None,Y=None,q=None,r=None):
222           """           """
223           initialize the problem           Initializes the problem.
224    
225           @param domain: domain of the PDE.           :param domain: domain of the PDE
226           @type domain: L{Domain}           :type domain: `Domain`
227           @param D: coefficient of the solution.           :param D: coefficient of the solution
228           @type D: C{float}, C{int}, L{numarray.NumArray}, L{Data}           :type D: ``float``, ``int``, ``numpy.ndarray``, `Data`
229           @param Y: right hand side           :param Y: right hand side
230           @type Y: C{float}, C{int}, L{numarray.NumArray}, L{Data}           :type Y: ``float``, ``int``, ``numpy.ndarray``, `Data`
231           @param q: location of constraints           :param q: location of constraints
232           @type q: C{float}, C{int}, L{numarray.NumArray}, L{Data}           :type q: ``float``, ``int``, ``numpy.ndarray``, `Data`
233           @param r: value of solution at locations of constraints           :param r: value of solution at locations of constraints
234           @type r: C{float}, C{int}, L{numarray.NumArray}, L{Data}           :type r: ``float``, ``int``, ``numpy.ndarray``, `Data`
235           """           """
236           self.__domain=domain           self.__domain=domain
237           self.__D=D           self.__D=D
# Line 234  class NoPDE: Line 240  class NoPDE:
240           self.__r=r           self.__r=r
241           self.__u=None           self.__u=None
242           self.__function_space=escript.Solution(self.__domain)           self.__function_space=escript.Solution(self.__domain)
243    
244       def setReducedOn(self):       def setReducedOn(self):
245           """           """
246           sets the L{FunctionSpace} of the solution to L{ReducedSolution}           Sets the `FunctionSpace` of the solution to `ReducedSolution`.
247           """           """
248           self.__function_space=escript.ReducedSolution(self.__domain)           self.__function_space=escript.ReducedSolution(self.__domain)
249           self.__u=None           self.__u=None
250    
251       def setReducedOff(self):       def setReducedOff(self):
252           """           """
253           sets the L{FunctionSpace} of the solution to L{Solution}           Sets the `FunctionSpace` of the solution to `Solution`.
254           """           """
255           self.__function_space=escript.Solution(self.__domain)           self.__function_space=escript.Solution(self.__domain)
256           self.__u=None           self.__u=None
257            
258       def setValue(self,D=None,Y=None,q=None,r=None):       def setValue(self,D=None,Y=None,q=None,r=None):
259           """           """
260           assigns values to the parameters.           Assigns values to the parameters.
261    
262           @param D: coefficient of the solution.           :param D: coefficient of the solution
263           @type D: C{float}, C{int}, L{numarray.NumArray}, L{Data}           :type D: ``float``, ``int``, ``numpy.ndarray``, `Data`
264           @param Y: right hand side           :param Y: right hand side
265           @type Y: C{float}, C{int}, L{numarray.NumArray}, L{Data}           :type Y: ``float``, ``int``, ``numpy.ndarray``, `Data`
266           @param q: location of constraints           :param q: location of constraints
267           @type q: C{float}, C{int}, L{numarray.NumArray}, L{Data}           :type q: ``float``, ``int``, ``numpy.ndarray``, `Data`
268           @param r: value of solution at locations of constraints           :param r: value of solution at locations of constraints
269           @type r: C{float}, C{int}, L{numarray.NumArray}, L{Data}           :type r: ``float``, ``int``, ``numpy.ndarray``, `Data`
270           """           """
271           if not D==None:           if not D==None:
272              self.__D=D              self.__D=D
# Line 276  class NoPDE: Line 283  class NoPDE:
283    
284       def getSolution(self):       def getSolution(self):
285           """           """
286           returns the solution           Returns the solution.
287            
288           @return: the solution of the problem           :return: the solution of the problem
289           @rtype: L{Data} object in the L{FunctionSpace} L{Solution} or L{ReducedSolution}.           :rtype: `Data` object in the `FunctionSpace` `Solution` or
290                     `ReducedSolution`
291           """           """
292           if self.__u==None:           if self.__u==None:
293              if self.__D==None:              if self.__D==None:
# Line 296  class NoPDE: Line 304  class NoPDE:
304                  self.__u*=(1.-q)                  self.__u*=(1.-q)
305                  if not self.__r==None: self.__u+=q*self.__r                  if not self.__r==None: self.__u+=q*self.__r
306           return self.__u           return self.__u
307                
308  class Locator:  class Locator:
309       """       """
310       Locator provides access to the values of data objects at a given       Locator provides access to the values of data objects at a given spatial
311       spatial coordinate x.         coordinate x.
312        
313       In fact, a Locator object finds the sample in the set of samples of a       In fact, a Locator object finds the sample in the set of samples of a
314       given function space or domain where which is closest to the given       given function space or domain which is closest to the given point x.
      point x.  
315       """       """
316    
317       def __init__(self,where,x=numarray.zeros((3,))):       def __init__(self,where,x=numpy.zeros((3,))):
318         """         """
319         Initializes a Locator to access values in Data objects on the Doamin         Initializes a Locator to access values in Data objects on the Doamin
320         or FunctionSpace where for the sample point which         or FunctionSpace for the sample point which is closest to the given
321         closest to the given point x.         point x.
322    
323         @param where: function space         :param where: function space
324         @type where: L{escript.FunctionSpace}         :type where: `escript.FunctionSpace`
325         @param x: coefficient of the solution.         :param x: location(s) of the Locator
326         @type x: L{numarray.NumArray} or C{list} of L{numarray.NumArray}         :type x: ``numpy.ndarray`` or ``list`` of ``numpy.ndarray``
327         """         """
328         if isinstance(where,escript.FunctionSpace):         if isinstance(where,escript.FunctionSpace):
329            self.__function_space=where            self.__function_space=where
330         else:         else:
331            self.__function_space=escript.ContinuousFunction(where)            self.__function_space=escript.ContinuousFunction(where)
332           iterative=False
333         if isinstance(x, list):         if isinstance(x, list):
334               if len(x)==0:
335                  raise "ValueError", "At least one point must be given."
336               try:
337                 iter(x[0])
338                 iterative=True
339               except TypeError:
340                 iterative=False
341           if iterative:
342             self.__id=[]             self.__id=[]
343             for p in x:             for p in x:
344                self.__id.append(util.length(self.__function_space.getX()-p[:self.__function_space.getDim()]).minGlobalDataPoint())                self.__id.append(util.length(self.__function_space.getX()-p[:self.__function_space.getDim()]).minGlobalDataPoint())
# Line 334  class Locator: Line 350  class Locator:
350         Returns the coordinates of the Locator as a string.         Returns the coordinates of the Locator as a string.
351         """         """
352         x=self.getX()         x=self.getX()
353         if instance(x,list):         if isinstance(x,list):
354            out="["            out="["
355            first=True            first=True
356            for xx in x:            for xx in x:
# Line 350  class Locator: Line 366  class Locator:
366    
367       def getX(self):       def getX(self):
368          """          """
369      Returns the exact coordinates of the Locator.          Returns the exact coordinates of the Locator.
370      """          """
371          return self(self.getFunctionSpace().getX())          return self(self.getFunctionSpace().getX())
372    
373       def getFunctionSpace(self):       def getFunctionSpace(self):
374          """          """
375      Returns the function space of the Locator.          Returns the function space of the Locator.
376      """          """
377          return self.__function_space          return self.__function_space
378    
379       def getId(self,item=None):       def getId(self,item=None):
380          """          """
381      Returns the identifier of the location.          Returns the identifier of the location.
382      """          """
383          if item == None:          if item == None:
384             return self.__id             return self.__id
385          else:          else:
# Line 375  class Locator: Line 391  class Locator:
391    
392       def __call__(self,data):       def __call__(self,data):
393          """          """
394      Returns the value of data at the Locator of a Data object otherwise          Returns the value of data at the Locator of a Data object.
395      the object is returned.          """
     """  
396          return self.getValue(data)          return self.getValue(data)
397    
398       def getValue(self,data):       def getValue(self,data):
399          """          """
400      Returns the value of data at the Locator if data is a Data object          Returns the value of ``data`` at the Locator if ``data`` is a `Data`
401      otherwise the object is returned.          object otherwise the object is returned.
402      """          """
403          if isinstance(data,escript.Data):          if isinstance(data,escript.Data):
404             if data.getFunctionSpace()==self.getFunctionSpace():             dat=util.interpolate(data,self.getFunctionSpace())
              dat=data  
            else:  
              dat=data.interpolate(self.getFunctionSpace())  
405             id=self.getId()             id=self.getId()
406             r=data.getRank()             r=data.getRank()
407             if isinstance(id,list):             if isinstance(id,list):
408                 out=[]                 out=[]
409                 for i in id:                 for i in id:
410                    o=data.getValueOfGlobalDataPoint(*i)                    o=numpy.array(dat.getTupleForGlobalDataPoint(*i))
411                    if data.getRank()==0:                    if data.getRank()==0:
412                       out.append(o[0])                       out.append(o[0])
413                    else:                    else:
414                       out.append(o)                       out.append(o)
415                 return out                 return out
416             else:             else:
417               out=data.getValueOfGlobalDataPoint(*id)               out=numpy.array(dat.getTupleForGlobalDataPoint(*id))
418               if data.getRank()==0:               if data.getRank()==0:
419                  return out[0]                  return out[0]
420               else:               else:
# Line 410  class Locator: Line 422  class Locator:
422          else:          else:
423             return data             return data
424    
425    
426    def getInfLocator(arg):
427        """
428        Return a Locator for a point with the inf value over all arg.
429        """
430        if not isinstance(arg, escript.Data):
431        raise TypeError,"getInfLocator: Unknown argument type."
432        a_inf=util.inf(arg)
433        loc=util.length(arg-a_inf).minGlobalDataPoint() # This gives us the location but not coords
434        x=arg.getFunctionSpace().getX()
435        x_min=x.getTupleForGlobalDataPoint(*loc)
436        return Locator(arg.getFunctionSpace(),x_min)
437    
438    def getSupLocator(arg):
439        """
440        Return a Locator for a point with the sup value over all arg.
441        """
442        if not isinstance(arg, escript.Data):
443        raise TypeError,"getInfLocator: Unknown argument type."
444        a_inf=util.sup(arg)
445        loc=util.length(arg-a_inf).minGlobalDataPoint() # This gives us the location but not coords
446        x=arg.getFunctionSpace().getX()
447        x_min=x.getTupleForGlobalDataPoint(*loc)
448        return Locator(arg.getFunctionSpace(),x_min)
449        
450    
451  class SolverSchemeException(Exception):  class SolverSchemeException(Exception):
452     """     """
453     exceptions thrown by solvers     This is a generic exception thrown by solvers.
454     """     """
455     pass     pass
456    
457  class IndefinitePreconditioner(SolverSchemeException):  class IndefinitePreconditioner(SolverSchemeException):
458     """     """
459     the preconditioner is not positive definite.     Exception thrown if the preconditioner is not positive definite.
460     """     """
461     pass     pass
462    
463  class MaxIterReached(SolverSchemeException):  class MaxIterReached(SolverSchemeException):
464     """     """
465     maxium number of iteration steps is reached.     Exception thrown if the maximum number of iteration steps is reached.
466     """     """
467     pass     pass
468  class IterationBreakDown(SolverSchemeException):  
469    class CorrectionFailed(SolverSchemeException):
470     """     """
471     iteration scheme econouters an incurable breakdown.     Exception thrown if no convergence has been achieved in the solution
472       correction scheme.
473     """     """
474     pass     pass
475  class NegativeNorm(SolverSchemeException):  
476    class IterationBreakDown(SolverSchemeException):
477     """     """
478     a norm calculation returns a negative norm.     Exception thrown if the iteration scheme encountered an incurable breakdown.
479     """     """
480     pass     pass
481    
482  class IterationHistory(object):  class NegativeNorm(SolverSchemeException):
483     """     """
484     The IterationHistory class is used to define a stopping criterium. It keeps track of the     Exception thrown if a norm calculation returns a negative norm.
    residual norms. The stoppingcriterium indicates termination if the residual norm has been reduced by  
    a given tolerance.  
485     """     """
486     def __init__(self,tolerance=math.sqrt(util.EPSILON),verbose=False):     pass
       """  
       Initialization  
   
       @param tolerance: tolerance  
       @type tolerance: positive C{float}  
       @param verbose: switches on the printing out some information  
       @type verbose: C{bool}  
       """  
       if not tolerance>0.:  
           raise ValueError,"tolerance needs to be positive."  
       self.tolerance=tolerance  
       self.verbose=verbose  
       self.history=[]  
    def stoppingcriterium(self,norm_r,r,x):  
        """  
        returns True if the C{norm_r} is C{tolerance}*C{norm_r[0]} where C{norm_r[0]}  is the residual norm at the first call.  
487    
488          def PCG(r, Aprod, x, Msolve, bilinearform, atol=0, rtol=1.e-8, iter_max=100, initial_guess=True, verbose=False):
489         @param norm_r: current residual norm     """
490         @type norm_r: non-negative C{float}     Solver for
        @param r: current residual (not used)  
        @param x: current solution approximation (not used)  
        @return: C{True} is the stopping criterium is fullfilled. Otherwise C{False} is returned.  
        @rtype: C{bool}  
491    
492         """     *Ax=b*
        self.history.append(norm_r)  
        if self.verbose: print "iter: %s:  inner(rhat,r) = %e"#(len(self.history)-1, self.history[-1])  
        return self.history[-1]<=self.tolerance * self.history[0]  
493    
494     def stoppingcriterium2(self,norm_r,norm_b,solver="GMRES",TOL=None):     with a symmetric and positive definite operator A (more details required!).
495         """     It uses the conjugate gradient method with preconditioner M providing an
496         returns True if the C{norm_r} is C{tolerance}*C{norm_b}     approximation of A.
497    
498             The iteration is terminated if
        @param norm_r: current residual norm  
        @type norm_r: non-negative C{float}  
        @param norm_b: norm of right hand side  
        @type norm_b: non-negative C{float}  
        @return: C{True} is the stopping criterium is fullfilled. Otherwise C{False} is returned.  
        @rtype: C{bool}  
499    
500         """     *|r| <= atol+rtol*|r0|*
        if TOL==None:  
           TOL=self.tolerance  
        self.history.append(norm_r)  
        if self.verbose: print "iter: %s:  norm(r) = %e"#(len(self.history)-1, self.history[-1])  
        return self.history[-1]<=TOL * norm_b  
   
 def PCG(b, Aprod, Msolve, bilinearform, stoppingcriterium, x=None, iter_max=100):  
    """  
    Solver for  
501    
502     M{Ax=b}     where *r0* is the initial residual and *|.|* is the energy norm. In fact
503    
504     with a symmetric and positive definite operator A (more details required!).     *|r| = sqrt( bilinearform(Msolve(r),r))*
    It uses the conjugate gradient method with preconditioner M providing an approximation of A.  
   
    The iteration is terminated if the C{stoppingcriterium} function return C{True}.  
505    
506     For details on the preconditioned conjugate gradient method see the book:     For details on the preconditioned conjugate gradient method see the book:
507    
508     Templates for the Solution of Linear Systems by R. Barrett, M. Berry,     I{Templates for the Solution of Linear Systems by R. Barrett, M. Berry,
509     T.F. Chan, J. Demmel, J. Donato, J. Dongarra, V. Eijkhout, R. Pozo,     T.F. Chan, J. Demmel, J. Donato, J. Dongarra, V. Eijkhout, R. Pozo,
510     C. Romine, and H. van der Vorst.     C. Romine, and H. van der Vorst}.
511    
512     @param b: the right hand side of the liner system. C{b} is altered.     :param r: initial residual *r=b-Ax*. ``r`` is altered.
513     @type b: any object supporting inplace add (x+=y) and scaling (x=scalar*y)     :type r: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
514     @param Aprod: returns the value Ax     :param x: an initial guess for the solution
515     @type Aprod: function C{Aprod(x)} where C{x} is of the same object like argument C{x}. The returned object needs to be of the same type like argument C{b}.     :type x: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
516     @param Msolve: solves Mx=r     :param Aprod: returns the value Ax
517     @type Msolve: function C{Msolve(r)} where C{r} is of the same type like argument C{b}. The returned object needs to be of the same     :type Aprod: function ``Aprod(x)`` where ``x`` is of the same object like
518  type like argument C{x}.                  argument ``x``. The returned object needs to be of the same type
519     @param bilinearform: inner product C{<x,r>}                  like argument ``r``.
520     @type bilinearform: function C{bilinearform(x,r)} where C{x} is of the same type like argument C{x} and C{r} is . The returned value is a C{float}.     :param Msolve: solves Mx=r
521     @param stoppingcriterium: function which returns True if a stopping criterium is meet. C{stoppingcriterium} has the arguments C{norm_r}, C{r} and C{x} giving the current norm of the residual (=C{sqrt(bilinearform(Msolve(r),r)}), the current residual and the current solution approximation. C{stoppingcriterium} is called in each iteration step.     :type Msolve: function ``Msolve(r)`` where ``r`` is of the same type like
522     @type stoppingcriterium: function that returns C{True} or C{False}                   argument ``r``. The returned object needs to be of the same
523     @param x: an initial guess for the solution. If no C{x} is given 0*b is used.                   type like argument ``x``.
524     @type x: any object supporting inplace add (x+=y) and scaling (x=scalar*y)     :param bilinearform: inner product ``<x,r>``
525     @param iter_max: maximum number of iteration steps.     :type bilinearform: function ``bilinearform(x,r)`` where ``x`` is of the same
526     @type iter_max: C{int}                         type like argument ``x`` and ``r`` is. The returned value
527     @return: the solution approximation and the corresponding residual                         is a ``float``.
528     @rtype: C{tuple}     :param atol: absolute tolerance
529     @warning: C{b} and C{x} are altered.     :type atol: non-negative ``float``
530       :param rtol: relative tolerance
531       :type rtol: non-negative ``float``
532       :param iter_max: maximum number of iteration steps
533       :type iter_max: ``int``
534       :return: the solution approximation and the corresponding residual
535       :rtype: ``tuple``
536       :warning: ``r`` and ``x`` are altered.
537     """     """
538     iter=0     iter=0
    if x==None:  
       x=0*b  
    else:  
       b += (-1)*Aprod(x)  
    r=b  
539     rhat=Msolve(r)     rhat=Msolve(r)
540     d = rhat     d = rhat
541     rhat_dot_r = bilinearform(rhat, r)     rhat_dot_r = bilinearform(rhat, r)
542     if rhat_dot_r<0: raise NegativeNorm,"negative norm."     if rhat_dot_r<0: raise NegativeNorm,"negative norm."
543       norm_r0=math.sqrt(rhat_dot_r)
544       atol2=atol+rtol*norm_r0
545       if atol2<=0:
546          raise ValueError,"Non-positive tolarance."
547       atol2=max(atol2, 100. * util.EPSILON * norm_r0)
548    
549       if verbose: print "PCG: initial residual norm = %e (absolute tolerance = %e)"%(norm_r0, atol2)
550    
551     while not stoppingcriterium(math.sqrt(rhat_dot_r),r,x):  
552       while not math.sqrt(rhat_dot_r) <= atol2:
553         iter+=1         iter+=1
554         if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached."%iter_max         if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached."%iter_max
555    
556         q=Aprod(d)         q=Aprod(d)
557         alpha = rhat_dot_r / bilinearform(d, q)         alpha = rhat_dot_r / bilinearform(d, q)
558         x += alpha * d         x += alpha * d
559         r += (-alpha) * q         if isinstance(q,ArithmeticTuple):
560           r += q * (-alpha)      # Doing it the other way calls the float64.__mul__ not AT.__rmul__
561           else:
562               r += (-alpha) * q
563         rhat=Msolve(r)         rhat=Msolve(r)
564         rhat_dot_r_new = bilinearform(rhat, r)         rhat_dot_r_new = bilinearform(rhat, r)
565         beta = rhat_dot_r_new / rhat_dot_r         beta = rhat_dot_r_new / rhat_dot_r
# Line 557  type like argument C{x}. Line 568  type like argument C{x}.
568    
569         rhat_dot_r = rhat_dot_r_new         rhat_dot_r = rhat_dot_r_new
570         if rhat_dot_r<0: raise NegativeNorm,"negative norm."         if rhat_dot_r<0: raise NegativeNorm,"negative norm."
571           if verbose: print "PCG: iteration step %s: residual norm = %e"%(iter, math.sqrt(rhat_dot_r))
572     return x,r     if verbose: print "PCG: tolerance reached after %s steps."%iter
573       return x,r,math.sqrt(rhat_dot_r)
574    
575  class Defect(object):  class Defect(object):
576      """      """
577      defines a non-linear defect F(x) of a variable x      Defines a non-linear defect F(x) of a variable x.
578      """      """
579      def __init__(self):      def __init__(self):
580          """          """
581          initialize defect          Initializes defect.
582          """          """
583          self.setDerivativeIncrementLength()          self.setDerivativeIncrementLength()
584    
585      def bilinearform(self, x0, x1):      def bilinearform(self, x0, x1):
586          """          """
587          returns the inner product of x0 and x1          Returns the inner product of x0 and x1
588          @param x0: a value for x  
589          @param x1: a value for x          :param x0: value for x0
590          @return: the inner product of x0 and x1          :param x1: value for x1
591          @rtype: C{float}          :return: the inner product of x0 and x1
592            :rtype: ``float``
593          """          """
594          return 0          return 0
595          
596      def norm(self,x):      def norm(self,x):
597          """          """
598          the norm of argument C{x}          Returns the norm of argument ``x``.
599    
600          @param x: a value for x          :param x: a value
601          @return: norm of argument x          :return: norm of argument x
602          @rtype: C{float}          :rtype: ``float``
603          @note: by default C{sqrt(self.bilinearform(x,x)} is retrurned.          :note: by default ``sqrt(self.bilinearform(x,x)`` is returned.
604          """          """
605          s=self.bilinearform(x,x)          s=self.bilinearform(x,x)
606          if s<0: raise NegativeNorm,"negative norm."          if s<0: raise NegativeNorm,"negative norm."
607          return math.sqrt(s)          return math.sqrt(s)
608    
   
609      def eval(self,x):      def eval(self,x):
610          """          """
611          returns the value F of a given x          Returns the value F of a given ``x``.
612    
613          @param x: value for which the defect C{F} is evalulated.          :param x: value for which the defect ``F`` is evaluated
614          @return: value of the defect at C{x}          :return: value of the defect at ``x``
615          """          """
616          return 0          return 0
617    
618      def __call__(self,x):      def __call__(self,x):
619          return self.eval(x)          return self.eval(x)
620    
621      def setDerivativeIncrementLength(self,inc=math.sqrt(util.EPSILON)):      def setDerivativeIncrementLength(self,inc=1000.*math.sqrt(util.EPSILON)):
622          """          """
623          sets the relative length of the increment used to approximate the derivative of the defect          Sets the relative length of the increment used to approximate the
624          the increment is inc*norm(x)/norm(v)*v in the direction of v with x as a staring point.          derivative of the defect. The increment is inc*norm(x)/norm(v)*v in the
625            direction of v with x as a starting point.
626    
627          @param inc: relative increment length          :param inc: relative increment length
628          @type inc: positive C{float}          :type inc: positive ``float``
629          """          """
630          if inc<=0: raise ValueError,"positive increment required."          if inc<=0: raise ValueError,"positive increment required."
631          self.__inc=inc          self.__inc=inc
632    
633      def getDerivativeIncrementLength(self):      def getDerivativeIncrementLength(self):
634          """          """
635          returns the relative increment length used to approximate the derivative of the defect          Returns the relative increment length used to approximate the
636          @return: value of the defect at C{x}          derivative of the defect.
637          @rtype: positive C{float}          :return: value of the defect at ``x``
638            :rtype: positive ``float``
639          """          """
640          return self.__inc          return self.__inc
641    
642      def derivative(self, F0, x0, v, v_is_normalised=True):      def derivative(self, F0, x0, v, v_is_normalised=True):
643          """          """
644          returns the directional derivative at x0 in the direction of v          Returns the directional derivative at ``x0`` in the direction of ``v``.
645    
646          @param F0: value of this defect at x0          :param F0: value of this defect at x0
647          @param x0: value at which derivative is calculated.          :param x0: value at which derivative is calculated
648          @param v: direction          :param v: direction
649          @param v_is_normalised: is true to indicate that C{v} is nomalized (self.norm(v)=0)          :param v_is_normalised: True to indicate that ``v`` is nomalized
650          @return: derivative of this defect at x0 in the direction of C{v}                                  (self.norm(v)=0)
651          @note: by default numerical evaluation (self.eval(x0+eps*v)-F0)/eps is used but this method          :return: derivative of this defect at x0 in the direction of ``v``
652          maybe oepsnew verwritten to use exact evalution.          :note: by default numerical evaluation (self.eval(x0+eps*v)-F0)/eps is
653                   used but this method maybe overwritten to use exact evaluation.
654          """          """
655          normx=self.norm(x0)          normx=self.norm(x0)
656          if normx>0:          if normx>0:
# Line 651  class Defect(object): Line 666  class Defect(object):
666          F1=self.eval(x0 + epsnew * v)          F1=self.eval(x0 + epsnew * v)
667          return (F1-F0)/epsnew          return (F1-F0)/epsnew
668    
669  ######################################      ######################################
670  def NewtonGMRES(defect, x, iter_max=100, sub_iter_max=20, atol=0,rtol=1.e-4, sub_tol_max=0.5, gamma=0.9, verbose=False):  def NewtonGMRES(defect, x, iter_max=100, sub_iter_max=20, atol=0,rtol=1.e-4, subtol_max=0.5, gamma=0.9, verbose=False):
671     """     """
672     solves a non-linear problem M{F(x)=0} for unknown M{x} using the stopping criterion:     Solves a non-linear problem *F(x)=0* for unknown *x* using the stopping
673       criterion:
674    
675       *norm(F(x) <= atol + rtol * norm(F(x0)*
676    
677     M{norm(F(x) <= atol + rtol * norm(F(x0)}     where *x0* is the initial guess.
678      
679     where M{x0} is the initial guess.     :param defect: object defining the function *F*. ``defect.norm`` defines the
680                      *norm* used in the stopping criterion.
681     @param defect: object defining the the function M{F}, C{defect.norm} defines the M{norm} used in the stopping criterion.     :type defect: `Defect`
682     @type defect: L{Defect}     :param x: initial guess for the solution, ``x`` is altered.
683     @param x: initial guess for the solution, C{x} is altered.     :type x: any object type allowing basic operations such as
684     @type x: any object type allowing basic operations such as  L{numarray.NumArray}, L{Data}              ``numpy.ndarray``, `Data`
685     @param iter_max: maximum number of iteration steps     :param iter_max: maximum number of iteration steps
686     @type iter_max: positive C{int}     :type iter_max: positive ``int``
687     @param sub_iter_max:     :param sub_iter_max: maximum number of inner iteration steps
688     @type sub_iter_max:     :type sub_iter_max: positive ``int``
689     @param atol: absolute tolerance for the solution     :param atol: absolute tolerance for the solution
690     @type atol: positive C{float}     :type atol: positive ``float``
691     @param rtol: relative tolerance for the solution     :param rtol: relative tolerance for the solution
692     @type rtol: positive C{float}     :type rtol: positive ``float``
693     @param gamma: tolerance safety factor for inner iteration     :param gamma: tolerance safety factor for inner iteration
694     @type gamma: positive C{float}, less than 1     :type gamma: positive ``float``, less than 1
695     @param sub_tol_max: upper bound for inner tolerance.     :param subtol_max: upper bound for inner tolerance
696     @type sub_tol_max: positive C{float}, less than 1     :type subtol_max: positive ``float``, less than 1
697     @return: an approximation of the solution with the desired accuracy     :return: an approximation of the solution with the desired accuracy
698     @rtype: same type as the initial guess.     :rtype: same type as the initial guess
699     """     """
700     lmaxit=iter_max     lmaxit=iter_max
701     if atol<0: raise ValueError,"atol needs to be non-negative."     if atol<0: raise ValueError,"atol needs to be non-negative."
702     if rtol<0: raise ValueError,"rtol needs to be non-negative."     if rtol<0: raise ValueError,"rtol needs to be non-negative."
703     if rtol+atol<=0: raise ValueError,"rtol or atol needs to be non-negative."     if rtol+atol<=0: raise ValueError,"rtol or atol needs to be non-negative."
704     if gamma<=0 or gamma>=1: raise ValueError,"tolerance safety factor for inner iteration (gamma =%s) needs to be positive and less than 1."%gamma     if gamma<=0 or gamma>=1: raise ValueError,"tolerance safety factor for inner iteration (gamma =%s) needs to be positive and less than 1."%gamma
705     if sub_tol_max<=0 or sub_tol_max>=1: raise ValueError,"upper bound for inner tolerance for inner iteration (sub_tol_max =%s) needs to be positive and less than 1."%sub_tol_max     if subtol_max<=0 or subtol_max>=1: raise ValueError,"upper bound for inner tolerance for inner iteration (subtol_max =%s) needs to be positive and less than 1."%subtol_max
706    
707     F=defect(x)     F=defect(x)
708     fnrm=defect.norm(F)     fnrm=defect.norm(F)
709     stop_tol=atol + rtol*fnrm     stop_tol=atol + rtol*fnrm
710     sub_tol=sub_tol_max     subtol=subtol_max
711     if verbose: print "NewtonGMRES: initial residual = %e."%fnrm     if verbose: print "NewtonGMRES: initial residual = %e."%fnrm
712     if verbose: print "             tolerance = %e."%sub_tol     if verbose: print "             tolerance = %e."%subtol
713     iter=1     iter=1
714     #     #
715     # main iteration loop     # main iteration loop
716     #     #
717     while not fnrm<=stop_tol:     while not fnrm<=stop_tol:
718              if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached."%iter_max              if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached."%iter_max
719              #              #
720          #   adjust sub_tol_          #   adjust subtol_
721          #          #
722              if iter > 1:              if iter > 1:
723             rat=fnrm/fnrmo             rat=fnrm/fnrmo
724                 sub_tol_old=sub_tol                 subtol_old=subtol
725             sub_tol=gamma*rat**2             subtol=gamma*rat**2
726             if gamma*sub_tol_old**2 > .1: sub_tol=max(sub_tol,gamma*sub_tol_old**2)             if gamma*subtol_old**2 > .1: subtol=max(subtol,gamma*subtol_old**2)
727             sub_tol=max(min(sub_tol,sub_tol_max), .5*stop_tol/fnrm)             subtol=max(min(subtol,subtol_max), .5*stop_tol/fnrm)
728          #          #
729          # calculate newton increment xc          # calculate newton increment xc
730              #     if iter_max in __FDGMRES is reached MaxIterReached is thrown              #     if iter_max in __FDGMRES is reached MaxIterReached is thrown
731              #     if iter_restart -1 is returned as sub_iter              #     if iter_restart -1 is returned as sub_iter
732              #     if  atol is reached sub_iter returns the numer of steps performed to get there              #     if  atol is reached sub_iter returns the numer of steps performed to get there
733              #              #
734              #                #
735              if verbose: print "             subiteration (GMRES) is called with relative tolerance %e."%sub_tol              if verbose: print "             subiteration (GMRES) is called with relative tolerance %e."%subtol
736              try:              try:
737                 xc, sub_iter=__FDGMRES(F, defect, x, sub_tol*fnrm, iter_max=iter_max-iter, iter_restart=sub_iter_max)                 xc, sub_iter=__FDGMRES(F, defect, x, subtol*fnrm, iter_max=iter_max-iter, iter_restart=sub_iter_max)
738              except MaxIterReached:              except MaxIterReached:
739                 raise MaxIterReached,"maximum number of %s steps reached."%iter_max                 raise MaxIterReached,"maximum number of %s steps reached."%iter_max
740              if sub_iter<0:              if sub_iter<0:
# Line 734  def NewtonGMRES(defect, x, iter_max=100, Line 752  def NewtonGMRES(defect, x, iter_max=100,
752    
753  def __givapp(c,s,vin):  def __givapp(c,s,vin):
754      """      """
755      apply a sequence of Givens rotations (c,s) to the recuirsively to the vector vin      Applies a sequence of Givens rotations (c,s) recursively to the vector
756      @warning: C{vin} is altered.      ``vin``
757    
758        :warning: ``vin`` is altered.
759      """      """
760      vrot=vin      vrot=vin
761      if isinstance(c,float):      if isinstance(c,float):
762          vrot=[c*vrot[0]-s*vrot[1],s*vrot[0]+c*vrot[1]]          vrot=[c*vrot[0]-s*vrot[1],s*vrot[0]+c*vrot[1]]
763      else:      else:
764          for i in range(len(c)):          for i in range(len(c)):
765              w1=c[i]*vrot[i]-s[i]*vrot[i+1]              w1=c[i]*vrot[i]-s[i]*vrot[i+1]
766          w2=s[i]*vrot[i]+c[i]*vrot[i+1]          w2=s[i]*vrot[i]+c[i]*vrot[i+1]
767              vrot[i:i+2]=w1,w2              vrot[i]=w1
768                vrot[i+1]=w2
769      return vrot      return vrot
770    
771  def __FDGMRES(F0, defect, x0, atol, iter_max=100, iter_restart=20):  def __FDGMRES(F0, defect, x0, atol, iter_max=100, iter_restart=20):
772     h=numarray.zeros((iter_restart,iter_restart),numarray.Float64)     h=numpy.zeros((iter_restart,iter_restart),numpy.float64)
773     c=numarray.zeros(iter_restart,numarray.Float64)     c=numpy.zeros(iter_restart,numpy.float64)
774     s=numarray.zeros(iter_restart,numarray.Float64)     s=numpy.zeros(iter_restart,numpy.float64)
775     g=numarray.zeros(iter_restart,numarray.Float64)     g=numpy.zeros(iter_restart,numpy.float64)
776     v=[]     v=[]
777    
778     rho=defect.norm(F0)     rho=defect.norm(F0)
779     if rho<=0.: return x0*0     if rho<=0.: return x0*0
780      
781     v.append(-F0/rho)     v.append(-F0/rho)
782     g[0]=rho     g[0]=rho
783     iter=0     iter=0
784     while rho > atol and iter<iter_restart-1:     while rho > atol and iter<iter_restart-1:
785            if iter  >= iter_max:
786      if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached."%iter_max              raise MaxIterReached,"maximum number of %s steps reached."%iter_max
787    
788          p=defect.derivative(F0,x0,v[iter], v_is_normalised=True)          p=defect.derivative(F0,x0,v[iter], v_is_normalised=True)
789      v.append(p)          v.append(p)
790    
791      v_norm1=defect.norm(v[iter+1])          v_norm1=defect.norm(v[iter+1])
792    
793          # Modified Gram-Schmidt          # Modified Gram-Schmidt
794      for j in range(iter+1):          for j in range(iter+1):
795           h[j][iter]=defect.bilinearform(v[j],v[iter+1])                h[j,iter]=defect.bilinearform(v[j],v[iter+1])
796           v[iter+1]-=h[j][iter]*v[j]              v[iter+1]-=h[j,iter]*v[j]
797          
798      h[iter+1][iter]=defect.norm(v[iter+1])          h[iter+1,iter]=defect.norm(v[iter+1])
799      v_norm2=h[iter+1][iter]          v_norm2=h[iter+1,iter]
800    
801          # Reorthogonalize if needed          # Reorthogonalize if needed
802      if v_norm1 + 0.001*v_norm2 == v_norm1:   #Brown/Hindmarsh condition (default)          if v_norm1 + 0.001*v_norm2 == v_norm1:   #Brown/Hindmarsh condition (default)
803          for j in range(iter+1):                for j in range(iter+1):
804             hr=defect.bilinearform(v[j],v[iter+1])                  hr=defect.bilinearform(v[j],v[iter+1])
805                 h[j][iter]=h[j][iter]+hr                  h[j,iter]=h[j,iter]+hr
806                 v[iter+1] -= hr*v[j]                  v[iter+1] -= hr*v[j]
807    
808          v_norm2=defect.norm(v[iter+1])              v_norm2=defect.norm(v[iter+1])
809          h[iter+1][iter]=v_norm2              h[iter+1,iter]=v_norm2
810          #   watch out for happy breakdown          #   watch out for happy breakdown
811          if not v_norm2 == 0:          if not v_norm2 == 0:
812                  v[iter+1]=v[iter+1]/h[iter+1][iter]              v[iter+1]=v[iter+1]/h[iter+1,iter]
813    
814          #   Form and store the information for the new Givens rotation          #   Form and store the information for the new Givens rotation
815      if iter > 0 :          if iter > 0 :
816          hhat=numarray.zeros(iter+1,numarray.Float64)              hhat=numpy.zeros(iter+1,numpy.float64)
817          for i in range(iter+1) : hhat[i]=h[i][iter]              for i in range(iter+1) : hhat[i]=h[i,iter]
818          hhat=__givapp(c[0:iter],s[0:iter],hhat);              hhat=__givapp(c[0:iter],s[0:iter],hhat);
819              for i in range(iter+1) : h[i][iter]=hhat[i]              for i in range(iter+1) : h[i,iter]=hhat[i]
820    
821      mu=math.sqrt(h[iter][iter]*h[iter][iter]+h[iter+1][iter]*h[iter+1][iter])          mu=math.sqrt(h[iter,iter]*h[iter,iter]+h[iter+1,iter]*h[iter+1,iter])
822    
823      if mu!=0 :          if mu!=0 :
824          c[iter]=h[iter][iter]/mu              c[iter]=h[iter,iter]/mu
825          s[iter]=-h[iter+1][iter]/mu              s[iter]=-h[iter+1,iter]/mu
826          h[iter][iter]=c[iter]*h[iter][iter]-s[iter]*h[iter+1][iter]              h[iter,iter]=c[iter]*h[iter,iter]-s[iter]*h[iter+1,iter]
827          h[iter+1][iter]=0.0              h[iter+1,iter]=0.0
828          g[iter:iter+2]=__givapp(c[iter],s[iter],g[iter:iter+2])              gg=__givapp(c[iter],s[iter],[g[iter],g[iter+1]])
829                g[iter]=gg[0]
830                g[iter+1]=gg[1]
831    
832          # Update the residual norm          # Update the residual norm
833          rho=abs(g[iter+1])          rho=abs(g[iter+1])
834      iter+=1          iter+=1
835    
836     # At this point either iter > iter_max or rho < tol.     # At this point either iter > iter_max or rho < tol.
837     # It's time to compute x and leave.             # It's time to compute x and leave.
838     if iter > 0 :     if iter > 0 :
839       y=numarray.zeros(iter,numarray.Float64)           y=numpy.zeros(iter,numpy.float64)
840       y[iter-1] = g[iter-1] / h[iter-1][iter-1]       y[iter-1] = g[iter-1] / h[iter-1,iter-1]
841       if iter > 1 :         if iter > 1 :
842          i=iter-2            i=iter-2
843          while i>=0 :          while i>=0 :
844            y[i] = ( g[i] - numarray.dot(h[i][i+1:iter], y[i+1:iter])) / h[i][i]            y[i] = ( g[i] - numpy.dot(h[i,i+1:iter], y[i+1:iter])) / h[i,i]
845            i=i-1            i=i-1
846       xhat=v[iter-1]*y[iter-1]       xhat=v[iter-1]*y[iter-1]
847       for i in range(iter-1):       for i in range(iter-1):
848      xhat += v[i]*y[i]      xhat += v[i]*y[i]
849     else :     else :
850        xhat=v[0] * 0        xhat=v[0] * 0
851    
852     if iter<iter_restart-1:     if iter<iter_restart-1:
853        stopped=iter        stopped=iter
854     else:     else:
855        stopped=-1        stopped=-1
856    
857     return xhat,stopped     return xhat,stopped
858    
859  ##############################################  def GMRES(r, Aprod, x, bilinearform, atol=0, rtol=1.e-8, iter_max=100, iter_restart=20, verbose=False,P_R=None):
860  def GMRES(b, Aprod, Msolve, bilinearform, stoppingcriterium, x=None, iter_max=100, iter_restart=20):     """
861  ################################################     Solver for
862    
863       *Ax=b*
864    
865       with a general operator A (more details required!).
866       It uses the generalized minimum residual method (GMRES).
867    
868       The iteration is terminated if
869    
870       *|r| <= atol+rtol*|r0|*
871    
872       where *r0* is the initial residual and *|.|* is the energy norm. In fact
873    
874       *|r| = sqrt( bilinearform(r,r))*
875    
876       :param r: initial residual *r=b-Ax*. ``r`` is altered.
877       :type r: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
878       :param x: an initial guess for the solution
879       :type x: same like ``r``
880       :param Aprod: returns the value Ax
881       :type Aprod: function ``Aprod(x)`` where ``x`` is of the same object like
882                    argument ``x``. The returned object needs to be of the same
883                    type like argument ``r``.
884       :param bilinearform: inner product ``<x,r>``
885       :type bilinearform: function ``bilinearform(x,r)`` where ``x`` is of the same
886                           type like argument ``x`` and ``r``. The returned value is
887                           a ``float``.
888       :param atol: absolute tolerance
889       :type atol: non-negative ``float``
890       :param rtol: relative tolerance
891       :type rtol: non-negative ``float``
892       :param iter_max: maximum number of iteration steps
893       :type iter_max: ``int``
894       :param iter_restart: in order to save memory the orthogonalization process
895                            is terminated after ``iter_restart`` steps and the
896                            iteration is restarted.
897       :type iter_restart: ``int``
898       :return: the solution approximation and the corresponding residual
899       :rtype: ``tuple``
900       :warning: ``r`` and ``x`` are altered.
901       """
902     m=iter_restart     m=iter_restart
903       restarted=False
904     iter=0     iter=0
905     xc=x     if rtol>0:
906          r_dot_r = bilinearform(r, r)
907          if r_dot_r<0: raise NegativeNorm,"negative norm."
908          atol2=atol+rtol*math.sqrt(r_dot_r)
909          if verbose: print "GMRES: norm of right hand side = %e (absolute tolerance = %e)"%(math.sqrt(r_dot_r), atol2)
910       else:
911          atol2=atol
912          if verbose: print "GMRES: absolute tolerance = %e"%atol2
913       if atol2<=0:
914          raise ValueError,"Non-positive tolarance."
915    
916     while True:     while True:
917        if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached"%iter_max        if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached"%iter_max
918        xc,stopped=__GMRESm(b*1, Aprod, Msolve, bilinearform, stoppingcriterium, x=xc*1, iter_max=iter_max-iter, iter_restart=m)        if restarted:
919             r2 = r-Aprod(x-x2)
920          else:
921             r2=1*r
922          x2=x*1.
923          x,stopped=_GMRESm(r2, Aprod, x, bilinearform, atol2, iter_max=iter_max-iter, iter_restart=m, verbose=verbose,P_R=P_R)
924          iter+=iter_restart
925        if stopped: break        if stopped: break
926        iter+=iter_restart            if verbose: print "GMRES: restart."
927     return xc        restarted=True
928       if verbose: print "GMRES: tolerance has been reached."
929       return x
930    
931  def __GMRESm(b, Aprod, Msolve, bilinearform, stoppingcriterium, x=None, iter_max=100, iter_restart=20):  def _GMRESm(r, Aprod, x, bilinearform, atol, iter_max=100, iter_restart=20, verbose=False, P_R=None):
932     iter=0     iter=0
    r=Msolve(b)  
    r_dot_r = bilinearform(r, r)  
    if r_dot_r<0: raise NegativeNorm,"negative norm."  
    norm_b=math.sqrt(r_dot_r)  
933    
934     if x==None:     h=numpy.zeros((iter_restart+1,iter_restart),numpy.float64)
935        x=0*b     c=numpy.zeros(iter_restart,numpy.float64)
936     else:     s=numpy.zeros(iter_restart,numpy.float64)
937        r=Msolve(b-Aprod(x))     g=numpy.zeros(iter_restart+1,numpy.float64)
       r_dot_r = bilinearform(r, r)  
       if r_dot_r<0: raise NegativeNorm,"negative norm."  
     
    h=numarray.zeros((iter_restart,iter_restart),numarray.Float64)  
    c=numarray.zeros(iter_restart,numarray.Float64)  
    s=numarray.zeros(iter_restart,numarray.Float64)  
    g=numarray.zeros(iter_restart,numarray.Float64)  
938     v=[]     v=[]
939    
940       r_dot_r = bilinearform(r, r)
941       if r_dot_r<0: raise NegativeNorm,"negative norm."
942     rho=math.sqrt(r_dot_r)     rho=math.sqrt(r_dot_r)
943      
944     v.append(r/rho)     v.append(r/rho)
945     g[0]=rho     g[0]=rho
946    
947     while not (stoppingcriterium(rho,norm_b) or iter==iter_restart-1):     if verbose: print "GMRES: initial residual %e (absolute tolerance = %e)"%(rho,atol)
948       while not (rho<=atol or iter==iter_restart):
949    
950      if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached."%iter_max      if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached."%iter_max
951    
952      p=Msolve(Aprod(v[iter]))          if P_R!=None:
953                p=Aprod(P_R(v[iter]))
954            else:
955            p=Aprod(v[iter])
956      v.append(p)      v.append(p)
957    
958      v_norm1=math.sqrt(bilinearform(v[iter+1], v[iter+1]))        v_norm1=math.sqrt(bilinearform(v[iter+1], v[iter+1]))
959    
960    # Modified Gram-Schmidt
961        for j in range(iter+1):
962          h[j,iter]=bilinearform(v[j],v[iter+1])
963          v[iter+1]-=h[j,iter]*v[j]
964    
965  # Modified Gram-Schmidt      h[iter+1,iter]=math.sqrt(bilinearform(v[iter+1],v[iter+1]))
966      for j in range(iter+1):      v_norm2=h[iter+1,iter]
       h[j][iter]=bilinearform(v[j],v[iter+1])    
       v[iter+1]-=h[j][iter]*v[j]  
         
     h[iter+1][iter]=math.sqrt(bilinearform(v[iter+1],v[iter+1]))  
     v_norm2=h[iter+1][iter]  
967    
968  # Reorthogonalize if needed  # Reorthogonalize if needed
969      if v_norm1 + 0.001*v_norm2 == v_norm1:   #Brown/Hindmarsh condition (default)      if v_norm1 + 0.001*v_norm2 == v_norm1:   #Brown/Hindmarsh condition (default)
970       for j in range(iter+1):         for j in range(iter+1):
971          hr=bilinearform(v[j],v[iter+1])          hr=bilinearform(v[j],v[iter+1])
972              h[j][iter]=h[j][iter]+hr              h[j,iter]=h[j,iter]+hr
973              v[iter+1] -= hr*v[j]              v[iter+1] -= hr*v[j]
974    
975       v_norm2=math.sqrt(bilinearform(v[iter+1], v[iter+1]))         v_norm2=math.sqrt(bilinearform(v[iter+1], v[iter+1]))
976       h[iter+1][iter]=v_norm2       h[iter+1,iter]=v_norm2
977    
978  #   watch out for happy breakdown  #   watch out for happy breakdown
979          if not v_norm2 == 0:          if not v_norm2 == 0:
980           v[iter+1]=v[iter+1]/h[iter+1][iter]           v[iter+1]=v[iter+1]/h[iter+1,iter]
981    
982  #   Form and store the information for the new Givens rotation  #   Form and store the information for the new Givens rotation
983      if iter > 0 :      if iter > 0: h[:iter+1,iter]=__givapp(c[:iter],s[:iter],h[:iter+1,iter])
984          hhat=numarray.zeros(iter+1,numarray.Float64)      mu=math.sqrt(h[iter,iter]*h[iter,iter]+h[iter+1,iter]*h[iter+1,iter])
         for i in range(iter+1) : hhat[i]=h[i][iter]  
         hhat=__givapp(c[0:iter],s[0:iter],hhat);  
             for i in range(iter+1) : h[i][iter]=hhat[i]  
   
     mu=math.sqrt(h[iter][iter]*h[iter][iter]+h[iter+1][iter]*h[iter+1][iter])  
985    
986      if mu!=0 :      if mu!=0 :
987          c[iter]=h[iter][iter]/mu          c[iter]=h[iter,iter]/mu
988          s[iter]=-h[iter+1][iter]/mu          s[iter]=-h[iter+1,iter]/mu
989          h[iter][iter]=c[iter]*h[iter][iter]-s[iter]*h[iter+1][iter]          h[iter,iter]=c[iter]*h[iter,iter]-s[iter]*h[iter+1,iter]
990          h[iter+1][iter]=0.0          h[iter+1,iter]=0.0
991          g[iter:iter+2]=__givapp(c[iter],s[iter],g[iter:iter+2])                  gg=__givapp(c[iter],s[iter],[g[iter],g[iter+1]])
992                    g[iter]=gg[0]
993                    g[iter+1]=gg[1]
994  # Update the residual norm  # Update the residual norm
995                  
996          rho=abs(g[iter+1])          rho=abs(g[iter+1])
997            if verbose: print "GMRES: iteration step %s: residual %e"%(iter,rho)
998      iter+=1      iter+=1
999    
1000  # At this point either iter > iter_max or rho < tol.  # At this point either iter > iter_max or rho < tol.
1001  # It's time to compute x and leave.          # It's time to compute x and leave.
1002    
1003     if iter > 0 :     if verbose: print "GMRES: iteration stopped after %s step."%iter
1004       y=numarray.zeros(iter,numarray.Float64)         if iter > 0 :
1005       y[iter-1] = g[iter-1] / h[iter-1][iter-1]       y=numpy.zeros(iter,numpy.float64)
1006       if iter > 1 :         y[iter-1] = g[iter-1] / h[iter-1,iter-1]
1007          i=iter-2         if iter > 1 :
1008            i=iter-2
1009          while i>=0 :          while i>=0 :
1010            y[i] = ( g[i] - numarray.dot(h[i][i+1:iter], y[i+1:iter])) / h[i][i]            y[i] = ( g[i] - numpy.dot(h[i,i+1:iter], y[i+1:iter])) / h[i,i]
1011            i=i-1            i=i-1
1012       xhat=v[iter-1]*y[iter-1]       xhat=v[iter-1]*y[iter-1]
1013       for i in range(iter-1):       for i in range(iter-1):
1014      xhat += v[i]*y[i]      xhat += v[i]*y[i]
1015     else : xhat=v[0]     else:
1016         xhat=v[0] * 0
1017     x += xhat     if P_R!=None:
1018     if iter<iter_restart-1:        x += P_R(xhat)
1019        stopped=True     else:
1020     else:        x += xhat
1021       if iter<iter_restart-1:
1022          stopped=True
1023       else:
1024        stopped=False        stopped=False
1025    
1026     return x,stopped     return x,stopped
1027    
1028  #################################################  def MINRES(r, Aprod, x, Msolve, bilinearform, atol=0, rtol=1.e-8, iter_max=100):
1029  def MINRES(b, Aprod, Msolve, bilinearform, stoppingcriterium, x=None, iter_max=100):      """
1030  #################################################      Solver for
1031      #  
1032      #  minres solves the system of linear equations Ax = b      *Ax=b*
1033      #  where A is a symmetric matrix (possibly indefinite or singular)  
1034      #  and b is a given vector.      with a symmetric and positive definite operator A (more details required!).
1035      #        It uses the minimum residual method (MINRES) with preconditioner M
1036      #  "A" may be a dense or sparse matrix (preferably sparse!)      providing an approximation of A.
1037      #  or the name of a function such that  
1038      #               y = A(x)      The iteration is terminated if
1039      #  returns the product y = Ax for any given vector x.  
1040      #      *|r| <= atol+rtol*|r0|*
1041      #  "M" defines a positive-definite preconditioner M = C C'.  
1042      #  "M" may be a dense or sparse matrix (preferably sparse!)      where *r0* is the initial residual and *|.|* is the energy norm. In fact
1043      #  or the name of a function such that  
1044      #  solves the system My = x for any given vector x.      *|r| = sqrt( bilinearform(Msolve(r),r))*
1045      #  
1046      #      For details on the preconditioned conjugate gradient method see the book:
1047        
1048        I{Templates for the Solution of Linear Systems by R. Barrett, M. Berry,
1049        T.F. Chan, J. Demmel, J. Donato, J. Dongarra, V. Eijkhout, R. Pozo,
1050        C. Romine, and H. van der Vorst}.
1051    
1052        :param r: initial residual *r=b-Ax*. ``r`` is altered.
1053        :type r: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
1054        :param x: an initial guess for the solution
1055        :type x: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
1056        :param Aprod: returns the value Ax
1057        :type Aprod: function ``Aprod(x)`` where ``x`` is of the same object like
1058                     argument ``x``. The returned object needs to be of the same
1059                     type like argument ``r``.
1060        :param Msolve: solves Mx=r
1061        :type Msolve: function ``Msolve(r)`` where ``r`` is of the same type like
1062                      argument ``r``. The returned object needs to be of the same
1063                      type like argument ``x``.
1064        :param bilinearform: inner product ``<x,r>``
1065        :type bilinearform: function ``bilinearform(x,r)`` where ``x`` is of the same
1066                            type like argument ``x`` and ``r`` is. The returned value
1067                            is a ``float``.
1068        :param atol: absolute tolerance
1069        :type atol: non-negative ``float``
1070        :param rtol: relative tolerance
1071        :type rtol: non-negative ``float``
1072        :param iter_max: maximum number of iteration steps
1073        :type iter_max: ``int``
1074        :return: the solution approximation and the corresponding residual
1075        :rtype: ``tuple``
1076        :warning: ``r`` and ``x`` are altered.
1077        """
1078      #------------------------------------------------------------------      #------------------------------------------------------------------
1079      # Set up y and v for the first Lanczos vector v1.      # Set up y and v for the first Lanczos vector v1.
1080      # y  =  beta1 P' v1,  where  P = C**(-1).      # y  =  beta1 P' v1,  where  P = C**(-1).
1081      # v is really P' v1.      # v is really P' v1.
1082      #------------------------------------------------------------------      #------------------------------------------------------------------
1083      if x==None:      r1    = r
1084        x=0*b      y = Msolve(r)
1085      else:      beta1 = bilinearform(y,r)
       b += (-1)*Aprod(x)  
1086    
     r1    = b  
     y = Msolve(b)  
     beta1 = bilinearform(y,b)  
   
1087      if beta1< 0: raise NegativeNorm,"negative norm."      if beta1< 0: raise NegativeNorm,"negative norm."
1088    
1089      #  If b = 0 exactly, stop with x = 0.      #  If r = 0 exactly, stop with x
1090      if beta1==0: return x*0.      if beta1==0: return x
1091    
1092      if beta1> 0:      if beta1> 0: beta1  = math.sqrt(beta1)
       beta1  = math.sqrt(beta1)        
1093    
1094      #------------------------------------------------------------------      #------------------------------------------------------------------
1095      # Initialize quantities.      # Initialize quantities.
# Line 1008  def MINRES(b, Aprod, Msolve, bilinearfor Line 1109  def MINRES(b, Aprod, Msolve, bilinearfor
1109      ynorm2 = 0      ynorm2 = 0
1110      cs     = -1      cs     = -1
1111      sn     = 0      sn     = 0
1112      w      = b*0.      w      = r*0.
1113      w2     = b*0.      w2     = r*0.
1114      r2     = r1      r2     = r1
1115      eps    = 0.0001      eps    = 0.0001
1116    
1117      #---------------------------------------------------------------------      #---------------------------------------------------------------------
1118      # Main iteration loop.      # Main iteration loop.
1119      # --------------------------------------------------------------------      # --------------------------------------------------------------------
1120      while not stoppingcriterium(rnorm,Anorm*ynorm,'MINRES'):    #  checks ||r|| < (||A|| ||x||) * TOL      while not rnorm<=atol+rtol*Anorm*ynorm:    #  checks ||r|| < (||A|| ||x||) * TOL
1121    
1122      if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached."%iter_max      if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached."%iter_max
1123          iter    = iter  +  1          iter    = iter  +  1
# Line 1035  def MINRES(b, Aprod, Msolve, bilinearfor Line 1136  def MINRES(b, Aprod, Msolve, bilinearfor
1136          #-----------------------------------------------------------------          #-----------------------------------------------------------------
1137          s = 1/beta                 # Normalize previous vector (in y).          s = 1/beta                 # Normalize previous vector (in y).
1138          v = s*y                    # v = vk if P = I          v = s*y                    # v = vk if P = I
1139        
1140          y      = Aprod(v)          y      = Aprod(v)
1141        
1142          if iter >= 2:          if iter >= 2:
1143            y = y - (beta/oldb)*r1            y = y - (beta/oldb)*r1
1144    
1145          alfa   = bilinearform(v,y)              # alphak          alfa   = bilinearform(v,y)              # alphak
1146          y      += (- alfa/beta)*r2          y      += (- alfa/beta)*r2
1147          r1     = r2          r1     = r2
1148          r2     = y          r2     = y
1149          y = Msolve(r2)          y = Msolve(r2)
# Line 1052  def MINRES(b, Aprod, Msolve, bilinearfor Line 1153  def MINRES(b, Aprod, Msolve, bilinearfor
1153    
1154          beta   = math.sqrt( beta )          beta   = math.sqrt( beta )
1155          tnorm2 = tnorm2 + alfa*alfa + oldb*oldb + beta*beta          tnorm2 = tnorm2 + alfa*alfa + oldb*oldb + beta*beta
1156            
1157          if iter==1:                 # Initialize a few things.          if iter==1:                 # Initialize a few things.
1158            gmax   = abs( alfa )      # alpha1            gmax   = abs( alfa )      # alpha1
1159            gmin   = gmax             # alpha1            gmin   = gmax             # alpha1
# Line 1060  def MINRES(b, Aprod, Msolve, bilinearfor Line 1161  def MINRES(b, Aprod, Msolve, bilinearfor
1161          # Apply previous rotation Qk-1 to get          # Apply previous rotation Qk-1 to get
1162          #   [deltak epslnk+1] = [cs  sn][dbark    0   ]          #   [deltak epslnk+1] = [cs  sn][dbark    0   ]
1163          #   [gbar k dbar k+1]   [sn -cs][alfak betak+1].          #   [gbar k dbar k+1]   [sn -cs][alfak betak+1].
1164        
1165          oldeps = epsln          oldeps = epsln
1166          delta  = cs * dbar  +  sn * alfa  # delta1 = 0         deltak          delta  = cs * dbar  +  sn * alfa  # delta1 = 0         deltak
1167          gbar   = sn * dbar  -  cs * alfa  # gbar 1 = alfa1     gbar k          gbar   = sn * dbar  -  cs * alfa  # gbar 1 = alfa1     gbar k
# Line 1070  def MINRES(b, Aprod, Msolve, bilinearfor Line 1171  def MINRES(b, Aprod, Msolve, bilinearfor
1171          # Compute the next plane rotation Qk          # Compute the next plane rotation Qk
1172    
1173          gamma  = math.sqrt(gbar*gbar+beta*beta)  # gammak          gamma  = math.sqrt(gbar*gbar+beta*beta)  # gammak
1174          gamma  = max(gamma,eps)          gamma  = max(gamma,eps)
1175          cs     = gbar / gamma             # ck          cs     = gbar / gamma             # ck
1176          sn     = beta / gamma             # sk          sn     = beta / gamma             # sk
1177          phi    = cs * phibar              # phik          phi    = cs * phibar              # phik
# Line 1078  def MINRES(b, Aprod, Msolve, bilinearfor Line 1179  def MINRES(b, Aprod, Msolve, bilinearfor
1179    
1180          # Update  x.          # Update  x.
1181    
1182          denom = 1/gamma          denom = 1/gamma
1183          w1    = w2          w1    = w2
1184          w2    = w          w2    = w
1185          w     = (v - oldeps*w1 - delta*w2) * denom          w     = (v - oldeps*w1 - delta*w2) * denom
1186          x     +=  phi*w          x     +=  phi*w
1187    
# Line 1095  def MINRES(b, Aprod, Msolve, bilinearfor Line 1196  def MINRES(b, Aprod, Msolve, bilinearfor
1196    
1197          # Estimate various norms and test for convergence.          # Estimate various norms and test for convergence.
1198    
1199          Anorm  = math.sqrt( tnorm2 )          Anorm  = math.sqrt( tnorm2 )
1200          ynorm  = math.sqrt( ynorm2 )          ynorm  = math.sqrt( ynorm2 )
1201    
1202          rnorm  = phibar          rnorm  = phibar
1203    
1204      return x      return x
1205    
1206  def TFQMR(b, Aprod, Msolve, bilinearform, stoppingcriterium, x=None, iter_max=100):  def TFQMR(r, Aprod, x, bilinearform, atol=0, rtol=1.e-8, iter_max=100):
1207      """
1208      Solver for
1209    
1210  # TFQMR solver for linear systems    *Ax=b*
1211  #  
1212  #    with a general operator A (more details required!).
1213  # initialization    It uses the Transpose-Free Quasi-Minimal Residual method (TFQMR).
 #  
   errtol = math.sqrt(bilinearform(b,b))  
   norm_b=errtol  
   kmax  = iter_max  
   error = []  
   
   if math.sqrt(bilinearform(x,x)) != 0.0:  
     r = b - Aprod(x)  
   else:  
     r = b  
1214    
1215    r=Msolve(r)    The iteration is terminated if
1216    
1217      *|r| <= atol+rtol*|r0|*
1218    
1219      where *r0* is the initial residual and *|.|* is the energy norm. In fact
1220    
1221      *|r| = sqrt( bilinearform(r,r))*
1222    
1223      :param r: initial residual *r=b-Ax*. ``r`` is altered.
1224      :type r: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
1225      :param x: an initial guess for the solution
1226      :type x: same like ``r``
1227      :param Aprod: returns the value Ax
1228      :type Aprod: function ``Aprod(x)`` where ``x`` is of the same object like
1229                   argument ``x``. The returned object needs to be of the same type
1230                   like argument ``r``.
1231      :param bilinearform: inner product ``<x,r>``
1232      :type bilinearform: function ``bilinearform(x,r)`` where ``x`` is of the same
1233                          type like argument ``x`` and ``r``. The returned value is
1234                          a ``float``.
1235      :param atol: absolute tolerance
1236      :type atol: non-negative ``float``
1237      :param rtol: relative tolerance
1238      :type rtol: non-negative ``float``
1239      :param iter_max: maximum number of iteration steps
1240      :type iter_max: ``int``
1241      :rtype: ``tuple``
1242      :warning: ``r`` and ``x`` are altered.
1243      """
1244    u1=0    u1=0
1245    u2=0    u2=0
1246    y1=0    y1=0
1247    y2=0    y2=0
1248    
1249    w = r    w = r
1250    y1 = r    y1 = r
1251    iter = 0    iter = 0
1252    d = 0    d = 0
1253        v = Aprod(y1)
   v = Msolve(Aprod(y1))  
1254    u1 = v    u1 = v
1255      
1256    theta = 0.0;    theta = 0.0;
1257    eta = 0.0;    eta = 0.0;
1258    tau = math.sqrt(bilinearform(r,r))    rho=bilinearform(r,r)
1259    error = [ error, tau ]    if rho < 0: raise NegativeNorm,"negative norm."
1260    rho = tau * tau    tau = math.sqrt(rho)
1261    m=1    norm_r0=tau
1262  #    while tau>atol+rtol*norm_r0:
 #  TFQMR iteration  
 #  
 #  while ( iter < kmax-1 ):  
     
   while not stoppingcriterium(tau*math.sqrt ( m + 1 ),norm_b,'TFQMR'):  
1263      if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached."%iter_max      if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached."%iter_max
1264    
1265      sigma = bilinearform(r,v)      sigma = bilinearform(r,v)
1266        if sigma == 0.0: raise IterationBreakDown,'TFQMR breakdown, sigma=0'
     if ( sigma == 0.0 ):  
       raise 'TFQMR breakdown, sigma=0'  
       
1267    
1268      alpha = rho / sigma      alpha = rho / sigma
1269    
# Line 1162  def TFQMR(b, Aprod, Msolve, bilinearform Line 1273  def TFQMR(b, Aprod, Msolve, bilinearform
1273  #  #
1274        if ( j == 1 ):        if ( j == 1 ):
1275          y2 = y1 - alpha * v          y2 = y1 - alpha * v
1276          u2 = Msolve(Aprod(y2))          u2 = Aprod(y2)
1277    
1278        m = 2 * (iter+1) - 2 + (j+1)        m = 2 * (iter+1) - 2 + (j+1)
1279        if j==0:        if j==0:
1280           w = w - alpha * u1           w = w - alpha * u1
1281           d = y1 + ( theta * theta * eta / alpha ) * d           d = y1 + ( theta * theta * eta / alpha ) * d
1282        if j==1:        if j==1:
# Line 1180  def TFQMR(b, Aprod, Msolve, bilinearform Line 1291  def TFQMR(b, Aprod, Msolve, bilinearform
1291  #  #
1292  #  Try to terminate the iteration at each pass through the loop  #  Try to terminate the iteration at each pass through the loop
1293  #  #
1294       # if ( tau * math.sqrt ( m + 1 ) <= errtol ):      if rho == 0.0: raise IterationBreakDown,'TFQMR breakdown, rho=0'
      #   error = [ error, tau ]  
      #   total_iters = iter  
      #   break  
         
   
     if ( rho == 0.0 ):  
       raise 'TFQMR breakdown, rho=0'  
       
1295    
1296      rhon = bilinearform(r,w)      rhon = bilinearform(r,w)
1297      beta = rhon / rho;      beta = rhon / rho;
1298      rho = rhon;      rho = rhon;
1299      y1 = w + beta * y2;      y1 = w + beta * y2;
1300      u1 = Msolve(Aprod(y1))      u1 = Aprod(y1)
1301      v = u1 + beta * ( u2 + beta * v )      v = u1 + beta * ( u2 + beta * v )
1302      error = [ error, tau ]  
1303      total_iters = iter      iter += 1
       
     iter = iter + 1  
1304    
1305    return x    return x
1306    
# Line 1208  def TFQMR(b, Aprod, Msolve, bilinearform Line 1309  def TFQMR(b, Aprod, Msolve, bilinearform
1309    
1310  class ArithmeticTuple(object):  class ArithmeticTuple(object):
1311     """     """
1312     tuple supporting inplace update x+=y and scaling x=a*y where x,y is an ArithmeticTuple and a is a float.     Tuple supporting inplace update x+=y and scaling x=a*y where ``x,y`` is an
1313       ArithmeticTuple and ``a`` is a float.
1314    
1315     example of usage:     Example of usage::
1316    
1317     from esys.escript import Data         from esys.escript import Data
1318     from numarray import array         from numpy import array
1319     a=Data(...)         a=Data(...)
1320     b=array([1.,4.])         b=array([1.,4.])
1321     x=ArithmeticTuple(a,b)         x=ArithmeticTuple(a,b)
1322     y=5.*x         y=5.*x
1323    
1324     """     """
1325     def __init__(self,*args):     def __init__(self,*args):
1326         """         """
1327         initialize object with elements args.         Initializes object with elements ``args``.
1328    
1329         @param args: tuple of object that support implace add (x+=y) and scaling (x=a*y)         :param args: tuple of objects that support inplace add (x+=y) and
1330                        scaling (x=a*y)
1331         """         """
1332         self.__items=list(args)         self.__items=list(args)
1333    
1334     def __len__(self):     def __len__(self):
1335         """         """
1336         number of items         Returns the number of items.
1337    
1338         @return: number of items         :return: number of items
1339         @rtype: C{int}         :rtype: ``int``
1340         """         """
1341         return len(self.__items)         return len(self.__items)
1342    
1343     def __getitem__(self,index):     def __getitem__(self,index):
1344         """         """
1345         get an item         Returns item at specified position.
1346    
1347         @param index: item to be returned         :param index: index of item to be returned
1348         @type index: C{int}         :type index: ``int``
1349         @return: item with index C{index}         :return: item with index ``index``
1350         """         """
1351         return self.__items.__getitem__(index)         return self.__items.__getitem__(index)
1352    
1353     def __mul__(self,other):     def __mul__(self,other):
1354         """         """
1355         scaling from the right         Scales by ``other`` from the right.
1356    
1357         @param other: scaling factor         :param other: scaling factor
1358         @type other: C{float}         :type other: ``float``
1359         @return: itemwise self*other         :return: itemwise self*other
1360         @rtype: L{ArithmeticTuple}         :rtype: `ArithmeticTuple`
1361         """         """
1362         out=[]         out=[]
1363         try:           try:
1364             l=len(other)             l=len(other)
1365             if l!=len(self):             if l!=len(self):
1366                 raise ValueError,"length of of arguments don't match."                 raise ValueError,"length of arguments don't match."
1367             for i in range(l): out.append(self[i]*other[i])             for i in range(l): out.append(self[i]*other[i])
1368         except TypeError:         except TypeError:
1369         for i in range(len(self)): out.append(self[i]*other)             for i in range(len(self)): out.append(self[i]*other)
1370         return ArithmeticTuple(*tuple(out))         return ArithmeticTuple(*tuple(out))
1371    
1372     def __rmul__(self,other):     def __rmul__(self,other):
1373         """         """
1374         scaling from the left         Scales by ``other`` from the left.
1375    
1376         @param other: scaling factor         :param other: scaling factor
1377         @type other: C{float}         :type other: ``float``
1378         @return: itemwise other*self         :return: itemwise other*self
1379         @rtype: L{ArithmeticTuple}         :rtype: `ArithmeticTuple`
1380         """         """
1381         out=[]         out=[]
1382         try:           try:
1383             l=len(other)             l=len(other)
1384             if l!=len(self):             if l!=len(self):
1385                 raise ValueError,"length of of arguments don't match."                 raise ValueError,"length of arguments don't match."
1386             for i in range(l): out.append(other[i]*self[i])             for i in range(l): out.append(other[i]*self[i])
1387         except TypeError:         except TypeError:
1388         for i in range(len(self)): out.append(other*self[i])             for i in range(len(self)): out.append(other*self[i])
1389         return ArithmeticTuple(*tuple(out))         return ArithmeticTuple(*tuple(out))
1390    
1391     def __div__(self,other):     def __div__(self,other):
1392         """         """
1393         dividing from the right         Scales by (1/``other``) from the right.
1394    
1395         @param other: scaling factor         :param other: scaling factor
1396         @type other: C{float}         :type other: ``float``
1397         @return: itemwise self/other         :return: itemwise self/other
1398         @rtype: L{ArithmeticTuple}         :rtype: `ArithmeticTuple`
1399         """         """
1400         return self*(1/other)         return self*(1/other)
1401    
1402     def __rdiv__(self,other):     def __rdiv__(self,other):
1403         """         """
1404         dividing from the left         Scales by (1/``other``) from the left.
1405    
1406         @param other: scaling factor         :param other: scaling factor
1407         @type other: C{float}         :type other: ``float``
1408         @return: itemwise other/self         :return: itemwise other/self
1409         @rtype: L{ArithmeticTuple}         :rtype: `ArithmeticTuple`
1410         """         """
1411         out=[]         out=[]
1412         try:           try:
1413             l=len(other)             l=len(other)
1414             if l!=len(self):             if l!=len(self):
1415                 raise ValueError,"length of of arguments don't match."                 raise ValueError,"length of arguments don't match."
1416             for i in range(l): out.append(other[i]/self[i])             for i in range(l): out.append(other[i]/self[i])
1417         except TypeError:         except TypeError:
1418         for i in range(len(self)): out.append(other/self[i])             for i in range(len(self)): out.append(other/self[i])
1419         return ArithmeticTuple(*tuple(out))         return ArithmeticTuple(*tuple(out))
1420      
1421     def __iadd__(self,other):     def __iadd__(self,other):
1422         """         """
1423         in-place add of other to self         Inplace addition of ``other`` to self.
1424    
1425         @param other: increment         :param other: increment
1426         @type other: C{ArithmeticTuple}         :type other: ``ArithmeticTuple``
1427         """         """
1428         if len(self) != len(other):         if len(self) != len(other):
1429             raise ValueError,"tuple length must match."             raise ValueError,"tuple lengths must match."
1430         for i in range(len(self)):         for i in range(len(self)):
1431             self.__items[i]+=other[i]             self.__items[i]+=other[i]
1432         return self         return self
1433    
1434     def __add__(self,other):     def __add__(self,other):
1435         """         """
1436         add other to self         Adds ``other`` to self.
1437    
1438         @param other: increment         :param other: increment
1439         @type other: C{ArithmeticTuple}         :type other: ``ArithmeticTuple``
1440         """         """
1441         out=[]         out=[]
1442         try:           try:
1443             l=len(other)             l=len(other)
1444             if l!=len(self):             if l!=len(self):
1445                 raise ValueError,"length of of arguments don't match."                 raise ValueError,"length of arguments don't match."
1446             for i in range(l): out.append(self[i]+other[i])             for i in range(l): out.append(self[i]+other[i])
1447         except TypeError:         except TypeError:
1448         for i in range(len(self)): out.append(self[i]+other)             for i in range(len(self)): out.append(self[i]+other)
1449         return ArithmeticTuple(*tuple(out))         return ArithmeticTuple(*tuple(out))
1450    
1451     def __sub__(self,other):     def __sub__(self,other):
1452         """         """
1453         subtract other from self         Subtracts ``other`` from self.
1454    
1455         @param other: increment         :param other: decrement
1456         @type other: C{ArithmeticTuple}         :type other: ``ArithmeticTuple``
1457         """         """
1458         out=[]         out=[]
1459         try:           try:
1460             l=len(other)             l=len(other)
1461             if l!=len(self):             if l!=len(self):
1462                 raise ValueError,"length of of arguments don't match."                 raise ValueError,"length of arguments don't match."
1463             for i in range(l): out.append(self[i]-other[i])             for i in range(l): out.append(self[i]-other[i])
1464         except TypeError:         except TypeError:
1465         for i in range(len(self)): out.append(self[i]-other)             for i in range(len(self)): out.append(self[i]-other)
1466         return ArithmeticTuple(*tuple(out))         return ArithmeticTuple(*tuple(out))
1467      
1468     def __isub__(self,other):     def __isub__(self,other):
1469         """         """
1470         subtract other from self         Inplace subtraction of ``other`` from self.
1471    
1472         @param other: increment         :param other: decrement
1473         @type other: C{ArithmeticTuple}         :type other: ``ArithmeticTuple``
1474         """         """
1475         if len(self) != len(other):         if len(self) != len(other):
1476             raise ValueError,"tuple length must match."             raise ValueError,"tuple length must match."
# Line 1377  class ArithmeticTuple(object): Line 1480  class ArithmeticTuple(object):
1480    
1481     def __neg__(self):     def __neg__(self):
1482         """         """
1483         negate         Negates values.
   
1484         """         """
1485         out=[]         out=[]
1486         for i in range(len(self)):         for i in range(len(self)):
# Line 1388  class ArithmeticTuple(object): Line 1490  class ArithmeticTuple(object):
1490    
1491  class HomogeneousSaddlePointProblem(object):  class HomogeneousSaddlePointProblem(object):
1492        """        """
1493        This provides a framwork for solving linear homogeneous saddle point problem of the form        This class provides a framework for solving linear homogeneous saddle
1494          point problems of the form::
              Av+B^*p=f  
              Bv    =0  
1495    
1496        for the unknowns v and p and given operators A and B and given right hand side f.            *Av+B^*p=f*
1497        B^* is the adjoint operator of B is the given inner product.            *Bv     =0*
1498    
1499          for the unknowns *v* and *p* and given operators *A* and *B* and
1500          given right hand side *f*. *B^** is the adjoint operator of *B*.
1501          *A* may depend weakly on *v* and *p*.
1502        """        """
1503        def __init__(self,**kwargs):        def __init__(self, **kwargs):
1504        """
1505        initializes the saddle point problem
1506        """
1507            self.resetControlParameters()
1508          self.setTolerance()          self.setTolerance()
1509          self.setToleranceReductionFactor()          self.setAbsoluteTolerance()
1510          def resetControlParameters(self,gamma=0.85, gamma_min=1.e-8,chi_max=0.1, omega_div=0.2, omega_conv=1.1, rtol_min=1.e-7, rtol_max=0.1, chi=1., C_p=1., C_v=1., safety_factor=0.3):
1511             """
1512             sets a control parameter
1513    
1514             :param gamma: ``1/(1-gamma)`` controls the perturbation of the converegence rate due to termination errors in the subproblems.
1515             :type gamma: ``float``
1516             :param gamma_min: minimum value for ``gamma``.
1517             :type gamma_min: ``float``
1518             :param chi_max: maximum tolerable converegence rate.
1519             :type chi_max: ``float``
1520             :param omega_div: reduction fact for ``gamma`` if no convergence is detected.
1521             :type omega_div: ``float``
1522             :param omega_conv: raise fact for ``gamma`` if convergence is detected.
1523             :type omega_conv: ``float``
1524             :param rtol_min: minimum relative tolerance used to calculate presssure and velocity increment.
1525             :type rtol_min: ``float``
1526             :param rtol_max: maximuim relative tolerance used to calculate presssure and velocity increment.
1527             :type rtol_max: ``float``
1528             :param chi: initial convergence measure.
1529             :type chi: ``float``
1530             :param C_p: initial value for constant to adjust pressure tolerance
1531             :type C_p: ``float``
1532             :param C_v: initial value for constant to adjust velocity tolerance
1533             :type C_v: ``float``
1534             :param safety_factor: safety factor for addjustment of pressure and velocity tolerance from stopping criteria
1535             :type safety_factor: ``float``
1536             """
1537             self.setControlParameter(gamma, gamma_min ,chi_max , omega_div , omega_conv, rtol_min , rtol_max, chi,C_p, C_v,safety_factor)
1538    
1539          def setControlParameter(self,gamma=None, gamma_min=None ,chi_max=None, omega_div=None, omega_conv=None, rtol_min=None, rtol_max=None, chi=None, C_p=None, C_v=None, safety_factor=None):
1540             """
1541             sets a control parameter
1542    
1543             :param gamma: ``1/(1-gamma)`` controls the perturbation of the converegence rate due to termination errors in the subproblems.
1544             :type gamma: ``float``
1545             :param gamma_min: minimum value for ``gamma``.
1546             :type gamma_min: ``float``
1547             :param chi_max: maximum tolerable converegence rate.
1548             :type chi_max: ``float``
1549             :param omega_div: reduction fact for ``gamma`` if no convergence is detected.
1550             :type omega_div: ``float``
1551             :param omega_conv: raise fact for ``gamma`` if convergence is detected.
1552             :type omega_conv: ``float``
1553             :param rtol_min: minimum relative tolerance used to calculate presssure and velocity increment.
1554             :type rtol_min: ``float``
1555             :param rtol_max: maximuim relative tolerance used to calculate presssure and velocity increment.
1556             :type rtol_max: ``float``
1557             :param chi: initial convergence measure.
1558             :type chi: ``float``
1559             :param C_p: initial value for constant to adjust pressure tolerance
1560             :type C_p: ``float``
1561             :param C_v: initial value for constant to adjust velocity tolerance
1562             :type C_v: ``float``
1563             :param safety_factor: safety factor for addjustment of pressure and velocity tolerance from stopping criteria
1564             :type safety_factor: ``float``
1565             """
1566             if not gamma == None:
1567                if gamma<=0 or gamma>=1:
1568                   raise ValueError,"Initial gamma needs to be positive and less than 1."
1569             else:
1570                gamma=self.__gamma
1571    
1572             if not gamma_min == None:
1573                if gamma_min<=0 or gamma_min>=1:
1574                   raise ValueError,"gamma_min needs to be positive and less than 1."
1575             else:
1576                gamma_min = self.__gamma_min
1577    
1578             if not chi_max == None:
1579                if chi_max<=0 or chi_max>=1:
1580                   raise ValueError,"chi_max needs to be positive and less than 1."
1581             else:
1582                chi_max = self.__chi_max
1583    
1584             if not omega_div == None:
1585                if omega_div<=0 or omega_div >=1:
1586                   raise ValueError,"omega_div needs to be positive and less than 1."
1587             else:
1588                omega_div=self.__omega_div
1589    
1590             if not omega_conv == None:
1591                if omega_conv<1:
1592                   raise ValueError,"omega_conv needs to be greater or equal to 1."
1593             else:
1594                omega_conv=self.__omega_conv
1595    
1596             if not rtol_min == None:
1597                if rtol_min<=0 or rtol_min>=1:
1598                   raise ValueError,"rtol_min needs to be positive and less than 1."
1599             else:
1600                rtol_min=self.__rtol_min
1601    
1602             if not rtol_max == None:
1603                if rtol_max<=0 or rtol_max>=1:
1604                   raise ValueError,"rtol_max needs to be positive and less than 1."
1605             else:
1606                rtol_max=self.__rtol_max
1607    
1608             if not chi == None:
1609                if chi<=0 or chi>1:
1610                   raise ValueError,"chi needs to be positive and less than 1."
1611             else:
1612                chi=self.__chi
1613    
1614             if not C_p == None:
1615                if C_p<1:
1616                   raise ValueError,"C_p need to be greater or equal to 1."
1617             else:
1618                C_p=self.__C_p
1619    
1620             if not C_v == None:
1621                if C_v<1:
1622                   raise ValueError,"C_v need to be greater or equal to 1."
1623             else:
1624                C_v=self.__C_v
1625    
1626             if not safety_factor == None:
1627                if safety_factor<=0 or safety_factor>1:
1628                   raise ValueError,"safety_factor need to be between zero and one."
1629             else:
1630                safety_factor=self.__safety_factor
1631    
1632             if gamma<gamma_min:
1633                   raise ValueError,"gamma = %e needs to be greater or equal gamma_min = %e."%(gamma,gamma_min)
1634             if rtol_max<=rtol_min:
1635                   raise ValueError,"rtol_max = %e needs to be greater rtol_min = %e."%(rtol_max,rtol_min)
1636                
1637             self.__gamma = gamma
1638             self.__gamma_min = gamma_min
1639             self.__chi_max = chi_max
1640             self.__omega_div = omega_div
1641             self.__omega_conv = omega_conv
1642             self.__rtol_min = rtol_min
1643             self.__rtol_max = rtol_max
1644             self.__chi = chi
1645             self.__C_p = C_p
1646             self.__C_v = C_v
1647             self.__safety_factor = safety_factor
1648    
1649          #=============================================================
1650        def initialize(self):        def initialize(self):
1651          """          """
1652          initialize the problem (overwrite)          Initializes the problem (overwrite).
1653          """          """
1654          pass          pass
       def B(self,v):  
          """  
          returns Bv (overwrite)  
          @rtype: equal to the type of p  
1655    
1656           @note: boundary conditions on p should be zero!        def inner_pBv(self,p,Bv):
1657           """           """
1658           pass           Returns inner product of element p and Bv (overwrite).
1659    
1660        def inner(self,p0,p1):           :param p: a pressure increment
1661             :param Bv: a residual
1662             :return: inner product of element p and Bv
1663             :rtype: ``float``
1664             :note: used if PCG is applied.
1665           """           """
1666           returns inner product of two element p0 and p1  (overwrite)           raise NotImplementedError,"no inner product for p and Bv implemented."
           
          @type p0: equal to the type of p  
          @type p1: equal to the type of p  
          @rtype: C{float}  
1667    
1668           @rtype: equal to the type of p        def inner_p(self,p0,p1):
1669           """           """
1670           pass           Returns inner product of p0 and p1 (overwrite).
1671    
1672        def solve_A(self,u,p):           :param p0: a pressure
1673             :param p1: a pressure
1674             :return: inner product of p0 and p1
1675             :rtype: ``float``
1676           """           """
1677           solves Av=f-Au-B^*p with accuracy self.getReducedTolerance() (overwrite)           raise NotImplementedError,"no inner product for p implemented."
1678      
1679          def norm_v(self,v):
1680             """
1681             Returns the norm of v (overwrite).
1682    
1683           @rtype: equal to the type of v           :param v: a velovity
1684           @note: boundary conditions on v should be zero!           :return: norm of v
1685             :rtype: non-negative ``float``
1686           """           """
1687           pass           raise NotImplementedError,"no norm of v implemented."
1688          def getDV(self, p, v, tol):
1689             """
1690             return a correction to the value for a given v and a given p with accuracy `tol` (overwrite)
1691    
1692        def solve_prec(self,p):           :param p: pressure
1693             :param v: pressure
1694             :return: dv given as *dv= A^{-1} (f-A v-B^*p)*
1695             :note: Only *A* may depend on *v* and *p*
1696           """           """
1697           provides a preconditioner for BA^{-1}B^* with accuracy self.getReducedTolerance() (overwrite)           raise NotImplementedError,"no dv calculation implemented."
1698    
1699           @rtype: equal to the type of p          
1700          def Bv(self,v, tol):
1701            """
1702            Returns Bv with accuracy `tol` (overwrite)
1703    
1704            :rtype: equal to the type of p
1705            :note: boundary conditions on p should be zero!
1706            """
1707            raise NotImplementedError, "no operator B implemented."
1708    
1709          def norm_Bv(self,Bv):
1710            """
1711            Returns the norm of Bv (overwrite).
1712    
1713            :rtype: equal to the type of p
1714            :note: boundary conditions on p should be zero!
1715            """
1716            raise NotImplementedError, "no norm of Bv implemented."
1717    
1718          def solve_AinvBt(self,dp, tol):
1719           """           """
1720           pass           Solves *A dv=B^*dp* with accuracy `tol`
1721    
1722        def stoppingcriterium(self,Bv,v,p):           :param dp: a pressure increment
1723             :return: the solution of *A dv=B^*dp*
1724             :note: boundary conditions on dv should be zero! *A* is the operator used in ``getDV`` and must not be altered.
1725           """           """
1726           returns a True if iteration is terminated. (overwrite)           raise NotImplementedError,"no operator A implemented."
1727    
1728           @rtype: C{bool}        def solve_prec(self,Bv, tol):
1729           """           """
1730           pass           Provides a preconditioner for *(BA^{-1}B^ * )* applied to Bv with accuracy `tol`
               
       def __inner(self,p,r):  
          return self.inner(p,r[1])  
1731    
1732        def __inner_p(self,p1,p2):           :rtype: equal to the type of p
1733           return self.inner(p1,p2)           :note: boundary conditions on p should be zero!
1734                   """
1735        def __inner_a(self,a1,a2):           raise NotImplementedError,"no preconditioner for Schur complement implemented."
1736           return self.inner_a(a1,a2)        #=============================================================
1737          def __Aprod_PCG(self,dp):
1738              dv=self.solve_AinvBt(dp, self.__subtol)
1739              return ArithmeticTuple(dv,self.Bv(dv, self.__subtol))
1740    
1741          def __inner_PCG(self,p,r):
1742             return self.inner_pBv(p,r[1])
1743    
1744          def __Msolve_PCG(self,r):
1745              return self.solve_prec(r[1], self.__subtol)
1746          #=============================================================
1747          def __Aprod_GMRES(self,p):
1748              return self.solve_prec(self.Bv(self.solve_AinvBt(p, self.__subtol), self.__subtol), self.__subtol)
1749          def __inner_GMRES(self,p0,p1):
1750             return self.inner_p(p0,p1)
1751    
1752          #=============================================================
1753          def norm_p(self,p):
1754              """
1755              calculates the norm of ``p``
1756              
1757              :param p: a pressure
1758              :return: the norm of ``p`` using the inner product for pressure
1759              :rtype: ``float``
1760              """
1761              f=self.inner_p(p,p)
1762              if f<0: raise ValueError,"negative pressure norm."
1763              return math.sqrt(f)
1764          
1765          def solve(self,v,p,max_iter=20, verbose=False, usePCG=True, iter_restart=20, max_correction_steps=10):
1766             """
1767             Solves the saddle point problem using initial guesses v and p.
1768    
1769        def __inner_a1(self,a1,a2):           :param v: initial guess for velocity
1770           return self.inner(a1[1],a2[1])           :param p: initial guess for pressure
1771             :type v: `Data`
1772             :type p: `Data`
1773             :param usePCG: indicates the usage of the PCG rather than GMRES scheme.
1774             :param max_iter: maximum number of iteration steps per correction
1775                              attempt
1776             :param verbose: if True, shows information on the progress of the
1777                             saddlepoint problem solver.
1778             :param iter_restart: restart the iteration after ``iter_restart`` steps
1779                                  (only used if useUzaw=False)
1780             :type usePCG: ``bool``
1781             :type max_iter: ``int``
1782             :type verbose: ``bool``
1783             :type iter_restart: ``int``
1784             :rtype: ``tuple`` of `Data` objects
1785             """
1786             self.verbose=verbose
1787             rtol=self.getTolerance()
1788             atol=self.getAbsoluteTolerance()
1789             correction_step=0
1790             converged=False
1791             error=None
1792             chi=None
1793             gamma=self.__gamma
1794             C_p=self.__C_p
1795             C_v=self.__C_v
1796             while not converged:
1797                  if error== None or chi == None:
1798                      gamma_new=gamma/self.__omega_conv
1799                  else:
1800                     if chi < self.__chi_max:
1801                        gamma_new=min(max(gamma*self.__omega_conv,1-chi*error/(self.__safety_factor*ATOL)), 1-chi/self.__chi_max)
1802                     else:
1803                        gamma_new=gamma*self.__omega_div
1804                  gamma=max(gamma_new, self.__gamma_min)
1805                  # calculate velocity for current pressure:
1806                  rtol_v=min(max(gamma/(1.+gamma)/C_v,self.__rtol_min), self.__rtol_max)
1807                  rtol_p=min(max(gamma/C_p, self.__rtol_min), self.__rtol_max)
1808                  self.__subtol=rtol_p**2
1809                  if self.verbose: print "HomogeneousSaddlePointProblem: step %s: gamma = %e, rtol_v= %e, rtol_p=%e"%(correction_step,gamma,rtol_v,rtol_p)
1810                  if self.verbose: print "HomogeneousSaddlePointProblem: subtolerance: %e"%self.__subtol
1811                  # calculate velocity for current pressure: A*dv= F-A*v-B*p
1812                  dv1=self.getDV(p,v,rtol_v)
1813                  v1=v+dv1
1814                  Bv1=self.Bv(v1, self.__subtol)
1815                  norm_Bv1=self.norm_Bv(Bv1)
1816                  norm_dv1=self.norm_v(dv1)
1817                  norm_v1=self.norm_v(v1)
1818                  ATOL=norm_v1*rtol+atol
1819                  if self.verbose: print "HomogeneousSaddlePointProblem: step %s: Bv = %e, dv = %e, v=%e"%(correction_step,norm_Bv1, norm_dv1, norm_v1)
1820                  if not ATOL>0: raise ValueError,"overall absolute tolerance needs to be positive."
1821                  if max(norm_Bv1, norm_dv1) <= ATOL:
1822                      converged=True
1823                      v=v1
1824                  else:
1825                      # now we solve for the pressure increment dp from B*A^{-1}B^* dp = Bv1
1826                      if usePCG:
1827                        dp,r,a_norm=PCG(ArithmeticTuple(v1,Bv1),self.__Aprod_PCG,0*p,self.__Msolve_PCG,self.__inner_PCG,atol=0, rtol=rtol_p,iter_max=max_iter, verbose=self.verbose)
1828                        v2=r[0]
1829                        Bv2=r[1]
1830                      else:
1831                        dp=GMRES(self.solve_prec(Bv1,self.__subtol),self.__Aprod_GMRES, 0*p, self.__inner_GMRES,atol=0, rtol=rtol_p,iter_max=max_iter, iter_restart=iter_restart, verbose=self.verbose)
1832                        dv2=self.solve_AinvBt(dp, self.__subtol)
1833                        v2=v1-dv2
1834                        Bv2=self.Bv(v2, self.__subtol)
1835                      #
1836                      # convergence indicators:
1837                      #
1838                      norm_v2=self.norm_v(v2)
1839                      norm_dv2=self.norm_v(v2-v)
1840                      error_new=max(norm_dv2, norm_Bv1)
1841                      norm_Bv2=self.norm_Bv(Bv2)
1842                      if self.verbose: print "HomogeneousSaddlePointProblem: step %s (part 2): Bv = %e, dv = %e, v=%e"%(correction_step,norm_Bv2, norm_dv2, norm_v2)
1843                      if error !=None:
1844                          chi_new=error_new/error
1845                          if self.verbose: print "HomogeneousSaddlePointProblem: step %s: convergence rate = %e"%(correction_step,chi_new)
1846                          if chi != None:
1847                              gamma0=max(gamma, 1-chi/chi_new)
1848                              C_p*=gamma0/gamma
1849                              C_v*=gamma0/gamma*(1+gamma)/(1+gamma0)
1850                          chi=chi_new
1851                      error = error_new
1852                      correction_step+=1
1853                      if correction_step>max_correction_steps:
1854                            raise CorrectionFailed,"Given up after %d correction steps."%correction_step
1855                      v,p=v2,p+dp
1856             if self.verbose: print "HomogeneousSaddlePointProblem: tolerance reached after %s steps."%correction_step
1857         return v,p
1858    
1859        def __stoppingcriterium(self,norm_r,r,p):        #========================================================================
1860            return self.stoppingcriterium(r[1],r[0],p)        def setTolerance(self,tolerance=1.e-4):
1861             """
1862             Sets the relative tolerance for (v,p).
1863    
1864        def __stoppingcriterium2(self,norm_r,norm_b,solver='GMRES',TOL=None):           :param tolerance: tolerance to be used
1865            return self.stoppingcriterium2(norm_r,norm_b,solver,TOL)           :type tolerance: non-negative ``float``
1866             """
1867             if tolerance<0:
1868                 raise ValueError,"tolerance must be positive."
1869             self.__rtol=tolerance
1870    
       def setTolerance(self,tolerance=1.e-8):  
               self.__tol=tolerance  
1871        def getTolerance(self):        def getTolerance(self):
1872                return self.__tol           """
1873        def setToleranceReductionFactor(self,reduction=0.01):           Returns the relative tolerance.
1874                self.__reduction=reduction  
1875        def getSubProblemTolerance(self):           :return: relative tolerance
1876                return self.__reduction*self.getTolerance()           :rtype: ``float``
1877             """
1878        def solve(self,v,p,max_iter=20, verbose=False, show_details=False, solver='PCG',iter_restart=20):           return self.__rtol
1879                """  
1880                solves the saddle point problem using initial guesses v and p.        def setAbsoluteTolerance(self,tolerance=0.):
1881             """
1882                @param max_iter: maximum number of iteration steps.           Sets the absolute tolerance.
1883                """  
1884                self.verbose=verbose           :param tolerance: tolerance to be used
1885                self.show_details=show_details and self.verbose           :type tolerance: non-negative ``float``
1886             """
1887                # assume p is known: then v=A^-1(f-B^*p)           if tolerance<0:
1888                # which leads to BA^-1B^*p = BA^-1f                 raise ValueError,"tolerance must be non-negative."
1889             self.__atol=tolerance
1890            # Az=f is solved as A(z-v)=f-Av (z-v = 0 on fixed_u_mask)        
1891            self.__z=v+self.solve_A(v,p*0)        def getAbsoluteTolerance(self):
1892                Bz=self.B(self.__z)           """
1893                #           Returns the absolute tolerance.
1894            #   solve BA^-1B^*p = Bz  
1895                #           :return: absolute tolerance
1896                #           :rtype: ``float``
1897                #           """
1898                self.iter=0           return self.__atol
           if solver=='GMRES':        
                 if self.verbose: print "enter GMRES method (iter_max=%s)"%max_iter  
                 p=GMRES(Bz,self.__Aprod2,self.__Msolve2,self.__inner_p,self.__stoppingcriterium2,iter_max=max_iter, x=p*1.,iter_restart=iter_restart)  
                 # solve Au=f-B^*p  
                 #       A(u-v)=f-B^*p-Av  
                 #       u=v+(u-v)  
         u=v+self.solve_A(v,p)  
   
           if solver=='TFQMR':        
                 if self.verbose: print "enter TFQMR method (iter_max=%s)"%max_iter  
                 p=TFQMR(Bz,self.__Aprod2,self.__Msolve2,self.__inner_p,self.__stoppingcriterium2,iter_max=max_iter, x=p*1.)  
                 # solve Au=f-B^*p  
                 #       A(u-v)=f-B^*p-Av  
                 #       u=v+(u-v)  
         u=v+self.solve_A(v,p)  
   
           if solver=='MINRES':        
                 if self.verbose: print "enter MINRES method (iter_max=%s)"%max_iter  
                 p=MINRES(Bz,self.__Aprod2,self.__Msolve2,self.__inner_p,self.__stoppingcriterium2,iter_max=max_iter, x=p*1.)  
                 # solve Au=f-B^*p  
                 #       A(u-v)=f-B^*p-Av  
                 #       u=v+(u-v)  
         u=v+self.solve_A(v,p)  
                 
           if solver=='GMRESC':        
                 if self.verbose: print "enter GMRES coupled method (iter_max=%s)"%max_iter  
                 p0=self.solve_prec1(Bz)  
             #alfa=(1/self.vol)*util.integrate(util.interpolate(p,escript.Function(self.domain)))  
                 #p-=alfa  
                 x=GMRES(ArithmeticTuple(self.__z*1.,p0*1),self.__Anext,self.__Mempty,self.__inner_a,self.__stoppingcriterium2,iter_max=max_iter, x=ArithmeticTuple(v*1,p*1),iter_restart=20)  
                 #x=NewtonGMRES(ArithmeticTuple(self.__z*1.,p0*1),self.__Aprod_Newton2,self.__Mempty,self.__inner_a,self.__stoppingcriterium2,iter_max=max_iter, x=ArithmeticTuple(v*1,p*1),atol=0,rtol=self.getTolerance())  
   
                 # solve Au=f-B^*p  
                 #       A(u-v)=f-B^*p-Av  
                 #       u=v+(u-v)  
             p=x[1]  
         u=v+self.solve_A(v,p)        
         #p=x[1]  
         #u=x[0]  
   
               if solver=='PCG':  
                 #   note that the residual r=Bz-BA^-1B^*p = B(z-A^-1B^*p) = Bv  
                 #  
                 #   with                    Av=Az-B^*p = f - B^*p (v=z on fixed_u_mask)  
                 #                           A(v-z)= f -Az - B^*p (v-z=0 on fixed_u_mask)  
                 if self.verbose: print "enter PCG method (iter_max=%s)"%max_iter  
                 p,r=PCG(ArithmeticTuple(self.__z*1.,Bz),self.__Aprod,self.__Msolve,self.__inner,self.__stoppingcriterium,iter_max=max_iter, x=p)  
             u=r[0]    
                 # print "DIFF=",util.integrate(p)  
   
               # print "RESULT div(u)=",util.Lsup(self.B(u)),util.Lsup(u)  
   
           return u,p  
   
       def __Msolve(self,r):  
           return self.solve_prec(r[1])  
   
       def __Msolve2(self,r):  
           return self.solve_prec(r*1)  
   
       def __Mempty(self,r):  
           return r  
   
   
       def __Aprod(self,p):  
           # return BA^-1B*p  
           #solve Av =B^*p as Av =f-Az-B^*(-p)  
           v=self.solve_A(self.__z,-p)  
           return ArithmeticTuple(v, self.B(v))  
   
       def __Anext(self,x):  
           # return next v,p  
           #solve Av =-B^*p as Av =f-Az-B^*p  
   
       pc=x[1]  
           v=self.solve_A(self.__z,-pc)  
       p=self.solve_prec1(self.B(v))  
   
           return ArithmeticTuple(v,p)  
   
   
       def __Aprod2(self,p):  
           # return BA^-1B*p  
           #solve Av =B^*p as Av =f-Az-B^*(-p)  
       v=self.solve_A(self.__z,-p)  
           return self.B(v)  
   
       def __Aprod_Newton(self,p):  
           # return BA^-1B*p - Bz  
           #solve Av =-B^*p as Av =f-Az-B^*p  
       v=self.solve_A(self.__z,-p)  
           return self.B(v-self.__z)  
   
       def __Aprod_Newton2(self,x):  
           # return BA^-1B*p - Bz  
           #solve Av =-B^*p as Av =f-Az-B^*p  
           pc=x[1]  
       v=self.solve_A(self.__z,-pc)  
           p=self.solve_prec1(self.B(v-self.__z))  
           return ArithmeticTuple(v,p)  
1899    
1900    
1901  def MaskFromBoundaryTag(domain,*tags):  def MaskFromBoundaryTag(domain,*tags):
1902     """     """
1903     creates a mask on the Solution(domain) function space which one for samples     Creates a mask on the Solution(domain) function space where the value is
1904     that touch the boundary tagged by tags.     one for samples that touch the boundary tagged by tags.
1905    
1906     usage: m=MaskFromBoundaryTag(domain,"left", "right")     Usage: m=MaskFromBoundaryTag(domain, "left", "right")
1907    
1908     @param domain: a given domain     :param domain: domain to be used
1909     @type domain: L{escript.Domain}     :type domain: `escript.Domain`
1910     @param tags: boundray tags     :param tags: boundary tags
1911     @type tags: C{str}     :type tags: ``str``
1912     @return: a mask which marks samples that are touching the boundary tagged by any of the given tags.     :return: a mask which marks samples that are touching the boundary tagged
1913     @rtype: L{escript.Data} of rank 0              by any of the given tags
1914       :rtype: `escript.Data` of rank 0
1915     """     """
1916     pde=linearPDEs.LinearPDE(domain,numEquations=1, numSolutions=1)     pde=linearPDEs.LinearPDE(domain,numEquations=1, numSolutions=1)
1917     d=escript.Scalar(0.,escript.FunctionOnBoundary(domain))     d=escript.Scalar(0.,escript.FunctionOnBoundary(domain))
1918     for t in tags: d.setTaggedValue(t,1.)     for t in tags: d.setTaggedValue(t,1.)
1919     pde.setValue(y=d)     pde.setValue(y=d)
1920     return util.whereNonZero(pde.getRightHandSide())     return util.whereNonZero(pde.getRightHandSide())
 #============================================================================================================================================  
 class SaddlePointProblem(object):  
    """  
    This implements a solver for a saddlepoint problem  
   
    M{f(u,p)=0}  
    M{g(u)=0}  
1921    
1922     for u and p. The problem is solved with an inexact Uszawa scheme for p:  def MaskFromTag(domain,*tags):
   
    M{Q_f (u^{k+1}-u^{k}) = - f(u^{k},p^{k})}  
    M{Q_g (p^{k+1}-p^{k}) =   g(u^{k+1})}  
   
    where Q_f is an approximation of the Jacobiean A_f of f with respect to u  and Q_f is an approximation of  
    A_g A_f^{-1} A_g with A_g is the jacobiean of g with respect to p. As a the construction of a 'proper'  
    Q_g can be difficult, non-linear conjugate gradient method is applied to solve for p, so Q_g plays  
    in fact the role of a preconditioner.  
1923     """     """
1924     def __init__(self,verbose=False,*args):     Creates a mask on the Solution(domain) function space where the value is
1925         """     one for samples that touch regions tagged by tags.
        initializes the problem  
   
        @param verbose: switches on the printing out some information  
        @type verbose: C{bool}  
        @note: this method may be overwritten by a particular saddle point problem  
        """  
        print "SaddlePointProblem should not be used anymore!"  
        if not isinstance(verbose,bool):  
             raise TypeError("verbose needs to be of type bool.")  
        self.__verbose=verbose  
        self.relaxation=1.  
        DeprecationWarning("SaddlePointProblem should not be used anymore.")  
   
    def trace(self,text):  
        """  
        prints text if verbose has been set  
   
        @param text: a text message  
        @type text: C{str}  
        """  
        if self.__verbose: print "%s: %s"%(str(self),text)  
1926    
1927     def solve_f(self,u,p,tol=1.e-8):     Usage: m=MaskFromTag(domain, "ham")
        """  
        solves  
   
        A_f du = f(u,p)  
   
        with tolerance C{tol} and return du. A_f is Jacobiean of f with respect to u.  
   
        @param u: current approximation of u  
        @type u: L{escript.Data}  
        @param p: current approximation of p  
        @type p: L{escript.Data}  
        @param tol: tolerance expected for du  
        @type tol: C{float}  
        @return: increment du  
        @rtype: L{escript.Data}  
        @note: this method has to be overwritten by a particular saddle point problem  
        """  
        pass  
   
    def solve_g(self,u,tol=1.e-8):  
        """  
        solves  
1928    
1929         Q_g dp = g(u)     :param domain: domain to be used
1930       :type domain: `escript.Domain`
1931         with Q_g is a preconditioner for A_g A_f^{-1} A_g with  A_g is the jacobiean of g with respect to p.     :param tags: boundary tags
1932       :type tags: ``str``
1933         @param u: current approximation of u     :return: a mask which marks samples that are touching the boundary tagged
1934         @type u: L{escript.Data}              by any of the given tags
1935         @param tol: tolerance expected for dp     :rtype: `escript.Data` of rank 0
1936         @type tol: C{float}     """
1937         @return: increment dp     pde=linearPDEs.LinearPDE(domain,numEquations=1, numSolutions=1)
1938         @rtype: L{escript.Data}     d=escript.Scalar(0.,escript.Function(domain))
1939         @note: this method has to be overwritten by a particular saddle point problem     for t in tags: d.setTaggedValue(t,1.)
1940         """     pde.setValue(Y=d)
1941         pass     return util.whereNonZero(pde.getRightHandSide())
   
    def inner(self,p0,p1):  
        """  
        inner product of p0 and p1 approximating p. Typically this returns integrate(p0*p1)  
        @return: inner product of p0 and p1  
        @rtype: C{float}  
        """  
        pass  
1942    
    subiter_max=3  
    def solve(self,u0,p0,tolerance=1.e-6,tolerance_u=None,iter_max=100,accepted_reduction=0.995,relaxation=None):  
         """  
         runs the solver  
1943    
         @param u0: initial guess for C{u}  
         @type u0: L{esys.escript.Data}  
         @param p0: initial guess for C{p}  
         @type p0: L{esys.escript.Data}  
         @param tolerance: tolerance for relative error in C{u} and C{p}  
         @type tolerance: positive C{float}  
         @param tolerance_u: tolerance for relative error in C{u} if different from C{tolerance}  
         @type tolerance_u: positive C{float}  
         @param iter_max: maximum number of iteration steps.  
         @type iter_max: C{int}  
         @param accepted_reduction: if the norm  g cannot be reduced by C{accepted_reduction} backtracking to adapt the  
                                    relaxation factor. If C{accepted_reduction=None} no backtracking is used.  
         @type accepted_reduction: positive C{float} or C{None}  
         @param relaxation: initial relaxation factor. If C{relaxation==None}, the last relaxation factor is used.  
         @type relaxation: C{float} or C{None}  
         """  
         tol=1.e-2  
         if tolerance_u==None: tolerance_u=tolerance  
         if not relaxation==None: self.relaxation=relaxation  
         if accepted_reduction ==None:  
               angle_limit=0.  
         elif accepted_reduction>=1.:  
               angle_limit=0.  
         else:  
               angle_limit=util.sqrt(1-accepted_reduction**2)  
         self.iter=0  
         u=u0  
         p=p0  
         #  
         #   initialize things:  
         #  
         converged=False  
         #  
         #  start loop:  
         #  
         #  initial search direction is g  
         #  
         while not converged :  
             if self.iter>iter_max:  
                 raise ArithmeticError("no convergence after %s steps."%self.iter)  
             f_new=self.solve_f(u,p,tol)  
             norm_f_new = util.Lsup(f_new)  
             u_new=u-f_new  
             g_new=self.solve_g(u_new,tol)  
             self.iter+=1  
             norm_g_new = util.sqrt(self.inner(g_new,g_new))  
             if norm_f_new==0. and norm_g_new==0.: return u, p  
             if self.iter>1 and not accepted_reduction==None:  
                #  
                #   did we manage to reduce the norm of G? I  
                #   if not we start a backtracking procedure  
                #  
                # print "new/old norm = ",norm_g_new, norm_g, norm_g_new/norm_g  
                if norm_g_new > accepted_reduction * norm_g:  
                   sub_iter=0  
                   s=self.relaxation  
                   d=g  
                   g_last=g  
                   self.trace("    start substepping: f = %s, g = %s, relaxation = %s."%(norm_f_new, norm_g_new, s))  
                   while sub_iter < self.subiter_max and  norm_g_new > accepted_reduction * norm_g:  
                      dg= g_new-g_last  
                      norm_dg=abs(util.sqrt(self.inner(dg,dg))/self.relaxation)  
                      rad=self.inner(g_new,dg)/self.relaxation  
                      # print "   ",sub_iter,": rad, norm_dg:",abs(rad), norm_dg*norm_g_new * angle_limit  
                      # print "   ",sub_iter,": rad, norm_dg:",rad, norm_dg, norm_g_new, norm_g  
                      if abs(rad) < norm_dg*norm_g_new * angle_limit:  
                          if sub_iter>0: self.trace("    no further improvements expected from backtracking.")  
                          break  
                      r=self.relaxation  
                      self.relaxation= - rad/norm_dg**2  
                      s+=self.relaxation  
                      #####  
                      # a=g_new+self.relaxation*dg/r  
                      # print "predicted new norm = ",util.sqrt(self.inner(a,a)),util.sqrt(self.inner(g_new,g_new)), self.relaxation  
                      #####  
                      g_last=g_new  
                      p+=self.relaxation*d  
                      f_new=self.solve_f(u,p,tol)  
                      u_new=u-f_new  
                      g_new=self.solve_g(u_new,tol)  
                      self.iter+=1  
                      norm_f_new = util.Lsup(f_new)  
                      norm_g_new = util.sqrt(self.inner(g_new,g_new))  
                      # print "   ",sub_iter," new g norm",norm_g_new  
                      self.trace("    %s th sub-step: f = %s, g = %s, relaxation = %s."%(sub_iter, norm_f_new, norm_g_new, s))  
                      #  
                      #   can we expect reduction of g?  
                      #  
                      # u_last=u_new  
                      sub_iter+=1  
                   self.relaxation=s  
             #  
             #  check for convergence:  
             #  
             norm_u_new = util.Lsup(u_new)  
             p_new=p+self.relaxation*g_new  
             norm_p_new = util.sqrt(self.inner(p_new,p_new))  
             self.trace("%s th step: f = %s, f/u = %s, g = %s, g/p = %s, relaxation = %s."%(self.iter, norm_f_new ,norm_f_new/norm_u_new, norm_g_new, norm_g_new/norm_p_new, self.relaxation))  
   
             if self.iter>1:  
                dg2=g_new-g  
                df2=f_new-f  
                norm_dg2=util.sqrt(self.inner(dg2,dg2))  
                norm_df2=util.Lsup(df2)  
                # print norm_g_new, norm_g, norm_dg, norm_p, tolerance  
                tol_eq_g=tolerance*norm_dg2/(norm_g*abs(self.relaxation))*norm_p_new  
                tol_eq_f=tolerance_u*norm_df2/norm_f*norm_u_new  
                if norm_g_new <= tol_eq_g and norm_f_new <= tol_eq_f:  
                    converged=True  
             f, norm_f, u, norm_u, g, norm_g, p, norm_p = f_new, norm_f_new, u_new, norm_u_new, g_new, norm_g_new, p_new, norm_p_new  
         self.trace("convergence after %s steps."%self.iter)  
         return u,p  

Legend:
Removed from v.1956  
changed lines
  Added in v.2745

  ViewVC Help
Powered by ViewVC 1.1.26