/[escript]/trunk/escript/py_src/generatediff
ViewVC logotype

Diff of /trunk/escript/py_src/generatediff

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 441 by gross, Fri Jan 20 03:40:39 2006 UTC revision 493 by gross, Fri Feb 3 02:18:45 2006 UTC
# Line 696  def minimumTEST(arg0,arg1): Line 696  def minimumTEST(arg0,arg1):
696                else:                else:
697                 out[i0,i1,i2,i3]=arg0[i0,i1,i2,i3]                 out[i0,i1,i2,i3]=arg0[i0,i1,i2,i3]
698       return out       return out
699        
700  def unrollLoops(a,b,o,arg,tap="",x="x"):  def unrollLoops(a,b,o,arg,tap="",x="x"):
701      out=""      out=""
702      if a.rank==1:      if a.rank==1:
# Line 798  def unrollLoopsOfGrad(a,b,o,arg,tap=""): Line 799  def unrollLoopsOfGrad(a,b,o,arg,tap=""):
799                   out+=tap+"%s[%s,%s,%s,%s]=o*(%s)*x_ref[%s]**(o-1)+(%s)\n"%(arg,i0,i1,i2,i99,a[i0,i1,i2,i99],i99,b[i0,i1,i2,i99])                   out+=tap+"%s[%s,%s,%s,%s]=o*(%s)*x_ref[%s]**(o-1)+(%s)\n"%(arg,i0,i1,i2,i99,a[i0,i1,i2,i99],i99,b[i0,i1,i2,i99])
800      return out      return out
801  def unrollLoopsOfDiv(a,b,o,arg,tap=""):  def unrollLoopsOfDiv(a,b,o,arg,tap=""):
   
   
       
802      out=tap+arg+"="      out=tap+arg+"="
803      if o=="1":      if o=="1":
804         z=0.         z=0.
# Line 935  def unrollLoopsOfInteriorIntegral(a,b,wh Line 933  def unrollLoopsOfInteriorIntegral(a,b,wh
933                 out+="+(%s)*0.5**o\n"%zop                 out+="+(%s)*0.5**o\n"%zop
934    
935      return out      return out
936    def unrollLoopsSimplified(b,arg,tap=""):
937        out=""
938        if isinstance(b,float) or b.rank==0:
939                 out+=tap+"%s=(%s)*x[0]\n"%(arg,str(b))
940    
941        elif b.rank==1:
942            for i0 in range(b.shape[0]):
943                 out+=tap+"%s[%s]=(%s)*x[%s]\n"%(arg,i0,b[i0],i0)
944        elif b.rank==2:
945            for i0 in range(b.shape[0]):
946             for i1 in range(b.shape[1]):
947                 out+=tap+"%s[%s,%s]=(%s)*x[%s]\n"%(arg,i0,i1,b[i0,i1],i1)
948        elif b.rank==3:
949            for i0 in range(b.shape[0]):
950             for i1 in range(b.shape[1]):
951               for i2 in range(b.shape[2]):
952                 out+=tap+"%s[%s,%s,%s]=(%s)*x[%s]\n"%(arg,i0,i1,i2,b[i0,i1,i2],i2)
953        elif b.rank==4:
954            for i0 in range(b.shape[0]):
955             for i1 in range(b.shape[1]):
956               for i2 in range(b.shape[2]):
957                for i3 in range(b.shape[3]):
958                 out+=tap+"%s[%s,%s,%s,%s]=(%s)*x[%s]\n"%(arg,i0,i1,i2,i3,b[i0,i1,i2,i3],i3)
959        return out
960    
961    def unrollLoopsOfL2(b,where,arg,tap=""):
962        out=""
963        z=[]
964        if isinstance(b,float) or b.rank==0:
965           z.append(b**2)
966        elif b.rank==1:
967            for i0 in range(b.shape[0]):
968                 z.append(b[i0]**2)
969        elif b.rank==2:
970            for i1 in range(b.shape[1]):
971               s=0
972               for i0 in range(b.shape[0]):
973                  s+=b[i0,i1]**2
974               z.append(s)
975        elif b.rank==3:
976            for i2 in range(b.shape[2]):
977              s=0
978              for i0 in range(b.shape[0]):
979                 for i1 in range(b.shape[1]):
980                    s+=b[i0,i1,i2]**2
981              z.append(s)
982    
983        elif b.rank==4:
984          for i3 in range(b.shape[3]):
985             s=0
986             for i0 in range(b.shape[0]):
987               for i1 in range(b.shape[1]):
988                  for i2 in range(b.shape[2]):
989                     s+=b[i0,i1,i2,i3]**2
990             z.append(s)        
991        if where=="Function":
992           xfac_o=1.
993           xfac_op=0.
994           z_fac_s=""
995           zo_fac_s=""
996           zo_fac=1./3.
997        elif where=="FunctionOnBoundary":
998           xfac_o=1.
999           xfac_op=0.
1000           z_fac_s="*dim"
1001           zo_fac_s="*(2.*dim+1.)/3."
1002           zo_fac=1.
1003        elif where in ["FunctionOnContactZero","FunctionOnContactOne"]:
1004           xfac_o=0.
1005           xfac_op=1.
1006           z_fac_s=""
1007           zo_fac_s=""    
1008           zo_fac=1./3.    
1009        zo=0.
1010        zop=0.
1011        for i99 in range(len(z)):
1012               if i99==0:
1013                   zo+=xfac_o*z[i99]
1014                   zop+=xfac_op*z[i99]
1015               else:
1016                   zo+=z[i99]
1017        out+=tap+"%s=sqrt((%s)%s"%(arg,zo*zo_fac,zo_fac_s)
1018        if zop==0.:
1019           out+=")\n"
1020        else:
1021           out+="+(%s))\n"%(zop*0.5**2)
1022        return out
1023    #=======================================================================================================
1024    # transpose
1025    #=======================================================================================================
1026    def transposeTest(r,offset):
1027        if isinstance(r,float): return r
1028        s=r.shape
1029        s1=1
1030        for i in s[:offset]: s1*=i
1031        s2=1
1032        for i in s[offset:]: s2*=i
1033        out=numarray.reshape(r,(s1,s2))
1034        out.transpose()
1035        return numarray.resize(out,s[offset:]+s[:offset])
1036    
1037    name,tt="transpose",transposeTest
1038    for case0 in ["array","Symbol","constData","taggedData","expandedData"]:
1039      for sh0 in [ (), (3,), (4,5), (6,2,2),(3,2,3,4)]:
1040        for offset in range(len(sh0)+1):
1041                  text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1042                  tname="test_%s_%s_rank%s_offset%s"%(name,case0,len(sh0),offset)
1043                  text+="   def %s(self):\n"%tname
1044                  sh_t=sh0[offset:]+sh0[:offset]
1045    
1046    #              sh_t=list(sh0)
1047    #              sh_t[offset+1]=sh_t[offset]
1048    #              sh_t=tuple(sh_t)
1049    #              sh_r=[]
1050    #              for i in range(offset): sh_r.append(sh0[i])
1051    #              for i in range(offset+2,len(sh0)): sh_r.append(sh0[i])              
1052    #              sh_r=tuple(sh_r)
1053    
1054                  a_0=makeArray(sh0,[-1.,1])
1055                  if case0 in ["taggedData", "expandedData"]:
1056                      a1_0=makeArray(sh0,[-1.,1])
1057                  else:
1058                      a1_0=a_0
1059                  r=tt(a_0,offset)
1060                  r1=tt(a1_0,offset)
1061                  text+=mkText(case0,"arg",a_0,a1_0)
1062                  text+="      res=%s(arg,%s)\n"%(name,offset)
1063                  if case0=="Symbol":
1064                     text+=mkText("array","s",a_0,a1_0)
1065                     text+="      sub=res.substitute({arg:s})\n"
1066                     res="sub"
1067                     text+=mkText("array","ref",r,r1)
1068                  else:
1069                     res="res"
1070                     text+=mkText(case0,"ref",r,r1)
1071                  text+=mkTypeAndShapeTest(case0,sh_t,"res")
1072                  text+="      self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1073                  
1074                  if case0 == "taggedData" :
1075                      t_prog_with_tags+=text
1076                  else:              
1077                      t_prog+=text
1078    
1079    print test_header
1080    # print t_prog
1081    print t_prog_with_tags
1082    print test_tail          
1083    1/0
1084    #=======================================================================================================
1085    # L2
1086    #=======================================================================================================
1087    for where in ["Function","FunctionOnBoundary","FunctionOnContactZero","FunctionOnContactOne"]:
1088      for data in ["Data","Symbol"]:
1089        for sh in [ (),(2,), (4,5), (6,2,2),(4,5,3,2)]:
1090             text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1091             tname="test_L2_on%s_from%s_rank%s"%(where,data,len(sh))
1092             text+="   def %s(self):\n"%tname
1093             text+="      \"\"\"\n"
1094             text+="      tests L2-norm of %s on the %s\n\n"%(data,where)
1095             text+="      assumptions: self.domain supports integration on %s\n"%where
1096             text+="      \"\"\"\n"
1097             text+="      dim=self.domain.getDim()\n"
1098             text+="      w=%s(self.domain)\n"%where
1099             text+="      x=w.getX()\n"
1100             o="1"
1101             if len(sh)>0:
1102                sh_2=sh[:len(sh)-1]+(2,)
1103                sh_3=sh[:len(sh)-1]+(3,)            
1104                b_2=makeArray(sh[:len(sh)-1]+(2,),[-1.,1])
1105                b_3=makeArray(sh[:len(sh)-1]+(3,),[-1.,1])
1106             else:
1107                sh_2=()
1108                sh_3=()
1109                b_2=makeArray(sh,[-1.,1])
1110                b_3=makeArray(sh,[-1.,1])
1111    
1112             if data=="Symbol":
1113                val="s"
1114                res="sub"
1115             else:
1116                val="arg"
1117                res="res"
1118             text+="      if dim==2:\n"
1119             if data=="Symbol":
1120                   text+="        arg=Symbol(shape=%s,dim=dim)\n"%str(sh_2)
1121    
1122             text+="        %s=Data(0,%s,w)\n"%(val,sh_2)
1123             text+=unrollLoopsSimplified(b_2,val,tap="        ")
1124             text+=unrollLoopsOfL2(b_2,where,"ref",tap="        ")
1125             text+="\n      else:\n"
1126             if data=="Symbol":
1127                   text+="        arg=Symbol(shape=%s,dim=dim)\n"%str(sh_3)
1128             text+="        %s=Data(0,%s,w)\n"%(val,sh_3)        
1129             text+=unrollLoopsSimplified(b_3,val,tap="        ")
1130             text+=unrollLoopsOfL2(b_3,where,"ref",tap="        ")
1131             text+="\n      res=L2(arg)\n"
1132             if data=="Symbol":
1133                text+="      sub=res.substitute({arg:s})\n"
1134                text+="      self.failUnless(isinstance(res,Symbol),\"wrong type of result.\")\n"
1135                text+="      self.failUnlessEqual(res.getShape(),(),\"wrong shape of result.\")\n"
1136             else:
1137                text+="      self.failUnless(isinstance(res,float),\"wrong type of result.\")\n"
1138             text+="      self.failUnlessAlmostEqual(%s,ref,int(-log10(self.RES_TOL)),\"wrong result\")\n"%res
1139             t_prog+=text
1140    print t_prog
1141    1/0
1142    
1143  #=======================================================================================================  #=======================================================================================================
1144  # div  # div
1145  #=======================================================================================================  #=======================================================================================================

Legend:
Removed from v.441  
changed lines
  Added in v.493

  ViewVC Help
Powered by ViewVC 1.1.26