/[escript]/trunk/escript/py_src/generateutil
ViewVC logotype

Diff of /trunk/escript/py_src/generateutil

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 291 by gross, Fri Dec 2 03:10:06 2005 UTC revision 396 by gross, Wed Dec 21 05:08:25 2005 UTC
# Line 33  def wherepos(arg): Line 33  def wherepos(arg):
33     else:     else:
34        return 0.        return 0.
35    
36    
37  class OPERATOR:  class OPERATOR:
38      def __init__(self,nickname,rng=[-1000.,1000],test_expr="",math_expr=None,      def __init__(self,nickname,rng=[-1000.,1000],test_expr="",math_expr=None,
39                  numarray_expr="",symbol_expr=None,diff=None,name=""):                  numarray_expr="",symbol_expr=None,diff=None,name=""):
# Line 589  def testTensorMult(arg0,arg1,sh_s): Line 590  def testTensorMult(arg0,arg1,sh_s):
590               for i3 in range(arg0.shape[3]):               for i3 in range(arg0.shape[3]):
591                       out[i0,i1]+=arg0[i0,i1,i2,i3]*arg1[i2,i3]                       out[i0,i1]+=arg0[i0,i1,i2,i3]*arg1[i2,i3]
592          return out          return out
593    
594    def testReduce(arg0,init_val,test_expr,post_expr):
595         out=init_val
596         if isinstance(arg0,float):
597              out=eval(test_expr.replace("%a1%","arg0"))
598         elif arg0.rank==0:
599              out=eval(test_expr.replace("%a1%","arg0"))
600         elif arg0.rank==1:
601            for i0 in range(arg0.shape[0]):
602                   out=eval(test_expr.replace("%a1%","arg0[i0]"))
603         elif arg0.rank==2:
604            for i0 in range(arg0.shape[0]):
605             for i1 in range(arg0.shape[1]):
606                   out=eval(test_expr.replace("%a1%","arg0[i0,i1]"))
607         elif arg0.rank==3:
608            for i0 in range(arg0.shape[0]):
609             for i1 in range(arg0.shape[1]):
610               for i2 in range(arg0.shape[2]):
611                   out=eval(test_expr.replace("%a1%","arg0[i0,i1,i2]"))
612         elif arg0.rank==4:
613            for i0 in range(arg0.shape[0]):
614             for i1 in range(arg0.shape[1]):
615               for i2 in range(arg0.shape[2]):
616                 for i3 in range(arg0.shape[3]):
617                   out=eval(test_expr.replace("%a1%","arg0[i0,i1,i2,i3]"))          
618         return eval(post_expr)
619        
620    def clipTEST(arg0,mn,mx):
621         if isinstance(arg0,float):
622              return max(min(arg0,mx),mn)
623         out=numarray.zeros(arg0.shape,numarray.Float64)
624         if arg0.rank==1:
625            for i0 in range(arg0.shape[0]):
626                out[i0]=max(min(arg0[i0],mx),mn)
627         elif arg0.rank==2:
628            for i0 in range(arg0.shape[0]):
629             for i1 in range(arg0.shape[1]):
630                out[i0,i1]=max(min(arg0[i0,i1],mx),mn)
631         elif arg0.rank==3:
632            for i0 in range(arg0.shape[0]):
633             for i1 in range(arg0.shape[1]):
634               for i2 in range(arg0.shape[2]):
635                  out[i0,i1,i2]=max(min(arg0[i0,i1,i2],mx),mn)
636         elif arg0.rank==4:
637            for i0 in range(arg0.shape[0]):
638             for i1 in range(arg0.shape[1]):
639               for i2 in range(arg0.shape[2]):
640                 for i3 in range(arg0.shape[3]):
641                    out[i0,i1,i2,i3]=max(min(arg0[i0,i1,i2,i3],mx),mn)
642         return out
643    def minimumTEST(arg0,arg1):
644         if isinstance(arg0,float):
645           if isinstance(arg1,float):
646              if arg0>arg1:
647                  return arg1
648              else:
649                  return arg0
650           else:
651              arg0=numarray.ones(arg1.shape)*arg0
652         else:
653           if isinstance(arg1,float):
654              arg1=numarray.ones(arg0.shape)*arg1
655         out=numarray.zeros(arg0.shape,numarray.Float64)
656         if arg0.rank==0:
657              if arg0>arg1:
658                  out=arg1
659              else:
660                  out=arg0
661         elif arg0.rank==1:
662            for i0 in range(arg0.shape[0]):
663              if arg0[i0]>arg1[i0]:
664                  out[i0]=arg1[i0]
665              else:
666                  out[i0]=arg0[i0]
667         elif arg0.rank==2:
668            for i0 in range(arg0.shape[0]):
669             for i1 in range(arg0.shape[1]):
670              if arg0[i0,i1]>arg1[i0,i1]:
671                  out[i0,i1]=arg1[i0,i1]
672              else:
673                  out[i0,i1]=arg0[i0,i1]
674         elif arg0.rank==3:
675            for i0 in range(arg0.shape[0]):
676             for i1 in range(arg0.shape[1]):
677               for i2 in range(arg0.shape[2]):
678                 if arg0[i0,i1,i2]>arg1[i0,i1,i2]:
679                  out[i0,i1,i2]=arg1[i0,i1,i2]
680                 else:
681                  out[i0,i1,i2]=arg0[i0,i1,i2]
682         elif arg0.rank==4:
683            for i0 in range(arg0.shape[0]):
684             for i1 in range(arg0.shape[1]):
685               for i2 in range(arg0.shape[2]):
686                 for i3 in range(arg0.shape[3]):
687                  if arg0[i0,i1,i2,i3]>arg1[i0,i1,i2,i3]:
688                   out[i0,i1,i2,i3]=arg1[i0,i1,i2,i3]
689                  else:
690                   out[i0,i1,i2,i3]=arg0[i0,i1,i2,i3]
691         return out
692    
693    
694  #=======================================================================================================  #=======================================================================================================
695  # tensor multiply  # clip
696  #=======================================================================================================  #=======================================================================================================
697  # oper=["generalTensorProduct",tensorProductTest]  oper_L=[["clip",clipTEST]]
698  # oper=["matrixmult",testMatrixMult]  for oper in oper_L:
699  oper=["tensormult",testTensorMult]   for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]:
   
 for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]:  
700    for sh0 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:    for sh0 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
701            if len(sh0)==0 or not case0=="float":
702                  text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
703                  tname="test_%s_%s_rank%s"%(oper[0],case0,len(sh0))
704                  text+="   def %s(self):\n"%tname
705                  a_0=makeArray(sh0,[-1.,1])
706                  if case0 in ["taggedData", "expandedData"]:
707                      a1_0=makeArray(sh0,[-1.,1])
708                  else:
709                      a1_0=a_0
710    
711                  r=oper[1](a_0,-0.3,0.5)
712                  r1=oper[1](a1_0,-0.3,0.5)
713                  text+=mkText(case0,"arg",a_0,a1_0)
714                  text+="      res=%s(arg,-0.3,0.5)\n"%oper[0]
715                  if case0=="Symbol":
716                     text+=mkText("array","s",a_0,a1_0)
717                     text+="      sub=res.substitute({arg:s})\n"
718                     res="sub"
719                     text+=mkText("array","ref",r,r1)
720                  else:
721                     res="res"
722                     text+=mkText(case0,"ref",r,r1)
723                  text+=mkTypeAndShapeTest(case0,sh0,"res")
724                  text+="      self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
725                  
726                  if case0 == "taggedData" :
727                      t_prog_with_tags+=text
728                  else:              
729                      t_prog+=text
730    
731    print test_header
732    # print t_prog
733    print t_prog_with_tags
734    print test_tail          
735    1/0
736    #=======================================================================================================
737    # maximum, minimum, clipping
738    #=======================================================================================================
739    oper_L=[ ["maximum",maximumTEST],
740             ["minimum",minimumTEST]]
741    for oper in oper_L:
742     for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]:
743      for sh1 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
744     for case1 in ["float","array","Symbol","constData","taggedData","expandedData"]:     for case1 in ["float","array","Symbol","constData","taggedData","expandedData"]:
745       for sh1 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:       for sh0 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
746         for sh_s in [ (),(3,), (2,3), (2,4,3),(4,2,3,2)]:          if (len(sh0)==0 or not case0=="float") and (len(sh1)==0 or not case1=="float") \
747            if (len(sh0+sh_s)==0 or not case0=="float") and (len(sh1+sh_s)==0 or not case1=="float") \             and (sh0==sh1 or len(sh0)==0 or len(sh1)==0) :
                and len(sh0+sh1)<5 and len(sh0+sh_s)<5 and len(sh1+sh_s)<5:  
             # if len(sh_s)==1 and len(sh0+sh_s)==2 and (len(sh_s+sh1)==1 or len(sh_s+sh1)==2)): # test for matrixmult  
             if ( len(sh_s)==1 and len(sh0+sh_s)==2 and ( len(sh1+sh_s)==2 or len(sh1+sh_s)==1 )) or (len(sh_s)==2 and len(sh0+sh_s)==4 and (len(sh1+sh_s)==2 or len(sh1+sh_s)==3 or len(sh1+sh_s)==4)):  # test for tensormult  
               case=getResultCaseForBin(case0,case1)    
748                use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData"                use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData"
749    
750                text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"                text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
751                # tname="test_generalTensorProduct_%s_rank%s_%s_rank%s_offset%s"%(case0,len(sh0+sh_s),case1,len(sh_s+sh1),len(sh_s))                tname="test_%s_%s_rank%s_%s_rank%s"%(oper[0],case0,len(sh0),case1,len(sh1))
               #tname="test_matrixmult_%s_rank%s_%s_rank%s"%(case0,len(sh0+sh_s),case1,len(sh_s+sh1))  
               tname="test_tensormult_%s_rank%s_%s_rank%s"%(case0,len(sh0+sh_s),case1,len(sh_s+sh1))  
               # if tname=="test_generalTensorProduct_array_rank1_array_rank2_offset1":  
               # print tnametest_generalTensorProduct_Symbol_rank1_Symbol_rank3_offset1  
752                text+="   def %s(self):\n"%tname                text+="   def %s(self):\n"%tname
753                a_0=makeArray(sh0+sh_s,[-1.,1])                a_0=makeArray(sh0,[-1.,1])
754                if case0 in ["taggedData", "expandedData"]:                if case0 in ["taggedData", "expandedData"]:
755                    a1_0=makeArray(sh0+sh_s,[-1.,1])                    a1_0=makeArray(sh0,[-1.,1])
756                else:                else:
757                    a1_0=a_0                    a1_0=a_0
758    
759                a_1=makeArray(sh_s+sh1,[-1.,1])                a_1=makeArray(sh1,[-1.,1])
760                if case1 in ["taggedData", "expandedData"]:                if case1 in ["taggedData", "expandedData"]:
761                    a1_1=makeArray(sh_s+sh1,[-1.,1])                    a1_1=makeArray(sh1,[-1.,1])
762                else:                else:
763                    a1_1=a_1                    a1_1=a_1
764                r=oper[1](a_0,a_1,sh_s)                r=oper[1](a_0,a_1)
765                r1=oper[1](a1_0,a1_1,sh_s)                r1=oper[1](a1_0,a1_1)
766                text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data)                text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data)
767                text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data)                text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data)
768                #text+="      res=matrixmult(arg0,arg1)\n"                text+="      res=%s(arg0,arg1)\n"%oper[0]
769                text+="      res=tensormult(arg0,arg1)\n"                case=getResultCaseForBin(case0,case1)              
               #text+="      res=generalTensorProduct(arg0,arg1,offset=%s)\n"%(len(sh_s))  
770                if case=="Symbol":                if case=="Symbol":
771                   c0_res,c1_res=case0,case1                   c0_res,c1_res=case0,case1
772                   subs="{"                   subs="{"
# Line 651  for case0 in ["float","array","Symbol"," Line 786  for case0 in ["float","array","Symbol","
786                else:                else:
787                   res="res"                   res="res"
788                   text+=mkText(case,"ref",r,r1)                   text+=mkText(case,"ref",r,r1)
789                text+=mkTypeAndShapeTest(case,sh0+sh1,"res")                if len(sh0)>len(sh1):
790                      text+=mkTypeAndShapeTest(case,sh0,"res")
791                  else:
792                      text+=mkTypeAndShapeTest(case,sh1,"res")
793                text+="      self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res                text+="      self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
794                  
795                if case0 == "taggedData" or case1 == "taggedData":                if case0 == "taggedData" or case1 == "taggedData":
796                    t_prog_with_tags+=text                    t_prog_with_tags+=text
797                else:                              else:              
798                    t_prog+=text                    t_prog+=text
799    
800  print test_header  print test_header
801  # print t_prog  # print t_prog
802  print t_prog_with_tags  print t_prog_with_tags
803  print test_tail            print test_tail          
804  1/0  1/0
805    
806    
807  #=======================================================================================================  #=======================================================================================================
808  # outer/inner  # outer inner
809  #=======================================================================================================  #=======================================================================================================
810  oper=["inner",innerTEST]  oper=["outer",outerTEST]
811  # oper=["outer",outerTEST]  # oper=["inner",innerTEST]
812  for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]:  for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]:
813    for sh1 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:    for sh1 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
814     for case1 in ["float","array","Symbol","constData","taggedData","expandedData"]:     for case1 in ["float","array","Symbol","constData","taggedData","expandedData"]:
# Line 725  for case0 in ["float","array","Symbol"," Line 867  for case0 in ["float","array","Symbol","
867  print test_header  print test_header
868  # print t_prog  # print t_prog
869  print t_prog_with_tags  print t_prog_with_tags
870    print test_tail          
871    1/0
872    
873    #=======================================================================================================
874    # local reduction
875    #=======================================================================================================
876    for oper in [["length",0.,"out+%a1%**2","math.sqrt(out)"],
877                 ["maxval",-1.e99,"max(out,%a1%)","out"],
878                 ["minval",1.e99,"min(out,%a1%)","out"] ]:
879      for case in case_set:
880         for sh in shape_set:
881           if not case=="float" or len(sh)==0:
882             text=""
883             text+="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
884             tname="def test_%s_%s_rank%s"%(oper[0],case,len(sh))
885             text+="   %s(self):\n"%tname
886             a=makeArray(sh,[-1.,1.])            
887             a1=makeArray(sh,[-1.,1.])
888             r1=testReduce(a1,oper[1],oper[2],oper[3])
889             r=testReduce(a,oper[1],oper[2],oper[3])
890            
891             text+=mkText(case,"arg",a,a1)
892             text+="      res=%s(arg)\n"%oper[0]
893             if case=="Symbol":        
894                 text+=mkText("array","s",a,a1)
895                 text+="      sub=res.substitute({arg:s})\n"        
896                 text+=mkText("array","ref",r,r1)
897                 res="sub"
898             else:
899                 text+=mkText(case,"ref",r,r1)
900                 res="res"
901             if oper[0]=="length":              
902                   text+=mkTypeAndShapeTest(case,(),"res")
903             else:            
904                if case=="float" or case=="array":        
905                   text+=mkTypeAndShapeTest("float",(),"res")
906                else:          
907                   text+=mkTypeAndShapeTest(case,(),"res")
908             text+="      self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
909             if case == "taggedData":
910               t_prog_with_tags+=text
911             else:
912               t_prog+=text
913    print test_header
914    # print t_prog
915    print t_prog_with_tags
916    print test_tail          
917    1/0
918              
919    #=======================================================================================================
920    # tensor multiply
921    #=======================================================================================================
922    # oper=["generalTensorProduct",tensorProductTest]
923    # oper=["matrixmult",testMatrixMult]
924    oper=["tensormult",testTensorMult]
925    
926    for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]:
927      for sh0 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
928       for case1 in ["float","array","Symbol","constData","taggedData","expandedData"]:
929         for sh1 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
930           for sh_s in [ (),(3,), (2,3), (2,4,3),(4,2,3,2)]:
931              if (len(sh0+sh_s)==0 or not case0=="float") and (len(sh1+sh_s)==0 or not case1=="float") \
932                   and len(sh0+sh1)<5 and len(sh0+sh_s)<5 and len(sh1+sh_s)<5:
933                # if len(sh_s)==1 and len(sh0+sh_s)==2 and (len(sh_s+sh1)==1 or len(sh_s+sh1)==2)): # test for matrixmult
934                if ( len(sh_s)==1 and len(sh0+sh_s)==2 and ( len(sh1+sh_s)==2 or len(sh1+sh_s)==1 )) or (len(sh_s)==2 and len(sh0+sh_s)==4 and (len(sh1+sh_s)==2 or len(sh1+sh_s)==3 or len(sh1+sh_s)==4)):  # test for tensormult
935                  case=getResultCaseForBin(case0,case1)  
936                  use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData"
937                  text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
938                  # tname="test_generalTensorProduct_%s_rank%s_%s_rank%s_offset%s"%(case0,len(sh0+sh_s),case1,len(sh_s+sh1),len(sh_s))
939                  #tname="test_matrixmult_%s_rank%s_%s_rank%s"%(case0,len(sh0+sh_s),case1,len(sh_s+sh1))
940                  tname="test_tensormult_%s_rank%s_%s_rank%s"%(case0,len(sh0+sh_s),case1,len(sh_s+sh1))
941                  # if tname=="test_generalTensorProduct_array_rank1_array_rank2_offset1":
942                  # print tnametest_generalTensorProduct_Symbol_rank1_Symbol_rank3_offset1
943                  text+="   def %s(self):\n"%tname
944                  a_0=makeArray(sh0+sh_s,[-1.,1])
945                  if case0 in ["taggedData", "expandedData"]:
946                      a1_0=makeArray(sh0+sh_s,[-1.,1])
947                  else:
948                      a1_0=a_0
949    
950                  a_1=makeArray(sh_s+sh1,[-1.,1])
951                  if case1 in ["taggedData", "expandedData"]:
952                      a1_1=makeArray(sh_s+sh1,[-1.,1])
953                  else:
954                      a1_1=a_1
955                  r=oper[1](a_0,a_1,sh_s)
956                  r1=oper[1](a1_0,a1_1,sh_s)
957                  text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data)
958                  text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data)
959                  #text+="      res=matrixmult(arg0,arg1)\n"
960                  text+="      res=tensormult(arg0,arg1)\n"
961                  #text+="      res=generalTensorProduct(arg0,arg1,offset=%s)\n"%(len(sh_s))
962                  if case=="Symbol":
963                     c0_res,c1_res=case0,case1
964                     subs="{"
965                     if case0=="Symbol":        
966                        text+=mkText("array","s0",a_0,a1_0)
967                        subs+="arg0:s0"
968                        c0_res="array"
969                     if case1=="Symbol":        
970                        text+=mkText("array","s1",a_1,a1_1)
971                        if not subs.endswith("{"): subs+=","
972                        subs+="arg1:s1"
973                        c1_res="array"
974                     subs+="}"  
975                     text+="      sub=res.substitute(%s)\n"%subs
976                     res="sub"
977                     text+=mkText(getResultCaseForBin(c0_res,c1_res),"ref",r,r1)
978                  else:
979                     res="res"
980                     text+=mkText(case,"ref",r,r1)
981                  text+=mkTypeAndShapeTest(case,sh0+sh1,"res")
982                  text+="      self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
983                  if case0 == "taggedData" or case1 == "taggedData":
984                      t_prog_with_tags+=text
985                  else:              
986                      t_prog+=text
987    print test_header
988    # print t_prog
989    print t_prog_with_tags
990  print test_tail            print test_tail          
991  1/0  1/0
992  #=======================================================================================================  #=======================================================================================================

Legend:
Removed from v.291  
changed lines
  Added in v.396

  ViewVC Help
Powered by ViewVC 1.1.26