/[escript]/trunk/escript/py_src/pdetools.py
ViewVC logotype

Annotation of /trunk/escript/py_src/pdetools.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3922 - (hide annotations)
Mon Jul 9 02:19:51 2012 UTC (7 years, 7 months ago) by jfenwick
File MIME type: text/x-python
File size: 64404 byte(s)
Adds setvalue to locator
Addresses mantis #531

1 ksteube 1809
2     ########################################################
3 ksteube 1312 #
4 jfenwick 3911 # Copyright (c) 2003-2012 by University of Queensland
5 ksteube 1809 # Earth Systems Science Computational Center (ESSCC)
6     # http://www.uq.edu.au/esscc
7 ksteube 1312 #
8 ksteube 1809 # Primary Business: Queensland, Australia
9     # Licensed under the Open Software License version 3.0
10     # http://www.opensource.org/licenses/osl-3.0.php
11 ksteube 1312 #
12 ksteube 1809 ########################################################
13 jgs 121
14 jfenwick 3911 __copyright__="""Copyright (c) 2003-2012 by University of Queensland
15 ksteube 1809 Earth Systems Science Computational Center (ESSCC)
16     http://www.uq.edu.au/esscc
17     Primary Business: Queensland, Australia"""
18     __license__="""Licensed under the Open Software License version 3.0
19     http://www.opensource.org/licenses/osl-3.0.php"""
20 jfenwick 2344 __url__="https://launchpad.net/escript-finley"
21 ksteube 1809
22 jgs 121 """
23 caltinay 2169 Provides some tools related to PDEs.
24 jgs 121
25 jgs 149 Currently includes:
26 caltinay 2169 - Projector - to project a discontinuous function onto a continuous function
27 gross 351 - Locator - to trace values in data objects at a certain location
28 caltinay 2169 - TimeIntegrationManager - to handle extrapolation in time
29 gross 867 - SaddlePointProblem - solver for Saddle point problems using the inexact uszawa scheme
30 gross 637
31 jfenwick 2625 :var __author__: name of author
32     :var __copyright__: copyrights
33     :var __license__: licence agreement
34     :var __url__: url entry point on documentation
35     :var __version__: version
36     :var __date__: date of the version
37 jgs 121 """
38    
39 gross 637 __author__="Lutz Gross, l.gross@uq.edu.au"
40 elspeth 609
41 gross 637
42 jfenwick 3771 from . import escript
43     from . import linearPDEs
44 jfenwick 2455 import numpy
45 jfenwick 3771 from . import util
46 ksteube 1312 import math
47 jgs 121
48 gross 351 class TimeIntegrationManager:
49     """
50 caltinay 2169 A simple mechanism to manage time dependend values.
51 gross 351
52 caltinay 2169 Typical usage is::
53 gross 351
54 gross 720 dt=0.1 # time increment
55     tm=TimeIntegrationManager(inital_value,p=1)
56     while t<1.
57     v_guess=tm.extrapolate(dt) # extrapolate to t+dt
58     v=...
59     tm.checkin(dt,v)
60     t+=dt
61 gross 351
62 jfenwick 2625 :note: currently only p=1 is supported.
63 gross 351 """
64     def __init__(self,*inital_values,**kwargs):
65     """
66 jfenwick 2625 Sets up the value manager where ``inital_values`` are the initial values
67 caltinay 2169 and p is the order used for extrapolation.
68 gross 351 """
69 jfenwick 3771 if "p" in kwargs:
70 gross 351 self.__p=kwargs["p"]
71     else:
72     self.__p=1
73 jfenwick 3771 if "time" in kwargs:
74 gross 351 self.__t=kwargs["time"]
75     else:
76     self.__t=0.
77     self.__v_mem=[inital_values]
78     self.__order=0
79     self.__dt_mem=[]
80     self.__num_val=len(inital_values)
81    
82     def getTime(self):
83     return self.__t
84 caltinay 2169
85 gross 396 def getValue(self):
86 gross 409 out=self.__v_mem[0]
87     if len(out)==1:
88     return out[0]
89     else:
90     return out
91    
92 gross 351 def checkin(self,dt,*values):
93     """
94 caltinay 2169 Adds new values to the manager. The p+1 last values are lost.
95 gross 351 """
96     o=min(self.__order+1,self.__p)
97     self.__order=min(self.__order+1,self.__p)
98     v_mem_new=[values]
99     dt_mem_new=[dt]
100     for i in range(o-1):
101     v_mem_new.append(self.__v_mem[i])
102     dt_mem_new.append(self.__dt_mem[i])
103     v_mem_new.append(self.__v_mem[o-1])
104     self.__order=o
105     self.__v_mem=v_mem_new
106     self.__dt_mem=dt_mem_new
107     self.__t+=dt
108    
109     def extrapolate(self,dt):
110     """
111 jfenwick 2625 Extrapolates to ``dt`` forward in time.
112 gross 351 """
113     if self.__order==0:
114     out=self.__v_mem[0]
115     else:
116     out=[]
117     for i in range(self.__num_val):
118     out.append((1.+dt/self.__dt_mem[0])*self.__v_mem[0][i]-dt/self.__dt_mem[0]*self.__v_mem[1][i])
119    
120     if len(out)==0:
121     return None
122     elif len(out)==1:
123     return out[0]
124     else:
125     return out
126    
127 caltinay 2169
128 jgs 121 class Projector:
129 jgs 149 """
130 caltinay 2169 The Projector is a factory which projects a discontinuous function onto a
131     continuous function on a given domain.
132 jgs 149 """
133 caltinay 2169 def __init__(self, domain, reduce=True, fast=True):
134 jgs 121 """
135 caltinay 2169 Creates a continuous function space projector for a domain.
136 jgs 121
137 jfenwick 2625 :param domain: Domain of the projection.
138     :param reduce: Flag to reduce projection order
139     :param fast: Flag to use a fast method based on matrix lumping
140 jgs 121 """
141 jgs 149 self.__pde = linearPDEs.LinearPDE(domain)
142 jgs 148 if fast:
143 gross 2474 self.__pde.getSolverOptions().setSolverMethod(linearPDEs.SolverOptions.LUMPING)
144 jgs 121 self.__pde.setSymmetryOn()
145     self.__pde.setReducedOrderTo(reduce)
146     self.__pde.setValue(D = 1.)
147 ksteube 1312 return
148 gross 2474 def getSolverOptions(self):
149 artak 2693 """
150     Returns the solver options of the PDE solver.
151    
152     :rtype: `linearPDEs.SolverOptions`
153     """
154     return self.__pde.getSolverOptions()
155 jgs 121
156 gross 2867 def getValue(self, input_data):
157     """
158     Projects ``input_data`` onto a continuous function.
159    
160     :param input_data: the data to be projected
161     """
162     return self(input_data)
163    
164 jgs 121 def __call__(self, input_data):
165     """
166 jfenwick 2625 Projects ``input_data`` onto a continuous function.
167 jgs 121
168 jfenwick 2625 :param input_data: the data to be projected
169 jgs 121 """
170 gross 525 out=escript.Data(0.,input_data.getShape(),self.__pde.getFunctionSpaceForSolution())
171 gross 1122 self.__pde.setValue(Y = escript.Data(), Y_reduced = escript.Data())
172 jgs 121 if input_data.getRank()==0:
173     self.__pde.setValue(Y = input_data)
174     out=self.__pde.getSolution()
175     elif input_data.getRank()==1:
176     for i0 in range(input_data.getShape()[0]):
177     self.__pde.setValue(Y = input_data[i0])
178     out[i0]=self.__pde.getSolution()
179     elif input_data.getRank()==2:
180     for i0 in range(input_data.getShape()[0]):
181     for i1 in range(input_data.getShape()[1]):
182     self.__pde.setValue(Y = input_data[i0,i1])
183     out[i0,i1]=self.__pde.getSolution()
184     elif input_data.getRank()==3:
185     for i0 in range(input_data.getShape()[0]):
186     for i1 in range(input_data.getShape()[1]):
187     for i2 in range(input_data.getShape()[2]):
188     self.__pde.setValue(Y = input_data[i0,i1,i2])
189     out[i0,i1,i2]=self.__pde.getSolution()
190     else:
191     for i0 in range(input_data.getShape()[0]):
192     for i1 in range(input_data.getShape()[1]):
193     for i2 in range(input_data.getShape()[2]):
194     for i3 in range(input_data.getShape()[3]):
195     self.__pde.setValue(Y = input_data[i0,i1,i2,i3])
196     out[i0,i1,i2,i3]=self.__pde.getSolution()
197     return out
198    
199 gross 525 class NoPDE:
200     """
201 caltinay 2169 Solves the following problem for u:
202 jgs 121
203 jfenwick 2625 *kronecker[i,j]*D[j]*u[j]=Y[i]*
204 gross 525
205     with constraint
206    
207 jfenwick 2625 *u[j]=r[j]* where *q[j]>0*
208 gross 525
209 jfenwick 2625 where *D*, *Y*, *r* and *q* are given functions of rank 1.
210 gross 525
211     In the case of scalars this takes the form
212    
213 jfenwick 2625 *D*u=Y*
214 gross 525
215     with constraint
216    
217 jfenwick 2625 *u=r* where *q>0*
218 gross 525
219 jfenwick 2625 where *D*, *Y*, *r* and *q* are given scalar functions.
220 gross 525
221 caltinay 2169 The constraint overwrites any other condition.
222 gross 525
223 jfenwick 2625 :note: This class is similar to the `linearPDEs.LinearPDE` class with
224 caltinay 2169 A=B=C=X=0 but has the intention that all input parameters are given
225 jfenwick 2625 in `Solution` or `ReducedSolution`.
226 gross 525 """
227 caltinay 2169 # The whole thing is a bit strange and I blame Rob Woodcock (CSIRO) for
228     # this.
229 gross 525 def __init__(self,domain,D=None,Y=None,q=None,r=None):
230     """
231 caltinay 2169 Initializes the problem.
232 gross 525
233 jfenwick 2625 :param domain: domain of the PDE
234     :type domain: `Domain`
235     :param D: coefficient of the solution
236     :type D: ``float``, ``int``, ``numpy.ndarray``, `Data`
237     :param Y: right hand side
238     :type Y: ``float``, ``int``, ``numpy.ndarray``, `Data`
239     :param q: location of constraints
240     :type q: ``float``, ``int``, ``numpy.ndarray``, `Data`
241     :param r: value of solution at locations of constraints
242     :type r: ``float``, ``int``, ``numpy.ndarray``, `Data`
243 gross 525 """
244     self.__domain=domain
245     self.__D=D
246     self.__Y=Y
247     self.__q=q
248     self.__r=r
249     self.__u=None
250     self.__function_space=escript.Solution(self.__domain)
251 caltinay 2169
252 gross 525 def setReducedOn(self):
253     """
254 jfenwick 2625 Sets the `FunctionSpace` of the solution to `ReducedSolution`.
255 gross 525 """
256     self.__function_space=escript.ReducedSolution(self.__domain)
257     self.__u=None
258    
259     def setReducedOff(self):
260     """
261 jfenwick 2625 Sets the `FunctionSpace` of the solution to `Solution`.
262 gross 525 """
263     self.__function_space=escript.Solution(self.__domain)
264     self.__u=None
265 caltinay 2169
266 gross 525 def setValue(self,D=None,Y=None,q=None,r=None):
267     """
268 caltinay 2169 Assigns values to the parameters.
269 gross 525
270 jfenwick 2625 :param D: coefficient of the solution
271     :type D: ``float``, ``int``, ``numpy.ndarray``, `Data`
272     :param Y: right hand side
273     :type Y: ``float``, ``int``, ``numpy.ndarray``, `Data`
274     :param q: location of constraints
275     :type q: ``float``, ``int``, ``numpy.ndarray``, `Data`
276     :param r: value of solution at locations of constraints
277     :type r: ``float``, ``int``, ``numpy.ndarray``, `Data`
278 gross 525 """
279     if not D==None:
280     self.__D=D
281     self.__u=None
282     if not Y==None:
283     self.__Y=Y
284     self.__u=None
285     if not q==None:
286     self.__q=q
287     self.__u=None
288     if not r==None:
289     self.__r=r
290     self.__u=None
291    
292     def getSolution(self):
293     """
294 caltinay 2169 Returns the solution.
295    
296 jfenwick 2625 :return: the solution of the problem
297     :rtype: `Data` object in the `FunctionSpace` `Solution` or
298     `ReducedSolution`
299 gross 525 """
300     if self.__u==None:
301     if self.__D==None:
302 jfenwick 3771 raise ValueError("coefficient D is undefined")
303 gross 525 D=escript.Data(self.__D,self.__function_space)
304     if D.getRank()>1:
305 jfenwick 3771 raise ValueError("coefficient D must have rank 0 or 1")
306 gross 525 if self.__Y==None:
307     self.__u=escript.Data(0.,D.getShape(),self.__function_space)
308     else:
309     self.__u=util.quotient(self.__Y,D)
310     if not self.__q==None:
311     q=util.wherePositive(escript.Data(self.__q,self.__function_space))
312     self.__u*=(1.-q)
313     if not self.__r==None: self.__u+=q*self.__r
314     return self.__u
315 caltinay 2169
316 jgs 147 class Locator:
317     """
318 caltinay 2169 Locator provides access to the values of data objects at a given spatial
319     coordinate x.
320    
321 jgs 149 In fact, a Locator object finds the sample in the set of samples of a
322 caltinay 2169 given function space or domain which is closest to the given point x.
323 jgs 147 """
324    
325 jfenwick 2455 def __init__(self,where,x=numpy.zeros((3,))):
326 jgs 149 """
327 caltinay 2169 Initializes a Locator to access values in Data objects on the Doamin
328     or FunctionSpace for the sample point which is closest to the given
329     point x.
330 gross 880
331 jfenwick 2625 :param where: function space
332     :type where: `escript.FunctionSpace`
333     :param x: location(s) of the Locator
334     :type x: ``numpy.ndarray`` or ``list`` of ``numpy.ndarray``
335 jgs 149 """
336     if isinstance(where,escript.FunctionSpace):
337 jgs 147 self.__function_space=where
338 jgs 121 else:
339 jgs 149 self.__function_space=escript.ContinuousFunction(where)
340 gross 2484 iterative=False
341 gross 880 if isinstance(x, list):
342 gross 2484 if len(x)==0:
343 jfenwick 3771 raise ValueError("At least one point must be given.")
344 gross 2484 try:
345     iter(x[0])
346     iterative=True
347     except TypeError:
348     iterative=False
349 gross 2954 xxx=self.__function_space.getX()
350 gross 2484 if iterative:
351 gross 880 self.__id=[]
352     for p in x:
353 gross 2948 self.__id.append(util.length(xxx-p[:self.__function_space.getDim()]).minGlobalDataPoint())
354 gross 880 else:
355 gross 2948 self.__id=util.length(xxx-x[:self.__function_space.getDim()]).minGlobalDataPoint()
356 jgs 121
357 jgs 147 def __str__(self):
358 jgs 149 """
359     Returns the coordinates of the Locator as a string.
360     """
361 gross 880 x=self.getX()
362 gross 2676 if isinstance(x,list):
363 gross 880 out="["
364     first=True
365     for xx in x:
366     if not first:
367     out+=","
368     else:
369     first=False
370     out+=str(xx)
371     out+="]>"
372     else:
373     out=str(x)
374     return out
375 jgs 121
376 gross 880 def getX(self):
377     """
378 caltinay 2169 Returns the exact coordinates of the Locator.
379     """
380 gross 880 return self(self.getFunctionSpace().getX())
381    
382 jgs 147 def getFunctionSpace(self):
383 jgs 149 """
384 caltinay 2169 Returns the function space of the Locator.
385     """
386 jgs 147 return self.__function_space
387    
388 gross 880 def getId(self,item=None):
389 jgs 149 """
390 caltinay 2169 Returns the identifier of the location.
391     """
392 gross 880 if item == None:
393     return self.__id
394     else:
395     if isinstance(self.__id,list):
396     return self.__id[item]
397     else:
398     return self.__id
399 jgs 121
400    
401 jgs 147 def __call__(self,data):
402 jgs 149 """
403 caltinay 2169 Returns the value of data at the Locator of a Data object.
404     """
405 jgs 147 return self.getValue(data)
406 jgs 121
407 jgs 147 def getValue(self,data):
408 jgs 149 """
409 jfenwick 2625 Returns the value of ``data`` at the Locator if ``data`` is a `Data`
410 caltinay 2169 object otherwise the object is returned.
411     """
412 jgs 149 if isinstance(data,escript.Data):
413 jfenwick 2517 dat=util.interpolate(data,self.getFunctionSpace())
414 gross 880 id=self.getId()
415     r=data.getRank()
416     if isinstance(id,list):
417     out=[]
418     for i in id:
419 jfenwick 2517 o=numpy.array(dat.getTupleForGlobalDataPoint(*i))
420 gross 880 if data.getRank()==0:
421     out.append(o[0])
422     else:
423     out.append(o)
424     return out
425 jgs 147 else:
426 jfenwick 2517 out=numpy.array(dat.getTupleForGlobalDataPoint(*id))
427 gross 880 if data.getRank()==0:
428     return out[0]
429     else:
430     return out
431 jgs 147 else:
432     return data
433 jfenwick 3574
434 jfenwick 3922 def setValue(self, data, v):
435     """
436     Sets the value of the ``data`` at the Locator.
437     """
438     data.expand()
439     if isinstance(data, escript.Data):
440     id=self.getId()
441     if isinstance(id, list):
442     for i in id:
443     data._setTupleForGlobalDataPoint(i[1], i[0], v)
444     else:
445     data._setTupleForGlobalDataPoint(id[1], id[0], v)
446     else:
447     raise TypeError, "setValue: Invalid argument type."
448 jgs 149
449 jfenwick 2745
450     def getInfLocator(arg):
451     """
452     Return a Locator for a point with the inf value over all arg.
453     """
454     if not isinstance(arg, escript.Data):
455 jfenwick 3852 raise TypeError("getInfLocator: Unknown argument type.")
456 jfenwick 2745 a_inf=util.inf(arg)
457     loc=util.length(arg-a_inf).minGlobalDataPoint() # This gives us the location but not coords
458     x=arg.getFunctionSpace().getX()
459     x_min=x.getTupleForGlobalDataPoint(*loc)
460     return Locator(arg.getFunctionSpace(),x_min)
461    
462     def getSupLocator(arg):
463     """
464     Return a Locator for a point with the sup value over all arg.
465     """
466     if not isinstance(arg, escript.Data):
467 jfenwick 3852 raise TypeError("getInfLocator: Unknown argument type.")
468 jfenwick 2745 a_inf=util.sup(arg)
469     loc=util.length(arg-a_inf).minGlobalDataPoint() # This gives us the location but not coords
470     x=arg.getFunctionSpace().getX()
471     x_min=x.getTupleForGlobalDataPoint(*loc)
472     return Locator(arg.getFunctionSpace(),x_min)
473    
474    
475 ksteube 1312 class SolverSchemeException(Exception):
476     """
477 caltinay 2169 This is a generic exception thrown by solvers.
478 ksteube 1312 """
479     pass
480    
481     class IndefinitePreconditioner(SolverSchemeException):
482     """
483 caltinay 2169 Exception thrown if the preconditioner is not positive definite.
484 ksteube 1312 """
485     pass
486 caltinay 2169
487 ksteube 1312 class MaxIterReached(SolverSchemeException):
488     """
489 caltinay 2169 Exception thrown if the maximum number of iteration steps is reached.
490 ksteube 1312 """
491     pass
492 caltinay 2169
493 gross 2100 class CorrectionFailed(SolverSchemeException):
494     """
495 caltinay 2169 Exception thrown if no convergence has been achieved in the solution
496     correction scheme.
497 gross 2100 """
498     pass
499 caltinay 2169
500 ksteube 1312 class IterationBreakDown(SolverSchemeException):
501     """
502 caltinay 2169 Exception thrown if the iteration scheme encountered an incurable breakdown.
503 ksteube 1312 """
504     pass
505 caltinay 2169
506 ksteube 1312 class NegativeNorm(SolverSchemeException):
507     """
508 caltinay 2169 Exception thrown if a norm calculation returns a negative norm.
509 ksteube 1312 """
510     pass
511    
512 gross 2156 def PCG(r, Aprod, x, Msolve, bilinearform, atol=0, rtol=1.e-8, iter_max=100, initial_guess=True, verbose=False):
513 ksteube 1312 """
514 caltinay 2169 Solver for
515 ksteube 1312
516 jfenwick 2625 *Ax=b*
517 ksteube 1312
518 caltinay 2169 with a symmetric and positive definite operator A (more details required!).
519     It uses the conjugate gradient method with preconditioner M providing an
520     approximation of A.
521 ksteube 1312
522 caltinay 2169 The iteration is terminated if
523 ksteube 1312
524 jfenwick 2625 *|r| <= atol+rtol*|r0|*
525 gross 2156
526 jfenwick 2625 where *r0* is the initial residual and *|.|* is the energy norm. In fact
527 gross 2156
528 jfenwick 2625 *|r| = sqrt( bilinearform(Msolve(r),r))*
529 gross 2156
530 ksteube 1312 For details on the preconditioned conjugate gradient method see the book:
531    
532 caltinay 2169 I{Templates for the Solution of Linear Systems by R. Barrett, M. Berry,
533     T.F. Chan, J. Demmel, J. Donato, J. Dongarra, V. Eijkhout, R. Pozo,
534     C. Romine, and H. van der Vorst}.
535 ksteube 1312
536 jfenwick 2625 :param r: initial residual *r=b-Ax*. ``r`` is altered.
537     :type r: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
538     :param x: an initial guess for the solution
539     :type x: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
540     :param Aprod: returns the value Ax
541     :type Aprod: function ``Aprod(x)`` where ``x`` is of the same object like
542     argument ``x``. The returned object needs to be of the same type
543     like argument ``r``.
544     :param Msolve: solves Mx=r
545     :type Msolve: function ``Msolve(r)`` where ``r`` is of the same type like
546     argument ``r``. The returned object needs to be of the same
547     type like argument ``x``.
548     :param bilinearform: inner product ``<x,r>``
549     :type bilinearform: function ``bilinearform(x,r)`` where ``x`` is of the same
550     type like argument ``x`` and ``r`` is. The returned value
551     is a ``float``.
552     :param atol: absolute tolerance
553     :type atol: non-negative ``float``
554     :param rtol: relative tolerance
555     :type rtol: non-negative ``float``
556     :param iter_max: maximum number of iteration steps
557     :type iter_max: ``int``
558     :return: the solution approximation and the corresponding residual
559     :rtype: ``tuple``
560     :warning: ``r`` and ``x`` are altered.
561 ksteube 1312 """
562     iter=0
563     rhat=Msolve(r)
564 caltinay 2169 d = rhat
565 ksteube 1312 rhat_dot_r = bilinearform(rhat, r)
566 jfenwick 3771 if rhat_dot_r<0: raise NegativeNorm("negative norm.")
567 gross 2156 norm_r0=math.sqrt(rhat_dot_r)
568     atol2=atol+rtol*norm_r0
569     if atol2<=0:
570 jfenwick 3771 raise ValueError("Non-positive tolarance.")
571 gross 2156 atol2=max(atol2, 100. * util.EPSILON * norm_r0)
572 ksteube 1312
573 jfenwick 3771 if verbose: print(("PCG: initial residual norm = %e (absolute tolerance = %e)"%(norm_r0, atol2)))
574 gross 2156
575 caltinay 2169
576 gross 2156 while not math.sqrt(rhat_dot_r) <= atol2:
577 gross 1330 iter+=1
578 jfenwick 3771 if iter >= iter_max: raise MaxIterReached("maximum number of %s steps reached."%iter_max)
579 ksteube 1312
580     q=Aprod(d)
581     alpha = rhat_dot_r / bilinearform(d, q)
582     x += alpha * d
583 jfenwick 2455 if isinstance(q,ArithmeticTuple):
584 jfenwick 3852 r += q * (-alpha) # Doing it the other way calls the float64.__mul__ not AT.__rmul__
585 jfenwick 2455 else:
586     r += (-alpha) * q
587 ksteube 1312 rhat=Msolve(r)
588     rhat_dot_r_new = bilinearform(rhat, r)
589     beta = rhat_dot_r_new / rhat_dot_r
590     rhat+=beta * d
591     d=rhat
592    
593     rhat_dot_r = rhat_dot_r_new
594 jfenwick 3771 if rhat_dot_r<0: raise NegativeNorm("negative norm.")
595     if verbose: print(("PCG: iteration step %s: residual norm = %e"%(iter, math.sqrt(rhat_dot_r))))
596     if verbose: print(("PCG: tolerance reached after %s steps."%iter))
597 gross 2264 return x,r,math.sqrt(rhat_dot_r)
598 ksteube 1312
599 gross 1878 class Defect(object):
600     """
601 caltinay 2169 Defines a non-linear defect F(x) of a variable x.
602 gross 1878 """
603     def __init__(self):
604     """
605 caltinay 2169 Initializes defect.
606 gross 1878 """
607     self.setDerivativeIncrementLength()
608 artak 1465
609 gross 1878 def bilinearform(self, x0, x1):
610     """
611 caltinay 2169 Returns the inner product of x0 and x1
612    
613 jfenwick 2625 :param x0: value for x0
614     :param x1: value for x1
615     :return: the inner product of x0 and x1
616     :rtype: ``float``
617 gross 1878 """
618     return 0
619 caltinay 2169
620 gross 1878 def norm(self,x):
621     """
622 jfenwick 2625 Returns the norm of argument ``x``.
623 caltinay 2169
624 jfenwick 2625 :param x: a value
625     :return: norm of argument x
626     :rtype: ``float``
627     :note: by default ``sqrt(self.bilinearform(x,x)`` is returned.
628 gross 1878 """
629     s=self.bilinearform(x,x)
630 jfenwick 3771 if s<0: raise NegativeNorm("negative norm.")
631 gross 1878 return math.sqrt(s)
632 artak 1465
633 gross 1878 def eval(self,x):
634     """
635 jfenwick 2625 Returns the value F of a given ``x``.
636 gross 1878
637 jfenwick 2625 :param x: value for which the defect ``F`` is evaluated
638     :return: value of the defect at ``x``
639 gross 1878 """
640     return 0
641    
642     def __call__(self,x):
643     return self.eval(x)
644    
645 gross 2683 def setDerivativeIncrementLength(self,inc=1000.*math.sqrt(util.EPSILON)):
646 gross 1878 """
647 caltinay 2169 Sets the relative length of the increment used to approximate the
648     derivative of the defect. The increment is inc*norm(x)/norm(v)*v in the
649     direction of v with x as a starting point.
650 gross 1878
651 jfenwick 2625 :param inc: relative increment length
652     :type inc: positive ``float``
653 gross 1878 """
654 jfenwick 3771 if inc<=0: raise ValueError("positive increment required.")
655 gross 1878 self.__inc=inc
656    
657     def getDerivativeIncrementLength(self):
658     """
659 caltinay 2169 Returns the relative increment length used to approximate the
660     derivative of the defect.
661 jfenwick 2625 :return: value of the defect at ``x``
662     :rtype: positive ``float``
663 gross 1878 """
664     return self.__inc
665    
666     def derivative(self, F0, x0, v, v_is_normalised=True):
667     """
668 jfenwick 2625 Returns the directional derivative at ``x0`` in the direction of ``v``.
669 gross 1878
670 jfenwick 2625 :param F0: value of this defect at x0
671     :param x0: value at which derivative is calculated
672     :param v: direction
673     :param v_is_normalised: True to indicate that ``v`` is nomalized
674 caltinay 2169 (self.norm(v)=0)
675 jfenwick 2625 :return: derivative of this defect at x0 in the direction of ``v``
676     :note: by default numerical evaluation (self.eval(x0+eps*v)-F0)/eps is
677 caltinay 2169 used but this method maybe overwritten to use exact evaluation.
678 gross 1878 """
679     normx=self.norm(x0)
680     if normx>0:
681     epsnew = self.getDerivativeIncrementLength() * normx
682     else:
683     epsnew = self.getDerivativeIncrementLength()
684     if not v_is_normalised:
685     normv=self.norm(v)
686     if normv<=0:
687     return F0*0
688     else:
689     epsnew /= normv
690     F1=self.eval(x0 + epsnew * v)
691     return (F1-F0)/epsnew
692    
693 caltinay 2169 ######################################
694 gross 2719 def NewtonGMRES(defect, x, iter_max=100, sub_iter_max=20, atol=0,rtol=1.e-4, subtol_max=0.5, gamma=0.9, verbose=False):
695 gross 1878 """
696 jfenwick 2625 Solves a non-linear problem *F(x)=0* for unknown *x* using the stopping
697 caltinay 2169 criterion:
698 gross 1878
699 jfenwick 2625 *norm(F(x) <= atol + rtol * norm(F(x0)*
700 caltinay 2169
701 jfenwick 2625 where *x0* is the initial guess.
702 gross 1878
703 jfenwick 2625 :param defect: object defining the function *F*. ``defect.norm`` defines the
704     *norm* used in the stopping criterion.
705     :type defect: `Defect`
706     :param x: initial guess for the solution, ``x`` is altered.
707     :type x: any object type allowing basic operations such as
708     ``numpy.ndarray``, `Data`
709     :param iter_max: maximum number of iteration steps
710     :type iter_max: positive ``int``
711     :param sub_iter_max: maximum number of inner iteration steps
712     :type sub_iter_max: positive ``int``
713     :param atol: absolute tolerance for the solution
714     :type atol: positive ``float``
715     :param rtol: relative tolerance for the solution
716     :type rtol: positive ``float``
717     :param gamma: tolerance safety factor for inner iteration
718     :type gamma: positive ``float``, less than 1
719 gross 2719 :param subtol_max: upper bound for inner tolerance
720     :type subtol_max: positive ``float``, less than 1
721 jfenwick 2625 :return: an approximation of the solution with the desired accuracy
722     :rtype: same type as the initial guess
723 gross 1878 """
724     lmaxit=iter_max
725 jfenwick 3771 if atol<0: raise ValueError("atol needs to be non-negative.")
726     if rtol<0: raise ValueError("rtol needs to be non-negative.")
727     if rtol+atol<=0: raise ValueError("rtol or atol needs to be non-negative.")
728     if gamma<=0 or gamma>=1: raise ValueError("tolerance safety factor for inner iteration (gamma =%s) needs to be positive and less than 1."%gamma)
729     if subtol_max<=0 or subtol_max>=1: raise ValueError("upper bound for inner tolerance for inner iteration (subtol_max =%s) needs to be positive and less than 1."%subtol_max)
730 gross 1878
731     F=defect(x)
732     fnrm=defect.norm(F)
733     stop_tol=atol + rtol*fnrm
734 gross 2719 subtol=subtol_max
735 jfenwick 3771 if verbose: print(("NewtonGMRES: initial residual = %e."%fnrm))
736     if verbose: print((" tolerance = %e."%subtol))
737 gross 1878 iter=1
738     #
739     # main iteration loop
740     #
741     while not fnrm<=stop_tol:
742 jfenwick 3771 if iter >= iter_max: raise MaxIterReached("maximum number of %s steps reached."%iter_max)
743 gross 1878 #
744 gross 2719 # adjust subtol_
745 gross 1878 #
746     if iter > 1:
747 jfenwick 3852 rat=fnrm/fnrmo
748 gross 2719 subtol_old=subtol
749 jfenwick 3852 subtol=gamma*rat**2
750     if gamma*subtol_old**2 > .1: subtol=max(subtol,gamma*subtol_old**2)
751     subtol=max(min(subtol,subtol_max), .5*stop_tol/fnrm)
752 gross 1878 #
753     # calculate newton increment xc
754     # if iter_max in __FDGMRES is reached MaxIterReached is thrown
755     # if iter_restart -1 is returned as sub_iter
756     # if atol is reached sub_iter returns the numer of steps performed to get there
757 caltinay 2169 #
758     #
759 jfenwick 3771 if verbose: print((" subiteration (GMRES) is called with relative tolerance %e."%subtol))
760 gross 1878 try:
761 gross 2719 xc, sub_iter=__FDGMRES(F, defect, x, subtol*fnrm, iter_max=iter_max-iter, iter_restart=sub_iter_max)
762 gross 1878 except MaxIterReached:
763 jfenwick 3771 raise MaxIterReached("maximum number of %s steps reached."%iter_max)
764 gross 1878 if sub_iter<0:
765     iter+=sub_iter_max
766     else:
767     iter+=sub_iter
768     # ====
769 jfenwick 3852 x+=xc
770 gross 1878 F=defect(x)
771 jfenwick 3852 iter+=1
772 gross 1878 fnrmo, fnrm=fnrm, defect.norm(F)
773 jfenwick 3771 if verbose: print((" step %s: residual %e."%(iter,fnrm)))
774     if verbose: print(("NewtonGMRES: completed after %s steps."%iter))
775 gross 1878 return x
776    
777     def __givapp(c,s,vin):
778     """
779 caltinay 2169 Applies a sequence of Givens rotations (c,s) recursively to the vector
780 jfenwick 2625 ``vin``
781 caltinay 2169
782 jfenwick 2625 :warning: ``vin`` is altered.
783 gross 1878 """
784 caltinay 2169 vrot=vin
785 gross 1467 if isinstance(c,float):
786 artak 2303 vrot=[c*vrot[0]-s*vrot[1],s*vrot[0]+c*vrot[1]]
787 gross 1467 else:
788     for i in range(len(c)):
789     w1=c[i]*vrot[i]-s[i]*vrot[i+1]
790 jfenwick 3852 w2=s[i]*vrot[i]+c[i]*vrot[i+1]
791 gross 2284 vrot[i]=w1
792     vrot[i+1]=w2
793 artak 1465 return vrot
794    
795 gross 1878 def __FDGMRES(F0, defect, x0, atol, iter_max=100, iter_restart=20):
796 jfenwick 2455 h=numpy.zeros((iter_restart,iter_restart),numpy.float64)
797     c=numpy.zeros(iter_restart,numpy.float64)
798     s=numpy.zeros(iter_restart,numpy.float64)
799     g=numpy.zeros(iter_restart,numpy.float64)
800 artak 1465 v=[]
801    
802 gross 1878 rho=defect.norm(F0)
803     if rho<=0.: return x0*0
804 caltinay 2169
805 gross 1878 v.append(-F0/rho)
806 artak 1465 g[0]=rho
807 gross 1878 iter=0
808     while rho > atol and iter<iter_restart-1:
809 caltinay 2169 if iter >= iter_max:
810 jfenwick 3771 raise MaxIterReached("maximum number of %s steps reached."%iter_max)
811 artak 1557
812 gross 1878 p=defect.derivative(F0,x0,v[iter], v_is_normalised=True)
813 caltinay 2169 v.append(p)
814 artak 1465
815 caltinay 2169 v_norm1=defect.norm(v[iter+1])
816 artak 1465
817 caltinay 2169 # Modified Gram-Schmidt
818     for j in range(iter+1):
819     h[j,iter]=defect.bilinearform(v[j],v[iter+1])
820     v[iter+1]-=h[j,iter]*v[j]
821 artak 1465
822 caltinay 2169 h[iter+1,iter]=defect.norm(v[iter+1])
823     v_norm2=h[iter+1,iter]
824    
825 gross 1878 # Reorthogonalize if needed
826 caltinay 2169 if v_norm1 + 0.001*v_norm2 == v_norm1: #Brown/Hindmarsh condition (default)
827     for j in range(iter+1):
828     hr=defect.bilinearform(v[j],v[iter+1])
829     h[j,iter]=h[j,iter]+hr
830     v[iter+1] -= hr*v[j]
831 artak 1465
832 caltinay 2169 v_norm2=defect.norm(v[iter+1])
833     h[iter+1,iter]=v_norm2
834     # watch out for happy breakdown
835 artak 1550 if not v_norm2 == 0:
836 caltinay 2169 v[iter+1]=v[iter+1]/h[iter+1,iter]
837 artak 1465
838 gross 1878 # Form and store the information for the new Givens rotation
839 caltinay 2169 if iter > 0 :
840 jfenwick 2455 hhat=numpy.zeros(iter+1,numpy.float64)
841 caltinay 2169 for i in range(iter+1) : hhat[i]=h[i,iter]
842     hhat=__givapp(c[0:iter],s[0:iter],hhat);
843     for i in range(iter+1) : h[i,iter]=hhat[i]
844 artak 1465
845 caltinay 2169 mu=math.sqrt(h[iter,iter]*h[iter,iter]+h[iter+1,iter]*h[iter+1,iter])
846 artak 1557
847 caltinay 2169 if mu!=0 :
848     c[iter]=h[iter,iter]/mu
849     s[iter]=-h[iter+1,iter]/mu
850     h[iter,iter]=c[iter]*h[iter,iter]-s[iter]*h[iter+1,iter]
851     h[iter+1,iter]=0.0
852 artak 2303 gg=__givapp(c[iter],s[iter],[g[iter],g[iter+1]])
853 gross 2284 g[iter]=gg[0]
854     g[iter+1]=gg[1]
855 artak 1465
856 gross 1878 # Update the residual norm
857 artak 1465 rho=abs(g[iter+1])
858 caltinay 2169 iter+=1
859 artak 1465
860 gross 1878 # At this point either iter > iter_max or rho < tol.
861 caltinay 2169 # It's time to compute x and leave.
862     if iter > 0 :
863 jfenwick 2455 y=numpy.zeros(iter,numpy.float64)
864 gross 2100 y[iter-1] = g[iter-1] / h[iter-1,iter-1]
865 caltinay 2169 if iter > 1 :
866     i=iter-2
867 artak 1465 while i>=0 :
868 jfenwick 2455 y[i] = ( g[i] - numpy.dot(h[i,i+1:iter], y[i+1:iter])) / h[i,i]
869 artak 1465 i=i-1
870     xhat=v[iter-1]*y[iter-1]
871     for i in range(iter-1):
872 jfenwick 3852 xhat += v[i]*y[i]
873 caltinay 2169 else :
874 gross 1878 xhat=v[0] * 0
875 artak 1557
876 caltinay 2169 if iter<iter_restart-1:
877 gross 1878 stopped=iter
878 caltinay 2169 else:
879 gross 1878 stopped=-1
880 artak 1465
881 gross 1878 return xhat,stopped
882 artak 1481
883 gross 2261 def GMRES(r, Aprod, x, bilinearform, atol=0, rtol=1.e-8, iter_max=100, iter_restart=20, verbose=False,P_R=None):
884 gross 2156 """
885 caltinay 2169 Solver for
886 gross 2156
887 jfenwick 2625 *Ax=b*
888 gross 2156
889 caltinay 2169 with a general operator A (more details required!).
890 gross 2156 It uses the generalized minimum residual method (GMRES).
891    
892 caltinay 2169 The iteration is terminated if
893 gross 2156
894 jfenwick 2625 *|r| <= atol+rtol*|r0|*
895 gross 2156
896 jfenwick 2625 where *r0* is the initial residual and *|.|* is the energy norm. In fact
897 gross 2156
898 jfenwick 2625 *|r| = sqrt( bilinearform(r,r))*
899 gross 2156
900 jfenwick 2625 :param r: initial residual *r=b-Ax*. ``r`` is altered.
901     :type r: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
902     :param x: an initial guess for the solution
903     :type x: same like ``r``
904     :param Aprod: returns the value Ax
905     :type Aprod: function ``Aprod(x)`` where ``x`` is of the same object like
906     argument ``x``. The returned object needs to be of the same
907     type like argument ``r``.
908     :param bilinearform: inner product ``<x,r>``
909     :type bilinearform: function ``bilinearform(x,r)`` where ``x`` is of the same
910     type like argument ``x`` and ``r``. The returned value is
911     a ``float``.
912     :param atol: absolute tolerance
913     :type atol: non-negative ``float``
914     :param rtol: relative tolerance
915     :type rtol: non-negative ``float``
916     :param iter_max: maximum number of iteration steps
917     :type iter_max: ``int``
918     :param iter_restart: in order to save memory the orthogonalization process
919     is terminated after ``iter_restart`` steps and the
920 caltinay 2169 iteration is restarted.
921 jfenwick 2625 :type iter_restart: ``int``
922     :return: the solution approximation and the corresponding residual
923     :rtype: ``tuple``
924     :warning: ``r`` and ``x`` are altered.
925 gross 2156 """
926 gross 1878 m=iter_restart
927 gross 2156 restarted=False
928 gross 1878 iter=0
929 gross 2100 if rtol>0:
930 gross 2156 r_dot_r = bilinearform(r, r)
931 jfenwick 3771 if r_dot_r<0: raise NegativeNorm("negative norm.")
932 gross 2100 atol2=atol+rtol*math.sqrt(r_dot_r)
933 jfenwick 3771 if verbose: print(("GMRES: norm of right hand side = %e (absolute tolerance = %e)"%(math.sqrt(r_dot_r), atol2)))
934 gross 2100 else:
935     atol2=atol
936 jfenwick 3771 if verbose: print(("GMRES: absolute tolerance = %e"%atol2))
937 gross 2156 if atol2<=0:
938 jfenwick 3771 raise ValueError("Non-positive tolarance.")
939 caltinay 2169
940 gross 1878 while True:
941 jfenwick 3771 if iter >= iter_max: raise MaxIterReached("maximum number of %s steps reached"%iter_max)
942 caltinay 2169 if restarted:
943 gross 2156 r2 = r-Aprod(x-x2)
944     else:
945     r2=1*r
946     x2=x*1.
947 gross 2261 x,stopped=_GMRESm(r2, Aprod, x, bilinearform, atol2, iter_max=iter_max-iter, iter_restart=m, verbose=verbose,P_R=P_R)
948 caltinay 2169 iter+=iter_restart
949 gross 1878 if stopped: break
950 jfenwick 3771 if verbose: print("GMRES: restart.")
951 gross 2156 restarted=True
952 jfenwick 3771 if verbose: print("GMRES: tolerance has been reached.")
953 gross 2156 return x
954 artak 1550
955 gross 2261 def _GMRESm(r, Aprod, x, bilinearform, atol, iter_max=100, iter_restart=20, verbose=False, P_R=None):
956 gross 1878 iter=0
957 caltinay 2169
958 jfenwick 2455 h=numpy.zeros((iter_restart+1,iter_restart),numpy.float64)
959     c=numpy.zeros(iter_restart,numpy.float64)
960     s=numpy.zeros(iter_restart,numpy.float64)
961     g=numpy.zeros(iter_restart+1,numpy.float64)
962 artak 1519 v=[]
963    
964 gross 2100 r_dot_r = bilinearform(r, r)
965 jfenwick 3771 if r_dot_r<0: raise NegativeNorm("negative norm.")
966 artak 1519 rho=math.sqrt(r_dot_r)
967 caltinay 2169
968 artak 1519 v.append(r/rho)
969     g[0]=rho
970    
971 jfenwick 3771 if verbose: print(("GMRES: initial residual %e (absolute tolerance = %e)"%(rho,atol)))
972 gross 2100 while not (rho<=atol or iter==iter_restart):
973 artak 1519
974 jfenwick 3852 if iter >= iter_max: raise MaxIterReached("maximum number of %s steps reached."%iter_max)
975 artak 1519
976 gross 2261 if P_R!=None:
977     p=Aprod(P_R(v[iter]))
978     else:
979 jfenwick 3852 p=Aprod(v[iter])
980     v.append(p)
981 artak 1519
982 jfenwick 3852 v_norm1=math.sqrt(bilinearform(v[iter+1], v[iter+1]))
983 artak 1519
984 caltinay 2169 # Modified Gram-Schmidt
985 jfenwick 3852 for j in range(iter+1):
986     h[j,iter]=bilinearform(v[j],v[iter+1])
987     v[iter+1]-=h[j,iter]*v[j]
988 caltinay 2169
989 jfenwick 3852 h[iter+1,iter]=math.sqrt(bilinearform(v[iter+1],v[iter+1]))
990     v_norm2=h[iter+1,iter]
991 artak 1519
992     # Reorthogonalize if needed
993 jfenwick 3852 if v_norm1 + 0.001*v_norm2 == v_norm1: #Brown/Hindmarsh condition (default)
994     for j in range(iter+1):
995     hr=bilinearform(v[j],v[iter+1])
996     h[j,iter]=h[j,iter]+hr
997     v[iter+1] -= hr*v[j]
998 artak 1519
999 jfenwick 3852 v_norm2=math.sqrt(bilinearform(v[iter+1], v[iter+1]))
1000     h[iter+1,iter]=v_norm2
1001 artak 1519
1002 caltinay 2169 # watch out for happy breakdown
1003 gross 1878 if not v_norm2 == 0:
1004 gross 2100 v[iter+1]=v[iter+1]/h[iter+1,iter]
1005 artak 1519
1006     # Form and store the information for the new Givens rotation
1007 jfenwick 3852 if iter > 0: h[:iter+1,iter]=__givapp(c[:iter],s[:iter],h[:iter+1,iter])
1008     mu=math.sqrt(h[iter,iter]*h[iter,iter]+h[iter+1,iter]*h[iter+1,iter])
1009 artak 1519
1010 jfenwick 3852 if mu!=0 :
1011     c[iter]=h[iter,iter]/mu
1012     s[iter]=-h[iter+1,iter]/mu
1013     h[iter,iter]=c[iter]*h[iter,iter]-s[iter]*h[iter+1,iter]
1014     h[iter+1,iter]=0.0
1015 artak 2303 gg=__givapp(c[iter],s[iter],[g[iter],g[iter+1]])
1016 gross 2284 g[iter]=gg[0]
1017     g[iter+1]=gg[1]
1018 artak 1519 # Update the residual norm
1019 caltinay 2169
1020 artak 1519 rho=abs(g[iter+1])
1021 jfenwick 3771 if verbose: print(("GMRES: iteration step %s: residual %e"%(iter,rho)))
1022 jfenwick 3852 iter+=1
1023 artak 1519
1024 gross 1878 # At this point either iter > iter_max or rho < tol.
1025 caltinay 2169 # It's time to compute x and leave.
1026 gross 1878
1027 jfenwick 3771 if verbose: print(("GMRES: iteration stopped after %s step."%iter))
1028 caltinay 2169 if iter > 0 :
1029 jfenwick 2455 y=numpy.zeros(iter,numpy.float64)
1030 gross 2100 y[iter-1] = g[iter-1] / h[iter-1,iter-1]
1031 caltinay 2169 if iter > 1 :
1032     i=iter-2
1033 artak 1519 while i>=0 :
1034 jfenwick 2455 y[i] = ( g[i] - numpy.dot(h[i,i+1:iter], y[i+1:iter])) / h[i,i]
1035 artak 1519 i=i-1
1036     xhat=v[iter-1]*y[iter-1]
1037     for i in range(iter-1):
1038 jfenwick 3852 xhat += v[i]*y[i]
1039 caltinay 2169 else:
1040 gross 2100 xhat=v[0] * 0
1041 gross 2261 if P_R!=None:
1042     x += P_R(xhat)
1043     else:
1044     x += xhat
1045 caltinay 2169 if iter<iter_restart-1:
1046     stopped=True
1047     else:
1048 artak 1519 stopped=False
1049    
1050     return x,stopped
1051    
1052 gross 2156 def MINRES(r, Aprod, x, Msolve, bilinearform, atol=0, rtol=1.e-8, iter_max=100):
1053     """
1054 caltinay 2169 Solver for
1055    
1056 jfenwick 2625 *Ax=b*
1057 caltinay 2169
1058     with a symmetric and positive definite operator A (more details required!).
1059     It uses the minimum residual method (MINRES) with preconditioner M
1060     providing an approximation of A.
1061    
1062     The iteration is terminated if
1063    
1064 jfenwick 2625 *|r| <= atol+rtol*|r0|*
1065 caltinay 2169
1066 jfenwick 2625 where *r0* is the initial residual and *|.|* is the energy norm. In fact
1067 gross 2156
1068 jfenwick 2625 *|r| = sqrt( bilinearform(Msolve(r),r))*
1069 caltinay 2169
1070 gross 2156 For details on the preconditioned conjugate gradient method see the book:
1071    
1072 caltinay 2169 I{Templates for the Solution of Linear Systems by R. Barrett, M. Berry,
1073     T.F. Chan, J. Demmel, J. Donato, J. Dongarra, V. Eijkhout, R. Pozo,
1074     C. Romine, and H. van der Vorst}.
1075    
1076 jfenwick 2625 :param r: initial residual *r=b-Ax*. ``r`` is altered.
1077     :type r: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
1078     :param x: an initial guess for the solution
1079     :type x: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
1080     :param Aprod: returns the value Ax
1081     :type Aprod: function ``Aprod(x)`` where ``x`` is of the same object like
1082     argument ``x``. The returned object needs to be of the same
1083     type like argument ``r``.
1084     :param Msolve: solves Mx=r
1085     :type Msolve: function ``Msolve(r)`` where ``r`` is of the same type like
1086     argument ``r``. The returned object needs to be of the same
1087     type like argument ``x``.
1088     :param bilinearform: inner product ``<x,r>``
1089     :type bilinearform: function ``bilinearform(x,r)`` where ``x`` is of the same
1090     type like argument ``x`` and ``r`` is. The returned value
1091     is a ``float``.
1092     :param atol: absolute tolerance
1093     :type atol: non-negative ``float``
1094     :param rtol: relative tolerance
1095     :type rtol: non-negative ``float``
1096     :param iter_max: maximum number of iteration steps
1097     :type iter_max: ``int``
1098     :return: the solution approximation and the corresponding residual
1099     :rtype: ``tuple``
1100     :warning: ``r`` and ``x`` are altered.
1101 gross 2156 """
1102 artak 1481 #------------------------------------------------------------------
1103     # Set up y and v for the first Lanczos vector v1.
1104     # y = beta1 P' v1, where P = C**(-1).
1105     # v is really P' v1.
1106     #------------------------------------------------------------------
1107 gross 2156 r1 = r
1108     y = Msolve(r)
1109     beta1 = bilinearform(y,r)
1110 caltinay 2169
1111 jfenwick 3771 if beta1< 0: raise NegativeNorm("negative norm.")
1112 artak 1481
1113 gross 2156 # If r = 0 exactly, stop with x
1114     if beta1==0: return x
1115 artak 1481
1116 caltinay 2169 if beta1> 0: beta1 = math.sqrt(beta1)
1117 artak 1481
1118     #------------------------------------------------------------------
1119 artak 1484 # Initialize quantities.
1120 artak 1481 # ------------------------------------------------------------------
1121 artak 1482 iter = 0
1122     Anorm = 0
1123     ynorm = 0
1124 artak 1481 oldb = 0
1125     beta = beta1
1126     dbar = 0
1127     epsln = 0
1128     phibar = beta1
1129     rhs1 = beta1
1130     rhs2 = 0
1131     rnorm = phibar
1132     tnorm2 = 0
1133     ynorm2 = 0
1134     cs = -1
1135     sn = 0
1136 gross 2156 w = r*0.
1137     w2 = r*0.
1138 artak 1481 r2 = r1
1139     eps = 0.0001
1140    
1141     #---------------------------------------------------------------------
1142     # Main iteration loop.
1143     # --------------------------------------------------------------------
1144 gross 2100 while not rnorm<=atol+rtol*Anorm*ynorm: # checks ||r|| < (||A|| ||x||) * TOL
1145 artak 1481
1146 jfenwick 3852 if iter >= iter_max: raise MaxIterReached("maximum number of %s steps reached."%iter_max)
1147 artak 1481 iter = iter + 1
1148    
1149     #-----------------------------------------------------------------
1150     # Obtain quantities for the next Lanczos vector vk+1, k = 1, 2,...
1151     # The general iteration is similar to the case k = 1 with v0 = 0:
1152     #
1153     # p1 = Operator * v1 - beta1 * v0,
1154     # alpha1 = v1'p1,
1155     # q2 = p2 - alpha1 * v1,
1156     # beta2^2 = q2'q2,
1157     # v2 = (1/beta2) q2.
1158     #
1159     # Again, y = betak P vk, where P = C**(-1).
1160     #-----------------------------------------------------------------
1161     s = 1/beta # Normalize previous vector (in y).
1162     v = s*y # v = vk if P = I
1163 caltinay 2169
1164 artak 1481 y = Aprod(v)
1165 caltinay 2169
1166 artak 1481 if iter >= 2:
1167     y = y - (beta/oldb)*r1
1168    
1169     alfa = bilinearform(v,y) # alphak
1170 caltinay 2169 y += (- alfa/beta)*r2
1171 artak 1481 r1 = r2
1172     r2 = y
1173     y = Msolve(r2)
1174     oldb = beta # oldb = betak
1175 artak 1550 beta = bilinearform(y,r2) # beta = betak+1^2
1176 jfenwick 3771 if beta < 0: raise NegativeNorm("negative norm.")
1177 artak 1481
1178     beta = math.sqrt( beta )
1179     tnorm2 = tnorm2 + alfa*alfa + oldb*oldb + beta*beta
1180 caltinay 2169
1181 artak 1481 if iter==1: # Initialize a few things.
1182     gmax = abs( alfa ) # alpha1
1183     gmin = gmax # alpha1
1184    
1185     # Apply previous rotation Qk-1 to get
1186     # [deltak epslnk+1] = [cs sn][dbark 0 ]
1187     # [gbar k dbar k+1] [sn -cs][alfak betak+1].
1188 caltinay 2169
1189 artak 1481 oldeps = epsln
1190     delta = cs * dbar + sn * alfa # delta1 = 0 deltak
1191     gbar = sn * dbar - cs * alfa # gbar 1 = alfa1 gbar k
1192     epsln = sn * beta # epsln2 = 0 epslnk+1
1193     dbar = - cs * beta # dbar 2 = beta2 dbar k+1
1194    
1195     # Compute the next plane rotation Qk
1196    
1197     gamma = math.sqrt(gbar*gbar+beta*beta) # gammak
1198 caltinay 2169 gamma = max(gamma,eps)
1199 artak 1481 cs = gbar / gamma # ck
1200     sn = beta / gamma # sk
1201     phi = cs * phibar # phik
1202     phibar = sn * phibar # phibark+1
1203    
1204     # Update x.
1205    
1206 caltinay 2169 denom = 1/gamma
1207     w1 = w2
1208     w2 = w
1209 artak 1481 w = (v - oldeps*w1 - delta*w2) * denom
1210 artak 1550 x += phi*w
1211 artak 1481
1212     # Go round again.
1213    
1214     gmax = max(gmax,gamma)
1215     gmin = min(gmin,gamma)
1216     z = rhs1 / gamma
1217     ynorm2 = z*z + ynorm2
1218     rhs1 = rhs2 - delta*z
1219     rhs2 = - epsln*z
1220    
1221     # Estimate various norms and test for convergence.
1222    
1223 caltinay 2169 Anorm = math.sqrt( tnorm2 )
1224     ynorm = math.sqrt( ynorm2 )
1225 artak 1481
1226     rnorm = phibar
1227    
1228     return x
1229 artak 1489
1230 gross 2156 def TFQMR(r, Aprod, x, bilinearform, atol=0, rtol=1.e-8, iter_max=100):
1231     """
1232 caltinay 2169 Solver for
1233 artak 1489
1234 jfenwick 2625 *Ax=b*
1235 artak 1489
1236 caltinay 2169 with a general operator A (more details required!).
1237     It uses the Transpose-Free Quasi-Minimal Residual method (TFQMR).
1238 artak 1489
1239 caltinay 2169 The iteration is terminated if
1240 artak 1489
1241 jfenwick 2625 *|r| <= atol+rtol*|r0|*
1242 gross 2156
1243 jfenwick 2625 where *r0* is the initial residual and *|.|* is the energy norm. In fact
1244 gross 2156
1245 jfenwick 2625 *|r| = sqrt( bilinearform(r,r))*
1246 gross 2156
1247 jfenwick 2625 :param r: initial residual *r=b-Ax*. ``r`` is altered.
1248     :type r: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
1249     :param x: an initial guess for the solution
1250     :type x: same like ``r``
1251     :param Aprod: returns the value Ax
1252     :type Aprod: function ``Aprod(x)`` where ``x`` is of the same object like
1253     argument ``x``. The returned object needs to be of the same type
1254     like argument ``r``.
1255     :param bilinearform: inner product ``<x,r>``
1256     :type bilinearform: function ``bilinearform(x,r)`` where ``x`` is of the same
1257     type like argument ``x`` and ``r``. The returned value is
1258     a ``float``.
1259     :param atol: absolute tolerance
1260     :type atol: non-negative ``float``
1261     :param rtol: relative tolerance
1262     :type rtol: non-negative ``float``
1263     :param iter_max: maximum number of iteration steps
1264     :type iter_max: ``int``
1265     :rtype: ``tuple``
1266     :warning: ``r`` and ``x`` are altered.
1267 gross 2156 """
1268 artak 1489 u1=0
1269     u2=0
1270     y1=0
1271     y2=0
1272    
1273     w = r
1274 caltinay 2169 y1 = r
1275     iter = 0
1276 artak 1489 d = 0
1277 gross 2156 v = Aprod(y1)
1278 artak 1489 u1 = v
1279 caltinay 2169
1280 artak 1489 theta = 0.0;
1281     eta = 0.0;
1282 gross 2156 rho=bilinearform(r,r)
1283 jfenwick 3771 if rho < 0: raise NegativeNorm("negative norm.")
1284 gross 2156 tau = math.sqrt(rho)
1285     norm_r0=tau
1286     while tau>atol+rtol*norm_r0:
1287 jfenwick 3771 if iter >= iter_max: raise MaxIterReached("maximum number of %s steps reached."%iter_max)
1288 artak 1489
1289     sigma = bilinearform(r,v)
1290 jfenwick 3771 if sigma == 0.0: raise IterationBreakDown('TFQMR breakdown, sigma=0')
1291 caltinay 2169
1292 artak 1489 alpha = rho / sigma
1293    
1294     for j in range(2):
1295     #
1296     # Compute y2 and u2 only if you have to
1297     #
1298     if ( j == 1 ):
1299     y2 = y1 - alpha * v
1300 gross 2156 u2 = Aprod(y2)
1301 caltinay 2169
1302 artak 1489 m = 2 * (iter+1) - 2 + (j+1)
1303 caltinay 2169 if j==0:
1304 artak 1489 w = w - alpha * u1
1305     d = y1 + ( theta * theta * eta / alpha ) * d
1306     if j==1:
1307     w = w - alpha * u2
1308     d = y2 + ( theta * theta * eta / alpha ) * d
1309    
1310     theta = math.sqrt(bilinearform(w,w))/ tau
1311     c = 1.0 / math.sqrt ( 1.0 + theta * theta )
1312     tau = tau * theta * c
1313     eta = c * c * alpha
1314     x = x + eta * d
1315     #
1316     # Try to terminate the iteration at each pass through the loop
1317     #
1318 jfenwick 3771 if rho == 0.0: raise IterationBreakDown('TFQMR breakdown, rho=0')
1319 artak 1489
1320     rhon = bilinearform(r,w)
1321     beta = rhon / rho;
1322     rho = rhon;
1323     y1 = w + beta * y2;
1324 gross 2156 u1 = Aprod(y1)
1325 artak 1489 v = u1 + beta * ( u2 + beta * v )
1326 caltinay 2169
1327 gross 2156 iter += 1
1328 artak 1489
1329     return x
1330    
1331    
1332 artak 1465 #############################################
1333    
1334 gross 1331 class ArithmeticTuple(object):
1335     """
1336 jfenwick 2625 Tuple supporting inplace update x+=y and scaling x=a*y where ``x,y`` is an
1337     ArithmeticTuple and ``a`` is a float.
1338 gross 1331
1339 caltinay 2169 Example of usage::
1340 gross 1331
1341 caltinay 2169 from esys.escript import Data
1342 jfenwick 2455 from numpy import array
1343 caltinay 2169 a=Data(...)
1344     b=array([1.,4.])
1345     x=ArithmeticTuple(a,b)
1346     y=5.*x
1347 gross 1331
1348     """
1349     def __init__(self,*args):
1350     """
1351 jfenwick 2625 Initializes object with elements ``args``.
1352 gross 1331
1353 jfenwick 2625 :param args: tuple of objects that support inplace add (x+=y) and
1354 caltinay 2169 scaling (x=a*y)
1355 gross 1331 """
1356     self.__items=list(args)
1357    
1358     def __len__(self):
1359     """
1360 caltinay 2169 Returns the number of items.
1361 gross 1331
1362 jfenwick 2625 :return: number of items
1363     :rtype: ``int``
1364 gross 1331 """
1365     return len(self.__items)
1366    
1367     def __getitem__(self,index):
1368     """
1369 caltinay 2169 Returns item at specified position.
1370 gross 1331
1371 jfenwick 2625 :param index: index of item to be returned
1372     :type index: ``int``
1373     :return: item with index ``index``
1374 gross 1331 """
1375     return self.__items.__getitem__(index)
1376    
1377     def __mul__(self,other):
1378     """
1379 jfenwick 2625 Scales by ``other`` from the right.
1380 gross 1331
1381 jfenwick 2625 :param other: scaling factor
1382     :type other: ``float``
1383     :return: itemwise self*other
1384     :rtype: `ArithmeticTuple`
1385 gross 1331 """
1386     out=[]
1387 caltinay 2169 try:
1388 gross 1896 l=len(other)
1389     if l!=len(self):
1390 jfenwick 3771 raise ValueError("length of arguments don't match.")
1391 gross 1896 for i in range(l): out.append(self[i]*other[i])
1392     except TypeError:
1393 caltinay 2169 for i in range(len(self)): out.append(self[i]*other)
1394 gross 1331 return ArithmeticTuple(*tuple(out))
1395    
1396     def __rmul__(self,other):
1397     """
1398 jfenwick 2625 Scales by ``other`` from the left.
1399 gross 1331
1400 jfenwick 2625 :param other: scaling factor
1401     :type other: ``float``
1402     :return: itemwise other*self
1403     :rtype: `ArithmeticTuple`
1404 gross 1331 """
1405     out=[]
1406 caltinay 2169 try:
1407 gross 1896 l=len(other)
1408     if l!=len(self):
1409 jfenwick 3771 raise ValueError("length of arguments don't match.")
1410 gross 1896 for i in range(l): out.append(other[i]*self[i])
1411     except TypeError:
1412 caltinay 2169 for i in range(len(self)): out.append(other*self[i])
1413 gross 1331 return ArithmeticTuple(*tuple(out))
1414    
1415 artak 1465 def __div__(self,other):
1416     """
1417 jfenwick 2625 Scales by (1/``other``) from the right.
1418 artak 1465
1419 jfenwick 2625 :param other: scaling factor
1420     :type other: ``float``
1421     :return: itemwise self/other
1422     :rtype: `ArithmeticTuple`
1423 artak 1465 """
1424 gross 1896 return self*(1/other)
1425 artak 1465
1426     def __rdiv__(self,other):
1427     """
1428 jfenwick 2625 Scales by (1/``other``) from the left.
1429 artak 1465
1430 jfenwick 2625 :param other: scaling factor
1431     :type other: ``float``
1432     :return: itemwise other/self
1433     :rtype: `ArithmeticTuple`
1434 artak 1465 """
1435     out=[]
1436 caltinay 2169 try:
1437 gross 1896 l=len(other)
1438     if l!=len(self):
1439 jfenwick 3771 raise ValueError("length of arguments don't match.")
1440 gross 1896 for i in range(l): out.append(other[i]/self[i])
1441     except TypeError:
1442 caltinay 2169 for i in range(len(self)): out.append(other/self[i])
1443 artak 1465 return ArithmeticTuple(*tuple(out))
1444 caltinay 2169
1445 gross 1331 def __iadd__(self,other):
1446     """
1447 jfenwick 2625 Inplace addition of ``other`` to self.
1448 gross 1331
1449 jfenwick 2625 :param other: increment
1450     :type other: ``ArithmeticTuple``
1451 gross 1331 """
1452     if len(self) != len(other):
1453 jfenwick 3771 raise ValueError("tuple lengths must match.")
1454 gross 1331 for i in range(len(self)):
1455     self.__items[i]+=other[i]
1456     return self
1457    
1458 artak 1550 def __add__(self,other):
1459     """
1460 jfenwick 2625 Adds ``other`` to self.
1461 artak 1550
1462 jfenwick 2625 :param other: increment
1463     :type other: ``ArithmeticTuple``
1464 artak 1550 """
1465 gross 1896 out=[]
1466 caltinay 2169 try:
1467 gross 1896 l=len(other)
1468     if l!=len(self):
1469 jfenwick 3771 raise ValueError("length of arguments don't match.")
1470 gross 1896 for i in range(l): out.append(self[i]+other[i])
1471     except TypeError:
1472 caltinay 2169 for i in range(len(self)): out.append(self[i]+other)
1473 gross 1896 return ArithmeticTuple(*tuple(out))
1474 artak 1550
1475     def __sub__(self,other):
1476     """
1477 jfenwick 2625 Subtracts ``other`` from self.
1478 artak 1550
1479 jfenwick 2625 :param other: decrement
1480     :type other: ``ArithmeticTuple``
1481 artak 1550 """
1482 gross 1896 out=[]
1483 caltinay 2169 try:
1484 gross 1896 l=len(other)
1485     if l!=len(self):
1486 jfenwick 3771 raise ValueError("length of arguments don't match.")
1487 gross 1896 for i in range(l): out.append(self[i]-other[i])
1488     except TypeError:
1489 caltinay 2169 for i in range(len(self)): out.append(self[i]-other)
1490 gross 1896 return ArithmeticTuple(*tuple(out))
1491 caltinay 2169
1492 artak 1557 def __isub__(self,other):
1493     """
1494 jfenwick 2625 Inplace subtraction of ``other`` from self.
1495 artak 1550
1496 jfenwick 2625 :param other: decrement
1497     :type other: ``ArithmeticTuple``
1498 artak 1557 """
1499     if len(self) != len(other):
1500 jfenwick 3771 raise ValueError("tuple length must match.")
1501 artak 1557 for i in range(len(self)):
1502     self.__items[i]-=other[i]
1503     return self
1504    
1505 artak 1550 def __neg__(self):
1506     """
1507 caltinay 2169 Negates values.
1508 artak 1550 """
1509 gross 1896 out=[]
1510 artak 1550 for i in range(len(self)):
1511 gross 1896 out.append(-self[i])
1512     return ArithmeticTuple(*tuple(out))
1513 artak 1550
1514    
1515 gross 1414 class HomogeneousSaddlePointProblem(object):
1516     """
1517 caltinay 2169 This class provides a framework for solving linear homogeneous saddle
1518     point problems of the form::
1519 gross 1414
1520 jfenwick 2625 *Av+B^*p=f*
1521     *Bv =0*
1522 gross 1414
1523 jfenwick 2625 for the unknowns *v* and *p* and given operators *A* and *B* and
1524     given right hand side *f*. *B^** is the adjoint operator of *B*.
1525 gross 2719 *A* may depend weakly on *v* and *p*.
1526 gross 1414 """
1527 gross 2719 def __init__(self, **kwargs):
1528 jfenwick 3852 """
1529     initializes the saddle point problem
1530     """
1531 gross 2719 self.resetControlParameters()
1532 gross 1414 self.setTolerance()
1533 gross 2100 self.setAbsoluteTolerance()
1534 gross 2862 def resetControlParameters(self, K_p=1., K_v=1., rtol_max=0.01, rtol_min = 1.e-7, chi_max=0.5, reduction_factor=0.3, theta = 0.1):
1535 gross 2719 """
1536     sets a control parameter
1537    
1538 gross 2843 :param K_p: initial value for constant to adjust pressure tolerance
1539     :type K_p: ``float``
1540     :param K_v: initial value for constant to adjust velocity tolerance
1541     :type K_v: ``float``
1542     :param rtol_max: maximuim relative tolerance used to calculate presssure and velocity increment.
1543     :type rtol_max: ``float``
1544 gross 2719 :param chi_max: maximum tolerable converegence rate.
1545     :type chi_max: ``float``
1546 gross 2843 :param reduction_factor: reduction factor for adjustment factors.
1547     :type reduction_factor: ``float``
1548 gross 2719 """
1549 gross 2843 self.setControlParameter(K_p, K_v, rtol_max, rtol_min, chi_max, reduction_factor, theta)
1550 gross 2719
1551 gross 2843 def setControlParameter(self,K_p=None, K_v=None, rtol_max=None, rtol_min=None, chi_max=None, reduction_factor=None, theta=None):
1552 gross 2719 """
1553     sets a control parameter
1554    
1555 gross 2843
1556     :param K_p: initial value for constant to adjust pressure tolerance
1557     :type K_p: ``float``
1558     :param K_v: initial value for constant to adjust velocity tolerance
1559     :type K_v: ``float``
1560     :param rtol_max: maximuim relative tolerance used to calculate presssure and velocity increment.
1561     :type rtol_max: ``float``
1562 gross 2719 :param chi_max: maximum tolerable converegence rate.
1563     :type chi_max: ``float``
1564 gross 2843 :type reduction_factor: ``float``
1565 gross 2719 """
1566 gross 2843 if not K_p == None:
1567     if K_p<1:
1568 jfenwick 3771 raise ValueError("K_p need to be greater or equal to 1.")
1569 gross 2719 else:
1570 gross 2843 K_p=self.__K_p
1571 gross 2719
1572 gross 2843 if not K_v == None:
1573     if K_v<1:
1574 jfenwick 3771 raise ValueError("K_v need to be greater or equal to 1.")
1575 gross 2719 else:
1576 gross 2843 K_v=self.__K_v
1577 gross 2719
1578 gross 2843 if not rtol_max == None:
1579     if rtol_max<=0 or rtol_max>=1:
1580 jfenwick 3771 raise ValueError("rtol_max needs to be positive and less than 1.")
1581 gross 2719 else:
1582 gross 2843 rtol_max=self.__rtol_max
1583 gross 2719
1584     if not rtol_min == None:
1585     if rtol_min<=0 or rtol_min>=1:
1586 jfenwick 3771 raise ValueError("rtol_min needs to be positive and less than 1.")
1587 gross 2719 else:
1588     rtol_min=self.__rtol_min
1589    
1590 gross 2843 if not chi_max == None:
1591     if chi_max<=0 or chi_max>=1:
1592 jfenwick 3771 raise ValueError("chi_max needs to be positive and less than 1.")
1593 gross 2719 else:
1594 gross 2843 chi_max = self.__chi_max
1595 gross 2719
1596 gross 2843 if not reduction_factor == None:
1597     if reduction_factor<=0 or reduction_factor>1:
1598 jfenwick 3771 raise ValueError("reduction_factor need to be between zero and one.")
1599 gross 2719 else:
1600 gross 2843 reduction_factor=self.__reduction_factor
1601 gross 2719
1602 gross 2843 if not theta == None:
1603     if theta<=0 or theta>1:
1604 jfenwick 3771 raise ValueError("theta need to be between zero and one.")
1605 gross 2719 else:
1606 gross 2843 theta=self.__theta
1607 gross 2719
1608 gross 2843 if rtol_min>=rtol_max:
1609 jfenwick 3771 raise ValueError("rtol_max = %e needs to be greater than rtol_min = %e"%(rtol_max,rtol_min))
1610 gross 2719 self.__chi_max = chi_max
1611     self.__rtol_max = rtol_max
1612 gross 2843 self.__K_p = K_p
1613     self.__K_v = K_v
1614     self.__reduction_factor = reduction_factor
1615     self.__theta = theta
1616     self.__rtol_min=rtol_min
1617 gross 2719
1618 gross 2100 #=============================================================
1619 gross 2445 def inner_pBv(self,p,Bv):
1620 gross 1414 """
1621 caltinay 2169 Returns inner product of element p and Bv (overwrite).
1622    
1623 jfenwick 2625 :param p: a pressure increment
1624     :param Bv: a residual
1625     :return: inner product of element p and Bv
1626     :rtype: ``float``
1627     :note: used if PCG is applied.
1628 gross 1414 """
1629 jfenwick 3771 raise NotImplementedError("no inner product for p and Bv implemented.")
1630 gross 1414
1631 gross 2100 def inner_p(self,p0,p1):
1632 gross 1414 """
1633 gross 2251 Returns inner product of p0 and p1 (overwrite).
1634 caltinay 2169
1635 jfenwick 2625 :param p0: a pressure
1636     :param p1: a pressure
1637     :return: inner product of p0 and p1
1638     :rtype: ``float``
1639 gross 1414 """
1640 jfenwick 3771 raise NotImplementedError("no inner product for p implemented.")
1641 gross 2251
1642     def norm_v(self,v):
1643     """
1644     Returns the norm of v (overwrite).
1645 gross 2100
1646 jfenwick 2625 :param v: a velovity
1647     :return: norm of v
1648     :rtype: non-negative ``float``
1649 gross 2100 """
1650 jfenwick 3771 raise NotImplementedError("no norm of v implemented.")
1651 gross 2719 def getDV(self, p, v, tol):
1652 gross 2100 """
1653 gross 2719 return a correction to the value for a given v and a given p with accuracy `tol` (overwrite)
1654 gross 1414
1655 gross 2719 :param p: pressure
1656     :param v: pressure
1657     :return: dv given as *dv= A^{-1} (f-A v-B^*p)*
1658     :note: Only *A* may depend on *v* and *p*
1659 gross 1414 """
1660 jfenwick 3771 raise NotImplementedError("no dv calculation implemented.")
1661 gross 2445
1662 gross 2251
1663 gross 2719 def Bv(self,v, tol):
1664 gross 2251 """
1665 gross 2719 Returns Bv with accuracy `tol` (overwrite)
1666 gross 2251
1667 jfenwick 2625 :rtype: equal to the type of p
1668     :note: boundary conditions on p should be zero!
1669 gross 2251 """
1670 jfenwick 3771 raise NotImplementedError("no operator B implemented.")
1671 gross 2251
1672 gross 2445 def norm_Bv(self,Bv):
1673     """
1674     Returns the norm of Bv (overwrite).
1675    
1676 jfenwick 2625 :rtype: equal to the type of p
1677     :note: boundary conditions on p should be zero!
1678 gross 2445 """
1679 jfenwick 3771 raise NotImplementedError("no norm of Bv implemented.")
1680 gross 2445
1681 gross 2719 def solve_AinvBt(self,dp, tol):
1682 gross 2251 """
1683 gross 2719 Solves *A dv=B^*dp* with accuracy `tol`
1684 gross 1414
1685 gross 2719 :param dp: a pressure increment
1686     :return: the solution of *A dv=B^*dp*
1687     :note: boundary conditions on dv should be zero! *A* is the operator used in ``getDV`` and must not be altered.
1688 gross 1414 """
1689 jfenwick 3771 raise NotImplementedError("no operator A implemented.")
1690 gross 1414
1691 gross 2719 def solve_prec(self,Bv, tol):
1692 gross 1414 """
1693 gross 2719 Provides a preconditioner for *(BA^{-1}B^ * )* applied to Bv with accuracy `tol`
1694 gross 1414
1695 jfenwick 2625 :rtype: equal to the type of p
1696     :note: boundary conditions on p should be zero!
1697 gross 1414 """
1698 jfenwick 3771 raise NotImplementedError("no preconditioner for Schur complement implemented.")
1699 gross 2100 #=============================================================
1700 gross 2719 def __Aprod_PCG(self,dp):
1701     dv=self.solve_AinvBt(dp, self.__subtol)
1702     return ArithmeticTuple(dv,self.Bv(dv, self.__subtol))
1703 gross 1414
1704 gross 2445 def __inner_PCG(self,p,r):
1705     return self.inner_pBv(p,r[1])
1706 gross 1414
1707 gross 2445 def __Msolve_PCG(self,r):
1708 gross 2719 return self.solve_prec(r[1], self.__subtol)
1709 gross 2100 #=============================================================
1710 gross 2251 def __Aprod_GMRES(self,p):
1711 gross 2719 return self.solve_prec(self.Bv(self.solve_AinvBt(p, self.__subtol), self.__subtol), self.__subtol)
1712 gross 2251 def __inner_GMRES(self,p0,p1):
1713     return self.inner_p(p0,p1)
1714 gross 2474
1715 gross 2100 #=============================================================
1716 gross 2251 def norm_p(self,p):
1717     """
1718 jfenwick 2625 calculates the norm of ``p``
1719 gross 2251
1720 jfenwick 2625 :param p: a pressure
1721     :return: the norm of ``p`` using the inner product for pressure
1722     :rtype: ``float``
1723 gross 2251 """
1724     f=self.inner_p(p,p)
1725 jfenwick 3771 if f<0: raise ValueError("negative pressure norm.")
1726 gross 2251 return math.sqrt(f)
1727 gross 2474
1728     def solve(self,v,p,max_iter=20, verbose=False, usePCG=True, iter_restart=20, max_correction_steps=10):
1729 gross 2100 """
1730 caltinay 2169 Solves the saddle point problem using initial guesses v and p.
1731 gross 1414
1732 jfenwick 2625 :param v: initial guess for velocity
1733     :param p: initial guess for pressure
1734     :type v: `Data`
1735     :type p: `Data`
1736     :param usePCG: indicates the usage of the PCG rather than GMRES scheme.
1737     :param max_iter: maximum number of iteration steps per correction
1738 caltinay 2169 attempt
1739 jfenwick 2625 :param verbose: if True, shows information on the progress of the
1740 caltinay 2169 saddlepoint problem solver.
1741 jfenwick 2625 :param iter_restart: restart the iteration after ``iter_restart`` steps
1742 caltinay 2169 (only used if useUzaw=False)
1743 jfenwick 2625 :type usePCG: ``bool``
1744     :type max_iter: ``int``
1745     :type verbose: ``bool``
1746     :type iter_restart: ``int``
1747     :rtype: ``tuple`` of `Data` objects
1748 gross 2793 :note: typically this method is overwritten by a subclass. It provides a wrapper for the ``_solve`` method.
1749 gross 2100 """
1750 gross 2793 return self._solve(v=v,p=p,max_iter=max_iter,verbose=verbose, usePCG=usePCG, iter_restart=iter_restart, max_correction_steps=max_correction_steps)
1751    
1752     def _solve(self,v,p,max_iter=20, verbose=False, usePCG=True, iter_restart=20, max_correction_steps=10):
1753     """
1754     see `_solve` method.
1755     """
1756 gross 2100 self.verbose=verbose
1757     rtol=self.getTolerance()
1758     atol=self.getAbsoluteTolerance()
1759 gross 2843
1760     K_p=self.__K_p
1761     K_v=self.__K_v
1762 gross 2100 correction_step=0
1763     converged=False
1764 gross 2719 chi=None
1765 gross 2843 eps=None
1766    
1767 jfenwick 3771 if self.verbose: print(("HomogeneousSaddlePointProblem: start iteration: rtol= %e, atol=%e"%(rtol, atol)))
1768 gross 2251 while not converged:
1769 gross 2843
1770     # get tolerance for velecity increment:
1771     if chi == None:
1772     rtol_v=self.__rtol_max
1773     else:
1774     rtol_v=min(chi/K_v,self.__rtol_max)
1775     rtol_v=max(rtol_v, self.__rtol_min)
1776 jfenwick 3771 if self.verbose: print(("HomogeneousSaddlePointProblem: step %s: rtol_v= %e"%(correction_step,rtol_v)))
1777 gross 2843 # get velocity increment:
1778     dv1=self.getDV(p,v,rtol_v)
1779     v1=v+dv1
1780     Bv1=self.Bv(v1, rtol_v)
1781     norm_Bv1=self.norm_Bv(Bv1)
1782     norm_dv1=self.norm_v(dv1)
1783 jfenwick 3771 if self.verbose: print(("HomogeneousSaddlePointProblem: step %s: norm_Bv1 = %e, norm_dv1 = %e"%(correction_step, norm_Bv1, norm_dv1)))
1784 gross 2843 if norm_dv1*self.__theta < norm_Bv1:
1785     # get tolerance for pressure increment:
1786     large_Bv1=True
1787     if chi == None or eps == None:
1788     rtol_p=self.__rtol_max
1789     else:
1790     rtol_p=min(chi**2*eps/K_p/norm_Bv1, self.__rtol_max)
1791     self.__subtol=max(rtol_p**2, self.__rtol_min)
1792 jfenwick 3771 if self.verbose: print(("HomogeneousSaddlePointProblem: step %s: rtol_p= %e"%(correction_step,rtol_p)))
1793 gross 2843 # now we solve for the pressure increment dp from B*A^{-1}B^* dp = Bv1
1794     if usePCG:
1795 gross 2719 dp,r,a_norm=PCG(ArithmeticTuple(v1,Bv1),self.__Aprod_PCG,0*p,self.__Msolve_PCG,self.__inner_PCG,atol=0, rtol=rtol_p,iter_max=max_iter, verbose=self.verbose)
1796     v2=r[0]
1797     Bv2=r[1]
1798 gross 2843 else:
1799     # don't use!!!!
1800 gross 2719 dp=GMRES(self.solve_prec(Bv1,self.__subtol),self.__Aprod_GMRES, 0*p, self.__inner_GMRES,atol=0, rtol=rtol_p,iter_max=max_iter, iter_restart=iter_restart, verbose=self.verbose)
1801     dv2=self.solve_AinvBt(dp, self.__subtol)
1802     v2=v1-dv2
1803     Bv2=self.Bv(v2, self.__subtol)
1804 gross 2843 p2=p+dp
1805     else:
1806     large_Bv1=False
1807     v2=v1
1808     p2=p
1809     # update business:
1810     norm_dv2=self.norm_v(v2-v)
1811     norm_v2=self.norm_v(v2)
1812 jfenwick 3771 if self.verbose: print(("HomogeneousSaddlePointProblem: step %s: v2 = %e, norm_dv2 = %e"%(correction_step, norm_v2, self.norm_v(v2-v))))
1813 gross 2843 eps, eps_old = max(norm_Bv1, norm_dv2), eps
1814     if eps_old == None:
1815     chi, chi_old = None, chi
1816     else:
1817     chi, chi_old = min(eps/ eps_old, self.__chi_max), chi
1818     if eps != None:
1819     if chi !=None:
1820 jfenwick 3771 if self.verbose: print(("HomogeneousSaddlePointProblem: step %s: convergence rate = %e, correction = %e"%(correction_step,chi, eps)))
1821 gross 2843 else:
1822 jfenwick 3771 if self.verbose: print(("HomogeneousSaddlePointProblem: step %s: correction = %e"%(correction_step, eps)))
1823 gross 2843 if eps <= rtol*norm_v2+atol :
1824     converged = True
1825     else:
1826     if correction_step>=max_correction_steps:
1827 jfenwick 3771 raise CorrectionFailed("Given up after %d correction steps."%correction_step)
1828 gross 2843 if chi_old!=None:
1829     K_p=max(1,self.__reduction_factor*K_p,(chi-chi_old)/chi_old**2*K_p)
1830     K_v=max(1,self.__reduction_factor*K_v,(chi-chi_old)/chi_old**2*K_p)
1831 jfenwick 3771 if self.verbose: print(("HomogeneousSaddlePointProblem: step %s: new adjustment factor K = %e"%(correction_step,K_p)))
1832 gross 2843 correction_step+=1
1833     v,p =v2, p2
1834 jfenwick 3771 if self.verbose: print(("HomogeneousSaddlePointProblem: tolerance reached after %s steps."%correction_step))
1835 jfenwick 3852 return v,p
1836 caltinay 2169 #========================================================================
1837 gross 2100 def setTolerance(self,tolerance=1.e-4):
1838     """
1839 caltinay 2169 Sets the relative tolerance for (v,p).
1840 gross 2100
1841 jfenwick 2625 :param tolerance: tolerance to be used
1842     :type tolerance: non-negative ``float``
1843 gross 2100 """
1844     if tolerance<0:
1845 jfenwick 3771 raise ValueError("tolerance must be positive.")
1846 gross 2100 self.__rtol=tolerance
1847 gross 2156
1848 gross 1414 def getTolerance(self):
1849 gross 2100 """
1850 caltinay 2169 Returns the relative tolerance.
1851 gross 1414
1852 jfenwick 2625 :return: relative tolerance
1853     :rtype: ``float``
1854 gross 2100 """
1855     return self.__rtol
1856 caltinay 2169
1857 gross 2100 def setAbsoluteTolerance(self,tolerance=0.):
1858     """
1859 caltinay 2169 Sets the absolute tolerance.
1860 gross 1414
1861 jfenwick 2625 :param tolerance: tolerance to be used
1862     :type tolerance: non-negative ``float``
1863 gross 2100 """
1864     if tolerance<0:
1865 jfenwick 3771 raise ValueError("tolerance must be non-negative.")
1866 gross 2100 self.__atol=tolerance
1867 caltinay 2169
1868 gross 2100 def getAbsoluteTolerance(self):
1869     """
1870 caltinay 2169 Returns the absolute tolerance.
1871 gross 1414
1872 jfenwick 2625 :return: absolute tolerance
1873     :rtype: ``float``
1874 gross 2100 """
1875     return self.__atol
1876 gross 1469
1877 caltinay 2169
1878 gross 1878 def MaskFromBoundaryTag(domain,*tags):
1879     """
1880 caltinay 2169 Creates a mask on the Solution(domain) function space where the value is
1881     one for samples that touch the boundary tagged by tags.
1882 gross 1878
1883 caltinay 2169 Usage: m=MaskFromBoundaryTag(domain, "left", "right")
1884 gross 1878
1885 jfenwick 2625 :param domain: domain to be used
1886     :type domain: `escript.Domain`
1887     :param tags: boundary tags
1888     :type tags: ``str``
1889     :return: a mask which marks samples that are touching the boundary tagged
1890 caltinay 2169 by any of the given tags
1891 jfenwick 2625 :rtype: `escript.Data` of rank 0
1892 gross 1878 """
1893     pde=linearPDEs.LinearPDE(domain,numEquations=1, numSolutions=1)
1894 gross 1956 d=escript.Scalar(0.,escript.FunctionOnBoundary(domain))
1895 gross 1878 for t in tags: d.setTaggedValue(t,1.)
1896     pde.setValue(y=d)
1897     return util.whereNonZero(pde.getRightHandSide())
1898 caltinay 2169
1899 gross 2498 def MaskFromTag(domain,*tags):
1900     """
1901     Creates a mask on the Solution(domain) function space where the value is
1902     one for samples that touch regions tagged by tags.
1903 gross 867
1904 gross 2498 Usage: m=MaskFromTag(domain, "ham")
1905    
1906 jfenwick 2625 :param domain: domain to be used
1907     :type domain: `escript.Domain`
1908     :param tags: boundary tags
1909     :type tags: ``str``
1910     :return: a mask which marks samples that are touching the boundary tagged
1911 gross 2498 by any of the given tags
1912 jfenwick 2625 :rtype: `escript.Data` of rank 0
1913 gross 2498 """
1914     pde=linearPDEs.LinearPDE(domain,numEquations=1, numSolutions=1)
1915     d=escript.Scalar(0.,escript.Function(domain))
1916     for t in tags: d.setTaggedValue(t,1.)
1917     pde.setValue(Y=d)
1918     return util.whereNonZero(pde.getRightHandSide())
1919    
1920    

Properties

Name Value
svn:eol-style native
svn:keywords Author Date Id Revision

  ViewVC Help
Powered by ViewVC 1.1.26