/[escript]/trunk/escript/py_src/generatediff
ViewVC logotype

Diff of /trunk/escript/py_src/generatediff

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 441 by gross, Fri Jan 20 03:40:39 2006 UTC revision 443 by gross, Fri Jan 20 06:22:38 2006 UTC
# Line 696  def minimumTEST(arg0,arg1): Line 696  def minimumTEST(arg0,arg1):
696                else:                else:
697                 out[i0,i1,i2,i3]=arg0[i0,i1,i2,i3]                 out[i0,i1,i2,i3]=arg0[i0,i1,i2,i3]
698       return out       return out
699        
700  def unrollLoops(a,b,o,arg,tap="",x="x"):  def unrollLoops(a,b,o,arg,tap="",x="x"):
701      out=""      out=""
702      if a.rank==1:      if a.rank==1:
# Line 798  def unrollLoopsOfGrad(a,b,o,arg,tap=""): Line 799  def unrollLoopsOfGrad(a,b,o,arg,tap=""):
799                   out+=tap+"%s[%s,%s,%s,%s]=o*(%s)*x_ref[%s]**(o-1)+(%s)\n"%(arg,i0,i1,i2,i99,a[i0,i1,i2,i99],i99,b[i0,i1,i2,i99])                   out+=tap+"%s[%s,%s,%s,%s]=o*(%s)*x_ref[%s]**(o-1)+(%s)\n"%(arg,i0,i1,i2,i99,a[i0,i1,i2,i99],i99,b[i0,i1,i2,i99])
800      return out      return out
801  def unrollLoopsOfDiv(a,b,o,arg,tap=""):  def unrollLoopsOfDiv(a,b,o,arg,tap=""):
   
   
       
802      out=tap+arg+"="      out=tap+arg+"="
803      if o=="1":      if o=="1":
804         z=0.         z=0.
# Line 935  def unrollLoopsOfInteriorIntegral(a,b,wh Line 933  def unrollLoopsOfInteriorIntegral(a,b,wh
933                 out+="+(%s)*0.5**o\n"%zop                 out+="+(%s)*0.5**o\n"%zop
934    
935      return out      return out
936    def unrollLoopsSimplified(b,arg,tap=""):
937        out=""
938        if isinstance(b,float) or b.rank==0:
939                 out+=tap+"%s=(%s)*x[0]\n"%(arg,str(b))
940    
941        elif b.rank==1:
942            for i0 in range(b.shape[0]):
943                 out+=tap+"%s[%s]=(%s)*x[%s]\n"%(arg,i0,b[i0],i0)
944        elif b.rank==2:
945            for i0 in range(b.shape[0]):
946             for i1 in range(b.shape[1]):
947                 out+=tap+"%s[%s,%s]=(%s)*x[%s]\n"%(arg,i0,i1,b[i0,i1],i1)
948        elif b.rank==3:
949            for i0 in range(b.shape[0]):
950             for i1 in range(b.shape[1]):
951               for i2 in range(b.shape[2]):
952                 out+=tap+"%s[%s,%s,%s]=(%s)*x[%s]\n"%(arg,i0,i1,i2,b[i0,i1,i2],i2)
953        elif b.rank==4:
954            for i0 in range(b.shape[0]):
955             for i1 in range(b.shape[1]):
956               for i2 in range(b.shape[2]):
957                for i3 in range(b.shape[3]):
958                 out+=tap+"%s[%s,%s,%s,%s]=(%s)*x[%s]\n"%(arg,i0,i1,i2,i3,b[i0,i1,i2,i3],i3)
959        return out
960    
961    def unrollLoopsOfL2(b,where,arg,tap=""):
962        out=""
963        z=[]
964        if isinstance(b,float) or b.rank==0:
965           z.append(b**2)
966        elif b.rank==1:
967            for i0 in range(b.shape[0]):
968                 z.append(b[i0]**2)
969        elif b.rank==2:
970            for i1 in range(b.shape[1]):
971               s=0
972               for i0 in range(b.shape[0]):
973                  s+=b[i0,i1]**2
974               z.append(s)
975        elif b.rank==3:
976            for i2 in range(b.shape[2]):
977              s=0
978              for i0 in range(b.shape[0]):
979                 for i1 in range(b.shape[1]):
980                    s+=b[i0,i1,i2]**2
981              z.append(s)
982    
983        elif b.rank==4:
984          for i3 in range(b.shape[3]):
985             s=0
986             for i0 in range(b.shape[0]):
987               for i1 in range(b.shape[1]):
988                  for i2 in range(b.shape[2]):
989                     s+=b[i0,i1,i2,i3]**2
990             z.append(s)        
991        if where=="Function":
992           xfac_o=1.
993           xfac_op=0.
994           z_fac_s=""
995           zo_fac_s=""
996           zo_fac=1./3.
997        elif where=="FunctionOnBoundary":
998           xfac_o=1.
999           xfac_op=0.
1000           z_fac_s="*dim"
1001           zo_fac_s="*(2.*dim+1.)/3."
1002           zo_fac=1.
1003        elif where in ["FunctionOnContactZero","FunctionOnContactOne"]:
1004           xfac_o=0.
1005           xfac_op=1.
1006           z_fac_s=""
1007           zo_fac_s=""    
1008           zo_fac=1./3.    
1009        zo=0.
1010        zop=0.
1011        for i99 in range(len(z)):
1012               if i99==0:
1013                   zo+=xfac_o*z[i99]
1014                   zop+=xfac_op*z[i99]
1015               else:
1016                   zo+=z[i99]
1017        out+=tap+"%s=sqrt((%s)%s"%(arg,zo*zo_fac,zo_fac_s)
1018        if zop==0.:
1019           out+=")\n"
1020        else:
1021           out+="+(%s))\n"%(zop*0.5**2)
1022        return out
1023    
1024    #=======================================================================================================
1025    # L2
1026    #=======================================================================================================
1027    for where in ["Function","FunctionOnBoundary","FunctionOnContactZero","FunctionOnContactOne"]:
1028      for data in ["Data","Symbol"]:
1029        for sh in [ (),(2,), (4,5), (6,2,2),(4,5,3,2)]:
1030             text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1031             tname="test_L2_on%s_from%s_rank%s"%(where,data,len(sh))
1032             text+="   def %s(self):\n"%tname
1033             text+="      \"\"\"\n"
1034             text+="      tests L2-norm of %s on the %s\n\n"%(data,where)
1035             text+="      assumptions: self.domain supports integration on %s\n"%where
1036             text+="      \"\"\"\n"
1037             text+="      dim=self.domain.getDim()\n"
1038             text+="      w=%s(self.domain)\n"%where
1039             text+="      x=w.getX()\n"
1040             o="1"
1041             if len(sh)>0:
1042                sh_2=sh[:len(sh)-1]+(2,)
1043                sh_3=sh[:len(sh)-1]+(3,)            
1044                b_2=makeArray(sh[:len(sh)-1]+(2,),[-1.,1])
1045                b_3=makeArray(sh[:len(sh)-1]+(3,),[-1.,1])
1046             else:
1047                sh_2=()
1048                sh_3=()
1049                b_2=makeArray(sh,[-1.,1])
1050                b_3=makeArray(sh,[-1.,1])
1051    
1052             if data=="Symbol":
1053                val="s"
1054                res="sub"
1055             else:
1056                val="arg"
1057                res="res"
1058             text+="      if dim==2:\n"
1059             if data=="Symbol":
1060                   text+="        arg=Symbol(shape=%s,dim=dim)\n"%str(sh_2)
1061    
1062             text+="        %s=Data(0,%s,w)\n"%(val,sh_2)
1063             text+=unrollLoopsSimplified(b_2,val,tap="        ")
1064             text+=unrollLoopsOfL2(b_2,where,"ref",tap="        ")
1065             text+="\n      else:\n"
1066             if data=="Symbol":
1067                   text+="        arg=Symbol(shape=%s,dim=dim)\n"%str(sh_3)
1068             text+="        %s=Data(0,%s,w)\n"%(val,sh_3)        
1069             text+=unrollLoopsSimplified(b_3,val,tap="        ")
1070             text+=unrollLoopsOfL2(b_3,where,"ref",tap="        ")
1071             text+="\n      res=L2(arg)\n"
1072             if data=="Symbol":
1073                text+="      sub=res.substitute({arg:s})\n"
1074                text+="      self.failUnless(isinstance(res,Symbol),\"wrong type of result.\")\n"
1075                text+="      self.failUnlessEqual(res.getShape(),(),\"wrong shape of result.\")\n"
1076             else:
1077                text+="      self.failUnless(isinstance(res,float),\"wrong type of result.\")\n"
1078             text+="      self.failUnlessAlmostEqual(%s,ref,int(-log10(self.RES_TOL)),\"wrong result\")\n"%res
1079             t_prog+=text
1080    print t_prog
1081    1/0
1082    
1083  #=======================================================================================================  #=======================================================================================================
1084  # div  # div
1085  #=======================================================================================================  #=======================================================================================================

Legend:
Removed from v.441  
changed lines
  Added in v.443

  ViewVC Help
Powered by ViewVC 1.1.26