1 | //===- Merger.cpp - Implementation of iteration lattices ------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/SparseTensor/Utils/Merger.h" |
10 | #include "mlir/Dialect/Arith/IR/Arith.h" |
11 | #include "mlir/Dialect/Complex/IR/Complex.h" |
12 | #include "mlir/Dialect/Math/IR/Math.h" |
13 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
14 | |
15 | #include "mlir/IR/Operation.h" |
16 | #include "llvm/Support/Debug.h" |
17 | #include <optional> |
18 | |
19 | namespace mlir { |
20 | namespace sparse_tensor { |
21 | |
22 | enum class ExpArity { |
23 | kNullary, |
24 | kUnary, |
25 | kBinary, |
26 | }; |
27 | |
28 | static ExpArity getExpArity(TensorExp::Kind k) { |
29 | switch (k) { |
30 | // Leaf. |
31 | case TensorExp::Kind::kTensor: |
32 | case TensorExp::Kind::kInvariant: |
33 | case TensorExp::Kind::kLoopVar: |
34 | case TensorExp::Kind::kSynZero: |
35 | return ExpArity::kNullary; |
36 | case TensorExp::Kind::kAbsF: |
37 | case TensorExp::Kind::kAbsC: |
38 | case TensorExp::Kind::kAbsI: |
39 | case TensorExp::Kind::kCeilF: |
40 | case TensorExp::Kind::kFloorF: |
41 | case TensorExp::Kind::kSqrtF: |
42 | case TensorExp::Kind::kSqrtC: |
43 | case TensorExp::Kind::kExpm1F: |
44 | case TensorExp::Kind::kExpm1C: |
45 | case TensorExp::Kind::kLog1pF: |
46 | case TensorExp::Kind::kLog1pC: |
47 | case TensorExp::Kind::kSinF: |
48 | case TensorExp::Kind::kSinC: |
49 | case TensorExp::Kind::kTanhF: |
50 | case TensorExp::Kind::kTanhC: |
51 | case TensorExp::Kind::kTruncF: |
52 | case TensorExp::Kind::kExtF: |
53 | case TensorExp::Kind::kCastFS: |
54 | case TensorExp::Kind::kCastFU: |
55 | case TensorExp::Kind::kCastSF: |
56 | case TensorExp::Kind::kCastUF: |
57 | case TensorExp::Kind::kCastS: |
58 | case TensorExp::Kind::kCastU: |
59 | case TensorExp::Kind::kCastIdx: |
60 | case TensorExp::Kind::kTruncI: |
61 | case TensorExp::Kind::kCIm: |
62 | case TensorExp::Kind::kCRe: |
63 | case TensorExp::Kind::kBitCast: |
64 | case TensorExp::Kind::kBinaryBranch: |
65 | case TensorExp::Kind::kUnary: |
66 | case TensorExp::Kind::kSelect: |
67 | case TensorExp::Kind::kNegF: |
68 | case TensorExp::Kind::kNegC: |
69 | case TensorExp::Kind::kNegI: |
70 | return ExpArity::kUnary; |
71 | // Binary operations. |
72 | case TensorExp::Kind::kDivF: |
73 | case TensorExp::Kind::kDivC: |
74 | case TensorExp::Kind::kDivS: |
75 | case TensorExp::Kind::kDivU: |
76 | case TensorExp::Kind::kShrS: |
77 | case TensorExp::Kind::kShrU: |
78 | case TensorExp::Kind::kShlI: |
79 | case TensorExp::Kind::kMulF: |
80 | case TensorExp::Kind::kMulC: |
81 | case TensorExp::Kind::kMulI: |
82 | case TensorExp::Kind::kAndI: |
83 | case TensorExp::Kind::kAddF: |
84 | case TensorExp::Kind::kAddC: |
85 | case TensorExp::Kind::kAddI: |
86 | case TensorExp::Kind::kOrI: |
87 | case TensorExp::Kind::kXorI: |
88 | case TensorExp::Kind::kBinary: |
89 | case TensorExp::Kind::kReduce: |
90 | case TensorExp::Kind::kSubF: |
91 | case TensorExp::Kind::kSubC: |
92 | case TensorExp::Kind::kSubI: |
93 | case TensorExp::Kind::kCmpF: |
94 | case TensorExp::Kind::kCmpI: |
95 | case TensorExp::Kind::kDenseOp: // kDenseOp can *at most* have two operands |
96 | return ExpArity::kBinary; |
97 | } |
98 | llvm_unreachable("unexpected kind" ); |
99 | } |
100 | |
101 | //===----------------------------------------------------------------------===// |
102 | // Constructors. |
103 | //===----------------------------------------------------------------------===// |
104 | |
105 | TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v, |
106 | Operation *o, Attribute a) |
107 | : kind(k), val(v), op(o) { |
108 | switch (kind) { |
109 | // Leaf. |
110 | case TensorExp::Kind::kTensor: |
111 | assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); |
112 | tensor = x; |
113 | return; |
114 | case TensorExp::Kind::kSynZero: |
115 | assert(x == detail::kInvalidId && y == detail::kInvalidId && !v && !o); |
116 | return; |
117 | case TensorExp::Kind::kInvariant: |
118 | assert(x == detail::kInvalidId && y == detail::kInvalidId && v && !o); |
119 | return; |
120 | case TensorExp::Kind::kLoopVar: |
121 | assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); |
122 | loop = x; |
123 | return; |
124 | // Unary operations. |
125 | case TensorExp::Kind::kAbsF: |
126 | case TensorExp::Kind::kAbsC: |
127 | case TensorExp::Kind::kAbsI: |
128 | case TensorExp::Kind::kCeilF: |
129 | case TensorExp::Kind::kFloorF: |
130 | case TensorExp::Kind::kSqrtF: |
131 | case TensorExp::Kind::kSqrtC: |
132 | case TensorExp::Kind::kExpm1F: |
133 | case TensorExp::Kind::kExpm1C: |
134 | case TensorExp::Kind::kLog1pF: |
135 | case TensorExp::Kind::kLog1pC: |
136 | case TensorExp::Kind::kSinF: |
137 | case TensorExp::Kind::kSinC: |
138 | case TensorExp::Kind::kTanhF: |
139 | case TensorExp::Kind::kTanhC: |
140 | case TensorExp::Kind::kNegF: |
141 | case TensorExp::Kind::kNegC: |
142 | case TensorExp::Kind::kNegI: |
143 | case TensorExp::Kind::kCIm: |
144 | case TensorExp::Kind::kCRe: |
145 | assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); |
146 | children.e0 = x; |
147 | children.e1 = y; |
148 | return; |
149 | case TensorExp::Kind::kTruncF: |
150 | case TensorExp::Kind::kExtF: |
151 | case TensorExp::Kind::kCastFS: |
152 | case TensorExp::Kind::kCastFU: |
153 | case TensorExp::Kind::kCastSF: |
154 | case TensorExp::Kind::kCastUF: |
155 | case TensorExp::Kind::kCastS: |
156 | case TensorExp::Kind::kCastU: |
157 | case TensorExp::Kind::kCastIdx: |
158 | case TensorExp::Kind::kTruncI: |
159 | case TensorExp::Kind::kBitCast: |
160 | assert(x != detail::kInvalidId && y == detail::kInvalidId && v && !o); |
161 | children.e0 = x; |
162 | children.e1 = y; |
163 | return; |
164 | case TensorExp::Kind::kBinaryBranch: |
165 | case TensorExp::Kind::kSelect: |
166 | assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && o); |
167 | children.e0 = x; |
168 | children.e1 = y; |
169 | return; |
170 | case TensorExp::Kind::kUnary: |
171 | // No assertion on y can be made, as the branching paths involve both |
172 | // a unary (`mapSet`) and binary (`disjSet`) pathway. |
173 | assert(x != detail::kInvalidId && !v && o); |
174 | children.e0 = x; |
175 | children.e1 = y; |
176 | return; |
177 | // Binary operations. |
178 | case TensorExp::Kind::kMulF: |
179 | case TensorExp::Kind::kMulC: |
180 | case TensorExp::Kind::kMulI: |
181 | case TensorExp::Kind::kDivF: |
182 | case TensorExp::Kind::kDivC: |
183 | case TensorExp::Kind::kDivS: |
184 | case TensorExp::Kind::kDivU: |
185 | case TensorExp::Kind::kAddF: |
186 | case TensorExp::Kind::kAddC: |
187 | case TensorExp::Kind::kAddI: |
188 | case TensorExp::Kind::kSubF: |
189 | case TensorExp::Kind::kSubC: |
190 | case TensorExp::Kind::kSubI: |
191 | case TensorExp::Kind::kAndI: |
192 | case TensorExp::Kind::kOrI: |
193 | case TensorExp::Kind::kXorI: |
194 | case TensorExp::Kind::kShrS: |
195 | case TensorExp::Kind::kShrU: |
196 | case TensorExp::Kind::kShlI: |
197 | assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o); |
198 | children.e0 = x; |
199 | children.e1 = y; |
200 | return; |
201 | case TensorExp::Kind::kCmpF: |
202 | case TensorExp::Kind::kCmpI: |
203 | assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o); |
204 | attr = a; |
205 | children.e0 = x; |
206 | children.e1 = y; |
207 | return; |
208 | case TensorExp::Kind::kBinary: |
209 | case TensorExp::Kind::kReduce: |
210 | assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && o); |
211 | children.e0 = x; |
212 | children.e1 = y; |
213 | return; |
214 | case TensorExp::Kind::kDenseOp: |
215 | assert(x != detail::kInvalidId && !v && o); |
216 | children.e0 = x; |
217 | children.e1 = y; |
218 | return; |
219 | } |
220 | llvm_unreachable("unexpected kind" ); |
221 | } |
222 | |
223 | Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops, |
224 | unsigned maxLvlRank) |
225 | : outTensor(numInputOutputTensors - 1), |
226 | syntheticTensor(numInputOutputTensors), |
227 | numTensors(numInputOutputTensors + 1), numLoops(numLoops), |
228 | hasSparseOut(false), |
229 | lvlTypes(numTensors, |
230 | std::vector<LevelType>(numLoops, LevelFormat::Undef)), |
231 | loopToLvl(numTensors, |
232 | std::vector<std::optional<Level>>(numLoops, std::nullopt)), |
233 | lvlToLoop(numTensors, |
234 | std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)), |
235 | loopToUnresolvedLvls(numLoops, std::vector<std::optional<LvlLTPair>>( |
236 | numTensors, std::nullopt)), |
237 | levelToDependentLoop(numTensors, |
238 | std::vector<std::vector<LoopCoeffPair>>( |
239 | maxLvlRank, std::vector<LoopCoeffPair>())), |
240 | loopBounds(numLoops, std::make_pair(x: numTensors, y&: numLoops)) {} |
241 | |
242 | //===----------------------------------------------------------------------===// |
243 | // Lattice methods. |
244 | //===----------------------------------------------------------------------===// |
245 | |
246 | ExprId Merger::addTensorExp(TensorId t) { |
247 | assert(isValidTensorId(t)); |
248 | const ExprId eNew(tensorExps.size()); |
249 | tensorExps.emplace_back(Args: TensorExp::Kind::kTensor, Args&: t, Args: detail::kInvalidId, |
250 | Args: Value(), Args: nullptr, Args: nullptr); |
251 | return eNew; |
252 | } |
253 | |
254 | ExprId Merger::addLoopVarExp(LoopId i) { |
255 | assert(isValidLoopId(i)); |
256 | const ExprId eNew(tensorExps.size()); |
257 | tensorExps.emplace_back(Args: TensorExp::Kind::kLoopVar, Args&: i, Args: detail::kInvalidId, |
258 | Args: Value(), Args: nullptr, Args: nullptr); |
259 | return eNew; |
260 | } |
261 | |
262 | ExprId Merger::addInvariantExp(Value v) { |
263 | const ExprId eNew(tensorExps.size()); |
264 | tensorExps.emplace_back(Args: TensorExp::Kind::kInvariant, Args: detail::kInvalidId, |
265 | Args: detail::kInvalidId, Args&: v, Args: nullptr, Args: nullptr); |
266 | return eNew; |
267 | } |
268 | |
269 | ExprId Merger::addSynZeroExp() { |
270 | const ExprId eNew(tensorExps.size()); |
271 | tensorExps.emplace_back(Args: TensorExp::Kind::kSynZero, Args: detail::kInvalidId, |
272 | Args: detail::kInvalidId, Args: Value(), Args: nullptr, Args: nullptr); |
273 | return eNew; |
274 | } |
275 | |
276 | ExprId Merger::addExp(TensorExp::Kind k, ExprId e0, ExprId e1, Operation *op, |
277 | Attribute attr) { |
278 | assert(k > TensorExp::Kind::kLoopVar); |
279 | const ExprId eNew(tensorExps.size()); |
280 | tensorExps.emplace_back(Args&: k, Args&: e0, Args&: e1, Args: Value(), Args&: op, Args&: attr); |
281 | return eNew; |
282 | } |
283 | |
284 | ExprId Merger::addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op, |
285 | Attribute attr) { |
286 | assert(k > TensorExp::Kind::kLoopVar); |
287 | const ExprId eNew(tensorExps.size()); |
288 | tensorExps.emplace_back(Args&: k, Args&: e, Args: detail::kInvalidId, Args&: v, Args&: op, Args&: attr); |
289 | return eNew; |
290 | } |
291 | |
292 | LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) { |
293 | const LatPointId pNew(latPoints.size()); |
294 | const unsigned size = numLoops * numTensors; |
295 | const TensorLoopId b = makeTensorLoopId(t, i); |
296 | latPoints.emplace_back(Args: size, Args&: e); |
297 | latPoints[pNew].bits.set(b); |
298 | return pNew; |
299 | } |
300 | |
301 | LatPointId Merger::addLat(const BitVector &bits, ExprId e) { |
302 | assert(bits.size() == numLoops * numTensors); |
303 | const LatPointId pNew(latPoints.size()); |
304 | latPoints.emplace_back(Args: bits, Args&: e); |
305 | return pNew; |
306 | } |
307 | |
308 | LatSetId Merger::addSet() { |
309 | const LatSetId sNew(latSets.size()); |
310 | latSets.emplace_back(); |
311 | return sNew; |
312 | } |
313 | |
314 | LatPointId Merger::conjLat(ExprId e, LatPointId p0, LatPointId p1, |
315 | Operation *op) { |
316 | TensorExp::Kind kind = exp(e).kind; |
317 | Attribute attr = exp(e).attr; |
318 | const LatPointId pNew(latPoints.size()); |
319 | const auto &point0 = lat(p: p0); |
320 | const auto &point1 = lat(p: p1); |
321 | BitVector bits(point0.bits); |
322 | bits |= point1.bits; |
323 | const ExprId ne = addExp(k: kind, e0: point0.exp, e1: point1.exp, op, attr); |
324 | latPoints.emplace_back(Args&: bits, Args: ne); |
325 | return pNew; |
326 | } |
327 | |
328 | LatSetId Merger::conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) { |
329 | const LatSetId sNew = addSet(); |
330 | auto &setNew = latSets[sNew]; |
331 | for (const LatPointId p0 : set(s0)) |
332 | for (const LatPointId p1 : set(s1)) |
333 | setNew.push_back(Elt: conjLat(e, p0, p1, op)); |
334 | return sNew; |
335 | } |
336 | |
337 | LatSetId Merger::disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) { |
338 | const LatSetId sNew = conjSet(e, s0, s1, op); |
339 | TensorExp::Kind kind = exp(e).kind; |
340 | |
341 | // Followed by all in s0. |
342 | latSets[sNew].append(RHS: latSets[s0]); |
343 | // Map binary 0-y to unary -y. |
344 | // TODO: move this if-else logic into buildLattices |
345 | if (kind == TensorExp::Kind::kSubF) |
346 | s1 = mapSet(kind: TensorExp::Kind::kNegF, s: s1); |
347 | else if (kind == TensorExp::Kind::kSubC) |
348 | s1 = mapSet(kind: TensorExp::Kind::kNegC, s: s1); |
349 | else if (kind == TensorExp::Kind::kSubI) |
350 | s1 = mapSet(kind: TensorExp::Kind::kNegI, s: s1); |
351 | // Followed by all in s1. |
352 | latSets[sNew].append(RHS: latSets[s1]); |
353 | return sNew; |
354 | } |
355 | |
356 | LatSetId Merger::disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1) { |
357 | assert(exp(e).kind == TensorExp::Kind::kCmpI || |
358 | exp(e).kind == TensorExp::Kind::kCmpF); |
359 | const LatSetId sNew = conjSet(e, s0, s1, op: nullptr); |
360 | |
361 | ExprId e0 = exp(e).children.e0; |
362 | ExprId e1 = exp(e).children.e1; |
363 | if (exp(e: e0).kind == TensorExp::Kind::kSynZero || |
364 | exp(e: e1).kind == TensorExp::Kind::kSynZero) { |
365 | // lhs and rhs can't be synthetic zero at the same time. |
366 | assert(exp(e0).kind != exp(e1).kind); |
367 | // If one of the operands has already been assigned to zero (the |
368 | // element is absent in the corresponding operand), then we do not |
369 | // need to build disjunctive set for it. |
370 | return sNew; |
371 | } |
372 | |
373 | auto lhsSet = mapBinWithSynZeroSet(e, s: s0, lhsZero: false); |
374 | auto rhsSet = mapBinWithSynZeroSet(e, s: s1, lhsZero: true); |
375 | latSets[sNew].append(RHS: latSets[lhsSet]); |
376 | latSets[sNew].append(RHS: latSets[rhsSet]); |
377 | return sNew; |
378 | } |
379 | |
380 | LatSetId Merger::combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig, |
381 | bool includeLeft, TensorExp::Kind ltrans, |
382 | Operation *opleft, bool includeRight, |
383 | TensorExp::Kind rtrans, Operation *opright) { |
384 | const LatSetId sNew = conjSet(e, s0, s1, op: orig); |
385 | // Left Region. |
386 | if (includeLeft) { |
387 | if (opleft) |
388 | s0 = mapSet(kind: ltrans, s: s0, v: Value(), op: opleft); |
389 | latSets[sNew].append(RHS: latSets[s0]); |
390 | } |
391 | // Right Region. |
392 | if (includeRight) { |
393 | if (opright) |
394 | s1 = mapSet(kind: rtrans, s: s1, v: Value(), op: opright); |
395 | latSets[sNew].append(RHS: latSets[s1]); |
396 | } |
397 | return sNew; |
398 | } |
399 | |
400 | LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v, |
401 | Operation *op) { |
402 | assert((TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect) || |
403 | TensorExp::Kind::kDenseOp == kind); |
404 | const LatSetId sNew = addSet(); |
405 | auto &setNew = latSets[sNew]; |
406 | for (const LatPointId p : set(s0)) { |
407 | const auto &point = latPoints[p]; |
408 | setNew.push_back(Elt: addLat(bits: point.bits, e: addExp(k: kind, e: point.exp, v, op))); |
409 | } |
410 | return sNew; |
411 | } |
412 | |
413 | LatSetId Merger::mapBinWithSynZeroSet(ExprId e, LatSetId s0, bool lhsZero) { |
414 | TensorExp::Kind kind = exp(e).kind; |
415 | Attribute a = exp(e).attr; |
416 | assert(TensorExp::Kind::kMulF <= kind && kind <= TensorExp::Kind::kShlI); |
417 | // Must be a binary operation. |
418 | const LatSetId sNew = addSet(); |
419 | auto &setNew = latSets[sNew]; |
420 | const ExprId zeroExp = addSynZeroExp(); |
421 | for (const LatPointId p : set(s0)) { |
422 | const auto &point = latPoints[p]; |
423 | ExprId newExp = lhsZero ? addExp(k: kind, e0: zeroExp, e1: point.exp, op: nullptr, attr: a) |
424 | : addExp(k: kind, e0: point.exp, e1: zeroExp, op: nullptr, attr: a); |
425 | setNew.push_back(Elt: addLat(bits: point.bits, e: newExp)); |
426 | } |
427 | return sNew; |
428 | } |
429 | |
430 | LatSetId Merger::optimizeSet(LatSetId s0) { |
431 | const LatSetId sNew = addSet(); |
432 | auto &setNew = latSets[sNew]; |
433 | const auto &set0 = set(s0); |
434 | assert(!set0.empty()); |
435 | const LatPointId p0 = set0[0]; |
436 | for (const LatPointId p1 : set0) { |
437 | bool add = true; |
438 | if (p0 != p1) { |
439 | // Check whether this is a straightforward copy. |
440 | if (expIsTensor(e: latPoints[p1].exp, t: outTensor)) |
441 | continue; |
442 | // Check whether this conjunction is already covered. |
443 | for (const LatPointId p2 : setNew) { |
444 | assert(!latGT(p1, p2)); // Lj => Li would be bad |
445 | if (onlyDenseDiff(p0: p2, p1)) { |
446 | add = false; |
447 | break; |
448 | } |
449 | } |
450 | assert(!add || latGT(p0, p1)); |
451 | } |
452 | if (add) |
453 | setNew.push_back(Elt: p1); |
454 | } |
455 | for (const LatPointId p : setNew) |
456 | latPoints[p].simple = simplifyCond(s: sNew, p); |
457 | return sNew; |
458 | } |
459 | |
460 | BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) { |
461 | // First determine if this lattice point is a *singleton*, i.e., |
462 | // the last point in a lattice, no other is less than this one. |
463 | bool isSingleton = true; |
464 | for (const LatPointId p1 : set(s0)) { |
465 | if (p0 != p1 && latGT(p0, p1)) { |
466 | isSingleton = false; |
467 | break; |
468 | } |
469 | } |
470 | |
471 | BitVector simple(latPoints[p0].bits); |
472 | bool reset = isSingleton && hasAnySparse(bits: simple); |
473 | const TensorLoopId be = simple.size(); |
474 | TensorLoopId offset = 0; // relative to the end |
475 | if (!reset) |
476 | // Starts resetting from a dense level, so that the first bit (if kept) |
477 | // is not undefined level-type. |
478 | for (unsigned b = 0; b < be; b++) { |
479 | if (simple[b] && getLvlType(b: TensorLoopId{b}).hasDenseSemantic()) { |
480 | offset = be - b - 1; // relative to the end |
481 | break; |
482 | } |
483 | } |
484 | |
485 | // Now apply the two basic rules. We also iterate the bits reversely to always |
486 | // keep the rightmost bit (which could possibly be a synthetic tensor). |
487 | for (unsigned b = be - 1 - offset, i = 0; i < be; |
488 | b = b == 0 ? be - 1 : b - 1, i++) { |
489 | // Slice on dense level has `locate` property as well, and can be optimized. |
490 | if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) { |
491 | const auto lt = getLvlType(b); |
492 | if (!lt.hasSparseSemantic()) { |
493 | if (reset) |
494 | simple.reset(Idx: b); |
495 | reset = true; |
496 | } |
497 | } |
498 | } |
499 | return simple; |
500 | } |
501 | |
502 | bool Merger::latGT(LatPointId i, LatPointId j) const { |
503 | const BitVector &bitsi = lat(p: i).bits; |
504 | const BitVector &bitsj = lat(p: j).bits; |
505 | assert(bitsi.size() == bitsj.size()); |
506 | if (bitsi.count() > bitsj.count()) { |
507 | for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++) |
508 | if (bitsj[b] && !bitsi[b]) |
509 | return false; |
510 | return true; |
511 | } |
512 | return false; |
513 | } |
514 | |
515 | bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const { |
516 | BitVector tmp(latPoints[j].bits); |
517 | tmp ^= latPoints[i].bits; |
518 | return !hasAnySparse(bits: tmp); |
519 | } |
520 | |
521 | bool Merger::expContainsTensor(ExprId e, TensorId t) const { |
522 | const auto &expr = exp(e); |
523 | // First we check `expIsTensor`. |
524 | if (expr.kind == TensorExp::Kind::kTensor) |
525 | return expr.tensor == t; |
526 | |
527 | switch (getExpArity(k: expr.kind)) { |
528 | case ExpArity::kNullary: |
529 | return false; |
530 | case ExpArity::kUnary: { |
531 | const ExprId e0 = expr.children.e0; |
532 | return expContainsTensor(e: e0, t); |
533 | } |
534 | case ExpArity::kBinary: { |
535 | const ExprId e0 = expr.children.e0; |
536 | const ExprId e1 = expr.children.e1; |
537 | return expContainsTensor(e: e0, t) || expContainsTensor(e: e1, t); |
538 | } |
539 | } |
540 | llvm_unreachable("unexpected arity" ); |
541 | } |
542 | |
543 | bool Merger::hasNegateOnOut(ExprId e) const { |
544 | const auto &expr = exp(e); |
545 | switch (expr.kind) { |
546 | case TensorExp::Kind::kNegF: |
547 | case TensorExp::Kind::kNegC: |
548 | case TensorExp::Kind::kNegI: |
549 | return expContainsTensor(e: expr.children.e0, t: outTensor); |
550 | case TensorExp::Kind::kSubF: |
551 | case TensorExp::Kind::kSubC: |
552 | case TensorExp::Kind::kSubI: |
553 | return expContainsTensor(e: expr.children.e1, t: outTensor) || |
554 | hasNegateOnOut(e: expr.children.e0); |
555 | case TensorExp::Kind::kDenseOp: { |
556 | bool lhsNeg = hasNegateOnOut(e: expr.children.e0); |
557 | if (!lhsNeg && expr.children.e1 != detail::kInvalidId) |
558 | return hasNegateOnOut(e: expr.children.e1); |
559 | return lhsNeg; |
560 | } |
561 | default: { |
562 | switch (getExpArity(k: expr.kind)) { |
563 | case ExpArity::kNullary: |
564 | return false; |
565 | case ExpArity::kUnary: |
566 | return hasNegateOnOut(e: expr.children.e0); |
567 | case ExpArity::kBinary: |
568 | return hasNegateOnOut(e: expr.children.e0) || |
569 | hasNegateOnOut(e: expr.children.e1); |
570 | } |
571 | } |
572 | } |
573 | llvm_unreachable("unexpected kind" ); |
574 | } |
575 | |
576 | bool Merger::isSingleCondition(TensorId t, ExprId e) const { |
577 | assert(isValidTensorId(t)); |
578 | const auto &expr = exp(e); |
579 | switch (expr.kind) { |
580 | // Leaf. |
581 | case TensorExp::Kind::kTensor: |
582 | return expr.tensor == t; |
583 | case TensorExp::Kind::kInvariant: |
584 | case TensorExp::Kind::kLoopVar: |
585 | case TensorExp::Kind::kSynZero: |
586 | return false; |
587 | // Unary operations. |
588 | case TensorExp::Kind::kAbsF: |
589 | case TensorExp::Kind::kAbsC: |
590 | case TensorExp::Kind::kAbsI: |
591 | case TensorExp::Kind::kCeilF: |
592 | case TensorExp::Kind::kFloorF: |
593 | case TensorExp::Kind::kSqrtF: |
594 | case TensorExp::Kind::kSqrtC: |
595 | case TensorExp::Kind::kExpm1F: |
596 | case TensorExp::Kind::kExpm1C: |
597 | case TensorExp::Kind::kLog1pF: |
598 | case TensorExp::Kind::kLog1pC: |
599 | case TensorExp::Kind::kSinF: |
600 | case TensorExp::Kind::kSinC: |
601 | case TensorExp::Kind::kTanhF: |
602 | case TensorExp::Kind::kTanhC: |
603 | case TensorExp::Kind::kNegF: |
604 | case TensorExp::Kind::kNegC: |
605 | case TensorExp::Kind::kNegI: |
606 | case TensorExp::Kind::kTruncF: |
607 | case TensorExp::Kind::kExtF: |
608 | case TensorExp::Kind::kCastFS: |
609 | case TensorExp::Kind::kCastFU: |
610 | case TensorExp::Kind::kCastSF: |
611 | case TensorExp::Kind::kCastUF: |
612 | case TensorExp::Kind::kCastS: |
613 | case TensorExp::Kind::kCastU: |
614 | case TensorExp::Kind::kCastIdx: |
615 | case TensorExp::Kind::kTruncI: |
616 | case TensorExp::Kind::kCIm: |
617 | case TensorExp::Kind::kCRe: |
618 | case TensorExp::Kind::kBitCast: |
619 | case TensorExp::Kind::kUnary: |
620 | return isSingleCondition(t, e: expr.children.e0); |
621 | case TensorExp::Kind::kBinaryBranch: |
622 | case TensorExp::Kind::kSelect: |
623 | return false; |
624 | // Binary operations. |
625 | case TensorExp::Kind::kDivF: // note: x / c only |
626 | case TensorExp::Kind::kDivC: |
627 | case TensorExp::Kind::kDivS: |
628 | case TensorExp::Kind::kDivU: |
629 | assert(!maybeZero(expr.children.e1)); |
630 | return isSingleCondition(t, e: expr.children.e0); |
631 | case TensorExp::Kind::kShrS: // note: x >> inv only |
632 | case TensorExp::Kind::kShrU: |
633 | case TensorExp::Kind::kShlI: |
634 | assert(isInvariant(expr.children.e1)); |
635 | return isSingleCondition(t, e: expr.children.e0); |
636 | case TensorExp::Kind::kMulF: |
637 | case TensorExp::Kind::kMulC: |
638 | case TensorExp::Kind::kMulI: |
639 | case TensorExp::Kind::kAndI: |
640 | case TensorExp::Kind::kReduce: |
641 | if (isSingleCondition(t, e: expr.children.e0)) |
642 | return isSingleCondition(t, e: expr.children.e1) || |
643 | isInvariant(e: expr.children.e1); |
644 | if (isSingleCondition(t, e: expr.children.e1)) |
645 | return isInvariant(e: expr.children.e0); |
646 | return false; |
647 | case TensorExp::Kind::kAddF: |
648 | case TensorExp::Kind::kAddC: |
649 | case TensorExp::Kind::kAddI: |
650 | return isSingleCondition(t, e: expr.children.e0) && |
651 | isSingleCondition(t, e: expr.children.e1); |
652 | case TensorExp::Kind::kSubF: |
653 | case TensorExp::Kind::kSubC: |
654 | case TensorExp::Kind::kSubI: |
655 | case TensorExp::Kind::kOrI: |
656 | case TensorExp::Kind::kXorI: |
657 | case TensorExp::Kind::kCmpF: |
658 | case TensorExp::Kind::kCmpI: |
659 | case TensorExp::Kind::kBinary: |
660 | return false; |
661 | case TensorExp::Kind::kDenseOp: |
662 | // Since Merger guarantees all the operands of the kDenseOp to be dense, the |
663 | // operation must be single-condition. |
664 | return true; |
665 | } |
666 | llvm_unreachable("unexpected kind" ); |
667 | } |
668 | |
669 | bool Merger::hasAnySparse(const BitVector &bits) const { |
670 | for (TensorLoopId b : bits.set_bits()) { |
671 | const auto lt = getLvlType(b); |
672 | if (lt.hasSparseSemantic()) |
673 | return true; |
674 | } |
675 | return hasSparseIdxReduction(bits); |
676 | } |
677 | |
678 | bool Merger::hasSparseIdxReduction(const BitVector &bits) const { |
679 | for (TensorLoopId b : bits.set_bits()) |
680 | if (isSparseLvlWithNonTrivialIdxExp(b)) |
681 | return true; |
682 | return false; |
683 | } |
684 | |
685 | #ifndef NDEBUG |
686 | |
687 | //===----------------------------------------------------------------------===// |
688 | // Print methods (for debugging). |
689 | //===----------------------------------------------------------------------===// |
690 | |
691 | static const char *kindToOpSymbol(TensorExp::Kind kind) { |
692 | switch (kind) { |
693 | // Leaf. |
694 | case TensorExp::Kind::kTensor: |
695 | return "tensor" ; |
696 | case TensorExp::Kind::kInvariant: |
697 | return "invariant" ; |
698 | case TensorExp::Kind::kLoopVar: |
699 | return "index" ; |
700 | case TensorExp::Kind::kSynZero: |
701 | return "0" ; |
702 | // Unary operations. |
703 | case TensorExp::Kind::kAbsF: |
704 | case TensorExp::Kind::kAbsC: |
705 | case TensorExp::Kind::kAbsI: |
706 | return "abs" ; |
707 | case TensorExp::Kind::kCeilF: |
708 | return "ceil" ; |
709 | case TensorExp::Kind::kFloorF: |
710 | return "floor" ; |
711 | case TensorExp::Kind::kSqrtF: |
712 | case TensorExp::Kind::kSqrtC: |
713 | return "sqrt" ; |
714 | case TensorExp::Kind::kExpm1F: |
715 | case TensorExp::Kind::kExpm1C: |
716 | return "expm1" ; |
717 | case TensorExp::Kind::kLog1pF: |
718 | case TensorExp::Kind::kLog1pC: |
719 | return "log1p" ; |
720 | case TensorExp::Kind::kSinF: |
721 | case TensorExp::Kind::kSinC: |
722 | return "sin" ; |
723 | case TensorExp::Kind::kTanhF: |
724 | case TensorExp::Kind::kTanhC: |
725 | return "tanh" ; |
726 | case TensorExp::Kind::kNegF: |
727 | case TensorExp::Kind::kNegC: |
728 | case TensorExp::Kind::kNegI: |
729 | return "-" ; |
730 | case TensorExp::Kind::kTruncF: |
731 | case TensorExp::Kind::kExtF: |
732 | case TensorExp::Kind::kCastFS: |
733 | case TensorExp::Kind::kCastFU: |
734 | case TensorExp::Kind::kCastSF: |
735 | case TensorExp::Kind::kCastUF: |
736 | case TensorExp::Kind::kCastS: |
737 | case TensorExp::Kind::kCastU: |
738 | case TensorExp::Kind::kCastIdx: |
739 | case TensorExp::Kind::kTruncI: |
740 | case TensorExp::Kind::kCIm: |
741 | return "complex.im" ; |
742 | case TensorExp::Kind::kCRe: |
743 | return "complex.re" ; |
744 | case TensorExp::Kind::kBitCast: |
745 | return "cast" ; |
746 | case TensorExp::Kind::kBinaryBranch: |
747 | return "binary_branch" ; |
748 | case TensorExp::Kind::kUnary: |
749 | return "unary" ; |
750 | case TensorExp::Kind::kSelect: |
751 | return "select" ; |
752 | // Binary operations. |
753 | case TensorExp::Kind::kMulF: |
754 | case TensorExp::Kind::kMulC: |
755 | case TensorExp::Kind::kMulI: |
756 | return "*" ; |
757 | case TensorExp::Kind::kDivF: |
758 | case TensorExp::Kind::kDivC: |
759 | case TensorExp::Kind::kDivS: |
760 | case TensorExp::Kind::kDivU: |
761 | return "/" ; |
762 | case TensorExp::Kind::kAddF: |
763 | case TensorExp::Kind::kAddC: |
764 | case TensorExp::Kind::kAddI: |
765 | return "+" ; |
766 | case TensorExp::Kind::kSubF: |
767 | case TensorExp::Kind::kSubC: |
768 | case TensorExp::Kind::kSubI: |
769 | return "-" ; |
770 | case TensorExp::Kind::kAndI: |
771 | return "&" ; |
772 | case TensorExp::Kind::kOrI: |
773 | return "|" ; |
774 | case TensorExp::Kind::kXorI: |
775 | return "^" ; |
776 | case TensorExp::Kind::kShrS: |
777 | return "a>>" ; |
778 | case TensorExp::Kind::kShrU: |
779 | return ">>" ; |
780 | case TensorExp::Kind::kShlI: |
781 | return "<<" ; |
782 | case TensorExp::Kind::kCmpF: |
783 | case TensorExp::Kind::kCmpI: |
784 | return "cmp" ; |
785 | case TensorExp::Kind::kBinary: |
786 | return "binary" ; |
787 | case TensorExp::Kind::kReduce: |
788 | return "reduce" ; |
789 | case TensorExp::Kind::kDenseOp: |
790 | return "dense" ; |
791 | } |
792 | llvm_unreachable("unexpected kind for symbol" ); |
793 | } |
794 | |
795 | void Merger::dumpExp(ExprId e) const { |
796 | const auto &expr = exp(e); |
797 | switch (expr.kind) { |
798 | // Leaf. |
799 | case TensorExp::Kind::kTensor: |
800 | if (expr.tensor == syntheticTensor) |
801 | llvm::dbgs() << "synthetic_" ; |
802 | else if (expr.tensor == outTensor) |
803 | llvm::dbgs() << "output_" ; |
804 | llvm::dbgs() << "tensor_" << expr.tensor; |
805 | break; |
806 | case TensorExp::Kind::kInvariant: |
807 | llvm::dbgs() << "invariant" ; |
808 | break; |
809 | case TensorExp::Kind::kSynZero: |
810 | llvm::dbgs() << "0" ; |
811 | break; |
812 | case TensorExp::Kind::kLoopVar: |
813 | llvm::dbgs() << "loopvar_" << expr.loop; |
814 | break; |
815 | // Unary operations. |
816 | case TensorExp::Kind::kAbsF: |
817 | case TensorExp::Kind::kAbsC: |
818 | case TensorExp::Kind::kAbsI: |
819 | case TensorExp::Kind::kCeilF: |
820 | case TensorExp::Kind::kFloorF: |
821 | case TensorExp::Kind::kSqrtF: |
822 | case TensorExp::Kind::kSqrtC: |
823 | case TensorExp::Kind::kExpm1F: |
824 | case TensorExp::Kind::kExpm1C: |
825 | case TensorExp::Kind::kLog1pF: |
826 | case TensorExp::Kind::kLog1pC: |
827 | case TensorExp::Kind::kSinF: |
828 | case TensorExp::Kind::kSinC: |
829 | case TensorExp::Kind::kTanhF: |
830 | case TensorExp::Kind::kTanhC: |
831 | case TensorExp::Kind::kNegF: |
832 | case TensorExp::Kind::kNegC: |
833 | case TensorExp::Kind::kNegI: |
834 | case TensorExp::Kind::kTruncF: |
835 | case TensorExp::Kind::kExtF: |
836 | case TensorExp::Kind::kCastFS: |
837 | case TensorExp::Kind::kCastFU: |
838 | case TensorExp::Kind::kCastSF: |
839 | case TensorExp::Kind::kCastUF: |
840 | case TensorExp::Kind::kCastS: |
841 | case TensorExp::Kind::kCastU: |
842 | case TensorExp::Kind::kCastIdx: |
843 | case TensorExp::Kind::kTruncI: |
844 | case TensorExp::Kind::kCIm: |
845 | case TensorExp::Kind::kCRe: |
846 | case TensorExp::Kind::kBitCast: |
847 | case TensorExp::Kind::kBinaryBranch: |
848 | case TensorExp::Kind::kUnary: |
849 | case TensorExp::Kind::kSelect: |
850 | llvm::dbgs() << kindToOpSymbol(kind: expr.kind) << " " ; |
851 | dumpExp(e: expr.children.e0); |
852 | break; |
853 | // Binary operations. |
854 | case TensorExp::Kind::kMulF: |
855 | case TensorExp::Kind::kMulC: |
856 | case TensorExp::Kind::kMulI: |
857 | case TensorExp::Kind::kDivF: |
858 | case TensorExp::Kind::kDivC: |
859 | case TensorExp::Kind::kDivS: |
860 | case TensorExp::Kind::kDivU: |
861 | case TensorExp::Kind::kAddF: |
862 | case TensorExp::Kind::kAddC: |
863 | case TensorExp::Kind::kAddI: |
864 | case TensorExp::Kind::kSubF: |
865 | case TensorExp::Kind::kSubC: |
866 | case TensorExp::Kind::kSubI: |
867 | case TensorExp::Kind::kAndI: |
868 | case TensorExp::Kind::kOrI: |
869 | case TensorExp::Kind::kXorI: |
870 | case TensorExp::Kind::kShrS: |
871 | case TensorExp::Kind::kShrU: |
872 | case TensorExp::Kind::kShlI: |
873 | case TensorExp::Kind::kCmpF: |
874 | case TensorExp::Kind::kCmpI: |
875 | case TensorExp::Kind::kBinary: |
876 | case TensorExp::Kind::kReduce: |
877 | case TensorExp::Kind::kDenseOp: |
878 | llvm::dbgs() << "(" ; |
879 | dumpExp(e: expr.children.e0); |
880 | llvm::dbgs() << " " << kindToOpSymbol(kind: expr.kind); |
881 | if (expr.attr) |
882 | llvm::dbgs() << "{" << expr.attr << "}" ; |
883 | if (expr.children.e1 != detail::kInvalidId) { |
884 | llvm::dbgs() << " " ; |
885 | dumpExp(e: expr.children.e1); |
886 | llvm::dbgs() << ")" ; |
887 | } else { |
888 | assert(expr.kind == TensorExp::Kind::kDenseOp); |
889 | } |
890 | break; |
891 | } |
892 | } |
893 | |
894 | void Merger::dumpLat(LatPointId p) const { |
895 | const auto &point = lat(p); |
896 | llvm::dbgs() << "lat(" ; |
897 | dumpBits(bits: point.bits); |
898 | llvm::dbgs() << " :" ; |
899 | dumpBits(bits: point.simple); |
900 | llvm::dbgs() << " : " ; |
901 | dumpExp(e: point.exp); |
902 | llvm::dbgs() << " )\n" ; |
903 | } |
904 | |
905 | void Merger::dumpSet(LatSetId s) const { |
906 | const auto &ss = set(s); |
907 | llvm::dbgs() << "{ #" << ss.size() << "\n" ; |
908 | for (const LatPointId p : ss) { |
909 | llvm::dbgs() << " " ; |
910 | dumpLat(p); |
911 | } |
912 | llvm::dbgs() << "}\n" ; |
913 | } |
914 | |
915 | void Merger::dumpBits(const BitVector &bits) const { |
916 | for (TensorLoopId b = 0, be = bits.size(); b < be; b++) { |
917 | if (bits[b]) { |
918 | const TensorId t = tensor(b); |
919 | const LoopId i = loop(b); |
920 | const auto lt = lvlTypes[t][i]; |
921 | if (isLvlWithNonTrivialIdxExp(b)) |
922 | llvm::dbgs() << " DEP_" << t << "_" << i; |
923 | else |
924 | llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(lt); |
925 | } |
926 | } |
927 | } |
928 | |
929 | #endif // NDEBUG |
930 | |
931 | //===----------------------------------------------------------------------===// |
932 | // Builder methods. |
933 | //===----------------------------------------------------------------------===// |
934 | |
935 | LatSetId Merger::buildLattices(ExprId e, LoopId i) { |
936 | // NOTE: The `expr` reference will be invalidated by recursive calls |
937 | // (and any other method that may add new expressions); therefore, the |
938 | // code below must make sure to copy fields of `expr` into local variables |
939 | // before making any recursive calls. |
940 | const auto &expr = exp(e); |
941 | const TensorExp::Kind kind = expr.kind; |
942 | switch (kind) { |
943 | // Leaf. |
944 | case TensorExp::Kind::kTensor: |
945 | case TensorExp::Kind::kInvariant: |
946 | case TensorExp::Kind::kSynZero: |
947 | case TensorExp::Kind::kLoopVar: { |
948 | // Either the loop-var is really used in the tensor expression, or it is |
949 | // set to the undefined loop-var in that level. An invariant expression, |
950 | // a proper index value, and a truly dynamic sparse output tensor are set |
951 | // to a synthetic tensor with undefined indices only to ensure the |
952 | // iteration space is not skipped as a result of their contents. |
953 | const LatSetId s = addSet(); |
954 | TensorId t = syntheticTensor; |
955 | if (kind == TensorExp::Kind::kTensor) { |
956 | t = expr.tensor; |
957 | if (hasSparseOut && t == outTensor) |
958 | t = syntheticTensor; |
959 | } |
960 | latSets[s].push_back(Elt: addLat(t, i, e)); |
961 | return s; |
962 | } |
963 | // Unary operations. |
964 | case TensorExp::Kind::kAbsF: |
965 | case TensorExp::Kind::kAbsC: |
966 | case TensorExp::Kind::kAbsI: |
967 | case TensorExp::Kind::kCeilF: |
968 | case TensorExp::Kind::kFloorF: |
969 | case TensorExp::Kind::kSqrtF: |
970 | case TensorExp::Kind::kSqrtC: |
971 | case TensorExp::Kind::kExpm1F: |
972 | case TensorExp::Kind::kExpm1C: |
973 | case TensorExp::Kind::kLog1pF: |
974 | case TensorExp::Kind::kLog1pC: |
975 | case TensorExp::Kind::kSinF: |
976 | case TensorExp::Kind::kSinC: |
977 | case TensorExp::Kind::kTanhF: |
978 | case TensorExp::Kind::kTanhC: |
979 | case TensorExp::Kind::kNegF: |
980 | case TensorExp::Kind::kNegC: |
981 | case TensorExp::Kind::kNegI: |
982 | case TensorExp::Kind::kTruncF: |
983 | case TensorExp::Kind::kExtF: |
984 | case TensorExp::Kind::kCastFS: |
985 | case TensorExp::Kind::kCastFU: |
986 | case TensorExp::Kind::kCastSF: |
987 | case TensorExp::Kind::kCastUF: |
988 | case TensorExp::Kind::kCastS: |
989 | case TensorExp::Kind::kCastU: |
990 | case TensorExp::Kind::kCastIdx: |
991 | case TensorExp::Kind::kTruncI: |
992 | case TensorExp::Kind::kCIm: |
993 | case TensorExp::Kind::kCRe: |
994 | case TensorExp::Kind::kBitCast: |
995 | // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the |
996 | // lattice set of the operand through the operator into a new set. |
997 | // |
998 | // -y|!y | y | |
999 | // --+---+---+ |
1000 | // | 0 |-y | |
1001 | { |
1002 | const ExprId e0 = expr.children.e0; |
1003 | const Value v = expr.val; |
1004 | return mapSet(kind, s0: buildLattices(e: e0, i), v); |
1005 | } |
1006 | case TensorExp::Kind::kBinaryBranch: |
1007 | case TensorExp::Kind::kSelect: |
1008 | // The left or right half of a binary operation which has already |
1009 | // been split into separate operations for each region. |
1010 | { |
1011 | const ExprId e0 = expr.children.e0; |
1012 | Operation *const op = expr.op; |
1013 | return mapSet(kind, s0: buildLattices(e: e0, i), v: Value(), op); |
1014 | } |
1015 | case TensorExp::Kind::kUnary: |
1016 | // A custom unary operation. |
1017 | // |
1018 | // op y| !y | y | |
1019 | // ----+----------+------------+ |
1020 | // | absent() | present(y) | |
1021 | { |
1022 | const ExprId e0 = expr.children.e0; |
1023 | UnaryOp unop = cast<UnaryOp>(expr.op); |
1024 | const LatSetId child0 = buildLattices(e: e0, i); |
1025 | Region &absentRegion = unop.getAbsentRegion(); |
1026 | if (absentRegion.empty()) { |
1027 | // Simple mapping over existing values. |
1028 | return mapSet(kind, s0: child0, v: Value(), op: unop); |
1029 | } |
1030 | // Use a disjunction with `unop` on the left and the absent value as an |
1031 | // invariant on the right. |
1032 | Block &absentBlock = absentRegion.front(); |
1033 | YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator()); |
1034 | const Value absentVal = absentYield.getSingleResult(); |
1035 | const ExprId rhs = addInvariantExp(v: absentVal); |
1036 | return disjSet(e, s0: child0, s1: buildLattices(e: rhs, i), op: unop); |
1037 | } |
1038 | // Binary operations. |
1039 | case TensorExp::Kind::kMulF: |
1040 | case TensorExp::Kind::kMulC: |
1041 | case TensorExp::Kind::kMulI: |
1042 | case TensorExp::Kind::kAndI: |
1043 | // A multiplicative operation only needs to be performed |
1044 | // for the conjunction of sparse iteration spaces. |
1045 | // |
1046 | // x*y|!y | y | |
1047 | // ---+---+---+ |
1048 | // !x | 0 | 0 | |
1049 | // x | 0 |x*y| |
1050 | // |
1051 | // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored. |
1052 | { |
1053 | const ExprId e0 = expr.children.e0; |
1054 | const ExprId e1 = expr.children.e1; |
1055 | return conjSet(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i)); |
1056 | } |
1057 | case TensorExp::Kind::kDivF: |
1058 | case TensorExp::Kind::kDivC: |
1059 | case TensorExp::Kind::kDivS: |
1060 | case TensorExp::Kind::kDivU: |
1061 | // A division is tricky, since 0/0, 0/c, c/0 all have |
1062 | // specific outcomes for floating-point and integers. |
1063 | // Thus, we need to traverse the full iteration space. |
1064 | // |
1065 | // x/y|!y | y | |
1066 | // ---+---+---+ |
1067 | // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero |
1068 | // x |x/0|x/y| INT: x/0=exception for any x |
1069 | // |
1070 | // TODO: for now we "fixed" this by only accepting x/c cases |
1071 | // during expression building, so that the conjunction |
1072 | // rules applies (viz. x/c = x*(1/c) as far as lattice |
1073 | // construction is concerned). |
1074 | { |
1075 | const ExprId e0 = expr.children.e0; |
1076 | const ExprId e1 = expr.children.e1; |
1077 | assert(!maybeZero(e1)); |
1078 | return conjSet(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i)); |
1079 | } |
1080 | case TensorExp::Kind::kAddF: |
1081 | case TensorExp::Kind::kAddC: |
1082 | case TensorExp::Kind::kAddI: |
1083 | case TensorExp::Kind::kSubF: |
1084 | case TensorExp::Kind::kSubC: |
1085 | case TensorExp::Kind::kSubI: |
1086 | case TensorExp::Kind::kOrI: |
1087 | case TensorExp::Kind::kXorI: |
1088 | // An additive operation needs to be performed |
1089 | // for the disjunction of sparse iteration spaces. |
1090 | // |
1091 | // x+y|!y | y | x-y|!y | y | |
1092 | // ---+---+---+ ---+---+---+ |
1093 | // !x | 0 | y | !x | 0 |-y | |
1094 | // x | x |x+y| x | x |x-y| |
1095 | { |
1096 | const ExprId e0 = expr.children.e0; |
1097 | const ExprId e1 = expr.children.e1; |
1098 | return disjSet(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i)); |
1099 | } |
1100 | case TensorExp::Kind::kCmpF: |
1101 | case TensorExp::Kind::kCmpI: |
1102 | // A comparison operation needs to be performed |
1103 | // for the disjunction of sparse iteration spaces. |
1104 | // |
1105 | // x < y | !y | y | |
1106 | // -------+-------+-------+ |
1107 | // !x | 0 | 0 < y | |
1108 | // x | x < 0 | x < y | |
1109 | { |
1110 | const ExprId e0 = expr.children.e0; |
1111 | const ExprId e1 = expr.children.e1; |
1112 | return disjSetWithZero(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i)); |
1113 | } |
1114 | case TensorExp::Kind::kShrS: |
1115 | case TensorExp::Kind::kShrU: |
1116 | case TensorExp::Kind::kShlI: |
1117 | // A shift operation by an invariant amount (viz. tensor expressions |
1118 | // can only occur at the left-hand-side of the operator) can be handled |
1119 | // with the conjunction rule. |
1120 | { |
1121 | const ExprId e0 = expr.children.e0; |
1122 | const ExprId e1 = expr.children.e1; |
1123 | assert(isInvariant(e1)); |
1124 | return conjSet(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i)); |
1125 | } |
1126 | case TensorExp::Kind::kBinary: |
1127 | // A custom binary operation. |
1128 | // |
1129 | // x op y| !y | y | |
1130 | // ------+---------+--------------+ |
1131 | // !x | empty | right(y) | |
1132 | // x | left(x) | overlap(x,y) | |
1133 | { |
1134 | const ExprId e0 = expr.children.e0; |
1135 | const ExprId e1 = expr.children.e1; |
1136 | BinaryOp binop = cast<BinaryOp>(expr.op); |
1137 | const LatSetId child0 = buildLattices(e: e0, i); |
1138 | const LatSetId child1 = buildLattices(e: e1, i); |
1139 | Region &leftRegion = binop.getLeftRegion(); |
1140 | Region &rightRegion = binop.getRightRegion(); |
1141 | // Left Region. |
1142 | Operation *leftYield = nullptr; |
1143 | if (!leftRegion.empty()) { |
1144 | Block &leftBlock = leftRegion.front(); |
1145 | leftYield = leftBlock.getTerminator(); |
1146 | } |
1147 | // Right Region. |
1148 | Operation *rightYield = nullptr; |
1149 | if (!rightRegion.empty()) { |
1150 | Block &rightBlock = rightRegion.front(); |
1151 | rightYield = rightBlock.getTerminator(); |
1152 | } |
1153 | bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty(); |
1154 | bool includeRight = binop.getRightIdentity() || !rightRegion.empty(); |
1155 | return combiSet(e, s0: child0, s1: child1, orig: binop, includeLeft, |
1156 | ltrans: TensorExp::Kind::kBinaryBranch, opleft: leftYield, includeRight, |
1157 | rtrans: TensorExp::Kind::kBinaryBranch, opright: rightYield); |
1158 | } |
1159 | case TensorExp::Kind::kReduce: |
1160 | // A custom reduce operation. |
1161 | { |
1162 | const ExprId e0 = expr.children.e0; |
1163 | const ExprId e1 = expr.children.e1; |
1164 | Operation *const op = expr.op; |
1165 | return conjSet(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i), op); |
1166 | } |
1167 | case TensorExp::Kind::kDenseOp: { |
1168 | // It does not really matter whether we use conjunctive/disjunctive set |
1169 | // here, as all the operands of kDenseOp must be dense, the disjunctive set |
1170 | // will be optimized into conjunctive set eventually. |
1171 | if (expr.children.e1 == detail::kInvalidId) { |
1172 | const ExprId e0 = expr.children.e0; |
1173 | Operation *const op = expr.op; |
1174 | return mapSet(kind, s0: buildLattices(e: e0, i), v: Value(), op); |
1175 | } |
1176 | |
1177 | const ExprId e0 = expr.children.e0; |
1178 | const ExprId e1 = expr.children.e1; |
1179 | Operation *const op = expr.op; |
1180 | return conjSet(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i), op); |
1181 | } |
1182 | } |
1183 | llvm_unreachable("unexpected expression kind" ); |
1184 | } |
1185 | |
1186 | std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { |
1187 | // Build the linalg semantics backward from yield. |
1188 | Operation *yield = op.getRegion().front().getTerminator(); |
1189 | assert(isa<linalg::YieldOp>(yield)); |
1190 | return buildTensorExp(op, yield->getOperand(0)).first; |
1191 | } |
1192 | |
1193 | /// Only returns false if we are certain this is a nonzero. |
1194 | bool Merger::maybeZero(ExprId e) const { |
1195 | const auto &expr = exp(e); |
1196 | if (expr.kind == TensorExp::Kind::kInvariant) { |
1197 | if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) { |
1198 | ArrayAttr arrayAttr = c.getValue(); |
1199 | return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() && |
1200 | cast<FloatAttr>(arrayAttr[1]).getValue().isZero(); |
1201 | } |
1202 | if (auto c = expr.val.getDefiningOp<arith::ConstantIntOp>()) |
1203 | return c.value() == 0; |
1204 | if (auto c = expr.val.getDefiningOp<arith::ConstantFloatOp>()) |
1205 | return c.value().isZero(); |
1206 | } |
1207 | return true; |
1208 | } |
1209 | |
1210 | Type Merger::inferType(ExprId e, Value src) const { |
1211 | // Obtain the destination type from the cast node. |
1212 | Type dtp = exp(e).val.getType(); |
1213 | // Inspect source type. For vector types, apply the same |
1214 | // vectorization to the destination type. |
1215 | if (auto vtp = dyn_cast<VectorType>(src.getType())) |
1216 | return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims()); |
1217 | return dtp; |
1218 | } |
1219 | |
1220 | /// Ensures that the sparsifier can generate code for expression. |
1221 | static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) { |
1222 | // Arguments are always admissible. |
1223 | if (isa<BlockArgument>(Val: v)) |
1224 | return true; |
1225 | // Accept index anywhere. |
1226 | Operation *def = v.getDefiningOp(); |
1227 | if (isa<linalg::IndexOp>(def)) |
1228 | return true; |
1229 | // Operation defined outside branch. |
1230 | if (def->getBlock() != block) |
1231 | return def->getBlock() != op->getBlock(); // invariant? |
1232 | // Operation defined within branch. Anything is accepted, |
1233 | // as long as all subexpressions are admissible. |
1234 | for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) |
1235 | if (!isAdmissibleBranchExp(op, block, v: def->getOperand(idx: i))) |
1236 | return false; |
1237 | return true; |
1238 | } |
1239 | |
1240 | /// Ensures that the sparsifier can generate code for branch. |
1241 | static bool isAdmissibleBranch(Operation *op, Region ®ion) { |
1242 | if (region.empty()) |
1243 | return true; |
1244 | // Build the semi-ring branch semantics backward from yield. |
1245 | Operation *yield = region.front().getTerminator(); |
1246 | assert(isa<YieldOp>(yield)); |
1247 | return isAdmissibleBranchExp(op, block: ®ion.front(), v: yield->getOperand(idx: 0)); |
1248 | } |
1249 | |
1250 | std::pair<std::optional<ExprId>, bool> |
1251 | Merger::buildTensorExp(linalg::GenericOp op, Value v) { |
1252 | // Recursion leaves. |
1253 | if (auto arg = dyn_cast<BlockArgument>(Val&: v)) { |
1254 | const TensorId tid = makeTensorId(t: arg.getArgNumber()); |
1255 | // Any argument of the generic op that is not marked as a scalar |
1256 | // argument is considered a tensor, indexed by the implicit loop |
1257 | // bounds. This includes rank-0 tensor arguments. |
1258 | if (arg.getOwner()->getParentOp() == op) { |
1259 | OpOperand &t = op->getOpOperand(tid); |
1260 | bool hasSpDep = getSparseTensorEncoding(t.get().getType()) != nullptr; |
1261 | if (!op.isScalar(&t)) |
1262 | return {addTensorExp(t: tid), hasSpDep}; |
1263 | v = t.get(); // get scalar value |
1264 | } |
1265 | // Any other argument (marked as scalar argument for the generic op |
1266 | // or belonging to an enveloping op) is considered invariant. |
1267 | return {addInvariantExp(v), /*hasSpDep=*/false}; |
1268 | } |
1269 | // Something defined outside is invariant. |
1270 | Operation *def = v.getDefiningOp(); |
1271 | if (def->getBlock() != &op.getRegion().front()) |
1272 | return {addInvariantExp(v), /*hasSpDep=*/false}; |
1273 | // Construct index operations. |
1274 | if (def->getNumOperands() == 0) { |
1275 | if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) |
1276 | return {addLoopVarExp(i: makeLoopId(i: indexOp.getDim())), /*hasSpDep=*/false}; |
1277 | } |
1278 | |
1279 | // Construct unary operations if subexpression can be built. |
1280 | if (def->getNumOperands() == 1) { |
1281 | const auto [x, hasSpDep] = buildTensorExp(op, def->getOperand(0)); |
1282 | if (x.has_value()) { |
1283 | const ExprId e = *x; |
1284 | if (isa<math::AbsFOp>(def)) |
1285 | return {addExp(k: TensorExp::Kind::kAbsF, e0: e), hasSpDep}; |
1286 | if (isa<complex::AbsOp>(def)) |
1287 | return {addExp(k: TensorExp::Kind::kAbsC, e0: e), hasSpDep}; |
1288 | if (isa<math::AbsIOp>(def)) |
1289 | return {addExp(k: TensorExp::Kind::kAbsI, e0: e), hasSpDep}; |
1290 | if (isa<math::CeilOp>(def)) |
1291 | return {addExp(k: TensorExp::Kind::kCeilF, e0: e), hasSpDep}; |
1292 | if (isa<math::FloorOp>(def)) |
1293 | return {addExp(k: TensorExp::Kind::kFloorF, e0: e), hasSpDep}; |
1294 | if (isa<math::SqrtOp>(def)) |
1295 | return {addExp(k: TensorExp::Kind::kSqrtF, e0: e), hasSpDep}; |
1296 | if (isa<complex::SqrtOp>(def)) |
1297 | return {addExp(k: TensorExp::Kind::kSqrtC, e0: e), hasSpDep}; |
1298 | if (isa<math::ExpM1Op>(def)) |
1299 | return {addExp(k: TensorExp::Kind::kExpm1F, e0: e), hasSpDep}; |
1300 | if (isa<complex::Expm1Op>(def)) |
1301 | return {addExp(k: TensorExp::Kind::kExpm1C, e0: e), hasSpDep}; |
1302 | if (isa<math::Log1pOp>(def)) |
1303 | return {addExp(k: TensorExp::Kind::kLog1pF, e0: e), hasSpDep}; |
1304 | if (isa<complex::Log1pOp>(def)) |
1305 | return {addExp(k: TensorExp::Kind::kLog1pC, e0: e), hasSpDep}; |
1306 | if (isa<math::SinOp>(def)) |
1307 | return {addExp(k: TensorExp::Kind::kSinF, e0: e), hasSpDep}; |
1308 | if (isa<complex::SinOp>(def)) |
1309 | return {addExp(k: TensorExp::Kind::kSinC, e0: e), hasSpDep}; |
1310 | if (isa<math::TanhOp>(def)) |
1311 | return {addExp(k: TensorExp::Kind::kTanhF, e0: e), hasSpDep}; |
1312 | if (isa<complex::TanhOp>(def)) |
1313 | return {addExp(k: TensorExp::Kind::kTanhC, e0: e), hasSpDep}; |
1314 | if (isa<arith::NegFOp>(def)) |
1315 | return {addExp(k: TensorExp::Kind::kNegF, e0: e), hasSpDep}; // no negi in std |
1316 | if (isa<complex::NegOp>(def)) |
1317 | return {addExp(k: TensorExp::Kind::kNegC, e0: e), hasSpDep}; |
1318 | if (isa<arith::TruncFOp>(def)) |
1319 | return {addExp(k: TensorExp::Kind::kTruncF, e, v), hasSpDep}; |
1320 | if (isa<arith::ExtFOp>(def)) |
1321 | return {addExp(k: TensorExp::Kind::kExtF, e, v), hasSpDep}; |
1322 | if (isa<arith::FPToSIOp>(def)) |
1323 | return {addExp(k: TensorExp::Kind::kCastFS, e, v), hasSpDep}; |
1324 | if (isa<arith::FPToUIOp>(def)) |
1325 | return {addExp(k: TensorExp::Kind::kCastFU, e, v), hasSpDep}; |
1326 | if (isa<arith::SIToFPOp>(def)) |
1327 | return {addExp(k: TensorExp::Kind::kCastSF, e, v), hasSpDep}; |
1328 | if (isa<arith::UIToFPOp>(def)) |
1329 | return {addExp(k: TensorExp::Kind::kCastUF, e, v), hasSpDep}; |
1330 | if (isa<arith::ExtSIOp>(def)) |
1331 | return {addExp(k: TensorExp::Kind::kCastS, e, v), hasSpDep}; |
1332 | if (isa<arith::ExtUIOp>(def)) |
1333 | return {addExp(k: TensorExp::Kind::kCastU, e, v), hasSpDep}; |
1334 | if (isa<arith::IndexCastOp>(def)) |
1335 | return {addExp(k: TensorExp::Kind::kCastIdx, e, v), hasSpDep}; |
1336 | if (isa<arith::TruncIOp>(def)) |
1337 | return {addExp(k: TensorExp::Kind::kTruncI, e, v), hasSpDep}; |
1338 | if (isa<complex::ImOp>(def)) |
1339 | return {addExp(k: TensorExp::Kind::kCIm, e0: e), hasSpDep}; |
1340 | if (isa<complex::ReOp>(def)) |
1341 | return {addExp(k: TensorExp::Kind::kCRe, e0: e), hasSpDep}; |
1342 | if (isa<arith::BitcastOp>(def)) |
1343 | return {addExp(k: TensorExp::Kind::kBitCast, e, v), hasSpDep}; |
1344 | if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) { |
1345 | if (isAdmissibleBranch(unop, unop.getPresentRegion()) && |
1346 | isAdmissibleBranch(unop, unop.getAbsentRegion())) |
1347 | return {addExp(k: TensorExp::Kind::kUnary, e, v: Value(), op: def), hasSpDep}; |
1348 | } |
1349 | if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) { |
1350 | if (isAdmissibleBranch(selop, selop.getRegion())) |
1351 | return {addExp(k: TensorExp::Kind::kSelect, e, v: Value(), op: def), hasSpDep}; |
1352 | } |
1353 | } |
1354 | } |
1355 | // Construct binary operations if subexpressions can be built. |
1356 | // See buildLattices() for an explanation of rejecting certain |
1357 | // division and shift operations. |
1358 | if (def->getNumOperands() == 2) { |
1359 | const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0)); |
1360 | const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1)); |
1361 | bool hasSpDep = xDepSp || yDepSp; |
1362 | if (x.has_value() && y.has_value()) { |
1363 | const ExprId e0 = *x; |
1364 | const ExprId e1 = *y; |
1365 | if (isa<arith::MulFOp>(def)) |
1366 | return {addExp(k: TensorExp::Kind::kMulF, e0, e1), hasSpDep}; |
1367 | if (isa<complex::MulOp>(def)) |
1368 | return {addExp(k: TensorExp::Kind::kMulC, e0, e1), hasSpDep}; |
1369 | if (isa<arith::MulIOp>(def)) |
1370 | return {addExp(k: TensorExp::Kind::kMulI, e0, e1), hasSpDep}; |
1371 | if (isa<arith::DivFOp>(def) && !maybeZero(e1)) |
1372 | return {addExp(k: TensorExp::Kind::kDivF, e0, e1), hasSpDep}; |
1373 | if (isa<complex::DivOp>(def) && !maybeZero(e1)) |
1374 | return {addExp(k: TensorExp::Kind::kDivC, e0, e1), hasSpDep}; |
1375 | if (isa<arith::DivSIOp>(def) && !maybeZero(e1)) |
1376 | return {addExp(k: TensorExp::Kind::kDivS, e0, e1), hasSpDep}; |
1377 | if (isa<arith::DivUIOp>(def) && !maybeZero(e1)) |
1378 | return {addExp(k: TensorExp::Kind::kDivU, e0, e1), hasSpDep}; |
1379 | if (isa<arith::AddFOp>(def)) |
1380 | return {addExp(k: TensorExp::Kind::kAddF, e0, e1), hasSpDep}; |
1381 | if (isa<complex::AddOp>(def)) |
1382 | return {addExp(k: TensorExp::Kind::kAddC, e0, e1), hasSpDep}; |
1383 | if (isa<arith::AddIOp>(def)) |
1384 | return {addExp(k: TensorExp::Kind::kAddI, e0, e1), hasSpDep}; |
1385 | if (isa<arith::SubFOp>(def)) |
1386 | return {addExp(k: TensorExp::Kind::kSubF, e0, e1), hasSpDep}; |
1387 | if (isa<complex::SubOp>(def)) |
1388 | return {addExp(k: TensorExp::Kind::kSubC, e0, e1), hasSpDep}; |
1389 | if (isa<arith::SubIOp>(def)) |
1390 | return {addExp(k: TensorExp::Kind::kSubI, e0, e1), hasSpDep}; |
1391 | if (isa<arith::AndIOp>(def)) |
1392 | return {addExp(k: TensorExp::Kind::kAndI, e0, e1), hasSpDep}; |
1393 | if (isa<arith::OrIOp>(def)) |
1394 | return {addExp(k: TensorExp::Kind::kOrI, e0, e1), hasSpDep}; |
1395 | if (isa<arith::XOrIOp>(def)) |
1396 | return {addExp(k: TensorExp::Kind::kXorI, e0, e1), hasSpDep}; |
1397 | if (isa<arith::ShRSIOp>(def) && isInvariant(e1)) |
1398 | return {addExp(k: TensorExp::Kind::kShrS, e0, e1), hasSpDep}; |
1399 | if (isa<arith::ShRUIOp>(def) && isInvariant(e1)) |
1400 | return {addExp(k: TensorExp::Kind::kShrU, e0, e1), hasSpDep}; |
1401 | if (isa<arith::ShLIOp>(def) && isInvariant(e1)) |
1402 | return {addExp(k: TensorExp::Kind::kShlI, e0, e1), hasSpDep}; |
1403 | if (auto ci = dyn_cast<arith::CmpIOp>(def)) { |
1404 | if (ci.getPredicate() == arith::CmpIPredicate::eq && |
1405 | ci.getPredicate() == arith::CmpIPredicate::sle && |
1406 | ci.getPredicate() == arith::CmpIPredicate::sge && |
1407 | ci.getPredicate() == arith::CmpIPredicate::ule && |
1408 | ci.getPredicate() == arith::CmpIPredicate::uge) { |
1409 | // We can not sparsify comparison with equal, this is because 0 <= 0 |
1410 | // yields true, and thus densifies the result. |
1411 | return {std::nullopt, false}; |
1412 | } |
1413 | |
1414 | auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr, |
1415 | ci.getPredicateAttr()); |
1416 | return {e, hasSpDep}; |
1417 | } |
1418 | if (auto cf = dyn_cast<arith::CmpFOp>(def)) { |
1419 | if (cf.getPredicate() == arith::CmpFPredicate::OEQ && |
1420 | cf.getPredicate() == arith::CmpFPredicate::OGE && |
1421 | cf.getPredicate() == arith::CmpFPredicate::OLE && |
1422 | cf.getPredicate() == arith::CmpFPredicate::ONE && |
1423 | cf.getPredicate() == arith::CmpFPredicate::UEQ && |
1424 | cf.getPredicate() == arith::CmpFPredicate::UGE && |
1425 | cf.getPredicate() == arith::CmpFPredicate::ULE && |
1426 | cf.getPredicate() == arith::CmpFPredicate::ORD && |
1427 | cf.getPredicate() == arith::CmpFPredicate::UNO) { |
1428 | // We can not sparsify comparison with equal, this is because 0 <= 0 |
1429 | // yields true, and thus densifies the result. |
1430 | return {std::nullopt, false}; |
1431 | } |
1432 | auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr, |
1433 | cf.getPredicateAttr()); |
1434 | return {e, hasSpDep}; |
1435 | } |
1436 | if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) { |
1437 | if (isAdmissibleBranch(binop, binop.getOverlapRegion()) && |
1438 | (binop.getLeftIdentity() || |
1439 | isAdmissibleBranch(binop, binop.getLeftRegion())) && |
1440 | (binop.getRightIdentity() || |
1441 | isAdmissibleBranch(binop, binop.getRightRegion()))) |
1442 | return {addExp(k: TensorExp::Kind::kBinary, e0, e1, op: def), hasSpDep}; |
1443 | } |
1444 | } |
1445 | } |
1446 | // Construct ternary operations if subexpressions can be built. |
1447 | if (def->getNumOperands() == 3) { |
1448 | const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0)); |
1449 | const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1)); |
1450 | const auto [z, zDepSp] = buildTensorExp(op, def->getOperand(2)); |
1451 | bool hasSpDep = xDepSp || yDepSp || zDepSp; |
1452 | if (x.has_value() && y.has_value() && z.has_value()) { |
1453 | const ExprId e0 = *x; |
1454 | const ExprId e1 = *y; |
1455 | if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) { |
1456 | if (isAdmissibleBranch(redop, redop.getRegion())) |
1457 | return {addExp(k: TensorExp::Kind::kReduce, e0, e1, op: def), hasSpDep}; |
1458 | } |
1459 | } |
1460 | } |
1461 | |
1462 | // If we reach here, we are dealing with an operation that is not currently |
1463 | // sparsifiable. We can still generate code for it if all its operands only |
1464 | // have dense dependencies (i.e., all the values are loaded from dense |
1465 | // tensors). |
1466 | if (def->getNumResults() != 1) // only handle single result operation. |
1467 | return {std::nullopt, false}; |
1468 | |
1469 | SmallVector<std::pair<std::optional<ExprId>, bool>, 2> subExp; |
1470 | // Builds all the sub-expressions |
1471 | for (Value operand : def->getOperands()) |
1472 | subExp.push_back(buildTensorExp(op, operand)); |
1473 | |
1474 | if (llvm::all_of(Range&: subExp, |
1475 | P: [](auto e) { return e.first.has_value() && !e.second; })) { |
1476 | // All the subexpressions can be built and has *no* sparse dependencies. |
1477 | if (subExp.size() == 2) { |
1478 | auto e = addExp(k: TensorExp::Kind::kDenseOp, e0: *subExp[0].first, |
1479 | e1: *subExp[1].first, op: def); |
1480 | return {e, false}; |
1481 | } |
1482 | if (subExp.size() == 1) { |
1483 | auto e = addExp(k: TensorExp::Kind::kDenseOp, e0: *subExp[0].first, |
1484 | e1: detail::kInvalidId, op: def); |
1485 | return {e, false}; |
1486 | } |
1487 | } |
1488 | // Cannot build. |
1489 | return {std::nullopt, false}; |
1490 | } |
1491 | |
1492 | static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion, |
1493 | ValueRange vals) { |
1494 | // Make a clone of overlap region. |
1495 | Region tmpRegion; |
1496 | IRMapping mapper; |
1497 | region.cloneInto(dest: &tmpRegion, destPos: tmpRegion.begin(), mapper); |
1498 | Block &clonedBlock = tmpRegion.front(); |
1499 | YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator()); |
1500 | // Merge cloned block and return yield value. |
1501 | Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
1502 | rewriter.inlineBlockBefore(source: &tmpRegion.front(), op: placeholder, argValues: vals); |
1503 | Value val = clonedYield.getSingleResult(); |
1504 | rewriter.eraseOp(op: clonedYield); |
1505 | rewriter.eraseOp(op: placeholder); |
1506 | return val; |
1507 | } |
1508 | |
1509 | static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, |
1510 | Operation *op, Value v0) { |
1511 | if (!v0) |
1512 | // Empty input value must be propagated. |
1513 | return Value(); |
1514 | UnaryOp unop = cast<UnaryOp>(op); |
1515 | Region &presentRegion = unop.getPresentRegion(); |
1516 | if (presentRegion.empty()) |
1517 | // Uninitialized Value() will be interpreted as missing data in the |
1518 | // output. |
1519 | return Value(); |
1520 | return insertYieldOp(rewriter, loc, region&: presentRegion, vals: {v0}); |
1521 | } |
1522 | |
1523 | static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, |
1524 | Operation *op, Value v0, Value v1) { |
1525 | if (!v0 || !v1) |
1526 | // Empty input values must be propagated. |
1527 | return Value(); |
1528 | BinaryOp binop = cast<BinaryOp>(op); |
1529 | Region &overlapRegion = binop.getOverlapRegion(); |
1530 | if (overlapRegion.empty()) |
1531 | // Uninitialized Value() will be interpreted as missing data in the |
1532 | // output. |
1533 | return Value(); |
1534 | return insertYieldOp(rewriter, loc, region&: overlapRegion, vals: {v0, v1}); |
1535 | } |
1536 | |
1537 | Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, |
1538 | Value v1) const { |
1539 | const auto &expr = exp(e); |
1540 | switch (expr.kind) { |
1541 | // Leaf. |
1542 | case TensorExp::Kind::kTensor: |
1543 | case TensorExp::Kind::kInvariant: |
1544 | case TensorExp::Kind::kLoopVar: |
1545 | case TensorExp::Kind::kSynZero: |
1546 | llvm_unreachable("unexpected non-op" ); |
1547 | // Unary operations. |
1548 | case TensorExp::Kind::kAbsF: |
1549 | return rewriter.create<math::AbsFOp>(loc, v0); |
1550 | case TensorExp::Kind::kAbsC: { |
1551 | auto type = cast<ComplexType>(v0.getType()); |
1552 | auto eltType = cast<FloatType>(type.getElementType()); |
1553 | return rewriter.create<complex::AbsOp>(loc, eltType, v0); |
1554 | } |
1555 | case TensorExp::Kind::kAbsI: |
1556 | return rewriter.create<math::AbsIOp>(loc, v0); |
1557 | case TensorExp::Kind::kCeilF: |
1558 | return rewriter.create<math::CeilOp>(loc, v0); |
1559 | case TensorExp::Kind::kFloorF: |
1560 | return rewriter.create<math::FloorOp>(loc, v0); |
1561 | case TensorExp::Kind::kSqrtF: |
1562 | return rewriter.create<math::SqrtOp>(loc, v0); |
1563 | case TensorExp::Kind::kSqrtC: |
1564 | return rewriter.create<complex::SqrtOp>(loc, v0); |
1565 | case TensorExp::Kind::kExpm1F: |
1566 | return rewriter.create<math::ExpM1Op>(loc, v0); |
1567 | case TensorExp::Kind::kExpm1C: |
1568 | return rewriter.create<complex::Expm1Op>(loc, v0); |
1569 | case TensorExp::Kind::kLog1pF: |
1570 | return rewriter.create<math::Log1pOp>(loc, v0); |
1571 | case TensorExp::Kind::kLog1pC: |
1572 | return rewriter.create<complex::Log1pOp>(loc, v0); |
1573 | case TensorExp::Kind::kSinF: |
1574 | return rewriter.create<math::SinOp>(loc, v0); |
1575 | case TensorExp::Kind::kSinC: |
1576 | return rewriter.create<complex::SinOp>(loc, v0); |
1577 | case TensorExp::Kind::kTanhF: |
1578 | return rewriter.create<math::TanhOp>(loc, v0); |
1579 | case TensorExp::Kind::kTanhC: |
1580 | return rewriter.create<complex::TanhOp>(loc, v0); |
1581 | case TensorExp::Kind::kNegF: |
1582 | return rewriter.create<arith::NegFOp>(loc, v0); |
1583 | case TensorExp::Kind::kNegC: |
1584 | return rewriter.create<complex::NegOp>(loc, v0); |
1585 | case TensorExp::Kind::kNegI: // no negi in std |
1586 | return rewriter.create<arith::SubIOp>( |
1587 | loc, |
1588 | rewriter.create<arith::ConstantOp>(loc, v0.getType(), |
1589 | rewriter.getZeroAttr(v0.getType())), |
1590 | v0); |
1591 | case TensorExp::Kind::kTruncF: |
1592 | return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0); |
1593 | case TensorExp::Kind::kExtF: |
1594 | return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0); |
1595 | case TensorExp::Kind::kCastFS: |
1596 | return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0); |
1597 | case TensorExp::Kind::kCastFU: |
1598 | return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0); |
1599 | case TensorExp::Kind::kCastSF: |
1600 | return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0); |
1601 | case TensorExp::Kind::kCastUF: |
1602 | return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0); |
1603 | case TensorExp::Kind::kCastS: |
1604 | return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0); |
1605 | case TensorExp::Kind::kCastU: |
1606 | return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0); |
1607 | case TensorExp::Kind::kCastIdx: |
1608 | return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0); |
1609 | case TensorExp::Kind::kTruncI: |
1610 | return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0); |
1611 | case TensorExp::Kind::kCIm: { |
1612 | auto type = cast<ComplexType>(v0.getType()); |
1613 | auto eltType = cast<FloatType>(type.getElementType()); |
1614 | return rewriter.create<complex::ImOp>(loc, eltType, v0); |
1615 | } |
1616 | case TensorExp::Kind::kCRe: { |
1617 | auto type = cast<ComplexType>(v0.getType()); |
1618 | auto eltType = cast<FloatType>(type.getElementType()); |
1619 | return rewriter.create<complex::ReOp>(loc, eltType, v0); |
1620 | } |
1621 | case TensorExp::Kind::kBitCast: |
1622 | return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0); |
1623 | // Binary operations. |
1624 | case TensorExp::Kind::kMulF: |
1625 | return rewriter.create<arith::MulFOp>(loc, v0, v1); |
1626 | case TensorExp::Kind::kMulC: |
1627 | return rewriter.create<complex::MulOp>(loc, v0, v1); |
1628 | case TensorExp::Kind::kMulI: |
1629 | return rewriter.create<arith::MulIOp>(loc, v0, v1); |
1630 | case TensorExp::Kind::kDivF: |
1631 | return rewriter.create<arith::DivFOp>(loc, v0, v1); |
1632 | case TensorExp::Kind::kDivC: |
1633 | return rewriter.create<complex::DivOp>(loc, v0, v1); |
1634 | case TensorExp::Kind::kDivS: |
1635 | return rewriter.create<arith::DivSIOp>(loc, v0, v1); |
1636 | case TensorExp::Kind::kDivU: |
1637 | return rewriter.create<arith::DivUIOp>(loc, v0, v1); |
1638 | case TensorExp::Kind::kAddF: |
1639 | return rewriter.create<arith::AddFOp>(loc, v0, v1); |
1640 | case TensorExp::Kind::kAddC: |
1641 | return rewriter.create<complex::AddOp>(loc, v0, v1); |
1642 | case TensorExp::Kind::kAddI: |
1643 | return rewriter.create<arith::AddIOp>(loc, v0, v1); |
1644 | case TensorExp::Kind::kSubF: |
1645 | return rewriter.create<arith::SubFOp>(loc, v0, v1); |
1646 | case TensorExp::Kind::kSubC: |
1647 | return rewriter.create<complex::SubOp>(loc, v0, v1); |
1648 | case TensorExp::Kind::kSubI: |
1649 | return rewriter.create<arith::SubIOp>(loc, v0, v1); |
1650 | case TensorExp::Kind::kAndI: |
1651 | return rewriter.create<arith::AndIOp>(loc, v0, v1); |
1652 | case TensorExp::Kind::kOrI: |
1653 | return rewriter.create<arith::OrIOp>(loc, v0, v1); |
1654 | case TensorExp::Kind::kXorI: |
1655 | return rewriter.create<arith::XOrIOp>(loc, v0, v1); |
1656 | case TensorExp::Kind::kShrS: |
1657 | return rewriter.create<arith::ShRSIOp>(loc, v0, v1); |
1658 | case TensorExp::Kind::kShrU: |
1659 | return rewriter.create<arith::ShRUIOp>(loc, v0, v1); |
1660 | case TensorExp::Kind::kShlI: |
1661 | return rewriter.create<arith::ShLIOp>(loc, v0, v1); |
1662 | case TensorExp::Kind::kCmpI: { |
1663 | auto predicate = llvm::cast<arith::CmpIPredicateAttr>(expr.attr); |
1664 | return rewriter.create<arith::CmpIOp>(loc, predicate, v0, v1); |
1665 | } |
1666 | case TensorExp::Kind::kCmpF: { |
1667 | auto predicate = llvm::cast<arith::CmpFPredicateAttr>(expr.attr); |
1668 | return rewriter.create<arith::CmpFOp>(loc, predicate, v0, v1); |
1669 | } |
1670 | case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic. |
1671 | return insertYieldOp(rewriter, loc, region&: *expr.op->getBlock()->getParent(), |
1672 | vals: {v0}); |
1673 | case TensorExp::Kind::kUnary: |
1674 | return buildUnaryPresent(rewriter, loc, op: expr.op, v0); |
1675 | case TensorExp::Kind::kSelect: |
1676 | return insertYieldOp(rewriter, loc, cast<SelectOp>(expr.op).getRegion(), |
1677 | {v0}); |
1678 | case TensorExp::Kind::kBinary: |
1679 | return buildBinaryOverlap(rewriter, loc, op: expr.op, v0, v1); |
1680 | case TensorExp::Kind::kReduce: { |
1681 | ReduceOp redOp = cast<ReduceOp>(expr.op); |
1682 | return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1}); |
1683 | } |
1684 | case TensorExp::Kind::kDenseOp: { |
1685 | Operation *actualOp = expr.op; |
1686 | IRMapping mapping; |
1687 | mapping.map(from: actualOp->getOperand(idx: 0), to: v0); |
1688 | if (actualOp->getNumOperands() == 2) |
1689 | mapping.map(from: actualOp->getOperand(idx: 1), to: v1); |
1690 | return rewriter.clone(op&: *actualOp, mapper&: mapping)->getResult(idx: 0); |
1691 | } |
1692 | } |
1693 | llvm_unreachable("unexpected expression kind in build" ); |
1694 | } |
1695 | |
1696 | } // namespace sparse_tensor |
1697 | } // namespace mlir |
1698 | |