/[escript]/trunk/escript/py_src/generateutil
ViewVC logotype

Diff of /trunk/escript/py_src/generateutil

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 157 by gross, Wed Nov 9 10:01:06 2005 UTC revision 313 by gross, Mon Dec 5 07:01:36 2005 UTC
# Line 1  Line 1 
1  #!/usr/bin/python  #!/usr/bin/python
2  # $Id:$  # $Id$
3    
4  """  """
5  program generates parts of the util.py and the test_util.py script  program generates parts of the util.py and the test_util.py script
# Line 33  def wherepos(arg): Line 33  def wherepos(arg):
33     else:     else:
34        return 0.        return 0.
35    
36    
37  class OPERATOR:  class OPERATOR:
38      def __init__(self,nickname,rng=[-1000.,1000],test_expr="",math_expr=None,      def __init__(self,nickname,rng=[-1000.,1000],test_expr="",math_expr=None,
39                  numarray_expr="",symbol_expr=None,diff=None,name=""):                  numarray_expr="",symbol_expr=None,diff=None,name=""):
# Line 431  def mkText(case,name,a,a1=None,use_taggi Line 432  def mkText(case,name,a,a1=None,use_taggi
432           if case=="float":           if case=="float":
433             if isinstance(a,float):             if isinstance(a,float):
434                  t_out+="      %s=%s\n"%(name,a)                  t_out+="      %s=%s\n"%(name,a)
435             elif len(a)==1:             elif a.rank==0:
436                  t_out+="      %s=%s\n"%(name,a)                  t_out+="      %s=%s\n"%(name,a)
437             else:             else:
438                  t_out+="      %s=numarray.array(%s)\n"%(name,a.tolist())                  t_out+="      %s=numarray.array(%s)\n"%(name,a.tolist())
439           elif case=="array":           elif case=="array":
440             if isinstance(a,float):             if isinstance(a,float):
441                  t_out+="      %s=numarray.array(%s)\n"%(name,a)                  t_out+="      %s=numarray.array(%s)\n"%(name,a)
442             elif len(a)==1:             elif a.rank==0:
443                  t_out+="      %s=numarray.array(%s)\n"%(name,a)                  t_out+="      %s=numarray.array(%s)\n"%(name,a)
444             else:             else:
445                  t_out+="      %s=numarray.array(%s)\n"%(name,a.tolist())                  t_out+="      %s=numarray.array(%s)\n"%(name,a.tolist())
446           elif case=="constData":           elif case=="constData":
447             if isinstance(a,float):             if isinstance(a,float):
448                t_out+="      %s=Data(%s,self.functionspace)\n"%(name,a)                t_out+="      %s=Data(%s,self.functionspace)\n"%(name,a)
449             elif len(a)==1:             elif a.rank==0:
450                t_out+="      %s=Data(%s,self.functionspace)\n"%(name,a)                t_out+="      %s=Data(%s,self.functionspace)\n"%(name,a)
451             else:             else:
452                t_out+="      %s=Data(numarray.array(%s),self.functionspace)\n"%(name,a.tolist())                t_out+="      %s=Data(numarray.array(%s),self.functionspace)\n"%(name,a.tolist())
# Line 453  def mkText(case,name,a,a1=None,use_taggi Line 454  def mkText(case,name,a,a1=None,use_taggi
454             if isinstance(a,float):             if isinstance(a,float):
455                t_out+="      %s=Data(%s,self.functionspace)\n"%(name,a)                t_out+="      %s=Data(%s,self.functionspace)\n"%(name,a)
456                t_out+="      %s.setTaggedValue(1,%s)\n"%(name,a1)                t_out+="      %s.setTaggedValue(1,%s)\n"%(name,a1)
457             elif len(a)==1:             elif a.rank==0:
458                t_out+="      %s=Data(%s,self.functionspace)\n"%(name,a)                t_out+="      %s=Data(%s,self.functionspace)\n"%(name,a)
459                t_out+="      %s.setTaggedValue(1,%s)\n"%(name,a1)                t_out+="      %s.setTaggedValue(1,%s)\n"%(name,a1)
460             else:             else:
# Line 464  def mkText(case,name,a,a1=None,use_taggi Line 465  def mkText(case,name,a,a1=None,use_taggi
465                if isinstance(a,float):                if isinstance(a,float):
466                   t_out+="      %s=Data(%s,self.functionspace)\n"%(name,a)                   t_out+="      %s=Data(%s,self.functionspace)\n"%(name,a)
467                   t_out+="      %s.setTaggedValue(1,%s)\n"%(name,a1)                   t_out+="      %s.setTaggedValue(1,%s)\n"%(name,a1)
468                elif len(a)==1:                elif a.rank==0:
469                   t_out+="      %s=Data(%s,self.functionspace)\n"%(name,a)                   t_out+="      %s=Data(%s,self.functionspace)\n"%(name,a)
470                   t_out+="      %s.setTaggedValue(1,%s)\n"%(name,a1)                   t_out+="      %s.setTaggedValue(1,%s)\n"%(name,a1)
471                else:                else:
# Line 475  def mkText(case,name,a,a1=None,use_taggi Line 476  def mkText(case,name,a,a1=None,use_taggi
476                t_out+="      msk_%s=whereNegative(self.functionspace.getX()[0]-0.5)\n"%name                t_out+="      msk_%s=whereNegative(self.functionspace.getX()[0]-0.5)\n"%name
477                if isinstance(a,float):                if isinstance(a,float):
478                     t_out+="      %s=msk_%s*(%s)+(1.-msk_%s)*(%s)\n"%(name,name,a,name,a1)                     t_out+="      %s=msk_%s*(%s)+(1.-msk_%s)*(%s)\n"%(name,name,a,name,a1)
479                elif len(a)==1:                elif a.rank==0:
480                     t_out+="      %s=msk_%s*numarray.array(%s)+(1.-msk_%s)*numarray.array(%s)\n"%(name,name,a,name,a1)                     t_out+="      %s=msk_%s*numarray.array(%s)+(1.-msk_%s)*numarray.array(%s)\n"%(name,name,a,name,a1)
481                else:                else:
482                     t_out+="      %s=msk_%s*numarray.array(%s)+(1.-msk_%s)*numarray.array(%s)\n"%(name,name,a.tolist(),name,a1.tolist())                     t_out+="      %s=msk_%s*numarray.array(%s)+(1.-msk_%s)*numarray.array(%s)\n"%(name,name,a.tolist(),name,a1.tolist())
483           elif case=="Symbol":           elif case=="Symbol":
484             if isinstance(a,float):             if isinstance(a,float):
485                t_out+="      %s=Symbol(shape=())\n"%(name)                t_out+="      %s=Symbol(shape=())\n"%(name)
486             elif len(a)==1:             elif a.rank==0:
487                t_out+="      %s=Symbol(shape=())\n"%(name)                t_out+="      %s=Symbol(shape=())\n"%(name)
488             else:             else:
489                t_out+="      %s=Symbol(shape=%s)\n"%(name,str(a.shape))                t_out+="      %s=Symbol(shape=%s)\n"%(name,str(a.shape))
# Line 517  def mkCode(txt,args=[],intend=""): Line 518  def mkCode(txt,args=[],intend=""):
518        out=out.replace("%%a%s%%"%c,r)        out=out.replace("%%a%s%%"%c,r)
519      return out        return out  
520    
521    def innerTEST(arg0,arg1):
522        if isinstance(arg0,float):
523           out=numarray.array(arg0*arg1)
524        else:
525           out=(arg0*arg1).sum()
526        return out
527    
528    def outerTEST(arg0,arg1):
529        if isinstance(arg0,float):
530           out=numarray.array(arg0*arg1)
531        elif isinstance(arg1,float):
532           out=numarray.array(arg0*arg1)
533        else:      
534           out=numarray.outerproduct(arg0,arg1).resize(arg0.shape+arg1.shape)
535        return out
536    
537    def tensorProductTest(arg0,arg1,sh_s):
538        if isinstance(arg0,float):
539           out=numarray.array(arg0*arg1)
540        elif isinstance(arg1,float):
541           out=numarray.array(arg0*arg1)
542        elif len(sh_s)==0:
543           out=numarray.outerproduct(arg0,arg1).resize(arg0.shape+arg1.shape)
544        else:
545           l=len(sh_s)
546           sh0=arg0.shape[:arg0.rank-l]
547           sh1=arg1.shape[l:]
548           ls,l0,l1=1,1,1
549           for i in sh_s: ls*=i
550           for i in sh0: l0*=i
551           for i in sh1: l1*=i
552           out1=numarray.outerproduct(arg0,arg1).resize((l0,ls,ls,l1))
553           out2=numarray.zeros((l0,l1),numarray.Float)
554           for i0 in range(l0):
555              for i1 in range(l1):
556                  for i in range(ls): out2[i0,i1]+=out1[i0,i,i,i1]
557           out=out2.resize(sh0+sh1)
558        return out
559          
560    def testMatrixMult(arg0,arg1,sh_s):
561         return numarray.matrixmultiply(arg0,arg1)
562    
563    
564    def testTensorMult(arg0,arg1,sh_s):
565         if len(arg0)==2:
566            return numarray.matrixmultiply(arg0,arg1)
567         else:
568            if arg1.rank==4:
569              out=numarray.zeros((arg0.shape[0],arg0.shape[1],arg1.shape[2],arg1.shape[3]),numarray.Float)
570              for i0 in range(arg0.shape[0]):
571               for i1 in range(arg0.shape[1]):
572                for i2 in range(arg0.shape[2]):
573                 for i3 in range(arg0.shape[3]):
574                  for j2 in range(arg1.shape[2]):
575                   for j3 in range(arg1.shape[3]):
576                         out[i0,i1,j2,j3]+=arg0[i0,i1,i2,i3]*arg1[i2,i3,j2,j3]
577            elif arg1.rank==3:
578              out=numarray.zeros((arg0.shape[0],arg0.shape[1],arg1.shape[2]),numarray.Float)
579              for i0 in range(arg0.shape[0]):
580               for i1 in range(arg0.shape[1]):
581                for i2 in range(arg0.shape[2]):
582                 for i3 in range(arg0.shape[3]):
583                  for j2 in range(arg1.shape[2]):
584                         out[i0,i1,j2]+=arg0[i0,i1,i2,i3]*arg1[i2,i3,j2]
585            elif arg1.rank==2:
586              out=numarray.zeros((arg0.shape[0],arg0.shape[1]),numarray.Float)
587              for i0 in range(arg0.shape[0]):
588               for i1 in range(arg0.shape[1]):
589                for i2 in range(arg0.shape[2]):
590                 for i3 in range(arg0.shape[3]):
591                         out[i0,i1]+=arg0[i0,i1,i2,i3]*arg1[i2,i3]
592            return out
593    
594    def testReduce(arg0,init_val,test_expr,post_expr):
595         out=init_val
596         if isinstance(arg0,float):
597              out=eval(test_expr.replace("%a1%","arg0"))
598         elif arg0.rank==0:
599              out=eval(test_expr.replace("%a1%","arg0"))
600         elif arg0.rank==1:
601            for i0 in range(arg0.shape[0]):
602                   out=eval(test_expr.replace("%a1%","arg0[i0]"))
603         elif arg0.rank==2:
604            for i0 in range(arg0.shape[0]):
605             for i1 in range(arg0.shape[1]):
606                   out=eval(test_expr.replace("%a1%","arg0[i0,i1]"))
607         elif arg0.rank==3:
608            for i0 in range(arg0.shape[0]):
609             for i1 in range(arg0.shape[1]):
610               for i2 in range(arg0.shape[2]):
611                   out=eval(test_expr.replace("%a1%","arg0[i0,i1,i2]"))
612         elif arg0.rank==4:
613            for i0 in range(arg0.shape[0]):
614             for i1 in range(arg0.shape[1]):
615               for i2 in range(arg0.shape[2]):
616                 for i3 in range(arg0.shape[3]):
617                   out=eval(test_expr.replace("%a1%","arg0[i0,i1,i2,i3]"))          
618         return eval(post_expr)
619  #=======================================================================================================  #=======================================================================================================
620  # basic binary operations (tests only!)  # local reduction
621  #=======================================================================================================  #=======================================================================================================
622  oper_range=[-5.,5.]  for oper in [["length",0.,"out+%a1%**2","math.sqrt(out)"],
623  for oper in [["add" ,"+",[-5.,5.]],               ["maxval",-1.e99,"max(out,%a1%)","out"],
624               ["mult","*",[-5.,5.]],               ["minval",1.e99,"min(out,%a1%)","out"] ]:
625               ["quotient" ,"/",[-5.,5.]],    for case in case_set:
626               ["power" ,"**",[0.01,5.]]]:       for sh in shape_set:
627     for case0 in case_set:         if not case=="float" or len(sh)==0:
628       for case1 in case_set:           text=""
629         for sh in shape_set:           text+="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
630           for sh_p in shape_set:           tname="def test_%s_%s_rank%s"%(oper[0],case,len(sh))
631             if len(sh_p)>0:           text+="   %s(self):\n"%tname
632                resource=[-1,1]           a=makeArray(sh,[-1.,1.])            
633             else:           a1=makeArray(sh,[-1.,1.])
634                resource=[1]           r1=testReduce(a1,oper[1],oper[2],oper[3])
635             for sh_d in resource:           r=testReduce(a,oper[1],oper[2],oper[3])
636              if sh_d>0:          
637                 sh0=sh           text+=mkText(case,"arg",a,a1)
638                 sh1=sh+sh_p           text+="      res=%s(arg)\n"%oper[0]
639              else:           if case=="Symbol":        
640                 sh1=sh               text+=mkText("array","s",a,a1)
641                 sh0=sh+sh_p               text+="      sub=res.substitute({arg:s})\n"        
642                             text+=mkText("array","ref",r,r1)
643              if (not case0=="float" or len(sh0)==0) and (not case1=="float" or len(sh1)==0) and \               res="sub"
644                 len(sh0)<5 and len(sh1)<5:           else:
645                 text+=mkText(case,"ref",r,r1)
646                 res="res"
647             if oper[0]=="length":              
648                   text+=mkTypeAndShapeTest(case,(),"res")
649             else:            
650                if case=="float" or case=="array":        
651                   text+=mkTypeAndShapeTest("float",(),"res")
652                else:          
653                   text+=mkTypeAndShapeTest(case,(),"res")
654             text+="      self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
655             if case == "taggedData":
656               t_prog_with_tags+=text
657             else:
658               t_prog+=text
659    print test_header
660    # print t_prog
661    print t_prog_with_tags
662    print test_tail          
663    1/0
664              
665    #=======================================================================================================
666    # tensor multiply
667    #=======================================================================================================
668    # oper=["generalTensorProduct",tensorProductTest]
669    # oper=["matrixmult",testMatrixMult]
670    oper=["tensormult",testTensorMult]
671    
672    for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]:
673      for sh0 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
674       for case1 in ["float","array","Symbol","constData","taggedData","expandedData"]:
675         for sh1 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
676           for sh_s in [ (),(3,), (2,3), (2,4,3),(4,2,3,2)]:
677              if (len(sh0+sh_s)==0 or not case0=="float") and (len(sh1+sh_s)==0 or not case1=="float") \
678                   and len(sh0+sh1)<5 and len(sh0+sh_s)<5 and len(sh1+sh_s)<5:
679                # if len(sh_s)==1 and len(sh0+sh_s)==2 and (len(sh_s+sh1)==1 or len(sh_s+sh1)==2)): # test for matrixmult
680                if ( len(sh_s)==1 and len(sh0+sh_s)==2 and ( len(sh1+sh_s)==2 or len(sh1+sh_s)==1 )) or (len(sh_s)==2 and len(sh0+sh_s)==4 and (len(sh1+sh_s)==2 or len(sh1+sh_s)==3 or len(sh1+sh_s)==4)):  # test for tensormult
681                  case=getResultCaseForBin(case0,case1)  
682                use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData"                use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData"
683                text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"                text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
684                  # tname="test_generalTensorProduct_%s_rank%s_%s_rank%s_offset%s"%(case0,len(sh0+sh_s),case1,len(sh_s+sh1),len(sh_s))
685                  #tname="test_matrixmult_%s_rank%s_%s_rank%s"%(case0,len(sh0+sh_s),case1,len(sh_s+sh1))
686                  tname="test_tensormult_%s_rank%s_%s_rank%s"%(case0,len(sh0+sh_s),case1,len(sh_s+sh1))
687                  # if tname=="test_generalTensorProduct_array_rank1_array_rank2_offset1":
688                  # print tnametest_generalTensorProduct_Symbol_rank1_Symbol_rank3_offset1
689                  text+="   def %s(self):\n"%tname
690                  a_0=makeArray(sh0+sh_s,[-1.,1])
691                  if case0 in ["taggedData", "expandedData"]:
692                      a1_0=makeArray(sh0+sh_s,[-1.,1])
693                  else:
694                      a1_0=a_0
695    
696                  a_1=makeArray(sh_s+sh1,[-1.,1])
697                  if case1 in ["taggedData", "expandedData"]:
698                      a1_1=makeArray(sh_s+sh1,[-1.,1])
699                  else:
700                      a1_1=a_1
701                  r=oper[1](a_0,a_1,sh_s)
702                  r1=oper[1](a1_0,a1_1,sh_s)
703                  text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data)
704                  text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data)
705                  #text+="      res=matrixmult(arg0,arg1)\n"
706                  text+="      res=tensormult(arg0,arg1)\n"
707                  #text+="      res=generalTensorProduct(arg0,arg1,offset=%s)\n"%(len(sh_s))
708                  if case=="Symbol":
709                     c0_res,c1_res=case0,case1
710                     subs="{"
711                     if case0=="Symbol":        
712                        text+=mkText("array","s0",a_0,a1_0)
713                        subs+="arg0:s0"
714                        c0_res="array"
715                     if case1=="Symbol":        
716                        text+=mkText("array","s1",a_1,a1_1)
717                        if not subs.endswith("{"): subs+=","
718                        subs+="arg1:s1"
719                        c1_res="array"
720                     subs+="}"  
721                     text+="      sub=res.substitute(%s)\n"%subs
722                     res="sub"
723                     text+=mkText(getResultCaseForBin(c0_res,c1_res),"ref",r,r1)
724                  else:
725                     res="res"
726                     text+=mkText(case,"ref",r,r1)
727                  text+=mkTypeAndShapeTest(case,sh0+sh1,"res")
728                  text+="      self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
729                  if case0 == "taggedData" or case1 == "taggedData":
730                      t_prog_with_tags+=text
731                  else:              
732                      t_prog+=text
733    print test_header
734    # print t_prog
735    print t_prog_with_tags
736    print test_tail          
737    1/0
738    #=======================================================================================================
739    # outer/inner
740    #=======================================================================================================
741    oper=["inner",innerTEST]
742    # oper=["outer",outerTEST]
743    for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]:
744      for sh1 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
745       for case1 in ["float","array","Symbol","constData","taggedData","expandedData"]:
746         for sh0 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
747            if (len(sh0)==0 or not case0=="float") and (len(sh1)==0 or not case1=="float") \
748               and len(sh0+sh1)<5:  
749                  use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData"
750    
751                  text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
752                tname="test_%s_%s_rank%s_%s_rank%s"%(oper[0],case0,len(sh0),case1,len(sh1))                tname="test_%s_%s_rank%s_%s_rank%s"%(oper[0],case0,len(sh0),case1,len(sh1))
753                text+="   def %s(self):\n"%tname                text+="   def %s(self):\n"%tname
754                a_0=makeArray(sh0,oper[2])                a_0=makeArray(sh0,[-1.,1])
755                if case0 in ["taggedData", "expandedData"]:                if case0 in ["taggedData", "expandedData"]:
756                    a1_0=makeArray(sh0,oper[2])                    a1_0=makeArray(sh0,[-1.,1])
757                else:                else:
758                    a1_0=a_0                    a1_0=a_0
759    
760                a_1=makeArray(sh1,oper[2])                a_1=makeArray(sh1,[-1.,1])
761                if case1 in ["taggedData", "expandedData"]:                if case1 in ["taggedData", "expandedData"]:
762                    a1_1=makeArray(sh1,oper[2])                    a1_1=makeArray(sh1,[-1.,1])
763                else:                else:
764                    a1_1=a_1                    a1_1=a_1
765                r1=makeResult2(a1_0,a1_1,"%a1%"+oper[1]+"%a2%")                r=oper[1](a_0,a_1)
766                r=makeResult2(a_0,a_1,"%a1%"+oper[1]+"%a2%")                r1=oper[1](a1_0,a1_1)
767                text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data)                text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data)
768                text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data)                text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data)
769                text+="      res=%s(arg0,arg1)\n"%oper[0]                text+="      res=%s(arg0,arg1)\n"%oper[0]
                 
770                case=getResultCaseForBin(case0,case1)                              case=getResultCaseForBin(case0,case1)              
771                if case=="Symbol":                if case=="Symbol":
772                   c0_res,c1_res=case0,case1                   c0_res,c1_res=case0,case1
# Line 584  for oper in [["add" ,"+",[-5.,5.]], Line 787  for oper in [["add" ,"+",[-5.,5.]],
787                else:                else:
788                   res="res"                   res="res"
789                   text+=mkText(case,"ref",r,r1)                   text+=mkText(case,"ref",r,r1)
790                if isinstance(r,float):                              text+=mkTypeAndShapeTest(case,sh0+sh1,"res")
                  text+=mkTypeAndShapeTest(case,(),"res")  
               else:  
                  text+=mkTypeAndShapeTest(case,r.shape,"res")  
791                text+="      self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res                text+="      self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
792                                
793                if case0 == "taggedData" or case1 == "taggedData":                if case0 == "taggedData" or case1 == "taggedData":
794                    t_prog_with_tags+=text                    t_prog_with_tags+=text
795                else:                              else:              
796                    t_prog+=text                    t_prog+=text
797    
798  print test_header  print test_header
799  # print t_prog  # print t_prog
800  print t_prog_with_tags  print t_prog_with_tags
801  print test_tail  print test_tail          
802  1/0  1/0
   
803  #=======================================================================================================  #=======================================================================================================
804  # basic binary operation overloading (tests only!)  # basic binary operation overloading (tests only!)
805  #=======================================================================================================  #=======================================================================================================
# Line 613  for oper in [["add" ,"+",[-5.,5.]], Line 813  for oper in [["add" ,"+",[-5.,5.]],
813       for sh0 in shape_set:       for sh0 in shape_set:
814         for case1 in case_set:         for case1 in case_set:
815           for sh1 in shape_set:           for sh1 in shape_set:
816             if (not case0=="float" or len(sh0)==0) and (not case1=="float" or len(sh1)==0) and \             if not case0=="array" and \
817                   (not case0=="float" or len(sh0)==0) and (not case1=="float" or len(sh1)==0) and \
818                 (sh0==() or sh1==() or sh1==sh0) and \                 (sh0==() or sh1==() or sh1==sh0) and \
819                 not (case0 in ["float","array"] and  case1 in ["float","array"]):                 not (case0 in ["float","array"] and  case1 in ["float","array"]):
820                  use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData"
821                text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"                text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
822                tname="test_%s_overloaded_%s_rank%s_%s_rank%s"%(oper[0],case0,len(sh0),case1,len(sh1))                tname="test_%s_overloaded_%s_rank%s_%s_rank%s"%(oper[0],case0,len(sh0),case1,len(sh1))
823                text+="   def %s(self):\n"%tname                text+="   def %s(self):\n"%tname
# Line 671  for oper in [["add" ,"+",[-5.,5.]], Line 873  for oper in [["add" ,"+",[-5.,5.]],
873                      t_prog+=text                      t_prog+=text
874    
875            
 # print u_prog  
 # 1/0  
876  print test_header  print test_header
877  print t_prog  # print t_prog
878    # print t_prog_with_tags
879    print t_prog_failing
880    print test_tail          
881  1/0  1/0
882    #=======================================================================================================
883    # basic binary operations (tests only!)
884    #=======================================================================================================
885    oper_range=[-5.,5.]
886    for oper in [["add" ,"+",[-5.,5.]],
887                 ["mult","*",[-5.,5.]],
888                 ["quotient" ,"/",[-5.,5.]],
889                 ["power" ,"**",[0.01,5.]]]:
890       for case0 in case_set:
891         for case1 in case_set:
892           for sh in shape_set:
893             for sh_p in shape_set:
894               if len(sh_p)>0:
895                  resource=[-1,1]
896               else:
897                  resource=[1]
898               for sh_d in resource:
899                if sh_d>0:
900                   sh0=sh
901                   sh1=sh+sh_p
902                else:
903                   sh1=sh
904                   sh0=sh+sh_p
905                
906                if (not case0=="float" or len(sh0)==0) and (not case1=="float" or len(sh1)==0) and \
907                   len(sh0)<5 and len(sh1)<5:
908                  use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData"
909                  text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
910                  tname="test_%s_%s_rank%s_%s_rank%s"%(oper[0],case0,len(sh0),case1,len(sh1))
911                  text+="   def %s(self):\n"%tname
912                  a_0=makeArray(sh0,oper[2])
913                  if case0 in ["taggedData", "expandedData"]:
914                      a1_0=makeArray(sh0,oper[2])
915                  else:
916                      a1_0=a_0
917    
918                  a_1=makeArray(sh1,oper[2])
919                  if case1 in ["taggedData", "expandedData"]:
920                      a1_1=makeArray(sh1,oper[2])
921                  else:
922                      a1_1=a_1
923                  r1=makeResult2(a1_0,a1_1,"%a1%"+oper[1]+"%a2%")
924                  r=makeResult2(a_0,a_1,"%a1%"+oper[1]+"%a2%")
925                  text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data)
926                  text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data)
927                  text+="      res=%s(arg0,arg1)\n"%oper[0]
928                  
929                  case=getResultCaseForBin(case0,case1)              
930                  if case=="Symbol":
931                     c0_res,c1_res=case0,case1
932                     subs="{"
933                     if case0=="Symbol":        
934                        text+=mkText("array","s0",a_0,a1_0)
935                        subs+="arg0:s0"
936                        c0_res="array"
937                     if case1=="Symbol":        
938                        text+=mkText("array","s1",a_1,a1_1)
939                        if not subs.endswith("{"): subs+=","
940                        subs+="arg1:s1"
941                        c1_res="array"
942                     subs+="}"  
943                     text+="      sub=res.substitute(%s)\n"%subs
944                     res="sub"
945                     text+=mkText(getResultCaseForBin(c0_res,c1_res),"ref",r,r1)
946                  else:
947                     res="res"
948                     text+=mkText(case,"ref",r,r1)
949                  if isinstance(r,float):              
950                     text+=mkTypeAndShapeTest(case,(),"res")
951                  else:
952                     text+=mkTypeAndShapeTest(case,r.shape,"res")
953                  text+="      self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
954                  
955                  if case0 == "taggedData" or case1 == "taggedData":
956                      t_prog_with_tags+=text
957                  else:              
958                      t_prog+=text
959    print test_header
960    # print t_prog
961    print t_prog_with_tags
962    print test_tail
963    1/0
964    
965  # print t_prog_with_tagsoper_range=[-5.,5.]  # print t_prog_with_tagsoper_range=[-5.,5.]
966  for oper in [["add" ,"+",[-5.,5.]],  for oper in [["add" ,"+",[-5.,5.]],
967               ["sub" ,"-",[-5.,5.]],               ["sub" ,"-",[-5.,5.]],
# Line 1375  def X(): Line 1661  def X():
1661                                ref_diff=(makeResult(trafo[j0,j1,j2,j3]*a_in+finc,f)-makeResult(trafo[j0,j1,j2,j3]*a_in,f))/finc                                ref_diff=(makeResult(trafo[j0,j1,j2,j3]*a_in+finc,f)-makeResult(trafo[j0,j1,j2,j3]*a_in,f))/finc
1662                                t_prog+="      self.failUnlessAlmostEqual(dvdin[%s,%s,%s,%s],%s,self.places,\"%s-derivative: wrong derivative of %s\")\n"%(j0,j1,j2,j3,ref_diff,str(sh_in),str((j0,j1,j2,j3)))                                t_prog+="      self.failUnlessAlmostEqual(dvdin[%s,%s,%s,%s],%s,self.places,\"%s-derivative: wrong derivative of %s\")\n"%(j0,j1,j2,j3,ref_diff,str(sh_in),str((j0,j1,j2,j3)))
1663    
1664  #==================  #
 case="inner"  
 for arg0 in ["float","array","Symbol","constData","taggedData","expandedData"]:  
    for arg1 in ["float","array","Symbol","constData","taggedData","expandedData"]:  
      for sh0 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:  
         sh1=sh0  
         if (len(sh0)==0 or not arg0=="float") and (len(sh1)==0 or not arg1=="float"):    
           tname="test_%s_%s_rank%s_%s_rank%s"%(case,arg0,len(sh0),arg1,len(sh1))  
           t_prog+="   def %s(self):\n"%tname  
           a0=makeArray(sh0,[-1,1])  
           a0_1=makeArray(sh0,[-1,1])  
           a1=makeArray(sh1,[-1,1])  
           a1_1=makeArray(sh1,[-1,1])  
           t_prog+=mkText(arg0,"arg0",a0,a0_1)  
           t_prog+=mkText(arg1,"arg1",a1,a1_1)  
           t_prog+="      res=%s(arg0,arg1)\n"%case  
   
 print t_prog              
 1/0  
1665    
1666  #==================  #==================
1667  cases=["Scalar","Vector","Tensor", "Tensor3","Tensor4"]  cases=["Scalar","Vector","Tensor", "Tensor3","Tensor4"]

Legend:
Removed from v.157  
changed lines
  Added in v.313

  ViewVC Help
Powered by ViewVC 1.1.26