/[escript]/trunk/ripley/generators/lamebuilder.py
ViewVC logotype

Contents of /trunk/ripley/generators/lamebuilder.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 4799 - (show annotations)
Wed Mar 26 00:37:59 2014 UTC (5 years, 10 months ago) by sshaw
File MIME type: text/x-python
File size: 12694 byte(s)
storing the code generators for the relevant chunks of the lame assembler
1 from __future__ import print_function
2
3 import lamesource
4 import sys
5
6 def buildTempAndSummation(dim, ids, temps, summations, forced_substitutions = []):
7 declarations = {}
8 sumStatements = {}
9 for k in range(dim):
10 for m in range(dim):
11 declarations[(k,m)] = []
12 sumStatements[(k,m)] = []
13 zeroes = [] #tracks tmpvars that will always be zero
14 nonzeroes = [] #tracks tmpvars that will possibly be non-zero
15 use_counts = {} #tmp var usage counts, used later for culling
16 #temp declarations
17 #hold on to the building expression for every identifier
18 var_expressions = {}
19 for line in temps:
20 newline = line.format(k,m)
21 insert = False
22 for index in ids.iterkeys():
23 if index in newline:
24 insert = True
25 break
26 if not insert:
27 value = newline.split("=")[1]
28 for tmp in nonzeroes:
29 if tmp in value:
30 x = use_counts.get(tmp, 0)
31 use_counts[tmp] = x+1
32 insert = True
33 i = newline.index("tmp")
34 identifier = newline[i:].split()[0]
35 expression = newline.split("= ")[1][:-1]
36 if insert:
37 nonzeroes.append(identifier)
38 var_expressions[identifier] = expression
39 else:
40 zeroes.append(identifier)
41 #summations
42 with_nonzero = [] #only those with some non-zero temp var
43 for line in summations:
44 newline = replaceZeroes(line.format(k,m),zeroes)
45 for nz in nonzeroes:
46 if nz in newline:
47 with_nonzero.append(newline)
48 components = newline[:-1].split("+=")[1].lstrip().replace(" + ", "|").replace(" - ", "|")
49 if components[0] == "-":
50 components = components[1:]
51 components = components.split("|")
52 for var in components:
53 x = use_counts.get(var, 0)
54 use_counts[var] = x+1
55 break
56 for fs in forced_substitutions:
57 if fs in nonzeroes:
58 use_counts[fs] = 1
59 for z in zeroes:
60 use_counts[z] = 0
61 #only interested in keeping variables declarations with 2 or more uses
62 for var in nonzeroes:
63 #0 we don't care and 1 we substitute the expression
64 if use_counts.get(var, 0) > 1:
65 declarations[(k,m)].append(" const double %s = %s;"%(var, var_expressions[var]))
66 #remove zeroes and replace single use tmpvars with their expression
67 for line in with_nonzero:
68 for key in use_counts.iterkeys():
69 if use_counts[key] == 0:
70 line = line.replace("%s;"%key, ";")
71 line = line.replace("%s "%key, " ")
72 elif use_counts[key] < 2:
73 s = var_expressions[key]
74 line = line.replace("%s;"%key, s+";")
75 line = line.replace("%s "%key, s+" ")
76 #print only if there's a right-hand-side of the expression
77 if "+=;" not in line and "+=;" not in line:
78 sumStatements[(k,m)].append(" "+line)
79 return declarations, sumStatements
80
81 def replaceZeroes(line, zeroes):
82 for zero in zeroes:
83 if zero+"|" in line:
84 line = line.replace(" + %s|"%zero, "")
85 line = line.replace("%s| + "%zero, "0 + ") #there's a better solution
86 line = line.replace(" - %s|"%zero, "") #it's just to stop
87 line = line.replace("%s| - "%zero, "0 - ") # a - b -> b instead of -b
88 line = line.replace("=%s|;"%zero,"=;")
89 if line:
90 return line.replace("|","").replace(" - 0", "").replace(" + 0", "").replace("+= 0 +", "+=").replace("+= 0 - ", "+=-")
91 return line
92
93 def print2DAExpanded():
94 dim = 2
95 quads = 2**dim
96 ids = {}
97 for i in range(dim):
98 for j in range(dim):
99 ids["{0}{0}{1}{1}".format(i,j)] = None
100 ids["{0}{1}{1}{0}".format(i,j)] = None
101 ids["{0}{1}{0}{1}".format(i,j)] = None
102
103 for name in sorted(ids.iterkeys()):
104 print("double A_{0}[{1}] =".format(name,quads), "{0};")
105 # ijji += mu
106 # ijij += mu
107 # iijj += lambda
108 print("if (!mu.isEmpty()) {\n const double *mu_p = mu.getSampleDataRO(e);")
109 completed = {}
110 for i in range(dim):
111 for j in range(dim):
112 for q in range(quads):
113 if i == j:
114 print(" A_{0}{0}{0}{0}[{1}] += 2*mu_p[{1}];".format(i,q))
115 else:
116 print(" A_{0}{1}{1}{0}[{2}] += mu_p[{2}];".format(i,j,q))
117 print(" A_{0}{1}{0}{1}[{2}] += mu_p[{2}];".format(i,j,q))
118 print("}\nif (!lambda.isEmpty()) {\n const double *lambda_p = lambda.getSampleDataRO(e);")
119 for i in range(dim):
120 for j in range(dim):
121 for q in range(quads):
122 print(" A_{0}{0}{1}{1}[{2}] += lambda_p[{2}];".format(i,j,q))
123 print("}")
124
125 decl, sums = buildTempAndSummation(dim, ids, lamesource.expanded2Dtemps, lamesource.expanded2Dsummations)
126 for k in range(dim):
127 for m in range(dim):
128 print("{")
129 print("\n".join(decl[(k,m)]))
130 print("\n".join(sums[(k,m)]))
131 print("}")
132
133 def print2DAReduced():
134 dim = 2
135 quads = 2**dim
136 ids = {}
137 for i in range(dim):
138 for j in range(dim):
139 ids["{0}{0}{1}{1}".format(i,j)] = None
140 ids["{0}{1}{1}{0}".format(i,j)] = None
141 ids["{0}{1}{0}{1}".format(i,j)] = None
142
143 for name in sorted(ids.iterkeys()):
144 print("double A_{0} =".format(name,quads), "0;")
145 # ijji += mu
146 # ijij += mu
147 # iijj += lambda
148
149 print("if (!mu.isEmpty()) {\n const double *mu_p = mu.getSampleDataRO(e);")
150
151 for i in range(dim):
152 for j in range(dim):
153 if i == j:
154 print(" A_{0}{0}{0}{0} += 2*mu_p[0];".format(i))
155 else:
156 print(" A_{0}{1}{1}{0} += mu_p[0];".format(i,j))
157 print(" A_{0}{1}{0}{1} += mu_p[0];".format(i,j))
158 print("}")
159 print("if (!lambda.isEmpty()) {\n const double *lambda_p = lambda.getSampleDataRO(e);")
160 for i in range(dim):
161 for j in range(dim):
162 print(" A_{0}{0}{1}{1} += lambda_p[0];".format(i,j))
163 print("}")
164
165 lines = [
166 "const double tmp_0 = 6*w1*(A_{0}0{1}1 - A_{0}1{1}0);",
167 "const double tmp_1 = 6*w1*(A_{0}0{1}1 + A_{0}1{1}0);",
168 "const double tmp_2 = 6*w1*(-A_{0}0{1}1 - A_{0}1{1}0);",
169 "const double tmp_3 = 6*w1*(-A_{0}0{1}1 + A_{0}1{1}0);"
170 ]
171
172 zeroes = []
173 nonzeroes = []
174 for k in range(dim):
175 for m in range(dim):
176 if k == m:
177 for q in range(quads):
178 zeroes.append("tmp{0}{0}_{1}".format(m, q))
179 continue
180 for line in lines:
181 newline = line.format(k,m)
182 insert = False
183 for index in ids.iterkeys():
184 if index in newline:
185 insert = True
186 break
187 i = newline.index("tmp")
188 if insert:
189 print(newline)
190 nonzeroes.append(newline[i:].split()[0].rstrip())
191 else:
192 zeroes.append(newline[i:].split()[0].rstrip())
193
194 for k in range(dim):
195 for m in range(dim):
196 for line in lamesource.reduced2Dsummations:
197 newline = replaceZeroes(line.format(k,m),zeroes)
198 if not newline:
199 continue
200 if k != m:
201 st = "A_{0}0{1}0".format(k,m)
202 if st in newline:
203 i = newline.index(st)
204 newline = newline[:i-3] + newline[i+12:]
205 st = "A_{0}1{1}1".format(k,m)
206 if st in newline:
207 i = newline.index(st)
208 newline = newline[:i-2] + newline[i+11:]
209 print(newline)
210
211 def print3DAExpanded():
212 dim = 3
213 quads = 2**dim
214 ids = {}
215 for i in range(dim):
216 for j in range(dim):
217 ids["{0}{0}{1}{1}".format(i,j)] = None
218 ids["{0}{1}{1}{0}".format(i,j)] = None
219 ids["{0}{1}{0}{1}".format(i,j)] = None
220
221 for name in sorted(ids.iterkeys()):
222 print("double A_{0}[{1}] =".format(name,quads), "{0};")
223 # ijji += mu
224 # ijij += mu
225 # iijj += lambda
226 print("if (!mu.isEmpty()) {\n const double *mu_p = mu.getSampleDataRO(e);")
227 for i in range(dim):
228 for j in range(dim):
229 for q in range(quads):
230 if i == j:
231 print(" A_{0}{0}{0}{0}[{1}] += 2*mu_p[{1}];".format(i,q))
232 else:
233 print(" A_{0}{1}{1}{0}[{2}] += mu_p[{2}];".format(i,j,q))
234 print(" A_{0}{1}{0}{1}[{2}] += mu_p[{2}];".format(i,j,q))
235 print("}")
236 print("if (!lambda.isEmpty()) {\n const double *lambda_p = lambda.getSampleDataRO(e);")
237 for i in range(dim):
238 for j in range(dim):
239 for q in range(quads):
240 print(" A_{0}{0}{1}{1}[{2}] += lambda_p[{2}];".format(i,j,q))
241 print("}")
242
243 decl, sums = buildTempAndSummation(dim, ids, lamesource.expanded3Dtemps, lamesource.expanded3Dsummations)
244 for k in range(dim):
245 for m in range(dim):
246 print("{")
247 print("\n".join(decl[(k,m)]))
248 print("\n".join(sums[(k,m)]))
249 print("}")
250
251 def print3DAReduced():
252 dim = 3
253 ids = {}
254 for i in range(dim):
255 for j in range(dim):
256 ids["{0}{0}{1}{1}".format(i,j)] = None
257 ids["{0}{1}{1}{0}".format(i,j)] = None
258 ids["{0}{1}{0}{1}".format(i,j)] = None
259
260 for i in sorted(ids.iterkeys()):
261 print("double Aw%s = 0;"%i)
262
263 print("if (!mu.isEmpty()) {\n const double *mu_p = mu.getSampleDataRO(e);")
264 for i in range(dim):
265 for j in range(dim):
266 if i == j:
267 print(" Aw{0}{0}{0}{0} += 2*mu_p[0];".format(i))
268 else:
269 print(" Aw{0}{1}{1}{0} += mu_p[0];".format(i,j))
270 print(" Aw{0}{1}{0}{1} += mu_p[0];".format(i,j))
271
272 print("}\nif (!lambda.isEmpty()) {\n const double *lambda_p = lambda.getSampleDataRO(e);")
273 for i in range(dim):
274 for j in range(dim):
275 print(" Aw{0}{0}{1}{1} += lambda_p[0];".format(i,j))
276 print("}")
277
278 for k in range(dim):
279 for m in range(dim):
280 for line in ["Aw{0}0{1}0 *= 8*w27;","Aw{0}0{1}1 *= 12*w8;","Aw{0}0{1}2 *= 12*w11;",
281 "Aw{0}1{1}0 *= 12*w8;","Aw{0}1{1}1 *= 8*w22;","Aw{0}1{1}2 *= 12*w10;",
282 "Aw{0}2{1}0 *= 12*w11;","Aw{0}2{1}1 *= 12*w10;","Aw{0}2{1}2 *= 8*w13;"]:
283 newline = line.format(k,m)
284 found = False
285 for ident in ids.iterkeys():
286 if ident in newline:
287 found = True
288 break
289 if found:
290 print(newline)
291
292 decs, sums = buildTempAndSummation(dim, ids, lamesource.reduced3Dtemps,
293 lamesource.reduced3Dsummations,
294 forced_substitutions = ["tmp12","tmp13","tmp14","tmp21","tmp22","tmp23"])
295 for k in range(dim):
296 for m in range(dim):
297 print("{")
298 print("\n".join(decs[(k,m)]))
299 print("\n".join(sums[(k,m)]))
300 print("}")
301
302
303
304 def printAReduced(dim):
305 if dim == 2:
306 return print2DAReduced()
307 elif dim == 3:
308 return print3DAReduced()
309 else:
310 raise
311
312 def printAExpanded(dim):
313 if dim == 2:
314 return print2DAExpanded()
315 elif dim == 3:
316 return print3DAExpanded()
317 else:
318 raise
319
320 if __name__ == "__main__":
321 if len(sys.argv) < 3 or sys.argv[1] not in ["R", "E"] or sys.argv[2] not in ["2","3"]:
322 print("Usage: {0} [R]educed/[E]xpanded dimensions\nE.g. {0} R 3")
323 exit(1)
324 dim = int(sys.argv[2])
325 if sys.argv[1] == "R":
326 printAReduced(dim)
327 else:
328 printAExpanded(dim)
329

  ViewVC Help
Powered by ViewVC 1.1.26