# Diff of /trunk/escript/py_src/util.py

trunk/esys2/escript/py_src/util.py revision 117 by jgs, Fri Apr 1 05:48:57 2005 UTC trunk/escript/py_src/util.py revision 155 by jgs, Wed Nov 9 02:02:19 2005 UTC
# Line 3  Line 3
3  ## @file util.py  ## @file util.py
4
5  """  """
6  @brief Utility functions for escript  Utility functions for escript
7
8    @todo:
9
10      - binary operations @ (@=+,-,*,/,**)::
11          (a@b)[:,*]=a[:]@b[:,*] if rank(a)<rank(b)
12          (a@b)[:]=a[:]@b[:] if rank(a)=rank(b)
13          (a@b)[*,:]=a[*,:]@b[:] if rank(a)>rank(b)
14      - implementation of outer::
15          outer(a,b)[:,*]=a[:]*b[*]
16      - trace::
17          trace(arg,axis0=a0,axis1=a1)(:,&,*)=sum_i trace(:,i,&,i,*) (i are at index a0 and a1)
18  """  """
19
20  import numarray  import numarray
21  import escript  import escript
22  #  import symbols
23  #   escript constants (have to be consistent witj utilC.h  import os
#
UNKNOWN=-1
EPSILON=1.e-15
Pi=numarray.pi
# some solver options:
NO_REORDERING=0
MINIMUM_FILL_IN=1
NESTED_DISSECTION=2
# solver methods
DEFAULT_METHOD=0
DIRECT=1
CHOLEVSKY=2
PCG=3
CR=4
CGS=5
BICGSTAB=6
SSOR=7
ILU0=8
ILUT=9
JACOBI=10
GMRES=11
PRES20=12

METHOD_KEY="method"
SYMMETRY_KEY="symmetric"
TOLERANCE_KEY="tolerance"

# supported file formats:
VRML=1
PNG=2
JPEG=3
JPG=3
PS=4
OOGL=5
BMP=6
TIFF=7
OPENINVENTOR=8
RENDERMAN=9
PNM=10
24
25  #  #=========================================================
26  # wrapper for various functions: if the argument has attribute the function name  #   some little helpers
27  # as an argument it calls the correspong methods. Otherwise the coresponsing numarray  #=========================================================
28  # function is called.  def _testForZero(arg):
29  #     """
30  # functions involving the underlying Domain:     Returns True is arg is considered to be zero.
31  #     """
33          return not arg>0
34       elif isinstance(arg,float):
35          return not arg>0.
36       elif isinstance(arg,numarray.NumArray):
37          a=abs(arg)
38          while isinstance(a,numarray.NumArray): a=numarray.sometrue(a)
39          return not a>0
40       else:
41          return False
42
43    #=========================================================
44    def saveVTK(filename,domain=None,**data):
45        """
46        writes a L{Data} objects into a files using the the VTK XML file format.
47
48        Example:
49
50           tmp=Scalar(..)
51           v=Vector(..)
52           saveVTK("solution.xml",temperature=tmp,velovity=v)
53
54        tmp and v are written into "solution.dx" where tmp is named "temperature" and v is named "velovity"
55
56        @param filename: file name of the output file
57        @type filename: C(str}
58        @param domain: domain of the L{Data} object. If not specified, the domain of the given L{Data} objects is used.
59        @type domain: L{escript.Domain}
60        @keyword <name>: writes the assigned value to the VTK file using <name> as identifier.
61        @type <name>: L{Data} object.
62        @note: The data objects have to be defined on the same domain. They may not be in the same L{FunctionSpace} but one cannot expect that all L{FunctionSpace} can be mixed. Typically, data on the boundary and data on the interior cannot be mixed.
63        """
64        if domain==None:
65           for i in data.keys():
66              if not data[i].isEmpty(): domain=data[i].getFunctionSpace().getDomain()
67        if domain==None:
68            raise ValueError,"no domain detected."
69        else:
70            domain.saveVTK(filename,data)
71
72    #=========================================================
73    def saveDX(filename,domain=None,**data):
74        """
75        writes a L{Data} objects into a files using the the DX file format.
76
77        Example:
78
79           tmp=Scalar(..)
80           v=Vector(..)
81           saveDX("solution.dx",temperature=tmp,velovity=v)
82
83        tmp and v are written into "solution.dx" where tmp is named "temperature" and v is named "velovity".
84
85        @param filename: file name of the output file
86        @type filename: C(str}
87        @param domain: domain of the L{Data} object. If not specified, the domain of the given L{Data} objects is used.
88        @type domain: L{escript.Domain}
89        @keyword <name>: writes the assigned value to the DX file using <name> as identifier. The identifier can be used to select the data set when data are imported into DX.
90        @type <name>: L{Data} object.
91        @note: The data objects have to be defined on the same domain. They may not be in the same L{FunctionSpace} but one cannot expect that all L{FunctionSpace} can be mixed. Typically, data on the boundary and data on the interior cannot be mixed.
92        """
93        if domain==None:
94           for i in data.keys():
95              if not data[i].isEmpty(): domain=data[i].getFunctionSpace().getDomain()
96        if domain==None:
97            raise ValueError,"no domain detected."
98        else:
99            domain.saveDX(filename,data)
100
101    #=========================================================
102
103    def exp(arg):
104      """      """
105      @brief returns the spatial gradient of the Data object arg      Applies the exponential function to arg.
106
107      @param arg: Data object representing the function which gradient to be calculated.      @param arg: argument
@param where: FunctionSpace in which the gradient will be. If None Function(dom) where dom is the
domain of the Data object arg.
108      """      """
109      if isinstance(arg,escript.Data):      if isinstance(arg,symbols.Symbol):
110         if where==None:         return symbols.Exp_Symbol(arg)
112         else:         return arg.exp()
113      else:      else:
114         return arg*0.         return numarray.exp(arg)
115
116  def integrate(arg,what=None):  def sqrt(arg):
117      """      """
118      @brief return the integral if the function represented by Data object arg over its domain.      Applies the squre root function to arg.
119
120      @param arg      @param arg: argument
121      """      """
122      if not what==None:      if isinstance(arg,symbols.Symbol):
123         arg2=escript.Data(arg,what)         return symbols.Sqrt_Symbol(arg)
124        elif hasattr(arg,"sqrt"):
125           return arg.sqrt()
126      else:      else:
127         arg2=arg         return numarray.sqrt(arg)
128      if arg2.getRank()==0:
129          return arg2.integrate()[0]  def log(arg):
130        """
131        Applies the logarithmic function base 10 to arg.
132
133        @param arg: argument
134        """
135        if isinstance(arg,symbols.Symbol):
136           return symbols.Log_Symbol(arg)
137        elif hasattr(arg,"log"):
138           return arg.log()
139      else:      else:
140          return arg2.integrate()         return numarray.log10(arg)
141
142  def interpolate(arg,where):  def ln(arg):
143      """      """
144      @brief interpolates the function represented by Data object arg into the FunctionSpace where.      Applies the natural logarithmic function to arg.
145
146      @param arg      @param arg: argument
@param where
147      """      """
148      if isinstance(arg,escript.Data):      if isinstance(arg,symbols.Symbol):
149         return arg.interpolate(where)         return symbols.Ln_Symbol(arg)
150        elif hasattr(arg,"ln"):
151           return arg.ln()
152      else:      else:
153         return arg         return numarray.log(arg)
154
155  # functions returning Data objects:  def sin(arg):
156        """
157        Applies the sin function to arg.
158
159  def transpose(arg,axis=None):      @param arg: argument
160      """      """
161      @brief returns the transpose of the Data object arg.      if isinstance(arg,symbols.Symbol):
162           return symbols.Sin_Symbol(arg)
163        elif hasattr(arg,"sin"):
164           return arg.sin()
165        else:
166           return numarray.sin(arg)
167
168      @param arg  def cos(arg):
169      """      """
170      if isinstance(arg,escript.Data):      Applies the cos function to arg.
171         # hack for transpose
172         r=arg.getRank()      @param arg: argument
173         if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects"      """
174         s=arg.getShape()      if isinstance(arg,symbols.Symbol):
175         out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace())         return symbols.Cos_Symbol(arg)
176         for i in range(s[0]):      elif hasattr(arg,"cos"):
177            for j in range(s[1]):         return arg.cos()
out[j,i]=arg[i,j]
return out
# end hack for transpose
if axis==None: axis=arg.getRank()/2
return arg.transpose(axis)
178      else:      else:
179         if axis==None: axis=arg.rank/2         return numarray.cos(arg)
return numarray.transpose(arg,axis=axis)
180
181  def trace(arg):  def tan(arg):
182      """      """
183      @brief      Applies the tan function to arg.
184
185      @param arg      @param arg: argument
186      """      """
187      if isinstance(arg,escript.Data):      if isinstance(arg,symbols.Symbol):
188         # hack for trace         return symbols.Tan_Symbol(arg)
189         r=arg.getRank()      elif hasattr(arg,"tan"):
190         if r!=2: raise ValueError,"trace only avalaible for rank 2 objects"         return arg.tan()
s=arg.getShape()
out=escript.Scalar(0,arg.getFunctionSpace())
for i in range(min(s)):
out+=arg[i,i]
return out
# end hack for trace
return arg.trace()
191      else:      else:
192         return numarray.trace(arg)         return numarray.tan(arg)
193
194  def exp(arg):  def asin(arg):
195      """      """
196      @brief      Applies the asin function to arg.
197
198      @param arg      @param arg: argument
199      """      """
200      if isinstance(arg,escript.Data):      if isinstance(arg,symbols.Symbol):
201         return arg.exp()         return symbols.Asin_Symbol(arg)
202        elif hasattr(arg,"asin"):
203           return arg.asin()
204      else:      else:
205         return numarray.exp(arg)         return numarray.asin(arg)
206
207  def sqrt(arg):  def acos(arg):
208      """      """
209      @brief      Applies the acos function to arg.
210
211      @param arg      @param arg: argument
212      """      """
213      if isinstance(arg,escript.Data):      if isinstance(arg,symbols.Symbol):
214         return arg.sqrt()         return symbols.Acos_Symbol(arg)
215        elif hasattr(arg,"acos"):
216           return arg.acos()
217      else:      else:
218         return numarray.sqrt(arg)         return numarray.acos(arg)
219
220  def sin(arg):  def atan(arg):
221      """      """
222      @brief      Applies the atan function to arg.
223
224      @param arg      @param arg: argument
225      """      """
226      if isinstance(arg,escript.Data):      if isinstance(arg,symbols.Symbol):
227         return arg.sin()         return symbols.Atan_Symbol(arg)
228        elif hasattr(arg,"atan"):
229           return arg.atan()
230      else:      else:
231         return numarray.sin(arg)         return numarray.atan(arg)
232
233  def tan(arg):  def sinh(arg):
234      """      """
235      @brief      Applies the sinh function to arg.
236
237      @param arg      @param arg: argument
238      """      """
239      if isinstance(arg,escript.Data):      if isinstance(arg,symbols.Symbol):
240         return arg.tan()         return symbols.Sinh_Symbol(arg)
241        elif hasattr(arg,"sinh"):
242           return arg.sinh()
243      else:      else:
244         return numarray.tan(arg)         return numarray.sinh(arg)
245
246  def cos(arg):  def cosh(arg):
247      """      """
248      @brief      Applies the cosh function to arg.
249
250      @param arg      @param arg: argument
251      """      """
252      if isinstance(arg,escript.Data):      if isinstance(arg,symbols.Symbol):
253         return arg.cos()         return symbols.Cosh_Symbol(arg)
254        elif hasattr(arg,"cosh"):
255           return arg.cosh()
256      else:      else:
257         return numarray.cos(arg)         return numarray.cosh(arg)
258
259    def tanh(arg):
260        """
261        Applies the tanh function to arg.
262
263        @param arg: argument
264        """
265        if isinstance(arg,symbols.Symbol):
266           return symbols.Tanh_Symbol(arg)
267        elif hasattr(arg,"tanh"):
268           return arg.tanh()
269        else:
270           return numarray.tanh(arg)
271
272    def asinh(arg):
273        """
274        Applies the asinh function to arg.
275
276        @param arg: argument
277        """
278        if isinstance(arg,symbols.Symbol):
279           return symbols.Asinh_Symbol(arg)
280        elif hasattr(arg,"asinh"):
281           return arg.asinh()
282        else:
283           return numarray.asinh(arg)
284
285    def acosh(arg):
286        """
287        Applies the acosh function to arg.
288
289        @param arg: argument
290        """
291        if isinstance(arg,symbols.Symbol):
292           return symbols.Acosh_Symbol(arg)
293        elif hasattr(arg,"acosh"):
294           return arg.acosh()
295        else:
296           return numarray.acosh(arg)
297
298    def atanh(arg):
299        """
300        Applies the atanh function to arg.
301
302        @param arg: argument
303        """
304        if isinstance(arg,symbols.Symbol):
305           return symbols.Atanh_Symbol(arg)
306        elif hasattr(arg,"atanh"):
307           return arg.atanh()
308        else:
309           return numarray.atanh(arg)
310
311    def sign(arg):
312        """
313        Applies the sign function to arg.
314
315        @param arg: argument
316        """
317        if isinstance(arg,symbols.Symbol):
318           return symbols.Sign_Symbol(arg)
319        elif hasattr(arg,"sign"):
320           return arg.sign()
321        else:
322           return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))- \
323                  numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
324
325  def maxval(arg):  def maxval(arg):
326      """      """
327      @brief      Returns the maximum value of argument arg.
328
329      @param arg      @param arg: argument
330      """      """
331      if isinstance(arg,escript.Data):      if isinstance(arg,symbols.Symbol):
332           return symbols.Max_Symbol(arg)
333        elif hasattr(arg,"maxval"):
334         return arg.maxval()         return arg.maxval()
335      elif isinstance(arg,float) or isinstance(arg,int):      elif hasattr(arg,"max"):
return arg
else:
336         return arg.max()         return arg.max()
337        else:
338           return arg
339
340  def minval(arg):  def minval(arg):
341      """      """
342      @brief      Returns the minimum value of argument arg.
343
344      @param arg      @param arg: argument
345      """      """
346      if isinstance(arg,escript.Data):      if isinstance(arg,symbols.Symbol):
347           return symbols.Min_Symbol(arg)
348        elif hasattr(arg,"maxval"):
349         return arg.minval()         return arg.minval()
350      elif isinstance(arg,float) or isinstance(arg,int):      elif hasattr(arg,"min"):
351           return arg.min()
352        else:
353         return arg         return arg
354
355    def wherePositive(arg):
356        """
357        Returns the positive values of argument arg.
358
359        @param arg: argument
360        """
361        if _testForZero(arg):
362          return 0
363        elif isinstance(arg,symbols.Symbol):
364           return symbols.WherePositive_Symbol(arg)
365        elif hasattr(arg,"wherePositive"):
366           return arg.minval()
367        elif hasattr(arg,"wherePositive"):
368           numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))
369      else:      else:
370         return arg.min()         if arg>0:
371              return 1.
372           else:
373              return 0.
374
375  def length(arg):  def whereNegative(arg):
376        """
377        Returns the negative values of argument arg.
378
379        @param arg: argument
380        """
381        if _testForZero(arg):
382          return 0
383        elif isinstance(arg,symbols.Symbol):
384           return symbols.WhereNegative_Symbol(arg)
385        elif hasattr(arg,"whereNegative"):
386           return arg.whereNegative()
387        elif hasattr(arg,"shape"):
388           numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
389        else:
390           if arg<0:
391              return 1.
392           else:
393              return 0.
394
395    def maximum(arg0,arg1):
396        """
397        Return arg1 where arg1 is bigger then arg0 otherwise arg0 is returned.
398        """
399        m=whereNegative(arg0-arg1)
400        return m*arg1+(1.-m)*arg0
401
402    def minimum(arg0,arg1):
403        """
404        Return arg0 where arg1 is bigger then arg0 otherwise arg1 is returned.
405        """
406        m=whereNegative(arg0-arg1)
407        return m*arg0+(1.-m)*arg1
408
409    def outer(arg0,arg1):
410       if _testForZero(arg0) or _testForZero(arg1):
411          return 0
412       else:
413          if isinstance(arg0,symbols.Symbol) or isinstance(arg1,symbols.Symbol):
414            return symbols.Outer_Symbol(arg0,arg1)
415          elif _identifyShape(arg0)==() or _identifyShape(arg1)==():
416            return arg0*arg1
417          elif isinstance(arg0,numarray.NumArray) and isinstance(arg1,numarray.NumArray):
418            return numarray.outer(arg0,arg1)
419          else:
420            if arg0.getRank()==1 and arg1.getRank()==1:
421              out=escript.Data(0,(arg0.getShape()[0],arg1.getShape()[0]),arg1.getFunctionSpace())
422              for i in range(arg0.getShape()[0]):
423                for j in range(arg1.getShape()[0]):
424                    out[i,j]=arg0[i]*arg1[j]
425              return out
426            else:
427              raise ValueError,"outer is not fully implemented yet."
428
429    def interpolate(arg,where):
430        """
431        Interpolates the function into the FunctionSpace where.
432
433        @param arg:    interpolant
434        @param where:  FunctionSpace to interpolate to
435        """
436        if _testForZero(arg):
437          return 0
438        elif isinstance(arg,symbols.Symbol):
439           return symbols.Interpolated_Symbol(arg,where)
440        else:
441           return escript.Data(arg,where)
442
443    def div(arg,where=None):
444        """
445        Returns the divergence of arg at where.
446
447        @param arg:   Data object representing the function which gradient to
448                      be calculated.
449        @param where: FunctionSpace in which the gradient will be calculated.
450                      If not present or C{None} an appropriate default is used.
451        """
453
454    def jump(arg):
455        """
456        Returns the jump of arg across a continuity.
457
458        @param arg:   Data object representing the function which gradient
459                      to be calculated.
460        """
461        d=arg.getDomain()
462        return arg.interpolate(escript.FunctionOnContactOne())-arg.interpolate(escript.FunctionOnContactZero())
463
464
466        """
467        Returns the spatial gradient of arg at where.
468
469        @param arg:   Data object representing the function which gradient
470                      to be calculated.
471        @param where: FunctionSpace in which the gradient will be calculated.
472                      If not present or C{None} an appropriate default is used.
473        """
474        if _testForZero(arg):
475          return 0
476        elif isinstance(arg,symbols.Symbol):
479           if where==None:
481           else:
483        else:
484           return arg*0.
485
486    def integrate(arg,where=None):
487        """
488        Return the integral if the function represented by Data object arg over
489        its domain.
490
491        @param arg:   Data object representing the function which is integrated.
492        @param where: FunctionSpace in which the integral is calculated.
493                      If not present or C{None} an appropriate default is used.
494        """
495        if _testForZero(arg):
496          return 0
497        elif isinstance(arg,symbols.Symbol):
498           return symbols.Integral_Symbol(arg,where)
499        else:
500           if not where==None: arg=escript.Data(arg,where)
501           if arg.getRank()==0:
502             return arg.integrate()[0]
503           else:
504             return arg.integrate()
505
506    #=============================
507    #
508    # wrapper for various functions: if the argument has attribute the function name
509    # as an argument it calls the corresponding methods. Otherwise the corresponding
510    # numarray function is called.
511
512    # functions involving the underlying Domain:
513
514
515    # functions returning Data objects:
516
517    def transpose(arg,axis=None):
518      """      """
519      @brief      Returns the transpose of the Data object arg.
520
521        @param arg:
522        """
523        if axis==None:
524           r=0
525           if hasattr(arg,"getRank"): r=arg.getRank()
526           if hasattr(arg,"rank"): r=arg.rank
527           axis=r/2
528        if isinstance(arg,symbols.Symbol):
529           return symbols.Transpose_Symbol(arg,axis=r)
530        if isinstance(arg,escript.Data):
531           # hack for transpose
532           r=arg.getRank()
533           if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects"
534           s=arg.getShape()
535           out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace())
536           for i in range(s[0]):
537              for j in range(s[1]):
538                 out[j,i]=arg[i,j]
539           return out
540           # end hack for transpose
541           return arg.transpose(axis)
542        else:
543           return numarray.transpose(arg,axis=axis)
544
545    def trace(arg,axis0=0,axis1=1):
546        """
547        Return
548
549        @param arg:
550        """
551        if isinstance(arg,symbols.Symbol):
552           s=list(arg.getShape())
553           s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:])
554           return symbols.Trace_Symbol(arg,axis0=axis0,axis1=axis1)
555        elif isinstance(arg,escript.Data):
556           # hack for trace
557           s=arg.getShape()
558           if s[axis0]!=s[axis1]:
559               raise ValueError,"illegal axis in trace"
560           out=escript.Scalar(0.,arg.getFunctionSpace())
561           for i in range(s[axis0]):
562              out+=arg[i,i]
563           return out
564           # end hack for trace
565        else:
566           return numarray.trace(arg,axis0=axis0,axis1=axis1)
567
568      @param arg  def length(arg):
569        """
570        @param arg:
571      """      """
572      if isinstance(arg,escript.Data):      if isinstance(arg,escript.Data):
573         if arg.isEmpty(): return escript.Data()         if arg.isEmpty(): return escript.Data()
574         if arg.getRank()==0:         if arg.getRank()==0:
575            return abs(arg)            return abs(arg)
576         elif arg.getRank()==1:         elif arg.getRank()==1:
577            sum=escript.Scalar(0,arg.getFunctionSpace())            out=escript.Scalar(0,arg.getFunctionSpace())
578            for i in range(arg.getShape()[0]):            for i in range(arg.getShape()[0]):
579               sum+=arg[i]**2               out+=arg[i]**2
580            return sqrt(sum)            return sqrt(out)
581         elif arg.getRank()==2:         elif arg.getRank()==2:
582            sum=escript.Scalar(0,arg.getFunctionSpace())            out=escript.Scalar(0,arg.getFunctionSpace())
583            for i in range(arg.getShape()[0]):            for i in range(arg.getShape()[0]):
584               for j in range(arg.getShape()[1]):               for j in range(arg.getShape()[1]):
585                  sum+=arg[i,j]**2                  out+=arg[i,j]**2
586            return sqrt(sum)            return sqrt(out)
587         elif arg.getRank()==3:         elif arg.getRank()==3:
588            sum=escript.Scalar(0,arg.getFunctionSpace())            out=escript.Scalar(0,arg.getFunctionSpace())
589            for i in range(arg.getShape()[0]):            for i in range(arg.getShape()[0]):
590               for j in range(arg.getShape()[1]):               for j in range(arg.getShape()[1]):
591                  for k in range(arg.getShape()[2]):                  for k in range(arg.getShape()[2]):
592                     sum+=arg[i,j,k]**2                     out+=arg[i,j,k]**2
593            return sqrt(sum)            return sqrt(out)
594         elif arg.getRank()==4:         elif arg.getRank()==4:
595            sum=escript.Scalar(0,arg.getFunctionSpace())            out=escript.Scalar(0,arg.getFunctionSpace())
596            for i in range(arg.getShape()[0]):            for i in range(arg.getShape()[0]):
597               for j in range(arg.getShape()[1]):               for j in range(arg.getShape()[1]):
598                  for k in range(arg.getShape()[2]):                  for k in range(arg.getShape()[2]):
599                     for l in range(arg.getShape()[3]):                     for l in range(arg.getShape()[3]):
600                        sum+=arg[i,j,k,l]**2                        out+=arg[i,j,k,l]**2
601            return sqrt(sum)            return sqrt(out)
602         else:         else:
603            raise SystemError,"length is not been implemented yet"            raise SystemError,"length is not been fully implemented yet"
604         # return arg.length()            # return arg.length()
605        elif isinstance(arg,float):
606           return abs(arg)
607      else:      else:
608         return sqrt((arg**2).sum())         return sqrt((arg**2).sum())
609
610  def deviator(arg):  def deviator(arg):
611      """      """
612      @brief      @param arg:

@param arg1
613      """      """
614      if isinstance(arg,escript.Data):      if isinstance(arg,escript.Data):
615          shape=arg.getShape()          shape=arg.getShape()
# Line 284  def deviator(arg): Line 621  def deviator(arg):
621            raise ValueError,"Deviator requires a square matrix"            raise ValueError,"Deviator requires a square matrix"
622      return arg-1./(shape[0]*1.)*trace(arg)*kronecker(shape[0])      return arg-1./(shape[0]*1.)*trace(arg)*kronecker(shape[0])
623
624  def inner(arg1,arg2):  def inner(arg0,arg1):
625      """      """
626      @brief      @param arg0:
627        @param arg1:
@param arg1, arg2
628      """      """
629      sum=escript.Scalar(0,arg1.getFunctionSpace())      if isinstance(arg0,escript.Data):
630           arg=arg0
631        else:
632           arg=arg1
633
634        out=escript.Scalar(0,arg.getFunctionSpace())
635      if arg.getRank()==0:      if arg.getRank()==0:
636            return arg1*arg2            return arg0*arg1
637      elif arg.getRank()==1:      elif arg.getRank()==1:
638           sum=escript.Scalar(0,arg.getFunctionSpace())           out=escript.Scalar(0,arg.getFunctionSpace())
639           for i in range(arg.getShape()[0]):           for i in range(arg.getShape()[0]):
640              sum+=arg1[i]*arg2[i]              out+=arg0[i]*arg1[i]
641      elif arg.getRank()==2:      elif arg.getRank()==2:
642          sum=escript.Scalar(0,arg.getFunctionSpace())          out=escript.Scalar(0,arg.getFunctionSpace())
643          for i in range(arg.getShape()[0]):          for i in range(arg.getShape()[0]):
644             for j in range(arg.getShape()[1]):             for j in range(arg.getShape()[1]):
645                sum+=arg1[i,j]*arg2[i,j]                out+=arg0[i,j]*arg1[i,j]
646      elif arg.getRank()==3:      elif arg.getRank()==3:
647          sum=escript.Scalar(0,arg.getFunctionSpace())          out=escript.Scalar(0,arg.getFunctionSpace())
648          for i in range(arg.getShape()[0]):          for i in range(arg.getShape()[0]):
649              for j in range(arg.getShape()[1]):              for j in range(arg.getShape()[1]):
650                 for k in range(arg.getShape()[2]):                 for k in range(arg.getShape()[2]):
651                    sum+=arg1[i,j,k]*arg2[i,j,k]                    out+=arg0[i,j,k]*arg1[i,j,k]
652      elif arg.getRank()==4:      elif arg.getRank()==4:
653          sum=escript.Scalar(0,arg.getFunctionSpace())          out=escript.Scalar(0,arg.getFunctionSpace())
654          for i in range(arg.getShape()[0]):          for i in range(arg.getShape()[0]):
655             for j in range(arg.getShape()[1]):             for j in range(arg.getShape()[1]):
656                for k in range(arg.getShape()[2]):                for k in range(arg.getShape()[2]):
657                   for l in range(arg.getShape()[3]):                   for l in range(arg.getShape()[3]):
658                      sum+=arg1[i,j,k,l]*arg2[i,j,k,l]                      out+=arg0[i,j,k,l]*arg1[i,j,k,l]
659      else:      else:
660            raise SystemError,"inner is not been implemented yet"            raise SystemError,"inner is not been implemented yet"
661      return sum      return out
662
663  def sign(arg):  def tensormult(arg0,arg1):
664      """      # check LinearPDE!!!!
665      @brief      raise SystemError,"tensormult is not implemented yet!"
666
667      @param arg  def matrixmult(arg0,arg1):
668      """
669      if isinstance(arg,escript.Data):      if isinstance(arg1,numarray.NumArray) and isinstance(arg0,numarray.NumArray):
670         return arg.sign()          numarray.matrixmult(arg0,arg1)
671      else:      else:
672         return numarray.greater(arg,numarray.zeros(arg.shape))-numarray.less(arg,numarray.zeros(arg.shape))        # escript.matmult(arg0,arg1)
673          if isinstance(arg1,escript.Data) and not isinstance(arg0,escript.Data):
674            arg0=escript.Data(arg0,arg1.getFunctionSpace())
675          elif isinstance(arg0,escript.Data) and not isinstance(arg1,escript.Data):
676            arg1=escript.Data(arg1,arg0.getFunctionSpace())
677          if arg0.getRank()==2 and arg1.getRank()==1:
678              out=escript.Data(0,(arg0.getShape()[0],),arg0.getFunctionSpace())
679              for i in range(arg0.getShape()[0]):
680                 for j in range(arg0.getShape()[1]):
681                   # uses Data object slicing, plus Data * and += operators
682                   out[i]+=arg0[i,j]*arg1[j]
683              return out
684          elif arg0.getRank()==1 and arg1.getRank()==1:
685              return inner(arg0,arg1)
686          else:
687              raise SystemError,"matrixmult is not fully implemented yet!"
688
689    #=========================================================
690  # reduction operations:  # reduction operations:
691    #=========================================================
692  def sum(arg):  def sum(arg):
693      """      """
694      @brief      @param arg:

@param arg
695      """      """
696      return arg.sum()      return arg.sum()
697
698  def sup(arg):  def sup(arg):
699      """      """
700      @brief      @param arg:

@param arg
701      """      """
702      if isinstance(arg,escript.Data):      if isinstance(arg,escript.Data):
703         return arg.sup()         return arg.sup()
# Line 355  def sup(arg): Line 708  def sup(arg):
708
709  def inf(arg):  def inf(arg):
710      """      """
711      @brief      @param arg:

@param arg
712      """      """
713      if isinstance(arg,escript.Data):      if isinstance(arg,escript.Data):
714         return arg.inf()         return arg.inf()
# Line 368  def inf(arg): Line 719  def inf(arg):
719
720  def L2(arg):  def L2(arg):
721      """      """
722      @brief returns the L2-norm of the      Returns the L2-norm of the argument
723
724      @param arg      @param arg:
725      """      """
726      if isinstance(arg,escript.Data):      if isinstance(arg,escript.Data):
727         return arg.L2()         return arg.L2()
# Line 381  def L2(arg): Line 732  def L2(arg):
732
733  def Lsup(arg):  def Lsup(arg):
734      """      """
735      @brief      @param arg:

@param arg
736      """      """
737      if isinstance(arg,escript.Data):      if isinstance(arg,escript.Data):
738         return arg.Lsup()         return arg.Lsup()
739      elif isinstance(arg,float) or isinstance(arg,int):      elif isinstance(arg,float) or isinstance(arg,int):
740         return abs(arg)         return abs(arg)
741      else:      else:
742         return max(numarray.abs(arg))         return numarray.abs(arg).max()

def Linf(arg):
"""
@brief

@param arg
"""
if isinstance(arg,escript.Data):
return arg.Linf()
elif isinstance(arg,float) or isinstance(arg,int):
return abs(arg)
else:
return min(numarray.abs(arg))
743
744  def dot(arg1,arg2):  def dot(arg0,arg1):
745      """      """
746      @brief      @param arg0:
747        @param arg1:
@param arg
748      """      """
749      if isinstance(arg1,escript.Data):      if isinstance(arg0,escript.Data):
750         return arg1.dot(arg2)         return arg0.dot(arg1)
751      elif isinstance(arg1,escript.Data):      elif isinstance(arg1,escript.Data):
752         return arg2.dot(arg1)         return arg1.dot(arg0)
753      else:      else:
754         return numarray.dot(arg1,arg2)         return numarray.dot(arg0,arg1)
755
756  def kronecker(d):  def kronecker(d):
757     return numarray.identity(d)     if hasattr(d,"getDim"):
758          return numarray.identity(d.getDim())*1.
759       else:
760          return numarray.identity(d)*1.
761
762  def unit(i,d):  def unit(i,d):
763     """     """
764     @brief return a unit vector of dimension d with nonzero index i     Return a unit vector of dimension d with nonzero index i.
765     @param d dimension
766     @param i index     @param d: dimension
767       @param i: index
768     """     """
769     e = numarray.zeros((d,))     e = numarray.zeros((d,),numarray.Float)
770     e[i] = 1.0     e[i] = 1.0
771     return e     return e

Legend:
 Removed from v.117 changed lines Added in v.155