/[escript]/trunk/escriptcore/test/python/test_util_binary_new.py
ViewVC logotype

Contents of /trunk/escriptcore/test/python/test_util_binary_new.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 6546 - (show annotations)
Mon Mar 27 00:00:49 2017 UTC (8 months, 3 weeks ago) by jfenwick
File MIME type: text/x-python
File size: 9037 byte(s)
Add test for overloaded +

Also add param to generator that will prevent
testing:
  numpy op Data
Because numpy applies its own interpretation

1
2 ##############################################################################
3 #
4 # Copyright (c) 2003-2017 by The University of Queensland
5 # http://www.uq.edu.au
6 #
7 # Primary Business: Queensland, Australia
8 # Licensed under the Apache License, version 2.0
9 # http://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Development until 2012 by Earth Systems Science Computational Center (ESSCC)
12 # Development 2012-2013 by School of Earth Sciences
13 # Development from 2014 by Centre for Geoscience Computing (GeoComp)
14 #
15 ##############################################################################
16
17 from __future__ import print_function, division
18
19 __copyright__="""Copyright (c) 2003-2017 by The University of Queensland
20 http://www.uq.edu.au
21 Primary Business: Queensland, Australia"""
22 __license__="""Licensed under the Apache License, version 2.0
23 http://www.apache.org/licenses/LICENSE-2.0"""
24 __url__="https://launchpad.net/escript-finley"
25
26 """
27 test for non-overloaded binary operations
28
29 :remark: use see `test_util`
30 :var __author__: name of author
31 :var __copyright__: copyrights
32 :var __license__: licence agreement
33 :var __url__: url entry point on documentation
34 :var __version__: version
35 :var __date__: date of the version
36 """
37
38 __author__="Joel Fenwick, joelfenwick@uq.edu.au"
39
40 import esys.escriptcore.utestselect as unittest
41 import numpy
42 from esys.escript import *
43 from test_util_base import Test_util_values
44
45
46
47
48
49
50
51 class Test_util_binary_new(Test_util_values):
52
53 def generate_indices(self, shape):
54 res=[0]*len(shape)
55 l=len(shape)
56 done=False
57 while not done:
58 yield tuple(res)
59 res[0]+=1
60 for i in range(l-1):
61 if res[i]>=shape[i]:
62 res[i]=0
63 res[i+1]+=1
64 else:
65 break
66 # now we check the last digit
67 if res[l-1]>=shape[l-1]:
68 done=True
69
70
71 def subst_outer(self, a, b):
72 if isinstance(a,float) or isinstance(a, complex):
73 a=(a,)
74 if isinstance(b,float) or isinstance(b, complex):
75 b=(b,)
76 sa=getShape(a)
77 sb=getShape(b)
78 a=numpy.array(a)
79 b=numpy.array(b)
80 targettype=a.dtype if a.dtype.kind=='c' else b.dtype
81 if sa==():
82 if sb==():
83 return a*b
84 resshape=sb
85 res=numpy.zeros(resshape, dtype=targettype)
86 for xb in self.generate_indices(sb):
87 res.itemset(xb,a*b.item(xb))
88 return res
89 elif sb==():
90 resshape=sa
91 res=numpy.zeros(resshape, dtype=targettype)
92 for xa in self.generate_indices(sa):
93 res.itemset(xa,a.item(xa)*b)
94 return res
95 else:
96 resshape=sa+sb
97 res=numpy.zeros(resshape, dtype=targettype)
98 for xa in self.generate_indices(sa):
99 for xb in self.generate_indices(sb):
100 res.itemset(xa+xb,a.item(xa)*b.item(xb))
101 return res
102
103 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
104 def test_add_combined(self):
105 opstring='a+b'
106 misccheck='isinstance(res, Data) if isinstance(a, Data) or isinstance(b, Data) else True' # doesn't cover all cases;
107 oraclecheck="refa+refb"
108 opname="add"
109 noshapemismatch=True
110 permitscalarmismatch=True
111 permit_array_op_data=False
112 self.generate_binary_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, no_shape_mismatch=noshapemismatch, permit_scalar_mismatch=permitscalarmismatch, permit_array_op_data=permit_array_op_data)
113 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
114 def test_inner_combined(self):
115 opstring='inner(a,b)'
116 misccheck=None # How to work out what the result of type should be
117 oraclecheck="numpy.tensordot(refa, refb, axes=getRank(refa))"
118 opname="inner"
119 noshapemismatch=True
120 permitscalarmismatch=False
121 self.generate_binary_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, no_shape_mismatch=noshapemismatch, permit_scalar_mismatch=permitscalarmismatch)
122 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
123 def test_outer_combined(self):
124 opstring='outer(a,b)'
125 misccheck=None # How to work out what the result of type should be
126 oraclecheck="self.subst_outer(refa,refb)"
127 opname="outer"
128 noshapemismatch=True
129 permitscalarmismatch=True
130 capcombinedrank=True
131 self.generate_binary_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, no_shape_mismatch=noshapemismatch, permit_scalar_mismatch=permitscalarmismatch, cap_combined_rank=capcombinedrank)
132 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
133 def test_matrix_minimum_combined(self):
134 opstring='minimum(a,b)'
135 misccheck=None # How to work out what the result of type should be
136 oraclecheck="numpy.minimum(refa,refb)"
137 opname="minimum"
138 noshapemismatch=True
139 permitscalarmismatch=True
140 cplx=False
141 self.generate_binary_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, no_shape_mismatch=noshapemismatch, permit_scalar_mismatch=permitscalarmismatch, support_cplx=cplx)
142 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
143 def test_matrix_maximum_combined(self):
144 opstring='maximum(a,b)'
145 misccheck=None # How to work out what the result of type should be
146 oraclecheck="numpy.maximum(refa,refb)"
147 opname="maximum"
148 noshapemismatch=True
149 permitscalarmismatch=True
150 cplx=False
151 self.generate_binary_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, no_shape_mismatch=noshapemismatch, permit_scalar_mismatch=permitscalarmismatch, support_cplx=cplx)
152 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
153 def test_matrix_mult_combined(self):
154 opstring='matrix_mult(a,b)'
155 misccheck=None # How to work out what the result of type should be
156 oraclecheck="numpy.dot(refa,refb)"
157 opname="matrix_mult"
158 aranks=(2,)
159 self.generate_binary_matrixlike_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, aranks=aranks)
160 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
161 def test_transpose_matrix_mult_combined(self):
162 opstring='transposed_matrix_mult(a,b)'
163 misccheck=None # How to work out what the result of type should be
164 oraclecheck="numpy.dot(numpy.transpose(refa),refb)"
165 opname="transposed_matrix_mult"
166 aranks=(2,)
167 self.generate_binary_matrixlike_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, aranks=aranks)
168 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
169 def test_matrix_transposed_mult_combined(self):
170 opstring='matrix_transposed_mult(a,b)'
171 misccheck=None # How to work out what the result of type should be
172 oraclecheck="numpy.dot(refa,numpy.transpose(refb))"
173 opname="matrix_transposed_mult"
174 aranks=(2,)
175 self.generate_binary_matrixlike_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, aranks=aranks)
176 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
177 def test_tensor_mult_combined(self):
178 opstring='tensor_mult(a,b)'
179 misccheck=None # How to work out what the result of type should be
180 oraclecheck="numpy.dot(refa,refb) if getRank(refa)==2 else numpy.tensordot(refa,refb)"
181 opname="tensor_mult"
182 aranks=(2,4)
183 self.generate_binary_matrixlike_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, aranks=aranks)
184 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
185 def test_transposed_tensor_mult_combined(self):
186 opstring='transposed_tensor_mult(a,b)'
187 misccheck=None # How to work out what the result of type should be
188 oraclecheck="numpy.dot(transpose(refa),refb) if getRank(refa)==2 else numpy.tensordot(transpose(refa),refb)"
189 opname="transposed_tensor_mult"
190 aranks=(2,4)
191 self.generate_binary_matrixlike_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, aranks=aranks)
192 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
193 #def test_tensor_transposed_mult_combined(self):
194 #opstring='tensor_transposed_mult(a,b)'
195 #misccheck=None # How to work out what the result of type should be
196 #oraclecheck="numpy.dot(refa,transpose(refb)) if getRank(refa)==2 else numpy.tensordot(refa,transpose(refb))"
197 #opname="tensor_tranposed_mult"
198 #aranks=(2,4)
199 #self.generate_binary_matrixlike_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, aranks=aranks)

  ViewVC Help
Powered by ViewVC 1.1.26