/[escript]/trunk/escriptcore/test/python/test_util_binary_new.py
ViewVC logotype

Annotation of /trunk/escriptcore/test/python/test_util_binary_new.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 6485 - (hide annotations)
Wed Jan 25 04:45:38 2017 UTC (2 years, 2 months ago) by jfenwick
File MIME type: text/x-python
File size: 7265 byte(s)
Some more tests

1 jfenwick 6470
2     ##############################################################################
3     #
4     # Copyright (c) 2003-2016 by The University of Queensland
5     # http://www.uq.edu.au
6     #
7     # Primary Business: Queensland, Australia
8     # Licensed under the Apache License, version 2.0
9     # http://www.apache.org/licenses/LICENSE-2.0
10     #
11     # Development until 2012 by Earth Systems Science Computational Center (ESSCC)
12     # Development 2012-2013 by School of Earth Sciences
13     # Development from 2014 by Centre for Geoscience Computing (GeoComp)
14     #
15     ##############################################################################
16    
17     from __future__ import print_function, division
18    
19     __copyright__="""Copyright (c) 2003-2016 by The University of Queensland
20     http://www.uq.edu.au
21     Primary Business: Queensland, Australia"""
22     __license__="""Licensed under the Apache License, version 2.0
23     http://www.apache.org/licenses/LICENSE-2.0"""
24     __url__="https://launchpad.net/escript-finley"
25    
26     """
27     test for non-overloaded binary operations
28    
29     :remark: use see `test_util`
30     :var __author__: name of author
31     :var __copyright__: copyrights
32     :var __license__: licence agreement
33     :var __url__: url entry point on documentation
34     :var __version__: version
35     :var __date__: date of the version
36     """
37    
38     __author__="Joel Fenwick, joelfenwick@uq.edu.au"
39    
40     import esys.escriptcore.utestselect as unittest
41     import numpy
42     from esys.escript import *
43     from test_util_base import Test_util_values
44    
45 jfenwick 6480
46    
47    
48    
49    
50    
51 jfenwick 6470 class Test_util_binary_new(Test_util_values):
52 jfenwick 6480
53     def generate_indices(self, shape):
54     res=[0]*len(shape)
55     l=len(shape)
56     done=False
57     while not done:
58     yield tuple(res)
59     res[0]+=1
60     for i in range(l-1):
61     if res[i]>=shape[i]:
62     res[i]=0
63     res[i+1]+=1
64     else:
65     break
66     # now we check the last digit
67     if res[l-1]>=shape[l-1]:
68     done=True
69    
70    
71     def subst_outer(self, a, b):
72     if isinstance(a,float) or isinstance(a, complex):
73     a=(a,)
74     if isinstance(b,float) or isinstance(b, complex):
75     b=(b,)
76     sa=getShape(a)
77     sb=getShape(b)
78     a=numpy.array(a)
79     b=numpy.array(b)
80     targettype=a.dtype if a.dtype.kind=='c' else b.dtype
81     if sa==():
82     if sb==():
83     return a*b
84     resshape=sb
85     res=numpy.zeros(resshape, dtype=targettype)
86     for xb in self.generate_indices(sb):
87     res.itemset(xb,a*b.item(xb))
88     return res
89     elif sb==():
90     resshape=sa
91     res=numpy.zeros(resshape, dtype=targettype)
92     for xa in self.generate_indices(sa):
93     res.itemset(xa,a.item(xa)*b)
94     return res
95     else:
96     resshape=sa+sb
97     res=numpy.zeros(resshape, dtype=targettype)
98     for xa in self.generate_indices(sa):
99     for xb in self.generate_indices(sb):
100     res.itemset(xa+xb,a.item(xa)*b.item(xb))
101     return res
102    
103    
104 jfenwick 6470 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
105     def test_inner_combined(self):
106     opstring='inner(a,b)'
107     misccheck=None # How to work out what the result of type should be
108     oraclecheck="numpy.tensordot(refa, refb, axes=getRank(refa))"
109     opname="inner"
110     noshapemismatch=True
111     permitscalarmismatch=False
112     self.generate_binary_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, no_shape_mismatch=noshapemismatch, permit_scalar_mismatch=permitscalarmismatch)
113 jfenwick 6480 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
114     def test_outer_combined(self):
115     opstring='outer(a,b)'
116     misccheck=None # How to work out what the result of type should be
117     oraclecheck="self.subst_outer(refa,refb)"
118     opname="outer"
119     noshapemismatch=True
120     permitscalarmismatch=True
121     capcombinedrank=True
122     self.generate_binary_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, no_shape_mismatch=noshapemismatch, permit_scalar_mismatch=permitscalarmismatch, cap_combined_rank=capcombinedrank)
123 jfenwick 6482 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
124     def test_matrix_mult_combined(self):
125     opstring='matrix_mult(a,b)'
126     misccheck=None # How to work out what the result of type should be
127     oraclecheck="numpy.dot(refa,refb)"
128     opname="matrix_mult"
129 jfenwick 6485 aranks=(2,)
130     self.generate_binary_matrixlike_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, aranks=aranks)
131 jfenwick 6482 #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
132     def test_transpose_matrix_mult_combined(self):
133     opstring='transposed_matrix_mult(a,b)'
134     misccheck=None # How to work out what the result of type should be
135     oraclecheck="numpy.dot(numpy.transpose(refa),refb)"
136     opname="transposed_matrix_mult"
137 jfenwick 6485 aranks=(2,)
138     self.generate_binary_matrixlike_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, aranks=aranks)
139     #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
140     def test_matrix_transposed_mult_combined(self):
141     opstring='matrix_transposed_mult(a,b)'
142     misccheck=None # How to work out what the result of type should be
143     oraclecheck="numpy.dot(refa,numpy.transpose(refb))"
144     opname="matrix_transposed_mult"
145     aranks=(2,)
146     self.generate_binary_matrixlike_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, aranks=aranks)
147     #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
148     def test_tensor_mult_combined(self):
149     opstring='tensor_mult(a,b)'
150     misccheck=None # How to work out what the result of type should be
151     oraclecheck="numpy.dot(refa,refb) if getRank(refa)==2 else numpy.tensordot(refa,refb)"
152     opname="tensor_mult"
153     aranks=(2,4)
154     self.generate_binary_matrixlike_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, aranks=aranks)
155     #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
156     def test_transposed_tensor_mult_combined(self):
157     opstring='transposed_tensor_mult(a,b)'
158     misccheck=None # How to work out what the result of type should be
159     oraclecheck="numpy.dot(transpose(refa),refb) if getRank(refa)==2 else numpy.tensordot(transpose(refa),refb)"
160     opname="transposed_tensor_mult"
161     aranks=(2,4)
162     self.generate_binary_matrixlike_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, aranks=aranks)
163     #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
164     #def test_tensor_transposed_mult_combined(self):
165     #opstring='tensor_transposed_mult(a,b)'
166     #misccheck=None # How to work out what the result of type should be
167     #oraclecheck="numpy.dot(refa,transpose(refb)) if getRank(refa)==2 else numpy.tensordot(refa,transpose(refb))"
168     #opname="tensor_tranposed_mult"
169     #aranks=(2,4)
170     #self.generate_binary_matrixlike_operation_test_batch_large(opstring, misccheck, oraclecheck, opname, aranks=aranks)

  ViewVC Help
Powered by ViewVC 1.1.26