/[escript]/trunk/escript/py_src/generatediff
ViewVC logotype

Diff of /trunk/escript/py_src/generatediff

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 291 by gross, Fri Dec 2 03:10:06 2005 UTC revision 438 by gross, Fri Jan 20 00:39:00 2006 UTC
# Line 33  def wherepos(arg): Line 33  def wherepos(arg):
33     else:     else:
34        return 0.        return 0.
35    
36    
37  class OPERATOR:  class OPERATOR:
38      def __init__(self,nickname,rng=[-1000.,1000],test_expr="",math_expr=None,      def __init__(self,nickname,rng=[-1000.,1000],test_expr="",math_expr=None,
39                  numarray_expr="",symbol_expr=None,diff=None,name=""):                  numarray_expr="",symbol_expr=None,diff=None,name=""):
# Line 172  def makeArray(shape,rng): Line 173  def makeArray(shape,rng):
173               for i2 in range(shape[2]):               for i2 in range(shape[2]):
174                  for i3 in range(shape[3]):                  for i3 in range(shape[3]):
175                     out[i0,i1,i2,i3]=l*random.random()+rng[0]                     out[i0,i1,i2,i3]=l*random.random()+rng[0]
176       elif len(shape)==5:
177           for i0 in range(shape[0]):
178              for i1 in range(shape[1]):
179                 for i2 in range(shape[2]):
180                    for i3 in range(shape[3]):
181                      for i4 in range(shape[4]):
182                       out[i0,i1,i2,i3,i4]=l*random.random()+rng[0]
183     else:     else:
184         raise SystemError,"rank is restricted to 4"         raise SystemError,"rank is restricted to 5"
185     return out             return out        
186    
187    
# Line 589  def testTensorMult(arg0,arg1,sh_s): Line 597  def testTensorMult(arg0,arg1,sh_s):
597               for i3 in range(arg0.shape[3]):               for i3 in range(arg0.shape[3]):
598                       out[i0,i1]+=arg0[i0,i1,i2,i3]*arg1[i2,i3]                       out[i0,i1]+=arg0[i0,i1,i2,i3]*arg1[i2,i3]
599          return out          return out
600    
601    def testReduce(arg0,init_val,test_expr,post_expr):
602         out=init_val
603         if isinstance(arg0,float):
604              out=eval(test_expr.replace("%a1%","arg0"))
605         elif arg0.rank==0:
606              out=eval(test_expr.replace("%a1%","arg0"))
607         elif arg0.rank==1:
608            for i0 in range(arg0.shape[0]):
609                   out=eval(test_expr.replace("%a1%","arg0[i0]"))
610         elif arg0.rank==2:
611            for i0 in range(arg0.shape[0]):
612             for i1 in range(arg0.shape[1]):
613                   out=eval(test_expr.replace("%a1%","arg0[i0,i1]"))
614         elif arg0.rank==3:
615            for i0 in range(arg0.shape[0]):
616             for i1 in range(arg0.shape[1]):
617               for i2 in range(arg0.shape[2]):
618                   out=eval(test_expr.replace("%a1%","arg0[i0,i1,i2]"))
619         elif arg0.rank==4:
620            for i0 in range(arg0.shape[0]):
621             for i1 in range(arg0.shape[1]):
622               for i2 in range(arg0.shape[2]):
623                 for i3 in range(arg0.shape[3]):
624                   out=eval(test_expr.replace("%a1%","arg0[i0,i1,i2,i3]"))          
625         return eval(post_expr)
626        
627    def clipTEST(arg0,mn,mx):
628         if isinstance(arg0,float):
629              return max(min(arg0,mx),mn)
630         out=numarray.zeros(arg0.shape,numarray.Float64)
631         if arg0.rank==1:
632            for i0 in range(arg0.shape[0]):
633                out[i0]=max(min(arg0[i0],mx),mn)
634         elif arg0.rank==2:
635            for i0 in range(arg0.shape[0]):
636             for i1 in range(arg0.shape[1]):
637                out[i0,i1]=max(min(arg0[i0,i1],mx),mn)
638         elif arg0.rank==3:
639            for i0 in range(arg0.shape[0]):
640             for i1 in range(arg0.shape[1]):
641               for i2 in range(arg0.shape[2]):
642                  out[i0,i1,i2]=max(min(arg0[i0,i1,i2],mx),mn)
643         elif arg0.rank==4:
644            for i0 in range(arg0.shape[0]):
645             for i1 in range(arg0.shape[1]):
646               for i2 in range(arg0.shape[2]):
647                 for i3 in range(arg0.shape[3]):
648                    out[i0,i1,i2,i3]=max(min(arg0[i0,i1,i2,i3],mx),mn)
649         return out
650    def minimumTEST(arg0,arg1):
651         if isinstance(arg0,float):
652           if isinstance(arg1,float):
653              if arg0>arg1:
654                  return arg1
655              else:
656                  return arg0
657           else:
658              arg0=numarray.ones(arg1.shape)*arg0
659         else:
660           if isinstance(arg1,float):
661              arg1=numarray.ones(arg0.shape)*arg1
662         out=numarray.zeros(arg0.shape,numarray.Float64)
663         if arg0.rank==0:
664              if arg0>arg1:
665                  out=arg1
666              else:
667                  out=arg0
668         elif arg0.rank==1:
669            for i0 in range(arg0.shape[0]):
670              if arg0[i0]>arg1[i0]:
671                  out[i0]=arg1[i0]
672              else:
673                  out[i0]=arg0[i0]
674         elif arg0.rank==2:
675            for i0 in range(arg0.shape[0]):
676             for i1 in range(arg0.shape[1]):
677              if arg0[i0,i1]>arg1[i0,i1]:
678                  out[i0,i1]=arg1[i0,i1]
679              else:
680                  out[i0,i1]=arg0[i0,i1]
681         elif arg0.rank==3:
682            for i0 in range(arg0.shape[0]):
683             for i1 in range(arg0.shape[1]):
684               for i2 in range(arg0.shape[2]):
685                 if arg0[i0,i1,i2]>arg1[i0,i1,i2]:
686                  out[i0,i1,i2]=arg1[i0,i1,i2]
687                 else:
688                  out[i0,i1,i2]=arg0[i0,i1,i2]
689         elif arg0.rank==4:
690            for i0 in range(arg0.shape[0]):
691             for i1 in range(arg0.shape[1]):
692               for i2 in range(arg0.shape[2]):
693                 for i3 in range(arg0.shape[3]):
694                  if arg0[i0,i1,i2,i3]>arg1[i0,i1,i2,i3]:
695                   out[i0,i1,i2,i3]=arg1[i0,i1,i2,i3]
696                  else:
697                   out[i0,i1,i2,i3]=arg0[i0,i1,i2,i3]
698         return out
699    def unrollLoops(a,b,o,arg,tap=""):
700        out=""
701        if a.rank==1:
702                 z=""
703                 for i99 in range(a.shape[0]):
704                   if not z=="": z+="+"
705                   if o=="1":
706                      z+="(%s)*x[%s]"%(a[i99]+b[i99],i99)
707                   else:
708                      z+="(%s)*x[%s]**o+(%s)*x[%s]"%(a[i99],i99,b[i99],i99)
709    
710                 out+=tap+"%s=%s\n"%(arg,z)
711    
712        elif a.rank==2:
713            for i0 in range(a.shape[0]):
714                 z=""
715                 for i99 in range(a.shape[1]):
716                   if not z=="": z+="+"
717                   if o=="1":
718                      z+="(%s)*x[%s]"%(a[i0,i99]+b[i0,i99],i99)
719                   else:
720                      z+="(%s)*x[%s]**o+(%s)*x[%s]"%(a[i0,i99],i99,b[i0,i99],i99)
721    
722                 out+=tap+"%s[%s]=%s\n"%(arg,i0,z)
723        elif a.rank==3:
724            for i0 in range(a.shape[0]):
725             for i1 in range(a.shape[1]):
726                 z=""
727                 for i99 in range(a.shape[2]):
728                   if not z=="": z+="+"
729                   if o=="1":
730                      z+="(%s)*x[%s]"%(a[i0,i1,i99]+b[i0,i1,i99],i99)
731                   else:
732                      z+="(%s)*x[%s]**o+(%s)*x[%s]"%(a[i0,i1,i99],i99,b[i0,i1,i99],i99)
733    
734                 out+=tap+"%s[%s,%s]=%s\n"%(arg,i0,i1,z)
735        elif a.rank==4:
736            for i0 in range(a.shape[0]):
737             for i1 in range(a.shape[1]):
738               for i2 in range(a.shape[2]):
739                 z=""
740                 for i99 in range(a.shape[3]):
741                   if not z=="": z+="+"
742                   if o=="1":
743                      z+="(%s)*x[%s]"%(a[i0,i1,i2,i99]+b[i0,i1,i2,i99],i99)
744                   else:
745                      z+="(%s)*x[%s]**o+(%s)*x[%s]"%(a[i0,i1,i2,i99],i99,b[i0,i1,i2,i99],i99)
746    
747                 out+=tap+"%s[%s,%s,%s]=%s\n"%(arg,i0,i1,i2,z)
748        elif a.rank==5:
749            for i0 in range(a.shape[0]):
750             for i1 in range(a.shape[1]):
751               for i2 in range(a.shape[2]):
752                for i3 in range(a.shape[3]):
753                 z=""
754                 for i99 in range(a.shape[4]):
755                   if not z=="": z+="+"
756                   if o=="1":
757                      z+="(%s)*x[%s]"%(a[i0,i1,i2,i3,i99]+b[i0,i1,i2,i3,i99],i99)
758                   else:
759                      z+="(%s)*x[%s]**o+(%s)*x[%s]"%(a[i0,i1,i2,i3,i99],i99,b[i0,i1,i2,i3,i99],i99)
760    
761                 out+=tap+"%s[%s,%s,%s,%s]=%s\n"%(arg,i0,i1,i2,i3,z)
762        return out
763    
764    def unrollLoopsOfGrad(a,b,o,arg,tap=""):
765        out=""
766        if a.rank==1:
767                 for i99 in range(a.shape[0]):
768                   if o=="1":
769                      out+=tap+"%s[%s]=(%s)\n"%(arg,i99,a[i99]+b[i99])
770                   else:
771                      out+=tap+"%s[%s]=o*(%s)*x_ref[%s]**(o-1)+(%s)\n"%(arg,i99,a[i99],i99,b[i99])
772    
773        elif a.rank==2:
774            for i0 in range(a.shape[0]):
775                 for i99 in range(a.shape[1]):
776                   if o=="1":
777                      out+=tap+"%s[%s,%s]=(%s)\n"%(arg,i0,i99,a[i0,i99]+b[i0,i99])
778                   else:
779                      out+=tap+"%s[%s,%s]=o*(%s)*x_ref[%s]**(o-1)+(%s)\n"%(arg,i0,i99,a[i0,i99],i99,b[i0,i99])
780    
781        elif a.rank==3:
782            for i0 in range(a.shape[0]):
783             for i1 in range(a.shape[1]):
784                 for i99 in range(a.shape[2]):
785                   if o=="1":
786                      out+=tap+"%s[%s,%s,%s]=(%s)\n"%(arg,i0,i1,i99,a[i0,i1,i99]+b[i0,i1,i99])
787                   else:
788                      out+=tap+"%s[%s,%s,%s]=o*(%s)*x_ref[%s]**(o-1)+(%s)\n"%(arg,i0,i1,i99,a[i0,i1,i99],i99,b[i0,i1,i99])
789    
790        elif a.rank==4:
791            for i0 in range(a.shape[0]):
792             for i1 in range(a.shape[1]):
793               for i2 in range(a.shape[2]):
794                 for i99 in range(a.shape[3]):
795                   if o=="1":
796                     out+=tap+"%s[%s,%s,%s,%s]=(%s)\n"%(arg,i0,i1,i2,i99,a[i0,i1,i2,i99]+b[i0,i1,i2,i99])
797                   else:
798                     out+=tap+"%s[%s,%s,%s,%s]=o*(%s)*x_ref[%s]**(o-1)+(%s)\n"%(arg,i0,i1,i2,i99,a[i0,i1,i2,i99],i99,b[i0,i1,i2,i99])
799        return out
800    def unrollLoopsOfInteriorIntegral(a,b,where,arg,tap=""):
801        if where=="Function":
802           xfac_o=1.
803           xfac_op=0.
804           z_fac=1./2.
805           z_fac_s=""
806           zo_fac_s="/(o+1.)"
807        elif where=="FunctionOnBoundary":
808           xfac_o=1.
809           xfac_op=0.
810           z_fac=1.
811           z_fac_s="*dim"
812           zo_fac_s="*(1+2.*(dim-1.)/(o+1.))"
813        elif where in ["FunctionOnContactZero","FunctionOnContactOne"]:
814           xfac_o=0.
815           xfac_op=1.
816           z_fac=1./2.
817           z_fac_s=""
818           zo_fac_s="/(o+1.)"
819        out=""
820        if a.rank==1:
821                 zo=0.
822                 zop=0.
823                 z=0.
824                 for i99 in range(a.shape[0]):
825                      if i99==0:
826                        zo+=       xfac_o*a[i99]
827                        zop+=       xfac_op*a[i99]
828                      else:
829                        zo+=a[i99]
830                      z+=b[i99]
831    
832                 out+=tap+"%s=(%s)%s+(%s)%s"%(arg,zo,zo_fac_s,z*z_fac,z_fac_s)
833                 if zop==0.:
834                   out+="\n"
835                 else:
836                   out+="+(%s)*0.5**o\n"%zop
837        elif a.rank==2:
838            for i0 in range(a.shape[0]):
839                 zo=0.
840                 zop=0.
841                 z=0.
842                 for i99 in range(a.shape[1]):
843                      if i99==0:
844                        zo+=       xfac_o*a[i0,i99]
845                        zop+=       xfac_op*a[i0,i99]
846                      else:
847                        zo+=a[i0,i99]
848                      z+=b[i0,i99]
849    
850                 out+=tap+"%s[%s]=(%s)%s+(%s)%s"%(arg,i0,zo,zo_fac_s,z*z_fac,z_fac_s)
851                 if zop==0.:
852                   out+="\n"
853                 else:
854                   out+="+(%s)*0.5**o\n"%zop
855        elif a.rank==3:
856            for i0 in range(a.shape[0]):
857             for i1 in range(a.shape[1]):
858                 zo=0.
859                 zop=0.
860                 z=0.
861                 for i99 in range(a.shape[2]):
862                      if i99==0:
863                        zo+=       xfac_o*a[i0,i1,i99]
864                        zop+=       xfac_op*a[i0,i1,i99]
865                      else:
866                        zo+=a[i0,i1,i99]
867                      z+=b[i0,i1,i99]
868    
869                 out+=tap+"%s[%s,%s]=(%s)%s+(%s)%s"%(arg,i0,i1,zo,zo_fac_s,z*z_fac,z_fac_s)
870                 if zop==0.:
871                   out+="\n"
872                 else:
873                   out+="+(%s)*0.5**o\n"%zop
874        elif a.rank==4:
875            for i0 in range(a.shape[0]):
876             for i1 in range(a.shape[1]):
877               for i2 in range(a.shape[2]):
878                 zo=0.
879                 zop=0.
880                 z=0.
881                 for i99 in range(a.shape[3]):
882                      if i99==0:
883                        zo+=       xfac_o*a[i0,i1,i2,i99]
884                        zop+=       xfac_op*a[i0,i1,i2,i99]
885    
886                      else:
887                        zo+=a[i0,i1,i2,i99]
888                      z+=b[i0,i1,i2,i99]
889    
890                 out+=tap+"%s[%s,%s,%s]=(%s)%s+(%s)%s"%(arg,i0,i1,i2,zo,zo_fac_s,z*z_fac,z_fac_s)
891                 if zop==0.:
892                   out+="\n"
893                 else:
894                   out+="+(%s)*0.5**o\n"%zop
895    
896        elif a.rank==5:
897            for i0 in range(a.shape[0]):
898             for i1 in range(a.shape[1]):
899               for i2 in range(a.shape[2]):
900                for i3 in range(a.shape[3]):
901                 zo=0.
902                 zop=0.
903                 z=0.
904                 for i99 in range(a.shape[4]):
905                      if i99==0:
906                        zo+=       xfac_o*a[i0,i1,i2,i3,i99]
907                        zop+=       xfac_op*a[i0,i1,i2,i3,i99]
908    
909                      else:
910                        zo+=a[i0,i1,i2,i3,i99]
911                      z+=b[i0,i1,i2,i3,i99]
912                 out+=tap+"%s[%s,%s,%s,%s]=(%s)%s+(%s)%s"%(arg,i0,i1,i2,i3,zo,zo_fac_s,z*z_fac,z_fac_s)
913                 if zop==0.:
914                   out+="\n"
915                 else:
916                   out+="+(%s)*0.5**o\n"%zop
917    
918        return out
919    
920    
921  #=======================================================================================================  #=======================================================================================================
922  # tensor multiply  # grad
923  #=======================================================================================================  #=======================================================================================================
924  # oper=["generalTensorProduct",tensorProductTest]  for where in ["Function","FunctionOnBoundary","FunctionOnContactZero","FunctionOnContactOne"]:
925  # oper=["matrixmult",testMatrixMult]    for data in ["Data","Symbol"]:
926  oper=["tensormult",testTensorMult]       for case in ["ContinuousFunction","Solution","ReducedSolution"]:
927           for sh in [ (),(2,), (4,5), (6,2,2)]:
928             text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
929             tname="test_grad_on%s_from%s_%s_rank%s"%(where,data,case,len(sh))
930             text+="   def %s(self):\n"%tname
931             text+="      \"\"\"\n"
932             text+="      tests gradient for rank %s %s on the %s\n\n"%(len(sh),data,where)
933             text+="      assumptions: %s(self.domain) exists\n"%case
934             text+="                   self.domain supports gardient on %s\n"%where
935             text+="      \"\"\"\n"
936             if case=="ReducedSolution":
937                text+="      o=1\n"
938                o="1"
939             else:
940                text+="      o=self.order\n"
941                o="o"
942             text+="      dim=self.domain.getDim()\n"
943             text+="      w_ref=%s(self.domain)\n"%where
944             text+="      x_ref=w_ref.getX()\n"
945             text+="      w=%s(self.domain)\n"%case
946             text+="      x=w.getX()\n"
947             a_2=makeArray(sh+(2,),[-1.,1])
948             b_2=makeArray(sh+(2,),[-1.,1])
949             a_3=makeArray(sh+(3,),[-1.,1])
950             b_3=makeArray(sh+(3,),[-1.,1])
951             if data=="Symbol":
952                text+="      arg=Symbol(shape=%s,dim=dim)\n"%str(sh)
953                val="s"
954                res="sub"
955             else:
956                val="arg"
957                res="res"
958             text+="      %s=Data(0,%s,w)\n"%(val,str(sh))
959             text+="      ref=Data(0,%s+(dim,),w_ref)\n"%str(sh)
960             text+="      if dim==2:\n"
961             text+=unrollLoops(a_2,b_2,o,val,tap="        ")
962             text+=unrollLoopsOfGrad(a_2,b_2,o,"ref",tap="        ")
963             text+="      else:\n"
964            
965             text+=unrollLoops(a_3,b_3,o,val,tap="        ")
966             text+=unrollLoopsOfGrad(a_3,b_3,o,"ref",tap="        ")
967             text+="      res=grad(arg,where=w_ref)\n"
968             if data=="Symbol":
969                text+="      sub=res.substitute({arg:s})\n"
970             text+="      self.failUnless(isinstance(res,%s),\"wrong type of result.\")\n"%data
971             text+="      self.failUnlessEqual(res.getShape(),%s+(dim,),\"wrong shape of result.\")\n"%str(sh)
972             text+="      self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
973    
974  for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]:  
975             t_prog+=text
976    print test_header
977    print t_prog
978    print test_tail          
979    1/0
980    
981    
982    #=======================================================================================================
983    # integrate
984    #=======================================================================================================
985    for where in ["Function","FunctionOnBoundary","FunctionOnContactZero","FunctionOnContactOne"]:
986      for data in ["Data","Symbol"]:
987        for case in ["ContinuousFunction","Solution","ReducedSolution","Function","FunctionOnBoundary","FunctionOnContactZero","FunctionOnContactOne"]:
988          for sh in [ (),(2,), (4,5), (6,2,2),(4,5,3,2)]:
989            if (not case in ["Function","FunctionOnBoundary","FunctionOnContactZero","FunctionOnContactOne"]) or where==case:  
990             text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
991             tname="test_integrate_on%s_from%s_%s_rank%s"%(where,data,case,len(sh))
992             text+="   def %s(self):\n"%tname
993             text+="      \"\"\"\n"
994             text+="      tests integral of rank %s %s on the %s\n\n"%(len(sh),data,where)
995             text+="      assumptions: %s(self.domain) exists\n"%case
996             text+="                   self.domain supports integral on %s\n"%where
997    
998             text+="      \"\"\"\n"
999             if case=="ReducedSolution":
1000                text+="      o=1\n"
1001                o="1"
1002             else:
1003                text+="      o=self.order\n"
1004                o="o"
1005             text+="      dim=self.domain.getDim()\n"
1006             text+="      w_ref=%s(self.domain)\n"%where
1007             text+="      w=%s(self.domain)\n"%case
1008             text+="      x=w.getX()\n"
1009             a_2=makeArray(sh+(2,),[-1.,1])
1010             b_2=makeArray(sh+(2,),[-1.,1])
1011             a_3=makeArray(sh+(3,),[-1.,1])
1012             b_3=makeArray(sh+(3,),[-1.,1])
1013             if data=="Symbol":
1014                text+="      arg=Symbol(shape=%s)\n"%str(sh)
1015                val="s"
1016                res="sub"
1017             else:
1018                val="arg"
1019                res="res"
1020                
1021             text+="      %s=Data(0,%s,w)\n"%(val,str(sh))
1022             if not len(sh)==0:
1023                text+="      ref=numarray.zeros(%s,numarray.Float)\n"%str(sh)
1024             text+="      if dim==2:\n"
1025             text+=unrollLoops(a_2,b_2,o,val,tap="        ")
1026             text+=unrollLoopsOfInteriorIntegral(a_2,b_2,where,"ref",tap="        ")
1027             text+="      else:\n"
1028            
1029             text+=unrollLoops(a_3,b_3,o,val,tap="        ")
1030             text+=unrollLoopsOfInteriorIntegral(a_3,b_3,where,"ref",tap="        ")
1031             if case in ["ContinuousFunction","Solution","ReducedSolution"]:
1032                 text+="      res=integrate(arg,where=w_ref)\n"
1033             else:
1034                 text+="      res=integrate(arg)\n"
1035    
1036             if data=="Symbol":
1037                text+="      sub=res.substitute({arg:s})\n"
1038             if len(sh)==0 and data=="Data":
1039                text+="      self.failUnless(isinstance(%s,float),\"wrong type of result.\")\n"%res
1040             else:
1041                if data=="Symbol":
1042                   text+="      self.failUnless(isinstance(res,Symbol),\"wrong type of result.\")\n"
1043                   text+="      self.failUnlessEqual(res.getShape(),%s,\"wrong shape of result.\")\n"%str(sh)
1044                else:
1045                   text+="      self.failUnless(isinstance(res,numarray.NumArray),\"wrong type of result.\")\n"
1046                   text+="      self.failUnlessEqual(res.shape,%s,\"wrong shape of result.\")\n"%str(sh)
1047             text+="      self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1048    
1049    
1050             t_prog+=text
1051    print test_header
1052    print t_prog
1053    print test_tail          
1054    1/0
1055    #=======================================================================================================
1056    # inverse
1057    #=======================================================================================================
1058    name="inverse"
1059    for case0 in ["array","Symbol","constData","taggedData","expandedData"]:
1060      for sh0 in [ (1,1), (2,2), (3,3)]:
1061                  text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1062                  tname="test_%s_%s_dim%s"%(name,case0,sh0[0])
1063                  text+="   def %s(self):\n"%tname
1064                  a_0=makeArray(sh0,[-1.,1])
1065                  for i in range(sh0[0]): a_0[i,i]+=2
1066                  if case0 in ["taggedData", "expandedData"]:
1067                      a1_0=makeArray(sh0,[-1.,1])
1068                      for i in range(sh0[0]): a1_0[i,i]+=3
1069                  else:
1070                      a1_0=a_0
1071                      
1072                  text+=mkText(case0,"arg",a_0,a1_0)
1073                  text+="      res=%s(arg)\n"%name
1074                  if case0=="Symbol":
1075                     text+=mkText("array","s",a_0,a1_0)
1076                     text+="      sub=res.substitute({arg:s})\n"
1077                     res="sub"
1078                     ref="s"
1079                  else:
1080                     ref="arg"
1081                     res="res"
1082                  text+=mkTypeAndShapeTest(case0,sh0,"res")
1083                  text+="      self.failUnless(Lsup(matrixmult(%s,%s)-kronecker(%s))<=self.RES_TOL,\"wrong result\")\n"%(res,ref,sh0[0])
1084                  
1085                  if case0 == "taggedData" :
1086                      t_prog_with_tags+=text
1087                  else:              
1088                      t_prog+=text
1089    
1090    print test_header
1091    # print t_prog
1092    print t_prog_with_tags
1093    print test_tail          
1094    1/0
1095    
1096    #=======================================================================================================
1097    # trace
1098    #=======================================================================================================
1099    def traceTest(r,offset):
1100        sh=r.shape
1101        r1=1
1102        for i in range(offset): r1*=sh[i]
1103        r2=1
1104        for i in range(offset+2,len(sh)): r2*=sh[i]
1105        r_s=numarray.reshape(r,(r1,sh[offset],sh[offset],r2))
1106        s=numarray.zeros([r1,r2],numarray.Float)
1107        for i1 in range(r1):
1108            for i2 in range(r2):
1109                for j in range(sh[offset]): s[i1,i2]+=r_s[i1,j,j,i2]
1110        return s.resize(sh[:offset]+sh[offset+2:])
1111    name,tt="trace",traceTest
1112    for case0 in ["array","Symbol","constData","taggedData","expandedData"]:
1113      for sh0 in [ (4,5), (6,2,2),(3,2,3,4)]:
1114        for offset in range(len(sh0)-1):
1115                  text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1116                  tname="test_%s_%s_rank%s_offset%s"%(name,case0,len(sh0),offset)
1117                  text+="   def %s(self):\n"%tname
1118                  sh_t=list(sh0)
1119                  sh_t[offset+1]=sh_t[offset]
1120                  sh_t=tuple(sh_t)
1121                  sh_r=[]
1122                  for i in range(offset): sh_r.append(sh0[i])
1123                  for i in range(offset+2,len(sh0)): sh_r.append(sh0[i])              
1124                  sh_r=tuple(sh_r)
1125                  a_0=makeArray(sh_t,[-1.,1])
1126                  if case0 in ["taggedData", "expandedData"]:
1127                      a1_0=makeArray(sh_t,[-1.,1])
1128                  else:
1129                      a1_0=a_0
1130                  r=tt(a_0,offset)
1131                  r1=tt(a1_0,offset)
1132                  text+=mkText(case0,"arg",a_0,a1_0)
1133                  text+="      res=%s(arg,%s)\n"%(name,offset)
1134                  if case0=="Symbol":
1135                     text+=mkText("array","s",a_0,a1_0)
1136                     text+="      sub=res.substitute({arg:s})\n"
1137                     res="sub"
1138                     text+=mkText("array","ref",r,r1)
1139                  else:
1140                     res="res"
1141                     text+=mkText(case0,"ref",r,r1)
1142                  text+=mkTypeAndShapeTest(case0,sh_r,"res")
1143                  text+="      self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1144                  
1145                  if case0 == "taggedData" :
1146                      t_prog_with_tags+=text
1147                  else:              
1148                      t_prog+=text
1149    
1150    print test_header
1151    # print t_prog
1152    print t_prog_with_tags
1153    print test_tail          
1154    1/0
1155    
1156    #=======================================================================================================
1157    # clip
1158    #=======================================================================================================
1159    oper_L=[["clip",clipTEST]]
1160    for oper in oper_L:
1161     for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]:
1162    for sh0 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:    for sh0 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
1163            if len(sh0)==0 or not case0=="float":
1164                  text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1165                  tname="test_%s_%s_rank%s"%(oper[0],case0,len(sh0))
1166                  text+="   def %s(self):\n"%tname
1167                  a_0=makeArray(sh0,[-1.,1])
1168                  if case0 in ["taggedData", "expandedData"]:
1169                      a1_0=makeArray(sh0,[-1.,1])
1170                  else:
1171                      a1_0=a_0
1172    
1173                  r=oper[1](a_0,-0.3,0.5)
1174                  r1=oper[1](a1_0,-0.3,0.5)
1175                  text+=mkText(case0,"arg",a_0,a1_0)
1176                  text+="      res=%s(arg,-0.3,0.5)\n"%oper[0]
1177                  if case0=="Symbol":
1178                     text+=mkText("array","s",a_0,a1_0)
1179                     text+="      sub=res.substitute({arg:s})\n"
1180                     res="sub"
1181                     text+=mkText("array","ref",r,r1)
1182                  else:
1183                     res="res"
1184                     text+=mkText(case0,"ref",r,r1)
1185                  text+=mkTypeAndShapeTest(case0,sh0,"res")
1186                  text+="      self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1187                  
1188                  if case0 == "taggedData" :
1189                      t_prog_with_tags+=text
1190                  else:              
1191                      t_prog+=text
1192    
1193    print test_header
1194    # print t_prog
1195    print t_prog_with_tags
1196    print test_tail          
1197    1/0
1198    
1199    #=======================================================================================================
1200    # maximum, minimum, clipping
1201    #=======================================================================================================
1202    oper_L=[ ["maximum",maximumTEST],
1203             ["minimum",minimumTEST]]
1204    for oper in oper_L:
1205     for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]:
1206      for sh1 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
1207     for case1 in ["float","array","Symbol","constData","taggedData","expandedData"]:     for case1 in ["float","array","Symbol","constData","taggedData","expandedData"]:
1208       for sh1 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:       for sh0 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
1209         for sh_s in [ (),(3,), (2,3), (2,4,3),(4,2,3,2)]:          if (len(sh0)==0 or not case0=="float") and (len(sh1)==0 or not case1=="float") \
1210            if (len(sh0+sh_s)==0 or not case0=="float") and (len(sh1+sh_s)==0 or not case1=="float") \             and (sh0==sh1 or len(sh0)==0 or len(sh1)==0) :
                and len(sh0+sh1)<5 and len(sh0+sh_s)<5 and len(sh1+sh_s)<5:  
             # if len(sh_s)==1 and len(sh0+sh_s)==2 and (len(sh_s+sh1)==1 or len(sh_s+sh1)==2)): # test for matrixmult  
             if ( len(sh_s)==1 and len(sh0+sh_s)==2 and ( len(sh1+sh_s)==2 or len(sh1+sh_s)==1 )) or (len(sh_s)==2 and len(sh0+sh_s)==4 and (len(sh1+sh_s)==2 or len(sh1+sh_s)==3 or len(sh1+sh_s)==4)):  # test for tensormult  
               case=getResultCaseForBin(case0,case1)    
1211                use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData"                use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData"
1212    
1213                text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"                text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1214                # tname="test_generalTensorProduct_%s_rank%s_%s_rank%s_offset%s"%(case0,len(sh0+sh_s),case1,len(sh_s+sh1),len(sh_s))                tname="test_%s_%s_rank%s_%s_rank%s"%(oper[0],case0,len(sh0),case1,len(sh1))
               #tname="test_matrixmult_%s_rank%s_%s_rank%s"%(case0,len(sh0+sh_s),case1,len(sh_s+sh1))  
               tname="test_tensormult_%s_rank%s_%s_rank%s"%(case0,len(sh0+sh_s),case1,len(sh_s+sh1))  
               # if tname=="test_generalTensorProduct_array_rank1_array_rank2_offset1":  
               # print tnametest_generalTensorProduct_Symbol_rank1_Symbol_rank3_offset1  
1215                text+="   def %s(self):\n"%tname                text+="   def %s(self):\n"%tname
1216                a_0=makeArray(sh0+sh_s,[-1.,1])                a_0=makeArray(sh0,[-1.,1])
1217                if case0 in ["taggedData", "expandedData"]:                if case0 in ["taggedData", "expandedData"]:
1218                    a1_0=makeArray(sh0+sh_s,[-1.,1])                    a1_0=makeArray(sh0,[-1.,1])
1219                else:                else:
1220                    a1_0=a_0                    a1_0=a_0
1221    
1222                a_1=makeArray(sh_s+sh1,[-1.,1])                a_1=makeArray(sh1,[-1.,1])
1223                if case1 in ["taggedData", "expandedData"]:                if case1 in ["taggedData", "expandedData"]:
1224                    a1_1=makeArray(sh_s+sh1,[-1.,1])                    a1_1=makeArray(sh1,[-1.,1])
1225                else:                else:
1226                    a1_1=a_1                    a1_1=a_1
1227                r=oper[1](a_0,a_1,sh_s)                r=oper[1](a_0,a_1)
1228                r1=oper[1](a1_0,a1_1,sh_s)                r1=oper[1](a1_0,a1_1)
1229                text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data)                text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data)
1230                text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data)                text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data)
1231                #text+="      res=matrixmult(arg0,arg1)\n"                text+="      res=%s(arg0,arg1)\n"%oper[0]
1232                text+="      res=tensormult(arg0,arg1)\n"                case=getResultCaseForBin(case0,case1)              
               #text+="      res=generalTensorProduct(arg0,arg1,offset=%s)\n"%(len(sh_s))  
1233                if case=="Symbol":                if case=="Symbol":
1234                   c0_res,c1_res=case0,case1                   c0_res,c1_res=case0,case1
1235                   subs="{"                   subs="{"
# Line 651  for case0 in ["float","array","Symbol"," Line 1249  for case0 in ["float","array","Symbol","
1249                else:                else:
1250                   res="res"                   res="res"
1251                   text+=mkText(case,"ref",r,r1)                   text+=mkText(case,"ref",r,r1)
1252                text+=mkTypeAndShapeTest(case,sh0+sh1,"res")                if len(sh0)>len(sh1):
1253                      text+=mkTypeAndShapeTest(case,sh0,"res")
1254                  else:
1255                      text+=mkTypeAndShapeTest(case,sh1,"res")
1256                text+="      self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res                text+="      self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1257                  
1258                if case0 == "taggedData" or case1 == "taggedData":                if case0 == "taggedData" or case1 == "taggedData":
1259                    t_prog_with_tags+=text                    t_prog_with_tags+=text
1260                else:                              else:              
1261                    t_prog+=text                    t_prog+=text
1262    
1263  print test_header  print test_header
1264  # print t_prog  # print t_prog
1265  print t_prog_with_tags  print t_prog_with_tags
1266  print test_tail            print test_tail          
1267  1/0  1/0
1268    
1269    
1270  #=======================================================================================================  #=======================================================================================================
1271  # outer/inner  # outer inner
1272  #=======================================================================================================  #=======================================================================================================
1273  oper=["inner",innerTEST]  oper=["outer",outerTEST]
1274  # oper=["outer",outerTEST]  # oper=["inner",innerTEST]
1275  for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]:  for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]:
1276    for sh1 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:    for sh1 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
1277     for case1 in ["float","array","Symbol","constData","taggedData","expandedData"]:     for case1 in ["float","array","Symbol","constData","taggedData","expandedData"]:
# Line 725  for case0 in ["float","array","Symbol"," Line 1330  for case0 in ["float","array","Symbol","
1330  print test_header  print test_header
1331  # print t_prog  # print t_prog
1332  print t_prog_with_tags  print t_prog_with_tags
1333    print test_tail          
1334    1/0
1335    
1336    #=======================================================================================================
1337    # local reduction
1338    #=======================================================================================================
1339    for oper in [["length",0.,"out+%a1%**2","math.sqrt(out)"],
1340                 ["maxval",-1.e99,"max(out,%a1%)","out"],
1341                 ["minval",1.e99,"min(out,%a1%)","out"] ]:
1342      for case in case_set:
1343         for sh in shape_set:
1344           if not case=="float" or len(sh)==0:
1345             text=""
1346             text+="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1347             tname="def test_%s_%s_rank%s"%(oper[0],case,len(sh))
1348             text+="   %s(self):\n"%tname
1349             a=makeArray(sh,[-1.,1.])            
1350             a1=makeArray(sh,[-1.,1.])
1351             r1=testReduce(a1,oper[1],oper[2],oper[3])
1352             r=testReduce(a,oper[1],oper[2],oper[3])
1353            
1354             text+=mkText(case,"arg",a,a1)
1355             text+="      res=%s(arg)\n"%oper[0]
1356             if case=="Symbol":        
1357                 text+=mkText("array","s",a,a1)
1358                 text+="      sub=res.substitute({arg:s})\n"        
1359                 text+=mkText("array","ref",r,r1)
1360                 res="sub"
1361             else:
1362                 text+=mkText(case,"ref",r,r1)
1363                 res="res"
1364             if oper[0]=="length":              
1365                   text+=mkTypeAndShapeTest(case,(),"res")
1366             else:            
1367                if case=="float" or case=="array":        
1368                   text+=mkTypeAndShapeTest("float",(),"res")
1369                else:          
1370                   text+=mkTypeAndShapeTest(case,(),"res")
1371             text+="      self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1372             if case == "taggedData":
1373               t_prog_with_tags+=text
1374             else:
1375               t_prog+=text
1376    print test_header
1377    # print t_prog
1378    print t_prog_with_tags
1379    print test_tail          
1380    1/0
1381              
1382    #=======================================================================================================
1383    # tensor multiply
1384    #=======================================================================================================
1385    # oper=["generalTensorProduct",tensorProductTest]
1386    # oper=["matrixmult",testMatrixMult]
1387    oper=["tensormult",testTensorMult]
1388    
1389    for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]:
1390      for sh0 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
1391       for case1 in ["float","array","Symbol","constData","taggedData","expandedData"]:
1392         for sh1 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
1393           for sh_s in [ (),(3,), (2,3), (2,4,3),(4,2,3,2)]:
1394              if (len(sh0+sh_s)==0 or not case0=="float") and (len(sh1+sh_s)==0 or not case1=="float") \
1395                   and len(sh0+sh1)<5 and len(sh0+sh_s)<5 and len(sh1+sh_s)<5:
1396                # if len(sh_s)==1 and len(sh0+sh_s)==2 and (len(sh_s+sh1)==1 or len(sh_s+sh1)==2)): # test for matrixmult
1397                if ( len(sh_s)==1 and len(sh0+sh_s)==2 and ( len(sh1+sh_s)==2 or len(sh1+sh_s)==1 )) or (len(sh_s)==2 and len(sh0+sh_s)==4 and (len(sh1+sh_s)==2 or len(sh1+sh_s)==3 or len(sh1+sh_s)==4)):  # test for tensormult
1398                  case=getResultCaseForBin(case0,case1)  
1399                  use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData"
1400                  text="   #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1401                  # tname="test_generalTensorProduct_%s_rank%s_%s_rank%s_offset%s"%(case0,len(sh0+sh_s),case1,len(sh_s+sh1),len(sh_s))
1402                  #tname="test_matrixmult_%s_rank%s_%s_rank%s"%(case0,len(sh0+sh_s),case1,len(sh_s+sh1))
1403                  tname="test_tensormult_%s_rank%s_%s_rank%s"%(case0,len(sh0+sh_s),case1,len(sh_s+sh1))
1404                  # if tname=="test_generalTensorProduct_array_rank1_array_rank2_offset1":
1405                  # print tnametest_generalTensorProduct_Symbol_rank1_Symbol_rank3_offset1
1406                  text+="   def %s(self):\n"%tname
1407                  a_0=makeArray(sh0+sh_s,[-1.,1])
1408                  if case0 in ["taggedData", "expandedData"]:
1409                      a1_0=makeArray(sh0+sh_s,[-1.,1])
1410                  else:
1411                      a1_0=a_0
1412    
1413                  a_1=makeArray(sh_s+sh1,[-1.,1])
1414                  if case1 in ["taggedData", "expandedData"]:
1415                      a1_1=makeArray(sh_s+sh1,[-1.,1])
1416                  else:
1417                      a1_1=a_1
1418                  r=oper[1](a_0,a_1,sh_s)
1419                  r1=oper[1](a1_0,a1_1,sh_s)
1420                  text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data)
1421                  text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data)
1422                  #text+="      res=matrixmult(arg0,arg1)\n"
1423                  text+="      res=tensormult(arg0,arg1)\n"
1424                  #text+="      res=generalTensorProduct(arg0,arg1,offset=%s)\n"%(len(sh_s))
1425                  if case=="Symbol":
1426                     c0_res,c1_res=case0,case1
1427                     subs="{"
1428                     if case0=="Symbol":        
1429                        text+=mkText("array","s0",a_0,a1_0)
1430                        subs+="arg0:s0"
1431                        c0_res="array"
1432                     if case1=="Symbol":        
1433                        text+=mkText("array","s1",a_1,a1_1)
1434                        if not subs.endswith("{"): subs+=","
1435                        subs+="arg1:s1"
1436                        c1_res="array"
1437                     subs+="}"  
1438                     text+="      sub=res.substitute(%s)\n"%subs
1439                     res="sub"
1440                     text+=mkText(getResultCaseForBin(c0_res,c1_res),"ref",r,r1)
1441                  else:
1442                     res="res"
1443                     text+=mkText(case,"ref",r,r1)
1444                  text+=mkTypeAndShapeTest(case,sh0+sh1,"res")
1445                  text+="      self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1446                  if case0 == "taggedData" or case1 == "taggedData":
1447                      t_prog_with_tags+=text
1448                  else:              
1449                      t_prog+=text
1450    print test_header
1451    # print t_prog
1452    print t_prog_with_tags
1453  print test_tail            print test_tail          
1454  1/0  1/0
1455  #=======================================================================================================  #=======================================================================================================

Legend:
Removed from v.291  
changed lines
  Added in v.438

  ViewVC Help
Powered by ViewVC 1.1.26