| 1 | //===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" |
| 10 | |
| 11 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 12 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 13 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 14 | #include "mlir/Dialect/Complex/IR/Complex.h" |
| 15 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 16 | #include "mlir/IR/AffineExpr.h" |
| 17 | #include "mlir/IR/AffineExprVisitor.h" |
| 18 | #include "mlir/IR/AffineMap.h" |
| 19 | #include "mlir/IR/BuiltinTypeInterfaces.h" |
| 20 | #include "mlir/IR/MLIRContext.h" |
| 21 | #include "mlir/IR/TypeUtilities.h" |
| 22 | #include "llvm/ADT/STLExtras.h" |
| 23 | #include "llvm/ADT/SetOperations.h" |
| 24 | #include "llvm/ADT/SmallBitVector.h" |
| 25 | #include "llvm/ADT/SmallVector.h" |
| 26 | #include "llvm/Support/Casting.h" |
| 27 | #include "llvm/Support/raw_ostream.h" |
| 28 | #include <algorithm> |
| 29 | #include <numeric> |
| 30 | #include <optional> |
| 31 | |
| 32 | using namespace mlir; |
| 33 | using namespace mlir::linalg; |
| 34 | |
| 35 | /// Include the definitions of the copy operation interface. |
| 36 | #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc" |
| 37 | |
| 38 | //===----------------------------------------------------------------------===// |
| 39 | // Interface utility functions |
| 40 | //===----------------------------------------------------------------------===// |
| 41 | |
| 42 | bool linalg::detail::canOpOperandsBeDroppedImpl( |
| 43 | linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) { |
| 44 | SmallVector<AffineMap> indexingMaps; |
| 45 | for (auto &opOperand : linalgOp->getOpOperands()) { |
| 46 | if (llvm::is_contained(droppedOperands, &opOperand)) |
| 47 | continue; |
| 48 | indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand)); |
| 49 | } |
| 50 | if (indexingMaps.empty()) { |
| 51 | // If there are no indexing maps, the operand can only be dropped |
| 52 | // if the op has no loops. |
| 53 | return linalgOp.getNumLoops() == 0; |
| 54 | } |
| 55 | return inversePermutation(concatAffineMaps( |
| 56 | indexingMaps, linalgOp.getContext())) != AffineMap(); |
| 57 | } |
| 58 | |
| 59 | //===----------------------------------------------------------------------===// |
| 60 | // CopyOpInterface implementation |
| 61 | //===----------------------------------------------------------------------===// |
| 62 | |
| 63 | bool linalg::isaCopyOpInterface(LinalgOp op) { |
| 64 | // Check all loops are parallel and linalgOp is single input and output. |
| 65 | if (!op.isAllParallelLoops() || !op.isSingleInputOutput()) |
| 66 | return false; |
| 67 | |
| 68 | auto mapRange = op.getIndexingMapsArray(); |
| 69 | if (mapRange.size() != 2 || !mapRange.front().isIdentity() || |
| 70 | !mapRange.back().isIdentity()) { |
| 71 | return false; |
| 72 | } |
| 73 | // Region. |
| 74 | return llvm::hasSingleElement(op.getBlock()->getOperations()); |
| 75 | } |
| 76 | |
| 77 | //===----------------------------------------------------------------------===// |
| 78 | // FillOpInterface implementation |
| 79 | //===----------------------------------------------------------------------===// |
| 80 | std::optional<Value> linalg::isaFillOpInterface(GenericOp op) { |
| 81 | // Structural. |
| 82 | if (!op.isAllParallelLoops() || !op.isSingleInputOutput() || |
| 83 | !op.isSingleYieldOp()) |
| 84 | return std::nullopt; |
| 85 | |
| 86 | // Input should be referenced and init should not. |
| 87 | if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) || |
| 88 | op.payloadUsesValueFromOperand(op.getDpsInitOperand(0))) |
| 89 | return std::nullopt; |
| 90 | |
| 91 | OpOperand *value = op.getDpsInputOperand(0); |
| 92 | if (!op.isScalar(value)) |
| 93 | return std::nullopt; |
| 94 | return value->get(); |
| 95 | } |
| 96 | |
| 97 | //===----------------------------------------------------------------------===// |
| 98 | // BroadcastOpInterface implementation |
| 99 | //===----------------------------------------------------------------------===// |
| 100 | std::optional<SmallVector<int64_t>> |
| 101 | linalg::isaBroadcastOpInterface(GenericOp op) { |
| 102 | // Structural. |
| 103 | if (!op.isAllParallelLoops() || !op.isSingleInputOutput() || |
| 104 | !op.isSingleYieldOp()) |
| 105 | return std::nullopt; |
| 106 | |
| 107 | auto srcTy = op.getDpsInputOperand(0)->get().getType(); |
| 108 | auto dstTy = op.getDpsInitOperand(0)->get().getType(); |
| 109 | if (!isa<MemRefType, RankedTensorType>(srcTy) || |
| 110 | !isa<MemRefType, RankedTensorType>(dstTy)) |
| 111 | return std::nullopt; |
| 112 | |
| 113 | // Check output is identity map. Broadcast could additionally be |
| 114 | // employing permutation of indices and that would be expressible |
| 115 | // in linalg.generic but is not expressible for named broadcast op. |
| 116 | auto dstMap = op.getIndexingMapsArray()[1]; |
| 117 | if (!dstMap.isIdentity()) |
| 118 | return std::nullopt; |
| 119 | |
| 120 | SmallVector<int64_t> position; |
| 121 | auto srcMap = op.getIndexingMapsArray()[0]; |
| 122 | |
| 123 | if (srcMap.getResults().size() >= dstMap.getResults().size()) |
| 124 | return std::nullopt; |
| 125 | |
| 126 | // Check input map is monotonically increasing DimIds. |
| 127 | for (unsigned i = 0; i < srcMap.getNumResults(); ++i) { |
| 128 | auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]); |
| 129 | if (!expr) |
| 130 | return std::nullopt; |
| 131 | int64_t pos = expr.getPosition(); |
| 132 | if (i > 0 && pos <= position[i - 1]) |
| 133 | return std::nullopt; |
| 134 | position.push_back(Elt: expr.getPosition()); |
| 135 | } |
| 136 | |
| 137 | SmallVector<int64_t> broadcastedDims; |
| 138 | auto numDims = srcMap.getNumDims(); |
| 139 | // This is quadratic but number of items is generally small. |
| 140 | for (auto dim : llvm::seq<int64_t>(0, numDims)) { |
| 141 | if (!llvm::is_contained(position, dim)) |
| 142 | broadcastedDims.push_back(dim); |
| 143 | } |
| 144 | return broadcastedDims; |
| 145 | } |
| 146 | |
| 147 | //===----------------------------------------------------------------------===// |
| 148 | // TransposeOpInterface implementation |
| 149 | //===----------------------------------------------------------------------===// |
| 150 | std::optional<SmallVector<int64_t>> |
| 151 | linalg::isaTransposeOpInterface(GenericOp op) { |
| 152 | // To specialize as a transpose op, the genericOp must be |
| 153 | // all parallel loops, single input, single output, and its body |
| 154 | // should be just a yield op, yielding input as output as is (no compute). |
| 155 | if (!op.isAllParallelLoops() || !op.isSingleInputOutput() || |
| 156 | !op.isSingleYieldOp()) |
| 157 | return std::nullopt; |
| 158 | |
| 159 | auto mapRange = op.getIndexingMapsArray(); |
| 160 | if (mapRange.size() != 2) |
| 161 | return std::nullopt; |
| 162 | |
| 163 | auto mapOfInput = mapRange.front(); |
| 164 | auto mapOfResult = mapRange.back(); |
| 165 | |
| 166 | // linalg.transpose permutes the dimensions of input using this |
| 167 | // rule: dim(result, i) = dim(input, permutation[i]) |
| 168 | if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation()) |
| 169 | return std::nullopt; |
| 170 | |
| 171 | SmallVector<int64_t> permutation(mapOfInput.getNumDims()); |
| 172 | for (unsigned i = 0; i < mapOfInput.getNumDims(); ++i) { |
| 173 | auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]); |
| 174 | permutation[expr.getPosition()] = i; |
| 175 | } |
| 176 | return permutation; |
| 177 | } |
| 178 | |
| 179 | //===----------------------------------------------------------------------===// |
| 180 | // Elementwise Single Unary/Binary-OpInterface implementation |
| 181 | //===----------------------------------------------------------------------===// |
| 182 | static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, |
| 183 | unsigned arity) { |
| 184 | // Check all loops are parallel. |
| 185 | if (!op.isAllParallelLoops() || op.getNumLoops() < 1) |
| 186 | return false; |
| 187 | |
| 188 | // Check there are arity-inputs, 1-output and all are identity-maps. |
| 189 | if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 || |
| 190 | !llvm::all_of(op.getIndexingMapsArray(), |
| 191 | [](AffineMap map) { return map.isIdentity(); })) |
| 192 | return false; |
| 193 | |
| 194 | // Init should not be referenced for elementwise operations. |
| 195 | if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0))) |
| 196 | return false; |
| 197 | |
| 198 | // A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such |
| 199 | // as resulting from producer-consumer fusion. Here, we restrict to two ops in |
| 200 | // the body, where the first is the elementwise single op and the second a |
| 201 | // yield. |
| 202 | Block *body = op.getBody(); |
| 203 | if (body->getOperations().size() != 2) |
| 204 | return false; |
| 205 | |
| 206 | Operation *oper = &body->front(); |
| 207 | if (oper->getNumOperands() != arity || oper->getNumResults() != 1) |
| 208 | return false; |
| 209 | |
| 210 | auto yieldOp = dyn_cast<linalg::YieldOp>(body->back()); |
| 211 | if (!yieldOp || yieldOp.getNumOperands() != 1 || |
| 212 | yieldOp->getOperand(0).getDefiningOp() != oper) |
| 213 | return false; |
| 214 | return true; |
| 215 | } |
| 216 | |
| 217 | bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op) { |
| 218 | // All basic elemwise checks. |
| 219 | if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 1)) |
| 220 | return false; |
| 221 | |
| 222 | // Check input is actully used. |
| 223 | if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0))) |
| 224 | return false; |
| 225 | return true; |
| 226 | } |
| 227 | |
| 228 | bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op) { |
| 229 | if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 2)) |
| 230 | return false; |
| 231 | |
| 232 | // Check both inputs are used (elementwise). |
| 233 | OpOperand *inputOpOperand0 = op.getDpsInputOperand(0); |
| 234 | OpOperand *inputOpOperand1 = op.getDpsInputOperand(1); |
| 235 | if (!op.payloadUsesValueFromOperand(inputOpOperand0) || |
| 236 | !op.payloadUsesValueFromOperand(inputOpOperand1)) |
| 237 | return false; |
| 238 | return true; |
| 239 | } |
| 240 | |
| 241 | //===----------------------------------------------------------------------===// |
| 242 | // ContractionOpInterface implementation |
| 243 | //===----------------------------------------------------------------------===// |
| 244 | |
| 245 | /// If the value is defined by a chain of unary side effect-free, go up the |
| 246 | /// use-def chain until the first value that isn't defined by such an op. |
| 247 | // TODO: relax to multi-operands with constants, which are technically unary ops |
| 248 | // as needed (e.g. add5). |
| 249 | static Value getSourceSkipUnary(Value value) { |
| 250 | Operation *op = value.getDefiningOp(); |
| 251 | while (op && op->getNumOperands() == 1) { |
| 252 | auto iface = dyn_cast<MemoryEffectOpInterface>(op); |
| 253 | if (!iface || !iface.hasNoEffect()) |
| 254 | break; |
| 255 | value = op->getOperand(idx: 0); |
| 256 | op = value.getDefiningOp(); |
| 257 | } |
| 258 | return value; |
| 259 | } |
| 260 | |
| 261 | bool mlir::linalg::detail::isContractionBody( |
| 262 | Block &block, function_ref<bool(Operation *, Operation *)> isaPair, |
| 263 | llvm::raw_ostream &errs) { |
| 264 | if (block.empty() || !block.back().mightHaveTrait<OpTrait::IsTerminator>()) { |
| 265 | errs << "no terminator in the block" ; |
| 266 | return false; |
| 267 | } |
| 268 | |
| 269 | if (block.getNumArguments() != 3) { |
| 270 | errs << "expected block with 3 arguments" ; |
| 271 | return false; |
| 272 | } |
| 273 | |
| 274 | Operation *terminator = block.getTerminator(); |
| 275 | if (terminator->getNumOperands() != 1) { |
| 276 | errs << "expected terminator with 1 operand" ; |
| 277 | return false; |
| 278 | } |
| 279 | |
| 280 | Value yielded = getSourceSkipUnary(value: terminator->getOperand(idx: 0)); |
| 281 | Operation *reductionOp = yielded.getDefiningOp(); |
| 282 | if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) { |
| 283 | errs << "expected reduction op to be binary" ; |
| 284 | return false; |
| 285 | } |
| 286 | |
| 287 | Value reductionLHS = getSourceSkipUnary(value: reductionOp->getOperand(idx: 0)); |
| 288 | Value reductionRHS = getSourceSkipUnary(value: reductionOp->getOperand(idx: 1)); |
| 289 | |
| 290 | if (reductionLHS != block.getArgument(i: 2) && |
| 291 | reductionRHS != block.getArgument(i: 2)) { |
| 292 | errs << "expected reduction to take block argument #2 as one of the " |
| 293 | "operands (modulo unary casts)" ; |
| 294 | return false; |
| 295 | } |
| 296 | |
| 297 | Value contributed = getSourceSkipUnary( |
| 298 | value: isa<BlockArgument>(Val: reductionLHS) ? reductionRHS : reductionLHS); |
| 299 | Operation *elementwiseOp = contributed.getDefiningOp(); |
| 300 | if (!elementwiseOp || elementwiseOp->getNumResults() != 1 || |
| 301 | elementwiseOp->getNumOperands() != 2) { |
| 302 | errs << "expected elementwise op to be binary" ; |
| 303 | return false; |
| 304 | } |
| 305 | |
| 306 | if (!isaPair(elementwiseOp, reductionOp)) { |
| 307 | errs << "expected reduction/elementwise op kind not satisfied" ; |
| 308 | return false; |
| 309 | } |
| 310 | |
| 311 | Value elementwiseLHS = getSourceSkipUnary(value: elementwiseOp->getOperand(idx: 0)); |
| 312 | Value elementwiseRHS = getSourceSkipUnary(value: elementwiseOp->getOperand(idx: 1)); |
| 313 | if ((elementwiseLHS == block.getArgument(i: 0) && |
| 314 | elementwiseRHS == block.getArgument(i: 1)) || |
| 315 | (elementwiseLHS == block.getArgument(i: 1) && |
| 316 | elementwiseRHS == block.getArgument(i: 0))) { |
| 317 | return true; |
| 318 | } |
| 319 | |
| 320 | errs << "expected elementwise op to apply to block arguments (modulo unary " |
| 321 | "casts)" ; |
| 322 | return false; |
| 323 | } |
| 324 | |
| 325 | /// Returns true if the two operations are of the kinds specified by a pair of |
| 326 | /// consecutive template arguments. |
| 327 | template <typename AddOpTy, typename MulOpTy, typename... Args> |
| 328 | static bool isPairTemplateImpl(Operation *add, Operation *mul) { |
| 329 | static_assert(sizeof...(Args) % 2 == 0, |
| 330 | "expected an even number of template arguments" ); |
| 331 | if (isa<AddOpTy>(add) && isa<MulOpTy>(mul)) |
| 332 | return true; |
| 333 | |
| 334 | if constexpr (sizeof...(Args) > 0) |
| 335 | return isPairTemplateImpl<Args...>(add, mul); |
| 336 | else |
| 337 | return false; |
| 338 | } |
| 339 | |
| 340 | /// Returns true if the block is a body of a contraction with the kinds of |
| 341 | /// operations given pairwise by template arguments. |
| 342 | template <typename... Args> |
| 343 | static bool isContractionBody(Block &block) { |
| 344 | return linalg::detail::isContractionBody(block, isaPair: &isPairTemplateImpl<Args...>); |
| 345 | } |
| 346 | |
| 347 | /// Given an `indexingMap` and its corresponding `iterators`, returns |
| 348 | /// the positions of the iterators of type `iter` that are indexed by |
| 349 | /// the `indexingMap` as a permutation. This is useful to infer various |
| 350 | /// subcomputations on a `LinalgOp`. This is performed by looking up |
| 351 | /// each result in the `indexingMap` and determining whether: |
| 352 | /// - It is a single AffineDimExpr. |
| 353 | /// - It is the only result involving this AffineDimExpr. |
| 354 | static llvm::SmallDenseSet<int64_t> |
| 355 | findPermutationsIndexingOperand(AffineMap indexingMap, |
| 356 | ArrayRef<utils::IteratorType> iterators, |
| 357 | utils::IteratorType iter) { |
| 358 | assert(iterators.size() == indexingMap.getNumDims()); |
| 359 | llvm::SmallDenseSet<int64_t> res; |
| 360 | for (AffineExpr e : indexingMap.getResults()) { |
| 361 | if (auto d = dyn_cast<AffineDimExpr>(Val&: e)) { |
| 362 | if (iterators[d.getPosition()] == iter && |
| 363 | llvm::count_if(Range: indexingMap.getResults(), P: [d](AffineExpr e) { |
| 364 | return e.isFunctionOfDim(position: d.getPosition()); |
| 365 | }) == 1) |
| 366 | res.insert(V: d.getPosition()); |
| 367 | } |
| 368 | } |
| 369 | return res; |
| 370 | } |
| 371 | |
| 372 | namespace { |
| 373 | auto par = utils::IteratorType::parallel; |
| 374 | auto red = utils::IteratorType::reduction; |
| 375 | } // namespace |
| 376 | |
| 377 | /// Infer the iterator types from the init affine map. This looks at which dims |
| 378 | /// are present in the map results, and returns an iterator types array with |
| 379 | /// parallel types for dims that are present, and reduction types for dims that |
| 380 | /// are not present. |
| 381 | static FailureOr<SmallVector<utils::IteratorType>> |
| 382 | inferIteratorsFromOutMap(AffineMap map) { |
| 383 | if (!map.isProjectedPermutation()) |
| 384 | return failure(); |
| 385 | SmallVector<utils::IteratorType> iterators(map.getNumDims(), red); |
| 386 | for (auto expr : map.getResults()) |
| 387 | if (auto dim = dyn_cast<AffineDimExpr>(Val&: expr)) |
| 388 | iterators[dim.getPosition()] = par; |
| 389 | return iterators; |
| 390 | } |
| 391 | |
| 392 | /// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form |
| 393 | /// a matmul subcomputation within `linalgOp`. These dimensions are such that: |
| 394 | /// 1. The m dimension is involved in an outer-product along LHS |
| 395 | /// (i.e. it is a permutation on RES and LHS and does not appear in RHS). |
| 396 | /// 2. The n dimension is involved in an outer-product along RHS |
| 397 | /// (i.e. it is a permutation on RES and RHS and does not appear in LHS). |
| 398 | /// 3. The k dimension appears as a permutation on LHS and RHS. |
| 399 | /// 4. m, n and k appear only once in any given indexing. |
| 400 | /// 5. Optional batch dimensions that appear in all operands are captured. |
| 401 | /// This allows e.g. detecting that some contraction is embedded within |
| 402 | /// `linalgOp` with some orthogonal heuristic. |
| 403 | static FailureOr<ContractionDimensions> |
| 404 | inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps, |
| 405 | ArrayRef<utils::IteratorType> iterators) { |
| 406 | llvm::SmallDenseSet<int64_t> a = |
| 407 | findPermutationsIndexingOperand(indexingMaps[0], iterators, par); |
| 408 | llvm::SmallDenseSet<int64_t> b = |
| 409 | findPermutationsIndexingOperand(indexingMaps[1], iterators, par); |
| 410 | llvm::SmallDenseSet<int64_t> c = |
| 411 | findPermutationsIndexingOperand(indexingMaps[2], iterators, par); |
| 412 | |
| 413 | // A & C - B are the iterators involved in an outer-product along A (the LHS). |
| 414 | llvm::SmallDenseSet<int64_t> ac = a; |
| 415 | llvm::set_intersect(S1&: ac, S2: c); |
| 416 | llvm::set_subtract(S1&: ac, S2: b); |
| 417 | // B & C - A are the iterators involved in an outer-product along B (the RHS). |
| 418 | llvm::SmallDenseSet<int64_t> bc = b; |
| 419 | llvm::set_intersect(S1&: bc, S2: c); |
| 420 | llvm::set_subtract(S1&: bc, S2: a); |
| 421 | // A & B & C are the "batch" dimensions. |
| 422 | llvm::SmallDenseSet<int64_t> batches = a; |
| 423 | llvm::set_intersect(S1&: batches, S2: b); |
| 424 | llvm::set_intersect(S1&: batches, S2: c); |
| 425 | |
| 426 | // A & B red are the reduction dimensions. |
| 427 | llvm::SmallDenseSet<int64_t> ra = |
| 428 | findPermutationsIndexingOperand(indexingMaps[0], iterators, red); |
| 429 | llvm::SmallDenseSet<int64_t> rb = |
| 430 | findPermutationsIndexingOperand(indexingMaps[1], iterators, red); |
| 431 | llvm::set_intersect(S1&: ra, S2: rb); |
| 432 | |
| 433 | // Return each set in sorted order. |
| 434 | ContractionDimensions dimensions{ |
| 435 | .batch: SmallVector<unsigned, 2>(batches.begin(), batches.end()), |
| 436 | .m: SmallVector<unsigned, 2>(ac.begin(), ac.end()), |
| 437 | .n: SmallVector<unsigned, 2>(bc.begin(), bc.end()), |
| 438 | .k: SmallVector<unsigned, 2>(ra.begin(), ra.end())}; |
| 439 | llvm::sort(Start: dimensions.batch.begin(), End: dimensions.batch.end()); |
| 440 | llvm::sort(Start: dimensions.m.begin(), End: dimensions.m.end()); |
| 441 | llvm::sort(Start: dimensions.n.begin(), End: dimensions.n.end()); |
| 442 | llvm::sort(Start: dimensions.k.begin(), End: dimensions.k.end()); |
| 443 | return dimensions; |
| 444 | } |
| 445 | |
| 446 | FailureOr<ContractionDimensions> |
| 447 | mlir::linalg::inferContractionDims(LinalgOp linalgOp) { |
| 448 | if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2) |
| 449 | return failure(); |
| 450 | return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(), |
| 451 | linalgOp.getIteratorTypesArray()); |
| 452 | } |
| 453 | |
| 454 | FailureOr<ContractionDimensions> |
| 455 | mlir::linalg::inferContractionDims(ArrayRef<AffineMap> indexingMaps) { |
| 456 | if (indexingMaps.size() != 3) |
| 457 | return failure(); |
| 458 | auto iterators = inferIteratorsFromOutMap(indexingMaps[2]); |
| 459 | if (failed(iterators)) |
| 460 | return failure(); |
| 461 | return inferContractionDimsImpl(indexingMaps, iterators.value()); |
| 462 | } |
| 463 | |
| 464 | namespace mlir::linalg::detail { |
| 465 | enum class MatchContractionResult { |
| 466 | Success = 0, |
| 467 | NotLinalgOp, |
| 468 | WrongNumOperands, |
| 469 | NoReduction, |
| 470 | NotProjectedPermutations, |
| 471 | NotAddMul |
| 472 | }; |
| 473 | } // namespace mlir::linalg::detail |
| 474 | |
| 475 | mlir::linalg::detail::MatchContractionResult |
| 476 | mlir::linalg::detail::isContractionInterfaceImpl( |
| 477 | Operation *op, mlir::linalg::ContractionDimensions *dimensions) { |
| 478 | auto linalgOp = dyn_cast<linalg::LinalgOp>(op); |
| 479 | if (!linalgOp) |
| 480 | return MatchContractionResult::NotLinalgOp; |
| 481 | if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1) |
| 482 | return MatchContractionResult::WrongNumOperands; |
| 483 | auto mapRange = linalgOp.getIndexingMapsArray(); |
| 484 | if (linalgOp.getNumReductionLoops() == 0) |
| 485 | return MatchContractionResult::NoReduction; |
| 486 | if (llvm::any_of(mapRange, |
| 487 | [](AffineMap m) { return !m.isProjectedPermutation(); })) |
| 488 | return MatchContractionResult::NotProjectedPermutations; |
| 489 | // TODO: more fields than add/mul. |
| 490 | // clang-format off |
| 491 | if (!::isContractionBody< |
| 492 | arith::MulFOp, arith::AddFOp, |
| 493 | arith::MulIOp, arith::AddIOp, |
| 494 | complex::MulOp, complex::AddOp, |
| 495 | arith::AndIOp, arith::OrIOp>( |
| 496 | *linalgOp.getBlock())) { |
| 497 | return MatchContractionResult::NotAddMul; |
| 498 | } |
| 499 | // clang-format on |
| 500 | |
| 501 | if (dimensions) { |
| 502 | FailureOr<ContractionDimensions> res = inferContractionDims(linalgOp); |
| 503 | assert(succeeded(res) && "unexpected failure to infer contraction dims" ); |
| 504 | *dimensions = *res; |
| 505 | } |
| 506 | return MatchContractionResult::Success; |
| 507 | } |
| 508 | |
| 509 | StringRef |
| 510 | mlir::linalg::detail::getMatchContractionMessage(MatchContractionResult res) { |
| 511 | switch (res) { |
| 512 | case MatchContractionResult::NotLinalgOp: |
| 513 | return "expected a LinalgOp" ; |
| 514 | case MatchContractionResult::WrongNumOperands: |
| 515 | return "expected op with 2 inputs and 1 output" ; |
| 516 | case MatchContractionResult::NoReduction: |
| 517 | return "expected at least 1 reduction" ; |
| 518 | case MatchContractionResult::NotProjectedPermutations: |
| 519 | return "expected indexing maps to be projected permutations" ; |
| 520 | case MatchContractionResult::NotAddMul: |
| 521 | return "expected add/mul op in the body" ; |
| 522 | case MatchContractionResult::Success: |
| 523 | return "" ; |
| 524 | } |
| 525 | llvm_unreachable("unhandled MatchContractionResult case" ); |
| 526 | } |
| 527 | |
| 528 | bool mlir::linalg::isaContractionOpInterface(LinalgOp linalgOp) { |
| 529 | if (!linalgOp) |
| 530 | return false; |
| 531 | Operation *op = linalgOp.getOperation(); |
| 532 | return isa<ContractionOpInterface>(op) || |
| 533 | (mlir::linalg::detail::isContractionInterfaceImpl(op) == |
| 534 | mlir::linalg::detail::MatchContractionResult::Success); |
| 535 | } |
| 536 | |
| 537 | /// Verify that a LinalgOp `op` is a contraction. |
| 538 | /// A Linalg contraction is defined in general terms: |
| 539 | /// 1. Has 2 input and 1 output shapes. |
| 540 | /// 2. Has at least one reduction dimension. |
| 541 | /// 3. Has only projected permutation indexing maps. |
| 542 | /// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field |
| 543 | /// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary |
| 544 | /// operations that may change the type (e.g. for mixed-precision). |
| 545 | /// As a consequence, when vectorization of such an op occurs, the only special |
| 546 | /// behavior is that the (unique) MulOpType is vectorized into a |
| 547 | /// `vector.contract`. All other ops are handled in a generic fashion. |
| 548 | /// In the future, we may wish to allow more input arguments and elementwise and |
| 549 | /// constant operations that do not involve the reduction dimension(s). |
| 550 | LogicalResult mlir::linalg::detail::verifyContractionInterface(Operation *op) { |
| 551 | auto res = isContractionInterfaceImpl(op); |
| 552 | if (res != MatchContractionResult::Success) |
| 553 | return op->emitError(message: getMatchContractionMessage(res)); |
| 554 | return success(); |
| 555 | } |
| 556 | |
| 557 | //===----------------------------------------------------------------------===// |
| 558 | // ConvolutionOpInterface implementation |
| 559 | //===----------------------------------------------------------------------===// |
| 560 | |
| 561 | /// Of the given two expressions returns one that is of type T (`lhs` gets |
| 562 | /// preference over `rhs`) |
| 563 | template <typename T> |
| 564 | static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs) { |
| 565 | return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) : nullptr); |
| 566 | } |
| 567 | |
| 568 | namespace { |
| 569 | /// Walk the indexing expressions for input of a convolution operation to verify |
| 570 | /// its of the right form, either |
| 571 | /// - AffineDimExpr |
| 572 | /// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))? |
| 573 | /// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)* |
| 574 | /// |
| 575 | /// classifies the AffineDimExpr as convolved dimensions or unconvolved |
| 576 | /// dimensions and verifies each dimension occurs only once. |
| 577 | struct ConvAccessExprWalker |
| 578 | : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> { |
| 579 | // Stores dimensions used in expressions of the above form. |
| 580 | llvm::SmallDenseSet<int64_t> convolvedDims; |
| 581 | // Stores the dual mapping between LHS and RHS of convolution exprs. |
| 582 | llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping; |
| 583 | // Stores single use dimensions used by an AffineDimExpr. |
| 584 | llvm::SmallDenseSet<int64_t> unConvolvedDims; |
| 585 | // Stores a mapping from convolved dims to their coefficient. |
| 586 | llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping; |
| 587 | |
| 588 | // Removes dims with multiple uses in the source input map from dimension |
| 589 | // sets tracked by this walker. |
| 590 | void clearMultiUseDims(AffineMap map) { |
| 591 | for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) { |
| 592 | if (llvm::count_if(Range: map.getResults(), P: [dimPos](AffineExpr e) { |
| 593 | return e.isFunctionOfDim(position: dimPos); |
| 594 | }) > 1) { |
| 595 | convolvedDims.erase(V: dimPos); |
| 596 | unConvolvedDims.erase(V: dimPos); |
| 597 | // If a duplicate dim is marked as convolved, the pair of the duplicate |
| 598 | // dim must be removed from the map as well. |
| 599 | auto it = convolvedDimMapping.find(Val: dimPos); |
| 600 | if (it != convolvedDimMapping.end()) { |
| 601 | int64_t pairedDim = it->second; |
| 602 | convolvedDims.erase(V: pairedDim); |
| 603 | unConvolvedDims.erase(V: pairedDim); |
| 604 | strideAndDilationMapping.erase(Val: pairedDim); |
| 605 | convolvedDimMapping.erase(Val: dimPos); |
| 606 | convolvedDimMapping.erase(Val: pairedDim); |
| 607 | } |
| 608 | } |
| 609 | } |
| 610 | } |
| 611 | |
| 612 | LogicalResult visitDimExpr(AffineDimExpr dimExpr) { |
| 613 | unsigned position = dimExpr.getPosition(); |
| 614 | if (unConvolvedDims.count(V: position) || convolvedDims.count(V: position)) { |
| 615 | return failure(); |
| 616 | } |
| 617 | unConvolvedDims.insert(V: position); |
| 618 | return success(); |
| 619 | } |
| 620 | |
| 621 | LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); } |
| 622 | |
| 623 | LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); } |
| 624 | |
| 625 | LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) { |
| 626 | // In pre-order visit, top level op has to be an add op. |
| 627 | if (binaryExpr.getKind() != AffineExprKind::Add) |
| 628 | return failure(); |
| 629 | auto lhsDimPos = getDimExprOrMulExprDimPos(expr: binaryExpr.getLHS()); |
| 630 | auto rhsDimPos = getDimExprOrMulExprDimPos(expr: binaryExpr.getRHS()); |
| 631 | if (failed(Result: lhsDimPos) || failed(Result: rhsDimPos)) |
| 632 | return failure(); |
| 633 | convolvedDimMapping[*lhsDimPos] = *rhsDimPos; |
| 634 | convolvedDimMapping[*rhsDimPos] = *lhsDimPos; |
| 635 | return success(); |
| 636 | } |
| 637 | |
| 638 | FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) { |
| 639 | if (auto dimExpr = dyn_cast<AffineDimExpr>(Val&: expr)) { |
| 640 | int64_t dim = dimExpr.getPosition(); |
| 641 | if (convolvedDims.count(V: dim) || unConvolvedDims.count(V: dim)) |
| 642 | return failure(); |
| 643 | // Stride/dilation for this dim is implicitly 1. |
| 644 | strideAndDilationMapping[dim] = |
| 645 | getAffineConstantExpr(constant: 1, context: expr.getContext()); |
| 646 | convolvedDims.insert(V: dim); |
| 647 | return dim; |
| 648 | } |
| 649 | if (auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(Val&: expr)) { |
| 650 | if (symbolMulExpr.getKind() != AffineExprKind::Mul) |
| 651 | return failure(); |
| 652 | auto lhsExpr = symbolMulExpr.getLHS(); |
| 653 | auto rhsExpr = symbolMulExpr.getRHS(); |
| 654 | // Check for symbol expression. |
| 655 | AffineExpr mulExpr = |
| 656 | getAffineExprOfType<AffineSymbolExpr>(lhs: lhsExpr, rhs: rhsExpr); |
| 657 | // If there was no symbol expr, check for constant expression. |
| 658 | if (!mulExpr) { |
| 659 | mulExpr = getAffineExprOfType<AffineConstantExpr>(lhs: lhsExpr, rhs: rhsExpr); |
| 660 | } |
| 661 | auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhs: lhsExpr, rhs: rhsExpr); |
| 662 | if (!mulExpr || !dimExpr) |
| 663 | return failure(); |
| 664 | int64_t dim = dimExpr.getPosition(); |
| 665 | if (convolvedDims.count(V: dim) || unConvolvedDims.count(V: dim)) |
| 666 | return failure(); |
| 667 | strideAndDilationMapping[dim] = mulExpr; |
| 668 | convolvedDims.insert(V: dim); |
| 669 | return dim; |
| 670 | } |
| 671 | return failure(); |
| 672 | } |
| 673 | }; |
| 674 | } // namespace |
| 675 | |
| 676 | static llvm::SmallDenseSet<int64_t> getPreservedDims(AffineMap map) { |
| 677 | assert(map.isProjectedPermutation() && |
| 678 | "expected map to have projected permutations" ); |
| 679 | llvm::SmallDenseSet<int64_t> preservedDims; |
| 680 | for (auto expr : map.getResults()) |
| 681 | preservedDims.insert(V: cast<AffineDimExpr>(Val&: expr).getPosition()); |
| 682 | return preservedDims; |
| 683 | } |
| 684 | |
| 685 | static SmallVector<int64_t, 2> |
| 686 | getConstantsFromExprList(const SmallVector<AffineExpr, 2> &exprs) { |
| 687 | SmallVector<int64_t, 2> vals; |
| 688 | for (auto e : exprs) { |
| 689 | auto constantExpr = dyn_cast<AffineConstantExpr>(Val&: e); |
| 690 | assert(constantExpr && "Found non-constant stride/dilation" ); |
| 691 | vals.push_back(Elt: constantExpr.getValue()); |
| 692 | } |
| 693 | return vals; |
| 694 | } |
| 695 | |
| 696 | /// Classifies dimensions in the `linalgOp` used by a convolution |
| 697 | /// subcomputation, as captured by `inputExprWalker`. If |
| 698 | /// `allowEmptyConvolvedDims` is not set this this will fail if there is not |
| 699 | /// at least convolved dimension pair (output image + filter loop). Convolution |
| 700 | /// dimensions are specified in sorted order, and strides match the order of |
| 701 | /// the filter loop dimensions, while the dilations match the order of the |
| 702 | /// output image dimensions. |
| 703 | static FailureOr<ConvolutionDimensions> |
| 704 | inferConvolutionDimsImpl(LinalgOp linalgOp, |
| 705 | ConvAccessExprWalker &inputExprWalker, |
| 706 | bool allowEmptyConvolvedDims) { |
| 707 | auto filterMap = |
| 708 | linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1)); |
| 709 | auto outputMap = |
| 710 | linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0)); |
| 711 | llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand( |
| 712 | filterMap, linalgOp.getIteratorTypesArray(), par); |
| 713 | llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand( |
| 714 | outputMap, linalgOp.getIteratorTypesArray(), par); |
| 715 | |
| 716 | // unConvolvedDims & outputDims - filterDims are the batch iterators. |
| 717 | llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims; |
| 718 | llvm::set_intersect(S1&: batch, S2: outputDims); |
| 719 | llvm::set_subtract(S1&: batch, S2: filterDims); |
| 720 | |
| 721 | // convolvedDims & outputDims are the output image iterators. |
| 722 | llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims; |
| 723 | llvm::set_intersect(S1&: oi, S2: outputDims); |
| 724 | |
| 725 | // filterDims & outputDims - unConvolvedDims are the output channel iterators. |
| 726 | llvm::SmallDenseSet<int64_t> oc = filterDims; |
| 727 | llvm::set_intersect(S1&: oc, S2: outputDims); |
| 728 | llvm::set_subtract(S1&: oc, S2: inputExprWalker.unConvolvedDims); |
| 729 | |
| 730 | // filterDims & outputDims & unConvolvedDims are the depth iterators. |
| 731 | llvm::SmallDenseSet<int64_t> depth = filterDims; |
| 732 | llvm::set_intersect(S1&: depth, S2: outputDims); |
| 733 | llvm::set_intersect(S1&: depth, S2: inputExprWalker.unConvolvedDims); |
| 734 | |
| 735 | llvm::SmallDenseSet<int64_t> filterReducedDims = |
| 736 | findPermutationsIndexingOperand(filterMap, |
| 737 | linalgOp.getIteratorTypesArray(), red); |
| 738 | |
| 739 | // convolvedDims & filterReducedDims are the filter loop iterators. |
| 740 | llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims; |
| 741 | llvm::set_intersect(S1&: fl, S2: filterReducedDims); |
| 742 | |
| 743 | // unConvolvedDims & filterReducedDims are the input channel iterators. |
| 744 | llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims; |
| 745 | llvm::set_intersect(S1&: ic, S2: filterReducedDims); |
| 746 | |
| 747 | if (oi.empty() && !allowEmptyConvolvedDims) |
| 748 | return failure(); |
| 749 | |
| 750 | // Return each set in sorted order. |
| 751 | ConvolutionDimensions dimensions{ |
| 752 | .batch: SmallVector<unsigned, 2>(batch.begin(), batch.end()), |
| 753 | .outputImage: SmallVector<unsigned, 2>(oi.begin(), oi.end()), |
| 754 | .outputChannel: SmallVector<unsigned, 2>(oc.begin(), oc.end()), |
| 755 | .filterLoop: SmallVector<unsigned, 2>(fl.begin(), fl.end()), |
| 756 | .inputChannel: SmallVector<unsigned, 2>(ic.begin(), ic.end()), |
| 757 | .depth: SmallVector<unsigned, 2>(depth.begin(), depth.end()), |
| 758 | /*strides=*/SmallVector<int64_t, 2>{}, |
| 759 | /*dilations=*/SmallVector<int64_t, 2>{}}; |
| 760 | llvm::sort(Start: dimensions.batch.begin(), End: dimensions.batch.end()); |
| 761 | llvm::sort(Start: dimensions.outputImage.begin(), End: dimensions.outputImage.end()); |
| 762 | llvm::sort(Start: dimensions.outputChannel.begin(), End: dimensions.outputChannel.end()); |
| 763 | llvm::sort(Start: dimensions.filterLoop.begin(), End: dimensions.filterLoop.end()); |
| 764 | llvm::sort(Start: dimensions.inputChannel.begin(), End: dimensions.inputChannel.end()); |
| 765 | llvm::sort(Start: dimensions.depth.begin(), End: dimensions.depth.end()); |
| 766 | |
| 767 | // Use the op carried strides/dilations attribute if present. |
| 768 | auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides" ); |
| 769 | if (!nativeStrides) { |
| 770 | SmallVector<AffineExpr, 2> strideExprs; |
| 771 | for (unsigned oiDim : dimensions.outputImage) |
| 772 | strideExprs.push_back(Elt: inputExprWalker.strideAndDilationMapping[oiDim]); |
| 773 | dimensions.strides = getConstantsFromExprList(exprs: strideExprs); |
| 774 | } else { |
| 775 | dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>()); |
| 776 | } |
| 777 | auto nativeDilations = |
| 778 | linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations" ); |
| 779 | if (!nativeDilations) { |
| 780 | SmallVector<AffineExpr, 2> dilationExprs; |
| 781 | for (unsigned flDim : dimensions.filterLoop) |
| 782 | dilationExprs.push_back(Elt: inputExprWalker.strideAndDilationMapping[flDim]); |
| 783 | dimensions.dilations = getConstantsFromExprList(exprs: dilationExprs); |
| 784 | } else { |
| 785 | dimensions.dilations = |
| 786 | llvm::to_vector<2>(nativeDilations.getValues<int64_t>()); |
| 787 | } |
| 788 | return dimensions; |
| 789 | } |
| 790 | |
| 791 | /// Find at least 1 parallel (output_image) and reduction (filter_loop) |
| 792 | /// dimension candidates that form a convolution subcomputation within |
| 793 | /// `linalgOp`. The LHS is assumed to be the convolution input while the |
| 794 | /// RHS is assumed as the filter. |
| 795 | /// These dimensions are such that: |
| 796 | /// 1. Optional batch dimensions that appear in the input and filter. |
| 797 | /// 2. The output_image dimension is involved in a cross-correlation along LHS |
| 798 | /// (i.e. it is a permutation on RES and LHS and has an associated |
| 799 | /// filter_loop in RHS). |
| 800 | /// 3. Optional output_channel dimension is involved in an outer-product along |
| 801 | /// RHS (i.e. it is a permutation on RES and RHS and does not appear in |
| 802 | /// LHS). |
| 803 | /// 4. Optional input_channel dimension appears as a permutation on LHS and |
| 804 | /// RHS. |
| 805 | /// 5. The filter_loop dimension appears as a permutation on the RHS and |
| 806 | /// represents the shape of the kernel cross-correlated along a |
| 807 | /// corresponding output_image dim. |
| 808 | /// 6. The input_channel dimension appears as a permutation on LHS and RHS. |
| 809 | /// 7. All dimensions appear only once in any given indexing map. |
| 810 | /// This allows e.g. detecting that some convolution is embedded within |
| 811 | /// `linalgOp` with some orthogonal heuristic. |
| 812 | /// When multiple dimension occurrences exist that match any classification |
| 813 | /// indices are returned in sorted order. |
| 814 | /// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty. |
| 815 | FailureOr<ConvolutionDimensions> |
| 816 | mlir::linalg::inferConvolutionDims(LinalgOp linalgOp) { |
| 817 | if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2) |
| 818 | return failure(); |
| 819 | |
| 820 | auto indexingMaps = linalgOp.getIndexingMapsArray(); |
| 821 | |
| 822 | // Check the input indexing map has the right form. |
| 823 | ConvAccessExprWalker inputExprWalker; |
| 824 | for (AffineExpr expr : indexingMaps[0].getResults()) |
| 825 | (void)inputExprWalker.visit(expr); |
| 826 | inputExprWalker.clearMultiUseDims(map: indexingMaps[0]); |
| 827 | |
| 828 | return inferConvolutionDimsImpl(linalgOp, inputExprWalker, |
| 829 | /*allowEmptyConvolvedDims=*/false); |
| 830 | } |
| 831 | |
| 832 | namespace mlir::linalg::detail { |
| 833 | enum class MatchConvolutionResult { |
| 834 | Success = 0, |
| 835 | NotLinalgOp, |
| 836 | WrongNumOperands, |
| 837 | WrongInputIndexingMap, |
| 838 | NotProjectedPermutations, |
| 839 | NonConvolutionLoop, |
| 840 | OutputDimsNotParallel, |
| 841 | NonOutputDimNotReduction, |
| 842 | EmptyConvolvedDims |
| 843 | }; |
| 844 | } // namespace mlir::linalg::detail |
| 845 | |
| 846 | mlir::linalg::detail::MatchConvolutionResult |
| 847 | mlir::linalg::detail::isConvolutionInterfaceImpl( |
| 848 | Operation *op, ConvolutionDimensions *dimensions, |
| 849 | bool allowEmptyConvolvedDims) { |
| 850 | auto linalgOp = dyn_cast<linalg::LinalgOp>(op); |
| 851 | if (!linalgOp) |
| 852 | return MatchConvolutionResult::NotLinalgOp; |
| 853 | if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1) |
| 854 | return MatchConvolutionResult::WrongNumOperands; |
| 855 | |
| 856 | auto indexingMaps = linalgOp.getIndexingMapsArray(); |
| 857 | |
| 858 | // Check the input indexing map has the right form. |
| 859 | ConvAccessExprWalker inputExprWalker; |
| 860 | if (llvm::any_of(indexingMaps[0].getResults(), |
| 861 | [&inputExprWalker](AffineExpr expr) { |
| 862 | return failed(Result: inputExprWalker.visit(expr)); |
| 863 | })) { |
| 864 | return MatchConvolutionResult::WrongInputIndexingMap; |
| 865 | } |
| 866 | |
| 867 | // Filter and output maps must be projected permutation. |
| 868 | if (!indexingMaps[1].isProjectedPermutation() || |
| 869 | !indexingMaps.back().isProjectedPermutation()) |
| 870 | return MatchConvolutionResult::NotProjectedPermutations; |
| 871 | |
| 872 | auto iteratorTypes = linalgOp.getIteratorTypesArray(); |
| 873 | |
| 874 | llvm::SmallDenseSet<int64_t> outputDims = |
| 875 | getPreservedDims(indexingMaps.back()); |
| 876 | llvm::SmallDenseSet<int64_t> filterDims = getPreservedDims(indexingMaps[1]); |
| 877 | // Make sure all loops are characterized as one of: |
| 878 | // - Batch loop : present in output, as non-convolved in input, not present in |
| 879 | // filter. |
| 880 | // - Output image dimension : present in output, convolved dims in input, not |
| 881 | // present in filter. |
| 882 | // - Output channel dimension : present in output, not present in input, |
| 883 | // present in filter. |
| 884 | // - Filter loop dimension : present in filter, convolved in input, not |
| 885 | // present in output. |
| 886 | // - Input channel dimension : unconvolved in input, not present in output, |
| 887 | // present in filter. |
| 888 | // - Depth multiplier : unconvolved in input, present in output, present in |
| 889 | // filter. |
| 890 | llvm::SmallDenseSet<int64_t> allLoopDims; |
| 891 | for (auto outputExpr : indexingMaps.back().getResults()) { |
| 892 | int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition(); |
| 893 | if (inputExprWalker.unConvolvedDims.count(outputDim) && |
| 894 | !filterDims.count(outputDim)) { |
| 895 | // Batch dimension. |
| 896 | if (iteratorTypes[outputDim] != utils::IteratorType::parallel) |
| 897 | return MatchConvolutionResult::OutputDimsNotParallel; |
| 898 | allLoopDims.insert(outputDim); |
| 899 | continue; |
| 900 | } |
| 901 | if (inputExprWalker.convolvedDims.count(outputDim) && |
| 902 | !filterDims.count(outputDim)) { |
| 903 | // Output image Loop dimension. |
| 904 | if (iteratorTypes[outputDim] != utils::IteratorType::parallel) |
| 905 | return MatchConvolutionResult::OutputDimsNotParallel; |
| 906 | allLoopDims.insert(outputDim); |
| 907 | continue; |
| 908 | } |
| 909 | if (!inputExprWalker.convolvedDims.count(outputDim) && |
| 910 | !inputExprWalker.unConvolvedDims.count(outputDim) && |
| 911 | filterDims.count(outputDim)) { |
| 912 | // Output channel dimension. |
| 913 | if (iteratorTypes[outputDim] != utils::IteratorType::parallel) |
| 914 | return MatchConvolutionResult::OutputDimsNotParallel; |
| 915 | allLoopDims.insert(outputDim); |
| 916 | continue; |
| 917 | } |
| 918 | if (inputExprWalker.unConvolvedDims.count(outputDim) && |
| 919 | filterDims.count(outputDim)) { |
| 920 | // Depth multiplier. |
| 921 | if (iteratorTypes[outputDim] != utils::IteratorType::parallel) |
| 922 | return MatchConvolutionResult::OutputDimsNotParallel; |
| 923 | allLoopDims.insert(outputDim); |
| 924 | continue; |
| 925 | } |
| 926 | return MatchConvolutionResult::NonConvolutionLoop; |
| 927 | } |
| 928 | for (auto filterExpr : indexingMaps[1].getResults()) { |
| 929 | int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition(); |
| 930 | if (outputDims.count(filterDim) && |
| 931 | !inputExprWalker.unConvolvedDims.count(filterDim) && |
| 932 | !inputExprWalker.convolvedDims.count(filterDim)) { |
| 933 | // Output channel dimension. This is already seen, continue; |
| 934 | continue; |
| 935 | } |
| 936 | if (inputExprWalker.convolvedDims.count(filterDim) && |
| 937 | !outputDims.count(filterDim)) { |
| 938 | // Filter loop dimension. |
| 939 | if (iteratorTypes[filterDim] != utils::IteratorType::reduction) |
| 940 | return MatchConvolutionResult::NonOutputDimNotReduction; |
| 941 | if (allLoopDims.count(filterDim)) |
| 942 | return MatchConvolutionResult::NonConvolutionLoop; |
| 943 | allLoopDims.insert(filterDim); |
| 944 | continue; |
| 945 | } |
| 946 | if (inputExprWalker.unConvolvedDims.count(filterDim) && |
| 947 | !outputDims.count(filterDim)) { |
| 948 | // Input channel dimension. |
| 949 | if (iteratorTypes[filterDim] != utils::IteratorType::reduction) |
| 950 | return MatchConvolutionResult::NonOutputDimNotReduction; |
| 951 | if (allLoopDims.count(filterDim)) |
| 952 | return MatchConvolutionResult::NonConvolutionLoop; |
| 953 | allLoopDims.insert(filterDim); |
| 954 | continue; |
| 955 | } |
| 956 | if (inputExprWalker.unConvolvedDims.count(filterDim) && |
| 957 | outputDims.count(filterDim)) { |
| 958 | // Depthwise loop. Already seen. |
| 959 | continue; |
| 960 | } |
| 961 | return MatchConvolutionResult::NonConvolutionLoop; |
| 962 | } |
| 963 | // All loops must be covered now. |
| 964 | if (allLoopDims.size() != linalgOp.getNumLoops()) |
| 965 | return MatchConvolutionResult::NonConvolutionLoop; |
| 966 | |
| 967 | if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty()) |
| 968 | return MatchConvolutionResult::EmptyConvolvedDims; |
| 969 | |
| 970 | if (dimensions) { |
| 971 | FailureOr<ConvolutionDimensions> res = inferConvolutionDimsImpl( |
| 972 | linalgOp, inputExprWalker, allowEmptyConvolvedDims); |
| 973 | assert(succeeded(res) && "unexpected failure to infer convolution dims" ); |
| 974 | *dimensions = *res; |
| 975 | } |
| 976 | |
| 977 | return MatchConvolutionResult::Success; |
| 978 | } |
| 979 | |
| 980 | StringRef |
| 981 | mlir::linalg::detail::getMatchConvolutionMessage(MatchConvolutionResult res) { |
| 982 | switch (res) { |
| 983 | case MatchConvolutionResult::NotLinalgOp: |
| 984 | return "expected a LinalgOp" ; |
| 985 | case MatchConvolutionResult::WrongNumOperands: |
| 986 | return "expected op with 2 inputs and 1 output" ; |
| 987 | case MatchConvolutionResult::WrongInputIndexingMap: |
| 988 | return "unexpected input index map for convolutions" ; |
| 989 | case MatchConvolutionResult::NotProjectedPermutations: |
| 990 | return "expected output/filter indexing maps to be projected permutations" ; |
| 991 | case MatchConvolutionResult::NonConvolutionLoop: |
| 992 | return "unexpected loop dimension for convolution op" ; |
| 993 | case MatchConvolutionResult::OutputDimsNotParallel: |
| 994 | return "expected all iterators used to access outputs to be parallel" ; |
| 995 | case MatchConvolutionResult::NonOutputDimNotReduction: |
| 996 | return "expected all iterators not used to access outputs to be reduction" ; |
| 997 | case MatchConvolutionResult::EmptyConvolvedDims: |
| 998 | return "expected convolved dim to be non-empty" ; |
| 999 | case MatchConvolutionResult::Success: |
| 1000 | return "" ; |
| 1001 | } |
| 1002 | llvm_unreachable("unhandled MatchConvolutionResult case" ); |
| 1003 | } |
| 1004 | |
| 1005 | bool mlir::linalg::isaConvolutionOpInterface(LinalgOp linalgOp, |
| 1006 | bool allowEmptyConvolvedDims) { |
| 1007 | return linalg::detail::isConvolutionInterfaceImpl( |
| 1008 | op: linalgOp.getOperation(), dimensions: nullptr, allowEmptyConvolvedDims) == |
| 1009 | linalg::detail::MatchConvolutionResult::Success; |
| 1010 | } |
| 1011 | |
| 1012 | LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) { |
| 1013 | MatchConvolutionResult res = isConvolutionInterfaceImpl(op); |
| 1014 | if (res != MatchConvolutionResult::Success) |
| 1015 | return op->emitError(message: getMatchConvolutionMessage(res)); |
| 1016 | return success(); |
| 1017 | } |
| 1018 | |
| 1019 | //===----------------------------------------------------------------------===// |
| 1020 | // FillOpInterface implementation |
| 1021 | //===----------------------------------------------------------------------===// |
| 1022 | |
| 1023 | enum class MatchFillResult { |
| 1024 | Success = 0, |
| 1025 | NotLinalgOp, |
| 1026 | WrongNumOperands, |
| 1027 | NotScalarInput |
| 1028 | }; |
| 1029 | |
| 1030 | static MatchFillResult isFillInterfaceImpl(Operation *op) { |
| 1031 | auto linalgOp = dyn_cast<linalg::LinalgOp>(op); |
| 1032 | if (!linalgOp) |
| 1033 | return MatchFillResult::NotLinalgOp; |
| 1034 | if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1) |
| 1035 | return MatchFillResult::WrongNumOperands; |
| 1036 | |
| 1037 | OpOperand *value = linalgOp.getDpsInputOperand(0); |
| 1038 | if (!linalgOp.isScalar(value)) |
| 1039 | return MatchFillResult::NotScalarInput; |
| 1040 | |
| 1041 | return MatchFillResult::Success; |
| 1042 | } |
| 1043 | |
| 1044 | LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) { |
| 1045 | auto res = isFillInterfaceImpl(op); |
| 1046 | if (res == MatchFillResult::NotLinalgOp) |
| 1047 | return op->emitError(message: "expected a LinalgOp" ); |
| 1048 | if (res == MatchFillResult::WrongNumOperands) |
| 1049 | return op->emitError(message: "expected op with 1 input and 1 output" ); |
| 1050 | if (res == MatchFillResult::NotScalarInput) |
| 1051 | return op->emitError(message: "expected op with scalar input" ); |
| 1052 | |
| 1053 | return success(); |
| 1054 | } |
| 1055 | |
| 1056 | //===----------------------------------------------------------------------===// |
| 1057 | // StructuredOpInterface implementation |
| 1058 | //===----------------------------------------------------------------------===// |
| 1059 | |
| 1060 | SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b, |
| 1061 | Location loc) { |
| 1062 | SmallVector<OpFoldResult> res; |
| 1063 | for (OpOperand &opOperand : getOperation()->getOpOperands()) { |
| 1064 | for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i) |
| 1065 | res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i)); |
| 1066 | } |
| 1067 | return res; |
| 1068 | } |
| 1069 | |
| 1070 | SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() { |
| 1071 | SmallVector<int64_t, 4> res; |
| 1072 | assert(!hasDynamicShape() && "expected operands to have static shapes" ); |
| 1073 | for (OpOperand &opOperand : getOperation()->getOpOperands()) |
| 1074 | llvm::append_range(res, getShape(&opOperand)); |
| 1075 | return res; |
| 1076 | } |
| 1077 | |
| 1078 | SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) { |
| 1079 | AffineMap map = getLoopsToShapesMap(); |
| 1080 | unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); |
| 1081 | auto viewSizes = createFlatListOfOperandDims(b, loc); |
| 1082 | SmallVector<Range, 4> res(numDims); |
| 1083 | for (unsigned idx = 0; idx < numRes; ++idx) { |
| 1084 | auto result = map.getResult(idx); |
| 1085 | if (auto d = dyn_cast<AffineDimExpr>(result)) { |
| 1086 | if (res[d.getPosition()].offset) |
| 1087 | continue; |
| 1088 | res[d.getPosition()] = |
| 1089 | Range{b.getIndexAttr(0), viewSizes[idx], b.getIndexAttr(1)}; |
| 1090 | } |
| 1091 | } |
| 1092 | return res; |
| 1093 | } |
| 1094 | |
| 1095 | /// Visitor to check if any of the given set of positions from AffineDimExprs |
| 1096 | /// are used within an AffineExpr. |
| 1097 | struct HasAffineDimExprVisitor |
| 1098 | : public AffineExprVisitor<HasAffineDimExprVisitor, bool> { |
| 1099 | HasAffineDimExprVisitor(llvm::SmallBitVector positions) |
| 1100 | : positions(std::move(positions)) {} |
| 1101 | |
| 1102 | bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) { |
| 1103 | return visit(expr: binaryOpExpr.getLHS()) || visit(expr: binaryOpExpr.getRHS()); |
| 1104 | } |
| 1105 | |
| 1106 | bool visitDimExpr(AffineDimExpr dimExpr) { |
| 1107 | return positions.test(Idx: dimExpr.getPosition()); |
| 1108 | } |
| 1109 | |
| 1110 | bool visitConstantExpr(AffineConstantExpr constExpr) { return false; } |
| 1111 | |
| 1112 | bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; } |
| 1113 | |
| 1114 | private: |
| 1115 | llvm::SmallBitVector positions; |
| 1116 | }; |
| 1117 | |
| 1118 | static std::pair<int64_t, int64_t> |
| 1119 | getResultsPositionInLoopsToShapeMap(LinalgOp &op) { |
| 1120 | int64_t inputRankSum = 0; |
| 1121 | int64_t outputRankSum = 0; |
| 1122 | for (OpOperand *input : op.getDpsInputOperands()) |
| 1123 | inputRankSum += op.getRank(input); |
| 1124 | for (OpOperand &output : op.getDpsInitsMutable()) |
| 1125 | outputRankSum += op.getRank(&output); |
| 1126 | return {inputRankSum, inputRankSum + outputRankSum}; |
| 1127 | } |
| 1128 | |
| 1129 | LogicalResult |
| 1130 | LinalgOp::reifyResultShapes(OpBuilder &b, |
| 1131 | ReifiedRankedShapedTypeDims &reifiedReturnShapes) { |
| 1132 | // An example that helps understand the logic below. |
| 1133 | // Consider the following expression O(i+j, j) += A(i,k) * B(k, j) |
| 1134 | // We want to express the shape of dim 0 of O in terms of shape of the inputs. |
| 1135 | // This is achieved as follows. |
| 1136 | // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1) |
| 1137 | // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1) |
| 1138 | // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2) |
| 1139 | // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap) |
| 1140 | // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1) |
| 1141 | AffineMap loopsToShapesMap = getLoopsToShapesMap(); |
| 1142 | |
| 1143 | // Find the position in the above map that represents the shape of the |
| 1144 | // result:dim being inferred. |
| 1145 | auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(*this); |
| 1146 | |
| 1147 | /// From loopsToShapesMap extract the submap that represents the shape of the |
| 1148 | /// (resultIdx, dim) needed. |
| 1149 | AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap( |
| 1150 | resultShapesSubMapPos.first, |
| 1151 | resultShapesSubMapPos.second - resultShapesSubMapPos.first); |
| 1152 | AffineMap resultShapesFromInputShapesMap = |
| 1153 | loopToResultsShapeMap.compose(getShapesToLoopsMap()); |
| 1154 | |
| 1155 | // Check that the result dim map does not contain the positions corresponding |
| 1156 | // to the outputs. |
| 1157 | llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims()); |
| 1158 | outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second); |
| 1159 | HasAffineDimExprVisitor checkDimExpr(std::move(outputDims)); |
| 1160 | Location loc = getOperation()->getLoc(); |
| 1161 | IRRewriter rewriter(b); |
| 1162 | SmallVector<OpFoldResult> allResultDimValues = |
| 1163 | affine::makeComposedFoldedMultiResultAffineApply( |
| 1164 | rewriter, loc, resultShapesFromInputShapesMap, |
| 1165 | createFlatListOfOperandDims(b, loc)); |
| 1166 | int64_t pos = 0; |
| 1167 | ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults(); |
| 1168 | for (OpOperand &opOperand : getDpsInitsMutable()) { |
| 1169 | SmallVector<OpFoldResult> shapes; |
| 1170 | for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) { |
| 1171 | auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType()); |
| 1172 | if (!shapedType.isDynamicDim(dim)) { |
| 1173 | // Static dim: Return IntegerAttr. |
| 1174 | shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim))); |
| 1175 | } else { |
| 1176 | // Dynamic dim: Return Value. |
| 1177 | OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos]) |
| 1178 | ? createOrFoldDimOp(b, loc, opOperand.get(), dim) |
| 1179 | : allResultDimValues[pos]; |
| 1180 | shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr)); |
| 1181 | } |
| 1182 | pos++; |
| 1183 | } |
| 1184 | reifiedReturnShapes.emplace_back(std::move(shapes)); |
| 1185 | } |
| 1186 | return success(); |
| 1187 | } |
| 1188 | |
| 1189 | /// Return the index in the indexingMaps vector that corresponds to this |
| 1190 | /// `opOperand`. |
| 1191 | int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) { |
| 1192 | auto operandNumber = opOperand->getOperandNumber(); |
| 1193 | auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation()); |
| 1194 | if (!dpsIface.isDpsInput(opOperand)) |
| 1195 | return operandNumber; |
| 1196 | unsigned start = dpsIface.getDpsInits().getBeginOperandIndex(); |
| 1197 | assert(!dpsIface.isDpsInit(opOperand)); |
| 1198 | // Account for potential inputs that are not DPS and may not appear in |
| 1199 | // `indexingMaps`. |
| 1200 | return cast<DestinationStyleOpInterface>(*this->getOperation()) |
| 1201 | .getNumDpsInputs() + |
| 1202 | operandNumber - start; |
| 1203 | } |
| 1204 | |
| 1205 | LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { |
| 1206 | LinalgOp linalgOp = cast<LinalgOp>(op); |
| 1207 | // Mixed tensor/buffer operands are not allowed. |
| 1208 | if (!linalgOp.hasPureTensorSemantics() && |
| 1209 | !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0) |
| 1210 | return op->emitOpError(message: "expected to have pure tensor or buffer semantics" ); |
| 1211 | |
| 1212 | // Before checking indexing maps, we need to make sure the attributes |
| 1213 | // referenced by it are valid. |
| 1214 | if (linalgOp.hasDynamicIndexingMaps()) |
| 1215 | if (failed(linalgOp.verifyIndexingMapRequiredAttributes())) |
| 1216 | return failure(); |
| 1217 | |
| 1218 | // All input/output operands must be indexed. |
| 1219 | if (static_cast<int64_t>(linalgOp.getIndexingMapsArray().size()) != |
| 1220 | linalgOp->getNumOperands()) |
| 1221 | return op->emitOpError(message: "expected the number of indexing_map (" ) |
| 1222 | << linalgOp.getIndexingMapsArray().size() |
| 1223 | << ") to be equal to the number of input/output operands (" |
| 1224 | << linalgOp->getNumOperands() << ")" ; |
| 1225 | |
| 1226 | // Set this flag if this op has user defined maps. This is required to guard |
| 1227 | // the below error condition which assume default indexing maps. |
| 1228 | for (OpOperand &opOperand : linalgOp->getOpOperands()) { |
| 1229 | AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand); |
| 1230 | |
| 1231 | // Symbols disallowed. |
| 1232 | if (indexingMap.getNumSymbols() != 0) |
| 1233 | return op->emitOpError("unexpected symbols in indexing_map #" ) |
| 1234 | << opOperand.getOperandNumber(); |
| 1235 | |
| 1236 | // Domain must be consistent. |
| 1237 | unsigned numLoops = linalgOp.getNumLoops(); |
| 1238 | if (indexingMap.getNumDims() != numLoops) |
| 1239 | return op->emitOpError("expected indexing_map #" ) |
| 1240 | << opOperand.getOperandNumber() << " to have " << numLoops |
| 1241 | << " dim(s) to match the number of loops" ; |
| 1242 | |
| 1243 | int64_t rank = linalgOp.getRank(&opOperand); |
| 1244 | |
| 1245 | if (indexingMap.getNumResults() != rank) |
| 1246 | return op->emitOpError("expected operand rank (" ) |
| 1247 | << rank << ") to match the result rank of indexing_map #" |
| 1248 | << opOperand.getOperandNumber() << " (" |
| 1249 | << indexingMap.getNumResults() << ")" ; |
| 1250 | } |
| 1251 | SmallVector<unsigned> redDims; |
| 1252 | linalgOp.getReductionDims(redDims); |
| 1253 | |
| 1254 | if (!linalgOp.getShapesToLoopsMap()) |
| 1255 | return op->emitOpError(message: "expected the shape-to-loops map to be non-null" ); |
| 1256 | |
| 1257 | // Check if given shapes match to inferred shapes. |
| 1258 | SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges(); |
| 1259 | SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0); |
| 1260 | // Verify only static cases since we can't get exact dimension sizes and |
| 1261 | // loop ranges for dynamic cases in this stage. |
| 1262 | if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) { |
| 1263 | for (int64_t &range : endLoopRangeValues) |
| 1264 | range -= 1; |
| 1265 | for (OpOperand &opOperand : linalgOp->getOpOperands()) { |
| 1266 | AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand); |
| 1267 | SmallVector<int64_t, 4> startIndices = |
| 1268 | indexingMap.compose(startLoopRangeValues); |
| 1269 | SmallVector<int64_t, 4> endIndices = |
| 1270 | indexingMap.compose(endLoopRangeValues); |
| 1271 | ArrayRef<int64_t> shape = linalgOp.getShape(&opOperand); |
| 1272 | for (auto dim : llvm::seq<int64_t>(0, shape.size())) { |
| 1273 | // Ignore dynamic dimension or the case that the dimension size is 0 |
| 1274 | if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0) |
| 1275 | continue; |
| 1276 | |
| 1277 | // The first index or last index should be the maximum or the minimum in |
| 1278 | // the inferred index ranges since the range is increasing or |
| 1279 | // decreasing. The size of dimensions of input/output operands and the |
| 1280 | // maximum value + 1 in the inferred range should be the same. But, for |
| 1281 | // now we check if the inferred ranges are in boundary of input/output |
| 1282 | // operands' size or not in case that Affine Expressions are complicated |
| 1283 | // such as d0 * 3 |
| 1284 | // + d1 since it is not easy to handle the issues. |
| 1285 | // Found the case that this solution can't check, for example, (d0, d1) |
| 1286 | // -> (d1 - d0) |
| 1287 | int64_t inferredDimSize = |
| 1288 | std::max(startIndices[dim], endIndices[dim]) + 1; |
| 1289 | if (std::min(startIndices[dim], endIndices[dim]) < 0) { |
| 1290 | std::string mapStr; |
| 1291 | { |
| 1292 | llvm::raw_string_ostream os(mapStr); |
| 1293 | os << indexingMap; |
| 1294 | } |
| 1295 | return op->emitOpError( |
| 1296 | "unexpected result less than 0 at expression #" ) |
| 1297 | << dim << " in " << mapStr; |
| 1298 | } |
| 1299 | if (isa<AffineDimExpr>(indexingMap.getResult(dim))) { |
| 1300 | if (inferredDimSize != shape[dim]) { |
| 1301 | return op->emitOpError("inferred input/output operand #" ) |
| 1302 | << opOperand.getOperandNumber() << " has shape's dimension #" |
| 1303 | << dim << " to be " << inferredDimSize << ", but found " |
| 1304 | << shape[dim]; |
| 1305 | } |
| 1306 | } else { |
| 1307 | if (inferredDimSize > shape[dim]) { |
| 1308 | return op->emitOpError("inferred input/output operand #" ) |
| 1309 | << opOperand.getOperandNumber() << " has shape's dimension #" |
| 1310 | << dim << " to be greater than or equal to " |
| 1311 | << inferredDimSize << ", but found " << shape[dim]; |
| 1312 | } |
| 1313 | } |
| 1314 | } |
| 1315 | } |
| 1316 | } |
| 1317 | |
| 1318 | // Check the region has exactly one block. |
| 1319 | if (linalgOp->getNumRegions() != 1 || |
| 1320 | !llvm::hasSingleElement(linalgOp->getRegion(0))) |
| 1321 | return op->emitOpError(message: "expects to have 1 region with 1 block" ); |
| 1322 | |
| 1323 | // Simplifying assumption: bbargs match 1-1 with shape operands elemental |
| 1324 | // types. |
| 1325 | // TODO: once ranked shape types are plugged in, we may want to drop the |
| 1326 | // corresponding bbargs, that can never be read from. This will be subject to |
| 1327 | // consistency discussions (i.e. what to do with output tensors whose bbarg is |
| 1328 | // not used). |
| 1329 | Block &block = linalgOp->getRegion(0).front(); |
| 1330 | |
| 1331 | if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments()) |
| 1332 | return op->emitOpError(message: "expected as many non-induction variable region " |
| 1333 | "arguments as the number of input/output operands" ); |
| 1334 | |
| 1335 | for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) { |
| 1336 | Type elementType = opOperand->get().getType(); |
| 1337 | if (isa<MemRefType, RankedTensorType>(elementType)) |
| 1338 | elementType = getElementTypeOrSelf(opOperand->get().getType()); |
| 1339 | Type argType = block.getArgument(opOperand->getOperandNumber()).getType(); |
| 1340 | if (elementType != argType) |
| 1341 | return op->emitOpError("expected type of bb argument #" ) |
| 1342 | << opOperand->getOperandNumber() << " (" << argType << ")" |
| 1343 | << " to match element or self type of the corresponding operand (" |
| 1344 | << elementType << ")" ; |
| 1345 | } |
| 1346 | |
| 1347 | return success(); |
| 1348 | } |
| 1349 | |