| 1 | //===- Merger.cpp - Implementation of iteration lattices ------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/SparseTensor/Utils/Merger.h" |
| 10 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 11 | #include "mlir/Dialect/Complex/IR/Complex.h" |
| 12 | #include "mlir/Dialect/Math/IR/Math.h" |
| 13 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
| 14 | |
| 15 | #include "mlir/IR/Operation.h" |
| 16 | #include "llvm/Support/Debug.h" |
| 17 | #include <optional> |
| 18 | |
| 19 | namespace mlir { |
| 20 | namespace sparse_tensor { |
| 21 | |
| 22 | enum class ExpArity { |
| 23 | kNullary, |
| 24 | kUnary, |
| 25 | kBinary, |
| 26 | }; |
| 27 | |
| 28 | static ExpArity getExpArity(TensorExp::Kind k) { |
| 29 | switch (k) { |
| 30 | // Leaf. |
| 31 | case TensorExp::Kind::kTensor: |
| 32 | case TensorExp::Kind::kInvariant: |
| 33 | case TensorExp::Kind::kLoopVar: |
| 34 | case TensorExp::Kind::kSynZero: |
| 35 | return ExpArity::kNullary; |
| 36 | case TensorExp::Kind::kAbsF: |
| 37 | case TensorExp::Kind::kAbsC: |
| 38 | case TensorExp::Kind::kAbsI: |
| 39 | case TensorExp::Kind::kCeilF: |
| 40 | case TensorExp::Kind::kFloorF: |
| 41 | case TensorExp::Kind::kSqrtF: |
| 42 | case TensorExp::Kind::kSqrtC: |
| 43 | case TensorExp::Kind::kExpm1F: |
| 44 | case TensorExp::Kind::kExpm1C: |
| 45 | case TensorExp::Kind::kLog1pF: |
| 46 | case TensorExp::Kind::kLog1pC: |
| 47 | case TensorExp::Kind::kRelu: |
| 48 | case TensorExp::Kind::kSinF: |
| 49 | case TensorExp::Kind::kSinC: |
| 50 | case TensorExp::Kind::kTanhF: |
| 51 | case TensorExp::Kind::kTanhC: |
| 52 | case TensorExp::Kind::kTruncF: |
| 53 | case TensorExp::Kind::kExtF: |
| 54 | case TensorExp::Kind::kCastFS: |
| 55 | case TensorExp::Kind::kCastFU: |
| 56 | case TensorExp::Kind::kCastSF: |
| 57 | case TensorExp::Kind::kCastUF: |
| 58 | case TensorExp::Kind::kCastS: |
| 59 | case TensorExp::Kind::kCastU: |
| 60 | case TensorExp::Kind::kCastIdx: |
| 61 | case TensorExp::Kind::kTruncI: |
| 62 | case TensorExp::Kind::kCIm: |
| 63 | case TensorExp::Kind::kCRe: |
| 64 | case TensorExp::Kind::kBitCast: |
| 65 | case TensorExp::Kind::kBinaryBranch: |
| 66 | case TensorExp::Kind::kUnary: |
| 67 | case TensorExp::Kind::kSelect: |
| 68 | case TensorExp::Kind::kNegF: |
| 69 | case TensorExp::Kind::kNegC: |
| 70 | case TensorExp::Kind::kNegI: |
| 71 | return ExpArity::kUnary; |
| 72 | // Binary operations. |
| 73 | case TensorExp::Kind::kDivF: |
| 74 | case TensorExp::Kind::kDivC: |
| 75 | case TensorExp::Kind::kDivS: |
| 76 | case TensorExp::Kind::kDivU: |
| 77 | case TensorExp::Kind::kShrS: |
| 78 | case TensorExp::Kind::kShrU: |
| 79 | case TensorExp::Kind::kShlI: |
| 80 | case TensorExp::Kind::kMulF: |
| 81 | case TensorExp::Kind::kMulC: |
| 82 | case TensorExp::Kind::kMulI: |
| 83 | case TensorExp::Kind::kAndI: |
| 84 | case TensorExp::Kind::kAddF: |
| 85 | case TensorExp::Kind::kAddC: |
| 86 | case TensorExp::Kind::kAddI: |
| 87 | case TensorExp::Kind::kOrI: |
| 88 | case TensorExp::Kind::kXorI: |
| 89 | case TensorExp::Kind::kBinary: |
| 90 | case TensorExp::Kind::kReduce: |
| 91 | case TensorExp::Kind::kSubF: |
| 92 | case TensorExp::Kind::kSubC: |
| 93 | case TensorExp::Kind::kSubI: |
| 94 | case TensorExp::Kind::kCmpF: |
| 95 | case TensorExp::Kind::kCmpI: |
| 96 | case TensorExp::Kind::kDenseOp: // kDenseOp can *at most* have two operands |
| 97 | return ExpArity::kBinary; |
| 98 | } |
| 99 | llvm_unreachable("unexpected kind" ); |
| 100 | } |
| 101 | |
| 102 | //===----------------------------------------------------------------------===// |
| 103 | // Constructors. |
| 104 | //===----------------------------------------------------------------------===// |
| 105 | |
| 106 | TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v, |
| 107 | Operation *o, Attribute a) |
| 108 | : kind(k), val(v), op(o), attr(a) { |
| 109 | switch (kind) { |
| 110 | // Leaf. |
| 111 | case TensorExp::Kind::kTensor: |
| 112 | assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); |
| 113 | tensor = x; |
| 114 | return; |
| 115 | case TensorExp::Kind::kSynZero: |
| 116 | assert(x == detail::kInvalidId && y == detail::kInvalidId && !v && !o); |
| 117 | return; |
| 118 | case TensorExp::Kind::kInvariant: |
| 119 | assert(x == detail::kInvalidId && y == detail::kInvalidId && v && !o); |
| 120 | return; |
| 121 | case TensorExp::Kind::kLoopVar: |
| 122 | assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); |
| 123 | loop = x; |
| 124 | return; |
| 125 | // Unary operations. |
| 126 | case TensorExp::Kind::kAbsF: |
| 127 | case TensorExp::Kind::kAbsC: |
| 128 | case TensorExp::Kind::kAbsI: |
| 129 | case TensorExp::Kind::kCeilF: |
| 130 | case TensorExp::Kind::kFloorF: |
| 131 | case TensorExp::Kind::kSqrtF: |
| 132 | case TensorExp::Kind::kSqrtC: |
| 133 | case TensorExp::Kind::kExpm1F: |
| 134 | case TensorExp::Kind::kExpm1C: |
| 135 | case TensorExp::Kind::kLog1pF: |
| 136 | case TensorExp::Kind::kLog1pC: |
| 137 | case TensorExp::Kind::kRelu: |
| 138 | case TensorExp::Kind::kSinF: |
| 139 | case TensorExp::Kind::kSinC: |
| 140 | case TensorExp::Kind::kTanhF: |
| 141 | case TensorExp::Kind::kTanhC: |
| 142 | case TensorExp::Kind::kNegF: |
| 143 | case TensorExp::Kind::kNegC: |
| 144 | case TensorExp::Kind::kNegI: |
| 145 | case TensorExp::Kind::kCIm: |
| 146 | case TensorExp::Kind::kCRe: |
| 147 | assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && !o); |
| 148 | children.e0 = x; |
| 149 | children.e1 = y; |
| 150 | return; |
| 151 | case TensorExp::Kind::kTruncF: |
| 152 | case TensorExp::Kind::kExtF: |
| 153 | case TensorExp::Kind::kCastFS: |
| 154 | case TensorExp::Kind::kCastFU: |
| 155 | case TensorExp::Kind::kCastSF: |
| 156 | case TensorExp::Kind::kCastUF: |
| 157 | case TensorExp::Kind::kCastS: |
| 158 | case TensorExp::Kind::kCastU: |
| 159 | case TensorExp::Kind::kCastIdx: |
| 160 | case TensorExp::Kind::kTruncI: |
| 161 | case TensorExp::Kind::kBitCast: |
| 162 | assert(x != detail::kInvalidId && y == detail::kInvalidId && v && !o); |
| 163 | children.e0 = x; |
| 164 | children.e1 = y; |
| 165 | return; |
| 166 | case TensorExp::Kind::kBinaryBranch: |
| 167 | case TensorExp::Kind::kSelect: |
| 168 | assert(x != detail::kInvalidId && y == detail::kInvalidId && !v && o); |
| 169 | children.e0 = x; |
| 170 | children.e1 = y; |
| 171 | return; |
| 172 | case TensorExp::Kind::kUnary: |
| 173 | // No assertion on y can be made, as the branching paths involve both |
| 174 | // a unary (`mapSet`) and binary (`disjSet`) pathway. |
| 175 | assert(x != detail::kInvalidId && !v && o); |
| 176 | children.e0 = x; |
| 177 | children.e1 = y; |
| 178 | return; |
| 179 | // Binary operations. |
| 180 | case TensorExp::Kind::kMulF: |
| 181 | case TensorExp::Kind::kMulC: |
| 182 | case TensorExp::Kind::kMulI: |
| 183 | case TensorExp::Kind::kDivF: |
| 184 | case TensorExp::Kind::kDivC: |
| 185 | case TensorExp::Kind::kDivS: |
| 186 | case TensorExp::Kind::kDivU: |
| 187 | case TensorExp::Kind::kAddF: |
| 188 | case TensorExp::Kind::kAddC: |
| 189 | case TensorExp::Kind::kAddI: |
| 190 | case TensorExp::Kind::kSubF: |
| 191 | case TensorExp::Kind::kSubC: |
| 192 | case TensorExp::Kind::kSubI: |
| 193 | case TensorExp::Kind::kAndI: |
| 194 | case TensorExp::Kind::kOrI: |
| 195 | case TensorExp::Kind::kXorI: |
| 196 | case TensorExp::Kind::kShrS: |
| 197 | case TensorExp::Kind::kShrU: |
| 198 | case TensorExp::Kind::kShlI: |
| 199 | assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o); |
| 200 | children.e0 = x; |
| 201 | children.e1 = y; |
| 202 | return; |
| 203 | case TensorExp::Kind::kCmpF: |
| 204 | case TensorExp::Kind::kCmpI: |
| 205 | assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && !o); |
| 206 | children.e0 = x; |
| 207 | children.e1 = y; |
| 208 | return; |
| 209 | case TensorExp::Kind::kBinary: |
| 210 | case TensorExp::Kind::kReduce: |
| 211 | assert(x != detail::kInvalidId && y != detail::kInvalidId && !v && o); |
| 212 | children.e0 = x; |
| 213 | children.e1 = y; |
| 214 | return; |
| 215 | case TensorExp::Kind::kDenseOp: |
| 216 | assert(x != detail::kInvalidId && !v && o); |
| 217 | children.e0 = x; |
| 218 | children.e1 = y; |
| 219 | return; |
| 220 | } |
| 221 | llvm_unreachable("unexpected kind" ); |
| 222 | } |
| 223 | |
| 224 | Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops, |
| 225 | unsigned maxLvlRank) |
| 226 | : outTensor(numInputOutputTensors - 1), |
| 227 | syntheticTensor(numInputOutputTensors), |
| 228 | numTensors(numInputOutputTensors + 1), numLoops(numLoops), |
| 229 | hasSparseOut(false), |
| 230 | lvlTypes(numTensors, |
| 231 | std::vector<LevelType>(numLoops, LevelFormat::Undef)), |
| 232 | loopToLvl(numTensors, |
| 233 | std::vector<std::optional<Level>>(numLoops, std::nullopt)), |
| 234 | lvlToLoop(numTensors, |
| 235 | std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)), |
| 236 | loopToUnresolvedLvls(numLoops, std::vector<std::optional<LvlLTPair>>( |
| 237 | numTensors, std::nullopt)), |
| 238 | levelToDependentLoop(numTensors, |
| 239 | std::vector<std::vector<LoopCoeffPair>>( |
| 240 | maxLvlRank, std::vector<LoopCoeffPair>())), |
| 241 | loopBounds(numLoops, std::make_pair(x: numTensors, y&: numLoops)) {} |
| 242 | |
| 243 | //===----------------------------------------------------------------------===// |
| 244 | // Lattice methods. |
| 245 | //===----------------------------------------------------------------------===// |
| 246 | |
| 247 | ExprId Merger::addTensorExp(TensorId t) { |
| 248 | assert(isValidTensorId(t)); |
| 249 | const ExprId eNew(tensorExps.size()); |
| 250 | tensorExps.emplace_back(Args: TensorExp::Kind::kTensor, Args&: t, Args: detail::kInvalidId, |
| 251 | Args: Value(), Args: nullptr, Args: nullptr); |
| 252 | return eNew; |
| 253 | } |
| 254 | |
| 255 | ExprId Merger::addLoopVarExp(LoopId i) { |
| 256 | assert(isValidLoopId(i)); |
| 257 | const ExprId eNew(tensorExps.size()); |
| 258 | tensorExps.emplace_back(Args: TensorExp::Kind::kLoopVar, Args&: i, Args: detail::kInvalidId, |
| 259 | Args: Value(), Args: nullptr, Args: nullptr); |
| 260 | return eNew; |
| 261 | } |
| 262 | |
| 263 | ExprId Merger::addInvariantExp(Value v) { |
| 264 | const ExprId eNew(tensorExps.size()); |
| 265 | tensorExps.emplace_back(Args: TensorExp::Kind::kInvariant, Args: detail::kInvalidId, |
| 266 | Args: detail::kInvalidId, Args&: v, Args: nullptr, Args: nullptr); |
| 267 | return eNew; |
| 268 | } |
| 269 | |
| 270 | ExprId Merger::addSynZeroExp() { |
| 271 | const ExprId eNew(tensorExps.size()); |
| 272 | tensorExps.emplace_back(Args: TensorExp::Kind::kSynZero, Args: detail::kInvalidId, |
| 273 | Args: detail::kInvalidId, Args: Value(), Args: nullptr, Args: nullptr); |
| 274 | return eNew; |
| 275 | } |
| 276 | |
| 277 | ExprId Merger::addExp(TensorExp::Kind k, ExprId e0, ExprId e1, Operation *op, |
| 278 | Attribute attr) { |
| 279 | assert(k > TensorExp::Kind::kLoopVar); |
| 280 | const ExprId eNew(tensorExps.size()); |
| 281 | tensorExps.emplace_back(Args&: k, Args&: e0, Args&: e1, Args: Value(), Args&: op, Args&: attr); |
| 282 | return eNew; |
| 283 | } |
| 284 | |
| 285 | ExprId Merger::addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op, |
| 286 | Attribute attr) { |
| 287 | assert(k > TensorExp::Kind::kLoopVar); |
| 288 | const ExprId eNew(tensorExps.size()); |
| 289 | tensorExps.emplace_back(Args&: k, Args&: e, Args: detail::kInvalidId, Args&: v, Args&: op, Args&: attr); |
| 290 | return eNew; |
| 291 | } |
| 292 | |
| 293 | LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) { |
| 294 | const LatPointId pNew(latPoints.size()); |
| 295 | const unsigned size = numLoops * numTensors; |
| 296 | const TensorLoopId b = makeTensorLoopId(t, i); |
| 297 | latPoints.emplace_back(Args: size, Args&: e); |
| 298 | latPoints[pNew].bits.set(b); |
| 299 | return pNew; |
| 300 | } |
| 301 | |
| 302 | LatPointId Merger::addLat(const BitVector &bits, ExprId e) { |
| 303 | assert(bits.size() == numLoops * numTensors); |
| 304 | const LatPointId pNew(latPoints.size()); |
| 305 | latPoints.emplace_back(Args: bits, Args&: e); |
| 306 | return pNew; |
| 307 | } |
| 308 | |
| 309 | LatSetId Merger::addSet() { |
| 310 | const LatSetId sNew(latSets.size()); |
| 311 | latSets.emplace_back(); |
| 312 | return sNew; |
| 313 | } |
| 314 | |
| 315 | LatPointId Merger::conjLat(ExprId e, LatPointId p0, LatPointId p1, |
| 316 | Operation *op) { |
| 317 | TensorExp::Kind kind = exp(e).kind; |
| 318 | Attribute attr = exp(e).attr; |
| 319 | const LatPointId pNew(latPoints.size()); |
| 320 | const auto &point0 = lat(p: p0); |
| 321 | const auto &point1 = lat(p: p1); |
| 322 | BitVector bits(point0.bits); |
| 323 | bits |= point1.bits; |
| 324 | const ExprId ne = addExp(k: kind, e0: point0.exp, e1: point1.exp, op, attr); |
| 325 | latPoints.emplace_back(Args&: bits, Args: ne); |
| 326 | return pNew; |
| 327 | } |
| 328 | |
| 329 | LatSetId Merger::conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) { |
| 330 | const LatSetId sNew = addSet(); |
| 331 | auto &setNew = latSets[sNew]; |
| 332 | for (const LatPointId p0 : set(s0)) |
| 333 | for (const LatPointId p1 : set(s1)) |
| 334 | setNew.push_back(Elt: conjLat(e, p0, p1, op)); |
| 335 | return sNew; |
| 336 | } |
| 337 | |
| 338 | LatSetId Merger::disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) { |
| 339 | const LatSetId sNew = conjSet(e, s0, s1, op); |
| 340 | TensorExp::Kind kind = exp(e).kind; |
| 341 | // Followed by all in s0. |
| 342 | latSets[sNew].append(RHS: latSets[s0]); |
| 343 | // Map binary 0-y to unary -y. |
| 344 | // TODO: move this if-else logic into buildLattices |
| 345 | if (kind == TensorExp::Kind::kSubF) |
| 346 | s1 = mapSet(kind: TensorExp::Kind::kNegF, s: s1); |
| 347 | else if (kind == TensorExp::Kind::kSubC) |
| 348 | s1 = mapSet(kind: TensorExp::Kind::kNegC, s: s1); |
| 349 | else if (kind == TensorExp::Kind::kSubI) |
| 350 | s1 = mapSet(kind: TensorExp::Kind::kNegI, s: s1); |
| 351 | // Followed by all in s1. |
| 352 | latSets[sNew].append(RHS: latSets[s1]); |
| 353 | return sNew; |
| 354 | } |
| 355 | |
| 356 | LatSetId Merger::disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1) { |
| 357 | assert(exp(e).kind == TensorExp::Kind::kCmpI || |
| 358 | exp(e).kind == TensorExp::Kind::kCmpF); |
| 359 | const LatSetId sNew = conjSet(e, s0, s1, op: nullptr); |
| 360 | |
| 361 | ExprId e0 = exp(e).children.e0; |
| 362 | ExprId e1 = exp(e).children.e1; |
| 363 | if (exp(e: e0).kind == TensorExp::Kind::kSynZero || |
| 364 | exp(e: e1).kind == TensorExp::Kind::kSynZero) { |
| 365 | // lhs and rhs can't be synthetic zero at the same time. |
| 366 | assert(exp(e0).kind != exp(e1).kind); |
| 367 | // If one of the operands has already been assigned to zero (the |
| 368 | // element is absent in the corresponding operand), then we do not |
| 369 | // need to build disjunctive set for it. |
| 370 | return sNew; |
| 371 | } |
| 372 | |
| 373 | auto lhsSet = mapBinWithSynZeroSet(e, s: s0, lhsZero: false); |
| 374 | auto rhsSet = mapBinWithSynZeroSet(e, s: s1, lhsZero: true); |
| 375 | latSets[sNew].append(RHS: latSets[lhsSet]); |
| 376 | latSets[sNew].append(RHS: latSets[rhsSet]); |
| 377 | return sNew; |
| 378 | } |
| 379 | |
| 380 | LatSetId Merger::combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig, |
| 381 | bool includeLeft, TensorExp::Kind ltrans, |
| 382 | Operation *opleft, bool includeRight, |
| 383 | TensorExp::Kind rtrans, Operation *opright) { |
| 384 | Attribute a = exp(e).attr; |
| 385 | const LatSetId sNew = conjSet(e, s0, s1, op: orig); |
| 386 | // Left Region. |
| 387 | if (includeLeft) { |
| 388 | if (opleft) |
| 389 | s0 = mapSet(kind: ltrans, s: s0, v: Value(), op: opleft, attr: a); |
| 390 | latSets[sNew].append(RHS: latSets[s0]); |
| 391 | } |
| 392 | // Right Region. |
| 393 | if (includeRight) { |
| 394 | if (opright) |
| 395 | s1 = mapSet(kind: rtrans, s: s1, v: Value(), op: opright, attr: a); |
| 396 | latSets[sNew].append(RHS: latSets[s1]); |
| 397 | } |
| 398 | return sNew; |
| 399 | } |
| 400 | |
| 401 | LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v, |
| 402 | Operation *op, Attribute a) { |
| 403 | assert((TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect) || |
| 404 | TensorExp::Kind::kDenseOp == kind); |
| 405 | const LatSetId sNew = addSet(); |
| 406 | auto &setNew = latSets[sNew]; |
| 407 | for (const LatPointId p : set(s0)) { |
| 408 | const auto &point = latPoints[p]; |
| 409 | setNew.push_back(Elt: addLat(bits: point.bits, e: addExp(k: kind, e: point.exp, v, op, attr: a))); |
| 410 | } |
| 411 | return sNew; |
| 412 | } |
| 413 | |
| 414 | LatSetId Merger::mapBinWithSynZeroSet(ExprId e, LatSetId s0, bool lhsZero) { |
| 415 | TensorExp::Kind kind = exp(e).kind; |
| 416 | Attribute a = exp(e).attr; |
| 417 | assert(TensorExp::Kind::kMulF <= kind && kind <= TensorExp::Kind::kShlI); |
| 418 | // Must be a binary operation. |
| 419 | const LatSetId sNew = addSet(); |
| 420 | auto &setNew = latSets[sNew]; |
| 421 | const ExprId zeroExp = addSynZeroExp(); |
| 422 | for (const LatPointId p : set(s0)) { |
| 423 | const auto &point = latPoints[p]; |
| 424 | ExprId newExp = lhsZero ? addExp(k: kind, e0: zeroExp, e1: point.exp, op: nullptr, attr: a) |
| 425 | : addExp(k: kind, e0: point.exp, e1: zeroExp, op: nullptr, attr: a); |
| 426 | setNew.push_back(Elt: addLat(bits: point.bits, e: newExp)); |
| 427 | } |
| 428 | return sNew; |
| 429 | } |
| 430 | |
| 431 | LatSetId Merger::optimizeSet(LatSetId s0) { |
| 432 | const LatSetId sNew = addSet(); |
| 433 | auto &setNew = latSets[sNew]; |
| 434 | const auto &set0 = set(s0); |
| 435 | assert(!set0.empty()); |
| 436 | const LatPointId p0 = set0[0]; |
| 437 | for (const LatPointId p1 : set0) { |
| 438 | bool add = true; |
| 439 | if (p0 != p1) { |
| 440 | // Check whether this is a straightforward copy. |
| 441 | if (expIsTensor(e: latPoints[p1].exp, t: outTensor)) |
| 442 | continue; |
| 443 | // Check whether this conjunction is already covered. |
| 444 | for (const LatPointId p2 : setNew) { |
| 445 | assert(!latGT(p1, p2)); // Lj => Li would be bad |
| 446 | if (onlyDenseDiff(p0: p2, p1)) { |
| 447 | add = false; |
| 448 | break; |
| 449 | } |
| 450 | } |
| 451 | assert(!add || latGT(p0, p1)); |
| 452 | } |
| 453 | if (add) |
| 454 | setNew.push_back(Elt: p1); |
| 455 | } |
| 456 | for (const LatPointId p : setNew) |
| 457 | latPoints[p].simple = simplifyCond(s: sNew, p); |
| 458 | return sNew; |
| 459 | } |
| 460 | |
| 461 | BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) { |
| 462 | // First determine if this lattice point is a *singleton*, i.e., |
| 463 | // the last point in a lattice, no other is less than this one. |
| 464 | bool isSingleton = true; |
| 465 | for (const LatPointId p1 : set(s0)) { |
| 466 | if (p0 != p1 && latGT(p0, p1)) { |
| 467 | isSingleton = false; |
| 468 | break; |
| 469 | } |
| 470 | } |
| 471 | |
| 472 | BitVector simple(latPoints[p0].bits); |
| 473 | bool reset = isSingleton && hasAnySparse(bits: simple); |
| 474 | const TensorLoopId be = simple.size(); |
| 475 | TensorLoopId offset = 0; // relative to the end |
| 476 | if (!reset) |
| 477 | // Starts resetting from a dense level, so that the first bit (if kept) |
| 478 | // is not undefined level-type. |
| 479 | for (unsigned b = 0; b < be; b++) { |
| 480 | if (simple[b] && getLvlType(b: TensorLoopId{b}).hasDenseSemantic()) { |
| 481 | offset = be - b - 1; // relative to the end |
| 482 | break; |
| 483 | } |
| 484 | } |
| 485 | |
| 486 | // Now apply the two basic rules. We also iterate the bits reversely to always |
| 487 | // keep the rightmost bit (which could possibly be a synthetic tensor). |
| 488 | for (unsigned b = be - 1 - offset, i = 0; i < be; |
| 489 | b = b == 0 ? be - 1 : b - 1, i++) { |
| 490 | // Slice on dense level has `locate` property as well, and can be optimized. |
| 491 | if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) { |
| 492 | const auto lt = getLvlType(b); |
| 493 | if (!lt.hasSparseSemantic()) { |
| 494 | if (reset) |
| 495 | simple.reset(Idx: b); |
| 496 | reset = true; |
| 497 | } |
| 498 | } |
| 499 | } |
| 500 | return simple; |
| 501 | } |
| 502 | |
| 503 | bool Merger::latGT(LatPointId i, LatPointId j) const { |
| 504 | const BitVector &bitsi = lat(p: i).bits; |
| 505 | const BitVector &bitsj = lat(p: j).bits; |
| 506 | assert(bitsi.size() == bitsj.size()); |
| 507 | if (bitsi.count() > bitsj.count()) { |
| 508 | for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++) |
| 509 | if (bitsj[b] && !bitsi[b]) |
| 510 | return false; |
| 511 | return true; |
| 512 | } |
| 513 | return false; |
| 514 | } |
| 515 | |
| 516 | bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const { |
| 517 | BitVector tmp(latPoints[j].bits); |
| 518 | tmp ^= latPoints[i].bits; |
| 519 | return !hasAnySparse(bits: tmp); |
| 520 | } |
| 521 | |
| 522 | bool Merger::expContainsTensor(ExprId e, TensorId t) const { |
| 523 | const auto &expr = exp(e); |
| 524 | // First we check `expIsTensor`. |
| 525 | if (expr.kind == TensorExp::Kind::kTensor) |
| 526 | return expr.tensor == t; |
| 527 | |
| 528 | switch (getExpArity(k: expr.kind)) { |
| 529 | case ExpArity::kNullary: |
| 530 | return false; |
| 531 | case ExpArity::kUnary: { |
| 532 | const ExprId e0 = expr.children.e0; |
| 533 | return expContainsTensor(e: e0, t); |
| 534 | } |
| 535 | case ExpArity::kBinary: { |
| 536 | const ExprId e0 = expr.children.e0; |
| 537 | const ExprId e1 = expr.children.e1; |
| 538 | return expContainsTensor(e: e0, t) || expContainsTensor(e: e1, t); |
| 539 | } |
| 540 | } |
| 541 | llvm_unreachable("unexpected arity" ); |
| 542 | } |
| 543 | |
| 544 | bool Merger::hasNegateOnOut(ExprId e) const { |
| 545 | const auto &expr = exp(e); |
| 546 | switch (expr.kind) { |
| 547 | case TensorExp::Kind::kNegF: |
| 548 | case TensorExp::Kind::kNegC: |
| 549 | case TensorExp::Kind::kNegI: |
| 550 | return expContainsTensor(e: expr.children.e0, t: outTensor); |
| 551 | case TensorExp::Kind::kSubF: |
| 552 | case TensorExp::Kind::kSubC: |
| 553 | case TensorExp::Kind::kSubI: |
| 554 | return expContainsTensor(e: expr.children.e1, t: outTensor) || |
| 555 | hasNegateOnOut(e: expr.children.e0); |
| 556 | case TensorExp::Kind::kDenseOp: { |
| 557 | bool lhsNeg = hasNegateOnOut(e: expr.children.e0); |
| 558 | if (!lhsNeg && expr.children.e1 != detail::kInvalidId) |
| 559 | return hasNegateOnOut(e: expr.children.e1); |
| 560 | return lhsNeg; |
| 561 | } |
| 562 | default: { |
| 563 | switch (getExpArity(k: expr.kind)) { |
| 564 | case ExpArity::kNullary: |
| 565 | return false; |
| 566 | case ExpArity::kUnary: |
| 567 | return hasNegateOnOut(e: expr.children.e0); |
| 568 | case ExpArity::kBinary: |
| 569 | return hasNegateOnOut(e: expr.children.e0) || |
| 570 | hasNegateOnOut(e: expr.children.e1); |
| 571 | } |
| 572 | } |
| 573 | } |
| 574 | llvm_unreachable("unexpected kind" ); |
| 575 | } |
| 576 | |
| 577 | bool Merger::isSingleCondition(TensorId t, ExprId e) const { |
| 578 | assert(isValidTensorId(t)); |
| 579 | const auto &expr = exp(e); |
| 580 | switch (expr.kind) { |
| 581 | // Leaf. |
| 582 | case TensorExp::Kind::kTensor: |
| 583 | return expr.tensor == t; |
| 584 | case TensorExp::Kind::kInvariant: |
| 585 | case TensorExp::Kind::kLoopVar: |
| 586 | case TensorExp::Kind::kSynZero: |
| 587 | return false; |
| 588 | // Unary operations. |
| 589 | case TensorExp::Kind::kAbsF: |
| 590 | case TensorExp::Kind::kAbsC: |
| 591 | case TensorExp::Kind::kAbsI: |
| 592 | case TensorExp::Kind::kCeilF: |
| 593 | case TensorExp::Kind::kFloorF: |
| 594 | case TensorExp::Kind::kSqrtF: |
| 595 | case TensorExp::Kind::kSqrtC: |
| 596 | case TensorExp::Kind::kExpm1F: |
| 597 | case TensorExp::Kind::kExpm1C: |
| 598 | case TensorExp::Kind::kLog1pF: |
| 599 | case TensorExp::Kind::kLog1pC: |
| 600 | case TensorExp::Kind::kRelu: |
| 601 | case TensorExp::Kind::kSinF: |
| 602 | case TensorExp::Kind::kSinC: |
| 603 | case TensorExp::Kind::kTanhF: |
| 604 | case TensorExp::Kind::kTanhC: |
| 605 | case TensorExp::Kind::kNegF: |
| 606 | case TensorExp::Kind::kNegC: |
| 607 | case TensorExp::Kind::kNegI: |
| 608 | case TensorExp::Kind::kTruncF: |
| 609 | case TensorExp::Kind::kExtF: |
| 610 | case TensorExp::Kind::kCastFS: |
| 611 | case TensorExp::Kind::kCastFU: |
| 612 | case TensorExp::Kind::kCastSF: |
| 613 | case TensorExp::Kind::kCastUF: |
| 614 | case TensorExp::Kind::kCastS: |
| 615 | case TensorExp::Kind::kCastU: |
| 616 | case TensorExp::Kind::kCastIdx: |
| 617 | case TensorExp::Kind::kTruncI: |
| 618 | case TensorExp::Kind::kCIm: |
| 619 | case TensorExp::Kind::kCRe: |
| 620 | case TensorExp::Kind::kBitCast: |
| 621 | case TensorExp::Kind::kUnary: |
| 622 | return isSingleCondition(t, e: expr.children.e0); |
| 623 | case TensorExp::Kind::kBinaryBranch: |
| 624 | case TensorExp::Kind::kSelect: |
| 625 | return false; |
| 626 | // Binary operations. |
| 627 | case TensorExp::Kind::kDivF: // note: x / c only |
| 628 | case TensorExp::Kind::kDivC: |
| 629 | case TensorExp::Kind::kDivS: |
| 630 | case TensorExp::Kind::kDivU: |
| 631 | assert(!maybeZero(expr.children.e1)); |
| 632 | return isSingleCondition(t, e: expr.children.e0); |
| 633 | case TensorExp::Kind::kShrS: // note: x >> inv only |
| 634 | case TensorExp::Kind::kShrU: |
| 635 | case TensorExp::Kind::kShlI: |
| 636 | assert(isInvariant(expr.children.e1)); |
| 637 | return isSingleCondition(t, e: expr.children.e0); |
| 638 | case TensorExp::Kind::kMulF: |
| 639 | case TensorExp::Kind::kMulC: |
| 640 | case TensorExp::Kind::kMulI: |
| 641 | case TensorExp::Kind::kAndI: |
| 642 | case TensorExp::Kind::kReduce: |
| 643 | if (isSingleCondition(t, e: expr.children.e0)) |
| 644 | return isSingleCondition(t, e: expr.children.e1) || |
| 645 | isInvariant(e: expr.children.e1); |
| 646 | if (isSingleCondition(t, e: expr.children.e1)) |
| 647 | return isInvariant(e: expr.children.e0); |
| 648 | return false; |
| 649 | case TensorExp::Kind::kAddF: |
| 650 | case TensorExp::Kind::kAddC: |
| 651 | case TensorExp::Kind::kAddI: |
| 652 | return isSingleCondition(t, e: expr.children.e0) && |
| 653 | isSingleCondition(t, e: expr.children.e1); |
| 654 | case TensorExp::Kind::kSubF: |
| 655 | case TensorExp::Kind::kSubC: |
| 656 | case TensorExp::Kind::kSubI: |
| 657 | case TensorExp::Kind::kOrI: |
| 658 | case TensorExp::Kind::kXorI: |
| 659 | case TensorExp::Kind::kCmpF: |
| 660 | case TensorExp::Kind::kCmpI: |
| 661 | case TensorExp::Kind::kBinary: |
| 662 | return false; |
| 663 | case TensorExp::Kind::kDenseOp: |
| 664 | // Since Merger guarantees all the operands of the kDenseOp to be dense, the |
| 665 | // operation must be single-condition. |
| 666 | return true; |
| 667 | } |
| 668 | llvm_unreachable("unexpected kind" ); |
| 669 | } |
| 670 | |
| 671 | bool Merger::hasAnySparse(const BitVector &bits) const { |
| 672 | for (TensorLoopId b : bits.set_bits()) { |
| 673 | const auto lt = getLvlType(b); |
| 674 | if (lt.hasSparseSemantic()) |
| 675 | return true; |
| 676 | } |
| 677 | return hasSparseIdxReduction(bits); |
| 678 | } |
| 679 | |
| 680 | bool Merger::hasSparseIdxReduction(const BitVector &bits) const { |
| 681 | for (TensorLoopId b : bits.set_bits()) |
| 682 | if (isSparseLvlWithNonTrivialIdxExp(b)) |
| 683 | return true; |
| 684 | return false; |
| 685 | } |
| 686 | |
| 687 | #ifndef NDEBUG |
| 688 | |
| 689 | //===----------------------------------------------------------------------===// |
| 690 | // Print methods (for debugging). |
| 691 | //===----------------------------------------------------------------------===// |
| 692 | |
| 693 | static const char *kindToOpSymbol(TensorExp::Kind kind) { |
| 694 | switch (kind) { |
| 695 | // Leaf. |
| 696 | case TensorExp::Kind::kTensor: |
| 697 | return "tensor" ; |
| 698 | case TensorExp::Kind::kInvariant: |
| 699 | return "invariant" ; |
| 700 | case TensorExp::Kind::kLoopVar: |
| 701 | return "index" ; |
| 702 | case TensorExp::Kind::kSynZero: |
| 703 | return "0" ; |
| 704 | // Unary operations. |
| 705 | case TensorExp::Kind::kAbsF: |
| 706 | case TensorExp::Kind::kAbsC: |
| 707 | case TensorExp::Kind::kAbsI: |
| 708 | return "abs" ; |
| 709 | case TensorExp::Kind::kCeilF: |
| 710 | return "ceil" ; |
| 711 | case TensorExp::Kind::kFloorF: |
| 712 | return "floor" ; |
| 713 | case TensorExp::Kind::kSqrtF: |
| 714 | case TensorExp::Kind::kSqrtC: |
| 715 | return "sqrt" ; |
| 716 | case TensorExp::Kind::kExpm1F: |
| 717 | case TensorExp::Kind::kExpm1C: |
| 718 | return "expm1" ; |
| 719 | case TensorExp::Kind::kLog1pF: |
| 720 | case TensorExp::Kind::kLog1pC: |
| 721 | return "log1p" ; |
| 722 | case TensorExp::Kind::kRelu: |
| 723 | return "relu" ; |
| 724 | case TensorExp::Kind::kSinF: |
| 725 | case TensorExp::Kind::kSinC: |
| 726 | return "sin" ; |
| 727 | case TensorExp::Kind::kTanhF: |
| 728 | case TensorExp::Kind::kTanhC: |
| 729 | return "tanh" ; |
| 730 | case TensorExp::Kind::kNegF: |
| 731 | case TensorExp::Kind::kNegC: |
| 732 | case TensorExp::Kind::kNegI: |
| 733 | return "-" ; |
| 734 | case TensorExp::Kind::kTruncF: |
| 735 | case TensorExp::Kind::kExtF: |
| 736 | case TensorExp::Kind::kCastFS: |
| 737 | case TensorExp::Kind::kCastFU: |
| 738 | case TensorExp::Kind::kCastSF: |
| 739 | case TensorExp::Kind::kCastUF: |
| 740 | case TensorExp::Kind::kCastS: |
| 741 | case TensorExp::Kind::kCastU: |
| 742 | case TensorExp::Kind::kCastIdx: |
| 743 | case TensorExp::Kind::kTruncI: |
| 744 | case TensorExp::Kind::kCIm: |
| 745 | return "complex.im" ; |
| 746 | case TensorExp::Kind::kCRe: |
| 747 | return "complex.re" ; |
| 748 | case TensorExp::Kind::kBitCast: |
| 749 | return "cast" ; |
| 750 | case TensorExp::Kind::kBinaryBranch: |
| 751 | return "binary_branch" ; |
| 752 | case TensorExp::Kind::kUnary: |
| 753 | return "unary" ; |
| 754 | case TensorExp::Kind::kSelect: |
| 755 | return "select" ; |
| 756 | // Binary operations. |
| 757 | case TensorExp::Kind::kMulF: |
| 758 | case TensorExp::Kind::kMulC: |
| 759 | case TensorExp::Kind::kMulI: |
| 760 | return "*" ; |
| 761 | case TensorExp::Kind::kDivF: |
| 762 | case TensorExp::Kind::kDivC: |
| 763 | case TensorExp::Kind::kDivS: |
| 764 | case TensorExp::Kind::kDivU: |
| 765 | return "/" ; |
| 766 | case TensorExp::Kind::kAddF: |
| 767 | case TensorExp::Kind::kAddC: |
| 768 | case TensorExp::Kind::kAddI: |
| 769 | return "+" ; |
| 770 | case TensorExp::Kind::kSubF: |
| 771 | case TensorExp::Kind::kSubC: |
| 772 | case TensorExp::Kind::kSubI: |
| 773 | return "-" ; |
| 774 | case TensorExp::Kind::kAndI: |
| 775 | return "&" ; |
| 776 | case TensorExp::Kind::kOrI: |
| 777 | return "|" ; |
| 778 | case TensorExp::Kind::kXorI: |
| 779 | return "^" ; |
| 780 | case TensorExp::Kind::kShrS: |
| 781 | return "a>>" ; |
| 782 | case TensorExp::Kind::kShrU: |
| 783 | return ">>" ; |
| 784 | case TensorExp::Kind::kShlI: |
| 785 | return "<<" ; |
| 786 | case TensorExp::Kind::kCmpF: |
| 787 | case TensorExp::Kind::kCmpI: |
| 788 | return "cmp" ; |
| 789 | case TensorExp::Kind::kBinary: |
| 790 | return "binary" ; |
| 791 | case TensorExp::Kind::kReduce: |
| 792 | return "reduce" ; |
| 793 | case TensorExp::Kind::kDenseOp: |
| 794 | return "dense" ; |
| 795 | } |
| 796 | llvm_unreachable("unexpected kind for symbol" ); |
| 797 | } |
| 798 | |
| 799 | void Merger::dumpExp(ExprId e) const { |
| 800 | const auto &expr = exp(e); |
| 801 | switch (expr.kind) { |
| 802 | // Leaf. |
| 803 | case TensorExp::Kind::kTensor: |
| 804 | if (expr.tensor == syntheticTensor) |
| 805 | llvm::dbgs() << "synthetic_" ; |
| 806 | else if (expr.tensor == outTensor) |
| 807 | llvm::dbgs() << "output_" ; |
| 808 | llvm::dbgs() << "tensor_" << expr.tensor; |
| 809 | break; |
| 810 | case TensorExp::Kind::kInvariant: |
| 811 | llvm::dbgs() << "invariant" ; |
| 812 | break; |
| 813 | case TensorExp::Kind::kSynZero: |
| 814 | llvm::dbgs() << "0" ; |
| 815 | break; |
| 816 | case TensorExp::Kind::kLoopVar: |
| 817 | llvm::dbgs() << "loopvar_" << expr.loop; |
| 818 | break; |
| 819 | // Unary operations. |
| 820 | case TensorExp::Kind::kAbsF: |
| 821 | case TensorExp::Kind::kAbsC: |
| 822 | case TensorExp::Kind::kAbsI: |
| 823 | case TensorExp::Kind::kCeilF: |
| 824 | case TensorExp::Kind::kFloorF: |
| 825 | case TensorExp::Kind::kSqrtF: |
| 826 | case TensorExp::Kind::kSqrtC: |
| 827 | case TensorExp::Kind::kExpm1F: |
| 828 | case TensorExp::Kind::kExpm1C: |
| 829 | case TensorExp::Kind::kLog1pF: |
| 830 | case TensorExp::Kind::kLog1pC: |
| 831 | case TensorExp::Kind::kRelu: |
| 832 | case TensorExp::Kind::kSinF: |
| 833 | case TensorExp::Kind::kSinC: |
| 834 | case TensorExp::Kind::kTanhF: |
| 835 | case TensorExp::Kind::kTanhC: |
| 836 | case TensorExp::Kind::kNegF: |
| 837 | case TensorExp::Kind::kNegC: |
| 838 | case TensorExp::Kind::kNegI: |
| 839 | case TensorExp::Kind::kTruncF: |
| 840 | case TensorExp::Kind::kExtF: |
| 841 | case TensorExp::Kind::kCastFS: |
| 842 | case TensorExp::Kind::kCastFU: |
| 843 | case TensorExp::Kind::kCastSF: |
| 844 | case TensorExp::Kind::kCastUF: |
| 845 | case TensorExp::Kind::kCastS: |
| 846 | case TensorExp::Kind::kCastU: |
| 847 | case TensorExp::Kind::kCastIdx: |
| 848 | case TensorExp::Kind::kTruncI: |
| 849 | case TensorExp::Kind::kCIm: |
| 850 | case TensorExp::Kind::kCRe: |
| 851 | case TensorExp::Kind::kBitCast: |
| 852 | case TensorExp::Kind::kBinaryBranch: |
| 853 | case TensorExp::Kind::kUnary: |
| 854 | case TensorExp::Kind::kSelect: |
| 855 | llvm::dbgs() << kindToOpSymbol(kind: expr.kind) << " " ; |
| 856 | dumpExp(e: expr.children.e0); |
| 857 | break; |
| 858 | // Binary operations. |
| 859 | case TensorExp::Kind::kMulF: |
| 860 | case TensorExp::Kind::kMulC: |
| 861 | case TensorExp::Kind::kMulI: |
| 862 | case TensorExp::Kind::kDivF: |
| 863 | case TensorExp::Kind::kDivC: |
| 864 | case TensorExp::Kind::kDivS: |
| 865 | case TensorExp::Kind::kDivU: |
| 866 | case TensorExp::Kind::kAddF: |
| 867 | case TensorExp::Kind::kAddC: |
| 868 | case TensorExp::Kind::kAddI: |
| 869 | case TensorExp::Kind::kSubF: |
| 870 | case TensorExp::Kind::kSubC: |
| 871 | case TensorExp::Kind::kSubI: |
| 872 | case TensorExp::Kind::kAndI: |
| 873 | case TensorExp::Kind::kOrI: |
| 874 | case TensorExp::Kind::kXorI: |
| 875 | case TensorExp::Kind::kShrS: |
| 876 | case TensorExp::Kind::kShrU: |
| 877 | case TensorExp::Kind::kShlI: |
| 878 | case TensorExp::Kind::kCmpF: |
| 879 | case TensorExp::Kind::kCmpI: |
| 880 | case TensorExp::Kind::kBinary: |
| 881 | case TensorExp::Kind::kReduce: |
| 882 | case TensorExp::Kind::kDenseOp: |
| 883 | llvm::dbgs() << "(" ; |
| 884 | dumpExp(e: expr.children.e0); |
| 885 | llvm::dbgs() << " " << kindToOpSymbol(kind: expr.kind); |
| 886 | if (expr.attr) |
| 887 | llvm::dbgs() << "{" << expr.attr << "}" ; |
| 888 | if (expr.children.e1 != detail::kInvalidId) { |
| 889 | llvm::dbgs() << " " ; |
| 890 | dumpExp(e: expr.children.e1); |
| 891 | llvm::dbgs() << ")" ; |
| 892 | } else { |
| 893 | assert(expr.kind == TensorExp::Kind::kDenseOp); |
| 894 | } |
| 895 | break; |
| 896 | } |
| 897 | } |
| 898 | |
| 899 | void Merger::dumpLat(LatPointId p) const { |
| 900 | const auto &point = lat(p); |
| 901 | llvm::dbgs() << "lat(" ; |
| 902 | dumpBits(bits: point.bits); |
| 903 | llvm::dbgs() << " :" ; |
| 904 | dumpBits(bits: point.simple); |
| 905 | llvm::dbgs() << " : " ; |
| 906 | dumpExp(e: point.exp); |
| 907 | llvm::dbgs() << " )\n" ; |
| 908 | } |
| 909 | |
| 910 | void Merger::dumpSet(LatSetId s) const { |
| 911 | const auto &ss = set(s); |
| 912 | llvm::dbgs() << "{ #" << ss.size() << "\n" ; |
| 913 | for (const LatPointId p : ss) { |
| 914 | llvm::dbgs() << " " ; |
| 915 | dumpLat(p); |
| 916 | } |
| 917 | llvm::dbgs() << "}\n" ; |
| 918 | } |
| 919 | |
| 920 | void Merger::dumpBits(const BitVector &bits) const { |
| 921 | for (TensorLoopId b = 0, be = bits.size(); b < be; b++) { |
| 922 | if (bits[b]) { |
| 923 | const TensorId t = tensor(b); |
| 924 | const LoopId i = loop(b); |
| 925 | const auto lt = lvlTypes[t][i]; |
| 926 | if (isLvlWithNonTrivialIdxExp(b)) |
| 927 | llvm::dbgs() << " DEP_" << t << "_" << i; |
| 928 | else |
| 929 | llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(lt); |
| 930 | } |
| 931 | } |
| 932 | } |
| 933 | |
| 934 | #endif // NDEBUG |
| 935 | |
| 936 | //===----------------------------------------------------------------------===// |
| 937 | // Builder methods. |
| 938 | //===----------------------------------------------------------------------===// |
| 939 | |
| 940 | LatSetId Merger::buildLattices(ExprId e, LoopId i) { |
| 941 | // NOTE: The `expr` reference will be invalidated by recursive calls |
| 942 | // (and any other method that may add new expressions); therefore, the |
| 943 | // code below must make sure to copy fields of `expr` into local variables |
| 944 | // before making any recursive calls. |
| 945 | const auto &expr = exp(e); |
| 946 | const TensorExp::Kind kind = expr.kind; |
| 947 | switch (kind) { |
| 948 | // Leaf. |
| 949 | case TensorExp::Kind::kTensor: |
| 950 | case TensorExp::Kind::kInvariant: |
| 951 | case TensorExp::Kind::kSynZero: |
| 952 | case TensorExp::Kind::kLoopVar: { |
| 953 | // Either the loop-var is really used in the tensor expression, or it is |
| 954 | // set to the undefined loop-var in that level. An invariant expression, |
| 955 | // a proper index value, and a truly dynamic sparse output tensor are set |
| 956 | // to a synthetic tensor with undefined indices only to ensure the |
| 957 | // iteration space is not skipped as a result of their contents. |
| 958 | const LatSetId s = addSet(); |
| 959 | TensorId t = syntheticTensor; |
| 960 | if (kind == TensorExp::Kind::kTensor) { |
| 961 | t = expr.tensor; |
| 962 | if (hasSparseOut && t == outTensor) |
| 963 | t = syntheticTensor; |
| 964 | } |
| 965 | latSets[s].push_back(Elt: addLat(t, i, e)); |
| 966 | return s; |
| 967 | } |
| 968 | // Unary operations. |
| 969 | case TensorExp::Kind::kAbsF: |
| 970 | case TensorExp::Kind::kAbsC: |
| 971 | case TensorExp::Kind::kAbsI: |
| 972 | case TensorExp::Kind::kCeilF: |
| 973 | case TensorExp::Kind::kFloorF: |
| 974 | case TensorExp::Kind::kSqrtF: |
| 975 | case TensorExp::Kind::kSqrtC: |
| 976 | case TensorExp::Kind::kExpm1F: |
| 977 | case TensorExp::Kind::kExpm1C: |
| 978 | case TensorExp::Kind::kLog1pF: |
| 979 | case TensorExp::Kind::kLog1pC: |
| 980 | case TensorExp::Kind::kRelu: |
| 981 | case TensorExp::Kind::kSinF: |
| 982 | case TensorExp::Kind::kSinC: |
| 983 | case TensorExp::Kind::kTanhF: |
| 984 | case TensorExp::Kind::kTanhC: |
| 985 | case TensorExp::Kind::kNegF: |
| 986 | case TensorExp::Kind::kNegC: |
| 987 | case TensorExp::Kind::kNegI: |
| 988 | case TensorExp::Kind::kTruncF: |
| 989 | case TensorExp::Kind::kExtF: |
| 990 | case TensorExp::Kind::kCastFS: |
| 991 | case TensorExp::Kind::kCastFU: |
| 992 | case TensorExp::Kind::kCastSF: |
| 993 | case TensorExp::Kind::kCastUF: |
| 994 | case TensorExp::Kind::kCastS: |
| 995 | case TensorExp::Kind::kCastU: |
| 996 | case TensorExp::Kind::kCastIdx: |
| 997 | case TensorExp::Kind::kTruncI: |
| 998 | case TensorExp::Kind::kCIm: |
| 999 | case TensorExp::Kind::kCRe: |
| 1000 | case TensorExp::Kind::kBitCast: |
| 1001 | // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the |
| 1002 | // lattice set of the operand through the operator into a new set. |
| 1003 | // |
| 1004 | // -y|!y | y | |
| 1005 | // --+---+---+ |
| 1006 | // | 0 |-y | |
| 1007 | { |
| 1008 | const ExprId e0 = expr.children.e0; |
| 1009 | const Value v = expr.val; |
| 1010 | Attribute a = expr.attr; |
| 1011 | return mapSet(kind, s0: buildLattices(e: e0, i), v, op: nullptr, a); |
| 1012 | } |
| 1013 | case TensorExp::Kind::kBinaryBranch: |
| 1014 | case TensorExp::Kind::kSelect: |
| 1015 | // The left or right half of a binary operation which has already |
| 1016 | // been split into separate operations for each region. |
| 1017 | { |
| 1018 | const ExprId e0 = expr.children.e0; |
| 1019 | Operation *const op = expr.op; |
| 1020 | return mapSet(kind, s0: buildLattices(e: e0, i), v: Value(), op); |
| 1021 | } |
| 1022 | case TensorExp::Kind::kUnary: |
| 1023 | // A custom unary operation. |
| 1024 | // |
| 1025 | // op y| !y | y | |
| 1026 | // ----+----------+------------+ |
| 1027 | // | absent() | present(y) | |
| 1028 | { |
| 1029 | const ExprId e0 = expr.children.e0; |
| 1030 | UnaryOp unop = cast<UnaryOp>(expr.op); |
| 1031 | const LatSetId child0 = buildLattices(e: e0, i); |
| 1032 | Region &absentRegion = unop.getAbsentRegion(); |
| 1033 | if (absentRegion.empty()) { |
| 1034 | // Simple mapping over existing values. |
| 1035 | return mapSet(kind, s0: child0, v: Value(), op: unop); |
| 1036 | } |
| 1037 | // Use a disjunction with `unop` on the left and the absent value as an |
| 1038 | // invariant on the right. |
| 1039 | Block &absentBlock = absentRegion.front(); |
| 1040 | YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator()); |
| 1041 | const Value absentVal = absentYield.getSingleResult(); |
| 1042 | const ExprId rhs = addInvariantExp(v: absentVal); |
| 1043 | return disjSet(e, s0: child0, s1: buildLattices(e: rhs, i), op: unop); |
| 1044 | } |
| 1045 | // Binary operations. |
| 1046 | case TensorExp::Kind::kMulF: |
| 1047 | case TensorExp::Kind::kMulC: |
| 1048 | case TensorExp::Kind::kMulI: |
| 1049 | case TensorExp::Kind::kAndI: |
| 1050 | // A multiplicative operation only needs to be performed |
| 1051 | // for the conjunction of sparse iteration spaces. |
| 1052 | // |
| 1053 | // x*y|!y | y | |
| 1054 | // ---+---+---+ |
| 1055 | // !x | 0 | 0 | |
| 1056 | // x | 0 |x*y| |
| 1057 | // |
| 1058 | // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored. |
| 1059 | { |
| 1060 | const ExprId e0 = expr.children.e0; |
| 1061 | const ExprId e1 = expr.children.e1; |
| 1062 | return conjSet(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i)); |
| 1063 | } |
| 1064 | case TensorExp::Kind::kDivF: |
| 1065 | case TensorExp::Kind::kDivC: |
| 1066 | case TensorExp::Kind::kDivS: |
| 1067 | case TensorExp::Kind::kDivU: |
| 1068 | // A division is tricky, since 0/0, 0/c, c/0 all have |
| 1069 | // specific outcomes for floating-point and integers. |
| 1070 | // Thus, we need to traverse the full iteration space. |
| 1071 | // |
| 1072 | // x/y|!y | y | |
| 1073 | // ---+---+---+ |
| 1074 | // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero |
| 1075 | // x |x/0|x/y| INT: x/0=exception for any x |
| 1076 | // |
| 1077 | // TODO: for now we "fixed" this by only accepting x/c cases |
| 1078 | // during expression building, so that the conjunction |
| 1079 | // rules applies (viz. x/c = x*(1/c) as far as lattice |
| 1080 | // construction is concerned). |
| 1081 | { |
| 1082 | const ExprId e0 = expr.children.e0; |
| 1083 | const ExprId e1 = expr.children.e1; |
| 1084 | assert(!maybeZero(e1)); |
| 1085 | return conjSet(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i)); |
| 1086 | } |
| 1087 | case TensorExp::Kind::kAddF: |
| 1088 | case TensorExp::Kind::kAddC: |
| 1089 | case TensorExp::Kind::kAddI: |
| 1090 | case TensorExp::Kind::kSubF: |
| 1091 | case TensorExp::Kind::kSubC: |
| 1092 | case TensorExp::Kind::kSubI: |
| 1093 | case TensorExp::Kind::kOrI: |
| 1094 | case TensorExp::Kind::kXorI: |
| 1095 | // An additive operation needs to be performed |
| 1096 | // for the disjunction of sparse iteration spaces. |
| 1097 | // |
| 1098 | // x+y|!y | y | x-y|!y | y | |
| 1099 | // ---+---+---+ ---+---+---+ |
| 1100 | // !x | 0 | y | !x | 0 |-y | |
| 1101 | // x | x |x+y| x | x |x-y| |
| 1102 | { |
| 1103 | const ExprId e0 = expr.children.e0; |
| 1104 | const ExprId e1 = expr.children.e1; |
| 1105 | return disjSet(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i)); |
| 1106 | } |
| 1107 | case TensorExp::Kind::kCmpF: |
| 1108 | case TensorExp::Kind::kCmpI: |
| 1109 | // A comparison operation needs to be performed |
| 1110 | // for the disjunction of sparse iteration spaces. |
| 1111 | // |
| 1112 | // x < y | !y | y | |
| 1113 | // -------+-------+-------+ |
| 1114 | // !x | 0 | 0 < y | |
| 1115 | // x | x < 0 | x < y | |
| 1116 | { |
| 1117 | const ExprId e0 = expr.children.e0; |
| 1118 | const ExprId e1 = expr.children.e1; |
| 1119 | return disjSetWithZero(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i)); |
| 1120 | } |
| 1121 | case TensorExp::Kind::kShrS: |
| 1122 | case TensorExp::Kind::kShrU: |
| 1123 | case TensorExp::Kind::kShlI: |
| 1124 | // A shift operation by an invariant amount (viz. tensor expressions |
| 1125 | // can only occur at the left-hand-side of the operator) can be handled |
| 1126 | // with the conjunction rule. |
| 1127 | { |
| 1128 | const ExprId e0 = expr.children.e0; |
| 1129 | const ExprId e1 = expr.children.e1; |
| 1130 | assert(isInvariant(e1)); |
| 1131 | return conjSet(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i)); |
| 1132 | } |
| 1133 | case TensorExp::Kind::kBinary: |
| 1134 | // A custom binary operation. |
| 1135 | // |
| 1136 | // x op y| !y | y | |
| 1137 | // ------+---------+--------------+ |
| 1138 | // !x | empty | right(y) | |
| 1139 | // x | left(x) | overlap(x,y) | |
| 1140 | { |
| 1141 | const ExprId e0 = expr.children.e0; |
| 1142 | const ExprId e1 = expr.children.e1; |
| 1143 | BinaryOp binop = cast<BinaryOp>(expr.op); |
| 1144 | const LatSetId child0 = buildLattices(e: e0, i); |
| 1145 | const LatSetId child1 = buildLattices(e: e1, i); |
| 1146 | Region &leftRegion = binop.getLeftRegion(); |
| 1147 | Region &rightRegion = binop.getRightRegion(); |
| 1148 | // Left Region. |
| 1149 | Operation *leftYield = nullptr; |
| 1150 | if (!leftRegion.empty()) { |
| 1151 | Block &leftBlock = leftRegion.front(); |
| 1152 | leftYield = leftBlock.getTerminator(); |
| 1153 | } |
| 1154 | // Right Region. |
| 1155 | Operation *rightYield = nullptr; |
| 1156 | if (!rightRegion.empty()) { |
| 1157 | Block &rightBlock = rightRegion.front(); |
| 1158 | rightYield = rightBlock.getTerminator(); |
| 1159 | } |
| 1160 | bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty(); |
| 1161 | bool includeRight = binop.getRightIdentity() || !rightRegion.empty(); |
| 1162 | return combiSet(e, s0: child0, s1: child1, orig: binop, includeLeft, |
| 1163 | ltrans: TensorExp::Kind::kBinaryBranch, opleft: leftYield, includeRight, |
| 1164 | rtrans: TensorExp::Kind::kBinaryBranch, opright: rightYield); |
| 1165 | } |
| 1166 | case TensorExp::Kind::kReduce: |
| 1167 | // A custom reduce operation. |
| 1168 | { |
| 1169 | const ExprId e0 = expr.children.e0; |
| 1170 | const ExprId e1 = expr.children.e1; |
| 1171 | Operation *const op = expr.op; |
| 1172 | return conjSet(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i), op); |
| 1173 | } |
| 1174 | case TensorExp::Kind::kDenseOp: { |
| 1175 | // It does not really matter whether we use conjunctive/disjunctive set |
| 1176 | // here, as all the operands of kDenseOp must be dense, the disjunctive set |
| 1177 | // will be optimized into conjunctive set eventually. |
| 1178 | if (expr.children.e1 == detail::kInvalidId) { |
| 1179 | const ExprId e0 = expr.children.e0; |
| 1180 | Operation *const op = expr.op; |
| 1181 | return mapSet(kind, s0: buildLattices(e: e0, i), v: Value(), op); |
| 1182 | } |
| 1183 | |
| 1184 | const ExprId e0 = expr.children.e0; |
| 1185 | const ExprId e1 = expr.children.e1; |
| 1186 | Operation *const op = expr.op; |
| 1187 | return conjSet(e, s0: buildLattices(e: e0, i), s1: buildLattices(e: e1, i), op); |
| 1188 | } |
| 1189 | } |
| 1190 | llvm_unreachable("unexpected expression kind" ); |
| 1191 | } |
| 1192 | |
| 1193 | std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { |
| 1194 | // Build the linalg semantics backward from yield. |
| 1195 | Operation *yield = op.getRegion().front().getTerminator(); |
| 1196 | assert(isa<linalg::YieldOp>(yield)); |
| 1197 | return buildTensorExp(op: op, v: yield->getOperand(idx: 0)).first; |
| 1198 | } |
| 1199 | |
| 1200 | /// Only returns true if we are certain this is a zero. |
| 1201 | static bool isCertainZero(Value val) { |
| 1202 | if (auto c = val.getDefiningOp<complex::ConstantOp>()) { |
| 1203 | ArrayAttr arrayAttr = c.getValue(); |
| 1204 | return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() && |
| 1205 | cast<FloatAttr>(arrayAttr[1]).getValue().isZero(); |
| 1206 | } |
| 1207 | if (auto c = val.getDefiningOp<arith::ConstantIntOp>()) |
| 1208 | return c.value() == 0; |
| 1209 | if (auto c = val.getDefiningOp<arith::ConstantFloatOp>()) |
| 1210 | return c.value().isZero(); |
| 1211 | return false; |
| 1212 | } |
| 1213 | |
| 1214 | /// Only returns false if we are certain this is a nonzero. |
| 1215 | bool Merger::maybeZero(ExprId e) const { |
| 1216 | const auto &expr = exp(e); |
| 1217 | if (expr.kind == TensorExp::Kind::kInvariant) { |
| 1218 | // Note that this is different from isCertainZero() in a subtle |
| 1219 | // way by always returning true for non-constants. |
| 1220 | if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) { |
| 1221 | ArrayAttr arrayAttr = c.getValue(); |
| 1222 | return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() && |
| 1223 | cast<FloatAttr>(arrayAttr[1]).getValue().isZero(); |
| 1224 | } |
| 1225 | if (auto c = expr.val.getDefiningOp<arith::ConstantIntOp>()) |
| 1226 | return c.value() == 0; |
| 1227 | if (auto c = expr.val.getDefiningOp<arith::ConstantFloatOp>()) |
| 1228 | return c.value().isZero(); |
| 1229 | } |
| 1230 | return true; |
| 1231 | } |
| 1232 | |
| 1233 | Type Merger::inferType(ExprId e, Value src) const { |
| 1234 | // Obtain the destination type from the cast node. |
| 1235 | Type dtp = exp(e).val.getType(); |
| 1236 | // Inspect source type. For vector types, apply the same |
| 1237 | // vectorization to the destination type. |
| 1238 | if (auto vtp = dyn_cast<VectorType>(src.getType())) |
| 1239 | return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims()); |
| 1240 | return dtp; |
| 1241 | } |
| 1242 | |
| 1243 | /// Ensures that the sparsifier can generate code for expression. |
| 1244 | static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) { |
| 1245 | // Arguments are always admissible. |
| 1246 | if (isa<BlockArgument>(Val: v)) |
| 1247 | return true; |
| 1248 | // Accept index anywhere. |
| 1249 | Operation *def = v.getDefiningOp(); |
| 1250 | if (isa<linalg::IndexOp>(def)) |
| 1251 | return true; |
| 1252 | // Operation defined outside branch. |
| 1253 | if (def->getBlock() != block) |
| 1254 | return def->getBlock() != op->getBlock(); // invariant? |
| 1255 | // Operation defined within branch. Anything is accepted, |
| 1256 | // as long as all subexpressions are admissible. |
| 1257 | for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) |
| 1258 | if (!isAdmissibleBranchExp(op, block, v: def->getOperand(idx: i))) |
| 1259 | return false; |
| 1260 | return true; |
| 1261 | } |
| 1262 | |
| 1263 | /// Ensures that the sparsifier can generate code for branch. |
| 1264 | static bool isAdmissibleBranch(Operation *op, Region ®ion) { |
| 1265 | if (region.empty()) |
| 1266 | return true; |
| 1267 | // Build the semi-ring branch semantics backward from yield. |
| 1268 | Operation *yield = region.front().getTerminator(); |
| 1269 | assert(isa<YieldOp>(yield)); |
| 1270 | return isAdmissibleBranchExp(op, block: ®ion.front(), v: yield->getOperand(idx: 0)); |
| 1271 | } |
| 1272 | |
| 1273 | // Recognizes a direct GT comparison. |
| 1274 | static bool isGreater(TensorExp::Kind kind, Attribute attr) { |
| 1275 | if (kind == TensorExp::Kind::kCmpI) { |
| 1276 | auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr).getValue(); |
| 1277 | return pred == arith::CmpIPredicate::ugt || |
| 1278 | pred == arith::CmpIPredicate::sgt; |
| 1279 | } |
| 1280 | if (kind == TensorExp::Kind::kCmpF) { |
| 1281 | auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr).getValue(); |
| 1282 | return pred == arith::CmpFPredicate::UGT || |
| 1283 | pred == arith::CmpFPredicate::OGT; |
| 1284 | } |
| 1285 | return false; |
| 1286 | } |
| 1287 | |
| 1288 | std::pair<std::optional<ExprId>, bool> |
| 1289 | Merger::buildTensorExp(linalg::GenericOp op, Value v) { |
| 1290 | // Recursion leaves. |
| 1291 | if (auto arg = dyn_cast<BlockArgument>(Val&: v)) { |
| 1292 | const TensorId tid = makeTensorId(t: arg.getArgNumber()); |
| 1293 | // Any argument of the generic op that is not marked as a scalar |
| 1294 | // argument is considered a tensor, indexed by the implicit loop |
| 1295 | // bounds. This includes rank-0 tensor arguments. |
| 1296 | if (arg.getOwner()->getParentOp() == op) { |
| 1297 | OpOperand &t = op->getOpOperand(tid); |
| 1298 | bool hasSpDep = getSparseTensorEncoding(t.get().getType()) != nullptr; |
| 1299 | if (!op.isScalar(&t)) |
| 1300 | return {addTensorExp(t: tid), hasSpDep}; |
| 1301 | v = t.get(); // get scalar value |
| 1302 | } |
| 1303 | // Any other argument (marked as scalar argument for the generic op |
| 1304 | // or belonging to an enveloping op) is considered invariant. |
| 1305 | return {addInvariantExp(v), /*hasSpDep=*/false}; |
| 1306 | } |
| 1307 | |
| 1308 | // Something defined outside is invariant. |
| 1309 | Operation *def = v.getDefiningOp(); |
| 1310 | if (def->getBlock() != &op.getRegion().front()) |
| 1311 | return {addInvariantExp(v), /*hasSpDep=*/false}; |
| 1312 | // Construct index operations. |
| 1313 | if (def->getNumOperands() == 0) { |
| 1314 | if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) |
| 1315 | return {addLoopVarExp(i: makeLoopId(i: indexOp.getDim())), /*hasSpDep=*/false}; |
| 1316 | } |
| 1317 | |
| 1318 | // Construct unary operations if subexpression can be built. |
| 1319 | if (def->getNumOperands() == 1) { |
| 1320 | const auto [x, hasSpDep] = buildTensorExp(op: op, v: def->getOperand(idx: 0)); |
| 1321 | if (x.has_value()) { |
| 1322 | const ExprId e = *x; |
| 1323 | if (isa<math::AbsFOp>(def)) |
| 1324 | return {addExp(k: TensorExp::Kind::kAbsF, e0: e), hasSpDep}; |
| 1325 | if (isa<complex::AbsOp>(def)) |
| 1326 | return {addExp(k: TensorExp::Kind::kAbsC, e0: e), hasSpDep}; |
| 1327 | if (isa<math::AbsIOp>(def)) |
| 1328 | return {addExp(k: TensorExp::Kind::kAbsI, e0: e), hasSpDep}; |
| 1329 | if (isa<math::CeilOp>(def)) |
| 1330 | return {addExp(k: TensorExp::Kind::kCeilF, e0: e), hasSpDep}; |
| 1331 | if (isa<math::FloorOp>(def)) |
| 1332 | return {addExp(k: TensorExp::Kind::kFloorF, e0: e), hasSpDep}; |
| 1333 | if (isa<math::SqrtOp>(def)) |
| 1334 | return {addExp(k: TensorExp::Kind::kSqrtF, e0: e), hasSpDep}; |
| 1335 | if (isa<complex::SqrtOp>(def)) |
| 1336 | return {addExp(k: TensorExp::Kind::kSqrtC, e0: e), hasSpDep}; |
| 1337 | if (isa<math::ExpM1Op>(def)) |
| 1338 | return {addExp(k: TensorExp::Kind::kExpm1F, e0: e), hasSpDep}; |
| 1339 | if (isa<complex::Expm1Op>(def)) |
| 1340 | return {addExp(k: TensorExp::Kind::kExpm1C, e0: e), hasSpDep}; |
| 1341 | if (isa<math::Log1pOp>(def)) |
| 1342 | return {addExp(k: TensorExp::Kind::kLog1pF, e0: e), hasSpDep}; |
| 1343 | if (isa<complex::Log1pOp>(def)) |
| 1344 | return {addExp(k: TensorExp::Kind::kLog1pC, e0: e), hasSpDep}; |
| 1345 | if (isa<math::SinOp>(def)) |
| 1346 | return {addExp(k: TensorExp::Kind::kSinF, e0: e), hasSpDep}; |
| 1347 | if (isa<complex::SinOp>(def)) |
| 1348 | return {addExp(k: TensorExp::Kind::kSinC, e0: e), hasSpDep}; |
| 1349 | if (isa<math::TanhOp>(def)) |
| 1350 | return {addExp(k: TensorExp::Kind::kTanhF, e0: e), hasSpDep}; |
| 1351 | if (isa<complex::TanhOp>(def)) |
| 1352 | return {addExp(k: TensorExp::Kind::kTanhC, e0: e), hasSpDep}; |
| 1353 | if (isa<arith::NegFOp>(def)) |
| 1354 | return {addExp(k: TensorExp::Kind::kNegF, e0: e), hasSpDep}; // no negi in std |
| 1355 | if (isa<complex::NegOp>(def)) |
| 1356 | return {addExp(k: TensorExp::Kind::kNegC, e0: e), hasSpDep}; |
| 1357 | if (isa<arith::TruncFOp>(def)) |
| 1358 | return {addExp(k: TensorExp::Kind::kTruncF, e, v), hasSpDep}; |
| 1359 | if (isa<arith::ExtFOp>(def)) |
| 1360 | return {addExp(k: TensorExp::Kind::kExtF, e, v), hasSpDep}; |
| 1361 | if (isa<arith::FPToSIOp>(def)) |
| 1362 | return {addExp(k: TensorExp::Kind::kCastFS, e, v), hasSpDep}; |
| 1363 | if (isa<arith::FPToUIOp>(def)) |
| 1364 | return {addExp(k: TensorExp::Kind::kCastFU, e, v), hasSpDep}; |
| 1365 | if (isa<arith::SIToFPOp>(def)) |
| 1366 | return {addExp(k: TensorExp::Kind::kCastSF, e, v), hasSpDep}; |
| 1367 | if (isa<arith::UIToFPOp>(def)) |
| 1368 | return {addExp(k: TensorExp::Kind::kCastUF, e, v), hasSpDep}; |
| 1369 | if (isa<arith::ExtSIOp>(def)) |
| 1370 | return {addExp(k: TensorExp::Kind::kCastS, e, v), hasSpDep}; |
| 1371 | if (isa<arith::ExtUIOp>(def)) |
| 1372 | return {addExp(k: TensorExp::Kind::kCastU, e, v), hasSpDep}; |
| 1373 | if (isa<arith::IndexCastOp>(def)) |
| 1374 | return {addExp(k: TensorExp::Kind::kCastIdx, e, v), hasSpDep}; |
| 1375 | if (isa<arith::TruncIOp>(def)) |
| 1376 | return {addExp(k: TensorExp::Kind::kTruncI, e, v), hasSpDep}; |
| 1377 | if (isa<complex::ImOp>(def)) |
| 1378 | return {addExp(k: TensorExp::Kind::kCIm, e0: e), hasSpDep}; |
| 1379 | if (isa<complex::ReOp>(def)) |
| 1380 | return {addExp(k: TensorExp::Kind::kCRe, e0: e), hasSpDep}; |
| 1381 | if (isa<arith::BitcastOp>(def)) |
| 1382 | return {addExp(k: TensorExp::Kind::kBitCast, e, v), hasSpDep}; |
| 1383 | if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) { |
| 1384 | if (isAdmissibleBranch(unop, unop.getPresentRegion()) && |
| 1385 | isAdmissibleBranch(unop, unop.getAbsentRegion())) |
| 1386 | return {addExp(k: TensorExp::Kind::kUnary, e, v: Value(), op: def), hasSpDep}; |
| 1387 | } |
| 1388 | if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) { |
| 1389 | if (isAdmissibleBranch(selop, selop.getRegion())) |
| 1390 | return {addExp(k: TensorExp::Kind::kSelect, e, v: Value(), op: def), hasSpDep}; |
| 1391 | } |
| 1392 | } |
| 1393 | } |
| 1394 | |
| 1395 | // Construct binary operations if subexpressions can be built. |
| 1396 | // See buildLattices() for an explanation of rejecting certain |
| 1397 | // division and shift operations. |
| 1398 | if (def->getNumOperands() == 2) { |
| 1399 | const auto [x, xSpVals] = buildTensorExp(op: op, v: def->getOperand(idx: 0)); |
| 1400 | const auto [y, ySpVals] = buildTensorExp(op: op, v: def->getOperand(idx: 1)); |
| 1401 | // For a conjunctive operation, it yields a "sparse" result if any operand |
| 1402 | // is sparse. For a disjunctive operation, it yields a "sparse" result if |
| 1403 | // all operands are sparse. |
| 1404 | bool conjSpVals = xSpVals || ySpVals; |
| 1405 | bool disjSpVals = xSpVals && ySpVals; |
| 1406 | if (x.has_value() && y.has_value()) { |
| 1407 | const ExprId e0 = *x; |
| 1408 | const ExprId e1 = *y; |
| 1409 | if (isa<arith::MulFOp>(def)) |
| 1410 | return {addExp(k: TensorExp::Kind::kMulF, e0, e1), conjSpVals}; |
| 1411 | if (isa<complex::MulOp>(def)) |
| 1412 | return {addExp(k: TensorExp::Kind::kMulC, e0, e1), conjSpVals}; |
| 1413 | if (isa<arith::MulIOp>(def)) |
| 1414 | return {addExp(k: TensorExp::Kind::kMulI, e0, e1), conjSpVals}; |
| 1415 | if (isa<arith::DivFOp>(def) && !maybeZero(e1)) |
| 1416 | return {addExp(k: TensorExp::Kind::kDivF, e0, e1), conjSpVals}; |
| 1417 | if (isa<complex::DivOp>(def) && !maybeZero(e1)) |
| 1418 | return {addExp(k: TensorExp::Kind::kDivC, e0, e1), conjSpVals}; |
| 1419 | if (isa<arith::DivSIOp>(def) && !maybeZero(e1)) |
| 1420 | return {addExp(k: TensorExp::Kind::kDivS, e0, e1), conjSpVals}; |
| 1421 | if (isa<arith::DivUIOp>(def) && !maybeZero(e1)) |
| 1422 | return {addExp(k: TensorExp::Kind::kDivU, e0, e1), conjSpVals}; |
| 1423 | if (isa<arith::AddFOp>(def)) |
| 1424 | return {addExp(k: TensorExp::Kind::kAddF, e0, e1), disjSpVals}; |
| 1425 | if (isa<complex::AddOp>(def)) |
| 1426 | return {addExp(k: TensorExp::Kind::kAddC, e0, e1), disjSpVals}; |
| 1427 | if (isa<arith::AddIOp>(def)) |
| 1428 | return {addExp(k: TensorExp::Kind::kAddI, e0, e1), disjSpVals}; |
| 1429 | if (isa<arith::SubFOp>(def)) |
| 1430 | return {addExp(k: TensorExp::Kind::kSubF, e0, e1), disjSpVals}; |
| 1431 | if (isa<complex::SubOp>(def)) |
| 1432 | return {addExp(k: TensorExp::Kind::kSubC, e0, e1), disjSpVals}; |
| 1433 | if (isa<arith::SubIOp>(def)) |
| 1434 | return {addExp(k: TensorExp::Kind::kSubI, e0, e1), disjSpVals}; |
| 1435 | if (isa<arith::AndIOp>(def)) |
| 1436 | return {addExp(k: TensorExp::Kind::kAndI, e0, e1), conjSpVals}; |
| 1437 | if (isa<arith::OrIOp>(def)) |
| 1438 | return {addExp(k: TensorExp::Kind::kOrI, e0, e1), disjSpVals}; |
| 1439 | if (isa<arith::XOrIOp>(def)) |
| 1440 | return {addExp(k: TensorExp::Kind::kXorI, e0, e1), disjSpVals}; |
| 1441 | if (isa<arith::ShRSIOp>(def) && isInvariant(e1)) |
| 1442 | return {addExp(k: TensorExp::Kind::kShrS, e0, e1), conjSpVals}; |
| 1443 | if (isa<arith::ShRUIOp>(def) && isInvariant(e1)) |
| 1444 | return {addExp(k: TensorExp::Kind::kShrU, e0, e1), conjSpVals}; |
| 1445 | if (isa<arith::ShLIOp>(def) && isInvariant(e1)) |
| 1446 | return {addExp(k: TensorExp::Kind::kShlI, e0, e1), conjSpVals}; |
| 1447 | if (auto ci = dyn_cast<arith::CmpIOp>(def)) { |
| 1448 | if (ci.getPredicate() == arith::CmpIPredicate::eq && |
| 1449 | ci.getPredicate() == arith::CmpIPredicate::sle && |
| 1450 | ci.getPredicate() == arith::CmpIPredicate::sge && |
| 1451 | ci.getPredicate() == arith::CmpIPredicate::ule && |
| 1452 | ci.getPredicate() == arith::CmpIPredicate::uge) { |
| 1453 | // We can not sparsify comparison with equal, this is because 0 <= 0 |
| 1454 | // yields true, and thus densifies the result. |
| 1455 | return {std::nullopt, false}; |
| 1456 | } |
| 1457 | |
| 1458 | auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr, |
| 1459 | ci.getPredicateAttr()); |
| 1460 | return {e, conjSpVals}; |
| 1461 | } |
| 1462 | if (auto cf = dyn_cast<arith::CmpFOp>(def)) { |
| 1463 | if (cf.getPredicate() == arith::CmpFPredicate::OEQ && |
| 1464 | cf.getPredicate() == arith::CmpFPredicate::OGE && |
| 1465 | cf.getPredicate() == arith::CmpFPredicate::OLE && |
| 1466 | cf.getPredicate() == arith::CmpFPredicate::ONE && |
| 1467 | cf.getPredicate() == arith::CmpFPredicate::UEQ && |
| 1468 | cf.getPredicate() == arith::CmpFPredicate::UGE && |
| 1469 | cf.getPredicate() == arith::CmpFPredicate::ULE && |
| 1470 | cf.getPredicate() == arith::CmpFPredicate::ORD && |
| 1471 | cf.getPredicate() == arith::CmpFPredicate::UNO) { |
| 1472 | // We can not sparsify comparison with equal, this is because 0 <= 0 |
| 1473 | // yields true, and thus densifies the result. |
| 1474 | return {std::nullopt, false}; |
| 1475 | } |
| 1476 | auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr, |
| 1477 | cf.getPredicateAttr()); |
| 1478 | return {e, conjSpVals}; |
| 1479 | } |
| 1480 | if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) { |
| 1481 | if (isAdmissibleBranch(binop, binop.getOverlapRegion()) && |
| 1482 | (binop.getLeftIdentity() || |
| 1483 | isAdmissibleBranch(binop, binop.getLeftRegion())) && |
| 1484 | (binop.getRightIdentity() || |
| 1485 | isAdmissibleBranch(binop, binop.getRightRegion()))) |
| 1486 | return {addExp(k: TensorExp::Kind::kBinary, e0, e1, op: def), conjSpVals}; |
| 1487 | } |
| 1488 | } |
| 1489 | } |
| 1490 | |
| 1491 | // Construct ternary operations if subexpressions can be built. |
| 1492 | if (def->getNumOperands() == 3) { |
| 1493 | const auto [x, xDepSp] = buildTensorExp(op: op, v: def->getOperand(idx: 0)); |
| 1494 | const auto [y, yDepSp] = buildTensorExp(op: op, v: def->getOperand(idx: 1)); |
| 1495 | const auto [z, zDepSp] = buildTensorExp(op: op, v: def->getOperand(idx: 2)); |
| 1496 | bool hasSpDep = xDepSp || yDepSp || zDepSp; |
| 1497 | if (x.has_value() && y.has_value() && z.has_value()) { |
| 1498 | const ExprId e0 = *x; |
| 1499 | const ExprId e1 = *y; |
| 1500 | if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) { |
| 1501 | if (isAdmissibleBranch(redop, redop.getRegion())) |
| 1502 | return {addExp(k: TensorExp::Kind::kReduce, e0, e1, op: def), hasSpDep}; |
| 1503 | } |
| 1504 | if (auto selop = dyn_cast<arith::SelectOp>(def)) { |
| 1505 | // Recognize an integral or floating-point ReLu(x) = Max(x, 0) |
| 1506 | // operation inside a very specific ternary select operation. |
| 1507 | // TODO: capture MIN/MAX/ABS/RELU structure in a more generic way |
| 1508 | const auto &cnd = exp(e: *x); |
| 1509 | if (isGreater(cnd.kind, cnd.attr) && |
| 1510 | exp(e: *y).kind == TensorExp::Kind::kTensor && |
| 1511 | exp(e: *z).kind == TensorExp::Kind::kInvariant && |
| 1512 | isCertainZero(exp(e: *z).val)) { |
| 1513 | const auto &a = exp(e: cnd.children.e0); |
| 1514 | const auto &b = exp(e: cnd.children.e1); |
| 1515 | if (a.kind == TensorExp::Kind::kTensor && |
| 1516 | a.tensor == exp(e: *y).tensor && |
| 1517 | b.kind == TensorExp::Kind::kInvariant && isCertainZero(b.val)) { |
| 1518 | return {addExp(TensorExp::Kind::kRelu, *y, detail::kInvalidId, |
| 1519 | nullptr, cnd.attr), |
| 1520 | yDepSp}; |
| 1521 | } |
| 1522 | } |
| 1523 | } |
| 1524 | } |
| 1525 | } |
| 1526 | |
| 1527 | // If we reach here, we are dealing with an operation that is not currently |
| 1528 | // sparsifiable. We can still generate code for it if all its operands only |
| 1529 | // have dense dependencies (i.e., all the values are loaded from dense |
| 1530 | // tensors). |
| 1531 | if (def->getNumResults() != 1) // only handle single result operation. |
| 1532 | return {std::nullopt, false}; |
| 1533 | SmallVector<std::pair<std::optional<ExprId>, bool>, 2> subExp; |
| 1534 | // Builds all the sub-expressions |
| 1535 | for (Value operand : def->getOperands()) |
| 1536 | subExp.push_back(Elt: buildTensorExp(op: op, v: operand)); |
| 1537 | |
| 1538 | if (llvm::all_of(Range&: subExp, |
| 1539 | P: [](auto e) { return e.first.has_value() && !e.second; })) { |
| 1540 | // All the subexpressions can be built and has *no* sparse dependencies. |
| 1541 | if (subExp.size() == 2) { |
| 1542 | auto e = addExp(k: TensorExp::Kind::kDenseOp, e0: *subExp[0].first, |
| 1543 | e1: *subExp[1].first, op: def); |
| 1544 | return {e, false}; |
| 1545 | } |
| 1546 | if (subExp.size() == 1) { |
| 1547 | auto e = addExp(k: TensorExp::Kind::kDenseOp, e0: *subExp[0].first, |
| 1548 | e1: detail::kInvalidId, op: def); |
| 1549 | return {e, false}; |
| 1550 | } |
| 1551 | } |
| 1552 | |
| 1553 | // Cannot build. |
| 1554 | return {std::nullopt, false}; |
| 1555 | } |
| 1556 | |
| 1557 | static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion, |
| 1558 | ValueRange vals) { |
| 1559 | // Make a clone of overlap region. |
| 1560 | Region tmpRegion; |
| 1561 | IRMapping mapper; |
| 1562 | region.cloneInto(dest: &tmpRegion, destPos: tmpRegion.begin(), mapper); |
| 1563 | Block &clonedBlock = tmpRegion.front(); |
| 1564 | YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator()); |
| 1565 | // Merge cloned block and return yield value. |
| 1566 | Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 1567 | rewriter.inlineBlockBefore(source: &tmpRegion.front(), op: placeholder, argValues: vals); |
| 1568 | Value val = clonedYield.getSingleResult(); |
| 1569 | rewriter.eraseOp(op: clonedYield); |
| 1570 | rewriter.eraseOp(op: placeholder); |
| 1571 | return val; |
| 1572 | } |
| 1573 | |
| 1574 | static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, |
| 1575 | Operation *op, Value v0) { |
| 1576 | if (!v0) |
| 1577 | // Empty input value must be propagated. |
| 1578 | return Value(); |
| 1579 | UnaryOp unop = cast<UnaryOp>(op); |
| 1580 | Region &presentRegion = unop.getPresentRegion(); |
| 1581 | if (presentRegion.empty()) |
| 1582 | // Uninitialized Value() will be interpreted as missing data in the |
| 1583 | // output. |
| 1584 | return Value(); |
| 1585 | return insertYieldOp(rewriter, loc, region&: presentRegion, vals: {v0}); |
| 1586 | } |
| 1587 | |
| 1588 | static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, |
| 1589 | Operation *op, Value v0, Value v1) { |
| 1590 | if (!v0 || !v1) |
| 1591 | // Empty input values must be propagated. |
| 1592 | return Value(); |
| 1593 | BinaryOp binop = cast<BinaryOp>(op); |
| 1594 | Region &overlapRegion = binop.getOverlapRegion(); |
| 1595 | if (overlapRegion.empty()) |
| 1596 | // Uninitialized Value() will be interpreted as missing data in the |
| 1597 | // output. |
| 1598 | return Value(); |
| 1599 | return insertYieldOp(rewriter, loc, region&: overlapRegion, vals: {v0, v1}); |
| 1600 | } |
| 1601 | |
| 1602 | static Value buildRelu(RewriterBase &rewriter, Location loc, Value v0, |
| 1603 | Attribute attr) { |
| 1604 | Type tp = v0.getType(); |
| 1605 | auto zero = |
| 1606 | rewriter.create<arith::ConstantOp>(loc, tp, rewriter.getZeroAttr(tp)); |
| 1607 | Value cmp; |
| 1608 | if (isa<FloatType>(Val: tp)) { |
| 1609 | auto pred = llvm::cast<arith::CmpFPredicateAttr>(attr); |
| 1610 | cmp = rewriter.create<arith::CmpFOp>(loc, pred, v0, zero); |
| 1611 | } else { |
| 1612 | auto pred = llvm::cast<arith::CmpIPredicateAttr>(attr); |
| 1613 | cmp = rewriter.create<arith::CmpIOp>(loc, pred, v0, zero); |
| 1614 | } |
| 1615 | return rewriter.create<arith::SelectOp>(loc, cmp, v0, zero); |
| 1616 | } |
| 1617 | |
| 1618 | Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, |
| 1619 | Value v1) const { |
| 1620 | const auto &expr = exp(e); |
| 1621 | switch (expr.kind) { |
| 1622 | // Leaf. |
| 1623 | case TensorExp::Kind::kTensor: |
| 1624 | case TensorExp::Kind::kInvariant: |
| 1625 | case TensorExp::Kind::kLoopVar: |
| 1626 | case TensorExp::Kind::kSynZero: |
| 1627 | llvm_unreachable("unexpected non-op" ); |
| 1628 | // Unary operations. |
| 1629 | case TensorExp::Kind::kAbsF: |
| 1630 | return rewriter.create<math::AbsFOp>(loc, v0); |
| 1631 | case TensorExp::Kind::kAbsC: { |
| 1632 | auto type = cast<ComplexType>(v0.getType()); |
| 1633 | auto eltType = cast<FloatType>(type.getElementType()); |
| 1634 | return rewriter.create<complex::AbsOp>(loc, eltType, v0); |
| 1635 | } |
| 1636 | case TensorExp::Kind::kAbsI: |
| 1637 | return rewriter.create<math::AbsIOp>(loc, v0); |
| 1638 | case TensorExp::Kind::kCeilF: |
| 1639 | return rewriter.create<math::CeilOp>(loc, v0); |
| 1640 | case TensorExp::Kind::kFloorF: |
| 1641 | return rewriter.create<math::FloorOp>(loc, v0); |
| 1642 | case TensorExp::Kind::kSqrtF: |
| 1643 | return rewriter.create<math::SqrtOp>(loc, v0); |
| 1644 | case TensorExp::Kind::kSqrtC: |
| 1645 | return rewriter.create<complex::SqrtOp>(loc, v0); |
| 1646 | case TensorExp::Kind::kExpm1F: |
| 1647 | return rewriter.create<math::ExpM1Op>(loc, v0); |
| 1648 | case TensorExp::Kind::kExpm1C: |
| 1649 | return rewriter.create<complex::Expm1Op>(loc, v0); |
| 1650 | case TensorExp::Kind::kLog1pF: |
| 1651 | return rewriter.create<math::Log1pOp>(loc, v0); |
| 1652 | case TensorExp::Kind::kLog1pC: |
| 1653 | return rewriter.create<complex::Log1pOp>(loc, v0); |
| 1654 | case TensorExp::Kind::kRelu: |
| 1655 | return buildRelu(rewriter, loc, v0, attr: expr.attr); |
| 1656 | case TensorExp::Kind::kSinF: |
| 1657 | return rewriter.create<math::SinOp>(loc, v0); |
| 1658 | case TensorExp::Kind::kSinC: |
| 1659 | return rewriter.create<complex::SinOp>(loc, v0); |
| 1660 | case TensorExp::Kind::kTanhF: |
| 1661 | return rewriter.create<math::TanhOp>(loc, v0); |
| 1662 | case TensorExp::Kind::kTanhC: |
| 1663 | return rewriter.create<complex::TanhOp>(loc, v0); |
| 1664 | case TensorExp::Kind::kNegF: |
| 1665 | return rewriter.create<arith::NegFOp>(loc, v0); |
| 1666 | case TensorExp::Kind::kNegC: |
| 1667 | return rewriter.create<complex::NegOp>(loc, v0); |
| 1668 | case TensorExp::Kind::kNegI: // no negi in std |
| 1669 | return rewriter.create<arith::SubIOp>( |
| 1670 | loc, |
| 1671 | rewriter.create<arith::ConstantOp>(loc, v0.getType(), |
| 1672 | rewriter.getZeroAttr(v0.getType())), |
| 1673 | v0); |
| 1674 | case TensorExp::Kind::kTruncF: |
| 1675 | return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0); |
| 1676 | case TensorExp::Kind::kExtF: |
| 1677 | return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0); |
| 1678 | case TensorExp::Kind::kCastFS: |
| 1679 | return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0); |
| 1680 | case TensorExp::Kind::kCastFU: |
| 1681 | return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0); |
| 1682 | case TensorExp::Kind::kCastSF: |
| 1683 | return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0); |
| 1684 | case TensorExp::Kind::kCastUF: |
| 1685 | return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0); |
| 1686 | case TensorExp::Kind::kCastS: |
| 1687 | return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0); |
| 1688 | case TensorExp::Kind::kCastU: |
| 1689 | return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0); |
| 1690 | case TensorExp::Kind::kCastIdx: |
| 1691 | return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0); |
| 1692 | case TensorExp::Kind::kTruncI: |
| 1693 | return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0); |
| 1694 | case TensorExp::Kind::kCIm: { |
| 1695 | auto type = cast<ComplexType>(v0.getType()); |
| 1696 | auto eltType = cast<FloatType>(type.getElementType()); |
| 1697 | return rewriter.create<complex::ImOp>(loc, eltType, v0); |
| 1698 | } |
| 1699 | case TensorExp::Kind::kCRe: { |
| 1700 | auto type = cast<ComplexType>(v0.getType()); |
| 1701 | auto eltType = cast<FloatType>(type.getElementType()); |
| 1702 | return rewriter.create<complex::ReOp>(loc, eltType, v0); |
| 1703 | } |
| 1704 | case TensorExp::Kind::kBitCast: |
| 1705 | return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0); |
| 1706 | // Binary operations. |
| 1707 | case TensorExp::Kind::kMulF: |
| 1708 | return rewriter.create<arith::MulFOp>(loc, v0, v1); |
| 1709 | case TensorExp::Kind::kMulC: |
| 1710 | return rewriter.create<complex::MulOp>(loc, v0, v1); |
| 1711 | case TensorExp::Kind::kMulI: |
| 1712 | return rewriter.create<arith::MulIOp>(loc, v0, v1); |
| 1713 | case TensorExp::Kind::kDivF: |
| 1714 | return rewriter.create<arith::DivFOp>(loc, v0, v1); |
| 1715 | case TensorExp::Kind::kDivC: |
| 1716 | return rewriter.create<complex::DivOp>(loc, v0, v1); |
| 1717 | case TensorExp::Kind::kDivS: |
| 1718 | return rewriter.create<arith::DivSIOp>(loc, v0, v1); |
| 1719 | case TensorExp::Kind::kDivU: |
| 1720 | return rewriter.create<arith::DivUIOp>(loc, v0, v1); |
| 1721 | case TensorExp::Kind::kAddF: |
| 1722 | return rewriter.create<arith::AddFOp>(loc, v0, v1); |
| 1723 | case TensorExp::Kind::kAddC: |
| 1724 | return rewriter.create<complex::AddOp>(loc, v0, v1); |
| 1725 | case TensorExp::Kind::kAddI: |
| 1726 | return rewriter.create<arith::AddIOp>(loc, v0, v1); |
| 1727 | case TensorExp::Kind::kSubF: |
| 1728 | return rewriter.create<arith::SubFOp>(loc, v0, v1); |
| 1729 | case TensorExp::Kind::kSubC: |
| 1730 | return rewriter.create<complex::SubOp>(loc, v0, v1); |
| 1731 | case TensorExp::Kind::kSubI: |
| 1732 | return rewriter.create<arith::SubIOp>(loc, v0, v1); |
| 1733 | case TensorExp::Kind::kAndI: |
| 1734 | return rewriter.create<arith::AndIOp>(loc, v0, v1); |
| 1735 | case TensorExp::Kind::kOrI: |
| 1736 | return rewriter.create<arith::OrIOp>(loc, v0, v1); |
| 1737 | case TensorExp::Kind::kXorI: |
| 1738 | return rewriter.create<arith::XOrIOp>(loc, v0, v1); |
| 1739 | case TensorExp::Kind::kShrS: |
| 1740 | return rewriter.create<arith::ShRSIOp>(loc, v0, v1); |
| 1741 | case TensorExp::Kind::kShrU: |
| 1742 | return rewriter.create<arith::ShRUIOp>(loc, v0, v1); |
| 1743 | case TensorExp::Kind::kShlI: |
| 1744 | return rewriter.create<arith::ShLIOp>(loc, v0, v1); |
| 1745 | case TensorExp::Kind::kCmpI: { |
| 1746 | auto predicate = llvm::cast<arith::CmpIPredicateAttr>(expr.attr); |
| 1747 | return rewriter.create<arith::CmpIOp>(loc, predicate, v0, v1); |
| 1748 | } |
| 1749 | case TensorExp::Kind::kCmpF: { |
| 1750 | auto predicate = llvm::cast<arith::CmpFPredicateAttr>(expr.attr); |
| 1751 | return rewriter.create<arith::CmpFOp>(loc, predicate, v0, v1); |
| 1752 | } |
| 1753 | case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic. |
| 1754 | return insertYieldOp(rewriter, loc, region&: *expr.op->getBlock()->getParent(), |
| 1755 | vals: {v0}); |
| 1756 | case TensorExp::Kind::kUnary: |
| 1757 | return buildUnaryPresent(rewriter, loc, op: expr.op, v0); |
| 1758 | case TensorExp::Kind::kSelect: |
| 1759 | return insertYieldOp(rewriter, loc, |
| 1760 | cast<sparse_tensor::SelectOp>(expr.op).getRegion(), |
| 1761 | {v0}); |
| 1762 | case TensorExp::Kind::kBinary: |
| 1763 | return buildBinaryOverlap(rewriter, loc, op: expr.op, v0, v1); |
| 1764 | case TensorExp::Kind::kReduce: { |
| 1765 | ReduceOp redOp = cast<ReduceOp>(expr.op); |
| 1766 | return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1}); |
| 1767 | } |
| 1768 | case TensorExp::Kind::kDenseOp: { |
| 1769 | Operation *actualOp = expr.op; |
| 1770 | IRMapping mapping; |
| 1771 | mapping.map(from: actualOp->getOperand(idx: 0), to: v0); |
| 1772 | if (actualOp->getNumOperands() == 2) |
| 1773 | mapping.map(from: actualOp->getOperand(idx: 1), to: v1); |
| 1774 | return rewriter.clone(op&: *actualOp, mapper&: mapping)->getResult(idx: 0); |
| 1775 | } |
| 1776 | } |
| 1777 | llvm_unreachable("unexpected expression kind in build" ); |
| 1778 | } |
| 1779 | |
| 1780 | } // namespace sparse_tensor |
| 1781 | } // namespace mlir |
| 1782 | |