/[escript]/trunk/escript/py_src/generatediff
ViewVC logotype

Diff of /trunk/escript/py_src/generatediff

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 438 by gross, Fri Jan 20 00:39:00 2006 UTC revision 493 by gross, Fri Feb 3 02:18:45 2006 UTC
# Line 696  def minimumTEST(arg0,arg1): Line 696  def minimumTEST(arg0,arg1):
696                else:                else:
697                 out[i0,i1,i2,i3]=arg0[i0,i1,i2,i3]                 out[i0,i1,i2,i3]=arg0[i0,i1,i2,i3]
698       return out       return out
699  def unrollLoops(a,b,o,arg,tap=""):      
700    def unrollLoops(a,b,o,arg,tap="",x="x"):
701      out=""      out=""
702      if a.rank==1:      if a.rank==1:
703               z=""               z=""
704               for i99 in range(a.shape[0]):               for i99 in range(a.shape[0]):
705                 if not z=="": z+="+"                 if not z=="": z+="+"
706                 if o=="1":                 if o=="1":
707                    z+="(%s)*x[%s]"%(a[i99]+b[i99],i99)                    z+="(%s)*%s[%s]"%(a[i99]+b[i99],x,i99)
708                 else:                 else:
709                    z+="(%s)*x[%s]**o+(%s)*x[%s]"%(a[i99],i99,b[i99],i99)                    z+="(%s)*%s[%s]**o+(%s)*%s[%s]"%(a[i99],x,i99,b[i99],x,i99)
710    
711               out+=tap+"%s=%s\n"%(arg,z)               out+=tap+"%s=%s\n"%(arg,z)
712    
# Line 717  def unrollLoops(a,b,o,arg,tap=""): Line 718  def unrollLoops(a,b,o,arg,tap=""):
718                 if o=="1":                 if o=="1":
719                    z+="(%s)*x[%s]"%(a[i0,i99]+b[i0,i99],i99)                    z+="(%s)*x[%s]"%(a[i0,i99]+b[i0,i99],i99)
720                 else:                 else:
721                    z+="(%s)*x[%s]**o+(%s)*x[%s]"%(a[i0,i99],i99,b[i0,i99],i99)                    z+="(%s)*%s[%s]**o+(%s)*%s[%s]"%(a[i0,i99],x,i99,b[i0,i99],x,i99)
722    
723               out+=tap+"%s[%s]=%s\n"%(arg,i0,z)               out+=tap+"%s[%s]=%s\n"%(arg,i0,z)
724      elif a.rank==3:      elif a.rank==3:
# Line 727  def unrollLoops(a,b,o,arg,tap=""): Line 728  def unrollLoops(a,b,o,arg,tap=""):
728               for i99 in range(a.shape[2]):               for i99 in range(a.shape[2]):
729                 if not z=="": z+="+"                 if not z=="": z+="+"
730                 if o=="1":                 if o=="1":
731                    z+="(%s)*x[%s]"%(a[i0,i1,i99]+b[i0,i1,i99],i99)                    z+="(%s)*%s[%s]"%(a[i0,i1,i99]+b[i0,i1,i99],x,i99)
732                 else:                 else:
733                    z+="(%s)*x[%s]**o+(%s)*x[%s]"%(a[i0,i1,i99],i99,b[i0,i1,i99],i99)                    z+="(%s)*%s[%s]**o+(%s)*%s[%s]"%(a[i0,i1,i99],x,i99,b[i0,i1,i99],x,i99)
734    
735               out+=tap+"%s[%s,%s]=%s\n"%(arg,i0,i1,z)               out+=tap+"%s[%s,%s]=%s\n"%(arg,i0,i1,z)
736      elif a.rank==4:      elif a.rank==4:
# Line 740  def unrollLoops(a,b,o,arg,tap=""): Line 741  def unrollLoops(a,b,o,arg,tap=""):
741               for i99 in range(a.shape[3]):               for i99 in range(a.shape[3]):
742                 if not z=="": z+="+"                 if not z=="": z+="+"
743                 if o=="1":                 if o=="1":
744                    z+="(%s)*x[%s]"%(a[i0,i1,i2,i99]+b[i0,i1,i2,i99],i99)                    z+="(%s)*%s[%s]"%(a[i0,i1,i2,i99]+b[i0,i1,i2,i99],x,i99)
745                 else:                 else:
746                    z+="(%s)*x[%s]**o+(%s)*x[%s]"%(a[i0,i1,i2,i99],i99,b[i0,i1,i2,i99],i99)                    z+="(%s)*%s[%s]**o+(%s)*%s[%s]"%(a[i0,i1,i2,i99],x,i99,b[i0,i1,i2,i99],x,i99)
747    
748               out+=tap+"%s[%s,%s,%s]=%s\n"%(arg,i0,i1,i2,z)               out+=tap+"%s[%s,%s,%s]=%s\n"%(arg,i0,i1,i2,z)
749      elif a.rank==5:      elif a.rank==5:
# Line 754  def unrollLoops(a,b,o,arg,tap=""): Line 755  def unrollLoops(a,b,o,arg,tap=""):
755               for i99 in range(a.shape[4]):               for i99 in range(a.shape[4]):
756                 if not z=="": z+="+"                 if not z=="": z+="+"
757                 if o=="1":                 if o=="1":
758                    z+="(%s)*x[%s]"%(a[i0,i1,i2,i3,i99]+b[i0,i1,i2,i3,i99],i99)                    z+="(%s)*%s[%s]"%(a[i0,i1,i2,i3,i99]+b[i0,i1,i2,i3,i99],x,i99)
759                 else:                 else:
760                    z+="(%s)*x[%s]**o+(%s)*x[%s]"%(a[i0,i1,i2,i3,i99],i99,b[i0,i1,i2,i3,i99],i99)                    z+="(%s)*%s[%s]**o+(%s)*%s[%s]"%(a[i0,i1,i2,i3,i99],x,i99,b[i0,i1,i2,i3,i99],x,i99)
761    
762               out+=tap+"%s[%s,%s,%s,%s]=%s\n"%(arg,i0,i1,i2,i3,z)               out+=tap+"%s[%s,%s,%s,%s]=%s\n"%(arg,i0,i1,i2,i3,z)
763      return out      return out
# Line 797  def unrollLoopsOfGrad(a,b,o,arg,tap=""): Line 798  def unrollLoopsOfGrad(a,b,o,arg,tap=""):
798                 else:                 else:
799                   out+=tap+"%s[%s,%s,%s,%s]=o*(%s)*x_ref[%s]**(o-1)+(%s)\n"%(arg,i0,i1,i2,i99,a[i0,i1,i2,i99],i99,b[i0,i1,i2,i99])                   out+=tap+"%s[%s,%s,%s,%s]=o*(%s)*x_ref[%s]**(o-1)+(%s)\n"%(arg,i0,i1,i2,i99,a[i0,i1,i2,i99],i99,b[i0,i1,i2,i99])
800      return out      return out
801    def unrollLoopsOfDiv(a,b,o,arg,tap=""):
802        out=tap+arg+"="
803        if o=="1":
804           z=0.
805           for i99 in range(a.shape[0]):
806                z+=b[i99,i99]+a[i99,i99]
807           out+="(%s)"%z    
808        else:
809           z=0.
810           for i99 in range(a.shape[0]):
811                z+=b[i99,i99]
812                if i99>0: out+="+"
813                out+="o*(%s)*x_ref[%s]**(o-1)"%(a[i99,i99],i99)
814           out+="+(%s)"%z    
815        return out
816    
817  def unrollLoopsOfInteriorIntegral(a,b,where,arg,tap=""):  def unrollLoopsOfInteriorIntegral(a,b,where,arg,tap=""):
818      if where=="Function":      if where=="Function":
819         xfac_o=1.         xfac_o=1.
# Line 916  def unrollLoopsOfInteriorIntegral(a,b,wh Line 933  def unrollLoopsOfInteriorIntegral(a,b,wh
933                 out+="+(%s)*0.5**o\n"%zop                 out+="+(%s)*0.5**o\n"%zop
934    
935      return out      return out
936    def unrollLoopsSimplified(b,arg,tap=""):
937        out=""
938        if isinstance(b,float) or b.rank==0:
939                 out+=tap+"%s=(%s)*x[0]\n"%(arg,str(b))
940    
941        elif b.rank==1:
942            for i0 in range(b.shape[0]):
943                 out+=tap+"%s[%s]=(%s)*x[%s]\n"%(arg,i0,b[i0],i0)
944        elif b.rank==2:
945            for i0 in range(b.shape[0]):
946             for i1 in range(b.shape[1]):
947                 out+=tap+"%s[%s,%s]=(%s)*x[%s]\n"%(arg,i0,i1,b[i0,i1],i1)
948        elif b.rank==3:
949            for i0 in range(b.shape[0]):
950             for i1 in range(b.shape[1]):
951               for i2 in range(b.shape[2]):
952                 out+=tap+"%s[%s,%s,%s]=(%s)*x[%s]\n"%(arg,i0,i1,i2,b[i0,i1,i2],i2)
953        elif b.rank==4:
954            for i0 in range(b.shape[0]):
955             for i1 in range(b.shape[1]):
956               for i2 in range(b.shape[2]):
957                for i3 in range(b.shape[3]):
958                 out+=tap+"%s[%s,%s,%s,%s]=(%s)*x[%s]\n"%(arg,i0,i1,i2,i3,b[i0,i1,i2,i3],i3)
959        return out
960    
961    def unrollLoopsOfL2(b,where,arg,tap=""):
962        out=""
963        z=[]
964        if isinstance(b,float) or b.rank==0:
965           z.append(b**2)
966        elif b.rank==1:
967            for i0 in range(b.shape[0]):
968                 z.append(b[i0]**2)
969        elif b.rank==2:
970            for i1 in range(b.shape[1]):
971               s=0
972               for i0 in range(b.shape[0]):
973                  s+=b[i0,i1]**2
974               z.append(s)
975        elif b.rank==3:
976            for i2 in range(b.shape[2]):
977              s=0
978              for i0 in range(b.shape[0]):
979                 for i1 in range(b.shape[1]):
980                    s+=b[i0,i1,i2]**2
981              z.append(s)
982    
983        elif b.rank==4:
984          for i3 in range(b.shape[3]):
985             s=0
986             for i0 in range(b.shape[0]):
987               for i1 in range(b.shape[1]):
988                  for i2 in range(b.shape[2]):
989                     s+=b[i0,i1,i2,i3]**2
990             z.append(s)        
991        if where=="Function":
992           xfac_o=1.
993           xfac_op=0.
994           z_fac_s=""
995           zo_fac_s=""
996           zo_fac=1./3.
997        elif where=="FunctionOnBoundary":
998           xfac_o=1.
999           xfac_op=0.
1000           z_fac_s="*dim"
1001           zo_fac_s="*(2.*dim+1.)/3."
1002           zo_fac=1.
1003        elif where in ["FunctionOnContactZero","FunctionOnContactOne"]:
1004           xfac_o=0.
1005           xfac_op=1.
1006           z_fac_s=""
1007           zo_fac_s=""    
1008           zo_fac=1./3.    
1009        zo=0.
1010        zop=0.
1011        for i99 in range(len(z)):
1012               if i99==0:
1013                   zo+=xfac_o*z[i99]
1014                   zop+=xfac_op*z[i99]
1015               else:
1016                   zo+=z[i99]
1017        out+=tap+"%s=sqrt((%s)%s"%(arg,zo*zo_fac,zo_fac_s)
1018        if zop==0.:
1019           out+=")\n"
1020        else:
1021           out+="+(%s))\n"%(zop*0.5**2)
1022        return out
1023    #=======================================================================================================
1024    # transpose
1025    #=======================================================================================================
1026    def transposeTest(r,offset):
1027        if isinstance(r,float): return r
1028        s=r.shape
1029        s1=1
1030        for i in s[:offset]: s1*=i
1031        s2=1
1032        for i in s[offset:]: s2*=i
1033        out=numarray.reshape(r,(s1,s2))
1034        out.transpose()
1035        return numarray.resize(out,s[offset:]+s[:offset])
1036    
1037    name,tt="transpose",transposeTest
1038    for case0 in ["array","Symbol","constData","taggedData","expandedData"]:
1039      for sh0 in [ (), (3,), (4,5), (6,2,2),(3,2,3,4)]:
1040        for offset in range(len(sh0)+1):
1041                  text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1042                  tname="test_%s_%s_rank%s_offset%s"%(name,case0,len(sh0),offset)
1043                  text+="   def %s(self):\n"%tname
1044                  sh_t=sh0[offset:]+sh0[:offset]
1045    
1046    #              sh_t=list(sh0)
1047    #              sh_t[offset+1]=sh_t[offset]
1048    #              sh_t=tuple(sh_t)
1049    #              sh_r=[]
1050    #              for i in range(offset): sh_r.append(sh0[i])
1051    #              for i in range(offset+2,len(sh0)): sh_r.append(sh0[i])              
1052    #              sh_r=tuple(sh_r)
1053    
1054                  a_0=makeArray(sh0,[-1.,1])
1055                  if case0 in ["taggedData", "expandedData"]:
1056                      a1_0=makeArray(sh0,[-1.,1])
1057                  else:
1058                      a1_0=a_0
1059                  r=tt(a_0,offset)
1060                  r1=tt(a1_0,offset)
1061                  text+=mkText(case0,"arg",a_0,a1_0)
1062                  text+="      res=%s(arg,%s)\n"%(name,offset)
1063                  if case0=="Symbol":
1064                     text+=mkText("array","s",a_0,a1_0)
1065                     text+="      sub=res.substitute({arg:s})\n"
1066                     res="sub"
1067                     text+=mkText("array","ref",r,r1)
1068                  else:
1069                     res="res"
1070                     text+=mkText(case0,"ref",r,r1)
1071                  text+=mkTypeAndShapeTest(case0,sh_t,"res")
1072                  text+="      self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1073                  
1074                  if case0 == "taggedData" :
1075                      t_prog_with_tags+=text
1076                  else:              
1077                      t_prog+=text
1078    
1079    print test_header
1080    # print t_prog
1081    print t_prog_with_tags
1082    print test_tail          
1083    1/0
1084    #=======================================================================================================
1085    # L2
1086    #=======================================================================================================
1087    for where in ["Function","FunctionOnBoundary","FunctionOnContactZero","FunctionOnContactOne"]:
1088      for data in ["Data","Symbol"]:
1089        for sh in [ (),(2,), (4,5), (6,2,2),(4,5,3,2)]:
1090             text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1091             tname="test_L2_on%s_from%s_rank%s"%(where,data,len(sh))
1092             text+="   def %s(self):\n"%tname
1093             text+="      \"\"\"\n"
1094             text+="      tests L2-norm of %s on the %s\n\n"%(data,where)
1095             text+="      assumptions: self.domain supports integration on %s\n"%where
1096             text+="      \"\"\"\n"
1097             text+="      dim=self.domain.getDim()\n"
1098             text+="      w=%s(self.domain)\n"%where
1099             text+="      x=w.getX()\n"
1100             o="1"
1101             if len(sh)>0:
1102                sh_2=sh[:len(sh)-1]+(2,)
1103                sh_3=sh[:len(sh)-1]+(3,)            
1104                b_2=makeArray(sh[:len(sh)-1]+(2,),[-1.,1])
1105                b_3=makeArray(sh[:len(sh)-1]+(3,),[-1.,1])
1106             else:
1107                sh_2=()
1108                sh_3=()
1109                b_2=makeArray(sh,[-1.,1])
1110                b_3=makeArray(sh,[-1.,1])
1111    
1112             if data=="Symbol":
1113                val="s"
1114                res="sub"
1115             else:
1116                val="arg"
1117                res="res"
1118             text+="      if dim==2:\n"
1119             if data=="Symbol":
1120                   text+="        arg=Symbol(shape=%s,dim=dim)\n"%str(sh_2)
1121    
1122             text+="        %s=Data(0,%s,w)\n"%(val,sh_2)
1123             text+=unrollLoopsSimplified(b_2,val,tap="        ")
1124             text+=unrollLoopsOfL2(b_2,where,"ref",tap="        ")
1125             text+="\n      else:\n"
1126             if data=="Symbol":
1127                   text+="        arg=Symbol(shape=%s,dim=dim)\n"%str(sh_3)
1128             text+="        %s=Data(0,%s,w)\n"%(val,sh_3)        
1129             text+=unrollLoopsSimplified(b_3,val,tap="        ")
1130             text+=unrollLoopsOfL2(b_3,where,"ref",tap="        ")
1131             text+="\n      res=L2(arg)\n"
1132             if data=="Symbol":
1133                text+="      sub=res.substitute({arg:s})\n"
1134                text+="      self.failUnless(isinstance(res,Symbol),\"wrong type of result.\")\n"
1135                text+="      self.failUnlessEqual(res.getShape(),(),\"wrong shape of result.\")\n"
1136             else:
1137                text+="      self.failUnless(isinstance(res,float),\"wrong type of result.\")\n"
1138             text+="      self.failUnlessAlmostEqual(%s,ref,int(-log10(self.RES_TOL)),\"wrong result\")\n"%res
1139             t_prog+=text
1140    print t_prog
1141    1/0
1142    
1143    #=======================================================================================================
1144    # div
1145    #=======================================================================================================
1146    for where in ["Function","FunctionOnBoundary","FunctionOnContactZero","FunctionOnContactOne"]:
1147      for data in ["Data","Symbol"]:
1148         for case in ["ContinuousFunction","Solution","ReducedSolution"]:
1149             text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1150             tname="test_div_on%s_from%s_%s"%(where,data,case)
1151             text+="   def %s(self):\n"%tname
1152             text+="      \"\"\"\n"
1153             text+="      tests divergence of %s on the %s\n\n"%(data,where)
1154             text+="      assumptions: %s(self.domain) exists\n"%case
1155             text+="                   self.domain supports div on %s\n"%where
1156             text+="      \"\"\"\n"
1157             if case=="ReducedSolution":
1158                text+="      o=1\n"
1159                o="1"
1160             else:
1161                text+="      o=self.order\n"
1162                o="o"
1163             text+="      dim=self.domain.getDim()\n"
1164             text+="      w_ref=%s(self.domain)\n"%where
1165             text+="      x_ref=w_ref.getX()\n"
1166             text+="      w=%s(self.domain)\n"%case
1167             text+="      x=w.getX()\n"
1168             a_2=makeArray((2,2),[-1.,1])
1169             b_2=makeArray((2,2),[-1.,1])
1170             a_3=makeArray((3,3),[-1.,1])
1171             b_3=makeArray((3,3),[-1.,1])
1172             if data=="Symbol":
1173                text+="      arg=Symbol(shape=(dim,),dim=dim)\n"
1174                val="s"
1175                res="sub"
1176             else:
1177                val="arg"
1178                res="res"
1179             text+="      %s=Vector(0,w)\n"%val
1180             text+="      if dim==2:\n"
1181             text+=unrollLoops(a_2,b_2,o,val,tap="        ")
1182             text+=unrollLoopsOfDiv(a_2,b_2,o,"ref",tap="        ")
1183             text+="\n      else:\n"
1184            
1185             text+=unrollLoops(a_3,b_3,o,val,tap="        ")
1186             text+=unrollLoopsOfDiv(a_3,b_3,o,"ref",tap="        ")
1187             text+="\n      res=div(arg,where=w_ref)\n"
1188             if data=="Symbol":
1189                text+="      sub=res.substitute({arg:s})\n"
1190             text+="      self.failUnless(isinstance(res,%s),\"wrong type of result.\")\n"%data
1191             text+="      self.failUnlessEqual(res.getShape(),(),\"wrong shape of result.\")\n"
1192             text+="      self.failUnlessEqual(%s.getFunctionSpace(),w_ref,\"wrong function space of result.\")\n"%res
1193             text+="      self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1194    
1195    
1196             t_prog+=text
1197    print t_prog
1198    1/0
1199    
1200    #=======================================================================================================
1201    # interpolation
1202    #=======================================================================================================
1203    for where in ["Function","FunctionOnBoundary","FunctionOnContactZero","FunctionOnContactOne","Solution","ReducedSolution"]:
1204      for data in ["Data","Symbol"]:
1205         for case in ["ContinuousFunction","Solution","ReducedSolution","Function","FunctionOnBoundary","FunctionOnContactZero","FunctionOnContactOne"]:
1206          for sh in [ (),(2,), (4,5), (6,2,2),(4,5,3,2)]:
1207            if  where==case or \
1208                ( case in ["ContinuousFunction","Solution","ReducedSolution"] and where in ["Function","FunctionOnBoundary","FunctionOnContactZero","FunctionOnContactOne"] ) or \
1209                ( case in ["FunctionOnContactZero","FunctionOnContactOne"] and where in ["FunctionOnContactZero","FunctionOnContactOne"] ) or \
1210                (case=="ContinuousFunction" and  where in ["Solution","ReducedSolution"]) or \
1211                (case=="Solution" and  where=="ReducedSolution") :
1212                
1213    
1214             text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1215             tname="test_interpolation_on%s_from%s_%s_rank%s"%(where,data,case,len(sh))
1216             text+="   def %s(self):\n"%tname
1217             text+="      \"\"\"\n"
1218             text+="      tests interpolation for rank %s %s onto the %s\n\n"%(len(sh),data,where)
1219             text+="      assumptions: self.domain supports inpterpolation from %s onto %s\n"%(case,where)
1220             text+="      \"\"\"\n"
1221             if case=="ReducedSolution" or where=="ReducedSolution":
1222                text+="      o=1\n"
1223                o="1"
1224             else:
1225                text+="      o=self.order\n"
1226                o="o"
1227             text+="      dim=self.domain.getDim()\n"
1228             text+="      w_ref=%s(self.domain)\n"%where
1229             text+="      x_ref=w_ref.getX()\n"
1230             text+="      w=%s(self.domain)\n"%case
1231             text+="      x=w.getX()\n"
1232             a_2=makeArray(sh+(2,),[-1.,1])
1233             b_2=makeArray(sh+(2,),[-1.,1])
1234             a_3=makeArray(sh+(3,),[-1.,1])
1235             b_3=makeArray(sh+(3,),[-1.,1])
1236             if data=="Symbol":
1237                text+="      arg=Symbol(shape=%s,dim=dim)\n"%str(sh)
1238                val="s"
1239                res="sub"
1240             else:
1241                val="arg"
1242                res="res"
1243             text+="      %s=Data(0,%s,w)\n"%(val,str(sh))
1244             text+="      ref=Data(0,%s,w_ref)\n"%str(sh)
1245             text+="      if dim==2:\n"
1246             text+=unrollLoops(a_2,b_2,o,val,tap="        ")
1247             text+=unrollLoops(a_2,b_2,o,"ref",tap="        ",x="x_ref")
1248             text+="      else:\n"
1249            
1250             text+=unrollLoops(a_3,b_3,o,val,tap="        ")
1251             text+=unrollLoops(a_3,b_3,o,"ref",tap="        ",x="x_ref")
1252             text+="      res=interpolate(arg,where=w_ref)\n"
1253             if data=="Symbol":
1254                text+="      sub=res.substitute({arg:s})\n"
1255             text+="      self.failUnless(isinstance(res,%s),\"wrong type of result.\")\n"%data
1256             text+="      self.failUnlessEqual(%s.getFunctionSpace(),w_ref,\"wrong functionspace of result.\")\n"%res
1257             text+="      self.failUnlessEqual(res.getShape(),%s,\"wrong shape of result.\")\n"%str(sh)
1258             text+="      self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1259             t_prog+=text
1260    print test_header
1261    print t_prog
1262    print test_tail          
1263    1/0
1264    
1265  #=======================================================================================================  #=======================================================================================================
1266  # grad  # grad

Legend:
Removed from v.438  
changed lines
  Added in v.493

  ViewVC Help
Powered by ViewVC 1.1.26