/[escript]/trunk/escript/py_src/generatediff
ViewVC logotype

Diff of /trunk/escript/py_src/generatediff

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 291 by gross, Fri Dec 2 03:10:06 2005 UTC revision 313 by gross, Mon Dec 5 07:01:36 2005 UTC
# Line 33  def wherepos(arg): Line 33  def wherepos(arg):
33     else:     else:
34        return 0.        return 0.
35    
36    
37  class OPERATOR:  class OPERATOR:
38      def __init__(self,nickname,rng=[-1000.,1000],test_expr="",math_expr=None,      def __init__(self,nickname,rng=[-1000.,1000],test_expr="",math_expr=None,
39                  numarray_expr="",symbol_expr=None,diff=None,name=""):                  numarray_expr="",symbol_expr=None,diff=None,name=""):
# Line 589  def testTensorMult(arg0,arg1,sh_s): Line 590  def testTensorMult(arg0,arg1,sh_s):
590               for i3 in range(arg0.shape[3]):               for i3 in range(arg0.shape[3]):
591                       out[i0,i1]+=arg0[i0,i1,i2,i3]*arg1[i2,i3]                       out[i0,i1]+=arg0[i0,i1,i2,i3]*arg1[i2,i3]
592          return out          return out
593    
594    def testReduce(arg0,init_val,test_expr,post_expr):
595         out=init_val
596         if isinstance(arg0,float):
597              out=eval(test_expr.replace("%a1%","arg0"))
598         elif arg0.rank==0:
599              out=eval(test_expr.replace("%a1%","arg0"))
600         elif arg0.rank==1:
601            for i0 in range(arg0.shape[0]):
602                   out=eval(test_expr.replace("%a1%","arg0[i0]"))
603         elif arg0.rank==2:
604            for i0 in range(arg0.shape[0]):
605             for i1 in range(arg0.shape[1]):
606                   out=eval(test_expr.replace("%a1%","arg0[i0,i1]"))
607         elif arg0.rank==3:
608            for i0 in range(arg0.shape[0]):
609             for i1 in range(arg0.shape[1]):
610               for i2 in range(arg0.shape[2]):
611                   out=eval(test_expr.replace("%a1%","arg0[i0,i1,i2]"))
612         elif arg0.rank==4:
613            for i0 in range(arg0.shape[0]):
614             for i1 in range(arg0.shape[1]):
615               for i2 in range(arg0.shape[2]):
616                 for i3 in range(arg0.shape[3]):
617                   out=eval(test_expr.replace("%a1%","arg0[i0,i1,i2,i3]"))          
618         return eval(post_expr)
619    #=======================================================================================================
620    # local reduction
621    #=======================================================================================================
622    for oper in [["length",0.,"out+%a1%**2","math.sqrt(out)"],
623                 ["maxval",-1.e99,"max(out,%a1%)","out"],
624                 ["minval",1.e99,"min(out,%a1%)","out"] ]:
625      for case in case_set:
626         for sh in shape_set:
627           if not case=="float" or len(sh)==0:
628             text=""
629             text+="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
630             tname="def test_%s_%s_rank%s"%(oper[0],case,len(sh))
631             text+="   %s(self):\n"%tname
632             a=makeArray(sh,[-1.,1.])            
633             a1=makeArray(sh,[-1.,1.])
634             r1=testReduce(a1,oper[1],oper[2],oper[3])
635             r=testReduce(a,oper[1],oper[2],oper[3])
636            
637             text+=mkText(case,"arg",a,a1)
638             text+="      res=%s(arg)\n"%oper[0]
639             if case=="Symbol":        
640                 text+=mkText("array","s",a,a1)
641                 text+="      sub=res.substitute({arg:s})\n"        
642                 text+=mkText("array","ref",r,r1)
643                 res="sub"
644             else:
645                 text+=mkText(case,"ref",r,r1)
646                 res="res"
647             if oper[0]=="length":              
648                   text+=mkTypeAndShapeTest(case,(),"res")
649             else:            
650                if case=="float" or case=="array":        
651                   text+=mkTypeAndShapeTest("float",(),"res")
652                else:          
653                   text+=mkTypeAndShapeTest(case,(),"res")
654             text+="      self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
655             if case == "taggedData":
656               t_prog_with_tags+=text
657             else:
658               t_prog+=text
659    print test_header
660    # print t_prog
661    print t_prog_with_tags
662    print test_tail          
663    1/0
664              
665  #=======================================================================================================  #=======================================================================================================
666  # tensor multiply  # tensor multiply
667  #=======================================================================================================  #=======================================================================================================

Legend:
Removed from v.291  
changed lines
  Added in v.313

  ViewVC Help
Powered by ViewVC 1.1.26