/[escript]/trunk/escript/py_src/pdetools.py
ViewVC logotype

Annotation of /trunk/escript/py_src/pdetools.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 4407 - (hide annotations)
Tue May 14 05:17:24 2013 UTC (6 years, 9 months ago) by jfenwick
File MIME type: text/x-python
File size: 67354 byte(s)
Added output function to arithmetic tuple
1 ksteube 1809
2 jfenwick 3981 ##############################################################################
3 ksteube 1312 #
4 jfenwick 4154 # Copyright (c) 2003-2013 by University of Queensland
5 jfenwick 3981 # http://www.uq.edu.au
6 ksteube 1312 #
7 ksteube 1809 # Primary Business: Queensland, Australia
8     # Licensed under the Open Software License version 3.0
9     # http://www.opensource.org/licenses/osl-3.0.php
10 ksteube 1312 #
11 jfenwick 3981 # Development until 2012 by Earth Systems Science Computational Center (ESSCC)
12     # Development since 2012 by School of Earth Sciences
13     #
14     ##############################################################################
15 jgs 121
16 jfenwick 4154 __copyright__="""Copyright (c) 2003-2013 by University of Queensland
17 jfenwick 3981 http://www.uq.edu.au
18 ksteube 1809 Primary Business: Queensland, Australia"""
19     __license__="""Licensed under the Open Software License version 3.0
20     http://www.opensource.org/licenses/osl-3.0.php"""
21 jfenwick 2344 __url__="https://launchpad.net/escript-finley"
22 ksteube 1809
23 jgs 121 """
24 caltinay 2169 Provides some tools related to PDEs.
25 jgs 121
26 jgs 149 Currently includes:
27 caltinay 2169 - Projector - to project a discontinuous function onto a continuous function
28 gross 351 - Locator - to trace values in data objects at a certain location
29 caltinay 2169 - TimeIntegrationManager - to handle extrapolation in time
30 gross 867 - SaddlePointProblem - solver for Saddle point problems using the inexact uszawa scheme
31 gross 637
32 jfenwick 2625 :var __author__: name of author
33     :var __copyright__: copyrights
34     :var __license__: licence agreement
35     :var __url__: url entry point on documentation
36     :var __version__: version
37     :var __date__: date of the version
38 jgs 121 """
39    
40 gross 637 __author__="Lutz Gross, l.gross@uq.edu.au"
41 elspeth 609
42 gross 637
43 jfenwick 3995 #from . import escript
44     from . import escriptcpp
45     escore=escriptcpp
46 jfenwick 3771 from . import linearPDEs
47 jfenwick 2455 import numpy
48 jfenwick 3771 from . import util
49 ksteube 1312 import math
50 jgs 121
51 jfenwick 4104 class TimeIntegrationManager(object):
52 gross 351 """
53 caltinay 2169 A simple mechanism to manage time dependend values.
54 gross 351
55 caltinay 2169 Typical usage is::
56 gross 351
57 gross 720 dt=0.1 # time increment
58     tm=TimeIntegrationManager(inital_value,p=1)
59     while t<1.
60     v_guess=tm.extrapolate(dt) # extrapolate to t+dt
61     v=...
62     tm.checkin(dt,v)
63     t+=dt
64 gross 351
65 jfenwick 2625 :note: currently only p=1 is supported.
66 gross 351 """
67     def __init__(self,*inital_values,**kwargs):
68     """
69 jfenwick 2625 Sets up the value manager where ``inital_values`` are the initial values
70 caltinay 2169 and p is the order used for extrapolation.
71 gross 351 """
72 jfenwick 3771 if "p" in kwargs:
73 gross 351 self.__p=kwargs["p"]
74     else:
75     self.__p=1
76 jfenwick 3771 if "time" in kwargs:
77 gross 351 self.__t=kwargs["time"]
78     else:
79     self.__t=0.
80     self.__v_mem=[inital_values]
81     self.__order=0
82     self.__dt_mem=[]
83     self.__num_val=len(inital_values)
84    
85     def getTime(self):
86     return self.__t
87 caltinay 2169
88 gross 396 def getValue(self):
89 gross 409 out=self.__v_mem[0]
90     if len(out)==1:
91     return out[0]
92     else:
93     return out
94    
95 gross 351 def checkin(self,dt,*values):
96     """
97 caltinay 2169 Adds new values to the manager. The p+1 last values are lost.
98 gross 351 """
99     o=min(self.__order+1,self.__p)
100     self.__order=min(self.__order+1,self.__p)
101     v_mem_new=[values]
102     dt_mem_new=[dt]
103     for i in range(o-1):
104     v_mem_new.append(self.__v_mem[i])
105     dt_mem_new.append(self.__dt_mem[i])
106     v_mem_new.append(self.__v_mem[o-1])
107     self.__order=o
108     self.__v_mem=v_mem_new
109     self.__dt_mem=dt_mem_new
110     self.__t+=dt
111    
112     def extrapolate(self,dt):
113     """
114 jfenwick 2625 Extrapolates to ``dt`` forward in time.
115 gross 351 """
116     if self.__order==0:
117     out=self.__v_mem[0]
118     else:
119     out=[]
120     for i in range(self.__num_val):
121     out.append((1.+dt/self.__dt_mem[0])*self.__v_mem[0][i]-dt/self.__dt_mem[0]*self.__v_mem[1][i])
122    
123     if len(out)==0:
124     return None
125     elif len(out)==1:
126     return out[0]
127     else:
128     return out
129    
130 caltinay 2169
131 jfenwick 4104 class Projector(object):
132 jgs 149 """
133 caltinay 2169 The Projector is a factory which projects a discontinuous function onto a
134     continuous function on a given domain.
135 jgs 149 """
136 caltinay 2169 def __init__(self, domain, reduce=True, fast=True):
137 jgs 121 """
138 caltinay 2169 Creates a continuous function space projector for a domain.
139 jgs 121
140 jfenwick 2625 :param domain: Domain of the projection.
141     :param reduce: Flag to reduce projection order
142     :param fast: Flag to use a fast method based on matrix lumping
143 jgs 121 """
144 jgs 149 self.__pde = linearPDEs.LinearPDE(domain)
145 jgs 148 if fast:
146 gross 2474 self.__pde.getSolverOptions().setSolverMethod(linearPDEs.SolverOptions.LUMPING)
147 jgs 121 self.__pde.setSymmetryOn()
148     self.__pde.setReducedOrderTo(reduce)
149     self.__pde.setValue(D = 1.)
150 ksteube 1312 return
151 gross 2474 def getSolverOptions(self):
152 artak 2693 """
153     Returns the solver options of the PDE solver.
154    
155     :rtype: `linearPDEs.SolverOptions`
156     """
157     return self.__pde.getSolverOptions()
158 jgs 121
159 gross 2867 def getValue(self, input_data):
160     """
161     Projects ``input_data`` onto a continuous function.
162    
163     :param input_data: the data to be projected
164     """
165     return self(input_data)
166    
167 jgs 121 def __call__(self, input_data):
168     """
169 jfenwick 2625 Projects ``input_data`` onto a continuous function.
170 jgs 121
171 jfenwick 2625 :param input_data: the data to be projected
172 jgs 121 """
173 jfenwick 3995 out=escore.Data(0.,input_data.getShape(),self.__pde.getFunctionSpaceForSolution())
174     self.__pde.setValue(Y = escore.Data(), Y_reduced = escore.Data())
175 jgs 121 if input_data.getRank()==0:
176     self.__pde.setValue(Y = input_data)
177     out=self.__pde.getSolution()
178     elif input_data.getRank()==1:
179     for i0 in range(input_data.getShape()[0]):
180     self.__pde.setValue(Y = input_data[i0])
181     out[i0]=self.__pde.getSolution()
182     elif input_data.getRank()==2:
183     for i0 in range(input_data.getShape()[0]):
184     for i1 in range(input_data.getShape()[1]):
185     self.__pde.setValue(Y = input_data[i0,i1])
186     out[i0,i1]=self.__pde.getSolution()
187     elif input_data.getRank()==3:
188     for i0 in range(input_data.getShape()[0]):
189     for i1 in range(input_data.getShape()[1]):
190     for i2 in range(input_data.getShape()[2]):
191     self.__pde.setValue(Y = input_data[i0,i1,i2])
192     out[i0,i1,i2]=self.__pde.getSolution()
193     else:
194     for i0 in range(input_data.getShape()[0]):
195     for i1 in range(input_data.getShape()[1]):
196     for i2 in range(input_data.getShape()[2]):
197     for i3 in range(input_data.getShape()[3]):
198     self.__pde.setValue(Y = input_data[i0,i1,i2,i3])
199     out[i0,i1,i2,i3]=self.__pde.getSolution()
200     return out
201    
202 jfenwick 4104 class NoPDE(object):
203 gross 525 """
204 caltinay 2169 Solves the following problem for u:
205 jgs 121
206 jfenwick 2625 *kronecker[i,j]*D[j]*u[j]=Y[i]*
207 gross 525
208     with constraint
209    
210 jfenwick 2625 *u[j]=r[j]* where *q[j]>0*
211 gross 525
212 jfenwick 2625 where *D*, *Y*, *r* and *q* are given functions of rank 1.
213 gross 525
214     In the case of scalars this takes the form
215    
216 jfenwick 2625 *D*u=Y*
217 gross 525
218     with constraint
219    
220 jfenwick 2625 *u=r* where *q>0*
221 gross 525
222 jfenwick 2625 where *D*, *Y*, *r* and *q* are given scalar functions.
223 gross 525
224 caltinay 2169 The constraint overwrites any other condition.
225 gross 525
226 jfenwick 2625 :note: This class is similar to the `linearPDEs.LinearPDE` class with
227 caltinay 2169 A=B=C=X=0 but has the intention that all input parameters are given
228 jfenwick 2625 in `Solution` or `ReducedSolution`.
229 gross 525 """
230 caltinay 2169 # The whole thing is a bit strange and I blame Rob Woodcock (CSIRO) for
231     # this.
232 gross 525 def __init__(self,domain,D=None,Y=None,q=None,r=None):
233     """
234 caltinay 2169 Initializes the problem.
235 gross 525
236 jfenwick 2625 :param domain: domain of the PDE
237     :type domain: `Domain`
238     :param D: coefficient of the solution
239     :type D: ``float``, ``int``, ``numpy.ndarray``, `Data`
240     :param Y: right hand side
241     :type Y: ``float``, ``int``, ``numpy.ndarray``, `Data`
242     :param q: location of constraints
243     :type q: ``float``, ``int``, ``numpy.ndarray``, `Data`
244     :param r: value of solution at locations of constraints
245     :type r: ``float``, ``int``, ``numpy.ndarray``, `Data`
246 gross 525 """
247     self.__domain=domain
248     self.__D=D
249     self.__Y=Y
250     self.__q=q
251     self.__r=r
252     self.__u=None
253 jfenwick 3995 self.__function_space=escore.Solution(self.__domain)
254 caltinay 2169
255 gross 525 def setReducedOn(self):
256     """
257 jfenwick 2625 Sets the `FunctionSpace` of the solution to `ReducedSolution`.
258 gross 525 """
259 jfenwick 3995 self.__function_space=escore.ReducedSolution(self.__domain)
260 gross 525 self.__u=None
261    
262     def setReducedOff(self):
263     """
264 jfenwick 2625 Sets the `FunctionSpace` of the solution to `Solution`.
265 gross 525 """
266 jfenwick 3995 self.__function_space=escore.Solution(self.__domain)
267 gross 525 self.__u=None
268 caltinay 2169
269 gross 525 def setValue(self,D=None,Y=None,q=None,r=None):
270     """
271 caltinay 2169 Assigns values to the parameters.
272 gross 525
273 jfenwick 2625 :param D: coefficient of the solution
274     :type D: ``float``, ``int``, ``numpy.ndarray``, `Data`
275     :param Y: right hand side
276     :type Y: ``float``, ``int``, ``numpy.ndarray``, `Data`
277     :param q: location of constraints
278     :type q: ``float``, ``int``, ``numpy.ndarray``, `Data`
279     :param r: value of solution at locations of constraints
280     :type r: ``float``, ``int``, ``numpy.ndarray``, `Data`
281 gross 525 """
282     if not D==None:
283     self.__D=D
284     self.__u=None
285     if not Y==None:
286     self.__Y=Y
287     self.__u=None
288     if not q==None:
289     self.__q=q
290     self.__u=None
291     if not r==None:
292     self.__r=r
293     self.__u=None
294    
295     def getSolution(self):
296     """
297 caltinay 2169 Returns the solution.
298    
299 jfenwick 2625 :return: the solution of the problem
300     :rtype: `Data` object in the `FunctionSpace` `Solution` or
301     `ReducedSolution`
302 gross 525 """
303     if self.__u==None:
304     if self.__D==None:
305 jfenwick 3771 raise ValueError("coefficient D is undefined")
306 jfenwick 3995 D=escore.Data(self.__D,self.__function_space)
307 gross 525 if D.getRank()>1:
308 jfenwick 3771 raise ValueError("coefficient D must have rank 0 or 1")
309 gross 525 if self.__Y==None:
310 jfenwick 3995 self.__u=escore.Data(0.,D.getShape(),self.__function_space)
311 gross 525 else:
312 caltinay 3975 self.__u=1./D*self.__Y
313 gross 525 if not self.__q==None:
314 jfenwick 3995 q=util.wherePositive(escore.Data(self.__q,self.__function_space))
315 gross 525 self.__u*=(1.-q)
316     if not self.__r==None: self.__u+=q*self.__r
317     return self.__u
318 caltinay 2169
319 jfenwick 4104 class Locator(object):
320 jgs 147 """
321 caltinay 2169 Locator provides access to the values of data objects at a given spatial
322     coordinate x.
323    
324 jgs 149 In fact, a Locator object finds the sample in the set of samples of a
325 caltinay 2169 given function space or domain which is closest to the given point x.
326 jgs 147 """
327    
328 jfenwick 2455 def __init__(self,where,x=numpy.zeros((3,))):
329 jgs 149 """
330 caltinay 2169 Initializes a Locator to access values in Data objects on the Doamin
331     or FunctionSpace for the sample point which is closest to the given
332     point x.
333 gross 880
334 jfenwick 2625 :param where: function space
335     :type where: `escript.FunctionSpace`
336     :param x: location(s) of the Locator
337     :type x: ``numpy.ndarray`` or ``list`` of ``numpy.ndarray``
338 jgs 149 """
339 jfenwick 3995 if isinstance(where,escore.FunctionSpace):
340 jgs 147 self.__function_space=where
341 jgs 121 else:
342 jfenwick 3995 self.__function_space=escore.ContinuousFunction(where)
343 gross 2484 iterative=False
344 gross 880 if isinstance(x, list):
345 gross 2484 if len(x)==0:
346 jfenwick 3771 raise ValueError("At least one point must be given.")
347 gross 2484 try:
348     iter(x[0])
349     iterative=True
350     except TypeError:
351     iterative=False
352 gross 2954 xxx=self.__function_space.getX()
353 gross 2484 if iterative:
354 gross 880 self.__id=[]
355     for p in x:
356 gross 2948 self.__id.append(util.length(xxx-p[:self.__function_space.getDim()]).minGlobalDataPoint())
357 gross 880 else:
358 gross 2948 self.__id=util.length(xxx-x[:self.__function_space.getDim()]).minGlobalDataPoint()
359 jgs 121
360 jgs 147 def __str__(self):
361 jgs 149 """
362     Returns the coordinates of the Locator as a string.
363     """
364 gross 880 x=self.getX()
365 gross 2676 if isinstance(x,list):
366 gross 880 out="["
367     first=True
368     for xx in x:
369     if not first:
370     out+=","
371     else:
372     first=False
373     out+=str(xx)
374     out+="]>"
375     else:
376     out=str(x)
377     return out
378 jgs 121
379 gross 880 def getX(self):
380     """
381 caltinay 2169 Returns the exact coordinates of the Locator.
382     """
383 gross 880 return self(self.getFunctionSpace().getX())
384    
385 jgs 147 def getFunctionSpace(self):
386 jgs 149 """
387 caltinay 2169 Returns the function space of the Locator.
388     """
389 jgs 147 return self.__function_space
390    
391 gross 880 def getId(self,item=None):
392 jgs 149 """
393 caltinay 2169 Returns the identifier of the location.
394     """
395 gross 880 if item == None:
396     return self.__id
397     else:
398     if isinstance(self.__id,list):
399     return self.__id[item]
400     else:
401     return self.__id
402 jgs 121
403    
404 jgs 147 def __call__(self,data):
405 jgs 149 """
406 caltinay 2169 Returns the value of data at the Locator of a Data object.
407     """
408 jgs 147 return self.getValue(data)
409 jgs 121
410 jgs 147 def getValue(self,data):
411 jgs 149 """
412 jfenwick 2625 Returns the value of ``data`` at the Locator if ``data`` is a `Data`
413 caltinay 2169 object otherwise the object is returned.
414     """
415 jfenwick 3995 if isinstance(data,escore.Data):
416 jfenwick 2517 dat=util.interpolate(data,self.getFunctionSpace())
417 gross 880 id=self.getId()
418     r=data.getRank()
419     if isinstance(id,list):
420     out=[]
421     for i in id:
422 jfenwick 2517 o=numpy.array(dat.getTupleForGlobalDataPoint(*i))
423 gross 880 if data.getRank()==0:
424     out.append(o[0])
425     else:
426     out.append(o)
427     return out
428 jgs 147 else:
429 jfenwick 2517 out=numpy.array(dat.getTupleForGlobalDataPoint(*id))
430 gross 880 if data.getRank()==0:
431     return out[0]
432     else:
433     return out
434 jgs 147 else:
435     return data
436 jfenwick 3574
437 jfenwick 3922 def setValue(self, data, v):
438     """
439     Sets the value of the ``data`` at the Locator.
440     """
441 jfenwick 3995 if isinstance(data, escore.Data):
442 jfenwick 3928 if data.getFunctionSpace()!=self.getFunctionSpace():
443 jfenwick 4018 raise TypeError("setValue: FunctionSpace of Locator and Data object must match.")
444 jfenwick 4237 data.expand()
445 jfenwick 3922 id=self.getId()
446     if isinstance(id, list):
447     for i in id:
448     data._setTupleForGlobalDataPoint(i[1], i[0], v)
449     else:
450     data._setTupleForGlobalDataPoint(id[1], id[0], v)
451     else:
452 jfenwick 4018 raise TypeError("setValue: Invalid argument type.")
453 jgs 149
454 jfenwick 2745
455     def getInfLocator(arg):
456     """
457     Return a Locator for a point with the inf value over all arg.
458     """
459 jfenwick 3995 if not isinstance(arg, escore.Data):
460 jfenwick 3852 raise TypeError("getInfLocator: Unknown argument type.")
461 jfenwick 2745 a_inf=util.inf(arg)
462 jfenwick 4237 loc=util.length(arg-a_inf).minGlobalDataPoint() # This gives us the location but not coords
463 jfenwick 2745 x=arg.getFunctionSpace().getX()
464     x_min=x.getTupleForGlobalDataPoint(*loc)
465     return Locator(arg.getFunctionSpace(),x_min)
466    
467     def getSupLocator(arg):
468     """
469     Return a Locator for a point with the sup value over all arg.
470     """
471 jfenwick 3995 if not isinstance(arg, escore.Data):
472 jfenwick 3852 raise TypeError("getInfLocator: Unknown argument type.")
473 jfenwick 2745 a_inf=util.sup(arg)
474 jfenwick 4237 loc=util.length(arg-a_inf).minGlobalDataPoint() # This gives us the location but not coords
475 jfenwick 2745 x=arg.getFunctionSpace().getX()
476     x_min=x.getTupleForGlobalDataPoint(*loc)
477     return Locator(arg.getFunctionSpace(),x_min)
478 jfenwick 4237
479 jfenwick 2745
480 ksteube 1312 class SolverSchemeException(Exception):
481     """
482 caltinay 2169 This is a generic exception thrown by solvers.
483 ksteube 1312 """
484     pass
485    
486     class IndefinitePreconditioner(SolverSchemeException):
487     """
488 caltinay 2169 Exception thrown if the preconditioner is not positive definite.
489 ksteube 1312 """
490     pass
491 caltinay 2169
492 ksteube 1312 class MaxIterReached(SolverSchemeException):
493     """
494 caltinay 2169 Exception thrown if the maximum number of iteration steps is reached.
495 ksteube 1312 """
496     pass
497 caltinay 2169
498 gross 2100 class CorrectionFailed(SolverSchemeException):
499     """
500 caltinay 2169 Exception thrown if no convergence has been achieved in the solution
501     correction scheme.
502 gross 2100 """
503     pass
504 caltinay 2169
505 ksteube 1312 class IterationBreakDown(SolverSchemeException):
506     """
507 caltinay 2169 Exception thrown if the iteration scheme encountered an incurable breakdown.
508 ksteube 1312 """
509     pass
510 caltinay 2169
511 ksteube 1312 class NegativeNorm(SolverSchemeException):
512     """
513 caltinay 2169 Exception thrown if a norm calculation returns a negative norm.
514 ksteube 1312 """
515     pass
516    
517 gross 2156 def PCG(r, Aprod, x, Msolve, bilinearform, atol=0, rtol=1.e-8, iter_max=100, initial_guess=True, verbose=False):
518 ksteube 1312 """
519 caltinay 2169 Solver for
520 ksteube 1312
521 jfenwick 2625 *Ax=b*
522 ksteube 1312
523 caltinay 2169 with a symmetric and positive definite operator A (more details required!).
524     It uses the conjugate gradient method with preconditioner M providing an
525     approximation of A.
526 ksteube 1312
527 caltinay 2169 The iteration is terminated if
528 ksteube 1312
529 jfenwick 2625 *|r| <= atol+rtol*|r0|*
530 gross 2156
531 jfenwick 2625 where *r0* is the initial residual and *|.|* is the energy norm. In fact
532 gross 2156
533 jfenwick 2625 *|r| = sqrt( bilinearform(Msolve(r),r))*
534 gross 2156
535 ksteube 1312 For details on the preconditioned conjugate gradient method see the book:
536    
537 caltinay 2169 I{Templates for the Solution of Linear Systems by R. Barrett, M. Berry,
538     T.F. Chan, J. Demmel, J. Donato, J. Dongarra, V. Eijkhout, R. Pozo,
539     C. Romine, and H. van der Vorst}.
540 ksteube 1312
541 jfenwick 2625 :param r: initial residual *r=b-Ax*. ``r`` is altered.
542     :type r: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
543     :param x: an initial guess for the solution
544     :type x: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
545     :param Aprod: returns the value Ax
546     :type Aprod: function ``Aprod(x)`` where ``x`` is of the same object like
547     argument ``x``. The returned object needs to be of the same type
548     like argument ``r``.
549     :param Msolve: solves Mx=r
550     :type Msolve: function ``Msolve(r)`` where ``r`` is of the same type like
551     argument ``r``. The returned object needs to be of the same
552     type like argument ``x``.
553     :param bilinearform: inner product ``<x,r>``
554     :type bilinearform: function ``bilinearform(x,r)`` where ``x`` is of the same
555     type like argument ``x`` and ``r`` is. The returned value
556     is a ``float``.
557     :param atol: absolute tolerance
558     :type atol: non-negative ``float``
559     :param rtol: relative tolerance
560     :type rtol: non-negative ``float``
561     :param iter_max: maximum number of iteration steps
562     :type iter_max: ``int``
563     :return: the solution approximation and the corresponding residual
564     :rtype: ``tuple``
565     :warning: ``r`` and ``x`` are altered.
566 ksteube 1312 """
567     iter=0
568     rhat=Msolve(r)
569 caltinay 2169 d = rhat
570 ksteube 1312 rhat_dot_r = bilinearform(rhat, r)
571 jfenwick 3771 if rhat_dot_r<0: raise NegativeNorm("negative norm.")
572 gross 2156 norm_r0=math.sqrt(rhat_dot_r)
573     atol2=atol+rtol*norm_r0
574     if atol2<=0:
575 jfenwick 3771 raise ValueError("Non-positive tolarance.")
576 gross 2156 atol2=max(atol2, 100. * util.EPSILON * norm_r0)
577 ksteube 1312
578 jfenwick 3771 if verbose: print(("PCG: initial residual norm = %e (absolute tolerance = %e)"%(norm_r0, atol2)))
579 gross 2156
580 caltinay 2169
581 gross 2156 while not math.sqrt(rhat_dot_r) <= atol2:
582 gross 1330 iter+=1
583 jfenwick 3771 if iter >= iter_max: raise MaxIterReached("maximum number of %s steps reached."%iter_max)
584 ksteube 1312
585     q=Aprod(d)
586     alpha = rhat_dot_r / bilinearform(d, q)
587     x += alpha * d
588 jfenwick 2455 if isinstance(q,ArithmeticTuple):
589 jfenwick 3852 r += q * (-alpha) # Doing it the other way calls the float64.__mul__ not AT.__rmul__
590 jfenwick 2455 else:
591     r += (-alpha) * q
592 ksteube 1312 rhat=Msolve(r)
593     rhat_dot_r_new = bilinearform(rhat, r)
594     beta = rhat_dot_r_new / rhat_dot_r
595     rhat+=beta * d
596     d=rhat
597    
598     rhat_dot_r = rhat_dot_r_new
599 jfenwick 3771 if rhat_dot_r<0: raise NegativeNorm("negative norm.")
600     if verbose: print(("PCG: iteration step %s: residual norm = %e"%(iter, math.sqrt(rhat_dot_r))))
601     if verbose: print(("PCG: tolerance reached after %s steps."%iter))
602 gross 2264 return x,r,math.sqrt(rhat_dot_r)
603 ksteube 1312
604 gross 1878 class Defect(object):
605     """
606 caltinay 2169 Defines a non-linear defect F(x) of a variable x.
607 gross 1878 """
608     def __init__(self):
609     """
610 caltinay 2169 Initializes defect.
611 gross 1878 """
612     self.setDerivativeIncrementLength()
613 artak 1465
614 gross 1878 def bilinearform(self, x0, x1):
615     """
616 caltinay 2169 Returns the inner product of x0 and x1
617    
618 jfenwick 2625 :param x0: value for x0
619     :param x1: value for x1
620     :return: the inner product of x0 and x1
621     :rtype: ``float``
622 gross 1878 """
623     return 0
624 caltinay 2169
625 gross 1878 def norm(self,x):
626     """
627 jfenwick 2625 Returns the norm of argument ``x``.
628 caltinay 2169
629 jfenwick 2625 :param x: a value
630     :return: norm of argument x
631     :rtype: ``float``
632     :note: by default ``sqrt(self.bilinearform(x,x)`` is returned.
633 gross 1878 """
634     s=self.bilinearform(x,x)
635 jfenwick 3771 if s<0: raise NegativeNorm("negative norm.")
636 gross 1878 return math.sqrt(s)
637 artak 1465
638 gross 1878 def eval(self,x):
639     """
640 jfenwick 2625 Returns the value F of a given ``x``.
641 gross 1878
642 jfenwick 2625 :param x: value for which the defect ``F`` is evaluated
643     :return: value of the defect at ``x``
644 gross 1878 """
645     return 0
646    
647     def __call__(self,x):
648     return self.eval(x)
649    
650 gross 2683 def setDerivativeIncrementLength(self,inc=1000.*math.sqrt(util.EPSILON)):
651 gross 1878 """
652 caltinay 2169 Sets the relative length of the increment used to approximate the
653     derivative of the defect. The increment is inc*norm(x)/norm(v)*v in the
654     direction of v with x as a starting point.
655 gross 1878
656 jfenwick 2625 :param inc: relative increment length
657     :type inc: positive ``float``
658 gross 1878 """
659 jfenwick 3771 if inc<=0: raise ValueError("positive increment required.")
660 gross 1878 self.__inc=inc
661    
662     def getDerivativeIncrementLength(self):
663     """
664 caltinay 2169 Returns the relative increment length used to approximate the
665     derivative of the defect.
666 jfenwick 2625 :return: value of the defect at ``x``
667     :rtype: positive ``float``
668 gross 1878 """
669     return self.__inc
670    
671     def derivative(self, F0, x0, v, v_is_normalised=True):
672     """
673 jfenwick 2625 Returns the directional derivative at ``x0`` in the direction of ``v``.
674 gross 1878
675 jfenwick 2625 :param F0: value of this defect at x0
676     :param x0: value at which derivative is calculated
677     :param v: direction
678     :param v_is_normalised: True to indicate that ``v`` is nomalized
679 caltinay 2169 (self.norm(v)=0)
680 jfenwick 2625 :return: derivative of this defect at x0 in the direction of ``v``
681     :note: by default numerical evaluation (self.eval(x0+eps*v)-F0)/eps is
682 caltinay 2169 used but this method maybe overwritten to use exact evaluation.
683 gross 1878 """
684     normx=self.norm(x0)
685     if normx>0:
686     epsnew = self.getDerivativeIncrementLength() * normx
687     else:
688     epsnew = self.getDerivativeIncrementLength()
689     if not v_is_normalised:
690     normv=self.norm(v)
691     if normv<=0:
692     return F0*0
693     else:
694     epsnew /= normv
695     F1=self.eval(x0 + epsnew * v)
696     return (F1-F0)/epsnew
697    
698 caltinay 2169 ######################################
699 gross 2719 def NewtonGMRES(defect, x, iter_max=100, sub_iter_max=20, atol=0,rtol=1.e-4, subtol_max=0.5, gamma=0.9, verbose=False):
700 gross 1878 """
701 jfenwick 2625 Solves a non-linear problem *F(x)=0* for unknown *x* using the stopping
702 caltinay 2169 criterion:
703 gross 1878
704 jfenwick 2625 *norm(F(x) <= atol + rtol * norm(F(x0)*
705 caltinay 2169
706 jfenwick 2625 where *x0* is the initial guess.
707 gross 1878
708 jfenwick 2625 :param defect: object defining the function *F*. ``defect.norm`` defines the
709     *norm* used in the stopping criterion.
710     :type defect: `Defect`
711     :param x: initial guess for the solution, ``x`` is altered.
712     :type x: any object type allowing basic operations such as
713     ``numpy.ndarray``, `Data`
714     :param iter_max: maximum number of iteration steps
715     :type iter_max: positive ``int``
716     :param sub_iter_max: maximum number of inner iteration steps
717     :type sub_iter_max: positive ``int``
718     :param atol: absolute tolerance for the solution
719     :type atol: positive ``float``
720     :param rtol: relative tolerance for the solution
721     :type rtol: positive ``float``
722     :param gamma: tolerance safety factor for inner iteration
723     :type gamma: positive ``float``, less than 1
724 gross 2719 :param subtol_max: upper bound for inner tolerance
725     :type subtol_max: positive ``float``, less than 1
726 jfenwick 2625 :return: an approximation of the solution with the desired accuracy
727     :rtype: same type as the initial guess
728 gross 1878 """
729     lmaxit=iter_max
730 jfenwick 3771 if atol<0: raise ValueError("atol needs to be non-negative.")
731     if rtol<0: raise ValueError("rtol needs to be non-negative.")
732     if rtol+atol<=0: raise ValueError("rtol or atol needs to be non-negative.")
733     if gamma<=0 or gamma>=1: raise ValueError("tolerance safety factor for inner iteration (gamma =%s) needs to be positive and less than 1."%gamma)
734     if subtol_max<=0 or subtol_max>=1: raise ValueError("upper bound for inner tolerance for inner iteration (subtol_max =%s) needs to be positive and less than 1."%subtol_max)
735 gross 1878
736     F=defect(x)
737     fnrm=defect.norm(F)
738     stop_tol=atol + rtol*fnrm
739 gross 2719 subtol=subtol_max
740 jfenwick 3771 if verbose: print(("NewtonGMRES: initial residual = %e."%fnrm))
741     if verbose: print((" tolerance = %e."%subtol))
742 gross 1878 iter=1
743     #
744     # main iteration loop
745     #
746     while not fnrm<=stop_tol:
747 jfenwick 3771 if iter >= iter_max: raise MaxIterReached("maximum number of %s steps reached."%iter_max)
748 gross 1878 #
749 jfenwick 4237 # adjust subtol_
750     #
751 gross 1878 if iter > 1:
752 jfenwick 3852 rat=fnrm/fnrmo
753 gross 2719 subtol_old=subtol
754 jfenwick 3852 subtol=gamma*rat**2
755     if gamma*subtol_old**2 > .1: subtol=max(subtol,gamma*subtol_old**2)
756     subtol=max(min(subtol,subtol_max), .5*stop_tol/fnrm)
757 jfenwick 4237 #
758     # calculate newton increment xc
759 gross 1878 # if iter_max in __FDGMRES is reached MaxIterReached is thrown
760     # if iter_restart -1 is returned as sub_iter
761     # if atol is reached sub_iter returns the numer of steps performed to get there
762 caltinay 2169 #
763     #
764 jfenwick 3771 if verbose: print((" subiteration (GMRES) is called with relative tolerance %e."%subtol))
765 gross 1878 try:
766 gross 2719 xc, sub_iter=__FDGMRES(F, defect, x, subtol*fnrm, iter_max=iter_max-iter, iter_restart=sub_iter_max)
767 gross 1878 except MaxIterReached:
768 jfenwick 3771 raise MaxIterReached("maximum number of %s steps reached."%iter_max)
769 gross 1878 if sub_iter<0:
770     iter+=sub_iter_max
771     else:
772     iter+=sub_iter
773     # ====
774 jfenwick 3852 x+=xc
775 gross 1878 F=defect(x)
776 jfenwick 3852 iter+=1
777 gross 1878 fnrmo, fnrm=fnrm, defect.norm(F)
778 jfenwick 3771 if verbose: print((" step %s: residual %e."%(iter,fnrm)))
779     if verbose: print(("NewtonGMRES: completed after %s steps."%iter))
780 gross 1878 return x
781    
782     def __givapp(c,s,vin):
783     """
784 caltinay 2169 Applies a sequence of Givens rotations (c,s) recursively to the vector
785 jfenwick 2625 ``vin``
786 caltinay 2169
787 jfenwick 2625 :warning: ``vin`` is altered.
788 gross 1878 """
789 caltinay 2169 vrot=vin
790 gross 1467 if isinstance(c,float):
791 artak 2303 vrot=[c*vrot[0]-s*vrot[1],s*vrot[0]+c*vrot[1]]
792 gross 1467 else:
793     for i in range(len(c)):
794     w1=c[i]*vrot[i]-s[i]*vrot[i+1]
795 jfenwick 3852 w2=s[i]*vrot[i]+c[i]*vrot[i+1]
796 gross 2284 vrot[i]=w1
797     vrot[i+1]=w2
798 artak 1465 return vrot
799    
800 gross 1878 def __FDGMRES(F0, defect, x0, atol, iter_max=100, iter_restart=20):
801 jfenwick 2455 h=numpy.zeros((iter_restart,iter_restart),numpy.float64)
802     c=numpy.zeros(iter_restart,numpy.float64)
803     s=numpy.zeros(iter_restart,numpy.float64)
804     g=numpy.zeros(iter_restart,numpy.float64)
805 artak 1465 v=[]
806    
807 gross 1878 rho=defect.norm(F0)
808     if rho<=0.: return x0*0
809 caltinay 2169
810 gross 1878 v.append(-F0/rho)
811 artak 1465 g[0]=rho
812 gross 1878 iter=0
813     while rho > atol and iter<iter_restart-1:
814 caltinay 2169 if iter >= iter_max:
815 jfenwick 3771 raise MaxIterReached("maximum number of %s steps reached."%iter_max)
816 artak 1557
817 gross 1878 p=defect.derivative(F0,x0,v[iter], v_is_normalised=True)
818 caltinay 2169 v.append(p)
819 artak 1465
820 caltinay 2169 v_norm1=defect.norm(v[iter+1])
821 artak 1465
822 caltinay 2169 # Modified Gram-Schmidt
823     for j in range(iter+1):
824     h[j,iter]=defect.bilinearform(v[j],v[iter+1])
825     v[iter+1]-=h[j,iter]*v[j]
826 artak 1465
827 caltinay 2169 h[iter+1,iter]=defect.norm(v[iter+1])
828     v_norm2=h[iter+1,iter]
829    
830 gross 1878 # Reorthogonalize if needed
831 caltinay 2169 if v_norm1 + 0.001*v_norm2 == v_norm1: #Brown/Hindmarsh condition (default)
832     for j in range(iter+1):
833     hr=defect.bilinearform(v[j],v[iter+1])
834     h[j,iter]=h[j,iter]+hr
835     v[iter+1] -= hr*v[j]
836 artak 1465
837 caltinay 2169 v_norm2=defect.norm(v[iter+1])
838     h[iter+1,iter]=v_norm2
839     # watch out for happy breakdown
840 artak 1550 if not v_norm2 == 0:
841 caltinay 2169 v[iter+1]=v[iter+1]/h[iter+1,iter]
842 artak 1465
843 gross 1878 # Form and store the information for the new Givens rotation
844 caltinay 2169 if iter > 0 :
845 jfenwick 2455 hhat=numpy.zeros(iter+1,numpy.float64)
846 caltinay 2169 for i in range(iter+1) : hhat[i]=h[i,iter]
847     hhat=__givapp(c[0:iter],s[0:iter],hhat);
848     for i in range(iter+1) : h[i,iter]=hhat[i]
849 artak 1465
850 caltinay 2169 mu=math.sqrt(h[iter,iter]*h[iter,iter]+h[iter+1,iter]*h[iter+1,iter])
851 artak 1557
852 caltinay 2169 if mu!=0 :
853     c[iter]=h[iter,iter]/mu
854     s[iter]=-h[iter+1,iter]/mu
855     h[iter,iter]=c[iter]*h[iter,iter]-s[iter]*h[iter+1,iter]
856     h[iter+1,iter]=0.0
857 artak 2303 gg=__givapp(c[iter],s[iter],[g[iter],g[iter+1]])
858 gross 2284 g[iter]=gg[0]
859     g[iter+1]=gg[1]
860 artak 1465
861 gross 1878 # Update the residual norm
862 artak 1465 rho=abs(g[iter+1])
863 caltinay 2169 iter+=1
864 artak 1465
865 gross 1878 # At this point either iter > iter_max or rho < tol.
866 caltinay 2169 # It's time to compute x and leave.
867     if iter > 0 :
868 jfenwick 2455 y=numpy.zeros(iter,numpy.float64)
869 gross 2100 y[iter-1] = g[iter-1] / h[iter-1,iter-1]
870 caltinay 2169 if iter > 1 :
871     i=iter-2
872 artak 1465 while i>=0 :
873 jfenwick 2455 y[i] = ( g[i] - numpy.dot(h[i,i+1:iter], y[i+1:iter])) / h[i,i]
874 artak 1465 i=i-1
875     xhat=v[iter-1]*y[iter-1]
876     for i in range(iter-1):
877 jfenwick 3852 xhat += v[i]*y[i]
878 caltinay 2169 else :
879 gross 1878 xhat=v[0] * 0
880 artak 1557
881 caltinay 2169 if iter<iter_restart-1:
882 gross 1878 stopped=iter
883 caltinay 2169 else:
884 gross 1878 stopped=-1
885 artak 1465
886 gross 1878 return xhat,stopped
887 artak 1481
888 gross 2261 def GMRES(r, Aprod, x, bilinearform, atol=0, rtol=1.e-8, iter_max=100, iter_restart=20, verbose=False,P_R=None):
889 gross 2156 """
890 caltinay 2169 Solver for
891 gross 2156
892 jfenwick 2625 *Ax=b*
893 gross 2156
894 caltinay 2169 with a general operator A (more details required!).
895 gross 2156 It uses the generalized minimum residual method (GMRES).
896    
897 caltinay 2169 The iteration is terminated if
898 gross 2156
899 jfenwick 2625 *|r| <= atol+rtol*|r0|*
900 gross 2156
901 jfenwick 2625 where *r0* is the initial residual and *|.|* is the energy norm. In fact
902 gross 2156
903 jfenwick 2625 *|r| = sqrt( bilinearform(r,r))*
904 gross 2156
905 jfenwick 2625 :param r: initial residual *r=b-Ax*. ``r`` is altered.
906     :type r: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
907     :param x: an initial guess for the solution
908     :type x: same like ``r``
909     :param Aprod: returns the value Ax
910     :type Aprod: function ``Aprod(x)`` where ``x`` is of the same object like
911     argument ``x``. The returned object needs to be of the same
912     type like argument ``r``.
913     :param bilinearform: inner product ``<x,r>``
914     :type bilinearform: function ``bilinearform(x,r)`` where ``x`` is of the same
915     type like argument ``x`` and ``r``. The returned value is
916     a ``float``.
917     :param atol: absolute tolerance
918     :type atol: non-negative ``float``
919     :param rtol: relative tolerance
920     :type rtol: non-negative ``float``
921     :param iter_max: maximum number of iteration steps
922     :type iter_max: ``int``
923     :param iter_restart: in order to save memory the orthogonalization process
924     is terminated after ``iter_restart`` steps and the
925 caltinay 2169 iteration is restarted.
926 jfenwick 2625 :type iter_restart: ``int``
927     :return: the solution approximation and the corresponding residual
928     :rtype: ``tuple``
929     :warning: ``r`` and ``x`` are altered.
930 gross 2156 """
931 gross 1878 m=iter_restart
932 gross 2156 restarted=False
933 gross 1878 iter=0
934 gross 2100 if rtol>0:
935 gross 2156 r_dot_r = bilinearform(r, r)
936 jfenwick 3771 if r_dot_r<0: raise NegativeNorm("negative norm.")
937 gross 2100 atol2=atol+rtol*math.sqrt(r_dot_r)
938 jfenwick 3771 if verbose: print(("GMRES: norm of right hand side = %e (absolute tolerance = %e)"%(math.sqrt(r_dot_r), atol2)))
939 gross 2100 else:
940     atol2=atol
941 jfenwick 3771 if verbose: print(("GMRES: absolute tolerance = %e"%atol2))
942 gross 2156 if atol2<=0:
943 jfenwick 3771 raise ValueError("Non-positive tolarance.")
944 caltinay 2169
945 gross 1878 while True:
946 jfenwick 3771 if iter >= iter_max: raise MaxIterReached("maximum number of %s steps reached"%iter_max)
947 caltinay 2169 if restarted:
948 gross 2156 r2 = r-Aprod(x-x2)
949     else:
950     r2=1*r
951     x2=x*1.
952 gross 2261 x,stopped=_GMRESm(r2, Aprod, x, bilinearform, atol2, iter_max=iter_max-iter, iter_restart=m, verbose=verbose,P_R=P_R)
953 caltinay 2169 iter+=iter_restart
954 gross 1878 if stopped: break
955 jfenwick 3771 if verbose: print("GMRES: restart.")
956 gross 2156 restarted=True
957 jfenwick 3771 if verbose: print("GMRES: tolerance has been reached.")
958 gross 2156 return x
959 artak 1550
960 gross 2261 def _GMRESm(r, Aprod, x, bilinearform, atol, iter_max=100, iter_restart=20, verbose=False, P_R=None):
961 gross 1878 iter=0
962 caltinay 2169
963 jfenwick 2455 h=numpy.zeros((iter_restart+1,iter_restart),numpy.float64)
964     c=numpy.zeros(iter_restart,numpy.float64)
965     s=numpy.zeros(iter_restart,numpy.float64)
966     g=numpy.zeros(iter_restart+1,numpy.float64)
967 artak 1519 v=[]
968    
969 gross 2100 r_dot_r = bilinearform(r, r)
970 jfenwick 3771 if r_dot_r<0: raise NegativeNorm("negative norm.")
971 artak 1519 rho=math.sqrt(r_dot_r)
972 caltinay 2169
973 artak 1519 v.append(r/rho)
974     g[0]=rho
975    
976 jfenwick 3771 if verbose: print(("GMRES: initial residual %e (absolute tolerance = %e)"%(rho,atol)))
977 gross 2100 while not (rho<=atol or iter==iter_restart):
978 artak 1519
979 jfenwick 3852 if iter >= iter_max: raise MaxIterReached("maximum number of %s steps reached."%iter_max)
980 artak 1519
981 gross 2261 if P_R!=None:
982     p=Aprod(P_R(v[iter]))
983     else:
984 jfenwick 3852 p=Aprod(v[iter])
985     v.append(p)
986 artak 1519
987 jfenwick 3852 v_norm1=math.sqrt(bilinearform(v[iter+1], v[iter+1]))
988 artak 1519
989 caltinay 2169 # Modified Gram-Schmidt
990 jfenwick 3852 for j in range(iter+1):
991     h[j,iter]=bilinearform(v[j],v[iter+1])
992     v[iter+1]-=h[j,iter]*v[j]
993 caltinay 2169
994 jfenwick 3852 h[iter+1,iter]=math.sqrt(bilinearform(v[iter+1],v[iter+1]))
995     v_norm2=h[iter+1,iter]
996 artak 1519
997     # Reorthogonalize if needed
998 jfenwick 3852 if v_norm1 + 0.001*v_norm2 == v_norm1: #Brown/Hindmarsh condition (default)
999     for j in range(iter+1):
1000     hr=bilinearform(v[j],v[iter+1])
1001     h[j,iter]=h[j,iter]+hr
1002     v[iter+1] -= hr*v[j]
1003 artak 1519
1004 jfenwick 3852 v_norm2=math.sqrt(bilinearform(v[iter+1], v[iter+1]))
1005     h[iter+1,iter]=v_norm2
1006 artak 1519
1007 caltinay 2169 # watch out for happy breakdown
1008 gross 1878 if not v_norm2 == 0:
1009 gross 2100 v[iter+1]=v[iter+1]/h[iter+1,iter]
1010 artak 1519
1011     # Form and store the information for the new Givens rotation
1012 jfenwick 3852 if iter > 0: h[:iter+1,iter]=__givapp(c[:iter],s[:iter],h[:iter+1,iter])
1013     mu=math.sqrt(h[iter,iter]*h[iter,iter]+h[iter+1,iter]*h[iter+1,iter])
1014 artak 1519
1015 jfenwick 3852 if mu!=0 :
1016     c[iter]=h[iter,iter]/mu
1017     s[iter]=-h[iter+1,iter]/mu
1018     h[iter,iter]=c[iter]*h[iter,iter]-s[iter]*h[iter+1,iter]
1019     h[iter+1,iter]=0.0
1020 artak 2303 gg=__givapp(c[iter],s[iter],[g[iter],g[iter+1]])
1021 gross 2284 g[iter]=gg[0]
1022     g[iter+1]=gg[1]
1023 artak 1519 # Update the residual norm
1024 caltinay 2169
1025 artak 1519 rho=abs(g[iter+1])
1026 jfenwick 3771 if verbose: print(("GMRES: iteration step %s: residual %e"%(iter,rho)))
1027 jfenwick 3852 iter+=1
1028 artak 1519
1029 gross 1878 # At this point either iter > iter_max or rho < tol.
1030 caltinay 2169 # It's time to compute x and leave.
1031 gross 1878
1032 jfenwick 3771 if verbose: print(("GMRES: iteration stopped after %s step."%iter))
1033 caltinay 2169 if iter > 0 :
1034 jfenwick 2455 y=numpy.zeros(iter,numpy.float64)
1035 gross 2100 y[iter-1] = g[iter-1] / h[iter-1,iter-1]
1036 caltinay 2169 if iter > 1 :
1037     i=iter-2
1038 artak 1519 while i>=0 :
1039 jfenwick 2455 y[i] = ( g[i] - numpy.dot(h[i,i+1:iter], y[i+1:iter])) / h[i,i]
1040 artak 1519 i=i-1
1041     xhat=v[iter-1]*y[iter-1]
1042     for i in range(iter-1):
1043 jfenwick 3852 xhat += v[i]*y[i]
1044 caltinay 2169 else:
1045 gross 2100 xhat=v[0] * 0
1046 gross 2261 if P_R!=None:
1047     x += P_R(xhat)
1048     else:
1049     x += xhat
1050 caltinay 2169 if iter<iter_restart-1:
1051     stopped=True
1052     else:
1053 artak 1519 stopped=False
1054    
1055     return x,stopped
1056    
1057 gross 2156 def MINRES(r, Aprod, x, Msolve, bilinearform, atol=0, rtol=1.e-8, iter_max=100):
1058     """
1059 caltinay 2169 Solver for
1060    
1061 jfenwick 2625 *Ax=b*
1062 caltinay 2169
1063     with a symmetric and positive definite operator A (more details required!).
1064     It uses the minimum residual method (MINRES) with preconditioner M
1065     providing an approximation of A.
1066    
1067     The iteration is terminated if
1068    
1069 jfenwick 2625 *|r| <= atol+rtol*|r0|*
1070 caltinay 2169
1071 jfenwick 2625 where *r0* is the initial residual and *|.|* is the energy norm. In fact
1072 gross 2156
1073 jfenwick 2625 *|r| = sqrt( bilinearform(Msolve(r),r))*
1074 caltinay 2169
1075 gross 2156 For details on the preconditioned conjugate gradient method see the book:
1076    
1077 caltinay 2169 I{Templates for the Solution of Linear Systems by R. Barrett, M. Berry,
1078     T.F. Chan, J. Demmel, J. Donato, J. Dongarra, V. Eijkhout, R. Pozo,
1079     C. Romine, and H. van der Vorst}.
1080    
1081 jfenwick 2625 :param r: initial residual *r=b-Ax*. ``r`` is altered.
1082     :type r: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
1083     :param x: an initial guess for the solution
1084     :type x: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
1085     :param Aprod: returns the value Ax
1086     :type Aprod: function ``Aprod(x)`` where ``x`` is of the same object like
1087     argument ``x``. The returned object needs to be of the same
1088     type like argument ``r``.
1089     :param Msolve: solves Mx=r
1090     :type Msolve: function ``Msolve(r)`` where ``r`` is of the same type like
1091     argument ``r``. The returned object needs to be of the same
1092     type like argument ``x``.
1093     :param bilinearform: inner product ``<x,r>``
1094     :type bilinearform: function ``bilinearform(x,r)`` where ``x`` is of the same
1095     type like argument ``x`` and ``r`` is. The returned value
1096     is a ``float``.
1097     :param atol: absolute tolerance
1098     :type atol: non-negative ``float``
1099     :param rtol: relative tolerance
1100     :type rtol: non-negative ``float``
1101     :param iter_max: maximum number of iteration steps
1102     :type iter_max: ``int``
1103     :return: the solution approximation and the corresponding residual
1104     :rtype: ``tuple``
1105     :warning: ``r`` and ``x`` are altered.
1106 gross 2156 """
1107 artak 1481 #------------------------------------------------------------------
1108     # Set up y and v for the first Lanczos vector v1.
1109     # y = beta1 P' v1, where P = C**(-1).
1110     # v is really P' v1.
1111     #------------------------------------------------------------------
1112 gross 2156 r1 = r
1113     y = Msolve(r)
1114     beta1 = bilinearform(y,r)
1115 caltinay 2169
1116 jfenwick 3771 if beta1< 0: raise NegativeNorm("negative norm.")
1117 artak 1481
1118 gross 2156 # If r = 0 exactly, stop with x
1119     if beta1==0: return x
1120 artak 1481
1121 caltinay 2169 if beta1> 0: beta1 = math.sqrt(beta1)
1122 artak 1481
1123     #------------------------------------------------------------------
1124 artak 1484 # Initialize quantities.
1125 artak 1481 # ------------------------------------------------------------------
1126 artak 1482 iter = 0
1127     Anorm = 0
1128     ynorm = 0
1129 artak 1481 oldb = 0
1130     beta = beta1
1131     dbar = 0
1132     epsln = 0
1133     phibar = beta1
1134     rhs1 = beta1
1135     rhs2 = 0
1136     rnorm = phibar
1137     tnorm2 = 0
1138     ynorm2 = 0
1139     cs = -1
1140     sn = 0
1141 gross 2156 w = r*0.
1142     w2 = r*0.
1143 artak 1481 r2 = r1
1144     eps = 0.0001
1145    
1146     #---------------------------------------------------------------------
1147     # Main iteration loop.
1148     # --------------------------------------------------------------------
1149 gross 2100 while not rnorm<=atol+rtol*Anorm*ynorm: # checks ||r|| < (||A|| ||x||) * TOL
1150 artak 1481
1151 jfenwick 3852 if iter >= iter_max: raise MaxIterReached("maximum number of %s steps reached."%iter_max)
1152 artak 1481 iter = iter + 1
1153    
1154     #-----------------------------------------------------------------
1155     # Obtain quantities for the next Lanczos vector vk+1, k = 1, 2,...
1156     # The general iteration is similar to the case k = 1 with v0 = 0:
1157     #
1158     # p1 = Operator * v1 - beta1 * v0,
1159     # alpha1 = v1'p1,
1160     # q2 = p2 - alpha1 * v1,
1161     # beta2^2 = q2'q2,
1162     # v2 = (1/beta2) q2.
1163     #
1164     # Again, y = betak P vk, where P = C**(-1).
1165     #-----------------------------------------------------------------
1166     s = 1/beta # Normalize previous vector (in y).
1167     v = s*y # v = vk if P = I
1168 caltinay 2169
1169 artak 1481 y = Aprod(v)
1170 caltinay 2169
1171 artak 1481 if iter >= 2:
1172     y = y - (beta/oldb)*r1
1173    
1174     alfa = bilinearform(v,y) # alphak
1175 caltinay 2169 y += (- alfa/beta)*r2
1176 artak 1481 r1 = r2
1177     r2 = y
1178     y = Msolve(r2)
1179     oldb = beta # oldb = betak
1180 artak 1550 beta = bilinearform(y,r2) # beta = betak+1^2
1181 jfenwick 3771 if beta < 0: raise NegativeNorm("negative norm.")
1182 artak 1481
1183     beta = math.sqrt( beta )
1184     tnorm2 = tnorm2 + alfa*alfa + oldb*oldb + beta*beta
1185 caltinay 2169
1186 artak 1481 if iter==1: # Initialize a few things.
1187     gmax = abs( alfa ) # alpha1
1188     gmin = gmax # alpha1
1189    
1190     # Apply previous rotation Qk-1 to get
1191     # [deltak epslnk+1] = [cs sn][dbark 0 ]
1192     # [gbar k dbar k+1] [sn -cs][alfak betak+1].
1193 caltinay 2169
1194 artak 1481 oldeps = epsln
1195     delta = cs * dbar + sn * alfa # delta1 = 0 deltak
1196     gbar = sn * dbar - cs * alfa # gbar 1 = alfa1 gbar k
1197     epsln = sn * beta # epsln2 = 0 epslnk+1
1198     dbar = - cs * beta # dbar 2 = beta2 dbar k+1
1199    
1200     # Compute the next plane rotation Qk
1201    
1202     gamma = math.sqrt(gbar*gbar+beta*beta) # gammak
1203 caltinay 2169 gamma = max(gamma,eps)
1204 artak 1481 cs = gbar / gamma # ck
1205     sn = beta / gamma # sk
1206     phi = cs * phibar # phik
1207     phibar = sn * phibar # phibark+1
1208    
1209     # Update x.
1210    
1211 caltinay 2169 denom = 1/gamma
1212     w1 = w2
1213     w2 = w
1214 artak 1481 w = (v - oldeps*w1 - delta*w2) * denom
1215 artak 1550 x += phi*w
1216 artak 1481
1217     # Go round again.
1218    
1219     gmax = max(gmax,gamma)
1220     gmin = min(gmin,gamma)
1221     z = rhs1 / gamma
1222     ynorm2 = z*z + ynorm2
1223     rhs1 = rhs2 - delta*z
1224     rhs2 = - epsln*z
1225    
1226     # Estimate various norms and test for convergence.
1227    
1228 caltinay 2169 Anorm = math.sqrt( tnorm2 )
1229     ynorm = math.sqrt( ynorm2 )
1230 artak 1481
1231     rnorm = phibar
1232    
1233     return x
1234 artak 1489
1235 gross 2156 def TFQMR(r, Aprod, x, bilinearform, atol=0, rtol=1.e-8, iter_max=100):
1236     """
1237 caltinay 2169 Solver for
1238 artak 1489
1239 jfenwick 2625 *Ax=b*
1240 artak 1489
1241 caltinay 2169 with a general operator A (more details required!).
1242     It uses the Transpose-Free Quasi-Minimal Residual method (TFQMR).
1243 artak 1489
1244 caltinay 2169 The iteration is terminated if
1245 artak 1489
1246 jfenwick 2625 *|r| <= atol+rtol*|r0|*
1247 gross 2156
1248 jfenwick 2625 where *r0* is the initial residual and *|.|* is the energy norm. In fact
1249 gross 2156
1250 jfenwick 2625 *|r| = sqrt( bilinearform(r,r))*
1251 gross 2156
1252 jfenwick 2625 :param r: initial residual *r=b-Ax*. ``r`` is altered.
1253     :type r: any object supporting inplace add (x+=y) and scaling (x=scalar*y)
1254     :param x: an initial guess for the solution
1255     :type x: same like ``r``
1256     :param Aprod: returns the value Ax
1257     :type Aprod: function ``Aprod(x)`` where ``x`` is of the same object like
1258     argument ``x``. The returned object needs to be of the same type
1259     like argument ``r``.
1260     :param bilinearform: inner product ``<x,r>``
1261     :type bilinearform: function ``bilinearform(x,r)`` where ``x`` is of the same
1262     type like argument ``x`` and ``r``. The returned value is
1263     a ``float``.
1264     :param atol: absolute tolerance
1265     :type atol: non-negative ``float``
1266     :param rtol: relative tolerance
1267     :type rtol: non-negative ``float``
1268     :param iter_max: maximum number of iteration steps
1269     :type iter_max: ``int``
1270     :rtype: ``tuple``
1271     :warning: ``r`` and ``x`` are altered.
1272 gross 2156 """
1273 artak 1489 u1=0
1274     u2=0
1275     y1=0
1276     y2=0
1277    
1278     w = r
1279 caltinay 2169 y1 = r
1280     iter = 0
1281 artak 1489 d = 0
1282 gross 2156 v = Aprod(y1)
1283 artak 1489 u1 = v
1284 caltinay 2169
1285 artak 1489 theta = 0.0;
1286     eta = 0.0;
1287 gross 2156 rho=bilinearform(r,r)
1288 jfenwick 3771 if rho < 0: raise NegativeNorm("negative norm.")
1289 gross 2156 tau = math.sqrt(rho)
1290     norm_r0=tau
1291     while tau>atol+rtol*norm_r0:
1292 jfenwick 3771 if iter >= iter_max: raise MaxIterReached("maximum number of %s steps reached."%iter_max)
1293 artak 1489
1294     sigma = bilinearform(r,v)
1295 jfenwick 3771 if sigma == 0.0: raise IterationBreakDown('TFQMR breakdown, sigma=0')
1296 caltinay 2169
1297 artak 1489 alpha = rho / sigma
1298    
1299     for j in range(2):
1300     #
1301     # Compute y2 and u2 only if you have to
1302     #
1303     if ( j == 1 ):
1304     y2 = y1 - alpha * v
1305 gross 2156 u2 = Aprod(y2)
1306 caltinay 2169
1307 artak 1489 m = 2 * (iter+1) - 2 + (j+1)
1308 caltinay 2169 if j==0:
1309 artak 1489 w = w - alpha * u1
1310     d = y1 + ( theta * theta * eta / alpha ) * d
1311     if j==1:
1312     w = w - alpha * u2
1313     d = y2 + ( theta * theta * eta / alpha ) * d
1314    
1315     theta = math.sqrt(bilinearform(w,w))/ tau
1316     c = 1.0 / math.sqrt ( 1.0 + theta * theta )
1317     tau = tau * theta * c
1318     eta = c * c * alpha
1319     x = x + eta * d
1320     #
1321     # Try to terminate the iteration at each pass through the loop
1322     #
1323 jfenwick 3771 if rho == 0.0: raise IterationBreakDown('TFQMR breakdown, rho=0')
1324 artak 1489
1325     rhon = bilinearform(r,w)
1326     beta = rhon / rho;
1327     rho = rhon;
1328     y1 = w + beta * y2;
1329 gross 2156 u1 = Aprod(y1)
1330 artak 1489 v = u1 + beta * ( u2 + beta * v )
1331 caltinay 2169
1332 gross 2156 iter += 1
1333 artak 1489
1334     return x
1335    
1336    
1337 artak 1465 #############################################
1338    
1339 gross 1331 class ArithmeticTuple(object):
1340     """
1341 jfenwick 2625 Tuple supporting inplace update x+=y and scaling x=a*y where ``x,y`` is an
1342     ArithmeticTuple and ``a`` is a float.
1343 gross 1331
1344 caltinay 2169 Example of usage::
1345 gross 1331
1346 caltinay 2169 from esys.escript import Data
1347 jfenwick 2455 from numpy import array
1348 gross 4071 a=eData(...)
1349 caltinay 2169 b=array([1.,4.])
1350     x=ArithmeticTuple(a,b)
1351     y=5.*x
1352 gross 1331
1353     """
1354     def __init__(self,*args):
1355     """
1356 jfenwick 2625 Initializes object with elements ``args``.
1357 gross 1331
1358 jfenwick 2625 :param args: tuple of objects that support inplace add (x+=y) and
1359 caltinay 2169 scaling (x=a*y)
1360 gross 1331 """
1361     self.__items=list(args)
1362    
1363     def __len__(self):
1364     """
1365 caltinay 2169 Returns the number of items.
1366 gross 1331
1367 jfenwick 2625 :return: number of items
1368     :rtype: ``int``
1369 gross 1331 """
1370     return len(self.__items)
1371    
1372     def __getitem__(self,index):
1373     """
1374 caltinay 2169 Returns item at specified position.
1375 gross 1331
1376 jfenwick 2625 :param index: index of item to be returned
1377     :type index: ``int``
1378     :return: item with index ``index``
1379 gross 1331 """
1380     return self.__items.__getitem__(index)
1381    
1382     def __mul__(self,other):
1383     """
1384 jfenwick 2625 Scales by ``other`` from the right.
1385 gross 1331
1386 jfenwick 2625 :param other: scaling factor
1387     :type other: ``float``
1388     :return: itemwise self*other
1389     :rtype: `ArithmeticTuple`
1390 gross 1331 """
1391     out=[]
1392 caltinay 2169 try:
1393 gross 1896 l=len(other)
1394     if l!=len(self):
1395 jfenwick 3771 raise ValueError("length of arguments don't match.")
1396 gross 4069 for i in range(l):
1397 jfenwick 4237 if self.__isEmpty(self[i]) or self.__isEmpty(other[i]):
1398     out.append(escore.Data())
1399     else:
1400     out.append(self[i]*other[i])
1401 gross 1896 except TypeError:
1402 jfenwick 4237 for i in range(len(self)):
1403     if self.__isEmpty(self[i]) or self.__isEmpty(other):
1404     out.append(escore.Data())
1405     else:
1406     out.append(self[i]*other)
1407 gross 1331 return ArithmeticTuple(*tuple(out))
1408    
1409     def __rmul__(self,other):
1410 gross 4069 """
1411     Scales by ``other`` from the left.
1412 gross 1331
1413 gross 4069 :param other: scaling factor
1414     :type other: ``float``
1415     :return: itemwise other*self
1416     :rtype: `ArithmeticTuple`
1417     """
1418     out=[]
1419     try:
1420 jfenwick 4237 l=len(other)
1421     if l!=len(self):
1422     raise ValueError("length of arguments don't match.")
1423     for i in range(l):
1424     if self.__isEmpty(self[i]) or self.__isEmpty(other[i]):
1425     out.append(escore.Data())
1426     else:
1427     out.append(other[i]*self[i])
1428 gross 4069 except TypeError:
1429 jfenwick 4237 for i in range(len(self)):
1430     if self.__isEmpty(self[i]) or self.__isEmpty(other):
1431     out.append(escore.Data())
1432     else:
1433     out.append(other*self[i])
1434 gross 4069 return ArithmeticTuple(*tuple(out))
1435 gross 1331
1436 artak 1465 def __div__(self,other):
1437     """
1438 jfenwick 2625 Scales by (1/``other``) from the right.
1439 artak 1465
1440 jfenwick 2625 :param other: scaling factor
1441     :type other: ``float``
1442     :return: itemwise self/other
1443     :rtype: `ArithmeticTuple`
1444 artak 1465 """
1445 gross 1896 return self*(1/other)
1446 artak 1465
1447     def __rdiv__(self,other):
1448 gross 4069 """
1449     Scales by (1/``other``) from the left.
1450 artak 1465
1451 gross 4069 :param other: scaling factor
1452     :type other: ``float``
1453     :return: itemwise other/self
1454     :rtype: `ArithmeticTuple`
1455     """
1456     out=[]
1457     try:
1458 jfenwick 4237 l=len(other)
1459     if l!=len(self):
1460     raise ValueError("length of arguments don't match.")
1461    
1462     for i in range(l):
1463     if self.__isEmpty(self[i]):
1464     raise ZeroDivisionError("in component %s"%i)
1465     else:
1466     if self.__isEmpty(other[i]):
1467     out.append(escore.Data())
1468     else:
1469     out.append(other[i]/self[i])
1470 gross 4069 except TypeError:
1471 jfenwick 4237 for i in range(len(self)):
1472     if self.__isEmpty(self[i]):
1473     raise ZeroDivisionError("in component %s"%i)
1474     else:
1475     if self.__isEmpty(other):
1476     out.append(escore.Data())
1477     else:
1478     out.append(other/self[i])
1479 gross 4069 return ArithmeticTuple(*tuple(out))
1480 caltinay 2169
1481 gross 1331 def __iadd__(self,other):
1482 gross 4069 """
1483     Inplace addition of ``other`` to self.
1484 gross 1331
1485 gross 4069 :param other: increment
1486     :type other: ``ArithmeticTuple``
1487     """
1488     if len(self) != len(other):
1489 jfenwick 4237 raise ValueError("tuple lengths must match.")
1490 gross 4069 for i in range(len(self)):
1491 jfenwick 4237 if self.__isEmpty(self.__items[i]):
1492     self.__items[i]=other[i]
1493     else:
1494     self.__items[i]+=other[i]
1495    
1496 gross 4069 return self
1497 gross 1331
1498 artak 1550 def __add__(self,other):
1499 gross 4069 """
1500     Adds ``other`` to self.
1501 artak 1550
1502 gross 4069 :param other: increment
1503     :type other: ``ArithmeticTuple``
1504     """
1505     out=[]
1506     try:
1507 jfenwick 4237 l=len(other)
1508     if l!=len(self):
1509     raise ValueError("length of arguments don't match.")
1510     for i in range(l):
1511     if self.__isEmpty(self[i]):
1512     out.append(other[i])
1513     elif self.__isEmpty(other[i]):
1514     out.append(self[i])
1515     else:
1516     out.append(self[i]+other[i])
1517 gross 4069 except TypeError:
1518 jfenwick 4237 for i in range(len(self)):
1519     if self.__isEmpty(self[i]):
1520     out.append(other)
1521     elif self.__isEmpty(other):
1522     out.append(self[i])
1523     else:
1524     out.append(self[i]+other)
1525 gross 4069 return ArithmeticTuple(*tuple(out))
1526 artak 1550
1527     def __sub__(self,other):
1528 gross 4069 """
1529     Subtracts ``other`` from self.
1530 artak 1550
1531 gross 4069 :param other: decrement
1532     :type other: ``ArithmeticTuple``
1533     """
1534     out=[]
1535     try:
1536 jfenwick 4237 l=len(other)
1537     if l!=len(self):
1538     raise ValueError("length of arguments don't match.")
1539     for i in range(l):
1540     if self.__isEmpty(other[i]):
1541     out.append(self[i])
1542     elif self.__isEmpty(self[i]):
1543     out.append(-other[i])
1544     else:
1545     out.append(self[i]-other[i])
1546 gross 4069 except TypeError:
1547 jfenwick 4237 for i in range(len(self)):
1548     if self.__isEmpty(other):
1549     out.append(self[i])
1550     elif self.__isEmpty(self[i]):
1551     out.append(-other)
1552     else:
1553     out.append(self[i]-other)
1554    
1555 gross 4069 return ArithmeticTuple(*tuple(out))
1556 caltinay 2169
1557 artak 1557 def __isub__(self,other):
1558 gross 4069 """
1559     Inplace subtraction of ``other`` from self.
1560 artak 1550
1561 gross 4069 :param other: decrement
1562     :type other: ``ArithmeticTuple``
1563     """
1564     if len(self) != len(other):
1565 jfenwick 4237 raise ValueError("tuple length must match.")
1566 gross 4069 for i in range(len(self)):
1567 jfenwick 4237 if not self.__isEmpty(other[i]):
1568     if self.__isEmpty(self.__items[i]):
1569     self.__items[i]=-other[i]
1570     else:
1571     self.__items[i]=other[i]
1572 gross 4069 return self
1573 artak 1557
1574 artak 1550 def __neg__(self):
1575 gross 4069 """
1576     Negates values.
1577     """
1578     out=[]
1579     for i in range(len(self)):
1580 jfenwick 4237 if self.__isEmpty(self[i]):
1581     out.append(escore.Data())
1582     else:
1583     out.append(-self[i])
1584    
1585 gross 4069 return ArithmeticTuple(*tuple(out))
1586     def __isEmpty(self, d):
1587 gross 4071 if isinstance(d, escore.Data):
1588 jfenwick 4237 return d.isEmpty()
1589 gross 4069 else:
1590 jfenwick 4237 return False
1591 jfenwick 4407
1592     def __str__(self):
1593     s="("
1594     for i in self:
1595     s=s+str(i)+", "
1596     s=s+")"
1597     return s
1598 artak 1550
1599 gross 1414 class HomogeneousSaddlePointProblem(object):
1600     """
1601 caltinay 2169 This class provides a framework for solving linear homogeneous saddle
1602     point problems of the form::
1603 gross 1414
1604 jfenwick 2625 *Av+B^*p=f*
1605     *Bv =0*
1606 gross 1414
1607 jfenwick 2625 for the unknowns *v* and *p* and given operators *A* and *B* and
1608     given right hand side *f*. *B^** is the adjoint operator of *B*.
1609 gross 2719 *A* may depend weakly on *v* and *p*.
1610 gross 1414 """
1611 gross 2719 def __init__(self, **kwargs):
1612 jfenwick 3852 """
1613     initializes the saddle point problem
1614     """
1615 gross 2719 self.resetControlParameters()
1616 gross 1414 self.setTolerance()
1617 gross 2100 self.setAbsoluteTolerance()
1618 gross 2862 def resetControlParameters(self, K_p=1., K_v=1., rtol_max=0.01, rtol_min = 1.e-7, chi_max=0.5, reduction_factor=0.3, theta = 0.1):
1619 gross 2719 """
1620     sets a control parameter
1621    
1622 gross 2843 :param K_p: initial value for constant to adjust pressure tolerance
1623     :type K_p: ``float``
1624     :param K_v: initial value for constant to adjust velocity tolerance
1625     :type K_v: ``float``
1626     :param rtol_max: maximuim relative tolerance used to calculate presssure and velocity increment.
1627     :type rtol_max: ``float``
1628 gross 2719 :param chi_max: maximum tolerable converegence rate.
1629     :type chi_max: ``float``
1630 gross 2843 :param reduction_factor: reduction factor for adjustment factors.
1631     :type reduction_factor: ``float``
1632 gross 2719 """
1633 gross 2843 self.setControlParameter(K_p, K_v, rtol_max, rtol_min, chi_max, reduction_factor, theta)
1634 gross 2719
1635 gross 2843 def setControlParameter(self,K_p=None, K_v=None, rtol_max=None, rtol_min=None, chi_max=None, reduction_factor=None, theta=None):
1636 gross 2719 """
1637     sets a control parameter
1638    
1639 gross 2843
1640     :param K_p: initial value for constant to adjust pressure tolerance
1641     :type K_p: ``float``
1642     :param K_v: initial value for constant to adjust velocity tolerance
1643     :type K_v: ``float``
1644     :param rtol_max: maximuim relative tolerance used to calculate presssure and velocity increment.
1645     :type rtol_max: ``float``
1646 gross 2719 :param chi_max: maximum tolerable converegence rate.
1647     :type chi_max: ``float``
1648 gross 2843 :type reduction_factor: ``float``
1649 gross 2719 """
1650 gross 2843 if not K_p == None:
1651     if K_p<1:
1652 jfenwick 3771 raise ValueError("K_p need to be greater or equal to 1.")
1653 gross 2719 else:
1654 gross 2843 K_p=self.__K_p
1655 gross 2719
1656 gross 2843 if not K_v == None:
1657     if K_v<1:
1658 jfenwick 3771 raise ValueError("K_v need to be greater or equal to 1.")
1659 gross 2719 else:
1660 gross 2843 K_v=self.__K_v
1661 gross 2719
1662 gross 2843 if not rtol_max == None:
1663     if rtol_max<=0 or rtol_max>=1:
1664 jfenwick 3771 raise ValueError("rtol_max needs to be positive and less than 1.")
1665 gross 2719 else:
1666 gross 2843 rtol_max=self.__rtol_max
1667 gross 2719
1668     if not rtol_min == None:
1669     if rtol_min<=0 or rtol_min>=1:
1670 jfenwick 3771 raise ValueError("rtol_min needs to be positive and less than 1.")
1671 gross 2719 else:
1672     rtol_min=self.__rtol_min
1673    
1674 gross 2843 if not chi_max == None:
1675     if chi_max<=0 or chi_max>=1:
1676 jfenwick 3771 raise ValueError("chi_max needs to be positive and less than 1.")
1677 gross 2719 else:
1678 gross 2843 chi_max = self.__chi_max
1679 gross 2719
1680 gross 2843 if not reduction_factor == None:
1681     if reduction_factor<=0 or reduction_factor>1:
1682 jfenwick 3771 raise ValueError("reduction_factor need to be between zero and one.")
1683 gross 2719 else:
1684 gross 2843 reduction_factor=self.__reduction_factor
1685 gross 2719
1686 gross 2843 if not theta == None:
1687     if theta<=0 or theta>1:
1688 jfenwick 3771 raise ValueError("theta need to be between zero and one.")
1689 gross 2719 else:
1690 gross 2843 theta=self.__theta
1691 gross 2719
1692 gross 2843 if rtol_min>=rtol_max:
1693 jfenwick 3771 raise ValueError("rtol_max = %e needs to be greater than rtol_min = %e"%(rtol_max,rtol_min))
1694 gross 2719 self.__chi_max = chi_max
1695     self.__rtol_max = rtol_max
1696 gross 2843 self.__K_p = K_p
1697     self.__K_v = K_v
1698     self.__reduction_factor = reduction_factor
1699     self.__theta = theta
1700     self.__rtol_min=rtol_min
1701 gross 2719
1702 gross 2100 #=============================================================
1703 gross 2445 def inner_pBv(self,p,Bv):
1704 gross 1414 """
1705 caltinay 2169 Returns inner product of element p and Bv (overwrite).
1706    
1707 jfenwick 2625 :param p: a pressure increment
1708     :param Bv: a residual
1709     :return: inner product of element p and Bv
1710     :rtype: ``float``
1711     :note: used if PCG is applied.
1712 gross 1414 """
1713 jfenwick 3771 raise NotImplementedError("no inner product for p and Bv implemented.")
1714 gross 1414
1715 gross 2100 def inner_p(self,p0,p1):
1716 gross 1414 """
1717 gross 2251 Returns inner product of p0 and p1 (overwrite).
1718 caltinay 2169
1719 jfenwick 2625 :param p0: a pressure
1720     :param p1: a pressure
1721     :return: inner product of p0 and p1
1722     :rtype: ``float``
1723 gross 1414 """
1724 jfenwick 3771 raise NotImplementedError("no inner product for p implemented.")
1725 gross 2251
1726     def norm_v(self,v):
1727     """
1728     Returns the norm of v (overwrite).
1729 gross 2100
1730 jfenwick 2625 :param v: a velovity
1731     :return: norm of v
1732     :rtype: non-negative ``float``
1733 gross 2100 """
1734 jfenwick 3771 raise NotImplementedError("no norm of v implemented.")
1735 gross 2719 def getDV(self, p, v, tol):
1736 gross 2100 """
1737 gross 2719 return a correction to the value for a given v and a given p with accuracy `tol` (overwrite)
1738 gross 1414
1739 gross 2719 :param p: pressure
1740     :param v: pressure
1741     :return: dv given as *dv= A^{-1} (f-A v-B^*p)*
1742     :note: Only *A* may depend on *v* and *p*
1743 gross 1414 """
1744 jfenwick 3771 raise NotImplementedError("no dv calculation implemented.")
1745 gross 2445
1746 gross 2251
1747 gross 2719 def Bv(self,v, tol):
1748 gross 2251 """
1749 gross 2719 Returns Bv with accuracy `tol` (overwrite)
1750 gross 2251
1751 jfenwick 2625 :rtype: equal to the type of p
1752     :note: boundary conditions on p should be zero!
1753 gross 2251 """
1754 jfenwick 3771 raise NotImplementedError("no operator B implemented.")
1755 gross 2251
1756 gross 2445 def norm_Bv(self,Bv):
1757     """
1758     Returns the norm of Bv (overwrite).
1759    
1760 jfenwick 2625 :rtype: equal to the type of p
1761     :note: boundary conditions on p should be zero!
1762 gross 2445 """
1763 jfenwick 3771 raise NotImplementedError("no norm of Bv implemented.")
1764 gross 2445
1765 gross 2719 def solve_AinvBt(self,dp, tol):
1766 gross 2251 """
1767 gross 2719 Solves *A dv=B^*dp* with accuracy `tol`
1768 gross 1414
1769 gross 2719 :param dp: a pressure increment
1770     :return: the solution of *A dv=B^*dp*
1771     :note: boundary conditions on dv should be zero! *A* is the operator used in ``getDV`` and must not be altered.
1772 gross 1414 """
1773 jfenwick 3771 raise NotImplementedError("no operator A implemented.")
1774 gross 1414
1775 gross 2719 def solve_prec(self,Bv, tol):
1776 gross 1414 """
1777 gross 2719 Provides a preconditioner for *(BA^{-1}B^ * )* applied to Bv with accuracy `tol`
1778 gross 1414
1779 jfenwick 2625 :rtype: equal to the type of p
1780     :note: boundary conditions on p should be zero!
1781 gross 1414 """
1782 jfenwick 3771 raise NotImplementedError("no preconditioner for Schur complement implemented.")
1783 gross 2100 #=============================================================
1784 gross 2719 def __Aprod_PCG(self,dp):
1785     dv=self.solve_AinvBt(dp, self.__subtol)
1786     return ArithmeticTuple(dv,self.Bv(dv, self.__subtol))
1787 gross 1414
1788 gross 2445 def __inner_PCG(self,p,r):
1789     return self.inner_pBv(p,r[1])
1790 gross 1414
1791 gross 2445 def __Msolve_PCG(self,r):
1792 gross 2719 return self.solve_prec(r[1], self.__subtol)
1793 gross 2100 #=============================================================
1794 gross 2251 def __Aprod_GMRES(self,p):
1795 gross 2719 return self.solve_prec(self.Bv(self.solve_AinvBt(p, self.__subtol), self.__subtol), self.__subtol)
1796 gross 2251 def __inner_GMRES(self,p0,p1):
1797     return self.inner_p(p0,p1)
1798 gross 2474
1799 gross 2100 #=============================================================
1800 gross 2251 def norm_p(self,p):
1801     """
1802 jfenwick 2625 calculates the norm of ``p``
1803 gross 2251
1804 jfenwick 2625 :param p: a pressure
1805     :return: the norm of ``p`` using the inner product for pressure
1806     :rtype: ``float``
1807 gross 2251 """
1808     f=self.inner_p(p,p)
1809 jfenwick 3771 if f<0: raise ValueError("negative pressure norm.")
1810 gross 2251 return math.sqrt(f)
1811 jfenwick 4237
1812 gross 2474 def solve(self,v,p,max_iter=20, verbose=False, usePCG=True, iter_restart=20, max_correction_steps=10):
1813 gross 2100 """
1814 caltinay 2169 Solves the saddle point problem using initial guesses v and p.
1815 gross 1414
1816 jfenwick 2625 :param v: initial guess for velocity
1817     :param p: initial guess for pressure
1818     :type v: `Data`
1819     :type p: `Data`
1820     :param usePCG: indicates the usage of the PCG rather than GMRES scheme.
1821     :param max_iter: maximum number of iteration steps per correction
1822 caltinay 2169 attempt
1823 jfenwick 2625 :param verbose: if True, shows information on the progress of the
1824 caltinay 2169 saddlepoint problem solver.
1825 jfenwick 2625 :param iter_restart: restart the iteration after ``iter_restart`` steps
1826 caltinay 2169 (only used if useUzaw=False)
1827 jfenwick 2625 :type usePCG: ``bool``
1828     :type max_iter: ``int``
1829     :type verbose: ``bool``
1830     :type iter_restart: ``int``
1831     :rtype: ``tuple`` of `Data` objects
1832 gross 2793 :note: typically this method is overwritten by a subclass. It provides a wrapper for the ``_solve`` method.
1833 gross 2100 """
1834 gross 2793 return self._solve(v=v,p=p,max_iter=max_iter,verbose=verbose, usePCG=usePCG, iter_restart=iter_restart, max_correction_steps=max_correction_steps)
1835    
1836     def _solve(self,v,p,max_iter=20, verbose=False, usePCG=True, iter_restart=20, max_correction_steps=10):
1837     """
1838     see `_solve` method.
1839     """
1840 gross 2100 self.verbose=verbose
1841     rtol=self.getTolerance()
1842     atol=self.getAbsoluteTolerance()
1843 gross 2843
1844     K_p=self.__K_p
1845     K_v=self.__K_v
1846 gross 2100 correction_step=0
1847     converged=False
1848 gross 2719 chi=None
1849 gross 2843 eps=None
1850    
1851 jfenwick 3771 if self.verbose: print(("HomogeneousSaddlePointProblem: start iteration: rtol= %e, atol=%e"%(rtol, atol)))
1852 gross 2251 while not converged:
1853 gross 2843
1854     # get tolerance for velecity increment:
1855     if chi == None:
1856     rtol_v=self.__rtol_max
1857     else:
1858     rtol_v=min(chi/K_v,self.__rtol_max)
1859     rtol_v=max(rtol_v, self.__rtol_min)
1860 jfenwick 3771 if self.verbose: print(("HomogeneousSaddlePointProblem: step %s: rtol_v= %e"%(correction_step,rtol_v)))
1861 gross 2843 # get velocity increment:
1862     dv1=self.getDV(p,v,rtol_v)
1863     v1=v+dv1
1864     Bv1=self.Bv(v1, rtol_v)
1865     norm_Bv1=self.norm_Bv(Bv1)
1866     norm_dv1=self.norm_v(dv1)
1867 jfenwick 3771 if self.verbose: print(("HomogeneousSaddlePointProblem: step %s: norm_Bv1 = %e, norm_dv1 = %e"%(correction_step, norm_Bv1, norm_dv1)))
1868 gross 2843 if norm_dv1*self.__theta < norm_Bv1:
1869     # get tolerance for pressure increment:
1870     large_Bv1=True
1871     if chi == None or eps == None:
1872     rtol_p=self.__rtol_max
1873     else:
1874     rtol_p=min(chi**2*eps/K_p/norm_Bv1, self.__rtol_max)
1875     self.__subtol=max(rtol_p**2, self.__rtol_min)
1876 jfenwick 3771 if self.verbose: print(("HomogeneousSaddlePointProblem: step %s: rtol_p= %e"%(correction_step,rtol_p)))
1877 gross 2843 # now we solve for the pressure increment dp from B*A^{-1}B^* dp = Bv1
1878     if usePCG:
1879 gross 2719 dp,r,a_norm=PCG(ArithmeticTuple(v1,Bv1),self.__Aprod_PCG,0*p,self.__Msolve_PCG,self.__inner_PCG,atol=0, rtol=rtol_p,iter_max=max_iter, verbose=self.verbose)
1880     v2=r[0]
1881     Bv2=r[1]
1882 gross 2843 else:
1883     # don't use!!!!
1884 gross 2719 dp=GMRES(self.solve_prec(Bv1,self.__subtol),self.__Aprod_GMRES, 0*p, self.__inner_GMRES,atol=0, rtol=rtol_p,iter_max=max_iter, iter_restart=iter_restart, verbose=self.verbose)
1885     dv2=self.solve_AinvBt(dp, self.__subtol)
1886     v2=v1-dv2
1887     Bv2=self.Bv(v2, self.__subtol)
1888 gross 2843 p2=p+dp
1889     else:
1890     large_Bv1=False
1891     v2=v1
1892     p2=p
1893     # update business:
1894     norm_dv2=self.norm_v(v2-v)
1895     norm_v2=self.norm_v(v2)
1896 jfenwick 3771 if self.verbose: print(("HomogeneousSaddlePointProblem: step %s: v2 = %e, norm_dv2 = %e"%(correction_step, norm_v2, self.norm_v(v2-v))))
1897 gross 2843 eps, eps_old = max(norm_Bv1, norm_dv2), eps
1898     if eps_old == None:
1899     chi, chi_old = None, chi
1900     else:
1901     chi, chi_old = min(eps/ eps_old, self.__chi_max), chi
1902     if eps != None:
1903     if chi !=None:
1904 jfenwick 3771 if self.verbose: print(("HomogeneousSaddlePointProblem: step %s: convergence rate = %e, correction = %e"%(correction_step,chi, eps)))
1905 gross 2843 else:
1906 jfenwick 3771 if self.verbose: print(("HomogeneousSaddlePointProblem: step %s: correction = %e"%(correction_step, eps)))
1907 gross 2843 if eps <= rtol*norm_v2+atol :
1908     converged = True
1909     else:
1910     if correction_step>=max_correction_steps:
1911 jfenwick 3771 raise CorrectionFailed("Given up after %d correction steps."%correction_step)
1912 gross 2843 if chi_old!=None:
1913     K_p=max(1,self.__reduction_factor*K_p,(chi-chi_old)/chi_old**2*K_p)
1914     K_v=max(1,self.__reduction_factor*K_v,(chi-chi_old)/chi_old**2*K_p)
1915 jfenwick 3771 if self.verbose: print(("HomogeneousSaddlePointProblem: step %s: new adjustment factor K = %e"%(correction_step,K_p)))
1916 gross 2843 correction_step+=1
1917     v,p =v2, p2
1918 jfenwick 3771 if self.verbose: print(("HomogeneousSaddlePointProblem: tolerance reached after %s steps."%correction_step))
1919 jfenwick 3852 return v,p
1920 caltinay 2169 #========================================================================
1921 gross 2100 def setTolerance(self,tolerance=1.e-4):
1922     """
1923 caltinay 2169 Sets the relative tolerance for (v,p).
1924 gross 2100
1925 jfenwick 2625 :param tolerance: tolerance to be used
1926     :type tolerance: non-negative ``float``
1927 gross 2100 """
1928     if tolerance<0:
1929 jfenwick 3771 raise ValueError("tolerance must be positive.")
1930 gross 2100 self.__rtol=tolerance
1931 gross 2156
1932 gross 1414 def getTolerance(self):
1933 gross 2100 """
1934 caltinay 2169 Returns the relative tolerance.
1935 gross 1414
1936 jfenwick 2625 :return: relative tolerance
1937     :rtype: ``float``
1938 gross 2100 """
1939     return self.__rtol
1940 caltinay 2169
1941 gross 2100 def setAbsoluteTolerance(self,tolerance=0.):
1942     """
1943 caltinay 2169 Sets the absolute tolerance.
1944 gross 1414
1945 jfenwick 2625 :param tolerance: tolerance to be used
1946     :type tolerance: non-negative ``float``
1947 gross 2100 """
1948     if tolerance<0:
1949 jfenwick 3771 raise ValueError("tolerance must be non-negative.")
1950 gross 2100 self.__atol=tolerance
1951 caltinay 2169
1952 gross 2100 def getAbsoluteTolerance(self):
1953     """
1954 caltinay 2169 Returns the absolute tolerance.
1955 gross 1414
1956 jfenwick 2625 :return: absolute tolerance
1957     :rtype: ``float``
1958 gross 2100 """
1959     return self.__atol
1960 gross 1469
1961 caltinay 2169
1962 gross 1878 def MaskFromBoundaryTag(domain,*tags):
1963     """
1964 caltinay 2169 Creates a mask on the Solution(domain) function space where the value is
1965     one for samples that touch the boundary tagged by tags.
1966 gross 1878
1967 caltinay 2169 Usage: m=MaskFromBoundaryTag(domain, "left", "right")
1968 gross 1878
1969 jfenwick 2625 :param domain: domain to be used
1970     :type domain: `escript.Domain`
1971     :param tags: boundary tags
1972     :type tags: ``str``
1973     :return: a mask which marks samples that are touching the boundary tagged
1974 caltinay 2169 by any of the given tags
1975 jfenwick 2625 :rtype: `escript.Data` of rank 0
1976 gross 1878 """
1977     pde=linearPDEs.LinearPDE(domain,numEquations=1, numSolutions=1)
1978 jfenwick 3995 d=escore.Scalar(0.,escore.FunctionOnBoundary(domain))
1979 gross 1878 for t in tags: d.setTaggedValue(t,1.)
1980     pde.setValue(y=d)
1981     return util.whereNonZero(pde.getRightHandSide())
1982 caltinay 2169
1983 gross 2498 def MaskFromTag(domain,*tags):
1984     """
1985     Creates a mask on the Solution(domain) function space where the value is
1986     one for samples that touch regions tagged by tags.
1987 gross 867
1988 gross 2498 Usage: m=MaskFromTag(domain, "ham")
1989    
1990 jfenwick 2625 :param domain: domain to be used
1991     :type domain: `escript.Domain`
1992     :param tags: boundary tags
1993     :type tags: ``str``
1994     :return: a mask which marks samples that are touching the boundary tagged
1995 gross 2498 by any of the given tags
1996 jfenwick 2625 :rtype: `escript.Data` of rank 0
1997 gross 2498 """
1998     pde=linearPDEs.LinearPDE(domain,numEquations=1, numSolutions=1)
1999 jfenwick 3995 d=escore.Scalar(0.,escore.Function(domain))
2000 gross 2498 for t in tags: d.setTaggedValue(t,1.)
2001     pde.setValue(Y=d)
2002     return util.whereNonZero(pde.getRightHandSide())
2003    
2004    

Properties

Name Value
svn:eol-style native
svn:keywords Author Date Id Revision

  ViewVC Help
Powered by ViewVC 1.1.26