/[escript]/trunk/escript/py_src/generatediff
ViewVC logotype

Diff of /trunk/escript/py_src/generatediff

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 441 by gross, Fri Jan 20 03:40:39 2006 UTC revision 536 by gross, Fri Feb 17 03:20:53 2006 UTC
# Line 184  def makeArray(shape,rng): Line 184  def makeArray(shape,rng):
184         raise SystemError,"rank is restricted to 5"         raise SystemError,"rank is restricted to 5"
185     return out             return out        
186    
187    def makeNumberedArray(shape,s=1.):
188       out=numarray.zeros(shape,numarray.Float64)
189       if len(shape)==0:
190           out=s*1.
191       elif len(shape)==1:
192           for i0 in range(shape[0]):
193                       out[i0]=s*i0
194       elif len(shape)==2:
195           for i0 in range(shape[0]):
196              for i1 in range(shape[1]):
197                       out[i0,i1]=s*(i1+shape[1]*i0)
198       elif len(shape)==3:
199           for i0 in range(shape[0]):
200              for i1 in range(shape[1]):
201                 for i2 in range(shape[2]):
202                       out[i0,i1,i2]=s*(i2+shape[2]*i1+shape[2]*shape[1]*i0)
203       elif len(shape)==4:
204           for i0 in range(shape[0]):
205              for i1 in range(shape[1]):
206                 for i2 in range(shape[2]):
207                    for i3 in range(shape[3]):
208                       out[i0,i1,i2,i3]=s*(i3+shape[3]*i2+shape[3]*shape[2]*i1+shape[3]*shape[2]*shape[1]*i0)
209       else:
210           raise SystemError,"rank is restricted to 4"
211       return out        
212    
213  def makeResult(val,test_expr):  def makeResult(val,test_expr):
214     if isinstance(val,float):     if isinstance(val,float):
# Line 524  def mkCode(txt,args=[],intend=""): Line 549  def mkCode(txt,args=[],intend=""):
549      for r in args:      for r in args:
550        out=out.replace("%%a%s%%"%c,r)        out=out.replace("%%a%s%%"%c,r)
551      return out        return out  
552    #=======================================================================================================
553    # nonsymmetric part
554    #=======================================================================================================
555    from esys.escript import *
556    for name in ["symmetric", "nonsymmetric"]:
557     f=1.
558     if name=="nonsymmetric": f=-1
559     for case0 in ["array","Symbol","constData","taggedData","expandedData"]:
560      for sh0 in [ (3,3), (2,3,2,3)]:
561                  text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
562                  tname="test_%s_%s_rank%s"%(name,case0,len(sh0))
563                  text+="   def %s(self):\n"%tname
564                  a_0=makeNumberedArray(sh0,s=1.)
565                  r_0=(a_0+f*transpose(a_0))/2.
566                  if case0 in ["taggedData", "expandedData"]:
567                     a1_0=makeNumberedArray(sh0,s=-1.)
568                     r1_0=(a1_0+f*transpose(a1_0))/2.
569                  else:
570                      a1_0=a_0                  
571                      r1_0=r_0
572                  text+=mkText(case0,"arg",a_0,a1_0)
573                  text+="      res=%s(arg)\n"%name
574                  if case0=="Symbol":
575                     text+=mkText("array","s",a_0,a1_0)
576                     text+="      sub=res.substitute({arg:s})\n"
577                     res="sub"
578                     text+=mkText("array","ref",r_0,r1_0)
579                  else:
580                     res="res"
581                     text+=mkText(case0,"ref",r_0,r1_0)  
582                  text+=mkTypeAndShapeTest(case0,sh0,"res")
583                  text+="      self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
584                  
585                  if case0 == "taggedData" :
586                      t_prog_with_tags+=text
587                  else:              
588                      t_prog+=text
589    print test_header
590    print t_prog
591    # print t_prog_with_tags
592    print test_tail          
593    1/0
594    
595    #=======================================================================================================
596    # eigenvalues
597    #=======================================================================================================
598    import numarray.linear_algebra
599    name="eigenvalues"
600    for case0 in ["array","Symbol","constData","taggedData","expandedData"]:
601      for sh0 in [ (1,1), (2,2), (3,3)]:
602                  text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
603                  tname="test_%s_%s_dim%s"%(name,case0,sh0[0])
604                  text+="   def %s(self):\n"%tname
605                  a_0=makeArray(sh0,[-1.,1])
606                  a_0=(a_0+numarray.transpose(a_0))/2.
607                  ev=numarray.linear_algebra.eigenvalues(a_0)
608                  ev.sort()
609                  if case0 in ["taggedData", "expandedData"]:
610                      a1_0=makeArray(sh0,[-1.,1])
611                      a1_0=(a1_0+numarray.transpose(a1_0))/2.
612                      ev1=numarray.linear_algebra.eigenvalues(a1_0)
613                      ev1.sort()
614                  else:
615                      a1_0=a_0                  
616                      ev1=ev
617                  text+=mkText(case0,"arg",a_0,a1_0)
618                  text+="      res=%s(arg)\n"%name
619                  if case0=="Symbol":
620                     text+=mkText("array","s",a_0,a1_0)
621                     text+="      sub=res.substitute({arg:s})\n"
622                     res="sub"
623                     text+=mkText("array","ref",ev,ev1)
624                  else:
625                     res="res"
626                     text+=mkText(case0,"ref",ev,ev1)  
627                  text+=mkTypeAndShapeTest(case0,(sh0[0],),"res")
628                  text+="      self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
629                  
630                  if case0 == "taggedData" :
631                      t_prog_with_tags+=text
632                  else:              
633                      t_prog+=text
634    print test_header
635    # print t_prog
636    print t_prog_with_tags
637    print test_tail          
638    1/0
639    
640    #=======================================================================================================
641    # slicing
642    #=======================================================================================================
643    for case0 in ["constData","taggedData","expandedData","Symbol"]:
644      for sh0 in [ (3,), (3,4), (3,4,3) ,(4,3,5,3)]:
645        # get perm:
646        if len(sh0)==2:
647            check=[[1,0]]
648        elif len(sh0)==3:
649            check=[[1,0,2],
650                   [1,2,0],
651                   [2,1,0],
652                   [2,0,2],
653                   [0,2,1]]
654        elif len(sh0)==4:
655            check=[[0,1,3,2],
656                   [0,2,1,3],
657                   [0,2,3,1],
658                   [0,3,2,1],
659                   [0,3,1,2] ,          
660                   [1,0,2,3],
661                   [1,0,3,2],
662                   [1,2,0,3],
663                   [1,2,3,0],
664                   [1,3,2,0],
665                   [1,3,0,2],
666                   [2,0,1,3],
667                   [2,0,3,1],
668                   [2,1,0,3],
669                   [2,1,3,0],
670                   [2,3,1,0],
671                   [2,3,0,1],
672                   [3,0,1,2],
673                   [3,0,2,1],
674                   [3,1,0,2],
675                   [3,1,2,0],
676                   [3,2,1,0],
677                   [3,2,0,1]]
678        else:
679             check=[]
680        
681        # create the test cases:
682        processed=[]
683        l=["R","U","L","P","C","N"]
684        c=[""]
685        for i in range(len(sh0)):
686           tmp=[]
687           for ci in c:
688              tmp+=[ci+li for li in l]
689           c=tmp
690        # SHUFFLE
691        c2=[]
692        while len(c)>0:
693            i=int(random.random()*len(c))
694            c2.append(c[i])
695            del c[i]
696        c=c2
697        for ci in c:
698          t=""
699          sh=()
700          for i in range(len(ci)):
701              if ci[i]=="R":
702                 s="%s:%s"%(1,sh0[i]-1)
703                 sh=sh+(sh0[i]-2,)            
704              if ci[i]=="U":
705                  s=":%s"%(sh0[i]-1)
706                  sh=sh+(sh0[i]-1,)            
707              if ci[i]=="L":
708                  s="2:"
709                  sh=sh+(sh0[i]-2,)            
710              if ci[i]=="P":
711                  s="%s"%(int(sh0[i]/2))
712              if ci[i]=="C":
713                  s=":"
714                  sh=sh+(sh0[i],)            
715              if ci[i]=="N":
716                  s=""
717                  sh=sh+(sh0[i],)
718              if len(s)>0:
719                 if not t=="": t+=","
720                 t+=s
721          N_found=False
722          noN_found=False
723          process=len(t)>0
724          for i in ci:
725             if i=="N":
726                if not noN_found and N_found: process=False
727                N_found=True
728             else:
729               if N_found: process=False
730               noNfound=True
731          # is there a similar one processed allready
732          if process and ci.find("N")==-1:
733             for ci2 in processed:
734               for chi in check:
735                   is_perm=True
736                   for i in range(len(chi)):
737                       if not ci[i]==ci2[chi[i]]: is_perm=False
738                   if is_perm: process=False
739          # if not process: print ci," rejected"
740          if process:
741           processed.append(ci)
742           text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
743           tname="test_getslice_%s_rank%s_%s"%(case0,len(sh0),ci)
744           text+="   def %s(self):\n"%tname
745           a_0=makeNumberedArray(sh0,s=1)
746           if case0 in ["taggedData", "expandedData"]:
747                a1_0=makeNumberedArray(sh0,s=-1.)
748           else:
749                a1_0=a_0
750           r=eval("a_0[%s]"%t)
751           r1=eval("a1_0[%s]"%t)
752           text+=mkText(case0,"arg",a_0,a1_0)
753           text+="      res=arg[%s]\n"%t
754           if case0=="Symbol":
755               text+=mkText("array","s",a_0,a1_0)
756               text+="      sub=res.substitute({arg:s})\n"
757               res="sub"
758               text+=mkText("array","ref",r,r1)
759           else:
760               res="res"
761               text+=mkText(case0,"ref",r,r1)
762           text+=mkTypeAndShapeTest(case0,sh,"res")
763           text+="      self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
764                  
765           if case0 == "taggedData" :
766                t_prog_with_tags+=text
767           else:              
768                t_prog+=text
769    
770    print test_header
771    # print t_prog
772    print t_prog_with_tags
773    print test_tail          
774    1/0
775    #============================================================================================
776  def innerTEST(arg0,arg1):  def innerTEST(arg0,arg1):
777      if isinstance(arg0,float):      if isinstance(arg0,float):
778         out=numarray.array(arg0*arg1)         out=numarray.array(arg0*arg1)
# Line 696  def minimumTEST(arg0,arg1): Line 944  def minimumTEST(arg0,arg1):
944                else:                else:
945                 out[i0,i1,i2,i3]=arg0[i0,i1,i2,i3]                 out[i0,i1,i2,i3]=arg0[i0,i1,i2,i3]
946       return out       return out
947        
948  def unrollLoops(a,b,o,arg,tap="",x="x"):  def unrollLoops(a,b,o,arg,tap="",x="x"):
949      out=""      out=""
950      if a.rank==1:      if a.rank==1:
# Line 798  def unrollLoopsOfGrad(a,b,o,arg,tap=""): Line 1047  def unrollLoopsOfGrad(a,b,o,arg,tap=""):
1047                   out+=tap+"%s[%s,%s,%s,%s]=o*(%s)*x_ref[%s]**(o-1)+(%s)\n"%(arg,i0,i1,i2,i99,a[i0,i1,i2,i99],i99,b[i0,i1,i2,i99])                   out+=tap+"%s[%s,%s,%s,%s]=o*(%s)*x_ref[%s]**(o-1)+(%s)\n"%(arg,i0,i1,i2,i99,a[i0,i1,i2,i99],i99,b[i0,i1,i2,i99])
1048      return out      return out
1049  def unrollLoopsOfDiv(a,b,o,arg,tap=""):  def unrollLoopsOfDiv(a,b,o,arg,tap=""):
   
   
       
1050      out=tap+arg+"="      out=tap+arg+"="
1051      if o=="1":      if o=="1":
1052         z=0.         z=0.
# Line 935  def unrollLoopsOfInteriorIntegral(a,b,wh Line 1181  def unrollLoopsOfInteriorIntegral(a,b,wh
1181                 out+="+(%s)*0.5**o\n"%zop                 out+="+(%s)*0.5**o\n"%zop
1182    
1183      return out      return out
1184    def unrollLoopsSimplified(b,arg,tap=""):
1185        out=""
1186        if isinstance(b,float) or b.rank==0:
1187                 out+=tap+"%s=(%s)*x[0]\n"%(arg,str(b))
1188    
1189        elif b.rank==1:
1190            for i0 in range(b.shape[0]):
1191                 out+=tap+"%s[%s]=(%s)*x[%s]\n"%(arg,i0,b[i0],i0)
1192        elif b.rank==2:
1193            for i0 in range(b.shape[0]):
1194             for i1 in range(b.shape[1]):
1195                 out+=tap+"%s[%s,%s]=(%s)*x[%s]\n"%(arg,i0,i1,b[i0,i1],i1)
1196        elif b.rank==3:
1197            for i0 in range(b.shape[0]):
1198             for i1 in range(b.shape[1]):
1199               for i2 in range(b.shape[2]):
1200                 out+=tap+"%s[%s,%s,%s]=(%s)*x[%s]\n"%(arg,i0,i1,i2,b[i0,i1,i2],i2)
1201        elif b.rank==4:
1202            for i0 in range(b.shape[0]):
1203             for i1 in range(b.shape[1]):
1204               for i2 in range(b.shape[2]):
1205                for i3 in range(b.shape[3]):
1206                 out+=tap+"%s[%s,%s,%s,%s]=(%s)*x[%s]\n"%(arg,i0,i1,i2,i3,b[i0,i1,i2,i3],i3)
1207        return out
1208    
1209    def unrollLoopsOfL2(b,where,arg,tap=""):
1210        out=""
1211        z=[]
1212        if isinstance(b,float) or b.rank==0:
1213           z.append(b**2)
1214        elif b.rank==1:
1215            for i0 in range(b.shape[0]):
1216                 z.append(b[i0]**2)
1217        elif b.rank==2:
1218            for i1 in range(b.shape[1]):
1219               s=0
1220               for i0 in range(b.shape[0]):
1221                  s+=b[i0,i1]**2
1222               z.append(s)
1223        elif b.rank==3:
1224            for i2 in range(b.shape[2]):
1225              s=0
1226              for i0 in range(b.shape[0]):
1227                 for i1 in range(b.shape[1]):
1228                    s+=b[i0,i1,i2]**2
1229              z.append(s)
1230    
1231        elif b.rank==4:
1232          for i3 in range(b.shape[3]):
1233             s=0
1234             for i0 in range(b.shape[0]):
1235               for i1 in range(b.shape[1]):
1236                  for i2 in range(b.shape[2]):
1237                     s+=b[i0,i1,i2,i3]**2
1238             z.append(s)        
1239        if where=="Function":
1240           xfac_o=1.
1241           xfac_op=0.
1242           z_fac_s=""
1243           zo_fac_s=""
1244           zo_fac=1./3.
1245        elif where=="FunctionOnBoundary":
1246           xfac_o=1.
1247           xfac_op=0.
1248           z_fac_s="*dim"
1249           zo_fac_s="*(2.*dim+1.)/3."
1250           zo_fac=1.
1251        elif where in ["FunctionOnContactZero","FunctionOnContactOne"]:
1252           xfac_o=0.
1253           xfac_op=1.
1254           z_fac_s=""
1255           zo_fac_s=""    
1256           zo_fac=1./3.    
1257        zo=0.
1258        zop=0.
1259        for i99 in range(len(z)):
1260               if i99==0:
1261                   zo+=xfac_o*z[i99]
1262                   zop+=xfac_op*z[i99]
1263               else:
1264                   zo+=z[i99]
1265        out+=tap+"%s=sqrt((%s)%s"%(arg,zo*zo_fac,zo_fac_s)
1266        if zop==0.:
1267           out+=")\n"
1268        else:
1269           out+="+(%s))\n"%(zop*0.5**2)
1270        return out
1271    #=======================================================================================================
1272    # transpose
1273    #=======================================================================================================
1274    def transposeTest(r,offset):
1275        if isinstance(r,float): return r
1276        s=r.shape
1277        s1=1
1278        for i in s[:offset]: s1*=i
1279        s2=1
1280        for i in s[offset:]: s2*=i
1281        out=numarray.reshape(r,(s1,s2))
1282        out.transpose()
1283        return numarray.resize(out,s[offset:]+s[:offset])
1284    
1285    name,tt="transpose",transposeTest
1286    for case0 in ["array","Symbol","constData","taggedData","expandedData"]:
1287      for sh0 in [ (), (3,), (4,5), (6,2,2),(3,2,3,4)]:
1288        for offset in range(len(sh0)+1):
1289                  text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1290                  tname="test_%s_%s_rank%s_offset%s"%(name,case0,len(sh0),offset)
1291                  text+="   def %s(self):\n"%tname
1292                  sh_t=sh0[offset:]+sh0[:offset]
1293    
1294    #              sh_t=list(sh0)
1295    #              sh_t[offset+1]=sh_t[offset]
1296    #              sh_t=tuple(sh_t)
1297    #              sh_r=[]
1298    #              for i in range(offset): sh_r.append(sh0[i])
1299    #              for i in range(offset+2,len(sh0)): sh_r.append(sh0[i])              
1300    #              sh_r=tuple(sh_r)
1301    
1302                  a_0=makeArray(sh0,[-1.,1])
1303                  if case0 in ["taggedData", "expandedData"]:
1304                      a1_0=makeArray(sh0,[-1.,1])
1305                  else:
1306                      a1_0=a_0
1307                  r=tt(a_0,offset)
1308                  r1=tt(a1_0,offset)
1309                  text+=mkText(case0,"arg",a_0,a1_0)
1310                  text+="      res=%s(arg,%s)\n"%(name,offset)
1311                  if case0=="Symbol":
1312                     text+=mkText("array","s",a_0,a1_0)
1313                     text+="      sub=res.substitute({arg:s})\n"
1314                     res="sub"
1315                     text+=mkText("array","ref",r,r1)
1316                  else:
1317                     res="res"
1318                     text+=mkText(case0,"ref",r,r1)
1319                  text+=mkTypeAndShapeTest(case0,sh_t,"res")
1320                  text+="      self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1321                  
1322                  if case0 == "taggedData" :
1323                      t_prog_with_tags+=text
1324                  else:              
1325                      t_prog+=text
1326    
1327    print test_header
1328    # print t_prog
1329    print t_prog_with_tags
1330    print test_tail          
1331    1/0
1332    #=======================================================================================================
1333    # L2
1334    #=======================================================================================================
1335    for where in ["Function","FunctionOnBoundary","FunctionOnContactZero","FunctionOnContactOne"]:
1336      for data in ["Data","Symbol"]:
1337        for sh in [ (),(2,), (4,5), (6,2,2),(4,5,3,2)]:
1338             text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1339             tname="test_L2_on%s_from%s_rank%s"%(where,data,len(sh))
1340             text+="   def %s(self):\n"%tname
1341             text+="      \"\"\"\n"
1342             text+="      tests L2-norm of %s on the %s\n\n"%(data,where)
1343             text+="      assumptions: self.domain supports integration on %s\n"%where
1344             text+="      \"\"\"\n"
1345             text+="      dim=self.domain.getDim()\n"
1346             text+="      w=%s(self.domain)\n"%where
1347             text+="      x=w.getX()\n"
1348             o="1"
1349             if len(sh)>0:
1350                sh_2=sh[:len(sh)-1]+(2,)
1351                sh_3=sh[:len(sh)-1]+(3,)            
1352                b_2=makeArray(sh[:len(sh)-1]+(2,),[-1.,1])
1353                b_3=makeArray(sh[:len(sh)-1]+(3,),[-1.,1])
1354             else:
1355                sh_2=()
1356                sh_3=()
1357                b_2=makeArray(sh,[-1.,1])
1358                b_3=makeArray(sh,[-1.,1])
1359    
1360             if data=="Symbol":
1361                val="s"
1362                res="sub"
1363             else:
1364                val="arg"
1365                res="res"
1366             text+="      if dim==2:\n"
1367             if data=="Symbol":
1368                   text+="        arg=Symbol(shape=%s,dim=dim)\n"%str(sh_2)
1369    
1370             text+="        %s=Data(0,%s,w)\n"%(val,sh_2)
1371             text+=unrollLoopsSimplified(b_2,val,tap="        ")
1372             text+=unrollLoopsOfL2(b_2,where,"ref",tap="        ")
1373             text+="\n      else:\n"
1374             if data=="Symbol":
1375                   text+="        arg=Symbol(shape=%s,dim=dim)\n"%str(sh_3)
1376             text+="        %s=Data(0,%s,w)\n"%(val,sh_3)        
1377             text+=unrollLoopsSimplified(b_3,val,tap="        ")
1378             text+=unrollLoopsOfL2(b_3,where,"ref",tap="        ")
1379             text+="\n      res=L2(arg)\n"
1380             if data=="Symbol":
1381                text+="      sub=res.substitute({arg:s})\n"
1382                text+="      self.failUnless(isinstance(res,Symbol),\"wrong type of result.\")\n"
1383                text+="      self.failUnlessEqual(res.getShape(),(),\"wrong shape of result.\")\n"
1384             else:
1385                text+="      self.failUnless(isinstance(res,float),\"wrong type of result.\")\n"
1386             text+="      self.failUnlessAlmostEqual(%s,ref,int(-log10(self.RES_TOL)),\"wrong result\")\n"%res
1387             t_prog+=text
1388    print t_prog
1389    1/0
1390    
1391  #=======================================================================================================  #=======================================================================================================
1392  # div  # div
1393  #=======================================================================================================  #=======================================================================================================

Legend:
Removed from v.441  
changed lines
  Added in v.536

  ViewVC Help
Powered by ViewVC 1.1.26