/[escript]/trunk/escriptcore/test/python/test_util_base.py
ViewVC logotype

Diff of /trunk/escriptcore/test/python/test_util_base.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 6482 by jfenwick, Tue Jan 24 07:45:33 2017 UTC revision 6483 by jfenwick, Wed Jan 25 02:56:14 2017 UTC
# Line 1112  class Test_util_values(unittest.TestCase Line 1112  class Test_util_values(unittest.TestCase
1112          d.tag()          d.tag()
1113          return (d, ref)          return (d, ref)
1114            
1115        def get_array_by_shape(self, s, cplx):
1116            dt=numpy.float64 if not cplx else numpy.complex128
1117            n=numpy.prod(s)
1118            a=numpy.arange(n, dtype=dt).reshape(s)
1119            if cplx:
1120                a-=((1+numpy.arange(n, dtype=dt))*1j).reshape(s)
1121            return a
1122        
1123        def make_constant_from_array(self, a, fs):
1124            d=Data(a, fs)
1125            return d
1126            
1127        def make_tagged_from_array(self, a, fs):
1128            d=Data(1.5, getShape(a), fs)
1129            if a.dtype.kind=='c':
1130                d.promote()
1131            for n in fs.getListOfTags():
1132                d.setTaggedValue(n, a)
1133            return d
1134            
1135        def make_expanded_from_array(self, a, fs):
1136            d=Data(a, fs)
1137            d.expand()
1138            return d
1139        
1140      def execute_ce_params(self, pars):      def execute_ce_params(self, pars):
1141          for v in pars:          for v in pars:
1142              a=v[0]              a=v[0]
# Line 1408  class Test_util_values(unittest.TestCase Line 1433  class Test_util_values(unittest.TestCase
1433                  print(oraclevalue)                                  print(oraclevalue)                
1434              self.assertTrue(oracleres,"wrong result for "+description)              self.assertTrue(oracleres,"wrong result for "+description)
1435    
1436        def generate_binary_matrixlike_operation_test_batch_large(self, opstring, misccheck, oraclecheck, opname, input_trans=None, minrank=0, maxrank=4, no_shape_mismatch=False, permit_scalar_mismatch=True, cap_combined_rank=False, fix_rank_a=None, fix_rank_b=None):
1437            """
1438            Generates a set of tests for binary operations.
1439            It is similar to the unary versions but with some unneeded options removed.
1440            For example, all operations in this type should accept complex arguments.
1441            opstring is a string of the operation to be performed (in terms of arguments a and b) eg "inner(a,b)"
1442            misccheck is a string giving a check to be run after the operation eg "isinstance(res,float)"
1443            opname is a string used to describe the operation being tested eg "inner"
1444            update1 is a string giving code used to update a variable rmerge to  
1445                account for tag additions for tagged data.
1446                eg:             update1="r2.min()"
1447                would result in     rmerge=eval(update1) running after the first tag is calculated
1448            """
1449            if input_trans is None:
1450                input_trans=lambda x: x
1451            pars=[]
1452            for ac in (False, True):    # complex or real arguments
1453                for bc in (False, True):
1454                    astr="real" if ac else "complex"
1455                    bstr="real" if bc else "complex"
1456                    aargset=[]
1457                    bargset=[]
1458                    if fix_rank_a is not None:
1459                        arange=fix_rank_a
1460                    else:
1461                        arange=range(minrank, maxrank+1)
1462                    if fix_rank_b is not None:
1463                        brange=fix_rank_b
1464                    else:
1465                        brange=range(minrank, maxrank+1)                    
1466                    for atype in "ACTE":   # Array/Constant/Tagged/Expanded
1467                        if atype=='A':
1468                            for r in arange:
1469                                aargset.append((self.get_array_input1(r,ac),astr+' array rank '+str(r), r))
1470                        elif atype=='C':
1471                            for r in arange:
1472                                aargset.append((self.get_const_input1(r, self.functionspace, ac), astr+' Constant rank '+str(r), r))
1473                        elif atype=='T':
1474                            for r in arange:
1475                                aargset.append((self.get_tagged_with_tagL1(r, self.functionspace, ac, set_tags=False),astr+' Tagged rank '+str(r), r))
1476                        elif atype=='E':
1477                            for r in arange:
1478                                aargset.append((self.get_expanded_inputL(r,self.functionspace, ac),astr+' Expanded rank '+str(r), r))
1479                    # Now we have a set of a args, match them with possible b's
1480                    for v in aargset:
1481                        arg=v[0][0]
1482                        argref=v[0][1]
1483                        adescr=v[1]
1484                        rank=v[2]
1485                        for br in brange:
1486                            tshape=(r,)*br
1487                            bargref=self.get_array_by_shape(tshape, bc)
1488                            
1489                            # now convert it to each possbile input type
1490                            barg=self.make_constant_from_array(bargref, self.functionspace)
1491                            bdescr=bstr+' Constant rank '+str(br)
1492                            p=(arg, barg, opstring, misccheck,
1493                               numpy.array(argref), numpy.array(bargref),
1494                               oraclecheck, opname+' '+adescr+'/'+bdescr)
1495                            pars.append(p)      
1496                            barg=self.make_tagged_from_array(bargref, self.functionspace)
1497                            bdescr=bstr+' Tagged rank '+str(br)
1498                            p=(arg, barg, opstring, misccheck,
1499                               numpy.array(argref), numpy.array(bargref),
1500                               oraclecheck, opname+' '+adescr+'/'+bdescr)
1501                            pars.append(p)          
1502                            barg=self.make_expanded_from_array(bargref, self.functionspace)
1503                            bdescr=bstr+' Expanded rank '+str(br)
1504                            p=(arg, barg, opstring, misccheck,
1505                               numpy.array(argref), numpy.array(bargref),
1506                               oraclecheck, opname+' '+adescr+'/'+bdescr)
1507                            pars.append(p)          
1508            self.execute_binary_params(pars)
1509    
1510      def generate_binary_operation_test_batch_large(self, opstring, misccheck, oraclecheck, opname, input_trans=None, minrank=0, maxrank=4, no_shape_mismatch=False, permit_scalar_mismatch=True, cap_combined_rank=False, fix_rank_a=None, fix_rank_b=None):      def generate_binary_operation_test_batch_large(self, opstring, misccheck, oraclecheck, opname, input_trans=None, minrank=0, maxrank=4, no_shape_mismatch=False, permit_scalar_mismatch=True, cap_combined_rank=False, fix_rank_a=None, fix_rank_b=None):
1511          """          """
# Line 1426  class Test_util_values(unittest.TestCase Line 1523  class Test_util_values(unittest.TestCase
1523          if input_trans is None:          if input_trans is None:
1524              input_trans=lambda x: x              input_trans=lambda x: x
1525          pars=[]          pars=[]
1526          for ac in (False, True):          for ac in (False, True):    # complex or real arguments
1527              for bc in (False, True):              for bc in (False, True):
1528                  astr="real" if ac else "complex"                  astr="real" if ac else "complex"
1529                  bstr="real" if bc else "complex"                  bstr="real" if bc else "complex"
1530                  aargset=[]                  aargset=[]
1531                  bargset=[]                  bargset=[]
1532                    if fix_rank_a is not None:
1533                        arange=fix_rank_a
1534                    else:
1535                        arange=range(minrank, maxrank+1)
1536                    if fix_rank_b is not None:
1537                        brange=fix_rank_b
1538                    else:
1539                        brange=range(minrank, maxrank+1)                    
1540                  for atype in "SACTE":   # Scalar/Array/Constant/Tagged/Expanded                  for atype in "SACTE":   # Scalar/Array/Constant/Tagged/Expanded
1541                      if atype=='S':                      if atype=='S':
1542                          aargset.append((self.get_scalar_input1(ac),astr+' scalar'))                          aargset.append((self.get_scalar_input1(ac),astr+' scalar'))
1543                      elif atype=='A':                      elif atype=='A':
1544                          for r in range(minrank, maxrank+1):                          for r in arange:
1545                              aargset.append((self.get_array_input1(r,ac),astr+' array rank '+str(r)))                              aargset.append((self.get_array_input1(r,ac),astr+' array rank '+str(r)))
1546                      elif atype=='C':                      elif atype=='C':
1547                          for r in range(minrank, maxrank+1):                          for r in arange:
1548                              aargset.append((self.get_const_input1(r, self.functionspace, ac), astr+' Constant rank '+str(r)))                              aargset.append((self.get_const_input1(r, self.functionspace, ac), astr+' Constant rank '+str(r)))
1549                      elif atype=='T':                      elif atype=='T':
1550                          for r in range(minrank, maxrank+1):                          for r in arange:
1551                              aargset.append((self.get_tagged_with_tagL1(r, self.functionspace, ac, set_tags=False),astr+' Tagged rank '+str(r)))                              aargset.append((self.get_tagged_with_tagL1(r, self.functionspace, ac, set_tags=False),astr+' Tagged rank '+str(r)))
1552                      elif atype=='E':                      elif atype=='E':
1553                          for r in range(minrank, maxrank+1):                          for r in arange:
1554                              aargset.append((self.get_expanded_inputL(r,self.functionspace, ac),astr+' Expanded rank '+str(r)))                              aargset.append((self.get_expanded_inputL(r,self.functionspace, ac),astr+' Expanded rank '+str(r)))    
1555                  for atype in "SACTE":   # Scalar/Array/Constant/Tagged/Expanded                  for atype in "SACTE":   # Scalar/Array/Constant/Tagged/Expanded
1556                      if atype=='S':                      if atype=='S':
1557                          bargset.append((self.get_scalar_input2(ac),bstr+' scalar'))                          bargset.append((self.get_scalar_input2(ac),bstr+' scalar'))
1558                      elif atype=='A':                      elif atype=='A':
1559                          for r in range(minrank, maxrank+1):                          for r in brange:
1560                              bargset.append((self.get_array_input2(r,ac),bstr+' array rank '+str(r)))                              bargset.append((self.get_array_input2(r,ac),bstr+' array rank '+str(r)))
1561                      elif atype=='C':                      elif atype=='C':
1562                          for r in range(minrank, maxrank+1):                          for r in brange:
1563                              bargset.append((self.get_const_input2(r, self.functionspace, ac),bstr+' Constant rank '+str(r)))                              bargset.append((self.get_const_input2(r, self.functionspace, ac),bstr+' Constant rank '+str(r)))
1564                      elif atype=='T':                      elif atype=='T':
1565                          for r in range(minrank, maxrank+1):                          for r in brange:
1566                              bargset.append((self.get_tagged_with_tagL2(r, self.functionspace, ac, set_tags=True), bstr+' Tagged rank '+str(r)))                              bargset.append((self.get_tagged_with_tagL2(r, self.functionspace, ac, set_tags=True), bstr+' Tagged rank '+str(r)))
1567                      elif atype=='E':                      elif atype=='E':
1568                          for r in range(minrank, maxrank+1):                          for r in brange:
1569                              bargset.append((self.get_expanded_inputL2(r, self.functionspace, ac),bstr+' Expanded rank '+str(r)))                              bargset.append((self.get_expanded_inputL2(r, self.functionspace, ac),bstr+' Expanded rank '+str(r)))
1570                  # now we have a complete set of possible args                      # now we have a complete set of possible args    
1571                  for aarg in aargset:                  for aarg in aargset:
1572                      for barg in bargset:                      for barg in bargset:
                         if fix_rank_a is not None and getRank(aarg[0][0]) not in fix_rank_a:  
                             continue  
                         if fix_rank_b is not None and getRank(barg[0][0]) not in fix_rank_b:  
                             continue  
1573                          if cap_combined_rank and getRank(aarg[0][0])+getRank(barg[0][0])>4:                          if cap_combined_rank and getRank(aarg[0][0])+getRank(barg[0][0])>4:
1574                              continue  #resulting object too big                              continue  #resulting object too big
1575                          if no_shape_mismatch:                          if no_shape_mismatch:

Legend:
Removed from v.6482  
changed lines
  Added in v.6483

  ViewVC Help
Powered by ViewVC 1.1.26