/[escript]/branches/symbolic_from_3470/escript/py_src/symbolic/symbols.py
ViewVC logotype

Diff of /branches/symbolic_from_3470/escript/py_src/symbolic/symbols.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3871 by caltinay, Mon Mar 12 05:18:16 2012 UTC revision 3872 by caltinay, Fri Mar 16 00:48:46 2012 UTC
# Line 85  class Symbol(object): Line 85  class Symbol(object):
85          if 'dim' in kwargs:          if 'dim' in kwargs:
86              self._dim=kwargs.pop('dim')              self._dim=kwargs.pop('dim')
87          else:          else:
88              self._dim=2              self._dim=-1 # undefined
89    
90          if 'subs' in kwargs:          if 'subs' in kwargs:
91              self._subs=kwargs.pop('subs')              self._subs=kwargs.pop('subs')
# Line 112  class Symbol(object): Line 112  class Symbol(object):
112                  self._arr=res                  self._arr=res
113                  if isinstance(arg, Symbol):                  if isinstance(arg, Symbol):
114                      self._subs.update(arg._subs)                      self._subs.update(arg._subs)
115                      self._dim=arg._dim                      if self._dim==-1:
116                            self._dim=arg._dim
117              elif isinstance(arg, sympy.Basic):              elif isinstance(arg, sympy.Basic):
118                  self._arr=numpy.array(arg)                  self._arr=numpy.array(arg)
119              else:              else:
# Line 208  class Symbol(object): Line 209  class Symbol(object):
209          """          """
210          Returns the spatial dimensionality of this symbol.          Returns the spatial dimensionality of this symbol.
211    
212          :return: the symbol's spatial dimensionality          :return: the symbol's spatial dimensionality, or -1 if undefined
213          :rtype: ``int``          :rtype: ``int``
214          """          """
215          return self._dim          return self._dim
# Line 294  class Symbol(object): Line 295  class Symbol(object):
295          return self.lambdarepr()          return self.lambdarepr()
296    
297      def lambdarepr(self):      def lambdarepr(self):
298            """
299            """
300          from sympy.printing.lambdarepr import lambdarepr          from sympy.printing.lambdarepr import lambdarepr
301          temp_arr=numpy.empty(self.getShape(), dtype=object)          temp_arr=numpy.empty(self.getShape(), dtype=object)
302          for idx,el in numpy.ndenumerate(self._arr):          for idx,el in numpy.ndenumerate(self._arr):
# Line 432  class Symbol(object): Line 435  class Symbol(object):
435    
436      def grad(self, where=None):      def grad(self, where=None):
437          """          """
438            Returns a symbol which represents the gradient of this symbol.
439          :type where: ``Symbol``, ``FunctionSpace``          :type where: ``Symbol``, ``FunctionSpace``
440          """          """
441            if self._dim < 0:
442                raise ValueError("grad: cannot compute gradient as symbol has undefined dimensionality")
443          subs=self._subs          subs=self._subs
444          if isinstance(where, Symbol):          if isinstance(where, Symbol):
445              if where.getRank()>0:              if where.getRank()>0:
# Line 524  class Symbol(object): Line 530  class Symbol(object):
530          if isinstance(other, Symbol):          if isinstance(other, Symbol):
531              arg1_c=other._arr.copy()              arg1_c=other._arr.copy()
532              sh1=other.getShape()              sh1=other.getShape()
533                dim=other._dim if self._dim < 0 else self._dim
534          else:          else:
535              arg1_c=other.copy()              arg1_c=other.copy()
536              sh1=other.shape              sh1=other.shape
537                dim=self._dim
538          d0,d1,d01=1,1,1          d0,d1,d01=1,1,1
539          for i in sh0[:self._arr.ndim-axis_offset]: d0*=i          for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
540          for i in sh1[axis_offset:]: d1*=i          for i in sh1[axis_offset:]: d1*=i
# Line 540  class Symbol(object): Line 548  class Symbol(object):
548          out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])          out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])
549          subs=self._subs.copy()          subs=self._subs.copy()
550          subs.update(other._subs)          subs.update(other._subs)
551          return Symbol(out, dim=self._dim, subs=subs)          return Symbol(out, dim=dim, subs=subs)
552    
553      def transposedTensorProduct(self, other, axis_offset):      def transposedTensorProduct(self, other, axis_offset):
554          """          """
# Line 550  class Symbol(object): Line 558  class Symbol(object):
558          if isinstance(other, Symbol):          if isinstance(other, Symbol):
559              arg1_c=other._arr.copy()              arg1_c=other._arr.copy()
560              sh1=other.getShape()              sh1=other.getShape()
561                dim=other._dim if self._dim < 0 else self._dim
562          else:          else:
563              arg1_c=other.copy()              arg1_c=other.copy()
564              sh1=other.shape              sh1=other.shape
565                dim=self._dim
566          d0,d1,d01=1,1,1          d0,d1,d01=1,1,1
567          for i in sh0[axis_offset:]: d0*=i          for i in sh0[axis_offset:]: d0*=i
568          for i in sh1[axis_offset:]: d1*=i          for i in sh1[axis_offset:]: d1*=i
# Line 566  class Symbol(object): Line 576  class Symbol(object):
576          out.resize(sh0[axis_offset:]+sh1[axis_offset:])          out.resize(sh0[axis_offset:]+sh1[axis_offset:])
577          subs=self._subs.copy()          subs=self._subs.copy()
578          subs.update(other._subs)          subs.update(other._subs)
579          return Symbol(out, dim=self._dim, subs=subs)          return Symbol(out, dim=dim, subs=subs)
580    
581      def tensorTransposedProduct(self, other, axis_offset):      def tensorTransposedProduct(self, other, axis_offset):
582          """          """
# Line 577  class Symbol(object): Line 587  class Symbol(object):
587              arg1_c=other._arr.copy()              arg1_c=other._arr.copy()
588              sh1=other.getShape()              sh1=other.getShape()
589              r1=other.getRank()              r1=other.getRank()
590                dim=other._dim if self._dim < 0 else self._dim
591          else:          else:
592              arg1_c=other.copy()              arg1_c=other.copy()
593              sh1=other.shape              sh1=other.shape
594              r1=other.ndim              r1=other.ndim
595                dim=self._dim
596          d0,d1,d01=1,1,1          d0,d1,d01=1,1,1
597          for i in sh0[:self._arr.ndim-axis_offset]: d0*=i          for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
598          for i in sh1[:r1-axis_offset]: d1*=i          for i in sh1[:r1-axis_offset]: d1*=i
# Line 594  class Symbol(object): Line 606  class Symbol(object):
606          out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])          out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])
607          subs=self._subs.copy()          subs=self._subs.copy()
608          subs.update(other._subs)          subs.update(other._subs)
609          return Symbol(out, dim=self._dim, subs=subs)          return Symbol(out, dim=dim, subs=subs)
610    
611      def trace(self, axis_offset):      def trace(self, axis_offset):
612          """          """
613            Returns the trace of this Symbol.
614          """          """
615          sh=self.getShape()          sh=self.getShape()
616          s1=1          s1=1
# Line 615  class Symbol(object): Line 628  class Symbol(object):
628    
629      def transpose(self, axis_offset):      def transpose(self, axis_offset):
630          """          """
631            Returns the transpose of this Symbol.
632          """          """
633          if axis_offset is None:          if axis_offset is None:
634              axis_offset=int(self._arr.ndim/2)              axis_offset=int(self._arr.ndim/2)
# Line 623  class Symbol(object): Line 637  class Symbol(object):
637    
638      def applyfunc(self, f, on_type=None):      def applyfunc(self, f, on_type=None):
639          """          """
640            Applies the function `f` to all elements (if on_type is None) or to
641            all elements of type `on_type`.
642          """          """
643          assert callable(f)          assert callable(f)
644          if self._arr.ndim==0:          if self._arr.ndim==0:
# Line 646  class Symbol(object): Line 662  class Symbol(object):
662    
663      def expand(self):      def expand(self):
664          """          """
665            Applies the sympy.expand operation on all elements in this symbol
666          """          """
667          return self.applyfunc(sympy.expand, sympy.Basic)          return self.applyfunc(sympy.expand, sympy.Basic)
668    
669      def simplify(self):      def simplify(self):
670          """          """
671            Applies the sympy.simplify operation on all elements in this symbol
672          """          """
673          return self.applyfunc(sympy.simplify, sympy.Basic)          return self.applyfunc(sympy.simplify, sympy.Basic)
674    
675      # unary/binary operations follow      # unary/binary operators follow
676    
677      def __pos__(self):      def __pos__(self):
678          return self          return self
# Line 685  class Symbol(object): Line 703  class Symbol(object):
703              raise TypeError("Incompatible shapes for operation")              raise TypeError("Incompatible shapes for operation")
704    
705      def __binaryop(self, op, other):      def __binaryop(self, op, other):
706            """
707            Helper for binary operations that checks types, shapes etc.
708            """
709          self._ensureShapeCompatible(other)          self._ensureShapeCompatible(other)
710          if isinstance(other, Symbol):          if isinstance(other, Symbol):
711              subs=self._subs.copy()              subs=self._subs.copy()
712              subs.update(other._subs)              subs.update(other._subs)
713              return Symbol(getattr(self._arr, op)(other._arr), dim=self._dim, subs=subs)              dim=other._dim if self._dim < 0 else self._dim
714                return Symbol(getattr(self._arr, op)(other._arr), dim=dim, subs=subs)
715          if isinstance(other, Data):          if isinstance(other, Data):
716              name='data'+str(id(other))              name='data'+str(id(other))
717              othersym=Symbol(name, other.getShape(), dim=self._dim)              othersym=Symbol(name, other.getShape(), dim=self._dim)
# Line 728  class Symbol(object): Line 750  class Symbol(object):
750      def __rpow__(self, other):      def __rpow__(self, other):
751          return self.__binaryop('__rpow__', other)          return self.__binaryop('__rpow__', other)
752    
753        # static methods
754    
755      @staticmethod      @staticmethod
756      def _symComp(sym):      def _symComp(sym):
757          """          """

Legend:
Removed from v.3871  
changed lines
  Added in v.3872

  ViewVC Help
Powered by ViewVC 1.1.26