| 1 | //===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" |
| 10 | |
| 11 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 12 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 13 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 14 | #include "mlir/Dialect/Complex/IR/Complex.h" |
| 15 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 16 | #include "mlir/IR/AffineExpr.h" |
| 17 | #include "mlir/IR/AffineExprVisitor.h" |
| 18 | #include "mlir/IR/AffineMap.h" |
| 19 | #include "mlir/IR/BuiltinTypeInterfaces.h" |
| 20 | #include "mlir/IR/MLIRContext.h" |
| 21 | #include "mlir/IR/TypeUtilities.h" |
| 22 | #include "llvm/ADT/STLExtras.h" |
| 23 | #include "llvm/ADT/SetOperations.h" |
| 24 | #include "llvm/ADT/SmallBitVector.h" |
| 25 | #include "llvm/ADT/SmallVector.h" |
| 26 | #include "llvm/Support/Casting.h" |
| 27 | #include "llvm/Support/raw_ostream.h" |
| 28 | #include <optional> |
| 29 | |
| 30 | using namespace mlir; |
| 31 | using namespace mlir::linalg; |
| 32 | |
| 33 | /// Include the definitions of the copy operation interface. |
| 34 | #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc" |
| 35 | |
| 36 | //===----------------------------------------------------------------------===// |
| 37 | // Interface utility functions |
| 38 | //===----------------------------------------------------------------------===// |
| 39 | |
| 40 | bool linalg::detail::canOpOperandsBeDroppedImpl( |
| 41 | linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) { |
| 42 | SmallVector<AffineMap> indexingMaps; |
| 43 | for (auto &opOperand : linalgOp->getOpOperands()) { |
| 44 | if (llvm::is_contained(Range&: droppedOperands, Element: &opOperand)) |
| 45 | continue; |
| 46 | indexingMaps.push_back(Elt: linalgOp.getMatchingIndexingMap(opOperand: &opOperand)); |
| 47 | } |
| 48 | if (indexingMaps.empty()) { |
| 49 | // If there are no indexing maps, the operand can only be dropped |
| 50 | // if the op has no loops. |
| 51 | return linalgOp.getNumLoops() == 0; |
| 52 | } |
| 53 | return inversePermutation(map: concatAffineMaps( |
| 54 | maps: indexingMaps, context: linalgOp.getContext())) != AffineMap(); |
| 55 | } |
| 56 | |
| 57 | //===----------------------------------------------------------------------===// |
| 58 | // CopyOpInterface implementation |
| 59 | //===----------------------------------------------------------------------===// |
| 60 | |
| 61 | bool linalg::isaCopyOpInterface(LinalgOp op) { |
| 62 | // Check all loops are parallel and linalgOp is single input and output. |
| 63 | if (!op.isAllParallelLoops() || !op.isSingleInputOutput()) |
| 64 | return false; |
| 65 | |
| 66 | auto mapRange = op.getIndexingMapsArray(); |
| 67 | if (mapRange.size() != 2 || !mapRange.front().isIdentity() || |
| 68 | !mapRange.back().isIdentity()) { |
| 69 | return false; |
| 70 | } |
| 71 | // Region. |
| 72 | return llvm::hasSingleElement(C&: op.getBlock()->getOperations()); |
| 73 | } |
| 74 | |
| 75 | //===----------------------------------------------------------------------===// |
| 76 | // FillOpInterface implementation |
| 77 | //===----------------------------------------------------------------------===// |
| 78 | /// Detects if a linalg.generic operation represents a fill with an inlined |
| 79 | /// constant. If so, returns the constant value. Otherwise, returns |
| 80 | /// std::nullopt. |
| 81 | static std::optional<Value> isaInlinedFillOp(GenericOp op) { |
| 82 | if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 || |
| 83 | op.getNumDpsInputs() != 0) |
| 84 | return std::nullopt; |
| 85 | |
| 86 | // Init should not be referenced. |
| 87 | if (op.payloadUsesValueFromOperand(opOperand: op.getDpsInitOperand(i: 0))) |
| 88 | return std::nullopt; |
| 89 | |
| 90 | Block *body = op.getBody(); |
| 91 | if (body->getOperations().size() != 1) |
| 92 | return std::nullopt; |
| 93 | |
| 94 | auto yieldOp = dyn_cast<linalg::YieldOp>(Val&: body->back()); |
| 95 | if (!yieldOp || yieldOp.getNumOperands() != 1) |
| 96 | return std::nullopt; |
| 97 | |
| 98 | Value yieldOperand = yieldOp->getOperand(idx: 0); |
| 99 | if (!yieldOperand.getDefiningOp<arith::ConstantOp>() && |
| 100 | !yieldOperand.getDefiningOp<complex::ConstantOp>()) |
| 101 | return std::nullopt; |
| 102 | |
| 103 | return yieldOperand; |
| 104 | } |
| 105 | |
| 106 | /// Detects if a linalg.generic operation represents an external scalar input. |
| 107 | /// If so, returns the constant value. Otherwise, returns std::nullopt. |
| 108 | static std::optional<Value> isaExternalFillOp(GenericOp op) { |
| 109 | // Structural. |
| 110 | if (!op.isAllParallelLoops() || !op.isSingleInputOutput() || |
| 111 | !op.isSingleYieldOp()) |
| 112 | return std::nullopt; |
| 113 | |
| 114 | // Input should be referenced and init should not. |
| 115 | if (!op.payloadUsesValueFromOperand(opOperand: op.getDpsInputOperand(i: 0)) || |
| 116 | op.payloadUsesValueFromOperand(opOperand: op.getDpsInitOperand(i: 0))) |
| 117 | return std::nullopt; |
| 118 | |
| 119 | OpOperand *value = op.getDpsInputOperand(i: 0); |
| 120 | if (!op.isScalar(opOperand: value)) |
| 121 | return std::nullopt; |
| 122 | return value->get(); |
| 123 | } |
| 124 | |
| 125 | std::optional<Value> linalg::isaFillOpInterface(GenericOp op) { |
| 126 | if (auto fillVal = isaInlinedFillOp(op)) |
| 127 | return fillVal; |
| 128 | return isaExternalFillOp(op); |
| 129 | } |
| 130 | |
| 131 | //===----------------------------------------------------------------------===// |
| 132 | // BroadcastOpInterface implementation |
| 133 | //===----------------------------------------------------------------------===// |
| 134 | std::optional<SmallVector<int64_t>> |
| 135 | linalg::isaBroadcastOpInterface(GenericOp op) { |
| 136 | // Structural. |
| 137 | if (!op.isAllParallelLoops() || !op.isSingleInputOutput() || |
| 138 | !op.isSingleYieldOp()) |
| 139 | return std::nullopt; |
| 140 | |
| 141 | auto srcTy = op.getDpsInputOperand(i: 0)->get().getType(); |
| 142 | auto dstTy = op.getDpsInitOperand(i: 0)->get().getType(); |
| 143 | if (!isa<MemRefType, RankedTensorType>(Val: srcTy) || |
| 144 | !isa<MemRefType, RankedTensorType>(Val: dstTy)) |
| 145 | return std::nullopt; |
| 146 | |
| 147 | // Check output is identity map. Broadcast could additionally be |
| 148 | // employing permutation of indices and that would be expressible |
| 149 | // in linalg.generic but is not expressible for named broadcast op. |
| 150 | auto dstMap = op.getIndexingMapsArray()[1]; |
| 151 | if (!dstMap.isIdentity()) |
| 152 | return std::nullopt; |
| 153 | |
| 154 | SmallVector<int64_t> position; |
| 155 | auto srcMap = op.getIndexingMapsArray()[0]; |
| 156 | |
| 157 | if (srcMap.getResults().size() >= dstMap.getResults().size()) |
| 158 | return std::nullopt; |
| 159 | |
| 160 | // Check input map is monotonically increasing DimIds. |
| 161 | for (unsigned i = 0; i < srcMap.getNumResults(); ++i) { |
| 162 | auto expr = llvm::dyn_cast<AffineDimExpr>(Val: srcMap.getResults()[i]); |
| 163 | if (!expr) |
| 164 | return std::nullopt; |
| 165 | int64_t pos = expr.getPosition(); |
| 166 | if (i > 0 && pos <= position[i - 1]) |
| 167 | return std::nullopt; |
| 168 | position.push_back(Elt: expr.getPosition()); |
| 169 | } |
| 170 | |
| 171 | SmallVector<int64_t> broadcastedDims; |
| 172 | auto numDims = srcMap.getNumDims(); |
| 173 | // This is quadratic but number of items is generally small. |
| 174 | for (auto dim : llvm::seq<int64_t>(Begin: 0, End: numDims)) { |
| 175 | if (!llvm::is_contained(Range&: position, Element: dim)) |
| 176 | broadcastedDims.push_back(Elt: dim); |
| 177 | } |
| 178 | return broadcastedDims; |
| 179 | } |
| 180 | |
| 181 | //===----------------------------------------------------------------------===// |
| 182 | // TransposeOpInterface implementation |
| 183 | //===----------------------------------------------------------------------===// |
| 184 | std::optional<SmallVector<int64_t>> |
| 185 | linalg::isaTransposeOpInterface(GenericOp op) { |
| 186 | // To specialize as a transpose op, the genericOp must be |
| 187 | // all parallel loops, single input, single output, and its body |
| 188 | // should be just a yield op, yielding input as output as is (no compute). |
| 189 | if (!op.isAllParallelLoops() || !op.isSingleInputOutput() || |
| 190 | !op.isSingleYieldOp()) |
| 191 | return std::nullopt; |
| 192 | |
| 193 | auto mapRange = op.getIndexingMapsArray(); |
| 194 | if (mapRange.size() != 2) |
| 195 | return std::nullopt; |
| 196 | |
| 197 | auto mapOfInput = mapRange.front(); |
| 198 | auto mapOfResult = mapRange.back(); |
| 199 | |
| 200 | // linalg.transpose permutes the dimensions of input using this |
| 201 | // rule: dim(result, i) = dim(input, permutation[i]) |
| 202 | if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation()) |
| 203 | return std::nullopt; |
| 204 | |
| 205 | SmallVector<int64_t> permutation(mapOfInput.getNumDims()); |
| 206 | for (unsigned i = 0; i < mapOfInput.getNumDims(); ++i) { |
| 207 | auto expr = llvm::cast<AffineDimExpr>(Val: mapOfInput.getResults()[i]); |
| 208 | permutation[expr.getPosition()] = i; |
| 209 | } |
| 210 | return permutation; |
| 211 | } |
| 212 | |
| 213 | //===----------------------------------------------------------------------===// |
| 214 | // Elementwise Single Unary/Binary-OpInterface implementation |
| 215 | //===----------------------------------------------------------------------===// |
| 216 | static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, |
| 217 | unsigned arity) { |
| 218 | // Check all loops are parallel. |
| 219 | if (!op.isAllParallelLoops() || op.getNumLoops() < 1) |
| 220 | return false; |
| 221 | |
| 222 | // Check there are arity-inputs, 1-output and all are identity-maps. |
| 223 | if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 || |
| 224 | !llvm::all_of(Range: op.getIndexingMapsArray(), |
| 225 | P: [](AffineMap map) { return map.isIdentity(); })) |
| 226 | return false; |
| 227 | |
| 228 | // Init should not be referenced for elementwise operations. |
| 229 | if (op.payloadUsesValueFromOperand(opOperand: op.getDpsInitOperand(i: 0))) |
| 230 | return false; |
| 231 | |
| 232 | // A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such |
| 233 | // as resulting from producer-consumer fusion. Here, we restrict to two ops in |
| 234 | // the body, where the first is the elementwise single op and the second a |
| 235 | // yield. |
| 236 | Block *body = op.getBody(); |
| 237 | if (body->getOperations().size() != 2) |
| 238 | return false; |
| 239 | |
| 240 | Operation *oper = &body->front(); |
| 241 | if (oper->getNumOperands() != arity || oper->getNumResults() != 1) |
| 242 | return false; |
| 243 | |
| 244 | auto yieldOp = dyn_cast<linalg::YieldOp>(Val&: body->back()); |
| 245 | if (!yieldOp || yieldOp.getNumOperands() != 1 || |
| 246 | yieldOp->getOperand(idx: 0).getDefiningOp() != oper) |
| 247 | return false; |
| 248 | return true; |
| 249 | } |
| 250 | |
| 251 | bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op) { |
| 252 | // All basic elemwise checks. |
| 253 | if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, arity: 1)) |
| 254 | return false; |
| 255 | |
| 256 | // Check input is actully used. |
| 257 | if (!op.payloadUsesValueFromOperand(opOperand: op.getDpsInputOperand(i: 0))) |
| 258 | return false; |
| 259 | return true; |
| 260 | } |
| 261 | |
| 262 | bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op) { |
| 263 | if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, arity: 2)) |
| 264 | return false; |
| 265 | |
| 266 | // Check both inputs are used (elementwise). |
| 267 | OpOperand *inputOpOperand0 = op.getDpsInputOperand(i: 0); |
| 268 | OpOperand *inputOpOperand1 = op.getDpsInputOperand(i: 1); |
| 269 | if (!op.payloadUsesValueFromOperand(opOperand: inputOpOperand0) || |
| 270 | !op.payloadUsesValueFromOperand(opOperand: inputOpOperand1)) |
| 271 | return false; |
| 272 | return true; |
| 273 | } |
| 274 | |
| 275 | //===----------------------------------------------------------------------===// |
| 276 | // ContractionOpInterface implementation |
| 277 | //===----------------------------------------------------------------------===// |
| 278 | |
| 279 | /// If the value is defined by a chain of unary side effect-free, go up the |
| 280 | /// use-def chain until the first value that isn't defined by such an op. |
| 281 | // TODO: relax to multi-operands with constants, which are technically unary ops |
| 282 | // as needed (e.g. add5). |
| 283 | static Value getSourceSkipUnary(Value value) { |
| 284 | Operation *op = value.getDefiningOp(); |
| 285 | while (op && op->getNumOperands() == 1) { |
| 286 | auto iface = dyn_cast<MemoryEffectOpInterface>(Val: op); |
| 287 | if (!iface || !iface.hasNoEffect()) |
| 288 | break; |
| 289 | value = op->getOperand(idx: 0); |
| 290 | op = value.getDefiningOp(); |
| 291 | } |
| 292 | return value; |
| 293 | } |
| 294 | |
| 295 | bool mlir::linalg::detail::isContractionBody( |
| 296 | Block &block, function_ref<bool(Operation *, Operation *)> isaPair, |
| 297 | llvm::raw_ostream &errs) { |
| 298 | if (block.empty() || !block.back().mightHaveTrait<OpTrait::IsTerminator>()) { |
| 299 | errs << "no terminator in the block" ; |
| 300 | return false; |
| 301 | } |
| 302 | |
| 303 | if (block.getNumArguments() != 3) { |
| 304 | errs << "expected block with 3 arguments" ; |
| 305 | return false; |
| 306 | } |
| 307 | |
| 308 | Operation *terminator = block.getTerminator(); |
| 309 | if (terminator->getNumOperands() != 1) { |
| 310 | errs << "expected terminator with 1 operand" ; |
| 311 | return false; |
| 312 | } |
| 313 | |
| 314 | Value yielded = getSourceSkipUnary(value: terminator->getOperand(idx: 0)); |
| 315 | Operation *reductionOp = yielded.getDefiningOp(); |
| 316 | if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) { |
| 317 | errs << "expected reduction op to be binary" ; |
| 318 | return false; |
| 319 | } |
| 320 | |
| 321 | Value reductionLHS = getSourceSkipUnary(value: reductionOp->getOperand(idx: 0)); |
| 322 | Value reductionRHS = getSourceSkipUnary(value: reductionOp->getOperand(idx: 1)); |
| 323 | |
| 324 | if (reductionLHS != block.getArgument(i: 2) && |
| 325 | reductionRHS != block.getArgument(i: 2)) { |
| 326 | errs << "expected reduction to take block argument #2 as one of the " |
| 327 | "operands (modulo unary casts)" ; |
| 328 | return false; |
| 329 | } |
| 330 | |
| 331 | Value contributed = getSourceSkipUnary( |
| 332 | value: isa<BlockArgument>(Val: reductionLHS) ? reductionRHS : reductionLHS); |
| 333 | Operation *elementwiseOp = contributed.getDefiningOp(); |
| 334 | if (!elementwiseOp || elementwiseOp->getNumResults() != 1 || |
| 335 | elementwiseOp->getNumOperands() != 2) { |
| 336 | errs << "expected elementwise op to be binary" ; |
| 337 | return false; |
| 338 | } |
| 339 | |
| 340 | if (!isaPair(elementwiseOp, reductionOp)) { |
| 341 | errs << "expected reduction/elementwise op kind not satisfied" ; |
| 342 | return false; |
| 343 | } |
| 344 | |
| 345 | Value elementwiseLHS = getSourceSkipUnary(value: elementwiseOp->getOperand(idx: 0)); |
| 346 | Value elementwiseRHS = getSourceSkipUnary(value: elementwiseOp->getOperand(idx: 1)); |
| 347 | if ((elementwiseLHS == block.getArgument(i: 0) && |
| 348 | elementwiseRHS == block.getArgument(i: 1)) || |
| 349 | (elementwiseLHS == block.getArgument(i: 1) && |
| 350 | elementwiseRHS == block.getArgument(i: 0))) { |
| 351 | return true; |
| 352 | } |
| 353 | |
| 354 | errs << "expected elementwise op to apply to block arguments (modulo unary " |
| 355 | "casts)" ; |
| 356 | return false; |
| 357 | } |
| 358 | |
| 359 | /// Returns true if the two operations are of the kinds specified by a pair of |
| 360 | /// consecutive template arguments. |
| 361 | template <typename AddOpTy, typename MulOpTy, typename... Args> |
| 362 | static bool isPairTemplateImpl(Operation *add, Operation *mul) { |
| 363 | static_assert(sizeof...(Args) % 2 == 0, |
| 364 | "expected an even number of template arguments" ); |
| 365 | if (isa<AddOpTy>(add) && isa<MulOpTy>(mul)) |
| 366 | return true; |
| 367 | |
| 368 | if constexpr (sizeof...(Args) > 0) |
| 369 | return isPairTemplateImpl<Args...>(add, mul); |
| 370 | else |
| 371 | return false; |
| 372 | } |
| 373 | |
| 374 | /// Returns true if the block is a body of a contraction with the kinds of |
| 375 | /// operations given pairwise by template arguments. |
| 376 | template <typename... Args> |
| 377 | static bool isContractionBody(Block &block) { |
| 378 | return linalg::detail::isContractionBody(block, isaPair: &isPairTemplateImpl<Args...>); |
| 379 | } |
| 380 | |
| 381 | /// Given an `indexingMap` and its corresponding `iterators`, returns |
| 382 | /// the positions of the iterators of type `iter` that are indexed by |
| 383 | /// the `indexingMap` as a permutation. This is useful to infer various |
| 384 | /// subcomputations on a `LinalgOp`. This is performed by looking up |
| 385 | /// each result in the `indexingMap` and determining whether: |
| 386 | /// - It is a single AffineDimExpr. |
| 387 | /// - It is the only result involving this AffineDimExpr. |
| 388 | static llvm::SmallDenseSet<int64_t> |
| 389 | findPermutationsIndexingOperand(AffineMap indexingMap, |
| 390 | ArrayRef<utils::IteratorType> iterators, |
| 391 | utils::IteratorType iter) { |
| 392 | assert(iterators.size() == indexingMap.getNumDims()); |
| 393 | llvm::SmallDenseSet<int64_t> res; |
| 394 | for (AffineExpr e : indexingMap.getResults()) { |
| 395 | if (auto d = dyn_cast<AffineDimExpr>(Val&: e)) { |
| 396 | if (iterators[d.getPosition()] == iter && |
| 397 | llvm::count_if(Range: indexingMap.getResults(), P: [d](AffineExpr e) { |
| 398 | return e.isFunctionOfDim(position: d.getPosition()); |
| 399 | }) == 1) |
| 400 | res.insert(V: d.getPosition()); |
| 401 | } |
| 402 | } |
| 403 | return res; |
| 404 | } |
| 405 | |
| 406 | namespace { |
| 407 | auto par = utils::IteratorType::parallel; |
| 408 | auto red = utils::IteratorType::reduction; |
| 409 | } // namespace |
| 410 | |
| 411 | /// Infer the iterator types from the init affine map. This looks at which dims |
| 412 | /// are present in the map results, and returns an iterator types array with |
| 413 | /// parallel types for dims that are present, and reduction types for dims that |
| 414 | /// are not present. |
| 415 | static FailureOr<SmallVector<utils::IteratorType>> |
| 416 | inferIteratorsFromOutMap(AffineMap map) { |
| 417 | if (!map.isProjectedPermutation()) |
| 418 | return failure(); |
| 419 | SmallVector<utils::IteratorType> iterators(map.getNumDims(), red); |
| 420 | for (auto expr : map.getResults()) |
| 421 | if (auto dim = dyn_cast<AffineDimExpr>(Val&: expr)) |
| 422 | iterators[dim.getPosition()] = par; |
| 423 | return iterators; |
| 424 | } |
| 425 | |
| 426 | /// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form |
| 427 | /// a matmul subcomputation within `linalgOp`. These dimensions are such that: |
| 428 | /// 1. The m dimension is involved in an outer-product along LHS |
| 429 | /// (i.e. it is a permutation on RES and LHS and does not appear in RHS). |
| 430 | /// 2. The n dimension is involved in an outer-product along RHS |
| 431 | /// (i.e. it is a permutation on RES and RHS and does not appear in LHS). |
| 432 | /// 3. The k dimension appears as a permutation on LHS and RHS. |
| 433 | /// 4. m, n and k appear only once in any given indexing. |
| 434 | /// 5. Optional batch dimensions that appear in all operands are captured. |
| 435 | /// This allows e.g. detecting that some contraction is embedded within |
| 436 | /// `linalgOp` with some orthogonal heuristic. |
| 437 | static FailureOr<ContractionDimensions> |
| 438 | inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps, |
| 439 | ArrayRef<utils::IteratorType> iterators) { |
| 440 | llvm::SmallDenseSet<int64_t> a = |
| 441 | findPermutationsIndexingOperand(indexingMap: indexingMaps[0], iterators, iter: par); |
| 442 | llvm::SmallDenseSet<int64_t> b = |
| 443 | findPermutationsIndexingOperand(indexingMap: indexingMaps[1], iterators, iter: par); |
| 444 | llvm::SmallDenseSet<int64_t> c = |
| 445 | findPermutationsIndexingOperand(indexingMap: indexingMaps[2], iterators, iter: par); |
| 446 | |
| 447 | // A & C - B are the iterators involved in an outer-product along A (the LHS). |
| 448 | llvm::SmallDenseSet<int64_t> ac = a; |
| 449 | llvm::set_intersect(S1&: ac, S2: c); |
| 450 | llvm::set_subtract(S1&: ac, S2: b); |
| 451 | // B & C - A are the iterators involved in an outer-product along B (the RHS). |
| 452 | llvm::SmallDenseSet<int64_t> bc = b; |
| 453 | llvm::set_intersect(S1&: bc, S2: c); |
| 454 | llvm::set_subtract(S1&: bc, S2: a); |
| 455 | // A & B & C are the "batch" dimensions. |
| 456 | llvm::SmallDenseSet<int64_t> batches = a; |
| 457 | llvm::set_intersect(S1&: batches, S2: b); |
| 458 | llvm::set_intersect(S1&: batches, S2: c); |
| 459 | |
| 460 | // A & B red are the reduction dimensions. |
| 461 | llvm::SmallDenseSet<int64_t> ra = |
| 462 | findPermutationsIndexingOperand(indexingMap: indexingMaps[0], iterators, iter: red); |
| 463 | llvm::SmallDenseSet<int64_t> rb = |
| 464 | findPermutationsIndexingOperand(indexingMap: indexingMaps[1], iterators, iter: red); |
| 465 | llvm::set_intersect(S1&: ra, S2: rb); |
| 466 | |
| 467 | // Return each set in sorted order. |
| 468 | ContractionDimensions dimensions{ |
| 469 | .batch: SmallVector<unsigned, 2>(batches.begin(), batches.end()), |
| 470 | .m: SmallVector<unsigned, 2>(ac.begin(), ac.end()), |
| 471 | .n: SmallVector<unsigned, 2>(bc.begin(), bc.end()), |
| 472 | .k: SmallVector<unsigned, 2>(ra.begin(), ra.end())}; |
| 473 | llvm::sort(Start: dimensions.batch.begin(), End: dimensions.batch.end()); |
| 474 | llvm::sort(Start: dimensions.m.begin(), End: dimensions.m.end()); |
| 475 | llvm::sort(Start: dimensions.n.begin(), End: dimensions.n.end()); |
| 476 | llvm::sort(Start: dimensions.k.begin(), End: dimensions.k.end()); |
| 477 | return dimensions; |
| 478 | } |
| 479 | |
| 480 | FailureOr<ContractionDimensions> |
| 481 | mlir::linalg::inferContractionDims(LinalgOp linalgOp) { |
| 482 | if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2) |
| 483 | return failure(); |
| 484 | return inferContractionDimsImpl(indexingMaps: linalgOp.getIndexingMapsArray(), |
| 485 | iterators: linalgOp.getIteratorTypesArray()); |
| 486 | } |
| 487 | |
| 488 | FailureOr<ContractionDimensions> |
| 489 | mlir::linalg::inferContractionDims(ArrayRef<AffineMap> indexingMaps) { |
| 490 | if (indexingMaps.size() != 3) |
| 491 | return failure(); |
| 492 | auto iterators = inferIteratorsFromOutMap(map: indexingMaps[2]); |
| 493 | if (failed(Result: iterators)) |
| 494 | return failure(); |
| 495 | return inferContractionDimsImpl(indexingMaps, iterators: iterators.value()); |
| 496 | } |
| 497 | |
| 498 | namespace mlir::linalg::detail { |
| 499 | enum class MatchContractionResult { |
| 500 | Success = 0, |
| 501 | NotLinalgOp, |
| 502 | WrongNumOperands, |
| 503 | NoReduction, |
| 504 | NotProjectedPermutations, |
| 505 | NotAddMul |
| 506 | }; |
| 507 | } // namespace mlir::linalg::detail |
| 508 | |
| 509 | mlir::linalg::detail::MatchContractionResult |
| 510 | mlir::linalg::detail::isContractionInterfaceImpl( |
| 511 | Operation *op, mlir::linalg::ContractionDimensions *dimensions) { |
| 512 | auto linalgOp = dyn_cast<linalg::LinalgOp>(Val: op); |
| 513 | if (!linalgOp) |
| 514 | return MatchContractionResult::NotLinalgOp; |
| 515 | if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1) |
| 516 | return MatchContractionResult::WrongNumOperands; |
| 517 | auto mapRange = linalgOp.getIndexingMapsArray(); |
| 518 | if (linalgOp.getNumReductionLoops() == 0) |
| 519 | return MatchContractionResult::NoReduction; |
| 520 | if (llvm::any_of(Range&: mapRange, |
| 521 | P: [](AffineMap m) { return !m.isProjectedPermutation(); })) |
| 522 | return MatchContractionResult::NotProjectedPermutations; |
| 523 | // TODO: more fields than add/mul. |
| 524 | // clang-format off |
| 525 | if (!::isContractionBody< |
| 526 | arith::MulFOp, arith::AddFOp, |
| 527 | arith::MulIOp, arith::AddIOp, |
| 528 | complex::MulOp, complex::AddOp, |
| 529 | arith::AndIOp, arith::OrIOp>( |
| 530 | block&: *linalgOp.getBlock())) { |
| 531 | return MatchContractionResult::NotAddMul; |
| 532 | } |
| 533 | // clang-format on |
| 534 | |
| 535 | if (dimensions) { |
| 536 | FailureOr<ContractionDimensions> res = inferContractionDims(linalgOp); |
| 537 | assert(succeeded(res) && "unexpected failure to infer contraction dims" ); |
| 538 | *dimensions = *res; |
| 539 | } |
| 540 | return MatchContractionResult::Success; |
| 541 | } |
| 542 | |
| 543 | StringRef |
| 544 | mlir::linalg::detail::getMatchContractionMessage(MatchContractionResult res) { |
| 545 | switch (res) { |
| 546 | case MatchContractionResult::NotLinalgOp: |
| 547 | return "expected a LinalgOp" ; |
| 548 | case MatchContractionResult::WrongNumOperands: |
| 549 | return "expected op with 2 inputs and 1 output" ; |
| 550 | case MatchContractionResult::NoReduction: |
| 551 | return "expected at least 1 reduction" ; |
| 552 | case MatchContractionResult::NotProjectedPermutations: |
| 553 | return "expected indexing maps to be projected permutations" ; |
| 554 | case MatchContractionResult::NotAddMul: |
| 555 | return "expected add/mul op in the body" ; |
| 556 | case MatchContractionResult::Success: |
| 557 | return "" ; |
| 558 | } |
| 559 | llvm_unreachable("unhandled MatchContractionResult case" ); |
| 560 | } |
| 561 | |
| 562 | bool mlir::linalg::isaContractionOpInterface(LinalgOp linalgOp) { |
| 563 | if (!linalgOp) |
| 564 | return false; |
| 565 | Operation *op = linalgOp.getOperation(); |
| 566 | return isa<ContractionOpInterface>(Val: op) || |
| 567 | (mlir::linalg::detail::isContractionInterfaceImpl(op) == |
| 568 | mlir::linalg::detail::MatchContractionResult::Success); |
| 569 | } |
| 570 | |
| 571 | /// Verify that a LinalgOp `op` is a contraction. |
| 572 | /// A Linalg contraction is defined in general terms: |
| 573 | /// 1. Has 2 input and 1 output shapes. |
| 574 | /// 2. Has at least one reduction dimension. |
| 575 | /// 3. Has only projected permutation indexing maps. |
| 576 | /// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field |
| 577 | /// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary |
| 578 | /// operations that may change the type (e.g. for mixed-precision). |
| 579 | /// As a consequence, when vectorization of such an op occurs, the only special |
| 580 | /// behavior is that the (unique) MulOpType is vectorized into a |
| 581 | /// `vector.contract`. All other ops are handled in a generic fashion. |
| 582 | /// In the future, we may wish to allow more input arguments and elementwise and |
| 583 | /// constant operations that do not involve the reduction dimension(s). |
| 584 | LogicalResult mlir::linalg::detail::verifyContractionInterface(Operation *op) { |
| 585 | auto res = isContractionInterfaceImpl(op); |
| 586 | if (res != MatchContractionResult::Success) |
| 587 | return op->emitError(message: getMatchContractionMessage(res)); |
| 588 | return success(); |
| 589 | } |
| 590 | |
| 591 | //===----------------------------------------------------------------------===// |
| 592 | // ConvolutionOpInterface implementation |
| 593 | //===----------------------------------------------------------------------===// |
| 594 | |
| 595 | /// Of the given two expressions returns one that is of type T (`lhs` gets |
| 596 | /// preference over `rhs`) |
| 597 | template <typename T> |
| 598 | static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs) { |
| 599 | return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) : nullptr); |
| 600 | } |
| 601 | |
| 602 | namespace { |
| 603 | /// Walk the indexing expressions for input of a convolution operation to verify |
| 604 | /// its of the right form, either |
| 605 | /// - AffineDimExpr |
| 606 | /// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))? |
| 607 | /// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)* |
| 608 | /// |
| 609 | /// classifies the AffineDimExpr as convolved dimensions or unconvolved |
| 610 | /// dimensions and verifies each dimension occurs only once. |
| 611 | struct ConvAccessExprWalker |
| 612 | : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> { |
| 613 | // Stores dimensions used in expressions of the above form. |
| 614 | llvm::SmallDenseSet<int64_t> convolvedDims; |
| 615 | // Stores the dual mapping between LHS and RHS of convolution exprs. |
| 616 | llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping; |
| 617 | // Stores single use dimensions used by an AffineDimExpr. |
| 618 | llvm::SmallDenseSet<int64_t> unConvolvedDims; |
| 619 | // Stores a mapping from convolved dims to their coefficient. |
| 620 | llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping; |
| 621 | |
| 622 | // Removes dims with multiple uses in the source input map from dimension |
| 623 | // sets tracked by this walker. |
| 624 | void clearMultiUseDims(AffineMap map) { |
| 625 | for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) { |
| 626 | if (llvm::count_if(Range: map.getResults(), P: [dimPos](AffineExpr e) { |
| 627 | return e.isFunctionOfDim(position: dimPos); |
| 628 | }) > 1) { |
| 629 | convolvedDims.erase(V: dimPos); |
| 630 | unConvolvedDims.erase(V: dimPos); |
| 631 | // If a duplicate dim is marked as convolved, the pair of the duplicate |
| 632 | // dim must be removed from the map as well. |
| 633 | auto it = convolvedDimMapping.find(Val: dimPos); |
| 634 | if (it != convolvedDimMapping.end()) { |
| 635 | int64_t pairedDim = it->second; |
| 636 | convolvedDims.erase(V: pairedDim); |
| 637 | unConvolvedDims.erase(V: pairedDim); |
| 638 | strideAndDilationMapping.erase(Val: pairedDim); |
| 639 | convolvedDimMapping.erase(Val: dimPos); |
| 640 | convolvedDimMapping.erase(Val: pairedDim); |
| 641 | } |
| 642 | } |
| 643 | } |
| 644 | } |
| 645 | |
| 646 | LogicalResult visitDimExpr(AffineDimExpr dimExpr) { |
| 647 | unsigned position = dimExpr.getPosition(); |
| 648 | if (unConvolvedDims.count(V: position) || convolvedDims.count(V: position)) { |
| 649 | return failure(); |
| 650 | } |
| 651 | unConvolvedDims.insert(V: position); |
| 652 | return success(); |
| 653 | } |
| 654 | |
| 655 | LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); } |
| 656 | |
| 657 | LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); } |
| 658 | |
| 659 | LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) { |
| 660 | // In pre-order visit, top level op has to be an add op. |
| 661 | if (binaryExpr.getKind() != AffineExprKind::Add) |
| 662 | return failure(); |
| 663 | auto lhsDimPos = getDimExprOrMulExprDimPos(expr: binaryExpr.getLHS()); |
| 664 | auto rhsDimPos = getDimExprOrMulExprDimPos(expr: binaryExpr.getRHS()); |
| 665 | if (failed(Result: lhsDimPos) || failed(Result: rhsDimPos)) |
| 666 | return failure(); |
| 667 | convolvedDimMapping[*lhsDimPos] = *rhsDimPos; |
| 668 | convolvedDimMapping[*rhsDimPos] = *lhsDimPos; |
| 669 | return success(); |
| 670 | } |
| 671 | |
| 672 | FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) { |
| 673 | if (auto dimExpr = dyn_cast<AffineDimExpr>(Val&: expr)) { |
| 674 | int64_t dim = dimExpr.getPosition(); |
| 675 | if (convolvedDims.count(V: dim) || unConvolvedDims.count(V: dim)) |
| 676 | return failure(); |
| 677 | // Stride/dilation for this dim is implicitly 1. |
| 678 | strideAndDilationMapping[dim] = |
| 679 | getAffineConstantExpr(constant: 1, context: expr.getContext()); |
| 680 | convolvedDims.insert(V: dim); |
| 681 | return dim; |
| 682 | } |
| 683 | if (auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(Val&: expr)) { |
| 684 | if (symbolMulExpr.getKind() != AffineExprKind::Mul) |
| 685 | return failure(); |
| 686 | auto lhsExpr = symbolMulExpr.getLHS(); |
| 687 | auto rhsExpr = symbolMulExpr.getRHS(); |
| 688 | // Check for symbol expression. |
| 689 | AffineExpr mulExpr = |
| 690 | getAffineExprOfType<AffineSymbolExpr>(lhs: lhsExpr, rhs: rhsExpr); |
| 691 | // If there was no symbol expr, check for constant expression. |
| 692 | if (!mulExpr) { |
| 693 | mulExpr = getAffineExprOfType<AffineConstantExpr>(lhs: lhsExpr, rhs: rhsExpr); |
| 694 | } |
| 695 | auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhs: lhsExpr, rhs: rhsExpr); |
| 696 | if (!mulExpr || !dimExpr) |
| 697 | return failure(); |
| 698 | int64_t dim = dimExpr.getPosition(); |
| 699 | if (convolvedDims.count(V: dim) || unConvolvedDims.count(V: dim)) |
| 700 | return failure(); |
| 701 | strideAndDilationMapping[dim] = mulExpr; |
| 702 | convolvedDims.insert(V: dim); |
| 703 | return dim; |
| 704 | } |
| 705 | return failure(); |
| 706 | } |
| 707 | }; |
| 708 | } // namespace |
| 709 | |
| 710 | static llvm::SmallDenseSet<int64_t> getPreservedDims(AffineMap map) { |
| 711 | assert(map.isProjectedPermutation() && |
| 712 | "expected map to have projected permutations" ); |
| 713 | llvm::SmallDenseSet<int64_t> preservedDims; |
| 714 | for (auto expr : map.getResults()) |
| 715 | preservedDims.insert(V: cast<AffineDimExpr>(Val&: expr).getPosition()); |
| 716 | return preservedDims; |
| 717 | } |
| 718 | |
| 719 | static SmallVector<int64_t, 2> |
| 720 | getConstantsFromExprList(const SmallVector<AffineExpr, 2> &exprs) { |
| 721 | SmallVector<int64_t, 2> vals; |
| 722 | for (auto e : exprs) { |
| 723 | auto constantExpr = dyn_cast<AffineConstantExpr>(Val&: e); |
| 724 | assert(constantExpr && "Found non-constant stride/dilation" ); |
| 725 | vals.push_back(Elt: constantExpr.getValue()); |
| 726 | } |
| 727 | return vals; |
| 728 | } |
| 729 | |
| 730 | /// Classifies dimensions in the `linalgOp` used by a convolution |
| 731 | /// subcomputation, as captured by `inputExprWalker`. If |
| 732 | /// `allowEmptyConvolvedDims` is not set this this will fail if there is not |
| 733 | /// at least convolved dimension pair (output image + filter loop). Convolution |
| 734 | /// dimensions are specified in sorted order, and strides match the order of |
| 735 | /// the filter loop dimensions, while the dilations match the order of the |
| 736 | /// output image dimensions. |
| 737 | static FailureOr<ConvolutionDimensions> |
| 738 | inferConvolutionDimsImpl(LinalgOp linalgOp, |
| 739 | ConvAccessExprWalker &inputExprWalker, |
| 740 | bool allowEmptyConvolvedDims) { |
| 741 | auto filterMap = |
| 742 | linalgOp.getMatchingIndexingMap(opOperand: linalgOp.getDpsInputOperand(i: 1)); |
| 743 | auto outputMap = |
| 744 | linalgOp.getMatchingIndexingMap(opOperand: linalgOp.getDpsInitOperand(i: 0)); |
| 745 | llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand( |
| 746 | indexingMap: filterMap, iterators: linalgOp.getIteratorTypesArray(), iter: par); |
| 747 | llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand( |
| 748 | indexingMap: outputMap, iterators: linalgOp.getIteratorTypesArray(), iter: par); |
| 749 | |
| 750 | // unConvolvedDims & outputDims - filterDims are the batch iterators. |
| 751 | llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims; |
| 752 | llvm::set_intersect(S1&: batch, S2: outputDims); |
| 753 | llvm::set_subtract(S1&: batch, S2: filterDims); |
| 754 | |
| 755 | // convolvedDims & outputDims are the output image iterators. |
| 756 | llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims; |
| 757 | llvm::set_intersect(S1&: oi, S2: outputDims); |
| 758 | |
| 759 | // filterDims & outputDims - unConvolvedDims are the output channel iterators. |
| 760 | llvm::SmallDenseSet<int64_t> oc = filterDims; |
| 761 | llvm::set_intersect(S1&: oc, S2: outputDims); |
| 762 | llvm::set_subtract(S1&: oc, S2: inputExprWalker.unConvolvedDims); |
| 763 | |
| 764 | // filterDims & outputDims & unConvolvedDims are the depth iterators. |
| 765 | llvm::SmallDenseSet<int64_t> depth = filterDims; |
| 766 | llvm::set_intersect(S1&: depth, S2: outputDims); |
| 767 | llvm::set_intersect(S1&: depth, S2: inputExprWalker.unConvolvedDims); |
| 768 | |
| 769 | llvm::SmallDenseSet<int64_t> filterReducedDims = |
| 770 | findPermutationsIndexingOperand(indexingMap: filterMap, |
| 771 | iterators: linalgOp.getIteratorTypesArray(), iter: red); |
| 772 | |
| 773 | // convolvedDims & filterReducedDims are the filter loop iterators. |
| 774 | llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims; |
| 775 | llvm::set_intersect(S1&: fl, S2: filterReducedDims); |
| 776 | |
| 777 | // unConvolvedDims & filterReducedDims are the input channel iterators. |
| 778 | llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims; |
| 779 | llvm::set_intersect(S1&: ic, S2: filterReducedDims); |
| 780 | |
| 781 | if (oi.empty() && !allowEmptyConvolvedDims) |
| 782 | return failure(); |
| 783 | |
| 784 | // Return each set in sorted order. |
| 785 | ConvolutionDimensions dimensions{ |
| 786 | .batch: SmallVector<unsigned, 2>(batch.begin(), batch.end()), |
| 787 | .outputImage: SmallVector<unsigned, 2>(oi.begin(), oi.end()), |
| 788 | .outputChannel: SmallVector<unsigned, 2>(oc.begin(), oc.end()), |
| 789 | .filterLoop: SmallVector<unsigned, 2>(fl.begin(), fl.end()), |
| 790 | .inputChannel: SmallVector<unsigned, 2>(ic.begin(), ic.end()), |
| 791 | .depth: SmallVector<unsigned, 2>(depth.begin(), depth.end()), |
| 792 | /*strides=*/SmallVector<int64_t, 2>{}, |
| 793 | /*dilations=*/SmallVector<int64_t, 2>{}}; |
| 794 | llvm::sort(Start: dimensions.batch.begin(), End: dimensions.batch.end()); |
| 795 | llvm::sort(Start: dimensions.outputImage.begin(), End: dimensions.outputImage.end()); |
| 796 | llvm::sort(Start: dimensions.outputChannel.begin(), End: dimensions.outputChannel.end()); |
| 797 | llvm::sort(Start: dimensions.filterLoop.begin(), End: dimensions.filterLoop.end()); |
| 798 | llvm::sort(Start: dimensions.inputChannel.begin(), End: dimensions.inputChannel.end()); |
| 799 | llvm::sort(Start: dimensions.depth.begin(), End: dimensions.depth.end()); |
| 800 | |
| 801 | // Use the op carried strides/dilations attribute if present. |
| 802 | auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>(name: "strides" ); |
| 803 | if (!nativeStrides) { |
| 804 | SmallVector<AffineExpr, 2> strideExprs; |
| 805 | for (unsigned oiDim : dimensions.outputImage) |
| 806 | strideExprs.push_back(Elt: inputExprWalker.strideAndDilationMapping[oiDim]); |
| 807 | dimensions.strides = getConstantsFromExprList(exprs: strideExprs); |
| 808 | } else { |
| 809 | dimensions.strides = llvm::to_vector<2>(Range: nativeStrides.getValues<int64_t>()); |
| 810 | } |
| 811 | auto nativeDilations = |
| 812 | linalgOp->getAttrOfType<DenseIntElementsAttr>(name: "dilations" ); |
| 813 | if (!nativeDilations) { |
| 814 | SmallVector<AffineExpr, 2> dilationExprs; |
| 815 | for (unsigned flDim : dimensions.filterLoop) |
| 816 | dilationExprs.push_back(Elt: inputExprWalker.strideAndDilationMapping[flDim]); |
| 817 | dimensions.dilations = getConstantsFromExprList(exprs: dilationExprs); |
| 818 | } else { |
| 819 | dimensions.dilations = |
| 820 | llvm::to_vector<2>(Range: nativeDilations.getValues<int64_t>()); |
| 821 | } |
| 822 | return dimensions; |
| 823 | } |
| 824 | |
| 825 | /// Find at least 1 parallel (output_image) and reduction (filter_loop) |
| 826 | /// dimension candidates that form a convolution subcomputation within |
| 827 | /// `linalgOp`. The LHS is assumed to be the convolution input while the |
| 828 | /// RHS is assumed as the filter. |
| 829 | /// These dimensions are such that: |
| 830 | /// 1. Optional batch dimensions that appear in the input and filter. |
| 831 | /// 2. The output_image dimension is involved in a cross-correlation along LHS |
| 832 | /// (i.e. it is a permutation on RES and LHS and has an associated |
| 833 | /// filter_loop in RHS). |
| 834 | /// 3. Optional output_channel dimension is involved in an outer-product along |
| 835 | /// RHS (i.e. it is a permutation on RES and RHS and does not appear in |
| 836 | /// LHS). |
| 837 | /// 4. Optional input_channel dimension appears as a permutation on LHS and |
| 838 | /// RHS. |
| 839 | /// 5. The filter_loop dimension appears as a permutation on the RHS and |
| 840 | /// represents the shape of the kernel cross-correlated along a |
| 841 | /// corresponding output_image dim. |
| 842 | /// 6. The input_channel dimension appears as a permutation on LHS and RHS. |
| 843 | /// 7. All dimensions appear only once in any given indexing map. |
| 844 | /// This allows e.g. detecting that some convolution is embedded within |
| 845 | /// `linalgOp` with some orthogonal heuristic. |
| 846 | /// When multiple dimension occurrences exist that match any classification |
| 847 | /// indices are returned in sorted order. |
| 848 | /// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty. |
| 849 | FailureOr<ConvolutionDimensions> |
| 850 | mlir::linalg::inferConvolutionDims(LinalgOp linalgOp) { |
| 851 | if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2) |
| 852 | return failure(); |
| 853 | |
| 854 | auto indexingMaps = linalgOp.getIndexingMapsArray(); |
| 855 | |
| 856 | // Check the input indexing map has the right form. |
| 857 | ConvAccessExprWalker inputExprWalker; |
| 858 | for (AffineExpr expr : indexingMaps[0].getResults()) |
| 859 | (void)inputExprWalker.visit(expr); |
| 860 | inputExprWalker.clearMultiUseDims(map: indexingMaps[0]); |
| 861 | |
| 862 | return inferConvolutionDimsImpl(linalgOp, inputExprWalker, |
| 863 | /*allowEmptyConvolvedDims=*/false); |
| 864 | } |
| 865 | |
| 866 | namespace mlir::linalg::detail { |
| 867 | enum class MatchConvolutionResult { |
| 868 | Success = 0, |
| 869 | NotLinalgOp, |
| 870 | WrongNumOperands, |
| 871 | WrongInputIndexingMap, |
| 872 | NotProjectedPermutations, |
| 873 | NonConvolutionLoop, |
| 874 | OutputDimsNotParallel, |
| 875 | NonOutputDimNotReduction, |
| 876 | EmptyConvolvedDims |
| 877 | }; |
| 878 | } // namespace mlir::linalg::detail |
| 879 | |
| 880 | mlir::linalg::detail::MatchConvolutionResult |
| 881 | mlir::linalg::detail::isConvolutionInterfaceImpl( |
| 882 | Operation *op, ConvolutionDimensions *dimensions, |
| 883 | bool allowEmptyConvolvedDims) { |
| 884 | auto linalgOp = dyn_cast<linalg::LinalgOp>(Val: op); |
| 885 | if (!linalgOp) |
| 886 | return MatchConvolutionResult::NotLinalgOp; |
| 887 | if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1) |
| 888 | return MatchConvolutionResult::WrongNumOperands; |
| 889 | |
| 890 | auto indexingMaps = linalgOp.getIndexingMapsArray(); |
| 891 | |
| 892 | // Check the input indexing map has the right form. |
| 893 | ConvAccessExprWalker inputExprWalker; |
| 894 | if (llvm::any_of(Range: indexingMaps[0].getResults(), |
| 895 | P: [&inputExprWalker](AffineExpr expr) { |
| 896 | return failed(Result: inputExprWalker.visit(expr)); |
| 897 | })) { |
| 898 | return MatchConvolutionResult::WrongInputIndexingMap; |
| 899 | } |
| 900 | |
| 901 | // Filter and output maps must be projected permutation. |
| 902 | if (!indexingMaps[1].isProjectedPermutation() || |
| 903 | !indexingMaps.back().isProjectedPermutation()) |
| 904 | return MatchConvolutionResult::NotProjectedPermutations; |
| 905 | |
| 906 | auto iteratorTypes = linalgOp.getIteratorTypesArray(); |
| 907 | |
| 908 | llvm::SmallDenseSet<int64_t> outputDims = |
| 909 | getPreservedDims(map: indexingMaps.back()); |
| 910 | llvm::SmallDenseSet<int64_t> filterDims = getPreservedDims(map: indexingMaps[1]); |
| 911 | // Make sure all loops are characterized as one of: |
| 912 | // - Batch loop : present in output, as non-convolved in input, not present in |
| 913 | // filter. |
| 914 | // - Output image dimension : present in output, convolved dims in input, not |
| 915 | // present in filter. |
| 916 | // - Output channel dimension : present in output, not present in input, |
| 917 | // present in filter. |
| 918 | // - Filter loop dimension : present in filter, convolved in input, not |
| 919 | // present in output. |
| 920 | // - Input channel dimension : unconvolved in input, not present in output, |
| 921 | // present in filter. |
| 922 | // - Depth multiplier : unconvolved in input, present in output, present in |
| 923 | // filter. |
| 924 | llvm::SmallDenseSet<int64_t> allLoopDims; |
| 925 | for (auto outputExpr : indexingMaps.back().getResults()) { |
| 926 | int64_t outputDim = cast<AffineDimExpr>(Val&: outputExpr).getPosition(); |
| 927 | if (inputExprWalker.unConvolvedDims.count(V: outputDim) && |
| 928 | !filterDims.count(V: outputDim)) { |
| 929 | // Batch dimension. |
| 930 | if (iteratorTypes[outputDim] != utils::IteratorType::parallel) |
| 931 | return MatchConvolutionResult::OutputDimsNotParallel; |
| 932 | allLoopDims.insert(V: outputDim); |
| 933 | continue; |
| 934 | } |
| 935 | if (inputExprWalker.convolvedDims.count(V: outputDim) && |
| 936 | !filterDims.count(V: outputDim)) { |
| 937 | // Output image Loop dimension. |
| 938 | if (iteratorTypes[outputDim] != utils::IteratorType::parallel) |
| 939 | return MatchConvolutionResult::OutputDimsNotParallel; |
| 940 | allLoopDims.insert(V: outputDim); |
| 941 | continue; |
| 942 | } |
| 943 | if (!inputExprWalker.convolvedDims.count(V: outputDim) && |
| 944 | !inputExprWalker.unConvolvedDims.count(V: outputDim) && |
| 945 | filterDims.count(V: outputDim)) { |
| 946 | // Output channel dimension. |
| 947 | if (iteratorTypes[outputDim] != utils::IteratorType::parallel) |
| 948 | return MatchConvolutionResult::OutputDimsNotParallel; |
| 949 | allLoopDims.insert(V: outputDim); |
| 950 | continue; |
| 951 | } |
| 952 | if (inputExprWalker.unConvolvedDims.count(V: outputDim) && |
| 953 | filterDims.count(V: outputDim)) { |
| 954 | // Depth multiplier. |
| 955 | if (iteratorTypes[outputDim] != utils::IteratorType::parallel) |
| 956 | return MatchConvolutionResult::OutputDimsNotParallel; |
| 957 | allLoopDims.insert(V: outputDim); |
| 958 | continue; |
| 959 | } |
| 960 | return MatchConvolutionResult::NonConvolutionLoop; |
| 961 | } |
| 962 | for (auto filterExpr : indexingMaps[1].getResults()) { |
| 963 | int64_t filterDim = cast<AffineDimExpr>(Val&: filterExpr).getPosition(); |
| 964 | if (outputDims.count(V: filterDim) && |
| 965 | !inputExprWalker.unConvolvedDims.count(V: filterDim) && |
| 966 | !inputExprWalker.convolvedDims.count(V: filterDim)) { |
| 967 | // Output channel dimension. This is already seen, continue; |
| 968 | continue; |
| 969 | } |
| 970 | if (inputExprWalker.convolvedDims.count(V: filterDim) && |
| 971 | !outputDims.count(V: filterDim)) { |
| 972 | // Filter loop dimension. |
| 973 | if (iteratorTypes[filterDim] != utils::IteratorType::reduction) |
| 974 | return MatchConvolutionResult::NonOutputDimNotReduction; |
| 975 | if (allLoopDims.count(V: filterDim)) |
| 976 | return MatchConvolutionResult::NonConvolutionLoop; |
| 977 | allLoopDims.insert(V: filterDim); |
| 978 | continue; |
| 979 | } |
| 980 | if (inputExprWalker.unConvolvedDims.count(V: filterDim) && |
| 981 | !outputDims.count(V: filterDim)) { |
| 982 | // Input channel dimension. |
| 983 | if (iteratorTypes[filterDim] != utils::IteratorType::reduction) |
| 984 | return MatchConvolutionResult::NonOutputDimNotReduction; |
| 985 | if (allLoopDims.count(V: filterDim)) |
| 986 | return MatchConvolutionResult::NonConvolutionLoop; |
| 987 | allLoopDims.insert(V: filterDim); |
| 988 | continue; |
| 989 | } |
| 990 | if (inputExprWalker.unConvolvedDims.count(V: filterDim) && |
| 991 | outputDims.count(V: filterDim)) { |
| 992 | // Depthwise loop. Already seen. |
| 993 | continue; |
| 994 | } |
| 995 | return MatchConvolutionResult::NonConvolutionLoop; |
| 996 | } |
| 997 | // All loops must be covered now. |
| 998 | if (allLoopDims.size() != linalgOp.getNumLoops()) |
| 999 | return MatchConvolutionResult::NonConvolutionLoop; |
| 1000 | |
| 1001 | if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty()) |
| 1002 | return MatchConvolutionResult::EmptyConvolvedDims; |
| 1003 | |
| 1004 | if (dimensions) { |
| 1005 | FailureOr<ConvolutionDimensions> res = inferConvolutionDimsImpl( |
| 1006 | linalgOp, inputExprWalker, allowEmptyConvolvedDims); |
| 1007 | assert(succeeded(res) && "unexpected failure to infer convolution dims" ); |
| 1008 | *dimensions = *res; |
| 1009 | } |
| 1010 | |
| 1011 | return MatchConvolutionResult::Success; |
| 1012 | } |
| 1013 | |
| 1014 | StringRef |
| 1015 | mlir::linalg::detail::getMatchConvolutionMessage(MatchConvolutionResult res) { |
| 1016 | switch (res) { |
| 1017 | case MatchConvolutionResult::NotLinalgOp: |
| 1018 | return "expected a LinalgOp" ; |
| 1019 | case MatchConvolutionResult::WrongNumOperands: |
| 1020 | return "expected op with 2 inputs and 1 output" ; |
| 1021 | case MatchConvolutionResult::WrongInputIndexingMap: |
| 1022 | return "unexpected input index map for convolutions" ; |
| 1023 | case MatchConvolutionResult::NotProjectedPermutations: |
| 1024 | return "expected output/filter indexing maps to be projected permutations" ; |
| 1025 | case MatchConvolutionResult::NonConvolutionLoop: |
| 1026 | return "unexpected loop dimension for convolution op" ; |
| 1027 | case MatchConvolutionResult::OutputDimsNotParallel: |
| 1028 | return "expected all iterators used to access outputs to be parallel" ; |
| 1029 | case MatchConvolutionResult::NonOutputDimNotReduction: |
| 1030 | return "expected all iterators not used to access outputs to be reduction" ; |
| 1031 | case MatchConvolutionResult::EmptyConvolvedDims: |
| 1032 | return "expected convolved dim to be non-empty" ; |
| 1033 | case MatchConvolutionResult::Success: |
| 1034 | return "" ; |
| 1035 | } |
| 1036 | llvm_unreachable("unhandled MatchConvolutionResult case" ); |
| 1037 | } |
| 1038 | |
| 1039 | bool mlir::linalg::isaConvolutionOpInterface(LinalgOp linalgOp, |
| 1040 | bool allowEmptyConvolvedDims) { |
| 1041 | return linalg::detail::isConvolutionInterfaceImpl( |
| 1042 | op: linalgOp.getOperation(), dimensions: nullptr, allowEmptyConvolvedDims) == |
| 1043 | linalg::detail::MatchConvolutionResult::Success; |
| 1044 | } |
| 1045 | |
| 1046 | LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) { |
| 1047 | MatchConvolutionResult res = isConvolutionInterfaceImpl(op); |
| 1048 | if (res != MatchConvolutionResult::Success) |
| 1049 | return op->emitError(message: getMatchConvolutionMessage(res)); |
| 1050 | return success(); |
| 1051 | } |
| 1052 | |
| 1053 | //===----------------------------------------------------------------------===// |
| 1054 | // FillOpInterface implementation |
| 1055 | //===----------------------------------------------------------------------===// |
| 1056 | |
| 1057 | enum class MatchFillResult { |
| 1058 | Success = 0, |
| 1059 | NotLinalgOp, |
| 1060 | WrongNumOperands, |
| 1061 | NotScalarInput |
| 1062 | }; |
| 1063 | |
| 1064 | static MatchFillResult isFillInterfaceImpl(Operation *op) { |
| 1065 | auto linalgOp = dyn_cast<linalg::LinalgOp>(Val: op); |
| 1066 | if (!linalgOp) |
| 1067 | return MatchFillResult::NotLinalgOp; |
| 1068 | if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1) |
| 1069 | return MatchFillResult::WrongNumOperands; |
| 1070 | |
| 1071 | OpOperand *value = linalgOp.getDpsInputOperand(i: 0); |
| 1072 | if (!linalgOp.isScalar(opOperand: value)) |
| 1073 | return MatchFillResult::NotScalarInput; |
| 1074 | |
| 1075 | return MatchFillResult::Success; |
| 1076 | } |
| 1077 | |
| 1078 | LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) { |
| 1079 | auto res = isFillInterfaceImpl(op); |
| 1080 | if (res == MatchFillResult::NotLinalgOp) |
| 1081 | return op->emitError(message: "expected a LinalgOp" ); |
| 1082 | if (res == MatchFillResult::WrongNumOperands) |
| 1083 | return op->emitError(message: "expected op with 1 input and 1 output" ); |
| 1084 | if (res == MatchFillResult::NotScalarInput) |
| 1085 | return op->emitError(message: "expected op with scalar input" ); |
| 1086 | |
| 1087 | return success(); |
| 1088 | } |
| 1089 | |
| 1090 | //===----------------------------------------------------------------------===// |
| 1091 | // StructuredOpInterface implementation |
| 1092 | //===----------------------------------------------------------------------===// |
| 1093 | |
| 1094 | SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b, |
| 1095 | Location loc) { |
| 1096 | SmallVector<OpFoldResult> res; |
| 1097 | for (OpOperand &opOperand : getOperation()->getOpOperands()) { |
| 1098 | for (int64_t i = 0, e = getRank(opOperand: &opOperand); i < e; ++i) |
| 1099 | res.push_back(Elt: createFoldedDimOp(b, loc, val: opOperand.get(), dim: i)); |
| 1100 | } |
| 1101 | return res; |
| 1102 | } |
| 1103 | |
| 1104 | SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() { |
| 1105 | SmallVector<int64_t, 4> res; |
| 1106 | assert(!hasDynamicShape() && "expected operands to have static shapes" ); |
| 1107 | for (OpOperand &opOperand : getOperation()->getOpOperands()) |
| 1108 | llvm::append_range(C&: res, R: getShape(opOperand: &opOperand)); |
| 1109 | return res; |
| 1110 | } |
| 1111 | |
| 1112 | SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) { |
| 1113 | AffineMap map = getLoopsToShapesMap(); |
| 1114 | unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); |
| 1115 | auto viewSizes = createFlatListOfOperandDims(b, loc); |
| 1116 | SmallVector<Range, 4> res(numDims); |
| 1117 | for (unsigned idx = 0; idx < numRes; ++idx) { |
| 1118 | auto result = map.getResult(idx); |
| 1119 | if (auto d = dyn_cast<AffineDimExpr>(Val&: result)) { |
| 1120 | if (res[d.getPosition()].offset) |
| 1121 | continue; |
| 1122 | res[d.getPosition()] = |
| 1123 | Range{.offset: b.getIndexAttr(value: 0), .size: viewSizes[idx], .stride: b.getIndexAttr(value: 1)}; |
| 1124 | } |
| 1125 | } |
| 1126 | return res; |
| 1127 | } |
| 1128 | |
| 1129 | /// Visitor to check if any of the given set of positions from AffineDimExprs |
| 1130 | /// are used within an AffineExpr. |
| 1131 | struct HasAffineDimExprVisitor |
| 1132 | : public AffineExprVisitor<HasAffineDimExprVisitor, bool> { |
| 1133 | HasAffineDimExprVisitor(llvm::SmallBitVector positions) |
| 1134 | : positions(std::move(positions)) {} |
| 1135 | |
| 1136 | bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) { |
| 1137 | return visit(expr: binaryOpExpr.getLHS()) || visit(expr: binaryOpExpr.getRHS()); |
| 1138 | } |
| 1139 | |
| 1140 | bool visitDimExpr(AffineDimExpr dimExpr) { |
| 1141 | return positions.test(Idx: dimExpr.getPosition()); |
| 1142 | } |
| 1143 | |
| 1144 | bool visitConstantExpr(AffineConstantExpr constExpr) { return false; } |
| 1145 | |
| 1146 | bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; } |
| 1147 | |
| 1148 | private: |
| 1149 | llvm::SmallBitVector positions; |
| 1150 | }; |
| 1151 | |
| 1152 | static std::pair<int64_t, int64_t> |
| 1153 | getResultsPositionInLoopsToShapeMap(LinalgOp &op) { |
| 1154 | int64_t inputRankSum = 0; |
| 1155 | int64_t outputRankSum = 0; |
| 1156 | for (OpOperand *input : op.getDpsInputOperands()) |
| 1157 | inputRankSum += op.getRank(opOperand: input); |
| 1158 | for (OpOperand &output : op.getDpsInitsMutable()) |
| 1159 | outputRankSum += op.getRank(opOperand: &output); |
| 1160 | return {inputRankSum, inputRankSum + outputRankSum}; |
| 1161 | } |
| 1162 | |
| 1163 | LogicalResult |
| 1164 | LinalgOp::reifyResultShapes(OpBuilder &b, |
| 1165 | ReifiedRankedShapedTypeDims &reifiedReturnShapes) { |
| 1166 | // An example that helps understand the logic below. |
| 1167 | // Consider the following expression O(i+j, j) += A(i,k) * B(k, j) |
| 1168 | // We want to express the shape of dim 0 of O in terms of shape of the inputs. |
| 1169 | // This is achieved as follows. |
| 1170 | // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1) |
| 1171 | // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1) |
| 1172 | // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2) |
| 1173 | // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap) |
| 1174 | // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1) |
| 1175 | AffineMap loopsToShapesMap = getLoopsToShapesMap(); |
| 1176 | |
| 1177 | // Find the position in the above map that represents the shape of the |
| 1178 | // result:dim being inferred. |
| 1179 | auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(op&: *this); |
| 1180 | |
| 1181 | /// From loopsToShapesMap extract the submap that represents the shape of the |
| 1182 | /// (resultIdx, dim) needed. |
| 1183 | AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap( |
| 1184 | start: resultShapesSubMapPos.first, |
| 1185 | length: resultShapesSubMapPos.second - resultShapesSubMapPos.first); |
| 1186 | AffineMap resultShapesFromInputShapesMap = |
| 1187 | loopToResultsShapeMap.compose(map: getShapesToLoopsMap()); |
| 1188 | |
| 1189 | // Check that the result dim map does not contain the positions corresponding |
| 1190 | // to the outputs. |
| 1191 | llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims()); |
| 1192 | outputDims.set(I: resultShapesSubMapPos.first, E: resultShapesSubMapPos.second); |
| 1193 | HasAffineDimExprVisitor checkDimExpr(std::move(outputDims)); |
| 1194 | Location loc = getOperation()->getLoc(); |
| 1195 | IRRewriter rewriter(b); |
| 1196 | SmallVector<OpFoldResult> allResultDimValues = |
| 1197 | affine::makeComposedFoldedMultiResultAffineApply( |
| 1198 | b&: rewriter, loc, map: resultShapesFromInputShapesMap, |
| 1199 | operands: createFlatListOfOperandDims(b, loc)); |
| 1200 | int64_t pos = 0; |
| 1201 | ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults(); |
| 1202 | for (OpOperand &opOperand : getDpsInitsMutable()) { |
| 1203 | SmallVector<OpFoldResult> shapes; |
| 1204 | for (int64_t dim : llvm::seq<int64_t>(Begin: 0, End: getRank(opOperand: &opOperand))) { |
| 1205 | auto shapedType = llvm::cast<ShapedType>(Val: opOperand.get().getType()); |
| 1206 | if (!shapedType.isDynamicDim(idx: dim)) { |
| 1207 | // Static dim: Return IntegerAttr. |
| 1208 | shapes.push_back(Elt: b.getIndexAttr(value: shapedType.getDimSize(idx: dim))); |
| 1209 | } else { |
| 1210 | // Dynamic dim: Return Value. |
| 1211 | OpFoldResult ofr = checkDimExpr.visit(expr: shapeExprs[pos]) |
| 1212 | ? createOrFoldDimOp(b, loc, val: opOperand.get(), dim) |
| 1213 | : allResultDimValues[pos]; |
| 1214 | shapes.push_back(Elt: getValueOrCreateConstantIndexOp(b, loc, ofr)); |
| 1215 | } |
| 1216 | pos++; |
| 1217 | } |
| 1218 | reifiedReturnShapes.emplace_back(Args: std::move(shapes)); |
| 1219 | } |
| 1220 | return success(); |
| 1221 | } |
| 1222 | |
| 1223 | /// Return the index in the indexingMaps vector that corresponds to this |
| 1224 | /// `opOperand`. |
| 1225 | int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) { |
| 1226 | auto operandNumber = opOperand->getOperandNumber(); |
| 1227 | auto dpsIface = cast<DestinationStyleOpInterface>(Val&: *this->getOperation()); |
| 1228 | if (!dpsIface.isDpsInput(opOperand)) |
| 1229 | return operandNumber; |
| 1230 | unsigned start = dpsIface.getDpsInits().getBeginOperandIndex(); |
| 1231 | assert(!dpsIface.isDpsInit(opOperand)); |
| 1232 | // Account for potential inputs that are not DPS and may not appear in |
| 1233 | // `indexingMaps`. |
| 1234 | return cast<DestinationStyleOpInterface>(Val&: *this->getOperation()) |
| 1235 | .getNumDpsInputs() + |
| 1236 | operandNumber - start; |
| 1237 | } |
| 1238 | |
| 1239 | LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { |
| 1240 | LinalgOp linalgOp = cast<LinalgOp>(Val: op); |
| 1241 | // Mixed tensor/buffer operands are not allowed. |
| 1242 | if (!linalgOp.hasPureTensorSemantics() && |
| 1243 | !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0) |
| 1244 | return op->emitOpError(message: "expected to have pure tensor or buffer semantics" ); |
| 1245 | |
| 1246 | // Before checking indexing maps, we need to make sure the attributes |
| 1247 | // referenced by it are valid. |
| 1248 | if (linalgOp.hasDynamicIndexingMaps()) |
| 1249 | if (failed(Result: linalgOp.verifyIndexingMapRequiredAttributes())) |
| 1250 | return failure(); |
| 1251 | |
| 1252 | // Delayed calling of IndexingMapOpInterface::verifyImpl. |
| 1253 | if (failed(Result: cast<IndexingMapOpInterface>(Val: op).verifyImpl())) |
| 1254 | return failure(); |
| 1255 | |
| 1256 | // Set this flag if this op has user defined maps. This is required to guard |
| 1257 | // the below error condition which assume default indexing maps. |
| 1258 | for (OpOperand &opOperand : linalgOp->getOpOperands()) { |
| 1259 | AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand: &opOperand); |
| 1260 | // Domain must be consistent. |
| 1261 | unsigned numLoops = linalgOp.getNumLoops(); |
| 1262 | if (indexingMap.getNumDims() != numLoops) |
| 1263 | return op->emitOpError(message: "expected indexing_map #" ) |
| 1264 | << opOperand.getOperandNumber() << " to have " << numLoops |
| 1265 | << " dim(s) to match the number of loops" ; |
| 1266 | } |
| 1267 | SmallVector<unsigned> redDims; |
| 1268 | linalgOp.getReductionDims(res&: redDims); |
| 1269 | |
| 1270 | if (!linalgOp.getShapesToLoopsMap()) |
| 1271 | return op->emitOpError(message: "expected the shape-to-loops map to be non-null" ); |
| 1272 | |
| 1273 | // Check the region has exactly one block. |
| 1274 | if (linalgOp->getNumRegions() != 1 || |
| 1275 | !llvm::hasSingleElement(C&: linalgOp->getRegion(index: 0))) |
| 1276 | return op->emitOpError(message: "expects to have 1 region with 1 block" ); |
| 1277 | |
| 1278 | // Simplifying assumption: bbargs match 1-1 with shape operands elemental |
| 1279 | // types. |
| 1280 | // TODO: once ranked shape types are plugged in, we may want to drop the |
| 1281 | // corresponding bbargs, that can never be read from. This will be subject to |
| 1282 | // consistency discussions (i.e. what to do with output tensors whose bbarg is |
| 1283 | // not used). |
| 1284 | Block &block = linalgOp->getRegion(index: 0).front(); |
| 1285 | |
| 1286 | if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments()) |
| 1287 | return op->emitOpError(message: "expected as many non-induction variable region " |
| 1288 | "arguments as the number of input/output operands" ); |
| 1289 | |
| 1290 | for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) { |
| 1291 | Type elementType = opOperand->get().getType(); |
| 1292 | if (isa<MemRefType, RankedTensorType>(Val: elementType)) |
| 1293 | elementType = getElementTypeOrSelf(type: opOperand->get().getType()); |
| 1294 | Type argType = block.getArgument(i: opOperand->getOperandNumber()).getType(); |
| 1295 | if (elementType != argType) |
| 1296 | return op->emitOpError(message: "expected type of bb argument #" ) |
| 1297 | << opOperand->getOperandNumber() << " (" << argType << ")" |
| 1298 | << " to match element or self type of the corresponding operand (" |
| 1299 | << elementType << ")" ; |
| 1300 | } |
| 1301 | |
| 1302 | return success(); |
| 1303 | } |
| 1304 | |