/[escript]/trunk/escript/py_src/generateutil
ViewVC logotype

Diff of /trunk/escript/py_src/generateutil

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 536 by gross, Fri Feb 17 03:20:53 2006 UTC revision 550 by gross, Wed Feb 22 02:14:38 2006 UTC
# Line 179  def makeArray(shape,rng): Line 179  def makeArray(shape,rng):
179               for i2 in range(shape[2]):               for i2 in range(shape[2]):
180                  for i3 in range(shape[3]):                  for i3 in range(shape[3]):
181                    for i4 in range(shape[4]):                    for i4 in range(shape[4]):
182                     out[i0,i1,i2,i3,i4]=l*random.random()+rng[0]                     out[i0,i1,i2,i3,i4]=l*ranm.random()+rng[0]
183     else:     else:
184         raise SystemError,"rank is restricted to 5"         raise SystemError,"rank is restricted to 5"
185     return out             return out        
# Line 190  def makeNumberedArray(shape,s=1.): Line 190  def makeNumberedArray(shape,s=1.):
190         out=s*1.         out=s*1.
191     elif len(shape)==1:     elif len(shape)==1:
192         for i0 in range(shape[0]):         for i0 in range(shape[0]):
193                     out[i0]=s*i0                     out[i0]=s*int(8*random.random()+1)
194     elif len(shape)==2:     elif len(shape)==2:
195         for i0 in range(shape[0]):         for i0 in range(shape[0]):
196            for i1 in range(shape[1]):            for i1 in range(shape[1]):
197                     out[i0,i1]=s*(i1+shape[1]*i0)                     out[i0,i1]=s*int(8*random.random()+1)
198     elif len(shape)==3:     elif len(shape)==3:
199         for i0 in range(shape[0]):         for i0 in range(shape[0]):
200            for i1 in range(shape[1]):            for i1 in range(shape[1]):
201               for i2 in range(shape[2]):               for i2 in range(shape[2]):
202                     out[i0,i1,i2]=s*(i2+shape[2]*i1+shape[2]*shape[1]*i0)                     out[i0,i1,i2]=s*int(8*random.random()+1)
203     elif len(shape)==4:     elif len(shape)==4:
204         for i0 in range(shape[0]):         for i0 in range(shape[0]):
205            for i1 in range(shape[1]):            for i1 in range(shape[1]):
206               for i2 in range(shape[2]):               for i2 in range(shape[2]):
207                  for i3 in range(shape[3]):                  for i3 in range(shape[3]):
208                     out[i0,i1,i2,i3]=s*(i3+shape[3]*i2+shape[3]*shape[2]*i1+shape[3]*shape[2]*shape[1]*i0)                     out[i0,i1,i2,i3]=s*int(8*random.random()+1)
209     else:     else:
210         raise SystemError,"rank is restricted to 4"         raise SystemError,"rank is restricted to 4"
211     return out             return out        
# Line 505  def mkText(case,name,a,a1=None,use_taggi Line 505  def mkText(case,name,a,a1=None,use_taggi
505                   t_out+="      %s.setTaggedValue(1,numarray.array(%s))\n"%(name,a1.tolist())                   t_out+="      %s.setTaggedValue(1,numarray.array(%s))\n"%(name,a1.tolist())
506                t_out+="      %s.expand()\n"%name                          t_out+="      %s.expand()\n"%name          
507             else:             else:
508                t_out+="      msk_%s=whereNegative(self.functionspace.getX()[0]-0.5)\n"%name                t_out+="      msk_%s=whereZero(self.functionspace.getX()[0],1.e-8)\n"%name
509                if isinstance(a,float):                if isinstance(a,float):
510                     t_out+="      %s=msk_%s*(%s)+(1.-msk_%s)*(%s)\n"%(name,name,a,name,a1)                     t_out+="      %s=(1.-msk_%s)*(%s)+msk_%s*(%s)\n"%(name,name,a,name,a1)
511                elif a.rank==0:                elif a.rank==0:
512                     t_out+="      %s=msk_%s*numarray.array(%s)+(1.-msk_%s)*numarray.array(%s)\n"%(name,name,a,name,a1)                     t_out+="      %s=msk_%s*numarray.array(%s)+(1.-msk_%s)*numarray.array(%s)\n"%(name,name,a,name,a1)
513                else:                else:
# Line 550  def mkCode(txt,args=[],intend=""): Line 550  def mkCode(txt,args=[],intend=""):
550        out=out.replace("%%a%s%%"%c,r)        out=out.replace("%%a%s%%"%c,r)
551      return out        return out  
552  #=======================================================================================================  #=======================================================================================================
553  # nonsymmetric part  # get slices
554    #=======================================================================================================
555    from esys.escript import *
556    for case0 in ["constData","taggedData","expandedData"]:
557       for sh0 in [ (3,), (3,4), (3,4,3) ,(4,3,3,3)]:
558        # get perm:
559        if len(sh0)==2:
560            check=[[1,0]]
561        elif len(sh0)==3:
562            check=[[1,0,2],
563                   [1,2,0],
564                   [2,1,0],
565                   [2,0,2],
566                   [0,2,1]]
567        elif len(sh0)==4:
568            check=[[0,1,3,2],
569                   [0,2,1,3],
570                   [0,2,3,1],
571                   [0,3,2,1],
572                   [0,3,1,2] ,          
573                   [1,0,2,3],
574                   [1,0,3,2],
575                   [1,2,0,3],
576                   [1,2,3,0],
577                   [1,3,2,0],
578                   [1,3,0,2],
579                   [2,0,1,3],
580                   [2,0,3,1],
581                   [2,1,0,3],
582                   [2,1,3,0],
583                   [2,3,1,0],
584                   [2,3,0,1],
585                   [3,0,1,2],
586                   [3,0,2,1],
587                   [3,1,0,2],
588                   [3,1,2,0],
589                   [3,2,1,0],
590                   [3,2,0,1]]
591        else:
592             check=[]
593        
594        # create the test cases:
595        processed=[]
596        l=["R","U","L","P","C","N"]
597        c=[""]
598        for i in range(len(sh0)):
599           tmp=[]
600           for ci in c:
601              tmp+=[ci+li for li in l]
602           c=tmp
603        # SHUFFLE
604        c2=[]
605        while len(c)>0:
606            i=int(random.random()*len(c))
607            c2.append(c[i])
608            del c[i]
609        c=c2
610        for ci in c:
611          t=""
612          sh=()
613          sl=()
614          for i in range(len(ci)):
615              if ci[i]=="R":
616                 s="%s:%s"%(1,sh0[i]-1)
617                 sl=sl+(slice(1,sh0[i]-1),)
618                 sh=sh+(sh0[i]-2,)            
619              if ci[i]=="U":
620                  s=":%s"%(sh0[i]-1)
621                  sh=sh+(sh0[i]-1,)
622                  sl=sl+(slice(0,sh0[i]-1),)            
623              if ci[i]=="L":
624                  s="2:"
625                  sh=sh+(sh0[i]-2,)            
626                  sl=sl+(slice(2,sh0[i]),)            
627              if ci[i]=="P":
628                  s="%s"%(int(sh0[i]/2))
629                  sl=sl+(int(sh0[i]/2),)            
630              if ci[i]=="C":
631                  s=":"
632                  sh=sh+(sh0[i],)            
633                  sl=sl+(slice(0,sh0[i]),)            
634              if ci[i]=="N":
635                  s=""
636                  sh=sh+(sh0[i],)
637              if len(s)>0:
638                 if not t=="": t+=","
639                 t+=s
640          if len(sl)==1: sl=sl[0]
641          N_found=False
642          noN_found=False
643          process=len(t)>0
644          for i in ci:
645             if i=="N":
646                if not noN_found and N_found: process=False
647                N_found=True
648             else:
649               if N_found: process=False
650               noNfound=True
651          # is there a similar one processed allready
652          if process and ci.find("N")==-1:
653             for ci2 in processed:
654               for chi in check:
655                   is_perm=True
656                   for i in range(len(chi)):
657                       if not ci[i]==ci2[chi[i]]: is_perm=False
658                   if is_perm: process=False
659          # if not process: print ci," rejected"
660          if process:
661           processed.append(ci)
662           for case1 in ["array","constData","taggedData","expandedData"]:
663              text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
664              tname="test_setslice_%s_rank%s_%s_%s"%(case0,len(sh0),case1,ci)
665              text+="   def %s(self):\n"%tname
666              a_0=makeNumberedArray(sh0)
667              if case0 in ["taggedData", "expandedData"]:
668                  a1_0=makeNumberedArray(sh0)
669              else:
670                  a1_0=a_0*1.
671    
672              a_1=makeNumberedArray(sh)
673              if case1 in ["taggedData", "expandedData"]:
674                  a1_1=makeNumberedArray(sh)
675              else:
676                  a1_1=a_1*1.
677    
678              text+=mkText(case0,"arg",a_0,a1_0)                                  
679              text+=mkText(case1,"val",a_1,a1_1)                                  
680              text+="      arg[%s]=val\n"%t
681              a_0.__setitem__(sl,a_1)
682              a1_0.__setitem__(sl,a1_1)
683              if Lsup(a_0-a1_0)==0.:
684                 text+=mkText("constData","ref",a_0,a1_0)
685              else:
686                 text+=mkText("expandedData","ref",a_0,a1_0)
687              text+="      self.failUnless(Lsup(arg-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"
688                  
689              if case0 == "taggedData" or case1 == "taggedData":
690                t_prog_with_tags+=text
691              else:              
692                t_prog+=text
693    
694    print test_header
695    # print t_prog
696    print t_prog_with_tags
697    print test_tail          
698    1/0
699    
700    #=======================================================================================================
701    # (non)symmetric part
702  #=======================================================================================================  #=======================================================================================================
703  from esys.escript import *  from esys.escript import *
704  for name in ["symmetric", "nonsymmetric"]:  for name in ["symmetric", "nonsymmetric"]:
# Line 638  print test_tail Line 786  print test_tail
786  1/0  1/0
787    
788  #=======================================================================================================  #=======================================================================================================
789  # slicing  # get slices
790  #=======================================================================================================  #=======================================================================================================
791  for case0 in ["constData","taggedData","expandedData","Symbol"]:  for case0 in ["constData","taggedData","expandedData","Symbol"]:
792    for sh0 in [ (3,), (3,4), (3,4,3) ,(4,3,5,3)]:    for sh0 in [ (3,), (3,4), (3,4,3) ,(4,3,5,3)]:

Legend:
Removed from v.536  
changed lines
  Added in v.550

  ViewVC Help
Powered by ViewVC 1.1.26