/[escript]/trunk/escript/py_src/pdetools.py
ViewVC logotype

Diff of /trunk/escript/py_src/pdetools.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 351 by gross, Tue Dec 13 09:12:15 2005 UTC revision 2862 by gross, Thu Jan 21 04:45:39 2010 UTC
# Line 1  Line 1 
1  # $Id$  
2    ########################################################
3    #
4    # Copyright (c) 2003-2009 by University of Queensland
5    # Earth Systems Science Computational Center (ESSCC)
6    # http://www.uq.edu.au/esscc
7    #
8    # Primary Business: Queensland, Australia
9    # Licensed under the Open Software License version 3.0
10    # http://www.opensource.org/licenses/osl-3.0.php
11    #
12    ########################################################
13    
14    __copyright__="""Copyright (c) 2003-2009 by University of Queensland
15    Earth Systems Science Computational Center (ESSCC)
16    http://www.uq.edu.au/esscc
17    Primary Business: Queensland, Australia"""
18    __license__="""Licensed under the Open Software License version 3.0
19    http://www.opensource.org/licenses/osl-3.0.php"""
20    __url__="https://launchpad.net/escript-finley"
21    
22  """  """
23  Provides some tools related to PDEs.  Provides some tools related to PDEs.
24    
25  Currently includes:  Currently includes:
26      - Projector - to project a discontinuous      - Projector - to project a discontinuous function onto a continuous function
27      - Locator - to trace values in data objects at a certain location      - Locator - to trace values in data objects at a certain location
28      - TimeIntegrationManager - to handel extraplotion in time      - TimeIntegrationManager - to handle extrapolation in time
29        - SaddlePointProblem - solver for Saddle point problems using the inexact uszawa scheme
30    
31    :var __author__: name of author
32    :var __copyright__: copyrights
33    :var __license__: licence agreement
34    :var __url__: url entry point on documentation
35    :var __version__: version
36    :var __date__: date of the version
37  """  """
38    
39    __author__="Lutz Gross, l.gross@uq.edu.au"
40    
41    
42  import escript  import escript
43  import linearPDEs  import linearPDEs
44  import numarray  import numpy
45  import util  import util
46    import math
47    
48  class TimeIntegrationManager:  class TimeIntegrationManager:
49    """    """
50    a simple mechanism to manage time dependend values.    A simple mechanism to manage time dependend values.
51    
52    typical usage is:    Typical usage is::
53    
54    dt=0.1 # time increment       dt=0.1 # time increment
55    tm=TimeIntegrationManager(inital_value,p=1)       tm=TimeIntegrationManager(inital_value,p=1)
56    while t<1.       while t<1.
57        v_guess=tm.extrapolate(dt) # extrapolate to t+dt           v_guess=tm.extrapolate(dt) # extrapolate to t+dt
58        v=...           v=...
59        tm.checkin(dt,v)           tm.checkin(dt,v)
60        t+=dt           t+=dt
61    
62    @remark: currently only p=1 is supported.    :note: currently only p=1 is supported.
63    """    """
64    def __init__(self,*inital_values,**kwargs):    def __init__(self,*inital_values,**kwargs):
65       """       """
66       sets up the value manager where inital_value is the initial value and p is order used for extrapolation       Sets up the value manager where ``inital_values`` are the initial values
67         and p is the order used for extrapolation.
68       """       """
69       if kwargs.has_key("p"):       if kwargs.has_key("p"):
70              self.__p=kwargs["p"]              self.__p=kwargs["p"]
# Line 50  class TimeIntegrationManager: Line 82  class TimeIntegrationManager:
82    def getTime(self):    def getTime(self):
83        return self.__t        return self.__t
84    
85      def getValue(self):
86          out=self.__v_mem[0]
87          if len(out)==1:
88              return out[0]
89          else:
90              return out
91    
92    def checkin(self,dt,*values):    def checkin(self,dt,*values):
93        """        """
94        adds new values to the manager. the p+1 last value get lost        Adds new values to the manager. The p+1 last values are lost.
95        """        """
96        o=min(self.__order+1,self.__p)        o=min(self.__order+1,self.__p)
97        self.__order=min(self.__order+1,self.__p)        self.__order=min(self.__order+1,self.__p)
# Line 69  class TimeIntegrationManager: Line 108  class TimeIntegrationManager:
108    
109    def extrapolate(self,dt):    def extrapolate(self,dt):
110        """        """
111        extrapolates to dt forward in time.        Extrapolates to ``dt`` forward in time.
112        """        """
113        if self.__order==0:        if self.__order==0:
114           out=self.__v_mem[0]           out=self.__v_mem[0]
# Line 85  class TimeIntegrationManager: Line 124  class TimeIntegrationManager:
124        else:        else:
125           return out           return out
126    
127    
128  class Projector:  class Projector:
129    """    """
130    The Projector is a factory which projects a discontiuous function onto a    The Projector is a factory which projects a discontinuous function onto a
131    continuous function on the a given domain.    continuous function on a given domain.
132    """    """
133    def __init__(self, domain, reduce = True, fast=True):    def __init__(self, domain, reduce=True, fast=True):
134      """      """
135      Create a continuous function space projector for a domain.      Creates a continuous function space projector for a domain.
136    
137      @param domain: Domain of the projection.      :param domain: Domain of the projection.
138      @param reduce: Flag to reduce projection order (default is True)      :param reduce: Flag to reduce projection order
139      @param fast: Flag to use a fast method based on matrix lumping (default is true)      :param fast: Flag to use a fast method based on matrix lumping
140      """      """
141      self.__pde = linearPDEs.LinearPDE(domain)      self.__pde = linearPDEs.LinearPDE(domain)
142      if fast:      if fast:
143        self.__pde.setSolverMethod(linearPDEs.LinearPDE.LUMPING)          self.__pde.getSolverOptions().setSolverMethod(linearPDEs.SolverOptions.LUMPING)
144      self.__pde.setSymmetryOn()      self.__pde.setSymmetryOn()
145      self.__pde.setReducedOrderTo(reduce)      self.__pde.setReducedOrderTo(reduce)
146      self.__pde.setValue(D = 1.)      self.__pde.setValue(D = 1.)
147      return      return
148      def getSolverOptions(self):
149    def __del__(self):      """
150      return      Returns the solver options of the PDE solver.
151        
152        :rtype: `linearPDEs.SolverOptions`
153        """
154        return self.__pde.getSolverOptions()
155    
156    def __call__(self, input_data):    def __call__(self, input_data):
157      """      """
158      Projects input_data onto a continuous function      Projects ``input_data`` onto a continuous function.
159    
160      @param input_data: The input_data to be projected.      :param input_data: the data to be projected
161      """      """
162      out=escript.Data(0.,input_data.getShape(),what=escript.ContinuousFunction(self.__pde.getDomain()))      out=escript.Data(0.,input_data.getShape(),self.__pde.getFunctionSpaceForSolution())
163        self.__pde.setValue(Y = escript.Data(), Y_reduced = escript.Data())
164      if input_data.getRank()==0:      if input_data.getRank()==0:
165          self.__pde.setValue(Y = input_data)          self.__pde.setValue(Y = input_data)
166          out=self.__pde.getSolution()          out=self.__pde.getSolution()
# Line 143  class Projector: Line 188  class Projector:
188                      out[i0,i1,i2,i3]=self.__pde.getSolution()                      out[i0,i1,i2,i3]=self.__pde.getSolution()
189      return out      return out
190    
191    class NoPDE:
192         """
193         Solves the following problem for u:
194    
195         *kronecker[i,j]*D[j]*u[j]=Y[i]*
196    
197         with constraint
198    
199         *u[j]=r[j]*  where *q[j]>0*
200    
201         where *D*, *Y*, *r* and *q* are given functions of rank 1.
202    
203         In the case of scalars this takes the form
204    
205         *D*u=Y*
206    
207         with constraint
208    
209         *u=r* where *q>0*
210    
211         where *D*, *Y*, *r* and *q* are given scalar functions.
212    
213         The constraint overwrites any other condition.
214    
215         :note: This class is similar to the `linearPDEs.LinearPDE` class with
216                A=B=C=X=0 but has the intention that all input parameters are given
217                in `Solution` or `ReducedSolution`.
218         """
219         # The whole thing is a bit strange and I blame Rob Woodcock (CSIRO) for
220         # this.
221         def __init__(self,domain,D=None,Y=None,q=None,r=None):
222             """
223             Initializes the problem.
224    
225             :param domain: domain of the PDE
226             :type domain: `Domain`
227             :param D: coefficient of the solution
228             :type D: ``float``, ``int``, ``numpy.ndarray``, `Data`
229             :param Y: right hand side
230             :type Y: ``float``, ``int``, ``numpy.ndarray``, `Data`
231             :param q: location of constraints
232             :type q: ``float``, ``int``, ``numpy.ndarray``, `Data`
233             :param r: value of solution at locations of constraints
234             :type r: ``float``, ``int``, ``numpy.ndarray``, `Data`
235             """
236             self.__domain=domain
237             self.__D=D
238             self.__Y=Y
239             self.__q=q
240             self.__r=r
241             self.__u=None
242             self.__function_space=escript.Solution(self.__domain)
243    
244         def setReducedOn(self):
245             """
246             Sets the `FunctionSpace` of the solution to `ReducedSolution`.
247             """
248             self.__function_space=escript.ReducedSolution(self.__domain)
249             self.__u=None
250    
251         def setReducedOff(self):
252             """
253             Sets the `FunctionSpace` of the solution to `Solution`.
254             """
255             self.__function_space=escript.Solution(self.__domain)
256             self.__u=None
257    
258         def setValue(self,D=None,Y=None,q=None,r=None):
259             """
260             Assigns values to the parameters.
261    
262             :param D: coefficient of the solution
263             :type D: ``float``, ``int``, ``numpy.ndarray``, `Data`
264             :param Y: right hand side
265             :type Y: ``float``, ``int``, ``numpy.ndarray``, `Data`
266             :param q: location of constraints
267             :type q: ``float``, ``int``, ``numpy.ndarray``, `Data`
268             :param r: value of solution at locations of constraints
269             :type r: ``float``, ``int``, ``numpy.ndarray``, `Data`
270             """
271             if not D==None:
272                self.__D=D
273                self.__u=None
274             if not Y==None:
275                self.__Y=Y
276                self.__u=None
277             if not q==None:
278                self.__q=q
279                self.__u=None
280             if not r==None:
281                self.__r=r
282                self.__u=None
283    
284         def getSolution(self):
285             """
286             Returns the solution.
287    
288             :return: the solution of the problem
289             :rtype: `Data` object in the `FunctionSpace` `Solution` or
290                     `ReducedSolution`
291             """
292             if self.__u==None:
293                if self.__D==None:
294                   raise ValueError,"coefficient D is undefined"
295                D=escript.Data(self.__D,self.__function_space)
296                if D.getRank()>1:
297                   raise ValueError,"coefficient D must have rank 0 or 1"
298                if self.__Y==None:
299                   self.__u=escript.Data(0.,D.getShape(),self.__function_space)
300                else:
301                   self.__u=util.quotient(self.__Y,D)
302                if not self.__q==None:
303                    q=util.wherePositive(escript.Data(self.__q,self.__function_space))
304                    self.__u*=(1.-q)
305                    if not self.__r==None: self.__u+=q*self.__r
306             return self.__u
307    
308  class Locator:  class Locator:
309       """       """
310       Locator provides access to the values of data objects at a given       Locator provides access to the values of data objects at a given spatial
311       spatial coordinate x.         coordinate x.
312        
313       In fact, a Locator object finds the sample in the set of samples of a       In fact, a Locator object finds the sample in the set of samples of a
314       given function space or domain where which is closest to the given       given function space or domain which is closest to the given point x.
      point x.  
315       """       """
316    
317       def __init__(self,where,x=numarray.zeros((3,))):       def __init__(self,where,x=numpy.zeros((3,))):
318         """         """
319         Initializes a Locator to access values in Data objects on the Doamin         Initializes a Locator to access values in Data objects on the Doamin
320         or FunctionSpace where for the sample point which         or FunctionSpace for the sample point which is closest to the given
321         closest to the given point x.         point x.
322    
323           :param where: function space
324           :type where: `escript.FunctionSpace`
325           :param x: location(s) of the Locator
326           :type x: ``numpy.ndarray`` or ``list`` of ``numpy.ndarray``
327         """         """
328         if isinstance(where,escript.FunctionSpace):         if isinstance(where,escript.FunctionSpace):
329            self.__function_space=where            self.__function_space=where
330         else:         else:
331            self.__function_space=escript.ContinuousFunction(where)            self.__function_space=escript.ContinuousFunction(where)
332         self.__id=util.length(x[:self.__function_space.getDim()]-self.__function_space.getX()).mindp()         iterative=False
333           if isinstance(x, list):
334               if len(x)==0:
335                  raise "ValueError", "At least one point must be given."
336               try:
337                 iter(x[0])
338                 iterative=True
339               except TypeError:
340                 iterative=False
341           if iterative:
342               self.__id=[]
343               for p in x:
344                  self.__id.append(util.length(self.__function_space.getX()-p[:self.__function_space.getDim()]).minGlobalDataPoint())
345           else:
346               self.__id=util.length(self.__function_space.getX()-x[:self.__function_space.getDim()]).minGlobalDataPoint()
347    
348       def __str__(self):       def __str__(self):
349         """         """
350         Returns the coordinates of the Locator as a string.         Returns the coordinates of the Locator as a string.
351         """         """
352         return "<Locator %s>"%str(self.getX())         x=self.getX()
353           if isinstance(x,list):
354              out="["
355              first=True
356              for xx in x:
357                if not first:
358                    out+=","
359                else:
360                    first=False
361                out+=str(xx)
362              out+="]>"
363           else:
364              out=str(x)
365           return out
366    
367         def getX(self):
368            """
369            Returns the exact coordinates of the Locator.
370            """
371            return self(self.getFunctionSpace().getX())
372    
373       def getFunctionSpace(self):       def getFunctionSpace(self):
374          """          """
375      Returns the function space of the Locator.          Returns the function space of the Locator.
376      """          """
377          return self.__function_space          return self.__function_space
378    
379       def getId(self):       def getId(self,item=None):
380          """          """
381      Returns the identifier of the location.          Returns the identifier of the location.
     """  
         return self.__id  
   
      def getX(self):  
382          """          """
383      Returns the exact coordinates of the Locator.          if item == None:
384      """             return self.__id
385          return self(self.getFunctionSpace().getX())          else:
386               if isinstance(self.__id,list):
387                  return self.__id[item]
388               else:
389                  return self.__id
390    
391    
392       def __call__(self,data):       def __call__(self,data):
393          """          """
394      Returns the value of data at the Locator of a Data object otherwise          Returns the value of data at the Locator of a Data object.
395      the object is returned.          """
     """  
396          return self.getValue(data)          return self.getValue(data)
397    
398       def getValue(self,data):       def getValue(self,data):
399          """          """
400      Returns the value of data at the Locator if data is a Data object          Returns the value of ``data`` at the Locator if ``data`` is a `Data`
401      otherwise the object is returned.          object otherwise the object is returned.
402      """          """
403          if isinstance(data,escript.Data):          if isinstance(data,escript.Data):
404             if data.getFunctionSpace()==self.getFunctionSpace():             dat=util.interpolate(data,self.getFunctionSpace())
405               out=data.convertToNumArrayFromDPNo(self.getId()[0],self.getId()[1])             id=self.getId()
406             else:             r=data.getRank()
407               out=data.interpolate(self.getFunctionSpace()).convertToNumArrayFromDPNo(self.getId()[0],self.getId()[1])             if isinstance(id,list):
408             if data.getRank()==0:                 out=[]
409                return out[0]                 for i in id:
410                      o=numpy.array(dat.getTupleForGlobalDataPoint(*i))
411                      if data.getRank()==0:
412                         out.append(o[0])
413                      else:
414                         out.append(o)
415                   return out
416             else:             else:
417                return out               out=numpy.array(dat.getTupleForGlobalDataPoint(*id))
418                 if data.getRank()==0:
419                    return out[0]
420                 else:
421                    return out
422          else:          else:
423             return data             return data
424    
425  # vim: expandtab shiftwidth=4:  
426    def getInfLocator(arg):
427        """
428        Return a Locator for a point with the inf value over all arg.
429        """
430        if not isinstance(arg, escript.Data):
431        raise TypeError,"getInfLocator: Unknown argument type."
432        a_inf=util.inf(arg)
433        loc=util.length(arg-a_inf).minGlobalDataPoint() # This gives us the location but not coords
434        x=arg.getFunctionSpace().getX()
435        x_min=x.getTupleForGlobalDataPoint(*loc)
436        return Locator(arg.getFunctionSpace(),x_min)
437    
438    def getSupLocator(arg):
439        """
440        Return a Locator for a point with the sup value over all arg.
441        """
442        if not isinstance(arg, escript.Data):
443        raise TypeError,"getInfLocator: Unknown argument type."
444        a_inf=util.sup(arg)
445        loc=util.length(arg-a_inf).minGlobalDataPoint() # This gives us the location but not coords
446        x=arg.getFunctionSpace().getX()
447        x_min=x.getTupleForGlobalDataPoint(*loc)
448        return Locator(arg.getFunctionSpace(),x_min)
449        
450    
451    class SolverSchemeException(Exception):
452       """
453       This is a generic exception thrown by solvers.
454       """
455       pass
456    
457    class IndefinitePreconditioner(SolverSchemeException):
458       """
459       Exception thrown if the preconditioner is not positive definite.
460       """
461       pass
462    
463    class MaxIterReached(SolverSchemeException):
464       """
465       Exception thrown if the maximum number of iteration steps is reached.
466       """
467       pass
468    
469    class CorrectionFailed(SolverSchemeException):
470       """
471       Exception thrown if no convergence has been achieved in the solution
472       correction scheme.
473       """
474       pass
475    
476    class IterationBreakDown(SolverSchemeException):
477       """
478       Exception thrown if the iteration scheme encountered an incurable breakdown.
479       """
480       pass
481    
482    class NegativeNorm(SolverSchemeException):
483       """
484       Exception thrown if a norm calculation returns a negative norm.
485       """
486       pass
487    
488    def PCG(r, Aprod, x, Msolve, bilinearform, atol=0, rtol=1.e-8, iter_max=100, initial_guess=True, verbose=False):
489       """
490       Solver for
491    
492       *Ax=b*
493    
494       with a symmetric and positive definite operator A (more details required!).
495       It uses the conjugate gradient method with preconditioner M providing an
496       approximation of A.
497    
498       The iteration is terminated if
499    
500       *|r| <= atol+rtol*|r0|*
501    
502       where *r0* is the initial residual and *|.|* is the energy norm. In fact
503    
504       *|r| = sqrt( bilinearform(Msolve(r),r))*
505    
506       For details on the preconditioned conjugate gradient method see the book:
507    
508       I{Templates for the Solution of Linear Systems by R. Barrett, M. Berry,
509       T.F. Chan, J. Demmel, J. Donato, J. Dongarra, V. Eijkhout, R. Pozo,
510       C. Romine, and H. van der Vorst}.
511    
512       :param r: initial residual *r=b-Ax*. ``r`` is altered.
513       :type r: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
514       :param x: an initial guess for the solution
515       :type x: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
516       :param Aprod: returns the value Ax
517       :type Aprod: function ``Aprod(x)`` where ``x`` is of the same object like
518                    argument ``x``. The returned object needs to be of the same type
519                    like argument ``r``.
520       :param Msolve: solves Mx=r
521       :type Msolve: function ``Msolve(r)`` where ``r`` is of the same type like
522                     argument ``r``. The returned object needs to be of the same
523                     type like argument ``x``.
524       :param bilinearform: inner product ``<x,r>``
525       :type bilinearform: function ``bilinearform(x,r)`` where ``x`` is of the same
526                           type like argument ``x`` and ``r`` is. The returned value
527                           is a ``float``.
528       :param atol: absolute tolerance
529       :type atol: non-negative ``float``
530       :param rtol: relative tolerance
531       :type rtol: non-negative ``float``
532       :param iter_max: maximum number of iteration steps
533       :type iter_max: ``int``
534       :return: the solution approximation and the corresponding residual
535       :rtype: ``tuple``
536       :warning: ``r`` and ``x`` are altered.
537       """
538       iter=0
539       rhat=Msolve(r)
540       d = rhat
541       rhat_dot_r = bilinearform(rhat, r)
542       if rhat_dot_r<0: raise NegativeNorm,"negative norm."
543       norm_r0=math.sqrt(rhat_dot_r)
544       atol2=atol+rtol*norm_r0
545       if atol2<=0:
546          raise ValueError,"Non-positive tolarance."
547       atol2=max(atol2, 100. * util.EPSILON * norm_r0)
548    
549       if verbose: print "PCG: initial residual norm = %e (absolute tolerance = %e)"%(norm_r0, atol2)
550    
551    
552       while not math.sqrt(rhat_dot_r) <= atol2:
553           iter+=1
554           if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached."%iter_max
555    
556           q=Aprod(d)
557           alpha = rhat_dot_r / bilinearform(d, q)
558           x += alpha * d
559           if isinstance(q,ArithmeticTuple):
560           r += q * (-alpha)      # Doing it the other way calls the float64.__mul__ not AT.__rmul__
561           else:
562               r += (-alpha) * q
563           rhat=Msolve(r)
564           rhat_dot_r_new = bilinearform(rhat, r)
565           beta = rhat_dot_r_new / rhat_dot_r
566           rhat+=beta * d
567           d=rhat
568    
569           rhat_dot_r = rhat_dot_r_new
570           if rhat_dot_r<0: raise NegativeNorm,"negative norm."
571           if verbose: print "PCG: iteration step %s: residual norm = %e"%(iter, math.sqrt(rhat_dot_r))
572       if verbose: print "PCG: tolerance reached after %s steps."%iter
573       return x,r,math.sqrt(rhat_dot_r)
574    
575    class Defect(object):
576        """
577        Defines a non-linear defect F(x) of a variable x.
578        """
579        def __init__(self):
580            """
581            Initializes defect.
582            """
583            self.setDerivativeIncrementLength()
584    
585        def bilinearform(self, x0, x1):
586            """
587            Returns the inner product of x0 and x1
588    
589            :param x0: value for x0
590            :param x1: value for x1
591            :return: the inner product of x0 and x1
592            :rtype: ``float``
593            """
594            return 0
595    
596        def norm(self,x):
597            """
598            Returns the norm of argument ``x``.
599    
600            :param x: a value
601            :return: norm of argument x
602            :rtype: ``float``
603            :note: by default ``sqrt(self.bilinearform(x,x)`` is returned.
604            """
605            s=self.bilinearform(x,x)
606            if s<0: raise NegativeNorm,"negative norm."
607            return math.sqrt(s)
608    
609        def eval(self,x):
610            """
611            Returns the value F of a given ``x``.
612    
613            :param x: value for which the defect ``F`` is evaluated
614            :return: value of the defect at ``x``
615            """
616            return 0
617    
618        def __call__(self,x):
619            return self.eval(x)
620    
621        def setDerivativeIncrementLength(self,inc=1000.*math.sqrt(util.EPSILON)):
622            """
623            Sets the relative length of the increment used to approximate the
624            derivative of the defect. The increment is inc*norm(x)/norm(v)*v in the
625            direction of v with x as a starting point.
626    
627            :param inc: relative increment length
628            :type inc: positive ``float``
629            """
630            if inc<=0: raise ValueError,"positive increment required."
631            self.__inc=inc
632    
633        def getDerivativeIncrementLength(self):
634            """
635            Returns the relative increment length used to approximate the
636            derivative of the defect.
637            :return: value of the defect at ``x``
638            :rtype: positive ``float``
639            """
640            return self.__inc
641    
642        def derivative(self, F0, x0, v, v_is_normalised=True):
643            """
644            Returns the directional derivative at ``x0`` in the direction of ``v``.
645    
646            :param F0: value of this defect at x0
647            :param x0: value at which derivative is calculated
648            :param v: direction
649            :param v_is_normalised: True to indicate that ``v`` is nomalized
650                                    (self.norm(v)=0)
651            :return: derivative of this defect at x0 in the direction of ``v``
652            :note: by default numerical evaluation (self.eval(x0+eps*v)-F0)/eps is
653                   used but this method maybe overwritten to use exact evaluation.
654            """
655            normx=self.norm(x0)
656            if normx>0:
657                 epsnew = self.getDerivativeIncrementLength() * normx
658            else:
659                 epsnew = self.getDerivativeIncrementLength()
660            if not v_is_normalised:
661                normv=self.norm(v)
662                if normv<=0:
663                   return F0*0
664                else:
665                   epsnew /= normv
666            F1=self.eval(x0 + epsnew * v)
667            return (F1-F0)/epsnew
668    
669    ######################################
670    def NewtonGMRES(defect, x, iter_max=100, sub_iter_max=20, atol=0,rtol=1.e-4, subtol_max=0.5, gamma=0.9, verbose=False):
671       """
672       Solves a non-linear problem *F(x)=0* for unknown *x* using the stopping
673       criterion:
674    
675       *norm(F(x) <= atol + rtol * norm(F(x0)*
676    
677       where *x0* is the initial guess.
678    
679       :param defect: object defining the function *F*. ``defect.norm`` defines the
680                      *norm* used in the stopping criterion.
681       :type defect: `Defect`
682       :param x: initial guess for the solution, ``x`` is altered.
683       :type x: any object type allowing basic operations such as
684                ``numpy.ndarray``, `Data`
685       :param iter_max: maximum number of iteration steps
686       :type iter_max: positive ``int``
687       :param sub_iter_max: maximum number of inner iteration steps
688       :type sub_iter_max: positive ``int``
689       :param atol: absolute tolerance for the solution
690       :type atol: positive ``float``
691       :param rtol: relative tolerance for the solution
692       :type rtol: positive ``float``
693       :param gamma: tolerance safety factor for inner iteration
694       :type gamma: positive ``float``, less than 1
695       :param subtol_max: upper bound for inner tolerance
696       :type subtol_max: positive ``float``, less than 1
697       :return: an approximation of the solution with the desired accuracy
698       :rtype: same type as the initial guess
699       """
700       lmaxit=iter_max
701       if atol<0: raise ValueError,"atol needs to be non-negative."
702       if rtol<0: raise ValueError,"rtol needs to be non-negative."
703       if rtol+atol<=0: raise ValueError,"rtol or atol needs to be non-negative."
704       if gamma<=0 or gamma>=1: raise ValueError,"tolerance safety factor for inner iteration (gamma =%s) needs to be positive and less than 1."%gamma
705       if subtol_max<=0 or subtol_max>=1: raise ValueError,"upper bound for inner tolerance for inner iteration (subtol_max =%s) needs to be positive and less than 1."%subtol_max
706    
707       F=defect(x)
708       fnrm=defect.norm(F)
709       stop_tol=atol + rtol*fnrm
710       subtol=subtol_max
711       if verbose: print "NewtonGMRES: initial residual = %e."%fnrm
712       if verbose: print "             tolerance = %e."%subtol
713       iter=1
714       #
715       # main iteration loop
716       #
717       while not fnrm<=stop_tol:
718                if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached."%iter_max
719                #
720            #   adjust subtol_
721            #
722                if iter > 1:
723               rat=fnrm/fnrmo
724                   subtol_old=subtol
725               subtol=gamma*rat**2
726               if gamma*subtol_old**2 > .1: subtol=max(subtol,gamma*subtol_old**2)
727               subtol=max(min(subtol,subtol_max), .5*stop_tol/fnrm)
728            #
729            # calculate newton increment xc
730                #     if iter_max in __FDGMRES is reached MaxIterReached is thrown
731                #     if iter_restart -1 is returned as sub_iter
732                #     if  atol is reached sub_iter returns the numer of steps performed to get there
733                #
734                #
735                if verbose: print "             subiteration (GMRES) is called with relative tolerance %e."%subtol
736                try:
737                   xc, sub_iter=__FDGMRES(F, defect, x, subtol*fnrm, iter_max=iter_max-iter, iter_restart=sub_iter_max)
738                except MaxIterReached:
739                   raise MaxIterReached,"maximum number of %s steps reached."%iter_max
740                if sub_iter<0:
741                   iter+=sub_iter_max
742                else:
743                   iter+=sub_iter
744                # ====
745            x+=xc
746                F=defect(x)
747            iter+=1
748                fnrmo, fnrm=fnrm, defect.norm(F)
749                if verbose: print "             step %s: residual %e."%(iter,fnrm)
750       if verbose: print "NewtonGMRES: completed after %s steps."%iter
751       return x
752    
753    def __givapp(c,s,vin):
754        """
755        Applies a sequence of Givens rotations (c,s) recursively to the vector
756        ``vin``
757    
758        :warning: ``vin`` is altered.
759        """
760        vrot=vin
761        if isinstance(c,float):
762            vrot=[c*vrot[0]-s*vrot[1],s*vrot[0]+c*vrot[1]]
763        else:
764            for i in range(len(c)):
765                w1=c[i]*vrot[i]-s[i]*vrot[i+1]
766            w2=s[i]*vrot[i]+c[i]*vrot[i+1]
767                vrot[i]=w1
768                vrot[i+1]=w2
769        return vrot
770    
771    def __FDGMRES(F0, defect, x0, atol, iter_max=100, iter_restart=20):
772       h=numpy.zeros((iter_restart,iter_restart),numpy.float64)
773       c=numpy.zeros(iter_restart,numpy.float64)
774       s=numpy.zeros(iter_restart,numpy.float64)
775       g=numpy.zeros(iter_restart,numpy.float64)
776       v=[]
777    
778       rho=defect.norm(F0)
779       if rho<=0.: return x0*0
780    
781       v.append(-F0/rho)
782       g[0]=rho
783       iter=0
784       while rho > atol and iter<iter_restart-1:
785            if iter  >= iter_max:
786                raise MaxIterReached,"maximum number of %s steps reached."%iter_max
787    
788            p=defect.derivative(F0,x0,v[iter], v_is_normalised=True)
789            v.append(p)
790    
791            v_norm1=defect.norm(v[iter+1])
792    
793            # Modified Gram-Schmidt
794            for j in range(iter+1):
795                h[j,iter]=defect.bilinearform(v[j],v[iter+1])
796                v[iter+1]-=h[j,iter]*v[j]
797    
798            h[iter+1,iter]=defect.norm(v[iter+1])
799            v_norm2=h[iter+1,iter]
800    
801            # Reorthogonalize if needed
802            if v_norm1 + 0.001*v_norm2 == v_norm1:   #Brown/Hindmarsh condition (default)
803                for j in range(iter+1):
804                    hr=defect.bilinearform(v[j],v[iter+1])
805                    h[j,iter]=h[j,iter]+hr
806                    v[iter+1] -= hr*v[j]
807    
808                v_norm2=defect.norm(v[iter+1])
809                h[iter+1,iter]=v_norm2
810            #   watch out for happy breakdown
811            if not v_norm2 == 0:
812                v[iter+1]=v[iter+1]/h[iter+1,iter]
813    
814            #   Form and store the information for the new Givens rotation
815            if iter > 0 :
816                hhat=numpy.zeros(iter+1,numpy.float64)
817                for i in range(iter+1) : hhat[i]=h[i,iter]
818                hhat=__givapp(c[0:iter],s[0:iter],hhat);
819                for i in range(iter+1) : h[i,iter]=hhat[i]
820    
821            mu=math.sqrt(h[iter,iter]*h[iter,iter]+h[iter+1,iter]*h[iter+1,iter])
822    
823            if mu!=0 :
824                c[iter]=h[iter,iter]/mu
825                s[iter]=-h[iter+1,iter]/mu
826                h[iter,iter]=c[iter]*h[iter,iter]-s[iter]*h[iter+1,iter]
827                h[iter+1,iter]=0.0
828                gg=__givapp(c[iter],s[iter],[g[iter],g[iter+1]])
829                g[iter]=gg[0]
830                g[iter+1]=gg[1]
831    
832            # Update the residual norm
833            rho=abs(g[iter+1])
834            iter+=1
835    
836       # At this point either iter > iter_max or rho < tol.
837       # It's time to compute x and leave.
838       if iter > 0 :
839         y=numpy.zeros(iter,numpy.float64)
840         y[iter-1] = g[iter-1] / h[iter-1,iter-1]
841         if iter > 1 :
842            i=iter-2
843            while i>=0 :
844              y[i] = ( g[i] - numpy.dot(h[i,i+1:iter], y[i+1:iter])) / h[i,i]
845              i=i-1
846         xhat=v[iter-1]*y[iter-1]
847         for i in range(iter-1):
848        xhat += v[i]*y[i]
849       else :
850          xhat=v[0] * 0
851    
852       if iter<iter_restart-1:
853          stopped=iter
854       else:
855          stopped=-1
856    
857       return xhat,stopped
858    
859    def GMRES(r, Aprod, x, bilinearform, atol=0, rtol=1.e-8, iter_max=100, iter_restart=20, verbose=False,P_R=None):
860       """
861       Solver for
862    
863       *Ax=b*
864    
865       with a general operator A (more details required!).
866       It uses the generalized minimum residual method (GMRES).
867    
868       The iteration is terminated if
869    
870       *|r| <= atol+rtol*|r0|*
871    
872       where *r0* is the initial residual and *|.|* is the energy norm. In fact
873    
874       *|r| = sqrt( bilinearform(r,r))*
875    
876       :param r: initial residual *r=b-Ax*. ``r`` is altered.
877       :type r: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
878       :param x: an initial guess for the solution
879       :type x: same like ``r``
880       :param Aprod: returns the value Ax
881       :type Aprod: function ``Aprod(x)`` where ``x`` is of the same object like
882                    argument ``x``. The returned object needs to be of the same
883                    type like argument ``r``.
884       :param bilinearform: inner product ``<x,r>``
885       :type bilinearform: function ``bilinearform(x,r)`` where ``x`` is of the same
886                           type like argument ``x`` and ``r``. The returned value is
887                           a ``float``.
888       :param atol: absolute tolerance
889       :type atol: non-negative ``float``
890       :param rtol: relative tolerance
891       :type rtol: non-negative ``float``
892       :param iter_max: maximum number of iteration steps
893       :type iter_max: ``int``
894       :param iter_restart: in order to save memory the orthogonalization process
895                            is terminated after ``iter_restart`` steps and the
896                            iteration is restarted.
897       :type iter_restart: ``int``
898       :return: the solution approximation and the corresponding residual
899       :rtype: ``tuple``
900       :warning: ``r`` and ``x`` are altered.
901       """
902       m=iter_restart
903       restarted=False
904       iter=0
905       if rtol>0:
906          r_dot_r = bilinearform(r, r)
907          if r_dot_r<0: raise NegativeNorm,"negative norm."
908          atol2=atol+rtol*math.sqrt(r_dot_r)
909          if verbose: print "GMRES: norm of right hand side = %e (absolute tolerance = %e)"%(math.sqrt(r_dot_r), atol2)
910       else:
911          atol2=atol
912          if verbose: print "GMRES: absolute tolerance = %e"%atol2
913       if atol2<=0:
914          raise ValueError,"Non-positive tolarance."
915    
916       while True:
917          if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached"%iter_max
918          if restarted:
919             r2 = r-Aprod(x-x2)
920          else:
921             r2=1*r
922          x2=x*1.
923          x,stopped=_GMRESm(r2, Aprod, x, bilinearform, atol2, iter_max=iter_max-iter, iter_restart=m, verbose=verbose,P_R=P_R)
924          iter+=iter_restart
925          if stopped: break
926          if verbose: print "GMRES: restart."
927          restarted=True
928       if verbose: print "GMRES: tolerance has been reached."
929       return x
930    
931    def _GMRESm(r, Aprod, x, bilinearform, atol, iter_max=100, iter_restart=20, verbose=False, P_R=None):
932       iter=0
933    
934       h=numpy.zeros((iter_restart+1,iter_restart),numpy.float64)
935       c=numpy.zeros(iter_restart,numpy.float64)
936       s=numpy.zeros(iter_restart,numpy.float64)
937       g=numpy.zeros(iter_restart+1,numpy.float64)
938       v=[]
939    
940       r_dot_r = bilinearform(r, r)
941       if r_dot_r<0: raise NegativeNorm,"negative norm."
942       rho=math.sqrt(r_dot_r)
943    
944       v.append(r/rho)
945       g[0]=rho
946    
947       if verbose: print "GMRES: initial residual %e (absolute tolerance = %e)"%(rho,atol)
948       while not (rho<=atol or iter==iter_restart):
949    
950        if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached."%iter_max
951    
952            if P_R!=None:
953                p=Aprod(P_R(v[iter]))
954            else:
955            p=Aprod(v[iter])
956        v.append(p)
957    
958        v_norm1=math.sqrt(bilinearform(v[iter+1], v[iter+1]))
959    
960    # Modified Gram-Schmidt
961        for j in range(iter+1):
962          h[j,iter]=bilinearform(v[j],v[iter+1])
963          v[iter+1]-=h[j,iter]*v[j]
964    
965        h[iter+1,iter]=math.sqrt(bilinearform(v[iter+1],v[iter+1]))
966        v_norm2=h[iter+1,iter]
967    
968    # Reorthogonalize if needed
969        if v_norm1 + 0.001*v_norm2 == v_norm1:   #Brown/Hindmarsh condition (default)
970         for j in range(iter+1):
971            hr=bilinearform(v[j],v[iter+1])
972                h[j,iter]=h[j,iter]+hr
973                v[iter+1] -= hr*v[j]
974    
975         v_norm2=math.sqrt(bilinearform(v[iter+1], v[iter+1]))
976         h[iter+1,iter]=v_norm2
977    
978    #   watch out for happy breakdown
979            if not v_norm2 == 0:
980             v[iter+1]=v[iter+1]/h[iter+1,iter]
981    
982    #   Form and store the information for the new Givens rotation
983        if iter > 0: h[:iter+1,iter]=__givapp(c[:iter],s[:iter],h[:iter+1,iter])
984        mu=math.sqrt(h[iter,iter]*h[iter,iter]+h[iter+1,iter]*h[iter+1,iter])
985    
986        if mu!=0 :
987            c[iter]=h[iter,iter]/mu
988            s[iter]=-h[iter+1,iter]/mu
989            h[iter,iter]=c[iter]*h[iter,iter]-s[iter]*h[iter+1,iter]
990            h[iter+1,iter]=0.0
991                    gg=__givapp(c[iter],s[iter],[g[iter],g[iter+1]])
992                    g[iter]=gg[0]
993                    g[iter+1]=gg[1]
994    # Update the residual norm
995    
996            rho=abs(g[iter+1])
997            if verbose: print "GMRES: iteration step %s: residual %e"%(iter,rho)
998        iter+=1
999    
1000    # At this point either iter > iter_max or rho < tol.
1001    # It's time to compute x and leave.
1002    
1003       if verbose: print "GMRES: iteration stopped after %s step."%iter
1004       if iter > 0 :
1005         y=numpy.zeros(iter,numpy.float64)
1006         y[iter-1] = g[iter-1] / h[iter-1,iter-1]
1007         if iter > 1 :
1008            i=iter-2
1009            while i>=0 :
1010              y[i] = ( g[i] - numpy.dot(h[i,i+1:iter], y[i+1:iter])) / h[i,i]
1011              i=i-1
1012         xhat=v[iter-1]*y[iter-1]
1013         for i in range(iter-1):
1014        xhat += v[i]*y[i]
1015       else:
1016         xhat=v[0] * 0
1017       if P_R!=None:
1018          x += P_R(xhat)
1019       else:
1020          x += xhat
1021       if iter<iter_restart-1:
1022          stopped=True
1023       else:
1024          stopped=False
1025    
1026       return x,stopped
1027    
1028    def MINRES(r, Aprod, x, Msolve, bilinearform, atol=0, rtol=1.e-8, iter_max=100):
1029        """
1030        Solver for
1031    
1032        *Ax=b*
1033    
1034        with a symmetric and positive definite operator A (more details required!).
1035        It uses the minimum residual method (MINRES) with preconditioner M
1036        providing an approximation of A.
1037    
1038        The iteration is terminated if
1039    
1040        *|r| <= atol+rtol*|r0|*
1041    
1042        where *r0* is the initial residual and *|.|* is the energy norm. In fact
1043    
1044        *|r| = sqrt( bilinearform(Msolve(r),r))*
1045    
1046        For details on the preconditioned conjugate gradient method see the book:
1047    
1048        I{Templates for the Solution of Linear Systems by R. Barrett, M. Berry,
1049        T.F. Chan, J. Demmel, J. Donato, J. Dongarra, V. Eijkhout, R. Pozo,
1050        C. Romine, and H. van der Vorst}.
1051    
1052        :param r: initial residual *r=b-Ax*. ``r`` is altered.
1053        :type r: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
1054        :param x: an initial guess for the solution
1055        :type x: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
1056        :param Aprod: returns the value Ax
1057        :type Aprod: function ``Aprod(x)`` where ``x`` is of the same object like
1058                     argument ``x``. The returned object needs to be of the same
1059                     type like argument ``r``.
1060        :param Msolve: solves Mx=r
1061        :type Msolve: function ``Msolve(r)`` where ``r`` is of the same type like
1062                      argument ``r``. The returned object needs to be of the same
1063                      type like argument ``x``.
1064        :param bilinearform: inner product ``<x,r>``
1065        :type bilinearform: function ``bilinearform(x,r)`` where ``x`` is of the same
1066                            type like argument ``x`` and ``r`` is. The returned value
1067                            is a ``float``.
1068        :param atol: absolute tolerance
1069        :type atol: non-negative ``float``
1070        :param rtol: relative tolerance
1071        :type rtol: non-negative ``float``
1072        :param iter_max: maximum number of iteration steps
1073        :type iter_max: ``int``
1074        :return: the solution approximation and the corresponding residual
1075        :rtype: ``tuple``
1076        :warning: ``r`` and ``x`` are altered.
1077        """
1078        #------------------------------------------------------------------
1079        # Set up y and v for the first Lanczos vector v1.
1080        # y  =  beta1 P' v1,  where  P = C**(-1).
1081        # v is really P' v1.
1082        #------------------------------------------------------------------
1083        r1    = r
1084        y = Msolve(r)
1085        beta1 = bilinearform(y,r)
1086    
1087        if beta1< 0: raise NegativeNorm,"negative norm."
1088    
1089        #  If r = 0 exactly, stop with x
1090        if beta1==0: return x
1091    
1092        if beta1> 0: beta1  = math.sqrt(beta1)
1093    
1094        #------------------------------------------------------------------
1095        # Initialize quantities.
1096        # ------------------------------------------------------------------
1097        iter   = 0
1098        Anorm = 0
1099        ynorm = 0
1100        oldb   = 0
1101        beta   = beta1
1102        dbar   = 0
1103        epsln  = 0
1104        phibar = beta1
1105        rhs1   = beta1
1106        rhs2   = 0
1107        rnorm  = phibar
1108        tnorm2 = 0
1109        ynorm2 = 0
1110        cs     = -1
1111        sn     = 0
1112        w      = r*0.
1113        w2     = r*0.
1114        r2     = r1
1115        eps    = 0.0001
1116    
1117        #---------------------------------------------------------------------
1118        # Main iteration loop.
1119        # --------------------------------------------------------------------
1120        while not rnorm<=atol+rtol*Anorm*ynorm:    #  checks ||r|| < (||A|| ||x||) * TOL
1121    
1122        if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached."%iter_max
1123            iter    = iter  +  1
1124    
1125            #-----------------------------------------------------------------
1126            # Obtain quantities for the next Lanczos vector vk+1, k = 1, 2,...
1127            # The general iteration is similar to the case k = 1 with v0 = 0:
1128            #
1129            #   p1      = Operator * v1  -  beta1 * v0,
1130            #   alpha1  = v1'p1,
1131            #   q2      = p2  -  alpha1 * v1,
1132            #   beta2^2 = q2'q2,
1133            #   v2      = (1/beta2) q2.
1134            #
1135            # Again, y = betak P vk,  where  P = C**(-1).
1136            #-----------------------------------------------------------------
1137            s = 1/beta                 # Normalize previous vector (in y).
1138            v = s*y                    # v = vk if P = I
1139    
1140            y      = Aprod(v)
1141    
1142            if iter >= 2:
1143              y = y - (beta/oldb)*r1
1144    
1145            alfa   = bilinearform(v,y)              # alphak
1146            y      += (- alfa/beta)*r2
1147            r1     = r2
1148            r2     = y
1149            y = Msolve(r2)
1150            oldb   = beta                           # oldb = betak
1151            beta   = bilinearform(y,r2)             # beta = betak+1^2
1152            if beta < 0: raise NegativeNorm,"negative norm."
1153    
1154            beta   = math.sqrt( beta )
1155            tnorm2 = tnorm2 + alfa*alfa + oldb*oldb + beta*beta
1156    
1157            if iter==1:                 # Initialize a few things.
1158              gmax   = abs( alfa )      # alpha1
1159              gmin   = gmax             # alpha1
1160    
1161            # Apply previous rotation Qk-1 to get
1162            #   [deltak epslnk+1] = [cs  sn][dbark    0   ]
1163            #   [gbar k dbar k+1]   [sn -cs][alfak betak+1].
1164    
1165            oldeps = epsln
1166            delta  = cs * dbar  +  sn * alfa  # delta1 = 0         deltak
1167            gbar   = sn * dbar  -  cs * alfa  # gbar 1 = alfa1     gbar k
1168            epsln  =               sn * beta  # epsln2 = 0         epslnk+1
1169            dbar   =            -  cs * beta  # dbar 2 = beta2     dbar k+1
1170    
1171            # Compute the next plane rotation Qk
1172    
1173            gamma  = math.sqrt(gbar*gbar+beta*beta)  # gammak
1174            gamma  = max(gamma,eps)
1175            cs     = gbar / gamma             # ck
1176            sn     = beta / gamma             # sk
1177            phi    = cs * phibar              # phik
1178            phibar = sn * phibar              # phibark+1
1179    
1180            # Update  x.
1181    
1182            denom = 1/gamma
1183            w1    = w2
1184            w2    = w
1185            w     = (v - oldeps*w1 - delta*w2) * denom
1186            x     +=  phi*w
1187    
1188            # Go round again.
1189    
1190            gmax   = max(gmax,gamma)
1191            gmin   = min(gmin,gamma)
1192            z      = rhs1 / gamma
1193            ynorm2 = z*z  +  ynorm2
1194            rhs1   = rhs2 -  delta*z
1195            rhs2   =      -  epsln*z
1196    
1197            # Estimate various norms and test for convergence.
1198    
1199            Anorm  = math.sqrt( tnorm2 )
1200            ynorm  = math.sqrt( ynorm2 )
1201    
1202            rnorm  = phibar
1203    
1204        return x
1205    
1206    def TFQMR(r, Aprod, x, bilinearform, atol=0, rtol=1.e-8, iter_max=100):
1207      """
1208      Solver for
1209    
1210      *Ax=b*
1211    
1212      with a general operator A (more details required!).
1213      It uses the Transpose-Free Quasi-Minimal Residual method (TFQMR).
1214    
1215      The iteration is terminated if
1216    
1217      *|r| <= atol+rtol*|r0|*
1218    
1219      where *r0* is the initial residual and *|.|* is the energy norm. In fact
1220    
1221      *|r| = sqrt( bilinearform(r,r))*
1222    
1223      :param r: initial residual *r=b-Ax*. ``r`` is altered.
1224      :type r: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
1225      :param x: an initial guess for the solution
1226      :type x: same like ``r``
1227      :param Aprod: returns the value Ax
1228      :type Aprod: function ``Aprod(x)`` where ``x`` is of the same object like
1229                   argument ``x``. The returned object needs to be of the same type
1230                   like argument ``r``.
1231      :param bilinearform: inner product ``<x,r>``
1232      :type bilinearform: function ``bilinearform(x,r)`` where ``x`` is of the same
1233                          type like argument ``x`` and ``r``. The returned value is
1234                          a ``float``.
1235      :param atol: absolute tolerance
1236      :type atol: non-negative ``float``
1237      :param rtol: relative tolerance
1238      :type rtol: non-negative ``float``
1239      :param iter_max: maximum number of iteration steps
1240      :type iter_max: ``int``
1241      :rtype: ``tuple``
1242      :warning: ``r`` and ``x`` are altered.
1243      """
1244      u1=0
1245      u2=0
1246      y1=0
1247      y2=0
1248    
1249      w = r
1250      y1 = r
1251      iter = 0
1252      d = 0
1253      v = Aprod(y1)
1254      u1 = v
1255    
1256      theta = 0.0;
1257      eta = 0.0;
1258      rho=bilinearform(r,r)
1259      if rho < 0: raise NegativeNorm,"negative norm."
1260      tau = math.sqrt(rho)
1261      norm_r0=tau
1262      while tau>atol+rtol*norm_r0:
1263        if iter  >= iter_max: raise MaxIterReached,"maximum number of %s steps reached."%iter_max
1264    
1265        sigma = bilinearform(r,v)
1266        if sigma == 0.0: raise IterationBreakDown,'TFQMR breakdown, sigma=0'
1267    
1268        alpha = rho / sigma
1269    
1270        for j in range(2):
1271    #
1272    #   Compute y2 and u2 only if you have to
1273    #
1274          if ( j == 1 ):
1275            y2 = y1 - alpha * v
1276            u2 = Aprod(y2)
1277    
1278          m = 2 * (iter+1) - 2 + (j+1)
1279          if j==0:
1280             w = w - alpha * u1
1281             d = y1 + ( theta * theta * eta / alpha ) * d
1282          if j==1:
1283             w = w - alpha * u2
1284             d = y2 + ( theta * theta * eta / alpha ) * d
1285    
1286          theta = math.sqrt(bilinearform(w,w))/ tau
1287          c = 1.0 / math.sqrt ( 1.0 + theta * theta )
1288          tau = tau * theta * c
1289          eta = c * c * alpha
1290          x = x + eta * d
1291    #
1292    #  Try to terminate the iteration at each pass through the loop
1293    #
1294        if rho == 0.0: raise IterationBreakDown,'TFQMR breakdown, rho=0'
1295    
1296        rhon = bilinearform(r,w)
1297        beta = rhon / rho;
1298        rho = rhon;
1299        y1 = w + beta * y2;
1300        u1 = Aprod(y1)
1301        v = u1 + beta * ( u2 + beta * v )
1302    
1303        iter += 1
1304    
1305      return x
1306    
1307    
1308    #############################################
1309    
1310    class ArithmeticTuple(object):
1311       """
1312       Tuple supporting inplace update x+=y and scaling x=a*y where ``x,y`` is an
1313       ArithmeticTuple and ``a`` is a float.
1314    
1315       Example of usage::
1316    
1317           from esys.escript import Data
1318           from numpy import array
1319           a=Data(...)
1320           b=array([1.,4.])
1321           x=ArithmeticTuple(a,b)
1322           y=5.*x
1323    
1324       """
1325       def __init__(self,*args):
1326           """
1327           Initializes object with elements ``args``.
1328    
1329           :param args: tuple of objects that support inplace add (x+=y) and
1330                        scaling (x=a*y)
1331           """
1332           self.__items=list(args)
1333    
1334       def __len__(self):
1335           """
1336           Returns the number of items.
1337    
1338           :return: number of items
1339           :rtype: ``int``
1340           """
1341           return len(self.__items)
1342    
1343       def __getitem__(self,index):
1344           """
1345           Returns item at specified position.
1346    
1347           :param index: index of item to be returned
1348           :type index: ``int``
1349           :return: item with index ``index``
1350           """
1351           return self.__items.__getitem__(index)
1352    
1353       def __mul__(self,other):
1354           """
1355           Scales by ``other`` from the right.
1356    
1357           :param other: scaling factor
1358           :type other: ``float``
1359           :return: itemwise self*other
1360           :rtype: `ArithmeticTuple`
1361           """
1362           out=[]
1363           try:
1364               l=len(other)
1365               if l!=len(self):
1366                   raise ValueError,"length of arguments don't match."
1367               for i in range(l): out.append(self[i]*other[i])
1368           except TypeError:
1369               for i in range(len(self)): out.append(self[i]*other)
1370           return ArithmeticTuple(*tuple(out))
1371    
1372       def __rmul__(self,other):
1373           """
1374           Scales by ``other`` from the left.
1375    
1376           :param other: scaling factor
1377           :type other: ``float``
1378           :return: itemwise other*self
1379           :rtype: `ArithmeticTuple`
1380           """
1381           out=[]
1382           try:
1383               l=len(other)
1384               if l!=len(self):
1385                   raise ValueError,"length of arguments don't match."
1386               for i in range(l): out.append(other[i]*self[i])
1387           except TypeError:
1388               for i in range(len(self)): out.append(other*self[i])
1389           return ArithmeticTuple(*tuple(out))
1390    
1391       def __div__(self,other):
1392           """
1393           Scales by (1/``other``) from the right.
1394    
1395           :param other: scaling factor
1396           :type other: ``float``
1397           :return: itemwise self/other
1398           :rtype: `ArithmeticTuple`
1399           """
1400           return self*(1/other)
1401    
1402       def __rdiv__(self,other):
1403           """
1404           Scales by (1/``other``) from the left.
1405    
1406           :param other: scaling factor
1407           :type other: ``float``
1408           :return: itemwise other/self
1409           :rtype: `ArithmeticTuple`
1410           """
1411           out=[]
1412           try:
1413               l=len(other)
1414               if l!=len(self):
1415                   raise ValueError,"length of arguments don't match."
1416               for i in range(l): out.append(other[i]/self[i])
1417           except TypeError:
1418               for i in range(len(self)): out.append(other/self[i])
1419           return ArithmeticTuple(*tuple(out))
1420    
1421       def __iadd__(self,other):
1422           """
1423           Inplace addition of ``other`` to self.
1424    
1425           :param other: increment
1426           :type other: ``ArithmeticTuple``
1427           """
1428           if len(self) != len(other):
1429               raise ValueError,"tuple lengths must match."
1430           for i in range(len(self)):
1431               self.__items[i]+=other[i]
1432           return self
1433    
1434       def __add__(self,other):
1435           """
1436           Adds ``other`` to self.
1437    
1438           :param other: increment
1439           :type other: ``ArithmeticTuple``
1440           """
1441           out=[]
1442           try:
1443               l=len(other)
1444               if l!=len(self):
1445                   raise ValueError,"length of arguments don't match."
1446               for i in range(l): out.append(self[i]+other[i])
1447           except TypeError:
1448               for i in range(len(self)): out.append(self[i]+other)
1449           return ArithmeticTuple(*tuple(out))
1450    
1451       def __sub__(self,other):
1452           """
1453           Subtracts ``other`` from self.
1454    
1455           :param other: decrement
1456           :type other: ``ArithmeticTuple``
1457           """
1458           out=[]
1459           try:
1460               l=len(other)
1461               if l!=len(self):
1462                   raise ValueError,"length of arguments don't match."
1463               for i in range(l): out.append(self[i]-other[i])
1464           except TypeError:
1465               for i in range(len(self)): out.append(self[i]-other)
1466           return ArithmeticTuple(*tuple(out))
1467    
1468       def __isub__(self,other):
1469           """
1470           Inplace subtraction of ``other`` from self.
1471    
1472           :param other: decrement
1473           :type other: ``ArithmeticTuple``
1474           """
1475           if len(self) != len(other):
1476               raise ValueError,"tuple length must match."
1477           for i in range(len(self)):
1478               self.__items[i]-=other[i]
1479           return self
1480    
1481       def __neg__(self):
1482           """
1483           Negates values.
1484           """
1485           out=[]
1486           for i in range(len(self)):
1487               out.append(-self[i])
1488           return ArithmeticTuple(*tuple(out))
1489    
1490    
1491    class HomogeneousSaddlePointProblem(object):
1492          """
1493          This class provides a framework for solving linear homogeneous saddle
1494          point problems of the form::
1495    
1496              *Av+B^*p=f*
1497              *Bv     =0*
1498    
1499          for the unknowns *v* and *p* and given operators *A* and *B* and
1500          given right hand side *f*. *B^** is the adjoint operator of *B*.
1501          *A* may depend weakly on *v* and *p*.
1502          """
1503          def __init__(self, **kwargs):
1504        """
1505        initializes the saddle point problem
1506        """
1507            self.resetControlParameters()
1508            self.setTolerance()
1509            self.setAbsoluteTolerance()
1510          def resetControlParameters(self, K_p=1., K_v=1., rtol_max=0.01, rtol_min = 1.e-7, chi_max=0.5, reduction_factor=0.3, theta = 0.1):
1511             """
1512             sets a control parameter
1513    
1514             :param K_p: initial value for constant to adjust pressure tolerance
1515             :type K_p: ``float``
1516             :param K_v: initial value for constant to adjust velocity tolerance
1517             :type K_v: ``float``
1518             :param rtol_max: maximuim relative tolerance used to calculate presssure and velocity increment.
1519             :type rtol_max: ``float``
1520             :param chi_max: maximum tolerable converegence rate.
1521             :type chi_max: ``float``
1522             :param reduction_factor: reduction factor for adjustment factors.
1523             :type reduction_factor: ``float``
1524             """
1525             self.setControlParameter(K_p, K_v, rtol_max, rtol_min, chi_max, reduction_factor, theta)
1526    
1527          def setControlParameter(self,K_p=None, K_v=None, rtol_max=None, rtol_min=None, chi_max=None, reduction_factor=None, theta=None):
1528             """
1529             sets a control parameter
1530    
1531    
1532             :param K_p: initial value for constant to adjust pressure tolerance
1533             :type K_p: ``float``
1534             :param K_v: initial value for constant to adjust velocity tolerance
1535             :type K_v: ``float``
1536             :param rtol_max: maximuim relative tolerance used to calculate presssure and velocity increment.
1537             :type rtol_max: ``float``
1538             :param chi_max: maximum tolerable converegence rate.
1539             :type chi_max: ``float``
1540             :type reduction_factor: ``float``
1541             """
1542             if not K_p == None:
1543                if K_p<1:
1544                   raise ValueError,"K_p need to be greater or equal to 1."
1545             else:
1546                K_p=self.__K_p
1547    
1548             if not K_v == None:
1549                if K_v<1:
1550                   raise ValueError,"K_v need to be greater or equal to 1."
1551             else:
1552                K_v=self.__K_v
1553    
1554             if not rtol_max == None:
1555                if rtol_max<=0 or rtol_max>=1:
1556                   raise ValueError,"rtol_max needs to be positive and less than 1."
1557             else:
1558                rtol_max=self.__rtol_max
1559    
1560             if not rtol_min == None:
1561                if rtol_min<=0 or rtol_min>=1:
1562                   raise ValueError,"rtol_min needs to be positive and less than 1."
1563             else:
1564                rtol_min=self.__rtol_min
1565    
1566             if not chi_max == None:
1567                if chi_max<=0 or chi_max>=1:
1568                   raise ValueError,"chi_max needs to be positive and less than 1."
1569             else:
1570                chi_max = self.__chi_max
1571    
1572             if not reduction_factor == None:
1573                if reduction_factor<=0 or reduction_factor>1:
1574                   raise ValueError,"reduction_factor need to be between zero and one."
1575             else:
1576                reduction_factor=self.__reduction_factor
1577    
1578             if not theta == None:
1579                if theta<=0 or theta>1:
1580                   raise ValueError,"theta need to be between zero and one."
1581             else:
1582                theta=self.__theta
1583    
1584             if rtol_min>=rtol_max:
1585                 raise ValueError,"rtol_max = %e needs to be greater than rtol_min = %e"%(rtol_max,rtol_min)
1586             self.__chi_max = chi_max
1587             self.__rtol_max = rtol_max
1588             self.__K_p = K_p
1589             self.__K_v = K_v
1590             self.__reduction_factor = reduction_factor
1591             self.__theta = theta
1592             self.__rtol_min=rtol_min
1593    
1594          #=============================================================
1595          def inner_pBv(self,p,Bv):
1596             """
1597             Returns inner product of element p and Bv (overwrite).
1598    
1599             :param p: a pressure increment
1600             :param Bv: a residual
1601             :return: inner product of element p and Bv
1602             :rtype: ``float``
1603             :note: used if PCG is applied.
1604             """
1605             raise NotImplementedError,"no inner product for p and Bv implemented."
1606    
1607          def inner_p(self,p0,p1):
1608             """
1609             Returns inner product of p0 and p1 (overwrite).
1610    
1611             :param p0: a pressure
1612             :param p1: a pressure
1613             :return: inner product of p0 and p1
1614             :rtype: ``float``
1615             """
1616             raise NotImplementedError,"no inner product for p implemented."
1617      
1618          def norm_v(self,v):
1619             """
1620             Returns the norm of v (overwrite).
1621    
1622             :param v: a velovity
1623             :return: norm of v
1624             :rtype: non-negative ``float``
1625             """
1626             raise NotImplementedError,"no norm of v implemented."
1627          def getDV(self, p, v, tol):
1628             """
1629             return a correction to the value for a given v and a given p with accuracy `tol` (overwrite)
1630    
1631             :param p: pressure
1632             :param v: pressure
1633             :return: dv given as *dv= A^{-1} (f-A v-B^*p)*
1634             :note: Only *A* may depend on *v* and *p*
1635             """
1636             raise NotImplementedError,"no dv calculation implemented."
1637    
1638            
1639          def Bv(self,v, tol):
1640            """
1641            Returns Bv with accuracy `tol` (overwrite)
1642    
1643            :rtype: equal to the type of p
1644            :note: boundary conditions on p should be zero!
1645            """
1646            raise NotImplementedError, "no operator B implemented."
1647    
1648          def norm_Bv(self,Bv):
1649            """
1650            Returns the norm of Bv (overwrite).
1651    
1652            :rtype: equal to the type of p
1653            :note: boundary conditions on p should be zero!
1654            """
1655            raise NotImplementedError, "no norm of Bv implemented."
1656    
1657          def solve_AinvBt(self,dp, tol):
1658             """
1659             Solves *A dv=B^*dp* with accuracy `tol`
1660    
1661             :param dp: a pressure increment
1662             :return: the solution of *A dv=B^*dp*
1663             :note: boundary conditions on dv should be zero! *A* is the operator used in ``getDV`` and must not be altered.
1664             """
1665             raise NotImplementedError,"no operator A implemented."
1666    
1667          def solve_prec(self,Bv, tol):
1668             """
1669             Provides a preconditioner for *(BA^{-1}B^ * )* applied to Bv with accuracy `tol`
1670    
1671             :rtype: equal to the type of p
1672             :note: boundary conditions on p should be zero!
1673             """
1674             raise NotImplementedError,"no preconditioner for Schur complement implemented."
1675          #=============================================================
1676          def __Aprod_PCG(self,dp):
1677              dv=self.solve_AinvBt(dp, self.__subtol)
1678              return ArithmeticTuple(dv,self.Bv(dv, self.__subtol))
1679    
1680          def __inner_PCG(self,p,r):
1681             return self.inner_pBv(p,r[1])
1682    
1683          def __Msolve_PCG(self,r):
1684              return self.solve_prec(r[1], self.__subtol)
1685          #=============================================================
1686          def __Aprod_GMRES(self,p):
1687              return self.solve_prec(self.Bv(self.solve_AinvBt(p, self.__subtol), self.__subtol), self.__subtol)
1688          def __inner_GMRES(self,p0,p1):
1689             return self.inner_p(p0,p1)
1690    
1691          #=============================================================
1692          def norm_p(self,p):
1693              """
1694              calculates the norm of ``p``
1695              
1696              :param p: a pressure
1697              :return: the norm of ``p`` using the inner product for pressure
1698              :rtype: ``float``
1699              """
1700              f=self.inner_p(p,p)
1701              if f<0: raise ValueError,"negative pressure norm."
1702              return math.sqrt(f)
1703          
1704          def solve(self,v,p,max_iter=20, verbose=False, usePCG=True, iter_restart=20, max_correction_steps=10):
1705             """
1706             Solves the saddle point problem using initial guesses v and p.
1707    
1708             :param v: initial guess for velocity
1709             :param p: initial guess for pressure
1710             :type v: `Data`
1711             :type p: `Data`
1712             :param usePCG: indicates the usage of the PCG rather than GMRES scheme.
1713             :param max_iter: maximum number of iteration steps per correction
1714                              attempt
1715             :param verbose: if True, shows information on the progress of the
1716                             saddlepoint problem solver.
1717             :param iter_restart: restart the iteration after ``iter_restart`` steps
1718                                  (only used if useUzaw=False)
1719             :type usePCG: ``bool``
1720             :type max_iter: ``int``
1721             :type verbose: ``bool``
1722             :type iter_restart: ``int``
1723             :rtype: ``tuple`` of `Data` objects
1724             :note: typically this method is overwritten by a subclass. It provides a wrapper for the ``_solve`` method.
1725             """
1726             return self._solve(v=v,p=p,max_iter=max_iter,verbose=verbose, usePCG=usePCG, iter_restart=iter_restart, max_correction_steps=max_correction_steps)
1727    
1728          def _solve(self,v,p,max_iter=20, verbose=False, usePCG=True, iter_restart=20, max_correction_steps=10):
1729             """
1730             see `_solve` method.
1731             """
1732             self.verbose=verbose
1733             rtol=self.getTolerance()
1734             atol=self.getAbsoluteTolerance()
1735    
1736             K_p=self.__K_p
1737             K_v=self.__K_v
1738             correction_step=0
1739             converged=False
1740             chi=None
1741             eps=None
1742    
1743             if self.verbose: print "HomogeneousSaddlePointProblem: start iteration: rtol= %e, atol=%e"%(rtol, atol)
1744             while not converged:
1745    
1746                 # get tolerance for velecity increment:
1747                 if chi == None:
1748                    rtol_v=self.__rtol_max
1749                 else:
1750                    rtol_v=min(chi/K_v,self.__rtol_max)
1751                 rtol_v=max(rtol_v, self.__rtol_min)
1752                 if self.verbose: print "HomogeneousSaddlePointProblem: step %s: rtol_v= %e"%(correction_step,rtol_v)
1753                 # get velocity increment:
1754                 dv1=self.getDV(p,v,rtol_v)
1755                 v1=v+dv1
1756                 Bv1=self.Bv(v1, rtol_v)
1757                 norm_Bv1=self.norm_Bv(Bv1)
1758                 norm_dv1=self.norm_v(dv1)
1759                 if self.verbose: print "HomogeneousSaddlePointProblem: step %s: norm_Bv1 = %e, norm_dv1 = %e"%(correction_step, norm_Bv1, norm_dv1)
1760                 if norm_dv1*self.__theta < norm_Bv1:
1761                    # get tolerance for pressure increment:
1762                    large_Bv1=True
1763                    if chi == None or eps == None:
1764                       rtol_p=self.__rtol_max
1765                    else:
1766                       rtol_p=min(chi**2*eps/K_p/norm_Bv1, self.__rtol_max)
1767                    self.__subtol=max(rtol_p**2, self.__rtol_min)
1768                    if self.verbose: print "HomogeneousSaddlePointProblem: step %s: rtol_p= %e"%(correction_step,rtol_p)
1769                    # now we solve for the pressure increment dp from B*A^{-1}B^* dp = Bv1
1770                    if usePCG:
1771                        dp,r,a_norm=PCG(ArithmeticTuple(v1,Bv1),self.__Aprod_PCG,0*p,self.__Msolve_PCG,self.__inner_PCG,atol=0, rtol=rtol_p,iter_max=max_iter, verbose=self.verbose)
1772                        v2=r[0]
1773                        Bv2=r[1]
1774                    else:
1775                        # don't use!!!!
1776                        dp=GMRES(self.solve_prec(Bv1,self.__subtol),self.__Aprod_GMRES, 0*p, self.__inner_GMRES,atol=0, rtol=rtol_p,iter_max=max_iter, iter_restart=iter_restart, verbose=self.verbose)
1777                        dv2=self.solve_AinvBt(dp, self.__subtol)
1778                        v2=v1-dv2
1779                        Bv2=self.Bv(v2, self.__subtol)
1780                    p2=p+dp
1781                 else:
1782                    large_Bv1=False
1783                    v2=v1
1784                    p2=p
1785                 # update business:
1786                 norm_dv2=self.norm_v(v2-v)
1787                 norm_v2=self.norm_v(v2)
1788                 if self.verbose: print "HomogeneousSaddlePointProblem: step %s: v2 = %e, norm_dv2 = %e"%(correction_step, norm_v2, self.norm_v(v2-v))
1789                 eps, eps_old = max(norm_Bv1, norm_dv2), eps
1790                 if eps_old == None:
1791                      chi, chi_old = None, chi
1792                 else:
1793                      chi, chi_old = min(eps/ eps_old, self.__chi_max), chi
1794                 if eps != None:
1795                     if chi !=None:
1796                        if self.verbose: print "HomogeneousSaddlePointProblem: step %s: convergence rate = %e, correction = %e"%(correction_step,chi, eps)
1797                     else:
1798                        if self.verbose: print "HomogeneousSaddlePointProblem: step %s: correction = %e"%(correction_step, eps)
1799                 if eps <= rtol*norm_v2+atol :
1800                     converged = True
1801                 else:
1802                     if correction_step>=max_correction_steps:
1803                          raise CorrectionFailed,"Given up after %d correction steps."%correction_step
1804                     if chi_old!=None:
1805                        K_p=max(1,self.__reduction_factor*K_p,(chi-chi_old)/chi_old**2*K_p)
1806                        K_v=max(1,self.__reduction_factor*K_v,(chi-chi_old)/chi_old**2*K_p)
1807                        if self.verbose: print "HomogeneousSaddlePointProblem: step %s: new adjustment factor K = %e"%(correction_step,K_p)
1808                 correction_step+=1
1809                 v,p =v2, p2
1810             if self.verbose: print "HomogeneousSaddlePointProblem: tolerance reached after %s steps."%correction_step
1811         return v,p
1812          #========================================================================
1813          def setTolerance(self,tolerance=1.e-4):
1814             """
1815             Sets the relative tolerance for (v,p).
1816    
1817             :param tolerance: tolerance to be used
1818             :type tolerance: non-negative ``float``
1819             """
1820             if tolerance<0:
1821                 raise ValueError,"tolerance must be positive."
1822             self.__rtol=tolerance
1823    
1824          def getTolerance(self):
1825             """
1826             Returns the relative tolerance.
1827    
1828             :return: relative tolerance
1829             :rtype: ``float``
1830             """
1831             return self.__rtol
1832    
1833          def setAbsoluteTolerance(self,tolerance=0.):
1834             """
1835             Sets the absolute tolerance.
1836    
1837             :param tolerance: tolerance to be used
1838             :type tolerance: non-negative ``float``
1839             """
1840             if tolerance<0:
1841                 raise ValueError,"tolerance must be non-negative."
1842             self.__atol=tolerance
1843    
1844          def getAbsoluteTolerance(self):
1845             """
1846             Returns the absolute tolerance.
1847    
1848             :return: absolute tolerance
1849             :rtype: ``float``
1850             """
1851             return self.__atol
1852    
1853    
1854    def MaskFromBoundaryTag(domain,*tags):
1855       """
1856       Creates a mask on the Solution(domain) function space where the value is
1857       one for samples that touch the boundary tagged by tags.
1858    
1859       Usage: m=MaskFromBoundaryTag(domain, "left", "right")
1860    
1861       :param domain: domain to be used
1862       :type domain: `escript.Domain`
1863       :param tags: boundary tags
1864       :type tags: ``str``
1865       :return: a mask which marks samples that are touching the boundary tagged
1866                by any of the given tags
1867       :rtype: `escript.Data` of rank 0
1868       """
1869       pde=linearPDEs.LinearPDE(domain,numEquations=1, numSolutions=1)
1870       d=escript.Scalar(0.,escript.FunctionOnBoundary(domain))
1871       for t in tags: d.setTaggedValue(t,1.)
1872       pde.setValue(y=d)
1873       return util.whereNonZero(pde.getRightHandSide())
1874    
1875    def MaskFromTag(domain,*tags):
1876       """
1877       Creates a mask on the Solution(domain) function space where the value is
1878       one for samples that touch regions tagged by tags.
1879    
1880       Usage: m=MaskFromTag(domain, "ham")
1881    
1882       :param domain: domain to be used
1883       :type domain: `escript.Domain`
1884       :param tags: boundary tags
1885       :type tags: ``str``
1886       :return: a mask which marks samples that are touching the boundary tagged
1887                by any of the given tags
1888       :rtype: `escript.Data` of rank 0
1889       """
1890       pde=linearPDEs.LinearPDE(domain,numEquations=1, numSolutions=1)
1891       d=escript.Scalar(0.,escript.Function(domain))
1892       for t in tags: d.setTaggedValue(t,1.)
1893       pde.setValue(Y=d)
1894       return util.whereNonZero(pde.getRightHandSide())
1895    
1896    

Legend:
Removed from v.351  
changed lines
  Added in v.2862

  ViewVC Help
Powered by ViewVC 1.1.26