/[escript]/trunk/ripley/generators/lamebuilder.py
ViewVC logotype

Contents of /trunk/ripley/generators/lamebuilder.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 5707 - (show annotations)
Mon Jun 29 03:59:06 2015 UTC (3 years, 8 months ago) by sshaw
File MIME type: text/x-python
File size: 13299 byte(s)
adding copyright headers to files without copyright info, moved header to top of file in some cases where it wasn't
1
2 ##############################################################################
3 #
4 # Copyright (c) 2003-2015 by The University of Queensland
5 # http://www.uq.edu.au
6 #
7 # Primary Business: Queensland, Australia
8 # Licensed under the Open Software License version 3.0
9 # http://www.opensource.org/licenses/osl-3.0.php
10 #
11 # Development until 2012 by Earth Systems Science Computational Center (ESSCC)
12 # Development 2012-2013 by School of Earth Sciences
13 # Development from 2014 by Centre for Geoscience Computing (GeoComp)
14 #
15 ##############################################################################
16
17 from __future__ import print_function, division
18
19 import lamesource
20 import sys
21
22 def buildTempAndSummation(dim, ids, temps, summations, forced_substitutions = []):
23 declarations = {}
24 sumStatements = {}
25 for k in range(dim):
26 for m in range(dim):
27 declarations[(k,m)] = []
28 sumStatements[(k,m)] = []
29 zeroes = [] #tracks tmpvars that will always be zero
30 nonzeroes = [] #tracks tmpvars that will possibly be non-zero
31 use_counts = {} #tmp var usage counts, used later for culling
32 #temp declarations
33 #hold on to the building expression for every identifier
34 var_expressions = {}
35 for line in temps:
36 newline = line.format(k,m)
37 insert = False
38 for index in ids.iterkeys():
39 if index in newline:
40 insert = True
41 break
42 if not insert:
43 value = newline.split("=")[1]
44 for tmp in nonzeroes:
45 if tmp in value:
46 x = use_counts.get(tmp, 0)
47 use_counts[tmp] = x+1
48 insert = True
49 i = newline.index("tmp")
50 identifier = newline[i:].split()[0]
51 expression = newline.split("= ")[1][:-1]
52 if insert:
53 nonzeroes.append(identifier)
54 var_expressions[identifier] = expression
55 else:
56 zeroes.append(identifier)
57 #summations
58 with_nonzero = [] #only those with some non-zero temp var
59 for line in summations:
60 newline = replaceZeroes(line.format(k,m),zeroes)
61 for nz in nonzeroes:
62 if nz in newline:
63 with_nonzero.append(newline)
64 components = newline[:-1].split("+=")[1].lstrip().replace(" + ", "|").replace(" - ", "|")
65 if components[0] == "-":
66 components = components[1:]
67 components = components.split("|")
68 for var in components:
69 x = use_counts.get(var, 0)
70 use_counts[var] = x+1
71 break
72 for fs in forced_substitutions:
73 if fs in nonzeroes:
74 use_counts[fs] = 1
75 for z in zeroes:
76 use_counts[z] = 0
77 #only interested in keeping variables declarations with 2 or more uses
78 for var in nonzeroes:
79 #0 we don't care and 1 we substitute the expression
80 if use_counts.get(var, 0) > 1:
81 declarations[(k,m)].append(" const double %s = %s;"%(var, var_expressions[var]))
82 #remove zeroes and replace single use tmpvars with their expression
83 for line in with_nonzero:
84 for key in use_counts.iterkeys():
85 if use_counts[key] == 0:
86 line = line.replace("%s;"%key, ";")
87 line = line.replace("%s "%key, " ")
88 elif use_counts[key] < 2:
89 s = var_expressions[key]
90 line = line.replace("%s;"%key, s+";")
91 line = line.replace("%s "%key, s+" ")
92 #print only if there's a right-hand-side of the expression
93 if "+=;" not in line and "+=;" not in line:
94 sumStatements[(k,m)].append(" "+line)
95 return declarations, sumStatements
96
97 def replaceZeroes(line, zeroes):
98 for zero in zeroes:
99 if zero+"|" in line:
100 line = line.replace(" + %s|"%zero, "")
101 line = line.replace("%s| + "%zero, "0 + ") #there's a better solution
102 line = line.replace(" - %s|"%zero, "") #it's just to stop
103 line = line.replace("%s| - "%zero, "0 - ") # a - b -> b instead of -b
104 line = line.replace("=%s|;"%zero,"=;")
105 if line:
106 return line.replace("|","").replace(" - 0", "").replace(" + 0", "").replace("+= 0 +", "+=").replace("+= 0 - ", "+=-")
107 return line
108
109 def print2DAExpanded():
110 dim = 2
111 quads = 2**dim
112 ids = {}
113 for i in range(dim):
114 for j in range(dim):
115 ids["{0}{0}{1}{1}".format(i,j)] = None
116 ids["{0}{1}{1}{0}".format(i,j)] = None
117 ids["{0}{1}{0}{1}".format(i,j)] = None
118
119 for name in sorted(ids.iterkeys()):
120 print("double A_{0}[{1}] =".format(name,quads), "{0};")
121 # ijji += mu
122 # ijij += mu
123 # iijj += lambda
124 print("if (!mu.isEmpty()) {\n const double *mu_p = mu.getSampleDataRO(e);")
125 completed = {}
126 for i in range(dim):
127 for j in range(dim):
128 for q in range(quads):
129 if i == j:
130 print(" A_{0}{0}{0}{0}[{1}] += 2*mu_p[{1}];".format(i,q))
131 else:
132 print(" A_{0}{1}{1}{0}[{2}] += mu_p[{2}];".format(i,j,q))
133 print(" A_{0}{1}{0}{1}[{2}] += mu_p[{2}];".format(i,j,q))
134 print("}\nif (!lambda.isEmpty()) {\n const double *lambda_p = lambda.getSampleDataRO(e);")
135 for i in range(dim):
136 for j in range(dim):
137 for q in range(quads):
138 print(" A_{0}{0}{1}{1}[{2}] += lambda_p[{2}];".format(i,j,q))
139 print("}")
140
141 decl, sums = buildTempAndSummation(dim, ids, lamesource.expanded2Dtemps, lamesource.expanded2Dsummations)
142 for k in range(dim):
143 for m in range(dim):
144 print("{")
145 print("\n".join(decl[(k,m)]))
146 print("\n".join(sums[(k,m)]))
147 print("}")
148
149 def print2DAReduced():
150 dim = 2
151 quads = 2**dim
152 ids = {}
153 for i in range(dim):
154 for j in range(dim):
155 ids["{0}{0}{1}{1}".format(i,j)] = None
156 ids["{0}{1}{1}{0}".format(i,j)] = None
157 ids["{0}{1}{0}{1}".format(i,j)] = None
158
159 for name in sorted(ids.iterkeys()):
160 print("double A_{0} =".format(name,quads), "0;")
161 # ijji += mu
162 # ijij += mu
163 # iijj += lambda
164
165 print("if (!mu.isEmpty()) {\n const double *mu_p = mu.getSampleDataRO(e);")
166
167 for i in range(dim):
168 for j in range(dim):
169 if i == j:
170 print(" A_{0}{0}{0}{0} += 2*mu_p[0];".format(i))
171 else:
172 print(" A_{0}{1}{1}{0} += mu_p[0];".format(i,j))
173 print(" A_{0}{1}{0}{1} += mu_p[0];".format(i,j))
174 print("}")
175 print("if (!lambda.isEmpty()) {\n const double *lambda_p = lambda.getSampleDataRO(e);")
176 for i in range(dim):
177 for j in range(dim):
178 print(" A_{0}{0}{1}{1} += lambda_p[0];".format(i,j))
179 print("}")
180
181 lines = [
182 "const double tmp_0 = 6*w1*(A_{0}0{1}1 - A_{0}1{1}0);",
183 "const double tmp_1 = 6*w1*(A_{0}0{1}1 + A_{0}1{1}0);",
184 "const double tmp_2 = 6*w1*(-A_{0}0{1}1 - A_{0}1{1}0);",
185 "const double tmp_3 = 6*w1*(-A_{0}0{1}1 + A_{0}1{1}0);"
186 ]
187
188 zeroes = []
189 nonzeroes = []
190 for k in range(dim):
191 for m in range(dim):
192 if k == m:
193 for q in range(quads):
194 zeroes.append("tmp{0}{0}_{1}".format(m, q))
195 continue
196 for line in lines:
197 newline = line.format(k,m)
198 insert = False
199 for index in ids.iterkeys():
200 if index in newline:
201 insert = True
202 break
203 i = newline.index("tmp")
204 if insert:
205 print(newline)
206 nonzeroes.append(newline[i:].split()[0].rstrip())
207 else:
208 zeroes.append(newline[i:].split()[0].rstrip())
209
210 for k in range(dim):
211 for m in range(dim):
212 for line in lamesource.reduced2Dsummations:
213 newline = replaceZeroes(line.format(k,m),zeroes)
214 if not newline:
215 continue
216 if k != m:
217 st = "A_{0}0{1}0".format(k,m)
218 if st in newline:
219 i = newline.index(st)
220 newline = newline[:i-3] + newline[i+12:]
221 st = "A_{0}1{1}1".format(k,m)
222 if st in newline:
223 i = newline.index(st)
224 newline = newline[:i-2] + newline[i+11:]
225 print(newline)
226
227 def print3DAExpanded():
228 dim = 3
229 quads = 2**dim
230 ids = {}
231 for i in range(dim):
232 for j in range(dim):
233 ids["{0}{0}{1}{1}".format(i,j)] = None
234 ids["{0}{1}{1}{0}".format(i,j)] = None
235 ids["{0}{1}{0}{1}".format(i,j)] = None
236
237 for name in sorted(ids.iterkeys()):
238 print("double A_{0}[{1}] =".format(name,quads), "{0};")
239 # ijji += mu
240 # ijij += mu
241 # iijj += lambda
242 print("if (!mu.isEmpty()) {\n const double *mu_p = mu.getSampleDataRO(e);")
243 for i in range(dim):
244 for j in range(dim):
245 for q in range(quads):
246 if i == j:
247 print(" A_{0}{0}{0}{0}[{1}] += 2*mu_p[{1}];".format(i,q))
248 else:
249 print(" A_{0}{1}{1}{0}[{2}] += mu_p[{2}];".format(i,j,q))
250 print(" A_{0}{1}{0}{1}[{2}] += mu_p[{2}];".format(i,j,q))
251 print("}")
252 print("if (!lambda.isEmpty()) {\n const double *lambda_p = lambda.getSampleDataRO(e);")
253 for i in range(dim):
254 for j in range(dim):
255 for q in range(quads):
256 print(" A_{0}{0}{1}{1}[{2}] += lambda_p[{2}];".format(i,j,q))
257 print("}")
258
259 decl, sums = buildTempAndSummation(dim, ids, lamesource.expanded3Dtemps, lamesource.expanded3Dsummations)
260 for k in range(dim):
261 for m in range(dim):
262 print("{")
263 print("\n".join(decl[(k,m)]))
264 print("\n".join(sums[(k,m)]))
265 print("}")
266
267 def print3DAReduced():
268 dim = 3
269 ids = {}
270 for i in range(dim):
271 for j in range(dim):
272 ids["{0}{0}{1}{1}".format(i,j)] = None
273 ids["{0}{1}{1}{0}".format(i,j)] = None
274 ids["{0}{1}{0}{1}".format(i,j)] = None
275
276 for i in sorted(ids.iterkeys()):
277 print("double Aw%s = 0;"%i)
278
279 print("if (!mu.isEmpty()) {\n const double *mu_p = mu.getSampleDataRO(e);")
280 for i in range(dim):
281 for j in range(dim):
282 if i == j:
283 print(" Aw{0}{0}{0}{0} += 2*mu_p[0];".format(i))
284 else:
285 print(" Aw{0}{1}{1}{0} += mu_p[0];".format(i,j))
286 print(" Aw{0}{1}{0}{1} += mu_p[0];".format(i,j))
287
288 print("}\nif (!lambda.isEmpty()) {\n const double *lambda_p = lambda.getSampleDataRO(e);")
289 for i in range(dim):
290 for j in range(dim):
291 print(" Aw{0}{0}{1}{1} += lambda_p[0];".format(i,j))
292 print("}")
293
294 for k in range(dim):
295 for m in range(dim):
296 for line in ["Aw{0}0{1}0 *= 8*w27;","Aw{0}0{1}1 *= 12*w8;","Aw{0}0{1}2 *= 12*w11;",
297 "Aw{0}1{1}0 *= 12*w8;","Aw{0}1{1}1 *= 8*w22;","Aw{0}1{1}2 *= 12*w10;",
298 "Aw{0}2{1}0 *= 12*w11;","Aw{0}2{1}1 *= 12*w10;","Aw{0}2{1}2 *= 8*w13;"]:
299 newline = line.format(k,m)
300 found = False
301 for ident in ids.iterkeys():
302 if ident in newline:
303 found = True
304 break
305 if found:
306 print(newline)
307
308 decs, sums = buildTempAndSummation(dim, ids, lamesource.reduced3Dtemps,
309 lamesource.reduced3Dsummations,
310 forced_substitutions = ["tmp12","tmp13","tmp14","tmp21","tmp22","tmp23"])
311 for k in range(dim):
312 for m in range(dim):
313 print("{")
314 print("\n".join(decs[(k,m)]))
315 print("\n".join(sums[(k,m)]))
316 print("}")
317
318
319
320 def printAReduced(dim):
321 if dim == 2:
322 return print2DAReduced()
323 elif dim == 3:
324 return print3DAReduced()
325 else:
326 raise
327
328 def printAExpanded(dim):
329 if dim == 2:
330 return print2DAExpanded()
331 elif dim == 3:
332 return print3DAExpanded()
333 else:
334 raise
335
336 if __name__ == "__main__":
337 if len(sys.argv) < 3 or sys.argv[1] not in ["R", "E"] or sys.argv[2] not in ["2","3"]:
338 print("Usage: {0} [R]educed/[E]xpanded dimensions\nE.g. {0} R 3")
339 exit(1)
340 dim = int(sys.argv[2])
341 if sys.argv[1] == "R":
342 printAReduced(dim)
343 else:
344 printAExpanded(dim)
345

  ViewVC Help
Powered by ViewVC 1.1.26