/[escript]/trunk/escript/py_src/util.py
ViewVC logotype

Diff of /trunk/escript/py_src/util.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 82 by jgs, Tue Oct 26 06:53:54 2004 UTC revision 153 by jgs, Tue Oct 25 01:51:20 2005 UTC
# Line 3  Line 3 
3  ## @file util.py  ## @file util.py
4    
5  """  """
6  @brief Utility functions for escript  Utility functions for escript
7    
8    @todo:
9    
10      - binary operations @ (@=+,-,*,/,**)::
11          (a@b)[:,*]=a[:]@b[:,*] if rank(a)<rank(b)
12          (a@b)[:]=a[:]@b[:] if rank(a)=rank(b)
13          (a@b)[*,:]=a[*,:]@b[:] if rank(a)>rank(b)
14      - implementation of outer::
15          outer(a,b)[:,*]=a[:]*b[*]
16      - trace::
17          trace(arg,axis0=a0,axis1=a1)(:,&,*)=sum_i trace(:,i,&,i,*) (i are at index a0 and a1)
18  """  """
19    
20  import numarray  import numarray
21  #  import escript
22  #   escript constants:  import symbols
23  #  import os
24  FALSE=0  
25  TRUE=1  #=========================================================
26  UNKNOWN=-1  #   some little helpers
27  EPSILON=1.e-15  #=========================================================
28  Pi=3.1415926535897931  def _testForZero(arg):
29  # matrix types     """
30  CSC=0       Returns True is arg is considered to be zero.
31  CSR=1     """
32  LUMPED=10     if isinstance(arg,int):
33  # some solver options:        return not arg>0
34  NO_REORDERING=0     elif isinstance(arg,float):
35  MINIMUM_FILL_IN=1        return not arg>0.
36  NESTED_DISSECTION=2     elif isinstance(arg,numarray.NumArray):
37  DEFAULT_METHOD=0        a=abs(arg)
38  PCG=1        while isinstance(a,numarray.NumArray): a=numarray.sometrue(a)
39  CR=2        return not a>0
40  CGS=3     else:
41  BICGSTAB=4        return False
42  SSOR=5  
43  ILU0=6  #=========================================================
44  ILUT=7  def saveVTK(filename,domain=None,**data):
45  JACOBI=8      """
46  # supported file formats:      writes a L{Data} objects into a files using the the VTK XML file format.
47  VRML=1  
48  PNG=2      Example:
49  JPEG=3  
50  JPG=3         tmp=Scalar(..)
51  PS=4         v=Vector(..)
52  OOGL=5         saveVTK("solution.xml",temperature=tmp,velovity=v)
53  BMP=6  
54  TIFF=7      tmp and v are written into "solution.dx" where tmp is named "temperature" and v is named "velovity"
55  OPENINVENTOR=8  
56  RENDERMAN=9      @param filename: file name of the output file
57  PNM=10      @type filename: C(str}
58  #      @param domain: domain of the L{Data} object. If not specified, the domain of the given L{Data} objects is used.
59  # wrapper for various functions: if the argument has attribute the function name      @type domain: L{escript.Domain}
60  # as an argument it calls the correspong methods. Otherwise the coresponsing numarray      @keyword <name>: writes the assigned value to the VTK file using <name> as identifier.
61  # function is called.      @type <name>: L{Data} object.
62  #      @note: The data objects have to be defined on the same domain. They may not be in the same L{FunctionSpace} but one cannot expect that all L{FunctionSpace} can be mixed. Typically, data on the boundary and data on the interior cannot be mixed.
63  def L2(arg):      """
64        if domain==None:
65           for i in data.keys():
66              if not data[i].isEmpty(): domain=data[i].getFunctionSpace().getDomain()
67        if domain==None:
68            raise ValueError,"no domain detected."
69        else:
70            domain.saveVTK(filename,data)
71    #=========================================================
72    def saveDX(filename,domain=None,**data):
73        """
74        writes a L{Data} objects into a files using the the DX file format.
75    
76        Example:
77    
78           tmp=Scalar(..)
79           v=Vector(..)
80           saveDX("solution.dx",temperature=tmp,velovity=v)
81    
82        tmp and v are written into "solution.dx" where tmp is named "temperature" and v is named "velovity".
83    
84        @param filename: file name of the output file
85        @type filename: C(str}
86        @param domain: domain of the L{Data} object. If not specified, the domain of the given L{Data} objects is used.
87        @type domain: L{escript.Domain}
88        @keyword <name>: writes the assigned value to the DX file using <name> as identifier. The identifier can be used to select the data set when data are imported into DX.
89        @type <name>: L{Data} object.
90        @note: The data objects have to be defined on the same domain. They may not be in the same L{FunctionSpace} but one cannot expect that all L{FunctionSpace} can be mixed. Typically, data on the boundary and data on the interior cannot be mixed.
91        """
92        if domain==None:
93           for i in data.keys():
94              if not data[i].isEmpty(): domain=data[i].getFunctionSpace().getDomain()
95        if domain==None:
96            raise ValueError,"no domain detected."
97        else:
98            domain.saveDX(filename,data)
99    #=========================================================
100    
101    def exp(arg):
102      """      """
103      @brief      Applies the exponential function to arg.
104    
105      @param arg      @param arg: argument
106      """      """
107      return arg.L2()      if isinstance(arg,symbols.Symbol):
108           return symbols.Exp_Symbol(arg)
109        elif hasattr(arg,"exp"):
110           return arg.exp()
111        else:
112           return numarray.exp(arg)
113    
114  def grad(arg,where=None):  def sqrt(arg):
115      """      """
116      @brief      Applies the squre root function to arg.
117    
118      @param arg      @param arg: argument
     @param where  
119      """      """
120      if where==None:      if isinstance(arg,symbols.Symbol):
121         return arg.grad()         return symbols.Sqrt_Symbol(arg)
122        elif hasattr(arg,"sqrt"):
123           return arg.sqrt()
124      else:      else:
125         return arg.grad(where)         return numarray.sqrt(arg)      
126    
127  def integrate(arg):  def log(arg):
128      """      """
129      @brief      Applies the logarithmic function base 10 to arg.
130    
131      @param arg      @param arg: argument
132      """      """
133      return arg.integrate()      if isinstance(arg,symbols.Symbol):
134           return symbols.Log_Symbol(arg)
135        elif hasattr(arg,"log"):
136           return arg.log()
137        else:
138           return numarray.log10(arg)
139    
140  def interpolate(arg,where):  def ln(arg):
141      """      """
142      @brief      Applies the natural logarithmic function to arg.
143    
144      @param arg      @param arg: argument
     @param where  
145      """      """
146      return arg.interpolate(where)      if isinstance(arg,symbols.Symbol):
147           return symbols.Ln_Symbol(arg)
148        elif hasattr(arg,"ln"):
149           return arg.ln()
150        else:
151           return numarray.log(arg)
152    
153  def transpose(arg):  def sin(arg):
154      """      """
155      @brief      Applies the sin function to arg.
156    
157      @param arg      @param arg: argument
158      """      """
159      if hasattr(arg,"transpose"):      if isinstance(arg,symbols.Symbol):
160         return arg.transpose()         return symbols.Sin_Symbol(arg)
161        elif hasattr(arg,"sin"):
162           return arg.sin()
163      else:      else:
164         return numarray.transpose(arg,axis=None)         return numarray.sin(arg)
165    
166  def trace(arg):  def cos(arg):
167      """      """
168      @brief      Applies the cos function to arg.
169    
170      @param arg      @param arg: argument
171      """      """
172      if hasattr(arg,"trace"):      if isinstance(arg,symbols.Symbol):
173         return arg.trace()         return symbols.Cos_Symbol(arg)
174        elif hasattr(arg,"cos"):
175           return arg.cos()
176      else:      else:
177         return numarray.trace(arg,k=0)         return numarray.cos(arg)
178    
179  def exp(arg):  def tan(arg):
180      """      """
181      @brief      Applies the tan function to arg.
182    
183      @param arg      @param arg: argument
184      """      """
185      if hasattr(arg,"exp"):      if isinstance(arg,symbols.Symbol):
186         return arg.exp()         return symbols.Tan_Symbol(arg)
187        elif hasattr(arg,"tan"):
188           return arg.tan()
189      else:      else:
190         return numarray.exp(arg)         return numarray.tan(arg)
191    
192  def sqrt(arg):  def asin(arg):
193      """      """
194      @brief      Applies the asin function to arg.
195    
196      @param arg      @param arg: argument
197      """      """
198       if hasattr(arg,"sqrt"):      if isinstance(arg,symbols.Symbol):
199          return arg.sqrt()         return symbols.Asin_Symbol(arg)
200       else:      elif hasattr(arg,"asin"):
201         return numarray.sqrt(arg)         return arg.asin()
202        else:
203           return numarray.asin(arg)
204    
205  def sin(arg):  def acos(arg):
206      """      """
207      @brief      Applies the acos function to arg.
208    
209      @param arg      @param arg: argument
210      """      """
211      if hasattr(arg,"sin"):      if isinstance(arg,symbols.Symbol):
212         return arg.sin()         return symbols.Acos_Symbol(arg)
213        elif hasattr(arg,"acos"):
214           return arg.acos()
215      else:      else:
216         return numarray.sin(arg)         return numarray.acos(arg)
217    
218  def cos(arg):  def atan(arg):
219      """      """
220      @brief      Applies the atan function to arg.
221    
222      @param arg      @param arg: argument
223      """      """
224      if hasattr(arg,"cos"):      if isinstance(arg,symbols.Symbol):
225         return arg.cos()         return symbols.Atan_Symbol(arg)
226        elif hasattr(arg,"atan"):
227           return arg.atan()
228      else:      else:
229         return numarray.cos(arg)         return numarray.atan(arg)
230    
231    def sinh(arg):
232        """
233        Applies the sinh function to arg.
234    
235        @param arg: argument
236        """
237        if isinstance(arg,symbols.Symbol):
238           return symbols.Sinh_Symbol(arg)
239        elif hasattr(arg,"sinh"):
240           return arg.sinh()
241        else:
242           return numarray.sinh(arg)
243    
244    def cosh(arg):
245        """
246        Applies the cosh function to arg.
247    
248        @param arg: argument
249        """
250        if isinstance(arg,symbols.Symbol):
251           return symbols.Cosh_Symbol(arg)
252        elif hasattr(arg,"cosh"):
253           return arg.cosh()
254        else:
255           return numarray.cosh(arg)
256    
257    def tanh(arg):
258        """
259        Applies the tanh function to arg.
260    
261        @param arg: argument
262        """
263        if isinstance(arg,symbols.Symbol):
264           return symbols.Tanh_Symbol(arg)
265        elif hasattr(arg,"tanh"):
266           return arg.tanh()
267        else:
268           return numarray.tanh(arg)
269    
270    def asinh(arg):
271        """
272        Applies the asinh function to arg.
273    
274        @param arg: argument
275        """
276        if isinstance(arg,symbols.Symbol):
277           return symbols.Asinh_Symbol(arg)
278        elif hasattr(arg,"asinh"):
279           return arg.asinh()
280        else:
281           return numarray.asinh(arg)
282    
283    def acosh(arg):
284        """
285        Applies the acosh function to arg.
286    
287        @param arg: argument
288        """
289        if isinstance(arg,symbols.Symbol):
290           return symbols.Acosh_Symbol(arg)
291        elif hasattr(arg,"acosh"):
292           return arg.acosh()
293        else:
294           return numarray.acosh(arg)
295    
296    def atanh(arg):
297        """
298        Applies the atanh function to arg.
299    
300        @param arg: argument
301        """
302        if isinstance(arg,symbols.Symbol):
303           return symbols.Atanh_Symbol(arg)
304        elif hasattr(arg,"atanh"):
305           return arg.atanh()
306        else:
307           return numarray.atanh(arg)
308    
309    def sign(arg):
310        """
311        Applies the sign function to arg.
312    
313        @param arg: argument
314        """
315        if isinstance(arg,symbols.Symbol):
316           return symbols.Sign_Symbol(arg)
317        elif hasattr(arg,"sign"):
318           return arg.sign()
319        else:
320           return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))- \
321                  numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
322    
323  def maxval(arg):  def maxval(arg):
324      """      """
325      @brief      Returns the maximum value of argument arg.
326    
327      @param arg      @param arg: argument
328      """      """
329      return arg.maxval()      if isinstance(arg,symbols.Symbol):
330           return symbols.Max_Symbol(arg)
331        elif hasattr(arg,"maxval"):
332           return arg.maxval()
333        elif hasattr(arg,"max"):
334           return arg.max()
335        else:
336           return arg
337    
338  def minval(arg):  def minval(arg):
339      """      """
340      @brief      Returns the minimum value of argument arg.
341    
342      @param arg      @param arg: argument
343      """      """
344      return arg.minval()      if isinstance(arg,symbols.Symbol):
345           return symbols.Min_Symbol(arg)
346        elif hasattr(arg,"maxval"):
347           return arg.minval()
348        elif hasattr(arg,"min"):
349           return arg.min()
350        else:
351           return arg
352    
353  def sup(arg):  def wherePositive(arg):
354      """      """
355      @brief      Returns the positive values of argument arg.
356    
357      @param arg      @param arg: argument
358      """      """
359      return arg.sup()      if _testForZero(arg):
360          return 0
361        elif isinstance(arg,symbols.Symbol):
362           return symbols.WherePositive_Symbol(arg)
363        elif hasattr(arg,"wherePositive"):
364           return arg.minval()
365        elif hasattr(arg,"wherePositive"):
366           numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))
367        else:
368           if arg>0:
369              return 1.
370           else:
371              return 0.
372    
373    def whereNegative(arg):
374        """
375        Returns the negative values of argument arg.
376    
377        @param arg: argument
378        """
379        if _testForZero(arg):
380          return 0
381        elif isinstance(arg,symbols.Symbol):
382           return symbols.WhereNegative_Symbol(arg)
383        elif hasattr(arg,"whereNegative"):
384           return arg.whereNegative()
385        elif hasattr(arg,"shape"):
386           numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
387        else:
388           if arg<0:
389              return 1.
390           else:
391              return 0.
392    
393    def maximum(arg0,arg1):
394        """
395        Return arg1 where arg1 is bigger then arg0 otherwise arg0 is returned.
396        """
397        m=whereNegative(arg0-arg1)
398        return m*arg1+(1.-m)*arg0
399      
400    def minimum(arg0,arg1):
401        """
402        Return arg0 where arg1 is bigger then arg0 otherwise arg1 is returned.
403        """
404        m=whereNegative(arg0-arg1)
405        return m*arg0+(1.-m)*arg1
406      
407    def outer(arg0,arg1):
408       if _testForZero(arg0) or _testForZero(arg1):
409          return 0
410       else:
411          if isinstance(arg0,symbols.Symbol) or isinstance(arg1,symbols.Symbol):
412            return symbols.Outer_Symbol(arg0,arg1)
413          elif _identifyShape(arg0)==() or _identifyShape(arg1)==():
414            return arg0*arg1
415          elif isinstance(arg0,numarray.NumArray) and isinstance(arg1,numarray.NumArray):
416            return numarray.outer(arg0,arg1)
417          else:
418            if arg0.getRank()==1 and arg1.getRank()==1:
419              out=escript.Data(0,(arg0.getShape()[0],arg1.getShape()[0]),arg1.getFunctionSpace())
420              for i in range(arg0.getShape()[0]):
421                for j in range(arg1.getShape()[0]):
422                    out[i,j]=arg0[i]*arg1[j]
423              return out
424            else:
425              raise ValueError,"outer is not fully implemented yet."
426    
427  def inf(arg):  def interpolate(arg,where):
428      """      """
429      @brief      Interpolates the function into the FunctionSpace where.
430    
431      @param arg      @param arg:    interpolant
432        @param where:  FunctionSpace to interpolate to
433      """      """
434      return arg.inf()      if _testForZero(arg):
435          return 0
436        elif isinstance(arg,symbols.Symbol):
437           return symbols.Interpolated_Symbol(arg,where)
438        else:
439           return escript.Data(arg,where)
440    
441  def Lsup(arg):  def div(arg,where=None):
442      """      """
443      @brief      Returns the divergence of arg at where.
444    
445      @param arg      @param arg:   Data object representing the function which gradient to
446                      be calculated.
447        @param where: FunctionSpace in which the gradient will be calculated.
448                      If not present or C{None} an appropriate default is used.
449      """      """
450      return arg.Lsup()      return trace(grad(arg,where))
451    
452  def length(arg):  def jump(arg):
453      """      """
454      @brief      Returns the jump of arg across a continuity.
455    
456      @param arg      @param arg:   Data object representing the function which gradient
457                      to be calculated.
458      """      """
459      return arg.length()      d=arg.getDomain()
460        return arg.interpolate(escript.FunctionOnContactOne())-arg.interpolate(escript.FunctionOnContactZero())
461      
462    
463  def sign(arg):  def grad(arg,where=None):
464      """      """
465      @brief      Returns the spatial gradient of arg at where.
466    
467        @param arg:   Data object representing the function which gradient
468                      to be calculated.
469        @param where: FunctionSpace in which the gradient will be calculated.
470                      If not present or C{None} an appropriate default is used.
471        """
472        if _testForZero(arg):
473          return 0
474        elif isinstance(arg,symbols.Symbol):
475           return symbols.Grad_Symbol(arg,where)
476        elif hasattr(arg,"grad"):
477           if where==None:
478              return arg.grad()
479           else:
480              return arg.grad(where)
481        else:
482           return arg*0.
483    
484      @param arg  def integrate(arg,where=None):
485      """      """
486      return arg.sign()      Return the integral if the function represented by Data object arg over
487  #      its domain.
488  # $Log$  
489  # Revision 1.1  2004/10/26 06:53:56  jgs      @param arg:   Data object representing the function which is integrated.
490  # Initial revision      @param where: FunctionSpace in which the integral is calculated.
491  #                    If not present or C{None} an appropriate default is used.
492  # Revision 1.1.2.3  2004/10/26 06:43:48  jgs      """
493  # committing Lutz's and Paul's changes to brach jgs      if _testForZero(arg):
494  #        return 0
495  # Revision 1.1.4.1  2004/10/20 05:32:51  cochrane      elif isinstance(arg,symbols.Symbol):
496  # Added incomplete Doxygen comments to files, or merely put the docstrings that already exist into Doxygen form.         return symbols.Integral_Symbol(arg,where)
497  #      else:    
498  # Revision 1.1  2004/08/05 03:58:27  gross         if not where==None: arg=escript.Data(arg,where)
499  # Bug in Assemble_NodeCoordinates fixed         if arg.getRank()==0:
500  #           return arg.integrate()[0]
501           else:
502             return arg.integrate()
503    
504    #=============================
505  #  #
506    # wrapper for various functions: if the argument has attribute the function name
507    # as an argument it calls the corresponding methods. Otherwise the corresponding
508    # numarray function is called.
509    
510    # functions involving the underlying Domain:
511    
512    
513    # functions returning Data objects:
514    
515    def transpose(arg,axis=None):
516        """
517        Returns the transpose of the Data object arg.
518    
519        @param arg:
520        """
521        if axis==None:
522           r=0
523           if hasattr(arg,"getRank"): r=arg.getRank()
524           if hasattr(arg,"rank"): r=arg.rank
525           axis=r/2
526        if isinstance(arg,symbols.Symbol):
527           return symbols.Transpose_Symbol(arg,axis=r)
528        if isinstance(arg,escript.Data):
529           # hack for transpose
530           r=arg.getRank()
531           if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects"
532           s=arg.getShape()
533           out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace())
534           for i in range(s[0]):
535              for j in range(s[1]):
536                 out[j,i]=arg[i,j]
537           return out
538           # end hack for transpose
539           return arg.transpose(axis)
540        else:
541           return numarray.transpose(arg,axis=axis)
542    
543    def trace(arg,axis0=0,axis1=1):
544        """
545        Return
546    
547        @param arg:
548        """
549        if isinstance(arg,symbols.Symbol):
550           s=list(arg.getShape())        
551           s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:])
552           return symbols.Trace_Symbol(arg,axis0=axis0,axis1=axis1)
553        elif isinstance(arg,escript.Data):
554           # hack for trace
555           s=arg.getShape()
556           if s[axis0]!=s[axis1]:
557               raise ValueError,"illegal axis in trace"
558           out=escript.Scalar(0.,arg.getFunctionSpace())
559           for i in range(s[axis0]):
560              out+=arg[i,i]
561           return out
562           # end hack for trace
563        else:
564           return numarray.trace(arg,axis0=axis0,axis1=axis1)
565    
566    def length(arg):
567        """
568    
569        @param arg:
570        """
571        if isinstance(arg,escript.Data):
572           if arg.isEmpty(): return escript.Data()
573           if arg.getRank()==0:
574              return abs(arg)
575           elif arg.getRank()==1:
576              out=escript.Scalar(0,arg.getFunctionSpace())
577              for i in range(arg.getShape()[0]):
578                 out+=arg[i]**2
579              return sqrt(out)
580           elif arg.getRank()==2:
581              out=escript.Scalar(0,arg.getFunctionSpace())
582              for i in range(arg.getShape()[0]):
583                 for j in range(arg.getShape()[1]):
584                    out+=arg[i,j]**2
585              return sqrt(out)
586           elif arg.getRank()==3:
587              out=escript.Scalar(0,arg.getFunctionSpace())
588              for i in range(arg.getShape()[0]):
589                 for j in range(arg.getShape()[1]):
590                    for k in range(arg.getShape()[2]):
591                       out+=arg[i,j,k]**2
592              return sqrt(out)
593           elif arg.getRank()==4:
594              out=escript.Scalar(0,arg.getFunctionSpace())
595              for i in range(arg.getShape()[0]):
596                 for j in range(arg.getShape()[1]):
597                    for k in range(arg.getShape()[2]):
598                       for l in range(arg.getShape()[3]):
599                          out+=arg[i,j,k,l]**2
600              return sqrt(out)
601           else:
602              raise SystemError,"length is not been fully implemented yet"
603              # return arg.length()
604        elif isinstance(arg,float):
605           return abs(arg)
606        else:
607           return sqrt((arg**2).sum())
608    
609    def deviator(arg):
610        """
611        @param arg:
612        """
613        if isinstance(arg,escript.Data):
614            shape=arg.getShape()
615        else:
616            shape=arg.shape
617        if len(shape)!=2:
618              raise ValueError,"Deviator requires rank 2 object"
619        if shape[0]!=shape[1]:
620              raise ValueError,"Deviator requires a square matrix"
621        return arg-1./(shape[0]*1.)*trace(arg)*kronecker(shape[0])
622    
623    def inner(arg0,arg1):
624        """
625        @param arg0:
626        @param arg1:
627        """
628        if isinstance(arg0,escript.Data):
629           arg=arg0
630        else:
631           arg=arg1
632    
633        out=escript.Scalar(0,arg.getFunctionSpace())
634        if arg.getRank()==0:
635              return arg0*arg1
636        elif arg.getRank()==1:
637             out=escript.Scalar(0,arg.getFunctionSpace())
638             for i in range(arg.getShape()[0]):
639                out+=arg0[i]*arg1[i]
640        elif arg.getRank()==2:
641            out=escript.Scalar(0,arg.getFunctionSpace())
642            for i in range(arg.getShape()[0]):
643               for j in range(arg.getShape()[1]):
644                  out+=arg0[i,j]*arg1[i,j]
645        elif arg.getRank()==3:
646            out=escript.Scalar(0,arg.getFunctionSpace())
647            for i in range(arg.getShape()[0]):
648                for j in range(arg.getShape()[1]):
649                   for k in range(arg.getShape()[2]):
650                      out+=arg0[i,j,k]*arg1[i,j,k]
651        elif arg.getRank()==4:
652            out=escript.Scalar(0,arg.getFunctionSpace())
653            for i in range(arg.getShape()[0]):
654               for j in range(arg.getShape()[1]):
655                  for k in range(arg.getShape()[2]):
656                     for l in range(arg.getShape()[3]):
657                        out+=arg0[i,j,k,l]*arg1[i,j,k,l]
658        else:
659              raise SystemError,"inner is not been implemented yet"
660        return out
661    
662    def tensormult(arg0,arg1):
663        # check LinearPDE!!!!
664        raise SystemError,"tensormult is not implemented yet!"
665    
666    def matrixmult(arg0,arg1):
667    
668        if isinstance(arg1,numarray.NumArray) and isinstance(arg0,numarray.NumArray):
669            numarray.matrixmult(arg0,arg1)
670        else:
671          # escript.matmult(arg0,arg1)
672          if isinstance(arg1,escript.Data) and not isinstance(arg0,escript.Data):
673            arg0=escript.Data(arg0,arg1.getFunctionSpace())
674          elif isinstance(arg0,escript.Data) and not isinstance(arg1,escript.Data):
675            arg1=escript.Data(arg1,arg0.getFunctionSpace())
676          if arg0.getRank()==2 and arg1.getRank()==1:
677              out=escript.Data(0,(arg0.getShape()[0],),arg0.getFunctionSpace())
678              for i in range(arg0.getShape()[0]):
679                 for j in range(arg0.getShape()[1]):
680                   # uses Data object slicing, plus Data * and += operators
681                   out[i]+=arg0[i,j]*arg1[j]
682              return out
683          elif arg0.getRank()==1 and arg1.getRank()==1:
684              return inner(arg0,arg1)
685          else:
686              raise SystemError,"matrixmult is not fully implemented yet!"
687    
688    #=========================================================
689    # reduction operations:
690    #=========================================================
691    def sum(arg):
692        """
693        @param arg:
694        """
695        return arg.sum()
696    
697    def sup(arg):
698        """
699        @param arg:
700        """
701        if isinstance(arg,escript.Data):
702           return arg.sup()
703        elif isinstance(arg,float) or isinstance(arg,int):
704           return arg
705        else:
706           return arg.max()
707    
708    def inf(arg):
709        """
710        @param arg:
711        """
712        if isinstance(arg,escript.Data):
713           return arg.inf()
714        elif isinstance(arg,float) or isinstance(arg,int):
715           return arg
716        else:
717           return arg.min()
718    
719    def L2(arg):
720        """
721        Returns the L2-norm of the argument
722    
723        @param arg:
724        """
725        if isinstance(arg,escript.Data):
726           return arg.L2()
727        elif isinstance(arg,float) or isinstance(arg,int):
728           return abs(arg)
729        else:
730           return numarry.sqrt(dot(arg,arg))
731    
732    def Lsup(arg):
733        """
734        @param arg:
735        """
736        if isinstance(arg,escript.Data):
737           return arg.Lsup()
738        elif isinstance(arg,float) or isinstance(arg,int):
739           return abs(arg)
740        else:
741           return numarray.abs(arg).max()
742    
743    def dot(arg0,arg1):
744        """
745        @param arg0:
746        @param arg1:
747        """
748        if isinstance(arg0,escript.Data):
749           return arg0.dot(arg1)
750        elif isinstance(arg1,escript.Data):
751           return arg1.dot(arg0)
752        else:
753           return numarray.dot(arg0,arg1)
754    
755    def kronecker(d):
756       if hasattr(d,"getDim"):
757          return numarray.identity(d.getDim())*1.
758       else:
759          return numarray.identity(d)*1.
760    
761    def unit(i,d):
762       """
763       Return a unit vector of dimension d with nonzero index i.
764    
765       @param d: dimension
766       @param i: index
767       """
768       e = numarray.zeros((d,),numarray.Float)
769       e[i] = 1.0
770       return e

Legend:
Removed from v.82  
changed lines
  Added in v.153

  ViewVC Help
Powered by ViewVC 1.1.26