1//===- Merger.cpp - Implementation of iteration lattices ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10#include "mlir/Dialect/Arith/IR/Arith.h"
11#include "mlir/Dialect/Complex/IR/Complex.h"
12#include "mlir/Dialect/Math/IR/Math.h"
13#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14
15#include "mlir/IR/Operation.h"
16#include "llvm/Support/Debug.h"
17#include <optional>
18
19namespace mlir {
20namespace sparse_tensor {
21
22enum class ExpArity {
23 kNullary,
24 kUnary,
25 kBinary,
26};
27
28static ExpArity getExpArity(TensorExp::Kind k) {
29 switch (k) {
30 // Leaf.
31 case TensorExp::Kind::kTensor:
32 case TensorExp::Kind::kInvariant:
33 case TensorExp::Kind::kLoopVar:
34 case TensorExp::Kind::kSynZero:
35 return ExpArity::kNullary;
36 case TensorExp::Kind::kAbsF:
37 case TensorExp::Kind::kAbsC:
38 case TensorExp::Kind::kAbsI:
39 case TensorExp::Kind::kCeilF:
40 case TensorExp::Kind::kFloorF:
41 case TensorExp::Kind::kSqrtF:
42 case TensorExp::Kind::kSqrtC:
43 case TensorExp::Kind::kExpm1F:
44 case TensorExp::Kind::kExpm1C:
45 case TensorExp::Kind::kLog1pF:
46 case TensorExp::Kind::kLog1pC:
47 case TensorExp::Kind::kSinF:
48 case TensorExp::Kind::kSinC:
49 case TensorExp::Kind::kTanhF:
50 case TensorExp::Kind::kTanhC:
51 case TensorExp::Kind::kTruncF:
52 case TensorExp::Kind::kExtF:
53 case TensorExp::Kind::kCastFS:
54 case TensorExp::Kind::kCastFU:
55 case TensorExp::Kind::kCastSF:
56 case TensorExp::Kind::kCastUF:
57 case TensorExp::Kind::kCastS:
58 case TensorExp::Kind::kCastU:
59 case TensorExp::Kind::kCastIdx:
60 case TensorExp::Kind::kTruncI:
61 case TensorExp::Kind::kCIm:
62 case TensorExp::Kind::kCRe:
63 case TensorExp::Kind::kBitCast:
64 case TensorExp::Kind::kBinaryBranch:
65 case TensorExp::Kind::kUnary:
66 case TensorExp::Kind::kSelect:
67 case TensorExp::Kind::kNegF:
68 case TensorExp::Kind::kNegC:
69 case TensorExp::Kind::kNegI:
70 return ExpArity::kUnary;
71 // Binary operations.
72 case TensorExp::Kind::kDivF:
73 case TensorExp::Kind::kDivC:
74 case TensorExp::Kind::kDivS:
75 case TensorExp::Kind::kDivU:
76 case TensorExp::Kind::kShrS:
77 case TensorExp::Kind::kShrU:
78 case TensorExp::Kind::kShlI:
79 case TensorExp::Kind::kMulF:
80 case TensorExp::Kind::kMulC:
81 case TensorExp::Kind::kMulI:
82 case TensorExp::Kind::kAndI:
83 case TensorExp::Kind::kAddF:
84 case TensorExp::Kind::kAddC:
85 case TensorExp::Kind::kAddI:
86 case TensorExp::Kind::kOrI:
87 case TensorExp::Kind::kXorI:
88 case TensorExp::Kind::kBinary:
89 case TensorExp::Kind::kReduce:
90 case TensorExp::Kind::kSubF:
91 case TensorExp::Kind::kSubC:
92 case TensorExp::Kind::kSubI:
93 case TensorExp::Kind::kCmpF:
94 case TensorExp::Kind::kCmpI:
95 case TensorExp::Kind::kDenseOp: // kDenseOp can *at most* have two operands
96 return ExpArity::kBinary;
97 }
98 llvm_unreachable("unexpected kind");
99}
100
101//===----------------------------------------------------------------------===//
102// Constructors.
103//===----------------------------------------------------------------------===//
104
105TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
106 Operation *o, Attribute a)
107 : kind(k), val(v), op(o) {
108 switch (kind) {
109 // Leaf.
110 case TensorExp::Kind::kTensor:
111 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
112 tensor = x;
113 return;
114 case TensorExp::Kind::kSynZero:
115 assert(x == detail::kInvalidId && y == detail::kInvalidId && !v && !o);
116 return;
117 case TensorExp::Kind::kInvariant:
118 assert(x == detail::kInvalidId && y == detail::kInvalidId && v && !o);
119 return;
120 case TensorExp::Kind::kLoopVar:
121 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
122 loop = x;
123 return;
124 // Unary operations.
125 case TensorExp::Kind::kAbsF:
126 case TensorExp::Kind::kAbsC:
127 case TensorExp::Kind::kAbsI:
128 case TensorExp::Kind::kCeilF:
129 case TensorExp::Kind::kFloorF:
130 case TensorExp::Kind::kSqrtF:
131 case TensorExp::Kind::kSqrtC:
132 case TensorExp::Kind::kExpm1F:
133 case TensorExp::Kind::kExpm1C:
134 case TensorExp::Kind::kLog1pF:
135 case TensorExp::Kind::kLog1pC:
136 case TensorExp::Kind::kSinF:
137 case TensorExp::Kind::kSinC:
138 case TensorExp::Kind::kTanhF:
139 case TensorExp::Kind::kTanhC:
140 case TensorExp::Kind::kNegF:
141 case TensorExp::Kind::kNegC:
142 case TensorExp::Kind::kNegI:
143 case TensorExp::Kind::kCIm:
144 case TensorExp::Kind::kCRe:
145 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o);
146 children.e0 = x;
147 children.e1 = y;
148 return;
149 case TensorExp::Kind::kTruncF:
150 case TensorExp::Kind::kExtF:
151 case TensorExp::Kind::kCastFS:
152 case TensorExp::Kind::kCastFU:
153 case TensorExp::Kind::kCastSF:
154 case TensorExp::Kind::kCastUF:
155 case TensorExp::Kind::kCastS:
156 case TensorExp::Kind::kCastU:
157 case TensorExp::Kind::kCastIdx:
158 case TensorExp::Kind::kTruncI:
159 case TensorExp::Kind::kBitCast:
160 assert(x != detail::kInvalidId && y == detail::kInvalidId && v && !o);
161 children.e0 = x;
162 children.e1 = y;
163 return;
164 case TensorExp::Kind::kBinaryBranch:
165 case TensorExp::Kind::kSelect:
166 assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && o);
167 children.e0 = x;
168 children.e1 = y;
169 return;
170 case TensorExp::Kind::kUnary:
171 // No assertion on y can be made, as the branching paths involve both
172 // a unary (`mapSet`) and binary (`disjSet`) pathway.
173 assert(x != detail::kInvalidId && !v && o);
174 children.e0 = x;
175 children.e1 = y;
176 return;
177 // Binary operations.
178 case TensorExp::Kind::kMulF:
179 case TensorExp::Kind::kMulC:
180 case TensorExp::Kind::kMulI:
181 case TensorExp::Kind::kDivF:
182 case TensorExp::Kind::kDivC:
183 case TensorExp::Kind::kDivS:
184 case TensorExp::Kind::kDivU:
185 case TensorExp::Kind::kAddF:
186 case TensorExp::Kind::kAddC:
187 case TensorExp::Kind::kAddI:
188 case TensorExp::Kind::kSubF:
189 case TensorExp::Kind::kSubC:
190 case TensorExp::Kind::kSubI:
191 case TensorExp::Kind::kAndI:
192 case TensorExp::Kind::kOrI:
193 case TensorExp::Kind::kXorI:
194 case TensorExp::Kind::kShrS:
195 case TensorExp::Kind::kShrU:
196 case TensorExp::Kind::kShlI:
197 assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
198 children.e0 = x;
199 children.e1 = y;
200 return;
201 case TensorExp::Kind::kCmpF:
202 case TensorExp::Kind::kCmpI:
203 assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o);
204 attr = a;
205 children.e0 = x;
206 children.e1 = y;
207 return;
208 case TensorExp::Kind::kBinary:
209 case TensorExp::Kind::kReduce:
210 assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && o);
211 children.e0 = x;
212 children.e1 = y;
213 return;
214 case TensorExp::Kind::kDenseOp:
215 assert(x != detail::kInvalidId && !v && o);
216 children.e0 = x;
217 children.e1 = y;
218 return;
219 }
220 llvm_unreachable("unexpected kind");
221}
222
223Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops,
224 unsigned maxLvlRank)
225 : outTensor(numInputOutputTensors - 1),
226 syntheticTensor(numInputOutputTensors),
227 numTensors(numInputOutputTensors + 1), numLoops(numLoops),
228 hasSparseOut(false),
229 lvlTypes(numTensors,
230 std::vector<LevelType>(numLoops, LevelFormat::Undef)),
231 loopToLvl(numTensors,
232 std::vector<std::optional<Level>>(numLoops, std::nullopt)),
233 lvlToLoop(numTensors,
234 std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)),
235 loopToUnresolvedLvls(numLoops, std::vector<std::optional<LvlLTPair>>(
236 numTensors, std::nullopt)),
237 levelToDependentLoop(numTensors,
238 std::vector<std::vector<LoopCoeffPair>>(
239 maxLvlRank, std::vector<LoopCoeffPair>())),
240 loopBounds(numLoops, std::make_pair(x: numTensors, y&: numLoops)) {}
241
242//===----------------------------------------------------------------------===//
243// Lattice methods.
244//===----------------------------------------------------------------------===//
245
246ExprId Merger::addTensorExp(TensorId t) {
247 assert(isValidTensorId(t));
248 const ExprId eNew(tensorExps.size());
249 tensorExps.emplace_back(Args: TensorExp::Kind::kTensor, Args&: t, Args: detail::kInvalidId,
250 Args: Value(), Args: nullptr, Args: nullptr);
251 return eNew;
252}
253
254ExprId Merger::addLoopVarExp(LoopId i) {
255 assert(isValidLoopId(i));
256 const ExprId eNew(tensorExps.size());
257 tensorExps.emplace_back(Args: TensorExp::Kind::kLoopVar, Args&: i, Args: detail::kInvalidId,
258 Args: Value(), Args: nullptr, Args: nullptr);
259 return eNew;
260}
261
262ExprId Merger::addInvariantExp(Value v) {
263 const ExprId eNew(tensorExps.size());
264 tensorExps.emplace_back(Args: TensorExp::Kind::kInvariant, Args: detail::kInvalidId,
265 Args: detail::kInvalidId, Args&: v, Args: nullptr, Args: nullptr);
266 return eNew;
267}
268
269ExprId Merger::addSynZeroExp() {
270 const ExprId eNew(tensorExps.size());
271 tensorExps.emplace_back(Args: TensorExp::Kind::kSynZero, Args: detail::kInvalidId,
272 Args: detail::kInvalidId, Args: Value(), Args: nullptr, Args: nullptr);
273 return eNew;
274}
275
276ExprId Merger::addExp(TensorExp::Kind k, ExprId e0, ExprId e1, Operation *op,
277 Attribute attr) {
278 assert(k > TensorExp::Kind::kLoopVar);
279 const ExprId eNew(tensorExps.size());
280 tensorExps.emplace_back(Args&: k, Args&: e0, Args&: e1, Args: Value(), Args&: op, Args&: attr);
281 return eNew;
282}
283
284ExprId Merger::addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op,
285 Attribute attr) {
286 assert(k > TensorExp::Kind::kLoopVar);
287 const ExprId eNew(tensorExps.size());
288 tensorExps.emplace_back(Args&: k, Args&: e, Args: detail::kInvalidId, Args&: v, Args&: op, Args&: attr);
289 return eNew;
290}
291
292LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) {
293 const LatPointId pNew(latPoints.size());
294 const unsigned size = numLoops * numTensors;
295 const TensorLoopId b = makeTensorLoopId(t, i);
296 latPoints.emplace_back(Args: size, Args&: e);
297 latPoints[pNew].bits.set(b);
298 return pNew;
299}
300
301LatPointId Merger::addLat(const BitVector &bits, ExprId e) {
302 assert(bits.size() == numLoops * numTensors);
303 const LatPointId pNew(latPoints.size());
304 latPoints.emplace_back(Args: bits, Args&: e);
305 return pNew;
306}
307
308LatSetId Merger::addSet() {
309 const LatSetId sNew(latSets.size());
310 latSets.emplace_back();
311 return sNew;
312}
313
314LatPointId Merger::conjLat(ExprId e, LatPointId p0, LatPointId p1,
315 Operation *op) {
316 TensorExp::Kind kind = exp(e).kind;
317 Attribute attr = exp(e).attr;
318 const LatPointId pNew(latPoints.size());
319 const auto &point0 = lat(p: p0);
320 const auto &point1 = lat(p: p1);
321 BitVector bits(point0.bits);
322 bits |= point1.bits;
323 const ExprId ne = addExp(k: kind, e0: point0.exp, e1: point1.exp, op, attr);
324 latPoints.emplace_back(Args&: bits, Args: ne);
325 return pNew;
326}
327
328LatSetId Merger::conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
329 const LatSetId sNew = addSet();
330 auto &setNew = latSets[sNew];
331 for (const LatPointId p0 : set(s0))
332 for (const LatPointId p1 : set(s1))
333 setNew.push_back(Elt: conjLat(e, p0, p1, op));
334 return sNew;
335}
336
337LatSetId Merger::disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {
338 const LatSetId sNew = conjSet(e, s0, s1, op);
339 TensorExp::Kind kind = exp(e).kind;
340
341 // Followed by all in s0.
342 latSets[sNew].append(RHS: latSets[s0]);
343 // Map binary 0-y to unary -y.
344 // TODO: move this if-else logic into buildLattices
345 if (kind == TensorExp::Kind::kSubF)
346 s1 = mapSet(kind: TensorExp::Kind::kNegF, s: s1);
347 else if (kind == TensorExp::Kind::kSubC)
348 s1 = mapSet(kind: TensorExp::Kind::kNegC, s: s1);
349 else if (kind == TensorExp::Kind::kSubI)
350 s1 = mapSet(kind: TensorExp::Kind::kNegI, s: s1);
351 // Followed by all in s1.
352 latSets[sNew].append(RHS: latSets[s1]);
353 return sNew;
354}
355
356LatSetId Merger::disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1) {
357 assert(exp(e).kind == TensorExp::Kind::kCmpI ||
358 exp(e).kind == TensorExp::Kind::kCmpF);
359 const LatSetId sNew = conjSet(e, s0, s1, op: nullptr);
360
361 ExprId e0 = exp(e).children.e0;
362 ExprId e1 = exp(e).children.e1;
363 if (exp(e: e0).kind == TensorExp::Kind::kSynZero ||
364 exp(e: e1).kind == TensorExp::Kind::kSynZero) {
365 // lhs and rhs can't be synthetic zero at the same time.
366 assert(exp(e0).kind != exp(e1).kind);
367 // If one of the operands has already been assigned to zero (the
368 // element is absent in the corresponding operand), then we do not
369 // need to build disjunctive set for it.
370 return sNew;
371 }
372
373 auto lhsSet = mapBinWithSynZeroSet(e, s: s0, lhsZero: false);
374 auto rhsSet = mapBinWithSynZeroSet(e, s: s1, lhsZero: true);
375 latSets[sNew].append(RHS: latSets[lhsSet]);
376 latSets[sNew].append(RHS: latSets[rhsSet]);
377 return sNew;
378}
379
380LatSetId Merger::combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig,
381 bool includeLeft, TensorExp::Kind ltrans,
382 Operation *opleft, bool includeRight,
383 TensorExp::Kind rtrans, Operation *opright) {
384 const LatSetId sNew = conjSet(e, s0, s1, op: orig);
385 // Left Region.
386 if (includeLeft) {
387 if (opleft)
388 s0 = mapSet(kind: ltrans, s: s0, v: Value(), op: opleft);
389 latSets[sNew].append(RHS: latSets[s0]);
390 }
391 // Right Region.
392 if (includeRight) {
393 if (opright)
394 s1 = mapSet(kind: rtrans, s: s1, v: Value(), op: opright);
395 latSets[sNew].append(RHS: latSets[s1]);
396 }
397 return sNew;
398}
399
400LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v,
401 Operation *op) {
402 assert((TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect) ||
403 TensorExp::Kind::kDenseOp == kind);
404 const LatSetId sNew = addSet();
405 auto &setNew = latSets[sNew];
406 for (const LatPointId p : set(s0)) {
407 const auto &point = latPoints[p];
408 setNew.push_back(Elt: addLat(bits: point.bits, e: addExp(k: kind, e: point.exp, v, op)));
409 }
410 return sNew;
411}
412
413LatSetId Merger::mapBinWithSynZeroSet(ExprId e, LatSetId s0, bool lhsZero) {
414 TensorExp::Kind kind = exp(e).kind;
415 Attribute a = exp(e).attr;
416 assert(TensorExp::Kind::kMulF <= kind && kind <= TensorExp::Kind::kShlI);
417 // Must be a binary operation.
418 const LatSetId sNew = addSet();
419 auto &setNew = latSets[sNew];
420 const ExprId zeroExp = addSynZeroExp();
421 for (const LatPointId p : set(s0)) {
422 const auto &point = latPoints[p];
423 ExprId newExp = lhsZero ? addExp(k: kind, e0: zeroExp, e1: point.exp, op: nullptr, attr: a)
424 : addExp(k: kind, e0: point.exp, e1: zeroExp, op: nullptr, attr: a);
425 setNew.push_back(Elt: addLat(bits: point.bits, e: newExp));
426 }
427 return sNew;
428}
429
430LatSetId Merger::optimizeSet(LatSetId s0) {
431 const LatSetId sNew = addSet();
432 auto &setNew = latSets[sNew];
433 const auto &set0 = set(s0);
434 assert(!set0.empty());
435 const LatPointId p0 = set0[0];
436 for (const LatPointId p1 : set0) {
437 bool add = true;
438 if (p0 != p1) {
439 // Check whether this is a straightforward copy.
440 if (expIsTensor(e: latPoints[p1].exp, t: outTensor))
441 continue;
442 // Check whether this conjunction is already covered.
443 for (const LatPointId p2 : setNew) {
444 assert(!latGT(p1, p2)); // Lj => Li would be bad
445 if (onlyDenseDiff(p0: p2, p1)) {
446 add = false;
447 break;
448 }
449 }
450 assert(!add || latGT(p0, p1));
451 }
452 if (add)
453 setNew.push_back(Elt: p1);
454 }
455 for (const LatPointId p : setNew)
456 latPoints[p].simple = simplifyCond(s: sNew, p);
457 return sNew;
458}
459
460BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
461 // First determine if this lattice point is a *singleton*, i.e.,
462 // the last point in a lattice, no other is less than this one.
463 bool isSingleton = true;
464 for (const LatPointId p1 : set(s0)) {
465 if (p0 != p1 && latGT(p0, p1)) {
466 isSingleton = false;
467 break;
468 }
469 }
470
471 BitVector simple(latPoints[p0].bits);
472 bool reset = isSingleton && hasAnySparse(bits: simple);
473 const TensorLoopId be = simple.size();
474 TensorLoopId offset = 0; // relative to the end
475 if (!reset)
476 // Starts resetting from a dense level, so that the first bit (if kept)
477 // is not undefined level-type.
478 for (unsigned b = 0; b < be; b++) {
479 if (simple[b] && getLvlType(b: TensorLoopId{b}).hasDenseSemantic()) {
480 offset = be - b - 1; // relative to the end
481 break;
482 }
483 }
484
485 // Now apply the two basic rules. We also iterate the bits reversely to always
486 // keep the rightmost bit (which could possibly be a synthetic tensor).
487 for (unsigned b = be - 1 - offset, i = 0; i < be;
488 b = b == 0 ? be - 1 : b - 1, i++) {
489 // Slice on dense level has `locate` property as well, and can be optimized.
490 if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
491 const auto lt = getLvlType(b);
492 if (!lt.hasSparseSemantic()) {
493 if (reset)
494 simple.reset(Idx: b);
495 reset = true;
496 }
497 }
498 }
499 return simple;
500}
501
502bool Merger::latGT(LatPointId i, LatPointId j) const {
503 const BitVector &bitsi = lat(p: i).bits;
504 const BitVector &bitsj = lat(p: j).bits;
505 assert(bitsi.size() == bitsj.size());
506 if (bitsi.count() > bitsj.count()) {
507 for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++)
508 if (bitsj[b] && !bitsi[b])
509 return false;
510 return true;
511 }
512 return false;
513}
514
515bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const {
516 BitVector tmp(latPoints[j].bits);
517 tmp ^= latPoints[i].bits;
518 return !hasAnySparse(bits: tmp);
519}
520
521bool Merger::expContainsTensor(ExprId e, TensorId t) const {
522 const auto &expr = exp(e);
523 // First we check `expIsTensor`.
524 if (expr.kind == TensorExp::Kind::kTensor)
525 return expr.tensor == t;
526
527 switch (getExpArity(k: expr.kind)) {
528 case ExpArity::kNullary:
529 return false;
530 case ExpArity::kUnary: {
531 const ExprId e0 = expr.children.e0;
532 return expContainsTensor(e: e0, t);
533 }
534 case ExpArity::kBinary: {
535 const ExprId e0 = expr.children.e0;
536 const ExprId e1 = expr.children.e1;
537 return expContainsTensor(e: e0, t) || expContainsTensor(e: e1, t);
538 }
539 }
540 llvm_unreachable("unexpected arity");
541}
542
543bool Merger::hasNegateOnOut(ExprId e) const {
544 const auto &expr = exp(e);
545 switch (expr.kind) {
546 case TensorExp::Kind::kNegF:
547 case TensorExp::Kind::kNegC:
548 case TensorExp::Kind::kNegI:
549 return expContainsTensor(e: expr.children.e0, t: outTensor);
550 case TensorExp::Kind::kSubF:
551 case TensorExp::Kind::kSubC:
552 case TensorExp::Kind::kSubI:
553 return expContainsTensor(e: expr.children.e1, t: outTensor) ||
554 hasNegateOnOut(e: expr.children.e0);
555 case TensorExp::Kind::kDenseOp: {
556 bool lhsNeg = hasNegateOnOut(e: expr.children.e0);
557 if (!lhsNeg && expr.children.e1 != detail::kInvalidId)
558 return hasNegateOnOut(e: expr.children.e1);
559 return lhsNeg;
560 }
561 default: {
562 switch (getExpArity(k: expr.kind)) {
563 case ExpArity::kNullary:
564 return false;
565 case ExpArity::kUnary:
566 return hasNegateOnOut(e: expr.children.e0);
567 case ExpArity::kBinary:
568 return hasNegateOnOut(e: expr.children.e0) ||
569 hasNegateOnOut(e: expr.children.e1);
570 }
571 }
572 }
573 llvm_unreachable("unexpected kind");
574}
575
576bool Merger::isSingleCondition(TensorId t, ExprId e) const {
577 assert(isValidTensorId(t));
578 const auto &expr = exp(e);
579 switch (expr.kind) {
580 // Leaf.
581 case TensorExp::Kind::kTensor:
582 return expr.tensor == t;
583 case TensorExp::Kind::kInvariant:
584 case TensorExp::Kind::kLoopVar:
585 case TensorExp::Kind::kSynZero:
586 return false;
587 // Unary operations.
588 case TensorExp::Kind::kAbsF:
589 case TensorExp::Kind::kAbsC:
590 case TensorExp::Kind::kAbsI:
591 case TensorExp::Kind::kCeilF:
592 case TensorExp::Kind::kFloorF:
593 case TensorExp::Kind::kSqrtF:
594 case TensorExp::Kind::kSqrtC:
595 case TensorExp::Kind::kExpm1F:
596 case TensorExp::Kind::kExpm1C:
597 case TensorExp::Kind::kLog1pF:
598 case TensorExp::Kind::kLog1pC:
599 case TensorExp::Kind::kSinF:
600 case TensorExp::Kind::kSinC:
601 case TensorExp::Kind::kTanhF:
602 case TensorExp::Kind::kTanhC:
603 case TensorExp::Kind::kNegF:
604 case TensorExp::Kind::kNegC:
605 case TensorExp::Kind::kNegI:
606 case TensorExp::Kind::kTruncF:
607 case TensorExp::Kind::kExtF:
608 case TensorExp::Kind::kCastFS:
609 case TensorExp::Kind::kCastFU:
610 case TensorExp::Kind::kCastSF:
611 case TensorExp::Kind::kCastUF:
612 case TensorExp::Kind::kCastS:
613 case TensorExp::Kind::kCastU:
614 case TensorExp::Kind::kCastIdx:
615 case TensorExp::Kind::kTruncI:
616 case TensorExp::Kind::kCIm:
617 case TensorExp::Kind::kCRe:
618 case TensorExp::Kind::kBitCast:
619 case TensorExp::Kind::kUnary:
620 return isSingleCondition(t, e: expr.children.e0);
621 case TensorExp::Kind::kBinaryBranch:
622 case TensorExp::Kind::kSelect:
623 return false;
624 // Binary operations.
625 case TensorExp::Kind::kDivF: // note: x / c only
626 case TensorExp::Kind::kDivC:
627 case TensorExp::Kind::kDivS:
628 case TensorExp::Kind::kDivU:
629 assert(!maybeZero(expr.children.e1));
630 return isSingleCondition(t, e: expr.children.e0);
631 case TensorExp::Kind::kShrS: // note: x >> inv only
632 case TensorExp::Kind::kShrU:
633 case TensorExp::Kind::kShlI:
634 assert(isInvariant(expr.children.e1));
635 return isSingleCondition(t, e: expr.children.e0);
636 case TensorExp::Kind::kMulF:
637 case TensorExp::Kind::kMulC:
638 case TensorExp::Kind::kMulI:
639 case TensorExp::Kind::kAndI:
640 case TensorExp::Kind::kReduce:
641 if (isSingleCondition(t, e: expr.children.e0))
642 return isSingleCondition(t, e: expr.children.e1) ||
643 isInvariant(e: expr.children.e1);
644 if (isSingleCondition(t, e: expr.children.e1))
645 return isInvariant(e: expr.children.e0);
646 return false;
647 case TensorExp::Kind::kAddF:
648 case TensorExp::Kind::kAddC:
649 case TensorExp::Kind::kAddI:
650 return isSingleCondition(t, e: expr.children.e0) &&
651 isSingleCondition(t, e: expr.children.e1);
652 case TensorExp::Kind::kSubF:
653 case TensorExp::Kind::kSubC:
654 case TensorExp::Kind::kSubI:
655 case TensorExp::Kind::kOrI:
656 case TensorExp::Kind::kXorI:
657 case TensorExp::Kind::kCmpF:
658 case TensorExp::Kind::kCmpI:
659 case TensorExp::Kind::kBinary:
660 return false;
661 case TensorExp::Kind::kDenseOp:
662 // Since Merger guarantees all the operands of the kDenseOp to be dense, the
663 // operation must be single-condition.
664 return true;
665 }
666 llvm_unreachable("unexpected kind");
667}
668
669bool Merger::hasAnySparse(const BitVector &bits) const {
670 for (TensorLoopId b : bits.set_bits()) {
671 const auto lt = getLvlType(b);
672 if (lt.hasSparseSemantic())
673 return true;
674 }
675 return hasSparseIdxReduction(bits);
676}
677
678bool Merger::hasSparseIdxReduction(const BitVector &bits) const {
679 for (TensorLoopId b : bits.set_bits())
680 if (isSparseLvlWithNonTrivialIdxExp(b))
681 return true;
682 return false;
683}
684
685#ifndef NDEBUG
686
687//===----------------------------------------------------------------------===//
688// Print methods (for debugging).
689//===----------------------------------------------------------------------===//
690
691static const char *kindToOpSymbol(TensorExp::Kind kind) {
692 switch (kind) {
693 // Leaf.
694 case TensorExp::Kind::kTensor:
695 return "tensor";
696 case TensorExp::Kind::kInvariant:
697 return "invariant";
698 case TensorExp::Kind::kLoopVar:
699 return "index";
700 case TensorExp::Kind::kSynZero:
701 return "0";
702 // Unary operations.
703 case TensorExp::Kind::kAbsF:
704 case TensorExp::Kind::kAbsC:
705 case TensorExp::Kind::kAbsI:
706 return "abs";
707 case TensorExp::Kind::kCeilF:
708 return "ceil";
709 case TensorExp::Kind::kFloorF:
710 return "floor";
711 case TensorExp::Kind::kSqrtF:
712 case TensorExp::Kind::kSqrtC:
713 return "sqrt";
714 case TensorExp::Kind::kExpm1F:
715 case TensorExp::Kind::kExpm1C:
716 return "expm1";
717 case TensorExp::Kind::kLog1pF:
718 case TensorExp::Kind::kLog1pC:
719 return "log1p";
720 case TensorExp::Kind::kSinF:
721 case TensorExp::Kind::kSinC:
722 return "sin";
723 case TensorExp::Kind::kTanhF:
724 case TensorExp::Kind::kTanhC:
725 return "tanh";
726 case TensorExp::Kind::kNegF:
727 case TensorExp::Kind::kNegC:
728 case TensorExp::Kind::kNegI:
729 return "-";
730 case TensorExp::Kind::kTruncF:
731 case TensorExp::Kind::kExtF:
732 case TensorExp::Kind::kCastFS:
733 case TensorExp::Kind::kCastFU:
734 case TensorExp::Kind::kCastSF:
735 case TensorExp::Kind::kCastUF:
736 case TensorExp::Kind::kCastS:
737 case TensorExp::Kind::kCastU:
738 case TensorExp::Kind::kCastIdx:
739 case TensorExp::Kind::kTruncI:
740 case TensorExp::Kind::kCIm:
741 return "complex.im";
742 case TensorExp::Kind::kCRe:
743 return "complex.re";
744 case TensorExp::Kind::kBitCast:
745 return "cast";
746 case TensorExp::Kind::kBinaryBranch:
747 return "binary_branch";
748 case TensorExp::Kind::kUnary:
749 return "unary";
750 case TensorExp::Kind::kSelect:
751 return "select";
752 // Binary operations.
753 case TensorExp::Kind::kMulF:
754 case TensorExp::Kind::kMulC:
755 case TensorExp::Kind::kMulI:
756 return "*";
757 case TensorExp::Kind::kDivF:
758 case TensorExp::Kind::kDivC:
759 case TensorExp::Kind::kDivS:
760 case TensorExp::Kind::kDivU:
761 return "/";
762 case TensorExp::Kind::kAddF:
763 case TensorExp::Kind::kAddC:
764 case TensorExp::Kind::kAddI:
765 return "+";
766 case TensorExp::Kind::kSubF:
767 case TensorExp::Kind::kSubC:
768 case TensorExp::Kind::kSubI:
769 return "-";
770 case TensorExp::Kind::kAndI:
771 return "&";
772 case TensorExp::Kind::kOrI:
773 return "|";
774 case TensorExp::Kind::kXorI:
775 return "^";
776 case TensorExp::Kind::kShrS:
777 return "a>>";
778 case TensorExp::Kind::kShrU:
779 return ">>";
780 case TensorExp::Kind::kShlI:
781 return "<<";
782 case TensorExp::Kind::kCmpF:
783 case TensorExp::Kind::kCmpI:
784 return "cmp";
785 case TensorExp::Kind::kBinary:
786 return "binary";
787 case TensorExp::Kind::kReduce:
788 return "reduce";
789 case TensorExp::Kind::kDenseOp:
790 return "dense";
791 }
792 llvm_unreachable("unexpected kind for symbol");
793}
794
795void Merger::dumpExp(ExprId e) const {
796 const auto &expr = exp(e);
797 switch (expr.kind) {
798 // Leaf.
799 case TensorExp::Kind::kTensor:
800 if (expr.tensor == syntheticTensor)
801 llvm::dbgs() << "synthetic_";
802 else if (expr.tensor == outTensor)
803 llvm::dbgs() << "output_";
804 llvm::dbgs() << "tensor_" << expr.tensor;
805 break;
806 case TensorExp::Kind::kInvariant:
807 llvm::dbgs() << "invariant";
808 break;
809 case TensorExp::Kind::kSynZero:
810 llvm::dbgs() << "0";
811 break;
812 case TensorExp::Kind::kLoopVar:
813 llvm::dbgs() << "loopvar_" << expr.loop;
814 break;
815 // Unary operations.
816 case TensorExp::Kind::kAbsF:
817 case TensorExp::Kind::kAbsC:
818 case TensorExp::Kind::kAbsI:
819 case TensorExp::Kind::kCeilF:
820 case TensorExp::Kind::kFloorF:
821 case TensorExp::Kind::kSqrtF:
822 case TensorExp::Kind::kSqrtC:
823 case TensorExp::Kind::kExpm1F:
824 case TensorExp::Kind::kExpm1C:
825 case TensorExp::Kind::kLog1pF:
826 case TensorExp::Kind::kLog1pC:
827 case TensorExp::Kind::kSinF:
828 case TensorExp::Kind::kSinC:
829 case TensorExp::Kind::kTanhF:
830 case TensorExp::Kind::kTanhC:
831 case TensorExp::Kind::kNegF:
832 case TensorExp::Kind::kNegC:
833 case TensorExp::Kind::kNegI:
834 case TensorExp::Kind::kTruncF:
835 case TensorExp::Kind::kExtF:
836 case TensorExp::Kind::kCastFS:
837 case TensorExp::Kind::kCastFU:
838 case TensorExp::Kind::kCastSF:
839 case TensorExp::Kind::kCastUF:
840 case TensorExp::Kind::kCastS:
841 case TensorExp::Kind::kCastU:
842 case TensorExp::Kind::kCastIdx:
843 case TensorExp::Kind::kTruncI:
844 case TensorExp::Kind::kCIm:
845 case TensorExp::Kind::kCRe:
846 case TensorExp::Kind::kBitCast:
847 case TensorExp::Kind::kBinaryBranch:
848 case TensorExp::Kind::kUnary:
849 case TensorExp::Kind::kSelect:
850 llvm::dbgs() << kindToOpSymbol(kind: expr.kind) << " ";
851 dumpExp(e: expr.children.e0);
852 break;
853 // Binary operations.
854 case TensorExp::Kind::kMulF:
855 case TensorExp::Kind::kMulC:
856 case TensorExp::Kind::kMulI:
857 case TensorExp::Kind::kDivF:
858 case TensorExp::Kind::kDivC:
859 case TensorExp::Kind::kDivS:
860 case TensorExp::Kind::kDivU:
861 case TensorExp::Kind::kAddF:
862 case TensorExp::Kind::kAddC:
863 case TensorExp::Kind::kAddI:
864 case TensorExp::Kind::kSubF:
865 case TensorExp::Kind::kSubC:
866 case TensorExp::Kind::kSubI:
867 case TensorExp::Kind::kAndI:
868 case TensorExp::Kind::kOrI:
869 case TensorExp::Kind::kXorI:
870 case TensorExp::Kind::kShrS:
871 case TensorExp::Kind::kShrU:
872 case TensorExp::Kind::kShlI:
873 case TensorExp::Kind::kCmpF:
874 case TensorExp::Kind::kCmpI:
875 case TensorExp::Kind::kBinary:
876 case TensorExp::Kind::kReduce:
877 case TensorExp::Kind::kDenseOp:
878 llvm::dbgs() << "(";
879 dumpExp(e: expr.children.e0);
880 llvm::dbgs() << " " << kindToOpSymbol(kind: expr.kind);
881 if (expr.attr)
882 llvm::dbgs() << "{" << expr.attr << "}";
883 if (expr.children.e1 != detail::kInvalidId) {
884 llvm::dbgs() << " ";
885 dumpExp(e: expr.children.e1);
886 llvm::dbgs() << ")";
887 } else {
888 assert(expr.kind == TensorExp::Kind::kDenseOp);
889 }
890 break;
891 }
892}
893
894void Merger::dumpLat(LatPointId p) const {
895 const auto &point = lat(p);
896 llvm::dbgs() << "lat(";
897 dumpBits(bits: point.bits);
898 llvm::dbgs() << " :";
899 dumpBits(bits: point.simple);
900 llvm::dbgs() << " : ";
901 dumpExp(e: point.exp);
902 llvm::dbgs() << " )\n";
903}
904
905void Merger::dumpSet(LatSetId s) const {
906 const auto &ss = set(s);
907 llvm::dbgs() << "{ #" << ss.size() << "\n";
908 for (const LatPointId p : ss) {
909 llvm::dbgs() << " ";
910 dumpLat(p);
911 }
912 llvm::dbgs() << "}\n";
913}
914
915void Merger::dumpBits(const BitVector &bits) const {
916 for (TensorLoopId b = 0, be = bits.size(); b < be; b++) {
917 if (bits[b]) {
918 const TensorId t = tensor(b);
919 const LoopId i = loop(b);
920 const auto lt = lvlTypes[t][i];
921 if (isLvlWithNonTrivialIdxExp(b))
922 llvm::dbgs() << " DEP_" << t << "_" << i;
923 else
924 llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(lt);
925 }
926 }
927}
928
929#endif // NDEBUG
930
931//===----------------------------------------------------------------------===//
932// Builder methods.
933//===----------------------------------------------------------------------===//
934
935LatSetId Merger::buildLattices(ExprId e, LoopId i) {
936 // NOTE: The `expr` reference will be invalidated by recursive calls
937 // (and any other method that may add new expressions); therefore, the
938 // code below must make sure to copy fields of `expr` into local variables
939 // before making any recursive calls.
940 const auto &expr = exp(e);
941 const TensorExp::Kind kind = expr.kind;
942 switch (kind) {
943 // Leaf.
944 case TensorExp::Kind::kTensor:
945 case TensorExp::Kind::kInvariant:
946 case TensorExp::Kind::kSynZero:
947 case TensorExp::Kind::kLoopVar: {
948 // Either the loop-var is really used in the tensor expression, or it is
949 // set to the undefined loop-var in that level. An invariant expression,
950 // a proper index value, and a truly dynamic sparse output tensor are set
951 // to a synthetic tensor with undefined indices only to ensure the
952 // iteration space is not skipped as a result of their contents.
953 const LatSetId s = addSet();
954 TensorId t = syntheticTensor;
955 if (kind == TensorExp::Kind::kTensor) {
956 t = expr.tensor;
957 if (hasSparseOut && t == outTensor)
958 t = syntheticTensor;
959 }
960 latSets[s].push_back(Elt: addLat(t, i, e));
961 return s;
962 }
963 // Unary operations.
964 case TensorExp::Kind::kAbsF:
965 case TensorExp::Kind::kAbsC:
966 case TensorExp::Kind::kAbsI:
967 case TensorExp::Kind::kCeilF:
968 case TensorExp::Kind::kFloorF:
969 case TensorExp::Kind::kSqrtF:
970 case TensorExp::Kind::kSqrtC:
971 case TensorExp::Kind::kExpm1F:
972 case TensorExp::Kind::kExpm1C:
973 case TensorExp::Kind::kLog1pF:
974 case TensorExp::Kind::kLog1pC:
975 case TensorExp::Kind::kSinF:
976 case TensorExp::Kind::kSinC:
977 case TensorExp::Kind::kTanhF:
978 case TensorExp::Kind::kTanhC:
979 case TensorExp::Kind::kNegF:
980 case TensorExp::Kind::kNegC:
981 case TensorExp::Kind::kNegI:
982 case TensorExp::Kind::kTruncF:
983 case TensorExp::Kind::kExtF:
984 case TensorExp::Kind::kCastFS:
985 case TensorExp::Kind::kCastFU:
986 case TensorExp::Kind::kCastSF:
987 case TensorExp::Kind::kCastUF:
988 case TensorExp::Kind::kCastS:
989 case TensorExp::Kind::kCastU:
990 case TensorExp::Kind::kCastIdx:
991 case TensorExp::Kind::kTruncI:
992 case TensorExp::Kind::kCIm:
993 case TensorExp::Kind::kCRe:
994 case TensorExp::Kind::kBitCast:
995 // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
996 // lattice set of the operand through the operator into a new set.
997 //
998 // -y|!y | y |
999 // --+---+---+
1000 // | 0 |-y |
1001 {
1002 const ExprId e0 = expr.children.e0;
1003 const Value v = expr.val;
1004 return mapSet(kind, s0: buildLattices(e: e0, i), v);
1005 }
1006 case TensorExp::Kind::kBinaryBranch:
1007 case TensorExp::Kind::kSelect:
1008 // The left or right half of a binary operation which has already
1009 // been split into separate operations for each region.
1010 {
1011 const ExprId e0 = expr.children.e0;
1012 Operation *const op = expr.op;
1013 return mapSet(kind, s0: buildLattices(e: e0, i), v: Value(), op);
1014 }
1015 case TensorExp::Kind::kUnary:
1016 // A custom unary operation.
1017 //
1018 // op y| !y | y |
1019 // ----+----------+------------+
1020 // | absent() | present(y) |
1021 {
1022 const ExprId e0 = expr.children.e0;
1023 UnaryOp unop = cast<UnaryOp>(expr.op);
1024 const LatSetId child0 = buildLattices(e: e0, i);
1025 Region &absentRegion = unop.getAbsentRegion();
1026 if (absentRegion.empty()) {
1027 // Simple mapping over existing values.
1028 return mapSet(kind, s0: child0, v: Value(), op: unop);
1029 }
1030 // Use a disjunction with `unop` on the left and the absent value as an
1031 // invariant on the right.
1032 Block &absentBlock = absentRegion.front();
1033 YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
1034 const Value absentVal = absentYield.getSingleResult();
1035 const ExprId rhs = addInvariantExp(v: absentVal);
1036 return disjSet(e, s0: child0, s1: buildLattices(e: rhs, i), op: unop);
1037 }
1038 // Binary operations.
1039 case TensorExp::Kind::kMulF:
1040 case TensorExp::Kind::kMulC:
1041 case TensorExp::Kind::kMulI:
1042 case TensorExp::Kind::kAndI:
1043 // A multiplicative operation only needs to be performed
1044 // for the conjunction of sparse iteration spaces.
1045 //
1046 // x*y|!y | y |
1047 // ---+---+---+
1048 // !x | 0 | 0 |
1049 // x | 0 |x*y|
1050 //
1051 // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
1052 {
1053 const ExprId e0 = expr.children.e0;
1054 const ExprId e1 = expr.children.e1;
1055 return conjSet(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i));
1056 }
1057 case TensorExp::Kind::kDivF:
1058 case TensorExp::Kind::kDivC:
1059 case TensorExp::Kind::kDivS:
1060 case TensorExp::Kind::kDivU:
1061 // A division is tricky, since 0/0, 0/c, c/0 all have
1062 // specific outcomes for floating-point and integers.
1063 // Thus, we need to traverse the full iteration space.
1064 //
1065 // x/y|!y | y |
1066 // ---+---+---+
1067 // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
1068 // x |x/0|x/y| INT: x/0=exception for any x
1069 //
1070 // TODO: for now we "fixed" this by only accepting x/c cases
1071 // during expression building, so that the conjunction
1072 // rules applies (viz. x/c = x*(1/c) as far as lattice
1073 // construction is concerned).
1074 {
1075 const ExprId e0 = expr.children.e0;
1076 const ExprId e1 = expr.children.e1;
1077 assert(!maybeZero(e1));
1078 return conjSet(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i));
1079 }
1080 case TensorExp::Kind::kAddF:
1081 case TensorExp::Kind::kAddC:
1082 case TensorExp::Kind::kAddI:
1083 case TensorExp::Kind::kSubF:
1084 case TensorExp::Kind::kSubC:
1085 case TensorExp::Kind::kSubI:
1086 case TensorExp::Kind::kOrI:
1087 case TensorExp::Kind::kXorI:
1088 // An additive operation needs to be performed
1089 // for the disjunction of sparse iteration spaces.
1090 //
1091 // x+y|!y | y | x-y|!y | y |
1092 // ---+---+---+ ---+---+---+
1093 // !x | 0 | y | !x | 0 |-y |
1094 // x | x |x+y| x | x |x-y|
1095 {
1096 const ExprId e0 = expr.children.e0;
1097 const ExprId e1 = expr.children.e1;
1098 return disjSet(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i));
1099 }
1100 case TensorExp::Kind::kCmpF:
1101 case TensorExp::Kind::kCmpI:
1102 // A comparison operation needs to be performed
1103 // for the disjunction of sparse iteration spaces.
1104 //
1105 // x < y | !y | y |
1106 // -------+-------+-------+
1107 // !x | 0 | 0 < y |
1108 // x | x < 0 | x < y |
1109 {
1110 const ExprId e0 = expr.children.e0;
1111 const ExprId e1 = expr.children.e1;
1112 return disjSetWithZero(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i));
1113 }
1114 case TensorExp::Kind::kShrS:
1115 case TensorExp::Kind::kShrU:
1116 case TensorExp::Kind::kShlI:
1117 // A shift operation by an invariant amount (viz. tensor expressions
1118 // can only occur at the left-hand-side of the operator) can be handled
1119 // with the conjunction rule.
1120 {
1121 const ExprId e0 = expr.children.e0;
1122 const ExprId e1 = expr.children.e1;
1123 assert(isInvariant(e1));
1124 return conjSet(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i));
1125 }
1126 case TensorExp::Kind::kBinary:
1127 // A custom binary operation.
1128 //
1129 // x op y| !y | y |
1130 // ------+---------+--------------+
1131 // !x | empty | right(y) |
1132 // x | left(x) | overlap(x,y) |
1133 {
1134 const ExprId e0 = expr.children.e0;
1135 const ExprId e1 = expr.children.e1;
1136 BinaryOp binop = cast<BinaryOp>(expr.op);
1137 const LatSetId child0 = buildLattices(e: e0, i);
1138 const LatSetId child1 = buildLattices(e: e1, i);
1139 Region &leftRegion = binop.getLeftRegion();
1140 Region &rightRegion = binop.getRightRegion();
1141 // Left Region.
1142 Operation *leftYield = nullptr;
1143 if (!leftRegion.empty()) {
1144 Block &leftBlock = leftRegion.front();
1145 leftYield = leftBlock.getTerminator();
1146 }
1147 // Right Region.
1148 Operation *rightYield = nullptr;
1149 if (!rightRegion.empty()) {
1150 Block &rightBlock = rightRegion.front();
1151 rightYield = rightBlock.getTerminator();
1152 }
1153 bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty();
1154 bool includeRight = binop.getRightIdentity() || !rightRegion.empty();
1155 return combiSet(e, s0: child0, s1: child1, orig: binop, includeLeft,
1156 ltrans: TensorExp::Kind::kBinaryBranch, opleft: leftYield, includeRight,
1157 rtrans: TensorExp::Kind::kBinaryBranch, opright: rightYield);
1158 }
1159 case TensorExp::Kind::kReduce:
1160 // A custom reduce operation.
1161 {
1162 const ExprId e0 = expr.children.e0;
1163 const ExprId e1 = expr.children.e1;
1164 Operation *const op = expr.op;
1165 return conjSet(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i), op);
1166 }
1167 case TensorExp::Kind::kDenseOp: {
1168 // It does not really matter whether we use conjunctive/disjunctive set
1169 // here, as all the operands of kDenseOp must be dense, the disjunctive set
1170 // will be optimized into conjunctive set eventually.
1171 if (expr.children.e1 == detail::kInvalidId) {
1172 const ExprId e0 = expr.children.e0;
1173 Operation *const op = expr.op;
1174 return mapSet(kind, s0: buildLattices(e: e0, i), v: Value(), op);
1175 }
1176
1177 const ExprId e0 = expr.children.e0;
1178 const ExprId e1 = expr.children.e1;
1179 Operation *const op = expr.op;
1180 return conjSet(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i), op);
1181 }
1182 }
1183 llvm_unreachable("unexpected expression kind");
1184}
1185
1186std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
1187 // Build the linalg semantics backward from yield.
1188 Operation *yield = op.getRegion().front().getTerminator();
1189 assert(isa<linalg::YieldOp>(yield));
1190 return buildTensorExp(op, yield->getOperand(0)).first;
1191}
1192
1193/// Only returns false if we are certain this is a nonzero.
1194bool Merger::maybeZero(ExprId e) const {
1195 const auto &expr = exp(e);
1196 if (expr.kind == TensorExp::Kind::kInvariant) {
1197 if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) {
1198 ArrayAttr arrayAttr = c.getValue();
1199 return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
1200 cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
1201 }
1202 if (auto c = expr.val.getDefiningOp<arith::ConstantIntOp>())
1203 return c.value() == 0;
1204 if (auto c = expr.val.getDefiningOp<arith::ConstantFloatOp>())
1205 return c.value().isZero();
1206 }
1207 return true;
1208}
1209
1210Type Merger::inferType(ExprId e, Value src) const {
1211 // Obtain the destination type from the cast node.
1212 Type dtp = exp(e).val.getType();
1213 // Inspect source type. For vector types, apply the same
1214 // vectorization to the destination type.
1215 if (auto vtp = dyn_cast<VectorType>(src.getType()))
1216 return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims());
1217 return dtp;
1218}
1219
1220/// Ensures that the sparsifier can generate code for expression.
1221static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) {
1222 // Arguments are always admissible.
1223 if (isa<BlockArgument>(Val: v))
1224 return true;
1225 // Accept index anywhere.
1226 Operation *def = v.getDefiningOp();
1227 if (isa<linalg::IndexOp>(def))
1228 return true;
1229 // Operation defined outside branch.
1230 if (def->getBlock() != block)
1231 return def->getBlock() != op->getBlock(); // invariant?
1232 // Operation defined within branch. Anything is accepted,
1233 // as long as all subexpressions are admissible.
1234 for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
1235 if (!isAdmissibleBranchExp(op, block, v: def->getOperand(idx: i)))
1236 return false;
1237 return true;
1238}
1239
1240/// Ensures that the sparsifier can generate code for branch.
1241static bool isAdmissibleBranch(Operation *op, Region &region) {
1242 if (region.empty())
1243 return true;
1244 // Build the semi-ring branch semantics backward from yield.
1245 Operation *yield = region.front().getTerminator();
1246 assert(isa<YieldOp>(yield));
1247 return isAdmissibleBranchExp(op, block: &region.front(), v: yield->getOperand(idx: 0));
1248}
1249
1250std::pair<std::optional<ExprId>, bool>
1251Merger::buildTensorExp(linalg::GenericOp op, Value v) {
1252 // Recursion leaves.
1253 if (auto arg = dyn_cast<BlockArgument>(Val&: v)) {
1254 const TensorId tid = makeTensorId(t: arg.getArgNumber());
1255 // Any argument of the generic op that is not marked as a scalar
1256 // argument is considered a tensor, indexed by the implicit loop
1257 // bounds. This includes rank-0 tensor arguments.
1258 if (arg.getOwner()->getParentOp() == op) {
1259 OpOperand &t = op->getOpOperand(tid);
1260 bool hasSpDep = getSparseTensorEncoding(t.get().getType()) != nullptr;
1261 if (!op.isScalar(&t))
1262 return {addTensorExp(t: tid), hasSpDep};
1263 v = t.get(); // get scalar value
1264 }
1265 // Any other argument (marked as scalar argument for the generic op
1266 // or belonging to an enveloping op) is considered invariant.
1267 return {addInvariantExp(v), /*hasSpDep=*/false};
1268 }
1269 // Something defined outside is invariant.
1270 Operation *def = v.getDefiningOp();
1271 if (def->getBlock() != &op.getRegion().front())
1272 return {addInvariantExp(v), /*hasSpDep=*/false};
1273 // Construct index operations.
1274 if (def->getNumOperands() == 0) {
1275 if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
1276 return {addLoopVarExp(i: makeLoopId(i: indexOp.getDim())), /*hasSpDep=*/false};
1277 }
1278
1279 // Construct unary operations if subexpression can be built.
1280 if (def->getNumOperands() == 1) {
1281 const auto [x, hasSpDep] = buildTensorExp(op, def->getOperand(0));
1282 if (x.has_value()) {
1283 const ExprId e = *x;
1284 if (isa<math::AbsFOp>(def))
1285 return {addExp(k: TensorExp::Kind::kAbsF, e0: e), hasSpDep};
1286 if (isa<complex::AbsOp>(def))
1287 return {addExp(k: TensorExp::Kind::kAbsC, e0: e), hasSpDep};
1288 if (isa<math::AbsIOp>(def))
1289 return {addExp(k: TensorExp::Kind::kAbsI, e0: e), hasSpDep};
1290 if (isa<math::CeilOp>(def))
1291 return {addExp(k: TensorExp::Kind::kCeilF, e0: e), hasSpDep};
1292 if (isa<math::FloorOp>(def))
1293 return {addExp(k: TensorExp::Kind::kFloorF, e0: e), hasSpDep};
1294 if (isa<math::SqrtOp>(def))
1295 return {addExp(k: TensorExp::Kind::kSqrtF, e0: e), hasSpDep};
1296 if (isa<complex::SqrtOp>(def))
1297 return {addExp(k: TensorExp::Kind::kSqrtC, e0: e), hasSpDep};
1298 if (isa<math::ExpM1Op>(def))
1299 return {addExp(k: TensorExp::Kind::kExpm1F, e0: e), hasSpDep};
1300 if (isa<complex::Expm1Op>(def))
1301 return {addExp(k: TensorExp::Kind::kExpm1C, e0: e), hasSpDep};
1302 if (isa<math::Log1pOp>(def))
1303 return {addExp(k: TensorExp::Kind::kLog1pF, e0: e), hasSpDep};
1304 if (isa<complex::Log1pOp>(def))
1305 return {addExp(k: TensorExp::Kind::kLog1pC, e0: e), hasSpDep};
1306 if (isa<math::SinOp>(def))
1307 return {addExp(k: TensorExp::Kind::kSinF, e0: e), hasSpDep};
1308 if (isa<complex::SinOp>(def))
1309 return {addExp(k: TensorExp::Kind::kSinC, e0: e), hasSpDep};
1310 if (isa<math::TanhOp>(def))
1311 return {addExp(k: TensorExp::Kind::kTanhF, e0: e), hasSpDep};
1312 if (isa<complex::TanhOp>(def))
1313 return {addExp(k: TensorExp::Kind::kTanhC, e0: e), hasSpDep};
1314 if (isa<arith::NegFOp>(def))
1315 return {addExp(k: TensorExp::Kind::kNegF, e0: e), hasSpDep}; // no negi in std
1316 if (isa<complex::NegOp>(def))
1317 return {addExp(k: TensorExp::Kind::kNegC, e0: e), hasSpDep};
1318 if (isa<arith::TruncFOp>(def))
1319 return {addExp(k: TensorExp::Kind::kTruncF, e, v), hasSpDep};
1320 if (isa<arith::ExtFOp>(def))
1321 return {addExp(k: TensorExp::Kind::kExtF, e, v), hasSpDep};
1322 if (isa<arith::FPToSIOp>(def))
1323 return {addExp(k: TensorExp::Kind::kCastFS, e, v), hasSpDep};
1324 if (isa<arith::FPToUIOp>(def))
1325 return {addExp(k: TensorExp::Kind::kCastFU, e, v), hasSpDep};
1326 if (isa<arith::SIToFPOp>(def))
1327 return {addExp(k: TensorExp::Kind::kCastSF, e, v), hasSpDep};
1328 if (isa<arith::UIToFPOp>(def))
1329 return {addExp(k: TensorExp::Kind::kCastUF, e, v), hasSpDep};
1330 if (isa<arith::ExtSIOp>(def))
1331 return {addExp(k: TensorExp::Kind::kCastS, e, v), hasSpDep};
1332 if (isa<arith::ExtUIOp>(def))
1333 return {addExp(k: TensorExp::Kind::kCastU, e, v), hasSpDep};
1334 if (isa<arith::IndexCastOp>(def))
1335 return {addExp(k: TensorExp::Kind::kCastIdx, e, v), hasSpDep};
1336 if (isa<arith::TruncIOp>(def))
1337 return {addExp(k: TensorExp::Kind::kTruncI, e, v), hasSpDep};
1338 if (isa<complex::ImOp>(def))
1339 return {addExp(k: TensorExp::Kind::kCIm, e0: e), hasSpDep};
1340 if (isa<complex::ReOp>(def))
1341 return {addExp(k: TensorExp::Kind::kCRe, e0: e), hasSpDep};
1342 if (isa<arith::BitcastOp>(def))
1343 return {addExp(k: TensorExp::Kind::kBitCast, e, v), hasSpDep};
1344 if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
1345 if (isAdmissibleBranch(unop, unop.getPresentRegion()) &&
1346 isAdmissibleBranch(unop, unop.getAbsentRegion()))
1347 return {addExp(k: TensorExp::Kind::kUnary, e, v: Value(), op: def), hasSpDep};
1348 }
1349 if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) {
1350 if (isAdmissibleBranch(selop, selop.getRegion()))
1351 return {addExp(k: TensorExp::Kind::kSelect, e, v: Value(), op: def), hasSpDep};
1352 }
1353 }
1354 }
1355 // Construct binary operations if subexpressions can be built.
1356 // See buildLattices() for an explanation of rejecting certain
1357 // division and shift operations.
1358 if (def->getNumOperands() == 2) {
1359 const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
1360 const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
1361 bool hasSpDep = xDepSp || yDepSp;
1362 if (x.has_value() && y.has_value()) {
1363 const ExprId e0 = *x;
1364 const ExprId e1 = *y;
1365 if (isa<arith::MulFOp>(def))
1366 return {addExp(k: TensorExp::Kind::kMulF, e0, e1), hasSpDep};
1367 if (isa<complex::MulOp>(def))
1368 return {addExp(k: TensorExp::Kind::kMulC, e0, e1), hasSpDep};
1369 if (isa<arith::MulIOp>(def))
1370 return {addExp(k: TensorExp::Kind::kMulI, e0, e1), hasSpDep};
1371 if (isa<arith::DivFOp>(def) && !maybeZero(e1))
1372 return {addExp(k: TensorExp::Kind::kDivF, e0, e1), hasSpDep};
1373 if (isa<complex::DivOp>(def) && !maybeZero(e1))
1374 return {addExp(k: TensorExp::Kind::kDivC, e0, e1), hasSpDep};
1375 if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
1376 return {addExp(k: TensorExp::Kind::kDivS, e0, e1), hasSpDep};
1377 if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
1378 return {addExp(k: TensorExp::Kind::kDivU, e0, e1), hasSpDep};
1379 if (isa<arith::AddFOp>(def))
1380 return {addExp(k: TensorExp::Kind::kAddF, e0, e1), hasSpDep};
1381 if (isa<complex::AddOp>(def))
1382 return {addExp(k: TensorExp::Kind::kAddC, e0, e1), hasSpDep};
1383 if (isa<arith::AddIOp>(def))
1384 return {addExp(k: TensorExp::Kind::kAddI, e0, e1), hasSpDep};
1385 if (isa<arith::SubFOp>(def))
1386 return {addExp(k: TensorExp::Kind::kSubF, e0, e1), hasSpDep};
1387 if (isa<complex::SubOp>(def))
1388 return {addExp(k: TensorExp::Kind::kSubC, e0, e1), hasSpDep};
1389 if (isa<arith::SubIOp>(def))
1390 return {addExp(k: TensorExp::Kind::kSubI, e0, e1), hasSpDep};
1391 if (isa<arith::AndIOp>(def))
1392 return {addExp(k: TensorExp::Kind::kAndI, e0, e1), hasSpDep};
1393 if (isa<arith::OrIOp>(def))
1394 return {addExp(k: TensorExp::Kind::kOrI, e0, e1), hasSpDep};
1395 if (isa<arith::XOrIOp>(def))
1396 return {addExp(k: TensorExp::Kind::kXorI, e0, e1), hasSpDep};
1397 if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
1398 return {addExp(k: TensorExp::Kind::kShrS, e0, e1), hasSpDep};
1399 if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
1400 return {addExp(k: TensorExp::Kind::kShrU, e0, e1), hasSpDep};
1401 if (isa<arith::ShLIOp>(def) && isInvariant(e1))
1402 return {addExp(k: TensorExp::Kind::kShlI, e0, e1), hasSpDep};
1403 if (auto ci = dyn_cast<arith::CmpIOp>(def)) {
1404 if (ci.getPredicate() == arith::CmpIPredicate::eq &&
1405 ci.getPredicate() == arith::CmpIPredicate::sle &&
1406 ci.getPredicate() == arith::CmpIPredicate::sge &&
1407 ci.getPredicate() == arith::CmpIPredicate::ule &&
1408 ci.getPredicate() == arith::CmpIPredicate::uge) {
1409 // We can not sparsify comparison with equal, this is because 0 <= 0
1410 // yields true, and thus densifies the result.
1411 return {std::nullopt, false};
1412 }
1413
1414 auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr,
1415 ci.getPredicateAttr());
1416 return {e, hasSpDep};
1417 }
1418 if (auto cf = dyn_cast<arith::CmpFOp>(def)) {
1419 if (cf.getPredicate() == arith::CmpFPredicate::OEQ &&
1420 cf.getPredicate() == arith::CmpFPredicate::OGE &&
1421 cf.getPredicate() == arith::CmpFPredicate::OLE &&
1422 cf.getPredicate() == arith::CmpFPredicate::ONE &&
1423 cf.getPredicate() == arith::CmpFPredicate::UEQ &&
1424 cf.getPredicate() == arith::CmpFPredicate::UGE &&
1425 cf.getPredicate() == arith::CmpFPredicate::ULE &&
1426 cf.getPredicate() == arith::CmpFPredicate::ORD &&
1427 cf.getPredicate() == arith::CmpFPredicate::UNO) {
1428 // We can not sparsify comparison with equal, this is because 0 <= 0
1429 // yields true, and thus densifies the result.
1430 return {std::nullopt, false};
1431 }
1432 auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr,
1433 cf.getPredicateAttr());
1434 return {e, hasSpDep};
1435 }
1436 if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
1437 if (isAdmissibleBranch(binop, binop.getOverlapRegion()) &&
1438 (binop.getLeftIdentity() ||
1439 isAdmissibleBranch(binop, binop.getLeftRegion())) &&
1440 (binop.getRightIdentity() ||
1441 isAdmissibleBranch(binop, binop.getRightRegion())))
1442 return {addExp(k: TensorExp::Kind::kBinary, e0, e1, op: def), hasSpDep};
1443 }
1444 }
1445 }
1446 // Construct ternary operations if subexpressions can be built.
1447 if (def->getNumOperands() == 3) {
1448 const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
1449 const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
1450 const auto [z, zDepSp] = buildTensorExp(op, def->getOperand(2));
1451 bool hasSpDep = xDepSp || yDepSp || zDepSp;
1452 if (x.has_value() && y.has_value() && z.has_value()) {
1453 const ExprId e0 = *x;
1454 const ExprId e1 = *y;
1455 if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
1456 if (isAdmissibleBranch(redop, redop.getRegion()))
1457 return {addExp(k: TensorExp::Kind::kReduce, e0, e1, op: def), hasSpDep};
1458 }
1459 }
1460 }
1461
1462 // If we reach here, we are dealing with an operation that is not currently
1463 // sparsifiable. We can still generate code for it if all its operands only
1464 // have dense dependencies (i.e., all the values are loaded from dense
1465 // tensors).
1466 if (def->getNumResults() != 1) // only handle single result operation.
1467 return {std::nullopt, false};
1468
1469 SmallVector<std::pair<std::optional<ExprId>, bool>, 2> subExp;
1470 // Builds all the sub-expressions
1471 for (Value operand : def->getOperands())
1472 subExp.push_back(buildTensorExp(op, operand));
1473
1474 if (llvm::all_of(Range&: subExp,
1475 P: [](auto e) { return e.first.has_value() && !e.second; })) {
1476 // All the subexpressions can be built and has *no* sparse dependencies.
1477 if (subExp.size() == 2) {
1478 auto e = addExp(k: TensorExp::Kind::kDenseOp, e0: *subExp[0].first,
1479 e1: *subExp[1].first, op: def);
1480 return {e, false};
1481 }
1482 if (subExp.size() == 1) {
1483 auto e = addExp(k: TensorExp::Kind::kDenseOp, e0: *subExp[0].first,
1484 e1: detail::kInvalidId, op: def);
1485 return {e, false};
1486 }
1487 }
1488 // Cannot build.
1489 return {std::nullopt, false};
1490}
1491
1492static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
1493 ValueRange vals) {
1494 // Make a clone of overlap region.
1495 Region tmpRegion;
1496 IRMapping mapper;
1497 region.cloneInto(dest: &tmpRegion, destPos: tmpRegion.begin(), mapper);
1498 Block &clonedBlock = tmpRegion.front();
1499 YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator());
1500 // Merge cloned block and return yield value.
1501 Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
1502 rewriter.inlineBlockBefore(source: &tmpRegion.front(), op: placeholder, argValues: vals);
1503 Value val = clonedYield.getSingleResult();
1504 rewriter.eraseOp(op: clonedYield);
1505 rewriter.eraseOp(op: placeholder);
1506 return val;
1507}
1508
1509static Value buildUnaryPresent(RewriterBase &rewriter, Location loc,
1510 Operation *op, Value v0) {
1511 if (!v0)
1512 // Empty input value must be propagated.
1513 return Value();
1514 UnaryOp unop = cast<UnaryOp>(op);
1515 Region &presentRegion = unop.getPresentRegion();
1516 if (presentRegion.empty())
1517 // Uninitialized Value() will be interpreted as missing data in the
1518 // output.
1519 return Value();
1520 return insertYieldOp(rewriter, loc, region&: presentRegion, vals: {v0});
1521}
1522
1523static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
1524 Operation *op, Value v0, Value v1) {
1525 if (!v0 || !v1)
1526 // Empty input values must be propagated.
1527 return Value();
1528 BinaryOp binop = cast<BinaryOp>(op);
1529 Region &overlapRegion = binop.getOverlapRegion();
1530 if (overlapRegion.empty())
1531 // Uninitialized Value() will be interpreted as missing data in the
1532 // output.
1533 return Value();
1534 return insertYieldOp(rewriter, loc, region&: overlapRegion, vals: {v0, v1});
1535}
1536
1537Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
1538 Value v1) const {
1539 const auto &expr = exp(e);
1540 switch (expr.kind) {
1541 // Leaf.
1542 case TensorExp::Kind::kTensor:
1543 case TensorExp::Kind::kInvariant:
1544 case TensorExp::Kind::kLoopVar:
1545 case TensorExp::Kind::kSynZero:
1546 llvm_unreachable("unexpected non-op");
1547 // Unary operations.
1548 case TensorExp::Kind::kAbsF:
1549 return rewriter.create<math::AbsFOp>(loc, v0);
1550 case TensorExp::Kind::kAbsC: {
1551 auto type = cast<ComplexType>(v0.getType());
1552 auto eltType = cast<FloatType>(type.getElementType());
1553 return rewriter.create<complex::AbsOp>(loc, eltType, v0);
1554 }
1555 case TensorExp::Kind::kAbsI:
1556 return rewriter.create<math::AbsIOp>(loc, v0);
1557 case TensorExp::Kind::kCeilF:
1558 return rewriter.create<math::CeilOp>(loc, v0);
1559 case TensorExp::Kind::kFloorF:
1560 return rewriter.create<math::FloorOp>(loc, v0);
1561 case TensorExp::Kind::kSqrtF:
1562 return rewriter.create<math::SqrtOp>(loc, v0);
1563 case TensorExp::Kind::kSqrtC:
1564 return rewriter.create<complex::SqrtOp>(loc, v0);
1565 case TensorExp::Kind::kExpm1F:
1566 return rewriter.create<math::ExpM1Op>(loc, v0);
1567 case TensorExp::Kind::kExpm1C:
1568 return rewriter.create<complex::Expm1Op>(loc, v0);
1569 case TensorExp::Kind::kLog1pF:
1570 return rewriter.create<math::Log1pOp>(loc, v0);
1571 case TensorExp::Kind::kLog1pC:
1572 return rewriter.create<complex::Log1pOp>(loc, v0);
1573 case TensorExp::Kind::kSinF:
1574 return rewriter.create<math::SinOp>(loc, v0);
1575 case TensorExp::Kind::kSinC:
1576 return rewriter.create<complex::SinOp>(loc, v0);
1577 case TensorExp::Kind::kTanhF:
1578 return rewriter.create<math::TanhOp>(loc, v0);
1579 case TensorExp::Kind::kTanhC:
1580 return rewriter.create<complex::TanhOp>(loc, v0);
1581 case TensorExp::Kind::kNegF:
1582 return rewriter.create<arith::NegFOp>(loc, v0);
1583 case TensorExp::Kind::kNegC:
1584 return rewriter.create<complex::NegOp>(loc, v0);
1585 case TensorExp::Kind::kNegI: // no negi in std
1586 return rewriter.create<arith::SubIOp>(
1587 loc,
1588 rewriter.create<arith::ConstantOp>(loc, v0.getType(),
1589 rewriter.getZeroAttr(v0.getType())),
1590 v0);
1591 case TensorExp::Kind::kTruncF:
1592 return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
1593 case TensorExp::Kind::kExtF:
1594 return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
1595 case TensorExp::Kind::kCastFS:
1596 return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
1597 case TensorExp::Kind::kCastFU:
1598 return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
1599 case TensorExp::Kind::kCastSF:
1600 return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
1601 case TensorExp::Kind::kCastUF:
1602 return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
1603 case TensorExp::Kind::kCastS:
1604 return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
1605 case TensorExp::Kind::kCastU:
1606 return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
1607 case TensorExp::Kind::kCastIdx:
1608 return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
1609 case TensorExp::Kind::kTruncI:
1610 return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
1611 case TensorExp::Kind::kCIm: {
1612 auto type = cast<ComplexType>(v0.getType());
1613 auto eltType = cast<FloatType>(type.getElementType());
1614 return rewriter.create<complex::ImOp>(loc, eltType, v0);
1615 }
1616 case TensorExp::Kind::kCRe: {
1617 auto type = cast<ComplexType>(v0.getType());
1618 auto eltType = cast<FloatType>(type.getElementType());
1619 return rewriter.create<complex::ReOp>(loc, eltType, v0);
1620 }
1621 case TensorExp::Kind::kBitCast:
1622 return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
1623 // Binary operations.
1624 case TensorExp::Kind::kMulF:
1625 return rewriter.create<arith::MulFOp>(loc, v0, v1);
1626 case TensorExp::Kind::kMulC:
1627 return rewriter.create<complex::MulOp>(loc, v0, v1);
1628 case TensorExp::Kind::kMulI:
1629 return rewriter.create<arith::MulIOp>(loc, v0, v1);
1630 case TensorExp::Kind::kDivF:
1631 return rewriter.create<arith::DivFOp>(loc, v0, v1);
1632 case TensorExp::Kind::kDivC:
1633 return rewriter.create<complex::DivOp>(loc, v0, v1);
1634 case TensorExp::Kind::kDivS:
1635 return rewriter.create<arith::DivSIOp>(loc, v0, v1);
1636 case TensorExp::Kind::kDivU:
1637 return rewriter.create<arith::DivUIOp>(loc, v0, v1);
1638 case TensorExp::Kind::kAddF:
1639 return rewriter.create<arith::AddFOp>(loc, v0, v1);
1640 case TensorExp::Kind::kAddC:
1641 return rewriter.create<complex::AddOp>(loc, v0, v1);
1642 case TensorExp::Kind::kAddI:
1643 return rewriter.create<arith::AddIOp>(loc, v0, v1);
1644 case TensorExp::Kind::kSubF:
1645 return rewriter.create<arith::SubFOp>(loc, v0, v1);
1646 case TensorExp::Kind::kSubC:
1647 return rewriter.create<complex::SubOp>(loc, v0, v1);
1648 case TensorExp::Kind::kSubI:
1649 return rewriter.create<arith::SubIOp>(loc, v0, v1);
1650 case TensorExp::Kind::kAndI:
1651 return rewriter.create<arith::AndIOp>(loc, v0, v1);
1652 case TensorExp::Kind::kOrI:
1653 return rewriter.create<arith::OrIOp>(loc, v0, v1);
1654 case TensorExp::Kind::kXorI:
1655 return rewriter.create<arith::XOrIOp>(loc, v0, v1);
1656 case TensorExp::Kind::kShrS:
1657 return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
1658 case TensorExp::Kind::kShrU:
1659 return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
1660 case TensorExp::Kind::kShlI:
1661 return rewriter.create<arith::ShLIOp>(loc, v0, v1);
1662 case TensorExp::Kind::kCmpI: {
1663 auto predicate = llvm::cast<arith::CmpIPredicateAttr>(expr.attr);
1664 return rewriter.create<arith::CmpIOp>(loc, predicate, v0, v1);
1665 }
1666 case TensorExp::Kind::kCmpF: {
1667 auto predicate = llvm::cast<arith::CmpFPredicateAttr>(expr.attr);
1668 return rewriter.create<arith::CmpFOp>(loc, predicate, v0, v1);
1669 }
1670 case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic.
1671 return insertYieldOp(rewriter, loc, region&: *expr.op->getBlock()->getParent(),
1672 vals: {v0});
1673 case TensorExp::Kind::kUnary:
1674 return buildUnaryPresent(rewriter, loc, op: expr.op, v0);
1675 case TensorExp::Kind::kSelect:
1676 return insertYieldOp(rewriter, loc, cast<SelectOp>(expr.op).getRegion(),
1677 {v0});
1678 case TensorExp::Kind::kBinary:
1679 return buildBinaryOverlap(rewriter, loc, op: expr.op, v0, v1);
1680 case TensorExp::Kind::kReduce: {
1681 ReduceOp redOp = cast<ReduceOp>(expr.op);
1682 return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1});
1683 }
1684 case TensorExp::Kind::kDenseOp: {
1685 Operation *actualOp = expr.op;
1686 IRMapping mapping;
1687 mapping.map(from: actualOp->getOperand(idx: 0), to: v0);
1688 if (actualOp->getNumOperands() == 2)
1689 mapping.map(from: actualOp->getOperand(idx: 1), to: v1);
1690 return rewriter.clone(op&: *actualOp, mapper&: mapping)->getResult(idx: 0);
1691 }
1692 }
1693 llvm_unreachable("unexpected expression kind in build");
1694}
1695
1696} // namespace sparse_tensor
1697} // namespace mlir
1698

source code of mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp