1//===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
10
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/Arith/Utils/Utils.h"
14#include "mlir/Dialect/Complex/IR/Complex.h"
15#include "mlir/Dialect/Linalg/IR/Linalg.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/Tensor/IR/Tensor.h"
18#include "mlir/IR/AffineExprVisitor.h"
19#include "mlir/IR/AffineMap.h"
20#include "mlir/IR/TypeUtilities.h"
21#include "llvm/ADT/SetOperations.h"
22#include "llvm/ADT/SmallBitVector.h"
23#include "llvm/ADT/SmallVector.h"
24#include <algorithm>
25
26using namespace mlir;
27using namespace mlir::linalg;
28
29/// Include the definitions of the copy operation interface.
30#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
31
32//===----------------------------------------------------------------------===//
33// Interface utility functions
34//===----------------------------------------------------------------------===//
35
36bool linalg::detail::canOpOperandsBeDroppedImpl(
37 linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
38 SmallVector<AffineMap> indexingMaps;
39 for (auto &opOperand : linalgOp->getOpOperands()) {
40 if (llvm::is_contained(droppedOperands, &opOperand))
41 continue;
42 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
43 }
44 if (indexingMaps.empty()) {
45 // If there are no indexing maps, the operand can only be dropped
46 // if the op has no loops.
47 return linalgOp.getNumLoops() == 0;
48 }
49 return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
50}
51
52//===----------------------------------------------------------------------===//
53// CopyOpInterface implementation
54//===----------------------------------------------------------------------===//
55
56bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
57 // Structural.
58 if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
59 return false;
60
61 // Operands and maps.
62 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
63 return false;
64 auto mapRange = linalgOp.getIndexingMapsArray();
65 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
66 !mapRange.back().isIdentity()) {
67 return false;
68 }
69 // Region.
70 return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
71}
72
73//===----------------------------------------------------------------------===//
74// ContractionOpInterface implementation
75//===----------------------------------------------------------------------===//
76
77/// If the value is defined by a chain of unary side effect-free, go up the
78/// use-def chain until the first value that isn't defined by such an op.
79// TODO: relax to multi-operands with constants, which are technically unary ops
80// as needed (e.g. add5).
81static Value getSourceSkipUnary(Value value) {
82 Operation *op = value.getDefiningOp();
83 while (op && op->getNumOperands() == 1) {
84 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
85 if (!iface || !iface.hasNoEffect())
86 break;
87 value = op->getOperand(idx: 0);
88 op = value.getDefiningOp();
89 }
90 return value;
91}
92
93bool mlir::linalg::detail::isContractionBody(
94 Block &block, function_ref<bool(Operation *, Operation *)> isaPair,
95 llvm::raw_ostream &errs) {
96 if (block.empty() || !block.back().mightHaveTrait<OpTrait::IsTerminator>()) {
97 errs << "no terminator in the block";
98 return false;
99 }
100
101 if (block.getNumArguments() != 3) {
102 errs << "expected block with 3 arguments";
103 return false;
104 }
105
106 Operation *terminator = block.getTerminator();
107 if (terminator->getNumOperands() != 1) {
108 errs << "expected terminator with 1 operand";
109 return false;
110 }
111
112 Value yielded = getSourceSkipUnary(value: terminator->getOperand(idx: 0));
113 Operation *reductionOp = yielded.getDefiningOp();
114 if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) {
115 errs << "expected reduction op to be binary";
116 return false;
117 }
118
119 Value reductionLHS = getSourceSkipUnary(value: reductionOp->getOperand(idx: 0));
120 Value reductionRHS = getSourceSkipUnary(value: reductionOp->getOperand(idx: 1));
121
122 if (reductionLHS != block.getArgument(i: 2) &&
123 reductionRHS != block.getArgument(i: 2)) {
124 errs << "expected reduction to take block argument #2 as one of the "
125 "operands (modulo unary casts)";
126 return false;
127 }
128
129 Value contributed = getSourceSkipUnary(
130 value: isa<BlockArgument>(Val: reductionLHS) ? reductionRHS : reductionLHS);
131 Operation *elementwiseOp = contributed.getDefiningOp();
132 if (elementwiseOp->getNumResults() != 1 ||
133 elementwiseOp->getNumOperands() != 2) {
134 errs << "expected elementwise op to be binary";
135 return false;
136 }
137
138 if (!isaPair(elementwiseOp, reductionOp)) {
139 errs << "expected reduction/elementwise op kind not satisfied";
140 return false;
141 }
142
143 Value elementwiseLHS = getSourceSkipUnary(value: elementwiseOp->getOperand(idx: 0));
144 Value elementwiseRHS = getSourceSkipUnary(value: elementwiseOp->getOperand(idx: 1));
145 if ((elementwiseLHS == block.getArgument(i: 0) &&
146 elementwiseRHS == block.getArgument(i: 1)) ||
147 (elementwiseLHS == block.getArgument(i: 1) &&
148 elementwiseRHS == block.getArgument(i: 0))) {
149 return true;
150 }
151
152 errs << "expected elementwise op to apply to block arguments (modulo unary "
153 "casts)";
154 return false;
155}
156
157/// Returns true if the two operations are of the kinds specified by a pair of
158/// consecutive template arguments.
159template <typename AddOpTy, typename MulOpTy, typename... Args>
160static bool isPairTemplateImpl(Operation *add, Operation *mul) {
161 static_assert(sizeof...(Args) % 2 == 0,
162 "expected an even number of template arguments");
163 if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
164 return true;
165
166 if constexpr (sizeof...(Args) > 0)
167 return isPairTemplateImpl<Args...>(add, mul);
168 else
169 return false;
170}
171
172/// Returns true if the block is a body of a contraction with the kinds of
173/// operations given pairwise by template arguments.
174template <typename... Args>
175static bool isContractionBody(Block &block) {
176 return linalg::detail::isContractionBody(block, isaPair: &isPairTemplateImpl<Args...>);
177}
178
179/// Given an `indexingMap` and its corresponding `iterators`, returns
180/// the positions of the iterators of type `iter` that are indexed by
181/// the `indexingMap` as a permutation. This is useful to infer various
182/// subcomputations on a `LinalgOp`. This is performed by looking up
183/// each result in the `indexingMap` and determining whether:
184/// - It is a single AffineDimExpr.
185/// - It is the only result involving this AffineDimExpr.
186static llvm::SmallDenseSet<int64_t>
187findPermutationsIndexingOperand(AffineMap indexingMap,
188 ArrayRef<utils::IteratorType> iterators,
189 utils::IteratorType iter) {
190 assert(iterators.size() == indexingMap.getNumDims());
191 llvm::SmallDenseSet<int64_t> res;
192 for (AffineExpr e : indexingMap.getResults()) {
193 if (auto d = dyn_cast<AffineDimExpr>(Val&: e)) {
194 if (iterators[d.getPosition()] == iter &&
195 llvm::count_if(Range: indexingMap.getResults(), P: [d](AffineExpr e) {
196 return e.isFunctionOfDim(position: d.getPosition());
197 }) == 1)
198 res.insert(V: d.getPosition());
199 }
200 }
201 return res;
202}
203
204namespace {
205auto par = utils::IteratorType::parallel;
206auto red = utils::IteratorType::reduction;
207} // namespace
208
209/// Infer the iterator types from the init affine map. This looks at which dims
210/// are present in the map results, and returns an iterator types array with
211/// parallel types for dims that are present, and reduction types for dims that
212/// are not present.
213static FailureOr<SmallVector<utils::IteratorType>>
214inferIteratorsFromOutMap(AffineMap map) {
215 if (!map.isProjectedPermutation())
216 return failure();
217 SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
218 for (auto expr : map.getResults())
219 if (auto dim = dyn_cast<AffineDimExpr>(Val&: expr))
220 iterators[dim.getPosition()] = par;
221 return iterators;
222}
223
224/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
225/// a matmul subcomputation within `linalgOp`. These dimensions are such that:
226/// 1. The m dimension is involved in an outer-product along LHS
227/// (i.e. it is a permutation on RES and LHS and does not appear in RHS).
228/// 2. The n dimension is involved in an outer-product along RHS
229/// (i.e. it is a permutation on RES and RHS and does not appear in LHS).
230/// 3. The k dimension appears as a permutation on LHS and RHS.
231/// 4. m, n and k appear only once in any given indexing.
232/// 5. Optional batch dimensions that appear in all operands are captured.
233/// This allows e.g. detecting that some contraction is embedded within
234/// `linalgOp` with some orthogonal heuristic.
235static FailureOr<ContractionDimensions>
236inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps,
237 ArrayRef<utils::IteratorType> iterators) {
238 llvm::SmallDenseSet<int64_t> a =
239 findPermutationsIndexingOperand(indexingMaps[0], iterators, par);
240 llvm::SmallDenseSet<int64_t> b =
241 findPermutationsIndexingOperand(indexingMaps[1], iterators, par);
242 llvm::SmallDenseSet<int64_t> c =
243 findPermutationsIndexingOperand(indexingMaps[2], iterators, par);
244
245 // A & C - B are the iterators involved in an outer-product along A (the LHS).
246 llvm::SmallDenseSet<int64_t> ac = a;
247 llvm::set_intersect(S1&: ac, S2: c);
248 llvm::set_subtract(S1&: ac, S2: b);
249 // B & C - A are the iterators involved in an outer-product along B (the RHS).
250 llvm::SmallDenseSet<int64_t> bc = b;
251 llvm::set_intersect(S1&: bc, S2: c);
252 llvm::set_subtract(S1&: bc, S2: a);
253 // A & B & C are the "batch" dimensions.
254 llvm::SmallDenseSet<int64_t> batches = a;
255 llvm::set_intersect(S1&: batches, S2: b);
256 llvm::set_intersect(S1&: batches, S2: c);
257
258 // A & B red are the reduction dimensions.
259 llvm::SmallDenseSet<int64_t> ra =
260 findPermutationsIndexingOperand(indexingMaps[0], iterators, red);
261 llvm::SmallDenseSet<int64_t> rb =
262 findPermutationsIndexingOperand(indexingMaps[1], iterators, red);
263 llvm::set_intersect(S1&: ra, S2: rb);
264
265 // Return each set in sorted order.
266 ContractionDimensions dimensions{
267 .batch: SmallVector<unsigned, 2>(batches.begin(), batches.end()),
268 .m: SmallVector<unsigned, 2>(ac.begin(), ac.end()),
269 .n: SmallVector<unsigned, 2>(bc.begin(), bc.end()),
270 .k: SmallVector<unsigned, 2>(ra.begin(), ra.end())};
271 llvm::sort(Start: dimensions.batch.begin(), End: dimensions.batch.end());
272 llvm::sort(Start: dimensions.m.begin(), End: dimensions.m.end());
273 llvm::sort(Start: dimensions.n.begin(), End: dimensions.n.end());
274 llvm::sort(Start: dimensions.k.begin(), End: dimensions.k.end());
275 return dimensions;
276}
277
278FailureOr<ContractionDimensions>
279mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
280 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
281 return failure();
282 return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(),
283 linalgOp.getIteratorTypesArray());
284}
285
286FailureOr<ContractionDimensions>
287mlir::linalg::inferContractionDims(ArrayRef<AffineMap> indexingMaps) {
288 if (indexingMaps.size() != 3)
289 return failure();
290 auto iterators = inferIteratorsFromOutMap(indexingMaps[2]);
291 if (failed(iterators))
292 return failure();
293 return inferContractionDimsImpl(indexingMaps, iterators.value());
294}
295
296namespace mlir::linalg::detail {
297enum class MatchContractionResult {
298 Success = 0,
299 NotLinalgOp,
300 WrongNumOperands,
301 NoReduction,
302 NotProjectedPermutations,
303 NotAddMul
304};
305} // namespace mlir::linalg::detail
306
307mlir::linalg::detail::MatchContractionResult
308mlir::linalg::detail::isContractionInterfaceImpl(
309 Operation *op, mlir::linalg::ContractionDimensions *dimensions) {
310 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
311 if (!linalgOp)
312 return MatchContractionResult::NotLinalgOp;
313 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
314 return MatchContractionResult::WrongNumOperands;
315 auto mapRange = linalgOp.getIndexingMapsArray();
316 if (linalgOp.getNumReductionLoops() == 0)
317 return MatchContractionResult::NoReduction;
318 if (llvm::any_of(mapRange,
319 [](AffineMap m) { return !m.isProjectedPermutation(); }))
320 return MatchContractionResult::NotProjectedPermutations;
321 // TODO: more fields than add/mul.
322 // clang-format off
323 if (!::isContractionBody<
324 arith::MulFOp, arith::AddFOp,
325 arith::MulIOp, arith::AddIOp,
326 complex::MulOp, complex::AddOp,
327 arith::AndIOp, arith::OrIOp>(
328 *linalgOp.getBlock())) {
329 return MatchContractionResult::NotAddMul;
330 }
331 // clang-format on
332
333 if (dimensions) {
334 FailureOr<ContractionDimensions> res = inferContractionDims(linalgOp);
335 assert(succeeded(res) && "unexpected failure to infer contraction dims");
336 *dimensions = *res;
337 }
338 return MatchContractionResult::Success;
339}
340
341StringRef
342mlir::linalg::detail::getMatchContractionMessage(MatchContractionResult res) {
343 switch (res) {
344 case MatchContractionResult::NotLinalgOp:
345 return "expected a LinalgOp";
346 case MatchContractionResult::WrongNumOperands:
347 return "expected op with 2 inputs and 1 output";
348 case MatchContractionResult::NoReduction:
349 return "expected at least 1 reduction";
350 case MatchContractionResult::NotProjectedPermutations:
351 return "expected indexing maps to be projected permutations";
352 case MatchContractionResult::NotAddMul:
353 return "expected add/mul op in the body";
354 case MatchContractionResult::Success:
355 return "";
356 }
357 llvm_unreachable("unhandled MatchContractionResult case");
358}
359
360bool mlir::linalg::isaContractionOpInterface(LinalgOp linalgOp) {
361 if (!linalgOp)
362 return false;
363 Operation *op = linalgOp.getOperation();
364 return isa<ContractionOpInterface>(op) ||
365 (mlir::linalg::detail::isContractionInterfaceImpl(op) ==
366 mlir::linalg::detail::MatchContractionResult::Success);
367}
368
369/// Verify that a LinalgOp `op` is a contraction.
370/// A Linalg contraction is defined in general terms:
371/// 1. Has 2 input and 1 output shapes.
372/// 2. Has at least one reduction dimension.
373/// 3. Has only projected permutation indexing maps.
374/// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
375/// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
376/// operations that may change the type (e.g. for mixed-precision).
377/// As a consequence, when vectorization of such an op occurs, the only special
378/// behavior is that the (unique) MulOpType is vectorized into a
379/// `vector.contract`. All other ops are handled in a generic fashion.
380/// In the future, we may wish to allow more input arguments and elementwise and
381/// constant operations that do not involve the reduction dimension(s).
382LogicalResult mlir::linalg::detail::verifyContractionInterface(Operation *op) {
383 auto res = isContractionInterfaceImpl(op);
384 if (res != MatchContractionResult::Success)
385 return op->emitError(message: getMatchContractionMessage(res));
386 return success();
387}
388
389//===----------------------------------------------------------------------===//
390// ConvolutionOpInterface implementation
391//===----------------------------------------------------------------------===//
392
393/// Of the given two expressions returns one that is of type T (`lhs` gets
394/// preference over `rhs`)
395template <typename T>
396static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs) {
397 return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) : nullptr);
398}
399
400namespace {
401/// Walk the indexing expressions for input of a convolution operation to verify
402/// its of the right form, either
403/// - AffineDimExpr
404/// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?
405/// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)*
406///
407/// classifies the AffineDimExpr as convolved dimensions or unconvolved
408/// dimensions and verifies each dimension occurs only once.
409struct ConvAccessExprWalker
410 : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {
411 // Stores dimensions used in expressions of the above form.
412 llvm::SmallDenseSet<int64_t> convolvedDims;
413 // Stores the dual mapping between LHS and RHS of convolution exprs.
414 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
415 // Stores single use dimensions used by an AffineDimExpr.
416 llvm::SmallDenseSet<int64_t> unConvolvedDims;
417 // Stores a mapping from convolved dims to their coefficient.
418 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
419
420 // Removes dims with multiple uses in the source input map from dimension
421 // sets tracked by this walker.
422 void clearMultiUseDims(AffineMap map) {
423 for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) {
424 if (llvm::count_if(Range: map.getResults(), P: [dimPos](AffineExpr e) {
425 return e.isFunctionOfDim(position: dimPos);
426 }) > 1) {
427 convolvedDims.erase(V: dimPos);
428 unConvolvedDims.erase(V: dimPos);
429 // If a duplicate dim is marked as convolved, the pair of the duplicate
430 // dim must be removed from the map as well.
431 if (convolvedDimMapping.contains(Val: dimPos)) {
432 int64_t pairedDim = convolvedDimMapping[dimPos];
433 convolvedDims.erase(V: pairedDim);
434 unConvolvedDims.erase(V: pairedDim);
435 strideAndDilationMapping.erase(Val: pairedDim);
436 convolvedDimMapping.erase(Val: dimPos);
437 convolvedDimMapping.erase(Val: pairedDim);
438 }
439 }
440 }
441 }
442
443 LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
444 unsigned position = dimExpr.getPosition();
445 if (unConvolvedDims.count(V: position) || convolvedDims.count(V: position)) {
446 return failure();
447 }
448 unConvolvedDims.insert(V: position);
449 return success();
450 }
451
452 LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); }
453
454 LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); }
455
456 LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
457 // In pre-order visit, top level op has to be an add op.
458 if (binaryExpr.getKind() != AffineExprKind::Add)
459 return failure();
460 auto lhsDimPos = getDimExprOrMulExprDimPos(expr: binaryExpr.getLHS());
461 auto rhsDimPos = getDimExprOrMulExprDimPos(expr: binaryExpr.getRHS());
462 if (failed(result: lhsDimPos) || failed(result: rhsDimPos))
463 return failure();
464 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
465 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
466 return success();
467 }
468
469 FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
470 if (auto dimExpr = dyn_cast<AffineDimExpr>(Val&: expr)) {
471 int64_t dim = dimExpr.getPosition();
472 if (convolvedDims.count(V: dim) || unConvolvedDims.count(V: dim))
473 return failure();
474 // Stride/dilation for this dim is implicitly 1.
475 strideAndDilationMapping[dim] =
476 getAffineConstantExpr(constant: 1, context: expr.getContext());
477 convolvedDims.insert(V: dim);
478 return dim;
479 }
480 if (auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(Val&: expr)) {
481 if (symbolMulExpr.getKind() != AffineExprKind::Mul)
482 return failure();
483 auto lhsExpr = symbolMulExpr.getLHS();
484 auto rhsExpr = symbolMulExpr.getRHS();
485 // Check for symbol expression.
486 AffineExpr mulExpr =
487 getAffineExprOfType<AffineSymbolExpr>(lhs: lhsExpr, rhs: rhsExpr);
488 // If there was no symbol expr, check for constant expression.
489 if (!mulExpr) {
490 mulExpr = getAffineExprOfType<AffineConstantExpr>(lhs: lhsExpr, rhs: rhsExpr);
491 }
492 auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhs: lhsExpr, rhs: rhsExpr);
493 if (!mulExpr || !dimExpr)
494 return failure();
495 int64_t dim = dimExpr.getPosition();
496 if (convolvedDims.count(V: dim) || unConvolvedDims.count(V: dim))
497 return failure();
498 strideAndDilationMapping[dim] = mulExpr;
499 convolvedDims.insert(V: dim);
500 return dim;
501 }
502 return failure();
503 }
504};
505} // namespace
506
507static llvm::SmallDenseSet<int64_t> getPreservedDims(AffineMap map) {
508 assert(map.isProjectedPermutation() &&
509 "expected map to have projected permutations");
510 llvm::SmallDenseSet<int64_t> preservedDims;
511 for (auto expr : map.getResults())
512 preservedDims.insert(V: cast<AffineDimExpr>(Val&: expr).getPosition());
513 return preservedDims;
514}
515
516static SmallVector<int64_t, 2>
517getConstantsFromExprList(const SmallVector<AffineExpr, 2> &exprs) {
518 SmallVector<int64_t, 2> vals;
519 for (auto e : exprs) {
520 auto constantExpr = dyn_cast<AffineConstantExpr>(Val&: e);
521 assert(constantExpr && "Found non-constant stride/dilation");
522 vals.push_back(Elt: constantExpr.getValue());
523 }
524 return vals;
525}
526
527/// Classifies dimensions in the `linalgOp` used by a convolution
528/// subcomputation, as captured by `inputExprWalker`. If
529/// `allowEmptyConvolvedDims` is not set this this will fail if there is not
530/// at least convolved dimension pair (output image + filter loop). Convolution
531/// dimensions are specified in sorted order, and strides match the order of
532/// the filter loop dimensions, while the dilations match the order of the
533/// output image dimensions.
534static FailureOr<ConvolutionDimensions>
535inferConvolutionDimsImpl(LinalgOp linalgOp,
536 ConvAccessExprWalker &inputExprWalker,
537 bool allowEmptyConvolvedDims) {
538 auto filterMap =
539 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
540 auto outputMap =
541 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
542 llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand(
543 filterMap, linalgOp.getIteratorTypesArray(), par);
544 llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand(
545 outputMap, linalgOp.getIteratorTypesArray(), par);
546
547 // unConvolvedDims & outputDims - filterDims are the batch iterators.
548 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
549 llvm::set_intersect(S1&: batch, S2: outputDims);
550 llvm::set_subtract(S1&: batch, S2: filterDims);
551
552 // convolvedDims & outputDims are the output image iterators.
553 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
554 llvm::set_intersect(S1&: oi, S2: outputDims);
555
556 // filterDims & outputDims - unConvolvedDims are the output channel iterators.
557 llvm::SmallDenseSet<int64_t> oc = filterDims;
558 llvm::set_intersect(S1&: oc, S2: outputDims);
559 llvm::set_subtract(S1&: oc, S2: inputExprWalker.unConvolvedDims);
560
561 // filterDims & outputDims & unConvolvedDims are the depth iterators.
562 llvm::SmallDenseSet<int64_t> depth = filterDims;
563 llvm::set_intersect(S1&: depth, S2: outputDims);
564 llvm::set_intersect(S1&: depth, S2: inputExprWalker.unConvolvedDims);
565
566 llvm::SmallDenseSet<int64_t> filterReducedDims =
567 findPermutationsIndexingOperand(filterMap,
568 linalgOp.getIteratorTypesArray(), red);
569
570 // convolvedDims & filterReducedDims are the filter loop iterators.
571 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
572 llvm::set_intersect(S1&: fl, S2: filterReducedDims);
573
574 // unConvolvedDims & filterReducedDims are the input channel iterators.
575 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
576 llvm::set_intersect(S1&: ic, S2: filterReducedDims);
577
578 if (oi.empty() && !allowEmptyConvolvedDims)
579 return failure();
580
581 // Return each set in sorted order.
582 ConvolutionDimensions dimensions{
583 .batch: SmallVector<unsigned, 2>(batch.begin(), batch.end()),
584 .outputImage: SmallVector<unsigned, 2>(oi.begin(), oi.end()),
585 .outputChannel: SmallVector<unsigned, 2>(oc.begin(), oc.end()),
586 .filterLoop: SmallVector<unsigned, 2>(fl.begin(), fl.end()),
587 .inputChannel: SmallVector<unsigned, 2>(ic.begin(), ic.end()),
588 .depth: SmallVector<unsigned, 2>(depth.begin(), depth.end()),
589 /*strides=*/SmallVector<int64_t, 2>{},
590 /*dilations=*/SmallVector<int64_t, 2>{}};
591 llvm::sort(Start: dimensions.batch.begin(), End: dimensions.batch.end());
592 llvm::sort(Start: dimensions.outputImage.begin(), End: dimensions.outputImage.end());
593 llvm::sort(Start: dimensions.outputChannel.begin(), End: dimensions.outputChannel.end());
594 llvm::sort(Start: dimensions.filterLoop.begin(), End: dimensions.filterLoop.end());
595 llvm::sort(Start: dimensions.inputChannel.begin(), End: dimensions.inputChannel.end());
596 llvm::sort(Start: dimensions.depth.begin(), End: dimensions.depth.end());
597
598 // Use the op carried strides/dilations attribute if present.
599 auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
600 if (!nativeStrides) {
601 SmallVector<AffineExpr, 2> strideExprs;
602 for (unsigned oiDim : dimensions.outputImage)
603 strideExprs.push_back(Elt: inputExprWalker.strideAndDilationMapping[oiDim]);
604 dimensions.strides = getConstantsFromExprList(exprs: strideExprs);
605 } else {
606 dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
607 }
608 auto nativeDilations =
609 linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
610 if (!nativeDilations) {
611 SmallVector<AffineExpr, 2> dilationExprs;
612 for (unsigned flDim : dimensions.filterLoop)
613 dilationExprs.push_back(Elt: inputExprWalker.strideAndDilationMapping[flDim]);
614 dimensions.dilations = getConstantsFromExprList(exprs: dilationExprs);
615 } else {
616 dimensions.dilations =
617 llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
618 }
619 return dimensions;
620}
621
622/// Find at least 1 parallel (output_image) and reduction (filter_loop)
623/// dimension candidates that form a convolution subcomputation within
624/// `linalgOp`. The LHS is assumed to be the convolution input while the
625/// RHS is assumed as the filter.
626/// These dimensions are such that:
627/// 1. Optional batch dimensions that appear in the input and filter.
628/// 2. The output_image dimension is involved in a cross-correlation along LHS
629/// (i.e. it is a permutation on RES and LHS and has an associated
630/// filter_loop in RHS).
631/// 3. Optional output_channel dimension is involved in an outer-product along
632/// RHS (i.e. it is a permutation on RES and RHS and does not appear in
633/// LHS).
634/// 4. Optional input_channel dimension appears as a permutation on LHS and
635/// RHS.
636/// 5. The filter_loop dimension appears as a permutation on the RHS and
637/// represents the shape of the kernel cross-correlated along a
638/// corresponding output_image dim.
639/// 6. The input_channel dimension appears as a permutation on LHS and RHS.
640/// 7. All dimensions appear only once in any given indexing map.
641/// This allows e.g. detecting that some convolution is embedded within
642/// `linalgOp` with some orthogonal heuristic.
643/// When multiple dimension occurrences exist that match any classification
644/// indices are returned in sorted order.
645/// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty.
646FailureOr<ConvolutionDimensions>
647mlir::linalg::inferConvolutionDims(LinalgOp linalgOp) {
648 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
649 return failure();
650
651 auto indexingMaps = linalgOp.getIndexingMapsArray();
652
653 // Check the input indexing map has the right form.
654 ConvAccessExprWalker inputExprWalker;
655 for (AffineExpr expr : indexingMaps[0].getResults())
656 (void)inputExprWalker.visit(expr);
657 inputExprWalker.clearMultiUseDims(map: indexingMaps[0]);
658
659 return inferConvolutionDimsImpl(linalgOp, inputExprWalker,
660 /*allowEmptyConvolvedDims=*/false);
661}
662
663namespace mlir::linalg::detail {
664enum class MatchConvolutionResult {
665 Success = 0,
666 NotLinalgOp,
667 WrongNumOperands,
668 WrongInputIndexingMap,
669 NotProjectedPermutations,
670 NonConvolutionLoop,
671 OutputDimsNotParallel,
672 NonOutputDimNotReduction
673};
674} // namespace mlir::linalg::detail
675
676mlir::linalg::detail::MatchConvolutionResult
677mlir::linalg::detail::isConvolutionInterfaceImpl(
678 Operation *op, ConvolutionDimensions *dimensions) {
679 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
680 if (!linalgOp)
681 return MatchConvolutionResult::NotLinalgOp;
682 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
683 return MatchConvolutionResult::WrongNumOperands;
684
685 auto indexingMaps = linalgOp.getIndexingMapsArray();
686
687 // Check the input indexing map has the right form.
688 ConvAccessExprWalker inputExprWalker;
689 if (llvm::any_of(indexingMaps[0].getResults(),
690 [&inputExprWalker](AffineExpr expr) {
691 return failed(result: inputExprWalker.visit(expr));
692 })) {
693 return MatchConvolutionResult::WrongInputIndexingMap;
694 }
695
696 // Filter and output maps must be projected permutation.
697 if (!indexingMaps[1].isProjectedPermutation() ||
698 !indexingMaps.back().isProjectedPermutation())
699 return MatchConvolutionResult::NotProjectedPermutations;
700
701 auto iteratorTypes = linalgOp.getIteratorTypesArray();
702
703 llvm::SmallDenseSet<int64_t> outputDims =
704 getPreservedDims(indexingMaps.back());
705 llvm::SmallDenseSet<int64_t> filterDims = getPreservedDims(indexingMaps[1]);
706 // Make sure all loops are characterized as one of:
707 // - Batch loop : present in output, as non-convolved in input, not present in
708 // filter.
709 // - Output image dimension : present in output, convolved dims in input, not
710 // present in filter.
711 // - Output channel dimension : present in output, not present in input,
712 // present in filter.
713 // - Filter loop dimension : present in filter, convolved in input, not
714 // present in output.
715 // - Input channel dimension : unconvolved in input, not present in output,
716 // present in filter.
717 // - Depth multiplier : unconvolved in input, present in output, present in
718 // filter.
719 llvm::SmallDenseSet<int64_t> allLoopDims;
720 for (auto outputExpr : indexingMaps.back().getResults()) {
721 int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
722 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
723 !filterDims.count(outputDim)) {
724 // Batch dimension.
725 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
726 return MatchConvolutionResult::OutputDimsNotParallel;
727 allLoopDims.insert(outputDim);
728 continue;
729 }
730 if (inputExprWalker.convolvedDims.count(outputDim) &&
731 !filterDims.count(outputDim)) {
732 // Output image Loop dimension.
733 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
734 return MatchConvolutionResult::OutputDimsNotParallel;
735 allLoopDims.insert(outputDim);
736 continue;
737 }
738 if (!inputExprWalker.convolvedDims.count(outputDim) &&
739 !inputExprWalker.unConvolvedDims.count(outputDim) &&
740 filterDims.count(outputDim)) {
741 // Output channel dimension.
742 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
743 return MatchConvolutionResult::OutputDimsNotParallel;
744 allLoopDims.insert(outputDim);
745 continue;
746 }
747 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
748 filterDims.count(outputDim)) {
749 // Depth multiplier.
750 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
751 return MatchConvolutionResult::OutputDimsNotParallel;
752 allLoopDims.insert(outputDim);
753 continue;
754 }
755 return MatchConvolutionResult::NonConvolutionLoop;
756 }
757 for (auto filterExpr : indexingMaps[1].getResults()) {
758 int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
759 if (outputDims.count(filterDim) &&
760 !inputExprWalker.unConvolvedDims.count(filterDim) &&
761 !inputExprWalker.convolvedDims.count(filterDim)) {
762 // Output channel dimension. This is already seen, continue;
763 continue;
764 }
765 if (inputExprWalker.convolvedDims.count(filterDim) &&
766 !outputDims.count(filterDim)) {
767 // Filter loop dimension.
768 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
769 return MatchConvolutionResult::NonOutputDimNotReduction;
770 if (allLoopDims.count(filterDim))
771 return MatchConvolutionResult::NonConvolutionLoop;
772 allLoopDims.insert(filterDim);
773 continue;
774 }
775 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
776 !outputDims.count(filterDim)) {
777 // Input channel dimension.
778 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
779 return MatchConvolutionResult::NonOutputDimNotReduction;
780 if (allLoopDims.count(filterDim))
781 return MatchConvolutionResult::NonConvolutionLoop;
782 allLoopDims.insert(filterDim);
783 continue;
784 }
785 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
786 outputDims.count(filterDim)) {
787 // Depthwise loop. Already seen.
788 continue;
789 }
790 return MatchConvolutionResult::NonConvolutionLoop;
791 }
792 // All loops must be covered now.
793 if (allLoopDims.size() != linalgOp.getNumLoops())
794 return MatchConvolutionResult::NonConvolutionLoop;
795
796 if (dimensions) {
797 FailureOr<ConvolutionDimensions> res =
798 inferConvolutionDimsImpl(linalgOp, inputExprWalker,
799 /*allowEmptyConvolvedDims=*/true);
800 assert(succeeded(res) && "unexpected failure to infer convolution dims");
801 *dimensions = *res;
802 }
803
804 return MatchConvolutionResult::Success;
805}
806
807StringRef
808mlir::linalg::detail::getMatchConvolutionMessage(MatchConvolutionResult res) {
809 switch (res) {
810 case MatchConvolutionResult::NotLinalgOp:
811 return "expected a LinalgOp";
812 case MatchConvolutionResult::WrongNumOperands:
813 return "expected op with 2 inputs and 1 output";
814 case MatchConvolutionResult::WrongInputIndexingMap:
815 return "unexpected input index map for convolutions";
816 case MatchConvolutionResult::NotProjectedPermutations:
817 return "expected output/filter indexing maps to be projected permutations";
818 case MatchConvolutionResult::NonConvolutionLoop:
819 return "unexpected loop dimension for convolution op";
820 case MatchConvolutionResult::OutputDimsNotParallel:
821 return "expected all iterators used to access outputs to be parallel";
822 case MatchConvolutionResult::NonOutputDimNotReduction:
823 return "expected all iterators not used to access outputs to be reduction";
824 case MatchConvolutionResult::Success:
825 return "";
826 }
827 llvm_unreachable("unhandled MatchConvolutionResult case");
828}
829
830bool mlir::linalg::isaConvolutionOpInterface(LinalgOp linalgOp) {
831 return linalg::detail::isConvolutionInterfaceImpl(op: linalgOp.getOperation()) ==
832 linalg::detail::MatchConvolutionResult::Success;
833}
834
835LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
836 MatchConvolutionResult res = isConvolutionInterfaceImpl(op);
837 if (res != MatchConvolutionResult::Success)
838 return op->emitError(message: getMatchConvolutionMessage(res));
839 return success();
840}
841
842//===----------------------------------------------------------------------===//
843// FillOpInterface implementation
844//===----------------------------------------------------------------------===//
845
846enum class MatchFillResult {
847 Success = 0,
848 NotLinalgOp,
849 WrongNumOperands,
850 NotScalarInput
851};
852
853static MatchFillResult isFillInterfaceImpl(Operation *op) {
854 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
855 if (!linalgOp)
856 return MatchFillResult::NotLinalgOp;
857 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
858 return MatchFillResult::WrongNumOperands;
859
860 OpOperand *value = linalgOp.getDpsInputOperand(0);
861 if (!linalgOp.isScalar(value))
862 return MatchFillResult::NotScalarInput;
863
864 return MatchFillResult::Success;
865}
866
867LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) {
868 auto res = isFillInterfaceImpl(op);
869 if (res == MatchFillResult::NotLinalgOp)
870 return op->emitError(message: "expected a LinalgOp");
871 if (res == MatchFillResult::WrongNumOperands)
872 return op->emitError(message: "expected op with 1 input and 1 output");
873 if (res == MatchFillResult::NotScalarInput)
874 return op->emitError(message: "expected op with scalar input");
875
876 return success();
877}
878
879//===----------------------------------------------------------------------===//
880// StructuredOpInterface implementation
881//===----------------------------------------------------------------------===//
882
883SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
884 Location loc) {
885 SmallVector<OpFoldResult> res;
886 for (OpOperand &opOperand : getOperation()->getOpOperands()) {
887 for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
888 res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i));
889 }
890 return res;
891}
892
893SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
894 SmallVector<int64_t, 4> res;
895 assert(!hasDynamicShape() && "expected operands to have static shapes");
896 for (OpOperand &opOperand : getOperation()->getOpOperands())
897 llvm::append_range(res, getShape(&opOperand));
898 return res;
899}
900
901SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
902 AffineMap map = getLoopsToShapesMap();
903 unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
904 auto viewSizes = createFlatListOfOperandDims(b, loc);
905 SmallVector<Range, 4> res(numDims);
906 for (unsigned idx = 0; idx < numRes; ++idx) {
907 auto result = map.getResult(idx);
908 if (auto d = dyn_cast<AffineDimExpr>(result)) {
909 if (res[d.getPosition()].offset)
910 continue;
911 res[d.getPosition()] =
912 Range{b.getIndexAttr(0), viewSizes[idx], b.getIndexAttr(1)};
913 }
914 }
915 return res;
916}
917
918SmallVector<int64_t, 4> LinalgOp::computeStaticLoopSizes() {
919 AffineMap map = getLoopsToShapesMap();
920 unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
921 SmallVector<int64_t, 4> allShapeSizes = createFlatListOfOperandStaticDims();
922 SmallVector<int64_t, 4> res(numDims, 0);
923 for (unsigned idx = 0; idx < numRes; ++idx) {
924 auto result = map.getResult(idx);
925 if (auto d = dyn_cast<AffineDimExpr>(result))
926 res[d.getPosition()] = allShapeSizes[idx];
927 }
928 return res;
929}
930
931/// Visitor to check if any of the given set of positions from AffineDimExprs
932/// are used within an AffineExpr.
933struct HasAffineDimExprVisitor
934 : public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
935 HasAffineDimExprVisitor(llvm::SmallBitVector positions)
936 : positions(std::move(positions)) {}
937
938 bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) {
939 return visit(expr: binaryOpExpr.getLHS()) || visit(expr: binaryOpExpr.getRHS());
940 }
941
942 bool visitDimExpr(AffineDimExpr dimExpr) {
943 return positions.test(Idx: dimExpr.getPosition());
944 }
945
946 bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
947
948 bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
949
950private:
951 llvm::SmallBitVector positions;
952};
953
954static std::pair<int64_t, int64_t>
955getResultsPositionInLoopsToShapeMap(LinalgOp &op) {
956 int64_t inputRankSum = 0;
957 int64_t outputRankSum = 0;
958 for (OpOperand *input : op.getDpsInputOperands())
959 inputRankSum += op.getRank(input);
960 for (OpOperand &output : op.getDpsInitsMutable())
961 outputRankSum += op.getRank(&output);
962 return {inputRankSum, inputRankSum + outputRankSum};
963}
964
965LogicalResult
966LinalgOp::reifyResultShapes(OpBuilder &b,
967 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
968 // An example that helps understand the logic below.
969 // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
970 // We want to express the shape of dim 0 of O in terms of shape of the inputs.
971 // This is achieved as follows.
972 // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
973 // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
974 // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
975 // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
976 // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
977 AffineMap loopsToShapesMap = getLoopsToShapesMap();
978
979 // Find the position in the above map that represents the shape of the
980 // result:dim being inferred.
981 auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(*this);
982
983 /// From loopsToShapesMap extract the submap that represents the shape of the
984 /// (resultIdx, dim) needed.
985 AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap(
986 resultShapesSubMapPos.first,
987 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
988 AffineMap resultShapesFromInputShapesMap =
989 loopToResultsShapeMap.compose(getShapesToLoopsMap());
990
991 // Check that the result dim map does not contain the positions corresponding
992 // to the outputs.
993 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims());
994 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
995 HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
996 Location loc = getOperation()->getLoc();
997 IRRewriter rewriter(b);
998 SmallVector<OpFoldResult> allResultDimValues =
999 affine::makeComposedFoldedMultiResultAffineApply(
1000 rewriter, loc, resultShapesFromInputShapesMap,
1001 createFlatListOfOperandDims(b, loc));
1002 int64_t pos = 0;
1003 ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
1004 for (OpOperand &opOperand : getDpsInitsMutable()) {
1005 SmallVector<OpFoldResult> shapes;
1006 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1007 auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1008 if (!shapedType.isDynamicDim(dim)) {
1009 // Static dim: Return IntegerAttr.
1010 shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim)));
1011 } else {
1012 // Dynamic dim: Return Value.
1013 OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos])
1014 ? createOrFoldDimOp(b, loc, opOperand.get(), dim)
1015 : allResultDimValues[pos];
1016 shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
1017 }
1018 pos++;
1019 }
1020 reifiedReturnShapes.emplace_back(std::move(shapes));
1021 }
1022 return success();
1023}
1024
1025/// Return the index in the indexingMaps vector that corresponds to this
1026/// `opOperand`.
1027int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
1028 auto operandNumber = opOperand->getOperandNumber();
1029 auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1030 if (!dpsIface.isDpsInput(opOperand))
1031 return operandNumber;
1032 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1033 assert(!dpsIface.isDpsInit(opOperand));
1034 // Account for potential inputs that are not DPS and may not appear in
1035 // `indexingMaps`.
1036 return cast<DestinationStyleOpInterface>(*this->getOperation())
1037 .getNumDpsInputs() +
1038 operandNumber - start;
1039}
1040
1041LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
1042 LinalgOp linalgOp = cast<LinalgOp>(op);
1043
1044 // Mixed tensor/buffer operands are not allowed.
1045 if (!linalgOp.hasPureTensorSemantics() &&
1046 !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
1047 return op->emitOpError(message: "expected to have pure tensor or buffer semantics");
1048
1049 // Before checking indexing maps, we need to make sure the attributes
1050 // referenced by it are valid.
1051 if (linalgOp.hasDynamicIndexingMaps())
1052 if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1053 return failure();
1054
1055 // All input/output operands must be indexed.
1056 if (static_cast<int64_t>(linalgOp.getIndexingMapsArray().size()) !=
1057 linalgOp->getNumOperands())
1058 return op->emitOpError(message: "expected the number of indexing_map (")
1059 << linalgOp.getIndexingMapsArray().size()
1060 << ") to be equal to the number of input/output operands ("
1061 << linalgOp->getNumOperands() << ")";
1062
1063 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1064 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1065
1066 // Symbols disallowed.
1067 if (indexingMap.getNumSymbols() != 0)
1068 return op->emitOpError("unexpected symbols in indexing_map #")
1069 << opOperand.getOperandNumber();
1070
1071 // Domain must be consistent.
1072 unsigned numLoops = linalgOp.getNumLoops();
1073 if (indexingMap.getNumDims() != numLoops)
1074 return op->emitOpError("expected indexing_map #")
1075 << opOperand.getOperandNumber() << " to have " << numLoops
1076 << " dim(s) to match the number of loops";
1077
1078 int64_t rank = linalgOp.getRank(&opOperand);
1079 if (indexingMap.getNumResults() != rank)
1080 return op->emitOpError("expected operand rank (")
1081 << rank << ") to match the result rank of indexing_map #"
1082 << opOperand.getOperandNumber() << " ("
1083 << indexingMap.getNumResults() << ")";
1084 }
1085
1086 SmallVector<unsigned> redDims;
1087 linalgOp.getReductionDims(redDims);
1088
1089 if (!linalgOp.getShapesToLoopsMap())
1090 return op->emitOpError(message: "expected the shape-to-loops map to be non-null");
1091
1092 // Check if given shapes match to inferred shapes.
1093 SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
1094 SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
1095
1096 // Verify only static cases since we can't get exact dimension sizes and loop
1097 // ranges for dynamic cases in this stage.
1098 if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
1099 for (int64_t &range : endLoopRangeValues)
1100 range -= 1;
1101 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1102 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1103 SmallVector<int64_t, 4> startIndices =
1104 indexingMap.compose(startLoopRangeValues);
1105 SmallVector<int64_t, 4> endIndices =
1106 indexingMap.compose(endLoopRangeValues);
1107 ArrayRef<int64_t> shape = linalgOp.getShape(&opOperand);
1108 for (auto dim : llvm::seq<int64_t>(0, shape.size())) {
1109 // Ignore dynamic dimension or the case that the dimension size is 0
1110 if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
1111 continue;
1112
1113 // The first index or last index should be the maximum or the minimum in
1114 // the inferred index ranges since the range is increasing or
1115 // decreasing. The size of dimensions of input/output operands and the
1116 // maximum value + 1 in the inferred range should be the same. But, for
1117 // now we check if the inferred ranges are in boundary of input/output
1118 // operands' size or not in case that Affine Expressions are complicated
1119 // such as d0 * 3
1120 // + d1 since it is not easy to handle the issues.
1121 // Found the case that this solution can't check, for example, (d0, d1)
1122 // -> (d1 - d0)
1123 int64_t inferredDimSize =
1124 std::max(startIndices[dim], endIndices[dim]) + 1;
1125 if (std::min(startIndices[dim], endIndices[dim]) < 0) {
1126 std::string mapStr;
1127 {
1128 llvm::raw_string_ostream os(mapStr);
1129 os << indexingMap;
1130 }
1131 return op->emitOpError(
1132 "unexpected result less than 0 at expression #")
1133 << dim << " in " << mapStr;
1134 }
1135 if (dyn_cast<AffineDimExpr>(indexingMap.getResult(dim))) {
1136 if (inferredDimSize != shape[dim]) {
1137 return op->emitOpError("inferred input/output operand #")
1138 << opOperand.getOperandNumber() << " has shape's dimension #"
1139 << dim << " to be " << inferredDimSize << ", but found "
1140 << shape[dim];
1141 }
1142 } else {
1143 if (inferredDimSize > shape[dim]) {
1144 return op->emitOpError("inferred input/output operand #")
1145 << opOperand.getOperandNumber() << " has shape's dimension #"
1146 << dim << " to be greater than or equal to "
1147 << inferredDimSize << ", but found " << shape[dim];
1148 }
1149 }
1150 }
1151 }
1152 }
1153
1154 // Check the region has exactly one block.
1155 if (linalgOp->getNumRegions() != 1 ||
1156 !llvm::hasSingleElement(linalgOp->getRegion(0)))
1157 return op->emitOpError(message: "expects to have 1 region with 1 block");
1158
1159 // Simplifying assumption: bbargs match 1-1 with shape operands elemental
1160 // types.
1161 // TODO: once ranked shape types are plugged in, we may want to drop the
1162 // corresponding bbargs, that can never be read from. This will be subject to
1163 // consistency discussions (i.e. what to do with output tensors whose bbarg is
1164 // not used).
1165 Block &block = linalgOp->getRegion(0).front();
1166
1167 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments())
1168 return op->emitOpError(message: "expected as many non-induction variable region "
1169 "arguments as the number of input/output operands");
1170
1171 for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1172 Type elementType = opOperand->get().getType();
1173 if (isa<MemRefType, RankedTensorType>(elementType))
1174 elementType = getElementTypeOrSelf(opOperand->get().getType());
1175 Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
1176 if (elementType != argType)
1177 return op->emitOpError("expected type of bb argument #")
1178 << opOperand->getOperandNumber() << " (" << argType << ")"
1179 << " to match element or self type of the corresponding operand ("
1180 << elementType << ")";
1181 }
1182
1183 return success();
1184}
1185

source code of mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp