1 | //===- Merger.cpp - Implementation of iteration lattices ------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/SparseTensor/Utils/Merger.h" |
10 | #include "mlir/Dialect/Arith/IR/Arith.h" |
11 | #include "mlir/Dialect/Complex/IR/Complex.h" |
12 | #include "mlir/Dialect/Math/IR/Math.h" |
13 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
14 | |
15 | #include "mlir/IR/Operation.h" |
16 | #include "llvm/Support/Debug.h" |
17 | #include <optional> |
18 | |
19 | namespace mlir { |
20 | namespace sparse_tensor { |
21 | |
22 | enum class ExpArity { |
23 | kNullary, |
24 | kUnary, |
25 | kBinary, |
26 | }; |
27 | |
28 | static ExpArity getExpArity(TensorExp::Kind k) { |
29 | switch (k) { |
30 | // Leaf. |
31 | case TensorExp::Kind::kTensor: |
32 | case TensorExp::Kind::kInvariant: |
33 | case TensorExp::Kind::kLoopVar: |
34 | case TensorExp::Kind::kSynZero: |
35 | return ExpArity::kNullary; |
36 | case TensorExp::Kind::kAbsF: |
37 | case TensorExp::Kind::kAbsC: |
38 | case TensorExp::Kind::kAbsI: |
39 | case TensorExp::Kind::kCeilF: |
40 | case TensorExp::Kind::kFloorF: |
41 | case TensorExp::Kind::kSqrtF: |
42 | case TensorExp::Kind::kSqrtC: |
43 | case TensorExp::Kind::kExpm1F: |
44 | case TensorExp::Kind::kExpm1C: |
45 | case TensorExp::Kind::kLog1pF: |
46 | case TensorExp::Kind::kLog1pC: |
47 | case TensorExp::Kind::kRelu: |
48 | case TensorExp::Kind::kSinF: |
49 | case TensorExp::Kind::kSinC: |
50 | case TensorExp::Kind::kTanhF: |
51 | case TensorExp::Kind::kTanhC: |
52 | case TensorExp::Kind::kTruncF: |
53 | case TensorExp::Kind::kExtF: |
54 | case TensorExp::Kind::kCastFS: |
55 | case TensorExp::Kind::kCastFU: |
56 | case TensorExp::Kind::kCastSF: |
57 | case TensorExp::Kind::kCastUF: |
58 | case TensorExp::Kind::kCastS: |
59 | case TensorExp::Kind::kCastU: |
60 | case TensorExp::Kind::kCastIdx: |
61 | case TensorExp::Kind::kTruncI: |
62 | case TensorExp::Kind::kCIm: |
63 | case TensorExp::Kind::kCRe: |
64 | case TensorExp::Kind::kBitCast: |
65 | case TensorExp::Kind::kBinaryBranch: |
66 | case TensorExp::Kind::kUnary: |
67 | case TensorExp::Kind::kSelect: |
68 | case TensorExp::Kind::kNegF: |
69 | case TensorExp::Kind::kNegC: |
70 | case TensorExp::Kind::kNegI: |
71 | return ExpArity::kUnary; |
72 | // Binary operations. |
73 | case TensorExp::Kind::kDivF: |
74 | case TensorExp::Kind::kDivC: |
75 | case TensorExp::Kind::kDivS: |
76 | case TensorExp::Kind::kDivU: |
77 | case TensorExp::Kind::kShrS: |
78 | case TensorExp::Kind::kShrU: |
79 | case TensorExp::Kind::kShlI: |
80 | case TensorExp::Kind::kMulF: |
81 | case TensorExp::Kind::kMulC: |
82 | case TensorExp::Kind::kMulI: |
83 | case TensorExp::Kind::kAndI: |
84 | case TensorExp::Kind::kAddF: |
85 | case TensorExp::Kind::kAddC: |
86 | case TensorExp::Kind::kAddI: |
87 | case TensorExp::Kind::kOrI: |
88 | case TensorExp::Kind::kXorI: |
89 | case TensorExp::Kind::kBinary: |
90 | case TensorExp::Kind::kReduce: |
91 | case TensorExp::Kind::kSubF: |
92 | case TensorExp::Kind::kSubC: |
93 | case TensorExp::Kind::kSubI: |
94 | case TensorExp::Kind::kCmpF: |
95 | case TensorExp::Kind::kCmpI: |
96 | case TensorExp::Kind::kDenseOp: // kDenseOp can *at most* have two operands |
97 | return ExpArity::kBinary; |
98 | } |
99 | llvm_unreachable("unexpected kind" ); |
100 | } |
101 | |
102 | //===----------------------------------------------------------------------===// |
103 | // Constructors. |
104 | //===----------------------------------------------------------------------===// |
105 | |
106 | TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v, |
107 | Operation *o, Attribute a) |
108 | : kind(k), val(v), op(o), attr(a) { |
109 | switch (kind) { |
110 | // Leaf. |
111 | case TensorExp::Kind::kTensor: |
112 | assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); |
113 | tensor = x; |
114 | return; |
115 | case TensorExp::Kind::kSynZero: |
116 | assert(x == detail::kInvalidId && y == detail::kInvalidId && !v && !o); |
117 | return; |
118 | case TensorExp::Kind::kInvariant: |
119 | assert(x == detail::kInvalidId && y == detail::kInvalidId && v && !o); |
120 | return; |
121 | case TensorExp::Kind::kLoopVar: |
122 | assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); |
123 | loop = x; |
124 | return; |
125 | // Unary operations. |
126 | case TensorExp::Kind::kAbsF: |
127 | case TensorExp::Kind::kAbsC: |
128 | case TensorExp::Kind::kAbsI: |
129 | case TensorExp::Kind::kCeilF: |
130 | case TensorExp::Kind::kFloorF: |
131 | case TensorExp::Kind::kSqrtF: |
132 | case TensorExp::Kind::kSqrtC: |
133 | case TensorExp::Kind::kExpm1F: |
134 | case TensorExp::Kind::kExpm1C: |
135 | case TensorExp::Kind::kLog1pF: |
136 | case TensorExp::Kind::kLog1pC: |
137 | case TensorExp::Kind::kRelu: |
138 | case TensorExp::Kind::kSinF: |
139 | case TensorExp::Kind::kSinC: |
140 | case TensorExp::Kind::kTanhF: |
141 | case TensorExp::Kind::kTanhC: |
142 | case TensorExp::Kind::kNegF: |
143 | case TensorExp::Kind::kNegC: |
144 | case TensorExp::Kind::kNegI: |
145 | case TensorExp::Kind::kCIm: |
146 | case TensorExp::Kind::kCRe: |
147 | assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); |
148 | children.e0 = x; |
149 | children.e1 = y; |
150 | return; |
151 | case TensorExp::Kind::kTruncF: |
152 | case TensorExp::Kind::kExtF: |
153 | case TensorExp::Kind::kCastFS: |
154 | case TensorExp::Kind::kCastFU: |
155 | case TensorExp::Kind::kCastSF: |
156 | case TensorExp::Kind::kCastUF: |
157 | case TensorExp::Kind::kCastS: |
158 | case TensorExp::Kind::kCastU: |
159 | case TensorExp::Kind::kCastIdx: |
160 | case TensorExp::Kind::kTruncI: |
161 | case TensorExp::Kind::kBitCast: |
162 | assert(x != detail::kInvalidId && y == detail::kInvalidId && v && !o); |
163 | children.e0 = x; |
164 | children.e1 = y; |
165 | return; |
166 | case TensorExp::Kind::kBinaryBranch: |
167 | case TensorExp::Kind::kSelect: |
168 | assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && o); |
169 | children.e0 = x; |
170 | children.e1 = y; |
171 | return; |
172 | case TensorExp::Kind::kUnary: |
173 | // No assertion on y can be made, as the branching paths involve both |
174 | // a unary (`mapSet`) and binary (`disjSet`) pathway. |
175 | assert(x != detail::kInvalidId && !v && o); |
176 | children.e0 = x; |
177 | children.e1 = y; |
178 | return; |
179 | // Binary operations. |
180 | case TensorExp::Kind::kMulF: |
181 | case TensorExp::Kind::kMulC: |
182 | case TensorExp::Kind::kMulI: |
183 | case TensorExp::Kind::kDivF: |
184 | case TensorExp::Kind::kDivC: |
185 | case TensorExp::Kind::kDivS: |
186 | case TensorExp::Kind::kDivU: |
187 | case TensorExp::Kind::kAddF: |
188 | case TensorExp::Kind::kAddC: |
189 | case TensorExp::Kind::kAddI: |
190 | case TensorExp::Kind::kSubF: |
191 | case TensorExp::Kind::kSubC: |
192 | case TensorExp::Kind::kSubI: |
193 | case TensorExp::Kind::kAndI: |
194 | case TensorExp::Kind::kOrI: |
195 | case TensorExp::Kind::kXorI: |
196 | case TensorExp::Kind::kShrS: |
197 | case TensorExp::Kind::kShrU: |
198 | case TensorExp::Kind::kShlI: |
199 | assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o); |
200 | children.e0 = x; |
201 | children.e1 = y; |
202 | return; |
203 | case TensorExp::Kind::kCmpF: |
204 | case TensorExp::Kind::kCmpI: |
205 | assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o); |
206 | children.e0 = x; |
207 | children.e1 = y; |
208 | return; |
209 | case TensorExp::Kind::kBinary: |
210 | case TensorExp::Kind::kReduce: |
211 | assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && o); |
212 | children.e0 = x; |
213 | children.e1 = y; |
214 | return; |
215 | case TensorExp::Kind::kDenseOp: |
216 | assert(x != detail::kInvalidId && !v && o); |
217 | children.e0 = x; |
218 | children.e1 = y; |
219 | return; |
220 | } |
221 | llvm_unreachable("unexpected kind" ); |
222 | } |
223 | |
224 | Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops, |
225 | unsigned maxLvlRank) |
226 | : outTensor(numInputOutputTensors - 1), |
227 | syntheticTensor(numInputOutputTensors), |
228 | numTensors(numInputOutputTensors + 1), numLoops(numLoops), |
229 | hasSparseOut(false), |
230 | lvlTypes(numTensors, |
231 | std::vector<LevelType>(numLoops, LevelFormat::Undef)), |
232 | loopToLvl(numTensors, |
233 | std::vector<std::optional<Level>>(numLoops, std::nullopt)), |
234 | lvlToLoop(numTensors, |
235 | std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)), |
236 | loopToUnresolvedLvls(numLoops, std::vector<std::optional<LvlLTPair>>( |
237 | numTensors, std::nullopt)), |
238 | levelToDependentLoop(numTensors, |
239 | std::vector<std::vector<LoopCoeffPair>>( |
240 | maxLvlRank, std::vector<LoopCoeffPair>())), |
241 | loopBounds(numLoops, std::make_pair(x: numTensors, y&: numLoops)) {} |
242 | |
243 | //===----------------------------------------------------------------------===// |
244 | // Lattice methods. |
245 | //===----------------------------------------------------------------------===// |
246 | |
247 | ExprId Merger::addTensorExp(TensorId t) { |
248 | assert(isValidTensorId(t)); |
249 | const ExprId eNew(tensorExps.size()); |
250 | tensorExps.emplace_back(Args: TensorExp::Kind::kTensor, Args&: t, Args: detail::kInvalidId, |
251 | Args: Value(), Args: nullptr, Args: nullptr); |
252 | return eNew; |
253 | } |
254 | |
255 | ExprId Merger::addLoopVarExp(LoopId i) { |
256 | assert(isValidLoopId(i)); |
257 | const ExprId eNew(tensorExps.size()); |
258 | tensorExps.emplace_back(Args: TensorExp::Kind::kLoopVar, Args&: i, Args: detail::kInvalidId, |
259 | Args: Value(), Args: nullptr, Args: nullptr); |
260 | return eNew; |
261 | } |
262 | |
263 | ExprId Merger::addInvariantExp(Value v) { |
264 | const ExprId eNew(tensorExps.size()); |
265 | tensorExps.emplace_back(Args: TensorExp::Kind::kInvariant, Args: detail::kInvalidId, |
266 | Args: detail::kInvalidId, Args&: v, Args: nullptr, Args: nullptr); |
267 | return eNew; |
268 | } |
269 | |
270 | ExprId Merger::addSynZeroExp() { |
271 | const ExprId eNew(tensorExps.size()); |
272 | tensorExps.emplace_back(Args: TensorExp::Kind::kSynZero, Args: detail::kInvalidId, |
273 | Args: detail::kInvalidId, Args: Value(), Args: nullptr, Args: nullptr); |
274 | return eNew; |
275 | } |
276 | |
277 | ExprId Merger::addExp(TensorExp::Kind k, ExprId e0, ExprId e1, Operation *op, |
278 | Attribute attr) { |
279 | assert(k > TensorExp::Kind::kLoopVar); |
280 | const ExprId eNew(tensorExps.size()); |
281 | tensorExps.emplace_back(Args&: k, Args&: e0, Args&: e1, Args: Value(), Args&: op, Args&: attr); |
282 | return eNew; |
283 | } |
284 | |
285 | ExprId Merger::addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op, |
286 | Attribute attr) { |
287 | assert(k > TensorExp::Kind::kLoopVar); |
288 | const ExprId eNew(tensorExps.size()); |
289 | tensorExps.emplace_back(Args&: k, Args&: e, Args: detail::kInvalidId, Args&: v, Args&: op, Args&: attr); |
290 | return eNew; |
291 | } |
292 | |
293 | LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) { |
294 | const LatPointId pNew(latPoints.size()); |
295 | const unsigned size = numLoops * numTensors; |
296 | const TensorLoopId b = makeTensorLoopId(t, i); |
297 | latPoints.emplace_back(Args: size, Args&: e); |
298 | latPoints[pNew].bits.set(b); |
299 | return pNew; |
300 | } |
301 | |
302 | LatPointId Merger::addLat(const BitVector &bits, ExprId e) { |
303 | assert(bits.size() == numLoops * numTensors); |
304 | const LatPointId pNew(latPoints.size()); |
305 | latPoints.emplace_back(Args: bits, Args&: e); |
306 | return pNew; |
307 | } |
308 | |
309 | LatSetId Merger::addSet() { |
310 | const LatSetId sNew(latSets.size()); |
311 | latSets.emplace_back(); |
312 | return sNew; |
313 | } |
314 | |
315 | LatPointId Merger::conjLat(ExprId e, LatPointId p0, LatPointId p1, |
316 | Operation *op) { |
317 | TensorExp::Kind kind = exp(e).kind; |
318 | Attribute attr = exp(e).attr; |
319 | const LatPointId pNew(latPoints.size()); |
320 | const auto &point0 = lat(p: p0); |
321 | const auto &point1 = lat(p: p1); |
322 | BitVector bits(point0.bits); |
323 | bits |= point1.bits; |
324 | const ExprId ne = addExp(k: kind, e0: point0.exp, e1: point1.exp, op, attr); |
325 | latPoints.emplace_back(Args&: bits, Args: ne); |
326 | return pNew; |
327 | } |
328 | |
329 | LatSetId Merger::conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) { |
330 | const LatSetId sNew = addSet(); |
331 | auto &setNew = latSets[sNew]; |
332 | for (const LatPointId p0 : set(s0)) |
333 | for (const LatPointId p1 : set(s1)) |
334 | setNew.push_back(Elt: conjLat(e, p0, p1, op)); |
335 | return sNew; |
336 | } |
337 | |
338 | LatSetId Merger::disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) { |
339 | const LatSetId sNew = conjSet(e, s0, s1, op); |
340 | TensorExp::Kind kind = exp(e).kind; |
341 | // Followed by all in s0. |
342 | latSets[sNew].append(RHS: latSets[s0]); |
343 | // Map binary 0-y to unary -y. |
344 | // TODO: move this if-else logic into buildLattices |
345 | if (kind == TensorExp::Kind::kSubF) |
346 | s1 = mapSet(kind: TensorExp::Kind::kNegF, s: s1); |
347 | else if (kind == TensorExp::Kind::kSubC) |
348 | s1 = mapSet(kind: TensorExp::Kind::kNegC, s: s1); |
349 | else if (kind == TensorExp::Kind::kSubI) |
350 | s1 = mapSet(kind: TensorExp::Kind::kNegI, s: s1); |
351 | // Followed by all in s1. |
352 | latSets[sNew].append(RHS: latSets[s1]); |
353 | return sNew; |
354 | } |
355 | |
356 | LatSetId Merger::disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1) { |
357 | assert(exp(e).kind == TensorExp::Kind::kCmpI || |
358 | exp(e).kind == TensorExp::Kind::kCmpF); |
359 | const LatSetId sNew = conjSet(e, s0, s1, op: nullptr); |
360 | |
361 | ExprId e0 = exp(e).children.e0; |
362 | ExprId e1 = exp(e).children.e1; |
363 | if (exp(e: e0).kind == TensorExp::Kind::kSynZero || |
364 | exp(e: e1).kind == TensorExp::Kind::kSynZero) { |
365 | // lhs and rhs can't be synthetic zero at the same time. |
366 | assert(exp(e0).kind != exp(e1).kind); |
367 | // If one of the operands has already been assigned to zero (the |
368 | // element is absent in the corresponding operand), then we do not |
369 | // need to build disjunctive set for it. |
370 | return sNew; |
371 | } |
372 | |
373 | auto lhsSet = mapBinWithSynZeroSet(e, s: s0, lhsZero: false); |
374 | auto rhsSet = mapBinWithSynZeroSet(e, s: s1, lhsZero: true); |
375 | latSets[sNew].append(RHS: latSets[lhsSet]); |
376 | latSets[sNew].append(RHS: latSets[rhsSet]); |
377 | return sNew; |
378 | } |
379 | |
380 | LatSetId Merger::combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig, |
381 | bool includeLeft, TensorExp::Kind ltrans, |
382 | Operation *opleft, bool includeRight, |
383 | TensorExp::Kind rtrans, Operation *opright) { |
384 | Attribute a = exp(e).attr; |
385 | const LatSetId sNew = conjSet(e, s0, s1, op: orig); |
386 | // Left Region. |
387 | if (includeLeft) { |
388 | if (opleft) |
389 | s0 = mapSet(kind: ltrans, s: s0, v: Value(), op: opleft, attr: a); |
390 | latSets[sNew].append(RHS: latSets[s0]); |
391 | } |
392 | // Right Region. |
393 | if (includeRight) { |
394 | if (opright) |
395 | s1 = mapSet(kind: rtrans, s: s1, v: Value(), op: opright, attr: a); |
396 | latSets[sNew].append(RHS: latSets[s1]); |
397 | } |
398 | return sNew; |
399 | } |
400 | |
401 | LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v, |
402 | Operation *op, Attribute a) { |
403 | assert((TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect) || |
404 | TensorExp::Kind::kDenseOp == kind); |
405 | const LatSetId sNew = addSet(); |
406 | auto &setNew = latSets[sNew]; |
407 | for (const LatPointId p : set(s0)) { |
408 | const auto &point = latPoints[p]; |
409 | setNew.push_back(Elt: addLat(bits: point.bits, e: addExp(k: kind, e: point.exp, v, op, attr: a))); |
410 | } |
411 | return sNew; |
412 | } |
413 | |
414 | LatSetId Merger::mapBinWithSynZeroSet(ExprId e, LatSetId s0, bool lhsZero) { |
415 | TensorExp::Kind kind = exp(e).kind; |
416 | Attribute a = exp(e).attr; |
417 | assert(TensorExp::Kind::kMulF <= kind && kind <= TensorExp::Kind::kShlI); |
418 | // Must be a binary operation. |
419 | const LatSetId sNew = addSet(); |
420 | auto &setNew = latSets[sNew]; |
421 | const ExprId zeroExp = addSynZeroExp(); |
422 | for (const LatPointId p : set(s0)) { |
423 | const auto &point = latPoints[p]; |
424 | ExprId newExp = lhsZero ? addExp(k: kind, e0: zeroExp, e1: point.exp, op: nullptr, attr: a) |
425 | : addExp(k: kind, e0: point.exp, e1: zeroExp, op: nullptr, attr: a); |
426 | setNew.push_back(Elt: addLat(bits: point.bits, e: newExp)); |
427 | } |
428 | return sNew; |
429 | } |
430 | |
431 | LatSetId Merger::optimizeSet(LatSetId s0) { |
432 | const LatSetId sNew = addSet(); |
433 | auto &setNew = latSets[sNew]; |
434 | const auto &set0 = set(s0); |
435 | assert(!set0.empty()); |
436 | const LatPointId p0 = set0[0]; |
437 | for (const LatPointId p1 : set0) { |
438 | bool add = true; |
439 | if (p0 != p1) { |
440 | // Check whether this is a straightforward copy. |
441 | if (expIsTensor(e: latPoints[p1].exp, t: outTensor)) |
442 | continue; |
443 | // Check whether this conjunction is already covered. |
444 | for (const LatPointId p2 : setNew) { |
445 | assert(!latGT(p1, p2)); // Lj => Li would be bad |
446 | if (onlyDenseDiff(p0: p2, p1)) { |
447 | add = false; |
448 | break; |
449 | } |
450 | } |
451 | assert(!add || latGT(p0, p1)); |
452 | } |
453 | if (add) |
454 | setNew.push_back(Elt: p1); |
455 | } |
456 | for (const LatPointId p : setNew) |
457 | latPoints[p].simple = simplifyCond(s: sNew, p); |
458 | return sNew; |
459 | } |
460 | |
461 | BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) { |
462 | // First determine if this lattice point is a *singleton*, i.e., |
463 | // the last point in a lattice, no other is less than this one. |
464 | bool isSingleton = true; |
465 | for (const LatPointId p1 : set(s0)) { |
466 | if (p0 != p1 && latGT(p0, p1)) { |
467 | isSingleton = false; |
468 | break; |
469 | } |
470 | } |
471 | |
472 | BitVector simple(latPoints[p0].bits); |
473 | bool reset = isSingleton && hasAnySparse(bits: simple); |
474 | const TensorLoopId be = simple.size(); |
475 | TensorLoopId offset = 0; // relative to the end |
476 | if (!reset) |
477 | // Starts resetting from a dense level, so that the first bit (if kept) |
478 | // is not undefined level-type. |
479 | for (unsigned b = 0; b < be; b++) { |
480 | if (simple[b] && getLvlType(b: TensorLoopId{b}).hasDenseSemantic()) { |
481 | offset = be - b - 1; // relative to the end |
482 | break; |
483 | } |
484 | } |
485 | |
486 | // Now apply the two basic rules. We also iterate the bits reversely to always |
487 | // keep the rightmost bit (which could possibly be a synthetic tensor). |
488 | for (unsigned b = be - 1 - offset, i = 0; i < be; |
489 | b = b == 0 ? be - 1 : b - 1, i++) { |
490 | // Slice on dense level has `locate` property as well, and can be optimized. |
491 | if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) { |
492 | const auto lt = getLvlType(b); |
493 | if (!lt.hasSparseSemantic()) { |
494 | if (reset) |
495 | simple.reset(Idx: b); |
496 | reset = true; |
497 | } |
498 | } |
499 | } |
500 | return simple; |
501 | } |
502 | |
503 | bool Merger::latGT(LatPointId i, LatPointId j) const { |
504 | const BitVector &bitsi = lat(p: i).bits; |
505 | const BitVector &bitsj = lat(p: j).bits; |
506 | assert(bitsi.size() == bitsj.size()); |
507 | if (bitsi.count() > bitsj.count()) { |
508 | for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++) |
509 | if (bitsj[b] && !bitsi[b]) |
510 | return false; |
511 | return true; |
512 | } |
513 | return false; |
514 | } |
515 | |
516 | bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const { |
517 | BitVector tmp(latPoints[j].bits); |
518 | tmp ^= latPoints[i].bits; |
519 | return !hasAnySparse(bits: tmp); |
520 | } |
521 | |
522 | bool Merger::expContainsTensor(ExprId e, TensorId t) const { |
523 | const auto &expr = exp(e); |
524 | // First we check `expIsTensor`. |
525 | if (expr.kind == TensorExp::Kind::kTensor) |
526 | return expr.tensor == t; |
527 | |
528 | switch (getExpArity(k: expr.kind)) { |
529 | case ExpArity::kNullary: |
530 | return false; |
531 | case ExpArity::kUnary: { |
532 | const ExprId e0 = expr.children.e0; |
533 | return expContainsTensor(e: e0, t); |
534 | } |
535 | case ExpArity::kBinary: { |
536 | const ExprId e0 = expr.children.e0; |
537 | const ExprId e1 = expr.children.e1; |
538 | return expContainsTensor(e: e0, t) || expContainsTensor(e: e1, t); |
539 | } |
540 | } |
541 | llvm_unreachable("unexpected arity" ); |
542 | } |
543 | |
544 | bool Merger::hasNegateOnOut(ExprId e) const { |
545 | const auto &expr = exp(e); |
546 | switch (expr.kind) { |
547 | case TensorExp::Kind::kNegF: |
548 | case TensorExp::Kind::kNegC: |
549 | case TensorExp::Kind::kNegI: |
550 | return expContainsTensor(e: expr.children.e0, t: outTensor); |
551 | case TensorExp::Kind::kSubF: |
552 | case TensorExp::Kind::kSubC: |
553 | case TensorExp::Kind::kSubI: |
554 | return expContainsTensor(e: expr.children.e1, t: outTensor) || |
555 | hasNegateOnOut(e: expr.children.e0); |
556 | case TensorExp::Kind::kDenseOp: { |
557 | bool lhsNeg = hasNegateOnOut(e: expr.children.e0); |
558 | if (!lhsNeg && expr.children.e1 != detail::kInvalidId) |
559 | return hasNegateOnOut(e: expr.children.e1); |
560 | return lhsNeg; |
561 | } |
562 | default: { |
563 | switch (getExpArity(k: expr.kind)) { |
564 | case ExpArity::kNullary: |
565 | return false; |
566 | case ExpArity::kUnary: |
567 | return hasNegateOnOut(e: expr.children.e0); |
568 | case ExpArity::kBinary: |
569 | return hasNegateOnOut(e: expr.children.e0) || |
570 | hasNegateOnOut(e: expr.children.e1); |
571 | } |
572 | } |
573 | } |
574 | llvm_unreachable("unexpected kind" ); |
575 | } |
576 | |
577 | bool Merger::isSingleCondition(TensorId t, ExprId e) const { |
578 | assert(isValidTensorId(t)); |
579 | const auto &expr = exp(e); |
580 | switch (expr.kind) { |
581 | // Leaf. |
582 | case TensorExp::Kind::kTensor: |
583 | return expr.tensor == t; |
584 | case TensorExp::Kind::kInvariant: |
585 | case TensorExp::Kind::kLoopVar: |
586 | case TensorExp::Kind::kSynZero: |
587 | return false; |
588 | // Unary operations. |
589 | case TensorExp::Kind::kAbsF: |
590 | case TensorExp::Kind::kAbsC: |
591 | case TensorExp::Kind::kAbsI: |
592 | case TensorExp::Kind::kCeilF: |
593 | case TensorExp::Kind::kFloorF: |
594 | case TensorExp::Kind::kSqrtF: |
595 | case TensorExp::Kind::kSqrtC: |
596 | case TensorExp::Kind::kExpm1F: |
597 | case TensorExp::Kind::kExpm1C: |
598 | case TensorExp::Kind::kLog1pF: |
599 | case TensorExp::Kind::kLog1pC: |
600 | case TensorExp::Kind::kRelu: |
601 | case TensorExp::Kind::kSinF: |
602 | case TensorExp::Kind::kSinC: |
603 | case TensorExp::Kind::kTanhF: |
604 | case TensorExp::Kind::kTanhC: |
605 | case TensorExp::Kind::kNegF: |
606 | case TensorExp::Kind::kNegC: |
607 | case TensorExp::Kind::kNegI: |
608 | case TensorExp::Kind::kTruncF: |
609 | case TensorExp::Kind::kExtF: |
610 | case TensorExp::Kind::kCastFS: |
611 | case TensorExp::Kind::kCastFU: |
612 | case TensorExp::Kind::kCastSF: |
613 | case TensorExp::Kind::kCastUF: |
614 | case TensorExp::Kind::kCastS: |
615 | case TensorExp::Kind::kCastU: |
616 | case TensorExp::Kind::kCastIdx: |
617 | case TensorExp::Kind::kTruncI: |
618 | case TensorExp::Kind::kCIm: |
619 | case TensorExp::Kind::kCRe: |
620 | case TensorExp::Kind::kBitCast: |
621 | case TensorExp::Kind::kUnary: |
622 | return isSingleCondition(t, e: expr.children.e0); |
623 | case TensorExp::Kind::kBinaryBranch: |
624 | case TensorExp::Kind::kSelect: |
625 | return false; |
626 | // Binary operations. |
627 | case TensorExp::Kind::kDivF: // note: x / c only |
628 | case TensorExp::Kind::kDivC: |
629 | case TensorExp::Kind::kDivS: |
630 | case TensorExp::Kind::kDivU: |
631 | assert(!maybeZero(expr.children.e1)); |
632 | return isSingleCondition(t, e: expr.children.e0); |
633 | case TensorExp::Kind::kShrS: // note: x >> inv only |
634 | case TensorExp::Kind::kShrU: |
635 | case TensorExp::Kind::kShlI: |
636 | assert(isInvariant(expr.children.e1)); |
637 | return isSingleCondition(t, e: expr.children.e0); |
638 | case TensorExp::Kind::kMulF: |
639 | case TensorExp::Kind::kMulC: |
640 | case TensorExp::Kind::kMulI: |
641 | case TensorExp::Kind::kAndI: |
642 | case TensorExp::Kind::kReduce: |
643 | if (isSingleCondition(t, e: expr.children.e0)) |
644 | return isSingleCondition(t, e: expr.children.e1) || |
645 | isInvariant(e: expr.children.e1); |
646 | if (isSingleCondition(t, e: expr.children.e1)) |
647 | return isInvariant(e: expr.children.e0); |
648 | return false; |
649 | case TensorExp::Kind::kAddF: |
650 | case TensorExp::Kind::kAddC: |
651 | case TensorExp::Kind::kAddI: |
652 | return isSingleCondition(t, e: expr.children.e0) && |
653 | isSingleCondition(t, e: expr.children.e1); |
654 | case TensorExp::Kind::kSubF: |
655 | case TensorExp::Kind::kSubC: |
656 | case TensorExp::Kind::kSubI: |
657 | case TensorExp::Kind::kOrI: |
658 | case TensorExp::Kind::kXorI: |
659 | case TensorExp::Kind::kCmpF: |
660 | case TensorExp::Kind::kCmpI: |
661 | case TensorExp::Kind::kBinary: |
662 | return false; |
663 | case TensorExp::Kind::kDenseOp: |
664 | // Since Merger guarantees all the operands of the kDenseOp to be dense, the |
665 | // operation must be single-condition. |
666 | return true; |
667 | } |
668 | llvm_unreachable("unexpected kind" ); |
669 | } |
670 | |
671 | bool Merger::hasAnySparse(const BitVector &bits) const { |
672 | for (TensorLoopId b : bits.set_bits()) { |
673 | const auto lt = getLvlType(b); |
674 | if (lt.hasSparseSemantic()) |
675 | return true; |
676 | } |
677 | return hasSparseIdxReduction(bits); |
678 | } |
679 | |
680 | bool Merger::hasSparseIdxReduction(const BitVector &bits) const { |
681 | for (TensorLoopId b : bits.set_bits()) |
682 | if (isSparseLvlWithNonTrivialIdxExp(b)) |
683 | return true; |
684 | return false; |
685 | } |
686 | |
687 | #ifndef NDEBUG |
688 | |
689 | //===----------------------------------------------------------------------===// |
690 | // Print methods (for debugging). |
691 | //===----------------------------------------------------------------------===// |
692 | |
693 | static const char *kindToOpSymbol(TensorExp::Kind kind) { |
694 | switch (kind) { |
695 | // Leaf. |
696 | case TensorExp::Kind::kTensor: |
697 | return "tensor" ; |
698 | case TensorExp::Kind::kInvariant: |
699 | return "invariant" ; |
700 | case TensorExp::Kind::kLoopVar: |
701 | return "index" ; |
702 | case TensorExp::Kind::kSynZero: |
703 | return "0" ; |
704 | // Unary operations. |
705 | case TensorExp::Kind::kAbsF: |
706 | case TensorExp::Kind::kAbsC: |
707 | case TensorExp::Kind::kAbsI: |
708 | return "abs" ; |
709 | case TensorExp::Kind::kCeilF: |
710 | return "ceil" ; |
711 | case TensorExp::Kind::kFloorF: |
712 | return "floor" ; |
713 | case TensorExp::Kind::kSqrtF: |
714 | case TensorExp::Kind::kSqrtC: |
715 | return "sqrt" ; |
716 | case TensorExp::Kind::kExpm1F: |
717 | case TensorExp::Kind::kExpm1C: |
718 | return "expm1" ; |
719 | case TensorExp::Kind::kLog1pF: |
720 | case TensorExp::Kind::kLog1pC: |
721 | return "log1p" ; |
722 | case TensorExp::Kind::kRelu: |
723 | return "relu" ; |
724 | case TensorExp::Kind::kSinF: |
725 | case TensorExp::Kind::kSinC: |
726 | return "sin" ; |
727 | case TensorExp::Kind::kTanhF: |
728 | case TensorExp::Kind::kTanhC: |
729 | return "tanh" ; |
730 | case TensorExp::Kind::kNegF: |
731 | case TensorExp::Kind::kNegC: |
732 | case TensorExp::Kind::kNegI: |
733 | return "-" ; |
734 | case TensorExp::Kind::kTruncF: |
735 | case TensorExp::Kind::kExtF: |
736 | case TensorExp::Kind::kCastFS: |
737 | case TensorExp::Kind::kCastFU: |
738 | case TensorExp::Kind::kCastSF: |
739 | case TensorExp::Kind::kCastUF: |
740 | case TensorExp::Kind::kCastS: |
741 | case TensorExp::Kind::kCastU: |
742 | case TensorExp::Kind::kCastIdx: |
743 | case TensorExp::Kind::kTruncI: |
744 | case TensorExp::Kind::kCIm: |
745 | return "complex.im" ; |
746 | case TensorExp::Kind::kCRe: |
747 | return "complex.re" ; |
748 | case TensorExp::Kind::kBitCast: |
749 | return "cast" ; |
750 | case TensorExp::Kind::kBinaryBranch: |
751 | return "binary_branch" ; |
752 | case TensorExp::Kind::kUnary: |
753 | return "unary" ; |
754 | case TensorExp::Kind::kSelect: |
755 | return "select" ; |
756 | // Binary operations. |
757 | case TensorExp::Kind::kMulF: |
758 | case TensorExp::Kind::kMulC: |
759 | case TensorExp::Kind::kMulI: |
760 | return "*" ; |
761 | case TensorExp::Kind::kDivF: |
762 | case TensorExp::Kind::kDivC: |
763 | case TensorExp::Kind::kDivS: |
764 | case TensorExp::Kind::kDivU: |
765 | return "/" ; |
766 | case TensorExp::Kind::kAddF: |
767 | case TensorExp::Kind::kAddC: |
768 | case TensorExp::Kind::kAddI: |
769 | return "+" ; |
770 | case TensorExp::Kind::kSubF: |
771 | case TensorExp::Kind::kSubC: |
772 | case TensorExp::Kind::kSubI: |
773 | return "-" ; |
774 | case TensorExp::Kind::kAndI: |
775 | return "&" ; |
776 | case TensorExp::Kind::kOrI: |
777 | return "|" ; |
778 | case TensorExp::Kind::kXorI: |
779 | return "^" ; |
780 | case TensorExp::Kind::kShrS: |
781 | return "a>>" ; |
782 | case TensorExp::Kind::kShrU: |
783 | return ">>" ; |
784 | case TensorExp::Kind::kShlI: |
785 | return "<<" ; |
786 | case TensorExp::Kind::kCmpF: |
787 | case TensorExp::Kind::kCmpI: |
788 | return "cmp" ; |
789 | case TensorExp::Kind::kBinary: |
790 | return "binary" ; |
791 | case TensorExp::Kind::kReduce: |
792 | return "reduce" ; |
793 | case TensorExp::Kind::kDenseOp: |
794 | return "dense" ; |
795 | } |
796 | llvm_unreachable("unexpected kind for symbol" ); |
797 | } |
798 | |
799 | void Merger::dumpExp(ExprId e) const { |
800 | const auto &expr = exp(e); |
801 | switch (expr.kind) { |
802 | // Leaf. |
803 | case TensorExp::Kind::kTensor: |
804 | if (expr.tensor == syntheticTensor) |
805 | llvm::dbgs() << "synthetic_" ; |
806 | else if (expr.tensor == outTensor) |
807 | llvm::dbgs() << "output_" ; |
808 | llvm::dbgs() << "tensor_" << expr.tensor; |
809 | break; |
810 | case TensorExp::Kind::kInvariant: |
811 | llvm::dbgs() << "invariant" ; |
812 | break; |
813 | case TensorExp::Kind::kSynZero: |
814 | llvm::dbgs() << "0" ; |
815 | break; |
816 | case TensorExp::Kind::kLoopVar: |
817 | llvm::dbgs() << "loopvar_" << expr.loop; |
818 | break; |
819 | // Unary operations. |
820 | case TensorExp::Kind::kAbsF: |
821 | case TensorExp::Kind::kAbsC: |
822 | case TensorExp::Kind::kAbsI: |
823 | case TensorExp::Kind::kCeilF: |
824 | case TensorExp::Kind::kFloorF: |
825 | case TensorExp::Kind::kSqrtF: |
826 | case TensorExp::Kind::kSqrtC: |
827 | case TensorExp::Kind::kExpm1F: |
828 | case TensorExp::Kind::kExpm1C: |
829 | case TensorExp::Kind::kLog1pF: |
830 | case TensorExp::Kind::kLog1pC: |
831 | case TensorExp::Kind::kRelu: |
832 | case TensorExp::Kind::kSinF: |
833 | case TensorExp::Kind::kSinC: |
834 | case TensorExp::Kind::kTanhF: |
835 | case TensorExp::Kind::kTanhC: |
836 | case TensorExp::Kind::kNegF: |
837 | case TensorExp::Kind::kNegC: |
838 | case TensorExp::Kind::kNegI: |
839 | case TensorExp::Kind::kTruncF: |
840 | case TensorExp::Kind::kExtF: |
841 | case TensorExp::Kind::kCastFS: |
842 | case TensorExp::Kind::kCastFU: |
843 | case TensorExp::Kind::kCastSF: |
844 | case TensorExp::Kind::kCastUF: |
845 | case TensorExp::Kind::kCastS: |
846 | case TensorExp::Kind::kCastU: |
847 | case TensorExp::Kind::kCastIdx: |
848 | case TensorExp::Kind::kTruncI: |
849 | case TensorExp::Kind::kCIm: |
850 | case TensorExp::Kind::kCRe: |
851 | case TensorExp::Kind::kBitCast: |
852 | case TensorExp::Kind::kBinaryBranch: |
853 | case TensorExp::Kind::kUnary: |
854 | case TensorExp::Kind::kSelect: |
855 | llvm::dbgs() << kindToOpSymbol(kind: expr.kind) << " " ; |
856 | dumpExp(e: expr.children.e0); |
857 | break; |
858 | // Binary operations. |
859 | case TensorExp::Kind::kMulF: |
860 | case TensorExp::Kind::kMulC: |
861 | case TensorExp::Kind::kMulI: |
862 | case TensorExp::Kind::kDivF: |
863 | case TensorExp::Kind::kDivC: |
864 | case TensorExp::Kind::kDivS: |
865 | case TensorExp::Kind::kDivU: |
866 | case TensorExp::Kind::kAddF: |
867 | case TensorExp::Kind::kAddC: |
868 | case TensorExp::Kind::kAddI: |
869 | case TensorExp::Kind::kSubF: |
870 | case TensorExp::Kind::kSubC: |
871 | case TensorExp::Kind::kSubI: |
872 | case TensorExp::Kind::kAndI: |
873 | case TensorExp::Kind::kOrI: |
874 | case TensorExp::Kind::kXorI: |
875 | case TensorExp::Kind::kShrS: |
876 | case TensorExp::Kind::kShrU: |
877 | case TensorExp::Kind::kShlI: |
878 | case TensorExp::Kind::kCmpF: |
879 | case TensorExp::Kind::kCmpI: |
880 | case TensorExp::Kind::kBinary: |
881 | case TensorExp::Kind::kReduce: |
882 | case TensorExp::Kind::kDenseOp: |
883 | llvm::dbgs() << "(" ; |
884 | dumpExp(e: expr.children.e0); |
885 | llvm::dbgs() << " " << kindToOpSymbol(kind: expr.kind); |
886 | if (expr.attr) |
887 | llvm::dbgs() << "{" << expr.attr << "}" ; |
888 | if (expr.children.e1 != detail::kInvalidId) { |
889 | llvm::dbgs() << " " ; |
890 | dumpExp(e: expr.children.e1); |
891 | llvm::dbgs() << ")" ; |
892 | } else { |
893 | assert(expr.kind == TensorExp::Kind::kDenseOp); |
894 | } |
895 | break; |
896 | } |
897 | } |
898 | |
899 | void Merger::dumpLat(LatPointId p) const { |
900 | const auto &point = lat(p); |
901 | llvm::dbgs() << "lat(" ; |
902 | dumpBits(bits: point.bits); |
903 | llvm::dbgs() << " :" ; |
904 | dumpBits(bits: point.simple); |
905 | llvm::dbgs() << " : " ; |
906 | dumpExp(e: point.exp); |
907 | llvm::dbgs() << " )\n" ; |
908 | } |
909 | |
910 | void Merger::dumpSet(LatSetId s) const { |
911 | const auto &ss = set(s); |
912 | llvm::dbgs() << "{ #" << ss.size() << "\n" ; |
913 | for (const LatPointId p : ss) { |
914 | llvm::dbgs() << " " ; |
915 | dumpLat(p); |
916 | } |
917 | llvm::dbgs() << "}\n" ; |
918 | } |
919 | |
920 | void Merger::dumpBits(const BitVector &bits) const { |
921 | for (TensorLoopId b = 0, be = bits.size(); b < be; b++) { |
922 | if (bits[b]) { |
923 | const TensorId t = tensor(b); |
924 | const LoopId i = loop(b); |
925 | const auto lt = lvlTypes[t][i]; |
926 | if (isLvlWithNonTrivialIdxExp(b)) |
927 | llvm::dbgs() << " DEP_" << t << "_" << i; |
928 | else |
929 | llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(lt); |
930 | } |
931 | } |
932 | } |
933 | |
934 | #endif // NDEBUG |
935 | |
936 | //===----------------------------------------------------------------------===// |
937 | // Builder methods. |
938 | //===----------------------------------------------------------------------===// |
939 | |
940 | LatSetId Merger::buildLattices(ExprId e, LoopId i) { |
941 | // NOTE: The `expr` reference will be invalidated by recursive calls |
942 | // (and any other method that may add new expressions); therefore, the |
943 | // code below must make sure to copy fields of `expr` into local variables |
944 | // before making any recursive calls. |
945 | const auto &expr = exp(e); |
946 | const TensorExp::Kind kind = expr.kind; |
947 | switch (kind) { |
948 | // Leaf. |
949 | case TensorExp::Kind::kTensor: |
950 | case TensorExp::Kind::kInvariant: |
951 | case TensorExp::Kind::kSynZero: |
952 | case TensorExp::Kind::kLoopVar: { |
953 | // Either the loop-var is really used in the tensor expression, or it is |
954 | // set to the undefined loop-var in that level. An invariant expression, |
955 | // a proper index value, and a truly dynamic sparse output tensor are set |
956 | // to a synthetic tensor with undefined indices only to ensure the |
957 | // iteration space is not skipped as a result of their contents. |
958 | const LatSetId s = addSet(); |
959 | TensorId t = syntheticTensor; |
960 | if (kind == TensorExp::Kind::kTensor) { |
961 | t = expr.tensor; |
962 | if (hasSparseOut && t == outTensor) |
963 | t = syntheticTensor; |
964 | } |
965 | latSets[s].push_back(Elt: addLat(t, i, e)); |
966 | return s; |
967 | } |
968 | // Unary operations. |
969 | case TensorExp::Kind::kAbsF: |
970 | case TensorExp::Kind::kAbsC: |
971 | case TensorExp::Kind::kAbsI: |
972 | case TensorExp::Kind::kCeilF: |
973 | case TensorExp::Kind::kFloorF: |
974 | case TensorExp::Kind::kSqrtF: |
975 | case TensorExp::Kind::kSqrtC: |
976 | case TensorExp::Kind::kExpm1F: |
977 | case TensorExp::Kind::kExpm1C: |
978 | case TensorExp::Kind::kLog1pF: |
979 | case TensorExp::Kind::kLog1pC: |
980 | case TensorExp::Kind::kRelu: |
981 | case TensorExp::Kind::kSinF: |
982 | case TensorExp::Kind::kSinC: |
983 | case TensorExp::Kind::kTanhF: |
984 | case TensorExp::Kind::kTanhC: |
985 | case TensorExp::Kind::kNegF: |
986 | case TensorExp::Kind::kNegC: |
987 | case TensorExp::Kind::kNegI: |
988 | case TensorExp::Kind::kTruncF: |
989 | case TensorExp::Kind::kExtF: |
990 | case TensorExp::Kind::kCastFS: |
991 | case TensorExp::Kind::kCastFU: |
992 | case TensorExp::Kind::kCastSF: |
993 | case TensorExp::Kind::kCastUF: |
994 | case TensorExp::Kind::kCastS: |
995 | case TensorExp::Kind::kCastU: |
996 | case TensorExp::Kind::kCastIdx: |
997 | case TensorExp::Kind::kTruncI: |
998 | case TensorExp::Kind::kCIm: |
999 | case TensorExp::Kind::kCRe: |
1000 | case TensorExp::Kind::kBitCast: |
1001 | // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the |
1002 | // lattice set of the operand through the operator into a new set. |
1003 | // |
1004 | // -y|!y | y | |
1005 | // --+---+---+ |
1006 | // | 0 |-y | |
1007 | { |
1008 | const ExprId e0 = expr.children.e0; |
1009 | const Value v = expr.val; |
1010 | Attribute a = expr.attr; |
1011 | return mapSet(kind, s0: buildLattices(e: e0, i), v, op: nullptr, a); |
1012 | } |
1013 | case TensorExp::Kind::kBinaryBranch: |
1014 | case TensorExp::Kind::kSelect: |
1015 | // The left or right half of a binary operation which has already |
1016 | // been split into separate operations for each region. |
1017 | { |
1018 | const ExprId e0 = expr.children.e0; |
1019 | Operation *const op = expr.op; |
1020 | return mapSet(kind, s0: buildLattices(e: e0, i), v: Value(), op); |
1021 | } |
1022 | case TensorExp::Kind::kUnary: |
1023 | // A custom unary operation. |
1024 | // |
1025 | // op y| !y | y | |
1026 | // ----+----------+------------+ |
1027 | // | absent() | present(y) | |
1028 | { |
1029 | const ExprId e0 = expr.children.e0; |
1030 | UnaryOp unop = cast<UnaryOp>(expr.op); |
1031 | const LatSetId child0 = buildLattices(e: e0, i); |
1032 | Region &absentRegion = unop.getAbsentRegion(); |
1033 | if (absentRegion.empty()) { |
1034 | // Simple mapping over existing values. |
1035 | return mapSet(kind, s0: child0, v: Value(), op: unop); |
1036 | } |
1037 | // Use a disjunction with `unop` on the left and the absent value as an |
1038 | // invariant on the right. |
1039 | Block &absentBlock = absentRegion.front(); |
1040 | YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator()); |
1041 | const Value absentVal = absentYield.getSingleResult(); |
1042 | const ExprId rhs = addInvariantExp(v: absentVal); |
1043 | return disjSet(e, s0: child0, s1: buildLattices(e: rhs, i), op: unop); |
1044 | } |
1045 | // Binary operations. |
1046 | case TensorExp::Kind::kMulF: |
1047 | case TensorExp::Kind::kMulC: |
1048 | case TensorExp::Kind::kMulI: |
1049 | case TensorExp::Kind::kAndI: |
1050 | // A multiplicative operation only needs to be performed |
1051 | // for the conjunction of sparse iteration spaces. |
1052 | // |
1053 | // x*y|!y | y | |
1054 | // ---+---+---+ |
1055 | // !x | 0 | 0 | |
1056 | // x | 0 |x*y| |
1057 | // |
1058 | // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored. |
1059 | { |
1060 | const ExprId e0 = expr.children.e0; |
1061 | const ExprId e1 = expr.children.e1; |
1062 | return conjSet(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i)); |
1063 | } |
1064 | case TensorExp::Kind::kDivF: |
1065 | case TensorExp::Kind::kDivC: |
1066 | case TensorExp::Kind::kDivS: |
1067 | case TensorExp::Kind::kDivU: |
1068 | // A division is tricky, since 0/0, 0/c, c/0 all have |
1069 | // specific outcomes for floating-point and integers. |
1070 | // Thus, we need to traverse the full iteration space. |
1071 | // |
1072 | // x/y|!y | y | |
1073 | // ---+---+---+ |
1074 | // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero |
1075 | // x |x/0|x/y| INT: x/0=exception for any x |
1076 | // |
1077 | // TODO: for now we "fixed" this by only accepting x/c cases |
1078 | // during expression building, so that the conjunction |
1079 | // rules applies (viz. x/c = x*(1/c) as far as lattice |
1080 | // construction is concerned). |
1081 | { |
1082 | const ExprId e0 = expr.children.e0; |
1083 | const ExprId e1 = expr.children.e1; |
1084 | assert(!maybeZero(e1)); |
1085 | return conjSet(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i)); |
1086 | } |
1087 | case TensorExp::Kind::kAddF: |
1088 | case TensorExp::Kind::kAddC: |
1089 | case TensorExp::Kind::kAddI: |
1090 | case TensorExp::Kind::kSubF: |
1091 | case TensorExp::Kind::kSubC: |
1092 | case TensorExp::Kind::kSubI: |
1093 | case TensorExp::Kind::kOrI: |
1094 | case TensorExp::Kind::kXorI: |
1095 | // An additive operation needs to be performed |
1096 | // for the disjunction of sparse iteration spaces. |
1097 | // |
1098 | // x+y|!y | y | x-y|!y | y | |
1099 | // ---+---+---+ ---+---+---+ |
1100 | // !x | 0 | y | !x | 0 |-y | |
1101 | // x | x |x+y| x | x |x-y| |
1102 | { |
1103 | const ExprId e0 = expr.children.e0; |
1104 | const ExprId e1 = expr.children.e1; |
1105 | return disjSet(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i)); |
1106 | } |
1107 | case TensorExp::Kind::kCmpF: |
1108 | case TensorExp::Kind::kCmpI: |
1109 | // A comparison operation needs to be performed |
1110 | // for the disjunction of sparse iteration spaces. |
1111 | // |
1112 | // x < y | !y | y | |
1113 | // -------+-------+-------+ |
1114 | // !x | 0 | 0 < y | |
1115 | // x | x < 0 | x < y | |
1116 | { |
1117 | const ExprId e0 = expr.children.e0; |
1118 | const ExprId e1 = expr.children.e1; |
1119 | return disjSetWithZero(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i)); |
1120 | } |
1121 | case TensorExp::Kind::kShrS: |
1122 | case TensorExp::Kind::kShrU: |
1123 | case TensorExp::Kind::kShlI: |
1124 | // A shift operation by an invariant amount (viz. tensor expressions |
1125 | // can only occur at the left-hand-side of the operator) can be handled |
1126 | // with the conjunction rule. |
1127 | { |
1128 | const ExprId e0 = expr.children.e0; |
1129 | const ExprId e1 = expr.children.e1; |
1130 | assert(isInvariant(e1)); |
1131 | return conjSet(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i)); |
1132 | } |
1133 | case TensorExp::Kind::kBinary: |
1134 | // A custom binary operation. |
1135 | // |
1136 | // x op y| !y | y | |
1137 | // ------+---------+--------------+ |
1138 | // !x | empty | right(y) | |
1139 | // x | left(x) | overlap(x,y) | |
1140 | { |
1141 | const ExprId e0 = expr.children.e0; |
1142 | const ExprId e1 = expr.children.e1; |
1143 | BinaryOp binop = cast<BinaryOp>(expr.op); |
1144 | const LatSetId child0 = buildLattices(e: e0, i); |
1145 | const LatSetId child1 = buildLattices(e: e1, i); |
1146 | Region &leftRegion = binop.getLeftRegion(); |
1147 | Region &rightRegion = binop.getRightRegion(); |
1148 | // Left Region. |
1149 | Operation *leftYield = nullptr; |
1150 | if (!leftRegion.empty()) { |
1151 | Block &leftBlock = leftRegion.front(); |
1152 | leftYield = leftBlock.getTerminator(); |
1153 | } |
1154 | // Right Region. |
1155 | Operation *rightYield = nullptr; |
1156 | if (!rightRegion.empty()) { |
1157 | Block &rightBlock = rightRegion.front(); |
1158 | rightYield = rightBlock.getTerminator(); |
1159 | } |
1160 | bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty(); |
1161 | bool includeRight = binop.getRightIdentity() || !rightRegion.empty(); |
1162 | return combiSet(e, s0: child0, s1: child1, orig: binop, includeLeft, |
1163 | ltrans: TensorExp::Kind::kBinaryBranch, opleft: leftYield, includeRight, |
1164 | rtrans: TensorExp::Kind::kBinaryBranch, opright: rightYield); |
1165 | } |
1166 | case TensorExp::Kind::kReduce: |
1167 | // A custom reduce operation. |
1168 | { |
1169 | const ExprId e0 = expr.children.e0; |
1170 | const ExprId e1 = expr.children.e1; |
1171 | Operation *const op = expr.op; |
1172 | return conjSet(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i), op); |
1173 | } |
1174 | case TensorExp::Kind::kDenseOp: { |
1175 | // It does not really matter whether we use conjunctive/disjunctive set |
1176 | // here, as all the operands of kDenseOp must be dense, the disjunctive set |
1177 | // will be optimized into conjunctive set eventually. |
1178 | if (expr.children.e1 == detail::kInvalidId) { |
1179 | const ExprId e0 = expr.children.e0; |
1180 | Operation *const op = expr.op; |
1181 | return mapSet(kind, s0: buildLattices(e: e0, i), v: Value(), op); |
1182 | } |
1183 | |
1184 | const ExprId e0 = expr.children.e0; |
1185 | const ExprId e1 = expr.children.e1; |
1186 | Operation *const op = expr.op; |
1187 | return conjSet(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i), op); |
1188 | } |
1189 | } |
1190 | llvm_unreachable("unexpected expression kind" ); |
1191 | } |
1192 | |
1193 | std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { |
1194 | // Build the linalg semantics backward from yield. |
1195 | Operation *yield = op.getRegion().front().getTerminator(); |
1196 | assert(isa<linalg::YieldOp>(yield)); |
1197 | return buildTensorExp(op: op, v: yield->getOperand(idx: 0)).first; |
1198 | } |
1199 | |
1200 | /// Only returns true if we are certain this is a zero. |
1201 | static bool isCertainZero(Value val) { |
1202 | if (auto c = val.getDefiningOp<complex::ConstantOp>()) { |
1203 | ArrayAttr arrayAttr = c.getValue(); |
1204 | return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() && |
1205 | cast<FloatAttr>(arrayAttr[1]).getValue().isZero(); |
1206 | } |
1207 | if (auto c = val.getDefiningOp<arith::ConstantIntOp>()) |
1208 | return c.value() == 0; |
1209 | if (auto c = val.getDefiningOp<arith::ConstantFloatOp>()) |
1210 | return c.value().isZero(); |
1211 | return false; |
1212 | } |
1213 | |
1214 | /// Only returns false if we are certain this is a nonzero. |
1215 | bool Merger::maybeZero(ExprId e) const { |
1216 | const auto &expr = exp(e); |
1217 | if (expr.kind == TensorExp::Kind::kInvariant) { |
1218 | // Note that this is different from isCertainZero() in a subtle |
1219 | // way by always returning true for non-constants. |
1220 | if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) { |
1221 | ArrayAttr arrayAttr = c.getValue(); |
1222 | return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() && |
1223 | cast<FloatAttr>(arrayAttr[1]).getValue().isZero(); |
1224 | } |
1225 | if (auto c = expr.val.getDefiningOp<arith::ConstantIntOp>()) |
1226 | return c.value() == 0; |
1227 | if (auto c = expr.val.getDefiningOp<arith::ConstantFloatOp>()) |
1228 | return c.value().isZero(); |
1229 | } |
1230 | return true; |
1231 | } |
1232 | |
1233 | Type Merger::inferType(ExprId e, Value src) const { |
1234 | // Obtain the destination type from the cast node. |
1235 | Type dtp = exp(e).val.getType(); |
1236 | // Inspect source type. For vector types, apply the same |
1237 | // vectorization to the destination type. |
1238 | if (auto vtp = dyn_cast<VectorType>(src.getType())) |
1239 | return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims()); |
1240 | return dtp; |
1241 | } |
1242 | |
1243 | /// Ensures that the sparsifier can generate code for expression. |
1244 | static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) { |
1245 | // Arguments are always admissible. |
1246 | if (isa<BlockArgument>(Val: v)) |
1247 | return true; |
1248 | // Accept index anywhere. |
1249 | Operation *def = v.getDefiningOp(); |
1250 | if (isa<linalg::IndexOp>(def)) |
1251 | return true; |
1252 | // Operation defined outside branch. |
1253 | if (def->getBlock() != block) |
1254 | return def->getBlock() != op->getBlock(); // invariant? |
1255 | // Operation defined within branch. Anything is accepted, |
1256 | // as long as all subexpressions are admissible. |
1257 | for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) |
1258 | if (!isAdmissibleBranchExp(op, block, v: def->getOperand(idx: i))) |
1259 | return false; |
1260 | return true; |
1261 | } |
1262 | |
1263 | /// Ensures that the sparsifier can generate code for branch. |
1264 | static bool isAdmissibleBranch(Operation *op, Region ®ion) { |
1265 | if (region.empty()) |
1266 | return true; |
1267 | // Build the semi-ring branch semantics backward from yield. |
1268 | Operation *yield = region.front().getTerminator(); |
1269 | assert(isa<YieldOp>(yield)); |
1270 | return isAdmissibleBranchExp(op, block: ®ion.front(), v: yield->getOperand(idx: 0)); |
1271 | } |
1272 | |
1273 | // Recognizes a direct GT comparison. |
1274 | static bool isGreater(TensorExp::Kind kind, Attribute attr) { |
1275 | if (kind == TensorExp::Kind::kCmpI) { |
1276 | auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr).getValue(); |
1277 | return pred == arith::CmpIPredicate::ugt || |
1278 | pred == arith::CmpIPredicate::sgt; |
1279 | } |
1280 | if (kind == TensorExp::Kind::kCmpF) { |
1281 | auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr).getValue(); |
1282 | return pred == arith::CmpFPredicate::UGT || |
1283 | pred == arith::CmpFPredicate::OGT; |
1284 | } |
1285 | return false; |
1286 | } |
1287 | |
1288 | std::pair<std::optional<ExprId>, bool> |
1289 | Merger::buildTensorExp(linalg::GenericOp op, Value v) { |
1290 | // Recursion leaves. |
1291 | if (auto arg = dyn_cast<BlockArgument>(Val&: v)) { |
1292 | const TensorId tid = makeTensorId(t: arg.getArgNumber()); |
1293 | // Any argument of the generic op that is not marked as a scalar |
1294 | // argument is considered a tensor, indexed by the implicit loop |
1295 | // bounds. This includes rank-0 tensor arguments. |
1296 | if (arg.getOwner()->getParentOp() == op) { |
1297 | OpOperand &t = op->getOpOperand(tid); |
1298 | bool hasSpDep = getSparseTensorEncoding(t.get().getType()) != nullptr; |
1299 | if (!op.isScalar(&t)) |
1300 | return {addTensorExp(t: tid), hasSpDep}; |
1301 | v = t.get(); // get scalar value |
1302 | } |
1303 | // Any other argument (marked as scalar argument for the generic op |
1304 | // or belonging to an enveloping op) is considered invariant. |
1305 | return {addInvariantExp(v), /*hasSpDep=*/false}; |
1306 | } |
1307 | |
1308 | // Something defined outside is invariant. |
1309 | Operation *def = v.getDefiningOp(); |
1310 | if (def->getBlock() != &op.getRegion().front()) |
1311 | return {addInvariantExp(v), /*hasSpDep=*/false}; |
1312 | // Construct index operations. |
1313 | if (def->getNumOperands() == 0) { |
1314 | if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) |
1315 | return {addLoopVarExp(i: makeLoopId(i: indexOp.getDim())), /*hasSpDep=*/false}; |
1316 | } |
1317 | |
1318 | // Construct unary operations if subexpression can be built. |
1319 | if (def->getNumOperands() == 1) { |
1320 | const auto [x, hasSpDep] = buildTensorExp(op: op, v: def->getOperand(idx: 0)); |
1321 | if (x.has_value()) { |
1322 | const ExprId e = *x; |
1323 | if (isa<math::AbsFOp>(def)) |
1324 | return {addExp(k: TensorExp::Kind::kAbsF, e0: e), hasSpDep}; |
1325 | if (isa<complex::AbsOp>(def)) |
1326 | return {addExp(k: TensorExp::Kind::kAbsC, e0: e), hasSpDep}; |
1327 | if (isa<math::AbsIOp>(def)) |
1328 | return {addExp(k: TensorExp::Kind::kAbsI, e0: e), hasSpDep}; |
1329 | if (isa<math::CeilOp>(def)) |
1330 | return {addExp(k: TensorExp::Kind::kCeilF, e0: e), hasSpDep}; |
1331 | if (isa<math::FloorOp>(def)) |
1332 | return {addExp(k: TensorExp::Kind::kFloorF, e0: e), hasSpDep}; |
1333 | if (isa<math::SqrtOp>(def)) |
1334 | return {addExp(k: TensorExp::Kind::kSqrtF, e0: e), hasSpDep}; |
1335 | if (isa<complex::SqrtOp>(def)) |
1336 | return {addExp(k: TensorExp::Kind::kSqrtC, e0: e), hasSpDep}; |
1337 | if (isa<math::ExpM1Op>(def)) |
1338 | return {addExp(k: TensorExp::Kind::kExpm1F, e0: e), hasSpDep}; |
1339 | if (isa<complex::Expm1Op>(def)) |
1340 | return {addExp(k: TensorExp::Kind::kExpm1C, e0: e), hasSpDep}; |
1341 | if (isa<math::Log1pOp>(def)) |
1342 | return {addExp(k: TensorExp::Kind::kLog1pF, e0: e), hasSpDep}; |
1343 | if (isa<complex::Log1pOp>(def)) |
1344 | return {addExp(k: TensorExp::Kind::kLog1pC, e0: e), hasSpDep}; |
1345 | if (isa<math::SinOp>(def)) |
1346 | return {addExp(k: TensorExp::Kind::kSinF, e0: e), hasSpDep}; |
1347 | if (isa<complex::SinOp>(def)) |
1348 | return {addExp(k: TensorExp::Kind::kSinC, e0: e), hasSpDep}; |
1349 | if (isa<math::TanhOp>(def)) |
1350 | return {addExp(k: TensorExp::Kind::kTanhF, e0: e), hasSpDep}; |
1351 | if (isa<complex::TanhOp>(def)) |
1352 | return {addExp(k: TensorExp::Kind::kTanhC, e0: e), hasSpDep}; |
1353 | if (isa<arith::NegFOp>(def)) |
1354 | return {addExp(k: TensorExp::Kind::kNegF, e0: e), hasSpDep}; // no negi in std |
1355 | if (isa<complex::NegOp>(def)) |
1356 | return {addExp(k: TensorExp::Kind::kNegC, e0: e), hasSpDep}; |
1357 | if (isa<arith::TruncFOp>(def)) |
1358 | return {addExp(k: TensorExp::Kind::kTruncF, e, v), hasSpDep}; |
1359 | if (isa<arith::ExtFOp>(def)) |
1360 | return {addExp(k: TensorExp::Kind::kExtF, e, v), hasSpDep}; |
1361 | if (isa<arith::FPToSIOp>(def)) |
1362 | return {addExp(k: TensorExp::Kind::kCastFS, e, v), hasSpDep}; |
1363 | if (isa<arith::FPToUIOp>(def)) |
1364 | return {addExp(k: TensorExp::Kind::kCastFU, e, v), hasSpDep}; |
1365 | if (isa<arith::SIToFPOp>(def)) |
1366 | return {addExp(k: TensorExp::Kind::kCastSF, e, v), hasSpDep}; |
1367 | if (isa<arith::UIToFPOp>(def)) |
1368 | return {addExp(k: TensorExp::Kind::kCastUF, e, v), hasSpDep}; |
1369 | if (isa<arith::ExtSIOp>(def)) |
1370 | return {addExp(k: TensorExp::Kind::kCastS, e, v), hasSpDep}; |
1371 | if (isa<arith::ExtUIOp>(def)) |
1372 | return {addExp(k: TensorExp::Kind::kCastU, e, v), hasSpDep}; |
1373 | if (isa<arith::IndexCastOp>(def)) |
1374 | return {addExp(k: TensorExp::Kind::kCastIdx, e, v), hasSpDep}; |
1375 | if (isa<arith::TruncIOp>(def)) |
1376 | return {addExp(k: TensorExp::Kind::kTruncI, e, v), hasSpDep}; |
1377 | if (isa<complex::ImOp>(def)) |
1378 | return {addExp(k: TensorExp::Kind::kCIm, e0: e), hasSpDep}; |
1379 | if (isa<complex::ReOp>(def)) |
1380 | return {addExp(k: TensorExp::Kind::kCRe, e0: e), hasSpDep}; |
1381 | if (isa<arith::BitcastOp>(def)) |
1382 | return {addExp(k: TensorExp::Kind::kBitCast, e, v), hasSpDep}; |
1383 | if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) { |
1384 | if (isAdmissibleBranch(unop, unop.getPresentRegion()) && |
1385 | isAdmissibleBranch(unop, unop.getAbsentRegion())) |
1386 | return {addExp(k: TensorExp::Kind::kUnary, e, v: Value(), op: def), hasSpDep}; |
1387 | } |
1388 | if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) { |
1389 | if (isAdmissibleBranch(selop, selop.getRegion())) |
1390 | return {addExp(k: TensorExp::Kind::kSelect, e, v: Value(), op: def), hasSpDep}; |
1391 | } |
1392 | } |
1393 | } |
1394 | |
1395 | // Construct binary operations if subexpressions can be built. |
1396 | // See buildLattices() for an explanation of rejecting certain |
1397 | // division and shift operations. |
1398 | if (def->getNumOperands() == 2) { |
1399 | const auto [x, xSpVals] = buildTensorExp(op: op, v: def->getOperand(idx: 0)); |
1400 | const auto [y, ySpVals] = buildTensorExp(op: op, v: def->getOperand(idx: 1)); |
1401 | // For a conjunctive operation, it yields a "sparse" result if any operand |
1402 | // is sparse. For a disjunctive operation, it yields a "sparse" result if |
1403 | // all operands are sparse. |
1404 | bool conjSpVals = xSpVals || ySpVals; |
1405 | bool disjSpVals = xSpVals && ySpVals; |
1406 | if (x.has_value() && y.has_value()) { |
1407 | const ExprId e0 = *x; |
1408 | const ExprId e1 = *y; |
1409 | if (isa<arith::MulFOp>(def)) |
1410 | return {addExp(k: TensorExp::Kind::kMulF, e0, e1), conjSpVals}; |
1411 | if (isa<complex::MulOp>(def)) |
1412 | return {addExp(k: TensorExp::Kind::kMulC, e0, e1), conjSpVals}; |
1413 | if (isa<arith::MulIOp>(def)) |
1414 | return {addExp(k: TensorExp::Kind::kMulI, e0, e1), conjSpVals}; |
1415 | if (isa<arith::DivFOp>(def) && !maybeZero(e1)) |
1416 | return {addExp(k: TensorExp::Kind::kDivF, e0, e1), conjSpVals}; |
1417 | if (isa<complex::DivOp>(def) && !maybeZero(e1)) |
1418 | return {addExp(k: TensorExp::Kind::kDivC, e0, e1), conjSpVals}; |
1419 | if (isa<arith::DivSIOp>(def) && !maybeZero(e1)) |
1420 | return {addExp(k: TensorExp::Kind::kDivS, e0, e1), conjSpVals}; |
1421 | if (isa<arith::DivUIOp>(def) && !maybeZero(e1)) |
1422 | return {addExp(k: TensorExp::Kind::kDivU, e0, e1), conjSpVals}; |
1423 | if (isa<arith::AddFOp>(def)) |
1424 | return {addExp(k: TensorExp::Kind::kAddF, e0, e1), disjSpVals}; |
1425 | if (isa<complex::AddOp>(def)) |
1426 | return {addExp(k: TensorExp::Kind::kAddC, e0, e1), disjSpVals}; |
1427 | if (isa<arith::AddIOp>(def)) |
1428 | return {addExp(k: TensorExp::Kind::kAddI, e0, e1), disjSpVals}; |
1429 | if (isa<arith::SubFOp>(def)) |
1430 | return {addExp(k: TensorExp::Kind::kSubF, e0, e1), disjSpVals}; |
1431 | if (isa<complex::SubOp>(def)) |
1432 | return {addExp(k: TensorExp::Kind::kSubC, e0, e1), disjSpVals}; |
1433 | if (isa<arith::SubIOp>(def)) |
1434 | return {addExp(k: TensorExp::Kind::kSubI, e0, e1), disjSpVals}; |
1435 | if (isa<arith::AndIOp>(def)) |
1436 | return {addExp(k: TensorExp::Kind::kAndI, e0, e1), conjSpVals}; |
1437 | if (isa<arith::OrIOp>(def)) |
1438 | return {addExp(k: TensorExp::Kind::kOrI, e0, e1), disjSpVals}; |
1439 | if (isa<arith::XOrIOp>(def)) |
1440 | return {addExp(k: TensorExp::Kind::kXorI, e0, e1), disjSpVals}; |
1441 | if (isa<arith::ShRSIOp>(def) && isInvariant(e1)) |
1442 | return {addExp(k: TensorExp::Kind::kShrS, e0, e1), conjSpVals}; |
1443 | if (isa<arith::ShRUIOp>(def) && isInvariant(e1)) |
1444 | return {addExp(k: TensorExp::Kind::kShrU, e0, e1), conjSpVals}; |
1445 | if (isa<arith::ShLIOp>(def) && isInvariant(e1)) |
1446 | return {addExp(k: TensorExp::Kind::kShlI, e0, e1), conjSpVals}; |
1447 | if (auto ci = dyn_cast<arith::CmpIOp>(def)) { |
1448 | if (ci.getPredicate() == arith::CmpIPredicate::eq && |
1449 | ci.getPredicate() == arith::CmpIPredicate::sle && |
1450 | ci.getPredicate() == arith::CmpIPredicate::sge && |
1451 | ci.getPredicate() == arith::CmpIPredicate::ule && |
1452 | ci.getPredicate() == arith::CmpIPredicate::uge) { |
1453 | // We can not sparsify comparison with equal, this is because 0 <= 0 |
1454 | // yields true, and thus densifies the result. |
1455 | return {std::nullopt, false}; |
1456 | } |
1457 | |
1458 | auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr, |
1459 | ci.getPredicateAttr()); |
1460 | return {e, conjSpVals}; |
1461 | } |
1462 | if (auto cf = dyn_cast<arith::CmpFOp>(def)) { |
1463 | if (cf.getPredicate() == arith::CmpFPredicate::OEQ && |
1464 | cf.getPredicate() == arith::CmpFPredicate::OGE && |
1465 | cf.getPredicate() == arith::CmpFPredicate::OLE && |
1466 | cf.getPredicate() == arith::CmpFPredicate::ONE && |
1467 | cf.getPredicate() == arith::CmpFPredicate::UEQ && |
1468 | cf.getPredicate() == arith::CmpFPredicate::UGE && |
1469 | cf.getPredicate() == arith::CmpFPredicate::ULE && |
1470 | cf.getPredicate() == arith::CmpFPredicate::ORD && |
1471 | cf.getPredicate() == arith::CmpFPredicate::UNO) { |
1472 | // We can not sparsify comparison with equal, this is because 0 <= 0 |
1473 | // yields true, and thus densifies the result. |
1474 | return {std::nullopt, false}; |
1475 | } |
1476 | auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr, |
1477 | cf.getPredicateAttr()); |
1478 | return {e, conjSpVals}; |
1479 | } |
1480 | if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) { |
1481 | if (isAdmissibleBranch(binop, binop.getOverlapRegion()) && |
1482 | (binop.getLeftIdentity() || |
1483 | isAdmissibleBranch(binop, binop.getLeftRegion())) && |
1484 | (binop.getRightIdentity() || |
1485 | isAdmissibleBranch(binop, binop.getRightRegion()))) |
1486 | return {addExp(k: TensorExp::Kind::kBinary, e0, e1, op: def), conjSpVals}; |
1487 | } |
1488 | } |
1489 | } |
1490 | |
1491 | // Construct ternary operations if subexpressions can be built. |
1492 | if (def->getNumOperands() == 3) { |
1493 | const auto [x, xDepSp] = buildTensorExp(op: op, v: def->getOperand(idx: 0)); |
1494 | const auto [y, yDepSp] = buildTensorExp(op: op, v: def->getOperand(idx: 1)); |
1495 | const auto [z, zDepSp] = buildTensorExp(op: op, v: def->getOperand(idx: 2)); |
1496 | bool hasSpDep = xDepSp || yDepSp || zDepSp; |
1497 | if (x.has_value() && y.has_value() && z.has_value()) { |
1498 | const ExprId e0 = *x; |
1499 | const ExprId e1 = *y; |
1500 | if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) { |
1501 | if (isAdmissibleBranch(redop, redop.getRegion())) |
1502 | return {addExp(k: TensorExp::Kind::kReduce, e0, e1, op: def), hasSpDep}; |
1503 | } |
1504 | if (auto selop = dyn_cast<arith::SelectOp>(def)) { |
1505 | // Recognize an integral or floating-point ReLu(x) = Max(x, 0) |
1506 | // operation inside a very specific ternary select operation. |
1507 | // TODO: capture MIN/MAX/ABS/RELU structure in a more generic way |
1508 | const auto &cnd = exp(e: *x); |
1509 | if (isGreater(cnd.kind, cnd.attr) && |
1510 | exp(e: *y).kind == TensorExp::Kind::kTensor && |
1511 | exp(e: *z).kind == TensorExp::Kind::kInvariant && |
1512 | isCertainZero(exp(e: *z).val)) { |
1513 | const auto &a = exp(e: cnd.children.e0); |
1514 | const auto &b = exp(e: cnd.children.e1); |
1515 | if (a.kind == TensorExp::Kind::kTensor && |
1516 | a.tensor == exp(e: *y).tensor && |
1517 | b.kind == TensorExp::Kind::kInvariant && isCertainZero(b.val)) { |
1518 | return {addExp(TensorExp::Kind::kRelu, *y, detail::kInvalidId, |
1519 | nullptr, cnd.attr), |
1520 | yDepSp}; |
1521 | } |
1522 | } |
1523 | } |
1524 | } |
1525 | } |
1526 | |
1527 | // If we reach here, we are dealing with an operation that is not currently |
1528 | // sparsifiable. We can still generate code for it if all its operands only |
1529 | // have dense dependencies (i.e., all the values are loaded from dense |
1530 | // tensors). |
1531 | if (def->getNumResults() != 1) // only handle single result operation. |
1532 | return {std::nullopt, false}; |
1533 | SmallVector<std::pair<std::optional<ExprId>, bool>, 2> subExp; |
1534 | // Builds all the sub-expressions |
1535 | for (Value operand : def->getOperands()) |
1536 | subExp.push_back(Elt: buildTensorExp(op: op, v: operand)); |
1537 | |
1538 | if (llvm::all_of(Range&: subExp, |
1539 | P: [](auto e) { return e.first.has_value() && !e.second; })) { |
1540 | // All the subexpressions can be built and has *no* sparse dependencies. |
1541 | if (subExp.size() == 2) { |
1542 | auto e = addExp(k: TensorExp::Kind::kDenseOp, e0: *subExp[0].first, |
1543 | e1: *subExp[1].first, op: def); |
1544 | return {e, false}; |
1545 | } |
1546 | if (subExp.size() == 1) { |
1547 | auto e = addExp(k: TensorExp::Kind::kDenseOp, e0: *subExp[0].first, |
1548 | e1: detail::kInvalidId, op: def); |
1549 | return {e, false}; |
1550 | } |
1551 | } |
1552 | |
1553 | // Cannot build. |
1554 | return {std::nullopt, false}; |
1555 | } |
1556 | |
1557 | static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion, |
1558 | ValueRange vals) { |
1559 | // Make a clone of overlap region. |
1560 | Region tmpRegion; |
1561 | IRMapping mapper; |
1562 | region.cloneInto(dest: &tmpRegion, destPos: tmpRegion.begin(), mapper); |
1563 | Block &clonedBlock = tmpRegion.front(); |
1564 | YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator()); |
1565 | // Merge cloned block and return yield value. |
1566 | Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
1567 | rewriter.inlineBlockBefore(source: &tmpRegion.front(), op: placeholder, argValues: vals); |
1568 | Value val = clonedYield.getSingleResult(); |
1569 | rewriter.eraseOp(op: clonedYield); |
1570 | rewriter.eraseOp(op: placeholder); |
1571 | return val; |
1572 | } |
1573 | |
1574 | static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, |
1575 | Operation *op, Value v0) { |
1576 | if (!v0) |
1577 | // Empty input value must be propagated. |
1578 | return Value(); |
1579 | UnaryOp unop = cast<UnaryOp>(op); |
1580 | Region &presentRegion = unop.getPresentRegion(); |
1581 | if (presentRegion.empty()) |
1582 | // Uninitialized Value() will be interpreted as missing data in the |
1583 | // output. |
1584 | return Value(); |
1585 | return insertYieldOp(rewriter, loc, region&: presentRegion, vals: {v0}); |
1586 | } |
1587 | |
1588 | static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, |
1589 | Operation *op, Value v0, Value v1) { |
1590 | if (!v0 || !v1) |
1591 | // Empty input values must be propagated. |
1592 | return Value(); |
1593 | BinaryOp binop = cast<BinaryOp>(op); |
1594 | Region &overlapRegion = binop.getOverlapRegion(); |
1595 | if (overlapRegion.empty()) |
1596 | // Uninitialized Value() will be interpreted as missing data in the |
1597 | // output. |
1598 | return Value(); |
1599 | return insertYieldOp(rewriter, loc, region&: overlapRegion, vals: {v0, v1}); |
1600 | } |
1601 | |
1602 | static Value buildRelu(RewriterBase &rewriter, Location loc, Value v0, |
1603 | Attribute attr) { |
1604 | Type tp = v0.getType(); |
1605 | auto zero = |
1606 | rewriter.create<arith::ConstantOp>(loc, tp, rewriter.getZeroAttr(tp)); |
1607 | Value cmp; |
1608 | if (isa<FloatType>(Val: tp)) { |
1609 | auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr); |
1610 | cmp = rewriter.create<arith::CmpFOp>(loc, pred, v0, zero); |
1611 | } else { |
1612 | auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr); |
1613 | cmp = rewriter.create<arith::CmpIOp>(loc, pred, v0, zero); |
1614 | } |
1615 | return rewriter.create<arith::SelectOp>(loc, cmp, v0, zero); |
1616 | } |
1617 | |
1618 | Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, |
1619 | Value v1) const { |
1620 | const auto &expr = exp(e); |
1621 | switch (expr.kind) { |
1622 | // Leaf. |
1623 | case TensorExp::Kind::kTensor: |
1624 | case TensorExp::Kind::kInvariant: |
1625 | case TensorExp::Kind::kLoopVar: |
1626 | case TensorExp::Kind::kSynZero: |
1627 | llvm_unreachable("unexpected non-op" ); |
1628 | // Unary operations. |
1629 | case TensorExp::Kind::kAbsF: |
1630 | return rewriter.create<math::AbsFOp>(loc, v0); |
1631 | case TensorExp::Kind::kAbsC: { |
1632 | auto type = cast<ComplexType>(v0.getType()); |
1633 | auto eltType = cast<FloatType>(type.getElementType()); |
1634 | return rewriter.create<complex::AbsOp>(loc, eltType, v0); |
1635 | } |
1636 | case TensorExp::Kind::kAbsI: |
1637 | return rewriter.create<math::AbsIOp>(loc, v0); |
1638 | case TensorExp::Kind::kCeilF: |
1639 | return rewriter.create<math::CeilOp>(loc, v0); |
1640 | case TensorExp::Kind::kFloorF: |
1641 | return rewriter.create<math::FloorOp>(loc, v0); |
1642 | case TensorExp::Kind::kSqrtF: |
1643 | return rewriter.create<math::SqrtOp>(loc, v0); |
1644 | case TensorExp::Kind::kSqrtC: |
1645 | return rewriter.create<complex::SqrtOp>(loc, v0); |
1646 | case TensorExp::Kind::kExpm1F: |
1647 | return rewriter.create<math::ExpM1Op>(loc, v0); |
1648 | case TensorExp::Kind::kExpm1C: |
1649 | return rewriter.create<complex::Expm1Op>(loc, v0); |
1650 | case TensorExp::Kind::kLog1pF: |
1651 | return rewriter.create<math::Log1pOp>(loc, v0); |
1652 | case TensorExp::Kind::kLog1pC: |
1653 | return rewriter.create<complex::Log1pOp>(loc, v0); |
1654 | case TensorExp::Kind::kRelu: |
1655 | return buildRelu(rewriter, loc, v0, attr: expr.attr); |
1656 | case TensorExp::Kind::kSinF: |
1657 | return rewriter.create<math::SinOp>(loc, v0); |
1658 | case TensorExp::Kind::kSinC: |
1659 | return rewriter.create<complex::SinOp>(loc, v0); |
1660 | case TensorExp::Kind::kTanhF: |
1661 | return rewriter.create<math::TanhOp>(loc, v0); |
1662 | case TensorExp::Kind::kTanhC: |
1663 | return rewriter.create<complex::TanhOp>(loc, v0); |
1664 | case TensorExp::Kind::kNegF: |
1665 | return rewriter.create<arith::NegFOp>(loc, v0); |
1666 | case TensorExp::Kind::kNegC: |
1667 | return rewriter.create<complex::NegOp>(loc, v0); |
1668 | case TensorExp::Kind::kNegI: // no negi in std |
1669 | return rewriter.create<arith::SubIOp>( |
1670 | loc, |
1671 | rewriter.create<arith::ConstantOp>(loc, v0.getType(), |
1672 | rewriter.getZeroAttr(v0.getType())), |
1673 | v0); |
1674 | case TensorExp::Kind::kTruncF: |
1675 | return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0); |
1676 | case TensorExp::Kind::kExtF: |
1677 | return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0); |
1678 | case TensorExp::Kind::kCastFS: |
1679 | return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0); |
1680 | case TensorExp::Kind::kCastFU: |
1681 | return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0); |
1682 | case TensorExp::Kind::kCastSF: |
1683 | return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0); |
1684 | case TensorExp::Kind::kCastUF: |
1685 | return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0); |
1686 | case TensorExp::Kind::kCastS: |
1687 | return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0); |
1688 | case TensorExp::Kind::kCastU: |
1689 | return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0); |
1690 | case TensorExp::Kind::kCastIdx: |
1691 | return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0); |
1692 | case TensorExp::Kind::kTruncI: |
1693 | return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0); |
1694 | case TensorExp::Kind::kCIm: { |
1695 | auto type = cast<ComplexType>(v0.getType()); |
1696 | auto eltType = cast<FloatType>(type.getElementType()); |
1697 | return rewriter.create<complex::ImOp>(loc, eltType, v0); |
1698 | } |
1699 | case TensorExp::Kind::kCRe: { |
1700 | auto type = cast<ComplexType>(v0.getType()); |
1701 | auto eltType = cast<FloatType>(type.getElementType()); |
1702 | return rewriter.create<complex::ReOp>(loc, eltType, v0); |
1703 | } |
1704 | case TensorExp::Kind::kBitCast: |
1705 | return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0); |
1706 | // Binary operations. |
1707 | case TensorExp::Kind::kMulF: |
1708 | return rewriter.create<arith::MulFOp>(loc, v0, v1); |
1709 | case TensorExp::Kind::kMulC: |
1710 | return rewriter.create<complex::MulOp>(loc, v0, v1); |
1711 | case TensorExp::Kind::kMulI: |
1712 | return rewriter.create<arith::MulIOp>(loc, v0, v1); |
1713 | case TensorExp::Kind::kDivF: |
1714 | return rewriter.create<arith::DivFOp>(loc, v0, v1); |
1715 | case TensorExp::Kind::kDivC: |
1716 | return rewriter.create<complex::DivOp>(loc, v0, v1); |
1717 | case TensorExp::Kind::kDivS: |
1718 | return rewriter.create<arith::DivSIOp>(loc, v0, v1); |
1719 | case TensorExp::Kind::kDivU: |
1720 | return rewriter.create<arith::DivUIOp>(loc, v0, v1); |
1721 | case TensorExp::Kind::kAddF: |
1722 | return rewriter.create<arith::AddFOp>(loc, v0, v1); |
1723 | case TensorExp::Kind::kAddC: |
1724 | return rewriter.create<complex::AddOp>(loc, v0, v1); |
1725 | case TensorExp::Kind::kAddI: |
1726 | return rewriter.create<arith::AddIOp>(loc, v0, v1); |
1727 | case TensorExp::Kind::kSubF: |
1728 | return rewriter.create<arith::SubFOp>(loc, v0, v1); |
1729 | case TensorExp::Kind::kSubC: |
1730 | return rewriter.create<complex::SubOp>(loc, v0, v1); |
1731 | case TensorExp::Kind::kSubI: |
1732 | return rewriter.create<arith::SubIOp>(loc, v0, v1); |
1733 | case TensorExp::Kind::kAndI: |
1734 | return rewriter.create<arith::AndIOp>(loc, v0, v1); |
1735 | case TensorExp::Kind::kOrI: |
1736 | return rewriter.create<arith::OrIOp>(loc, v0, v1); |
1737 | case TensorExp::Kind::kXorI: |
1738 | return rewriter.create<arith::XOrIOp>(loc, v0, v1); |
1739 | case TensorExp::Kind::kShrS: |
1740 | return rewriter.create<arith::ShRSIOp>(loc, v0, v1); |
1741 | case TensorExp::Kind::kShrU: |
1742 | return rewriter.create<arith::ShRUIOp>(loc, v0, v1); |
1743 | case TensorExp::Kind::kShlI: |
1744 | return rewriter.create<arith::ShLIOp>(loc, v0, v1); |
1745 | case TensorExp::Kind::kCmpI: { |
1746 | auto predicate = llvm::cast<arith::CmpIPredicateAttr>(expr.attr); |
1747 | return rewriter.create<arith::CmpIOp>(loc, predicate, v0, v1); |
1748 | } |
1749 | case TensorExp::Kind::kCmpF: { |
1750 | auto predicate = llvm::cast<arith::CmpFPredicateAttr>(expr.attr); |
1751 | return rewriter.create<arith::CmpFOp>(loc, predicate, v0, v1); |
1752 | } |
1753 | case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic. |
1754 | return insertYieldOp(rewriter, loc, region&: *expr.op->getBlock()->getParent(), |
1755 | vals: {v0}); |
1756 | case TensorExp::Kind::kUnary: |
1757 | return buildUnaryPresent(rewriter, loc, op: expr.op, v0); |
1758 | case TensorExp::Kind::kSelect: |
1759 | return insertYieldOp(rewriter, loc, |
1760 | cast<sparse_tensor::SelectOp>(expr.op).getRegion(), |
1761 | {v0}); |
1762 | case TensorExp::Kind::kBinary: |
1763 | return buildBinaryOverlap(rewriter, loc, op: expr.op, v0, v1); |
1764 | case TensorExp::Kind::kReduce: { |
1765 | ReduceOp redOp = cast<ReduceOp>(expr.op); |
1766 | return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1}); |
1767 | } |
1768 | case TensorExp::Kind::kDenseOp: { |
1769 | Operation *actualOp = expr.op; |
1770 | IRMapping mapping; |
1771 | mapping.map(from: actualOp->getOperand(idx: 0), to: v0); |
1772 | if (actualOp->getNumOperands() == 2) |
1773 | mapping.map(from: actualOp->getOperand(idx: 1), to: v1); |
1774 | return rewriter.clone(op&: *actualOp, mapper&: mapping)->getResult(idx: 0); |
1775 | } |
1776 | } |
1777 | llvm_unreachable("unexpected expression kind in build" ); |
1778 | } |
1779 | |
1780 | } // namespace sparse_tensor |
1781 | } // namespace mlir |
1782 | |