/[escript]/trunk/escriptcore/test/python/test_util_binary_new.py
ViewVC logotype

Annotation of /trunk/escriptcore/test/python/test_util_binary_new.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 6544 - (hide annotations)
Thu Mar 23 23:59:21 2017 UTC (2 years ago) by jfenwick
File MIME type: text/x-python
File size: 8372 byte(s)
Add two tests
1 jfenwick 6470
2     ##############################################################################
3     #
4 jfenwick 6523 # Copyright (c) 2003-2017 by The University of Queensland
5 jfenwick 6470 # http://www.uq.edu.au
6     #
7     # Primary Business: Queensland, Australia
8     # Licensed under the Apache License, version 2.0
9     # http://www.apache.org/licenses/LICENSE-2.0
10     #
11     # Development until 2012 by Earth Systems Science Computational Center (ESSCC)
12     # Development 2012-2013 by School of Earth Sciences
13     # Development from 2014 by Centre for Geoscience Computing (GeoComp)
14     #
15     ##############################################################################
16    
17     from __future__ import print_function, division
18    
19 jfenwick 6523 __copyright__="""Copyright (c) 2003-2017 by The University of Queensland
20 jfenwick 6470 http://www.uq.edu.au
21     Primary Business: Queensland, Australia"""
22     __license__="""Licensed under the Apache License, version 2.0
23     http://www.apache.org/licenses/LICENSE-2.0"""
24     __url__="https://launchpad.net/escript-finley"
25    
26     """
27     test for non-overloaded binary operations
28    
29     :remark: use see `test_util`
30     :var __author__: name of author
31     :var __copyright__: copyrights
32     :var __license__: licence agreement
33     :var __url__: url entry point on documentation
34     :var __version__: version
35     :var __date__: date of the version
36     """
37    
38     __author__="Joel Fenwick, joelfenwick@uq.edu.au"
39    
40     import esys.escriptcore.utestselect as unittest
41     import numpy
42     from esys.escript import *
43     from test_util_base import Test_util_values
44    
45 jfenwick 6480
46    
47    
48    
49    
50    
51 jfenwick 6470 class Test_util_binary_new(Test_util_values):
52 jfenwick 6480
53     def generate_indices(self, shape):
54     res=[0]*len(shape)
55     l=len(shape)
56     done=False
57     while not done:
58     yield tuple(res)
59     res[0]+=1
60     for i in range(l-1):
61     if res[i]>=shape[i]:
62     res[i]=0
63     res[i+1]+=1
64     else:
65     break
66     # now we check the last digit
67     if res[l-1]>=shape[l-1]:
68     done=True
69    
70    
71     def subst_outer(self, a, b):
72     if isinstance(a,float) or isinstance(a, complex):
73     a=(a,)
74     if isinstance(b,float) or isinstance(b, complex):
75     b=(b,)
76     sa=getShape(a)
77     sb=getShape(b)
78     a=numpy.array(a)
79     b=numpy.array(b)
80     targettype=a.dtype if a.dtype.kind=='c' else b.dtype
81     if sa==():
82     if sb==():
83     return a*b
84     resshape=sb
85     res=numpy.zeros(resshape, dtype=targettype)
86     for xb in self.generate_indices(sb):
87     res.itemset(xb,a*b.item(xb))
88     return res
89     elif sb==():
90     resshape=sa
91     res=numpy.zeros(resshape, dtype=targettype)
92     for xa in self.generate_indices(sa):
93     res.itemset(xa,a.item(xa)*b)
94     return res
95     else:
96     resshape=sa+sb
97     res=numpy.zeros(resshape, dtype=targettype)
98     for xa in self.generate_indices(sa):
99     for xb in self.generate_indices(sb):
100     res.itemset(xa+xb,a.item(xa)*b.item(xb))
101     return res
102    
103    
104 jfenwick 6470 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
105     def test_inner_combined(self):
106     opstring='inner(a,b)'
107     misccheck=None # How to work out what the result of type should be
108     oraclecheck="numpy.tensordot(refa, refb, axes=getRank(refa))"
109     opname="inner"
110     noshapemismatch=True
111     permitscalarmismatch=False
112     self.generate_binary_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, no_shape_mismatch=noshapemismatch, permit_scalar_mismatch=permitscalarmismatch)
113 jfenwick 6480 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
114     def test_outer_combined(self):
115     opstring='outer(a,b)'
116     misccheck=None # How to work out what the result of type should be
117     oraclecheck="self.subst_outer(refa,refb)"
118     opname="outer"
119     noshapemismatch=True
120     permitscalarmismatch=True
121     capcombinedrank=True
122     self.generate_binary_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, no_shape_mismatch=noshapemismatch, permit_scalar_mismatch=permitscalarmismatch, cap_combined_rank=capcombinedrank)
123 jfenwick 6482 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
124 jfenwick 6544 def test_matrix_minimum_combined(self):
125     opstring='minimum(a,b)'
126     misccheck=None # How to work out what the result of type should be
127     oraclecheck="numpy.minimum(refa,refb)"
128     opname="minimum"
129     noshapemismatch=True
130     permitscalarmismatch=True
131     self.generate_binary_matrixlike_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, no_shape_mismatch=noshapemismatch, permit_scalar_mismatch=permitscalarmismatch)
132     #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
133     def test_matrix_maximum_combined(self):
134     opstring='maximum(a,b)'
135     misccheck=None # How to work out what the result of type should be
136     oraclecheck="numpy.maximum(refa,refb)"
137     opname="maximum"
138     noshapemismatch=True
139     permitscalarmismatch=True
140     self.generate_binary_matrixlike_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, no_shape_mismatch=noshapemismatch, permit_scalar_mismatch=permitscalarmismatch)
141     #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
142 jfenwick 6482 def test_matrix_mult_combined(self):
143     opstring='matrix_mult(a,b)'
144     misccheck=None # How to work out what the result of type should be
145     oraclecheck="numpy.dot(refa,refb)"
146     opname="matrix_mult"
147 jfenwick 6485 aranks=(2,)
148     self.generate_binary_matrixlike_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, aranks=aranks)
149 jfenwick 6482 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
150     def test_transpose_matrix_mult_combined(self):
151     opstring='transposed_matrix_mult(a,b)'
152     misccheck=None # How to work out what the result of type should be
153     oraclecheck="numpy.dot(numpy.transpose(refa),refb)"
154     opname="transposed_matrix_mult"
155 jfenwick 6485 aranks=(2,)
156     self.generate_binary_matrixlike_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, aranks=aranks)
157     #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
158     def test_matrix_transposed_mult_combined(self):
159     opstring='matrix_transposed_mult(a,b)'
160     misccheck=None # How to work out what the result of type should be
161     oraclecheck="numpy.dot(refa,numpy.transpose(refb))"
162     opname="matrix_transposed_mult"
163     aranks=(2,)
164     self.generate_binary_matrixlike_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, aranks=aranks)
165     #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
166     def test_tensor_mult_combined(self):
167     opstring='tensor_mult(a,b)'
168     misccheck=None # How to work out what the result of type should be
169     oraclecheck="numpy.dot(refa,refb) if getRank(refa)==2 else numpy.tensordot(refa,refb)"
170     opname="tensor_mult"
171     aranks=(2,4)
172     self.generate_binary_matrixlike_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, aranks=aranks)
173     #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
174     def test_transposed_tensor_mult_combined(self):
175     opstring='transposed_tensor_mult(a,b)'
176     misccheck=None # How to work out what the result of type should be
177     oraclecheck="numpy.dot(transpose(refa),refb) if getRank(refa)==2 else numpy.tensordot(transpose(refa),refb)"
178     opname="transposed_tensor_mult"
179     aranks=(2,4)
180     self.generate_binary_matrixlike_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, aranks=aranks)
181     #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
182     #def test_tensor_transposed_mult_combined(self):
183     #opstring='tensor_transposed_mult(a,b)'
184     #misccheck=None # How to work out what the result of type should be
185     #oraclecheck="numpy.dot(refa,transpose(refb)) if getRank(refa)==2 else numpy.tensordot(refa,transpose(refb))"
186     #opname="tensor_tranposed_mult"
187     #aranks=(2,4)
188     #self.generate_binary_matrixlike_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, aranks=aranks)

  ViewVC Help
Powered by ViewVC 1.1.26