/[escript]/trunk/escriptcore/test/python/test_util_binary_new.py
ViewVC logotype

Contents of /trunk/escriptcore/test/python/test_util_binary_new.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 6545 - (show annotations)
Fri Mar 24 00:21:35 2017 UTC (8 months, 3 weeks ago) by jfenwick
File MIME type: text/x-python
File size: 8421 byte(s)
Fixing the last commit.   (I did test it)
1
2 ##############################################################################
3 #
4 # Copyright (c) 2003-2017 by The University of Queensland
5 # http://www.uq.edu.au
6 #
7 # Primary Business: Queensland, Australia
8 # Licensed under the Apache License, version 2.0
9 # http://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Development until 2012 by Earth Systems Science Computational Center (ESSCC)
12 # Development 2012-2013 by School of Earth Sciences
13 # Development from 2014 by Centre for Geoscience Computing (GeoComp)
14 #
15 ##############################################################################
16
17 from __future__ import print_function, division
18
19 __copyright__="""Copyright (c) 2003-2017 by The University of Queensland
20 http://www.uq.edu.au
21 Primary Business: Queensland, Australia"""
22 __license__="""Licensed under the Apache License, version 2.0
23 http://www.apache.org/licenses/LICENSE-2.0"""
24 __url__="https://launchpad.net/escript-finley"
25
26 """
27 test for non-overloaded binary operations
28
29 :remark: use see `test_util`
30 :var __author__: name of author
31 :var __copyright__: copyrights
32 :var __license__: licence agreement
33 :var __url__: url entry point on documentation
34 :var __version__: version
35 :var __date__: date of the version
36 """
37
38 __author__="Joel Fenwick, joelfenwick@uq.edu.au"
39
40 import esys.escriptcore.utestselect as unittest
41 import numpy
42 from esys.escript import *
43 from test_util_base import Test_util_values
44
45
46
47
48
49
50
51 class Test_util_binary_new(Test_util_values):
52
53 def generate_indices(self, shape):
54 res=[0]*len(shape)
55 l=len(shape)
56 done=False
57 while not done:
58 yield tuple(res)
59 res[0]+=1
60 for i in range(l-1):
61 if res[i]>=shape[i]:
62 res[i]=0
63 res[i+1]+=1
64 else:
65 break
66 # now we check the last digit
67 if res[l-1]>=shape[l-1]:
68 done=True
69
70
71 def subst_outer(self, a, b):
72 if isinstance(a,float) or isinstance(a, complex):
73 a=(a,)
74 if isinstance(b,float) or isinstance(b, complex):
75 b=(b,)
76 sa=getShape(a)
77 sb=getShape(b)
78 a=numpy.array(a)
79 b=numpy.array(b)
80 targettype=a.dtype if a.dtype.kind=='c' else b.dtype
81 if sa==():
82 if sb==():
83 return a*b
84 resshape=sb
85 res=numpy.zeros(resshape, dtype=targettype)
86 for xb in self.generate_indices(sb):
87 res.itemset(xb,a*b.item(xb))
88 return res
89 elif sb==():
90 resshape=sa
91 res=numpy.zeros(resshape, dtype=targettype)
92 for xa in self.generate_indices(sa):
93 res.itemset(xa,a.item(xa)*b)
94 return res
95 else:
96 resshape=sa+sb
97 res=numpy.zeros(resshape, dtype=targettype)
98 for xa in self.generate_indices(sa):
99 for xb in self.generate_indices(sb):
100 res.itemset(xa+xb,a.item(xa)*b.item(xb))
101 return res
102
103
104 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
105 def test_inner_combined(self):
106 opstring='inner(a,b)'
107 misccheck=None # How to work out what the result of type should be
108 oraclecheck="numpy.tensordot(refa, refb, axes=getRank(refa))"
109 opname="inner"
110 noshapemismatch=True
111 permitscalarmismatch=False
112 self.generate_binary_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, no_shape_mismatch=noshapemismatch, permit_scalar_mismatch=permitscalarmismatch)
113 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
114 def test_outer_combined(self):
115 opstring='outer(a,b)'
116 misccheck=None # How to work out what the result of type should be
117 oraclecheck="self.subst_outer(refa,refb)"
118 opname="outer"
119 noshapemismatch=True
120 permitscalarmismatch=True
121 capcombinedrank=True
122 self.generate_binary_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, no_shape_mismatch=noshapemismatch, permit_scalar_mismatch=permitscalarmismatch, cap_combined_rank=capcombinedrank)
123 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
124 def test_matrix_minimum_combined(self):
125 opstring='minimum(a,b)'
126 misccheck=None # How to work out what the result of type should be
127 oraclecheck="numpy.minimum(refa,refb)"
128 opname="minimum"
129 noshapemismatch=True
130 permitscalarmismatch=True
131 cplx=False
132 self.generate_binary_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, no_shape_mismatch=noshapemismatch, permit_scalar_mismatch=permitscalarmismatch, support_cplx=cplx)
133 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
134 def test_matrix_maximum_combined(self):
135 opstring='maximum(a,b)'
136 misccheck=None # How to work out what the result of type should be
137 oraclecheck="numpy.maximum(refa,refb)"
138 opname="maximum"
139 noshapemismatch=True
140 permitscalarmismatch=True
141 cplx=False
142 self.generate_binary_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, no_shape_mismatch=noshapemismatch, permit_scalar_mismatch=permitscalarmismatch, support_cplx=cplx)
143 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
144 def test_matrix_mult_combined(self):
145 opstring='matrix_mult(a,b)'
146 misccheck=None # How to work out what the result of type should be
147 oraclecheck="numpy.dot(refa,refb)"
148 opname="matrix_mult"
149 aranks=(2,)
150 self.generate_binary_matrixlike_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, aranks=aranks)
151 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
152 def test_transpose_matrix_mult_combined(self):
153 opstring='transposed_matrix_mult(a,b)'
154 misccheck=None # How to work out what the result of type should be
155 oraclecheck="numpy.dot(numpy.transpose(refa),refb)"
156 opname="transposed_matrix_mult"
157 aranks=(2,)
158 self.generate_binary_matrixlike_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, aranks=aranks)
159 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
160 def test_matrix_transposed_mult_combined(self):
161 opstring='matrix_transposed_mult(a,b)'
162 misccheck=None # How to work out what the result of type should be
163 oraclecheck="numpy.dot(refa,numpy.transpose(refb))"
164 opname="matrix_transposed_mult"
165 aranks=(2,)
166 self.generate_binary_matrixlike_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, aranks=aranks)
167 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
168 def test_tensor_mult_combined(self):
169 opstring='tensor_mult(a,b)'
170 misccheck=None # How to work out what the result of type should be
171 oraclecheck="numpy.dot(refa,refb) if getRank(refa)==2 else numpy.tensordot(refa,refb)"
172 opname="tensor_mult"
173 aranks=(2,4)
174 self.generate_binary_matrixlike_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, aranks=aranks)
175 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
176 def test_transposed_tensor_mult_combined(self):
177 opstring='transposed_tensor_mult(a,b)'
178 misccheck=None # How to work out what the result of type should be
179 oraclecheck="numpy.dot(transpose(refa),refb) if getRank(refa)==2 else numpy.tensordot(transpose(refa),refb)"
180 opname="transposed_tensor_mult"
181 aranks=(2,4)
182 self.generate_binary_matrixlike_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, aranks=aranks)
183 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
184 #def test_tensor_transposed_mult_combined(self):
185 #opstring='tensor_transposed_mult(a,b)'
186 #misccheck=None # How to work out what the result of type should be
187 #oraclecheck="numpy.dot(refa,transpose(refb)) if getRank(refa)==2 else numpy.tensordot(refa,transpose(refb))"
188 #opname="tensor_tranposed_mult"
189 #aranks=(2,4)
190 #self.generate_binary_matrixlike_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, aranks=aranks)

  ViewVC Help
Powered by ViewVC 1.1.26