1//===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
10
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/Arith/Utils/Utils.h"
14#include "mlir/Dialect/Complex/IR/Complex.h"
15#include "mlir/Dialect/Linalg/IR/Linalg.h"
16#include "mlir/IR/AffineExpr.h"
17#include "mlir/IR/AffineExprVisitor.h"
18#include "mlir/IR/AffineMap.h"
19#include "mlir/IR/BuiltinTypeInterfaces.h"
20#include "mlir/IR/MLIRContext.h"
21#include "mlir/IR/TypeUtilities.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/SetOperations.h"
24#include "llvm/ADT/SmallBitVector.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/Support/Casting.h"
27#include "llvm/Support/raw_ostream.h"
28#include <algorithm>
29#include <numeric>
30#include <optional>
31
32using namespace mlir;
33using namespace mlir::linalg;
34
35/// Include the definitions of the copy operation interface.
36#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
37
38//===----------------------------------------------------------------------===//
39// Interface utility functions
40//===----------------------------------------------------------------------===//
41
42bool linalg::detail::canOpOperandsBeDroppedImpl(
43 linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
44 SmallVector<AffineMap> indexingMaps;
45 for (auto &opOperand : linalgOp->getOpOperands()) {
46 if (llvm::is_contained(droppedOperands, &opOperand))
47 continue;
48 indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
49 }
50 if (indexingMaps.empty()) {
51 // If there are no indexing maps, the operand can only be dropped
52 // if the op has no loops.
53 return linalgOp.getNumLoops() == 0;
54 }
55 return inversePermutation(concatAffineMaps(
56 indexingMaps, linalgOp.getContext())) != AffineMap();
57}
58
59//===----------------------------------------------------------------------===//
60// CopyOpInterface implementation
61//===----------------------------------------------------------------------===//
62
63bool linalg::isaCopyOpInterface(LinalgOp op) {
64 // Check all loops are parallel and linalgOp is single input and output.
65 if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
66 return false;
67
68 auto mapRange = op.getIndexingMapsArray();
69 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
70 !mapRange.back().isIdentity()) {
71 return false;
72 }
73 // Region.
74 return llvm::hasSingleElement(op.getBlock()->getOperations());
75}
76
77//===----------------------------------------------------------------------===//
78// FillOpInterface implementation
79//===----------------------------------------------------------------------===//
80std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
81 // Structural.
82 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
83 !op.isSingleYieldOp())
84 return std::nullopt;
85
86 // Input should be referenced and init should not.
87 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
88 op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
89 return std::nullopt;
90
91 OpOperand *value = op.getDpsInputOperand(0);
92 if (!op.isScalar(value))
93 return std::nullopt;
94 return value->get();
95}
96
97//===----------------------------------------------------------------------===//
98// BroadcastOpInterface implementation
99//===----------------------------------------------------------------------===//
100std::optional<SmallVector<int64_t>>
101linalg::isaBroadcastOpInterface(GenericOp op) {
102 // Structural.
103 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
104 !op.isSingleYieldOp())
105 return std::nullopt;
106
107 auto srcTy = op.getDpsInputOperand(0)->get().getType();
108 auto dstTy = op.getDpsInitOperand(0)->get().getType();
109 if (!isa<MemRefType, RankedTensorType>(srcTy) ||
110 !isa<MemRefType, RankedTensorType>(dstTy))
111 return std::nullopt;
112
113 // Check output is identity map. Broadcast could additionally be
114 // employing permutation of indices and that would be expressible
115 // in linalg.generic but is not expressible for named broadcast op.
116 auto dstMap = op.getIndexingMapsArray()[1];
117 if (!dstMap.isIdentity())
118 return std::nullopt;
119
120 SmallVector<int64_t> position;
121 auto srcMap = op.getIndexingMapsArray()[0];
122
123 if (srcMap.getResults().size() >= dstMap.getResults().size())
124 return std::nullopt;
125
126 // Check input map is monotonically increasing DimIds.
127 for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
128 auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
129 if (!expr)
130 return std::nullopt;
131 int64_t pos = expr.getPosition();
132 if (i > 0 && pos <= position[i - 1])
133 return std::nullopt;
134 position.push_back(Elt: expr.getPosition());
135 }
136
137 SmallVector<int64_t> broadcastedDims;
138 auto numDims = srcMap.getNumDims();
139 // This is quadratic but number of items is generally small.
140 for (auto dim : llvm::seq<int64_t>(0, numDims)) {
141 if (!llvm::is_contained(position, dim))
142 broadcastedDims.push_back(dim);
143 }
144 return broadcastedDims;
145}
146
147//===----------------------------------------------------------------------===//
148// TransposeOpInterface implementation
149//===----------------------------------------------------------------------===//
150std::optional<SmallVector<int64_t>>
151linalg::isaTransposeOpInterface(GenericOp op) {
152 // To specialize as a transpose op, the genericOp must be
153 // all parallel loops, single input, single output, and its body
154 // should be just a yield op, yielding input as output as is (no compute).
155 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
156 !op.isSingleYieldOp())
157 return std::nullopt;
158
159 auto mapRange = op.getIndexingMapsArray();
160 if (mapRange.size() != 2)
161 return std::nullopt;
162
163 auto mapOfInput = mapRange.front();
164 auto mapOfResult = mapRange.back();
165
166 // linalg.transpose permutes the dimensions of input using this
167 // rule: dim(result, i) = dim(input, permutation[i])
168 if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
169 return std::nullopt;
170
171 SmallVector<int64_t> permutation(mapOfInput.getNumDims());
172 for (unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
173 auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
174 permutation[expr.getPosition()] = i;
175 }
176 return permutation;
177}
178
179//===----------------------------------------------------------------------===//
180// Elementwise Single Unary/Binary-OpInterface implementation
181//===----------------------------------------------------------------------===//
182static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op,
183 unsigned arity) {
184 // Check all loops are parallel.
185 if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
186 return false;
187
188 // Check there are arity-inputs, 1-output and all are identity-maps.
189 if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
190 !llvm::all_of(op.getIndexingMapsArray(),
191 [](AffineMap map) { return map.isIdentity(); }))
192 return false;
193
194 // Init should not be referenced for elementwise operations.
195 if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
196 return false;
197
198 // A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such
199 // as resulting from producer-consumer fusion. Here, we restrict to two ops in
200 // the body, where the first is the elementwise single op and the second a
201 // yield.
202 Block *body = op.getBody();
203 if (body->getOperations().size() != 2)
204 return false;
205
206 Operation *oper = &body->front();
207 if (oper->getNumOperands() != arity || oper->getNumResults() != 1)
208 return false;
209
210 auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
211 if (!yieldOp || yieldOp.getNumOperands() != 1 ||
212 yieldOp->getOperand(0).getDefiningOp() != oper)
213 return false;
214 return true;
215}
216
217bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op) {
218 // All basic elemwise checks.
219 if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 1))
220 return false;
221
222 // Check input is actully used.
223 if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
224 return false;
225 return true;
226}
227
228bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op) {
229 if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 2))
230 return false;
231
232 // Check both inputs are used (elementwise).
233 OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
234 OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
235 if (!op.payloadUsesValueFromOperand(inputOpOperand0) ||
236 !op.payloadUsesValueFromOperand(inputOpOperand1))
237 return false;
238 return true;
239}
240
241//===----------------------------------------------------------------------===//
242// ContractionOpInterface implementation
243//===----------------------------------------------------------------------===//
244
245/// If the value is defined by a chain of unary side effect-free, go up the
246/// use-def chain until the first value that isn't defined by such an op.
247// TODO: relax to multi-operands with constants, which are technically unary ops
248// as needed (e.g. add5).
249static Value getSourceSkipUnary(Value value) {
250 Operation *op = value.getDefiningOp();
251 while (op && op->getNumOperands() == 1) {
252 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
253 if (!iface || !iface.hasNoEffect())
254 break;
255 value = op->getOperand(idx: 0);
256 op = value.getDefiningOp();
257 }
258 return value;
259}
260
261bool mlir::linalg::detail::isContractionBody(
262 Block &block, function_ref<bool(Operation *, Operation *)> isaPair,
263 llvm::raw_ostream &errs) {
264 if (block.empty() || !block.back().mightHaveTrait<OpTrait::IsTerminator>()) {
265 errs << "no terminator in the block";
266 return false;
267 }
268
269 if (block.getNumArguments() != 3) {
270 errs << "expected block with 3 arguments";
271 return false;
272 }
273
274 Operation *terminator = block.getTerminator();
275 if (terminator->getNumOperands() != 1) {
276 errs << "expected terminator with 1 operand";
277 return false;
278 }
279
280 Value yielded = getSourceSkipUnary(value: terminator->getOperand(idx: 0));
281 Operation *reductionOp = yielded.getDefiningOp();
282 if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) {
283 errs << "expected reduction op to be binary";
284 return false;
285 }
286
287 Value reductionLHS = getSourceSkipUnary(value: reductionOp->getOperand(idx: 0));
288 Value reductionRHS = getSourceSkipUnary(value: reductionOp->getOperand(idx: 1));
289
290 if (reductionLHS != block.getArgument(i: 2) &&
291 reductionRHS != block.getArgument(i: 2)) {
292 errs << "expected reduction to take block argument #2 as one of the "
293 "operands (modulo unary casts)";
294 return false;
295 }
296
297 Value contributed = getSourceSkipUnary(
298 value: isa<BlockArgument>(Val: reductionLHS) ? reductionRHS : reductionLHS);
299 Operation *elementwiseOp = contributed.getDefiningOp();
300 if (!elementwiseOp || elementwiseOp->getNumResults() != 1 ||
301 elementwiseOp->getNumOperands() != 2) {
302 errs << "expected elementwise op to be binary";
303 return false;
304 }
305
306 if (!isaPair(elementwiseOp, reductionOp)) {
307 errs << "expected reduction/elementwise op kind not satisfied";
308 return false;
309 }
310
311 Value elementwiseLHS = getSourceSkipUnary(value: elementwiseOp->getOperand(idx: 0));
312 Value elementwiseRHS = getSourceSkipUnary(value: elementwiseOp->getOperand(idx: 1));
313 if ((elementwiseLHS == block.getArgument(i: 0) &&
314 elementwiseRHS == block.getArgument(i: 1)) ||
315 (elementwiseLHS == block.getArgument(i: 1) &&
316 elementwiseRHS == block.getArgument(i: 0))) {
317 return true;
318 }
319
320 errs << "expected elementwise op to apply to block arguments (modulo unary "
321 "casts)";
322 return false;
323}
324
325/// Returns true if the two operations are of the kinds specified by a pair of
326/// consecutive template arguments.
327template <typename AddOpTy, typename MulOpTy, typename... Args>
328static bool isPairTemplateImpl(Operation *add, Operation *mul) {
329 static_assert(sizeof...(Args) % 2 == 0,
330 "expected an even number of template arguments");
331 if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
332 return true;
333
334 if constexpr (sizeof...(Args) > 0)
335 return isPairTemplateImpl<Args...>(add, mul);
336 else
337 return false;
338}
339
340/// Returns true if the block is a body of a contraction with the kinds of
341/// operations given pairwise by template arguments.
342template <typename... Args>
343static bool isContractionBody(Block &block) {
344 return linalg::detail::isContractionBody(block, isaPair: &isPairTemplateImpl<Args...>);
345}
346
347/// Given an `indexingMap` and its corresponding `iterators`, returns
348/// the positions of the iterators of type `iter` that are indexed by
349/// the `indexingMap` as a permutation. This is useful to infer various
350/// subcomputations on a `LinalgOp`. This is performed by looking up
351/// each result in the `indexingMap` and determining whether:
352/// - It is a single AffineDimExpr.
353/// - It is the only result involving this AffineDimExpr.
354static llvm::SmallDenseSet<int64_t>
355findPermutationsIndexingOperand(AffineMap indexingMap,
356 ArrayRef<utils::IteratorType> iterators,
357 utils::IteratorType iter) {
358 assert(iterators.size() == indexingMap.getNumDims());
359 llvm::SmallDenseSet<int64_t> res;
360 for (AffineExpr e : indexingMap.getResults()) {
361 if (auto d = dyn_cast<AffineDimExpr>(Val&: e)) {
362 if (iterators[d.getPosition()] == iter &&
363 llvm::count_if(Range: indexingMap.getResults(), P: [d](AffineExpr e) {
364 return e.isFunctionOfDim(position: d.getPosition());
365 }) == 1)
366 res.insert(V: d.getPosition());
367 }
368 }
369 return res;
370}
371
372namespace {
373auto par = utils::IteratorType::parallel;
374auto red = utils::IteratorType::reduction;
375} // namespace
376
377/// Infer the iterator types from the init affine map. This looks at which dims
378/// are present in the map results, and returns an iterator types array with
379/// parallel types for dims that are present, and reduction types for dims that
380/// are not present.
381static FailureOr<SmallVector<utils::IteratorType>>
382inferIteratorsFromOutMap(AffineMap map) {
383 if (!map.isProjectedPermutation())
384 return failure();
385 SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
386 for (auto expr : map.getResults())
387 if (auto dim = dyn_cast<AffineDimExpr>(Val&: expr))
388 iterators[dim.getPosition()] = par;
389 return iterators;
390}
391
392/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
393/// a matmul subcomputation within `linalgOp`. These dimensions are such that:
394/// 1. The m dimension is involved in an outer-product along LHS
395/// (i.e. it is a permutation on RES and LHS and does not appear in RHS).
396/// 2. The n dimension is involved in an outer-product along RHS
397/// (i.e. it is a permutation on RES and RHS and does not appear in LHS).
398/// 3. The k dimension appears as a permutation on LHS and RHS.
399/// 4. m, n and k appear only once in any given indexing.
400/// 5. Optional batch dimensions that appear in all operands are captured.
401/// This allows e.g. detecting that some contraction is embedded within
402/// `linalgOp` with some orthogonal heuristic.
403static FailureOr<ContractionDimensions>
404inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps,
405 ArrayRef<utils::IteratorType> iterators) {
406 llvm::SmallDenseSet<int64_t> a =
407 findPermutationsIndexingOperand(indexingMaps[0], iterators, par);
408 llvm::SmallDenseSet<int64_t> b =
409 findPermutationsIndexingOperand(indexingMaps[1], iterators, par);
410 llvm::SmallDenseSet<int64_t> c =
411 findPermutationsIndexingOperand(indexingMaps[2], iterators, par);
412
413 // A & C - B are the iterators involved in an outer-product along A (the LHS).
414 llvm::SmallDenseSet<int64_t> ac = a;
415 llvm::set_intersect(S1&: ac, S2: c);
416 llvm::set_subtract(S1&: ac, S2: b);
417 // B & C - A are the iterators involved in an outer-product along B (the RHS).
418 llvm::SmallDenseSet<int64_t> bc = b;
419 llvm::set_intersect(S1&: bc, S2: c);
420 llvm::set_subtract(S1&: bc, S2: a);
421 // A & B & C are the "batch" dimensions.
422 llvm::SmallDenseSet<int64_t> batches = a;
423 llvm::set_intersect(S1&: batches, S2: b);
424 llvm::set_intersect(S1&: batches, S2: c);
425
426 // A & B red are the reduction dimensions.
427 llvm::SmallDenseSet<int64_t> ra =
428 findPermutationsIndexingOperand(indexingMaps[0], iterators, red);
429 llvm::SmallDenseSet<int64_t> rb =
430 findPermutationsIndexingOperand(indexingMaps[1], iterators, red);
431 llvm::set_intersect(S1&: ra, S2: rb);
432
433 // Return each set in sorted order.
434 ContractionDimensions dimensions{
435 .batch: SmallVector<unsigned, 2>(batches.begin(), batches.end()),
436 .m: SmallVector<unsigned, 2>(ac.begin(), ac.end()),
437 .n: SmallVector<unsigned, 2>(bc.begin(), bc.end()),
438 .k: SmallVector<unsigned, 2>(ra.begin(), ra.end())};
439 llvm::sort(Start: dimensions.batch.begin(), End: dimensions.batch.end());
440 llvm::sort(Start: dimensions.m.begin(), End: dimensions.m.end());
441 llvm::sort(Start: dimensions.n.begin(), End: dimensions.n.end());
442 llvm::sort(Start: dimensions.k.begin(), End: dimensions.k.end());
443 return dimensions;
444}
445
446FailureOr<ContractionDimensions>
447mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
448 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
449 return failure();
450 return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(),
451 linalgOp.getIteratorTypesArray());
452}
453
454FailureOr<ContractionDimensions>
455mlir::linalg::inferContractionDims(ArrayRef<AffineMap> indexingMaps) {
456 if (indexingMaps.size() != 3)
457 return failure();
458 auto iterators = inferIteratorsFromOutMap(indexingMaps[2]);
459 if (failed(iterators))
460 return failure();
461 return inferContractionDimsImpl(indexingMaps, iterators.value());
462}
463
464namespace mlir::linalg::detail {
465enum class MatchContractionResult {
466 Success = 0,
467 NotLinalgOp,
468 WrongNumOperands,
469 NoReduction,
470 NotProjectedPermutations,
471 NotAddMul
472};
473} // namespace mlir::linalg::detail
474
475mlir::linalg::detail::MatchContractionResult
476mlir::linalg::detail::isContractionInterfaceImpl(
477 Operation *op, mlir::linalg::ContractionDimensions *dimensions) {
478 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
479 if (!linalgOp)
480 return MatchContractionResult::NotLinalgOp;
481 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
482 return MatchContractionResult::WrongNumOperands;
483 auto mapRange = linalgOp.getIndexingMapsArray();
484 if (linalgOp.getNumReductionLoops() == 0)
485 return MatchContractionResult::NoReduction;
486 if (llvm::any_of(mapRange,
487 [](AffineMap m) { return !m.isProjectedPermutation(); }))
488 return MatchContractionResult::NotProjectedPermutations;
489 // TODO: more fields than add/mul.
490 // clang-format off
491 if (!::isContractionBody<
492 arith::MulFOp, arith::AddFOp,
493 arith::MulIOp, arith::AddIOp,
494 complex::MulOp, complex::AddOp,
495 arith::AndIOp, arith::OrIOp>(
496 *linalgOp.getBlock())) {
497 return MatchContractionResult::NotAddMul;
498 }
499 // clang-format on
500
501 if (dimensions) {
502 FailureOr<ContractionDimensions> res = inferContractionDims(linalgOp);
503 assert(succeeded(res) && "unexpected failure to infer contraction dims");
504 *dimensions = *res;
505 }
506 return MatchContractionResult::Success;
507}
508
509StringRef
510mlir::linalg::detail::getMatchContractionMessage(MatchContractionResult res) {
511 switch (res) {
512 case MatchContractionResult::NotLinalgOp:
513 return "expected a LinalgOp";
514 case MatchContractionResult::WrongNumOperands:
515 return "expected op with 2 inputs and 1 output";
516 case MatchContractionResult::NoReduction:
517 return "expected at least 1 reduction";
518 case MatchContractionResult::NotProjectedPermutations:
519 return "expected indexing maps to be projected permutations";
520 case MatchContractionResult::NotAddMul:
521 return "expected add/mul op in the body";
522 case MatchContractionResult::Success:
523 return "";
524 }
525 llvm_unreachable("unhandled MatchContractionResult case");
526}
527
528bool mlir::linalg::isaContractionOpInterface(LinalgOp linalgOp) {
529 if (!linalgOp)
530 return false;
531 Operation *op = linalgOp.getOperation();
532 return isa<ContractionOpInterface>(op) ||
533 (mlir::linalg::detail::isContractionInterfaceImpl(op) ==
534 mlir::linalg::detail::MatchContractionResult::Success);
535}
536
537/// Verify that a LinalgOp `op` is a contraction.
538/// A Linalg contraction is defined in general terms:
539/// 1. Has 2 input and 1 output shapes.
540/// 2. Has at least one reduction dimension.
541/// 3. Has only projected permutation indexing maps.
542/// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
543/// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
544/// operations that may change the type (e.g. for mixed-precision).
545/// As a consequence, when vectorization of such an op occurs, the only special
546/// behavior is that the (unique) MulOpType is vectorized into a
547/// `vector.contract`. All other ops are handled in a generic fashion.
548/// In the future, we may wish to allow more input arguments and elementwise and
549/// constant operations that do not involve the reduction dimension(s).
550LogicalResult mlir::linalg::detail::verifyContractionInterface(Operation *op) {
551 auto res = isContractionInterfaceImpl(op);
552 if (res != MatchContractionResult::Success)
553 return op->emitError(message: getMatchContractionMessage(res));
554 return success();
555}
556
557//===----------------------------------------------------------------------===//
558// ConvolutionOpInterface implementation
559//===----------------------------------------------------------------------===//
560
561/// Of the given two expressions returns one that is of type T (`lhs` gets
562/// preference over `rhs`)
563template <typename T>
564static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs) {
565 return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) : nullptr);
566}
567
568namespace {
569/// Walk the indexing expressions for input of a convolution operation to verify
570/// its of the right form, either
571/// - AffineDimExpr
572/// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?
573/// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)*
574///
575/// classifies the AffineDimExpr as convolved dimensions or unconvolved
576/// dimensions and verifies each dimension occurs only once.
577struct ConvAccessExprWalker
578 : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {
579 // Stores dimensions used in expressions of the above form.
580 llvm::SmallDenseSet<int64_t> convolvedDims;
581 // Stores the dual mapping between LHS and RHS of convolution exprs.
582 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
583 // Stores single use dimensions used by an AffineDimExpr.
584 llvm::SmallDenseSet<int64_t> unConvolvedDims;
585 // Stores a mapping from convolved dims to their coefficient.
586 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
587
588 // Removes dims with multiple uses in the source input map from dimension
589 // sets tracked by this walker.
590 void clearMultiUseDims(AffineMap map) {
591 for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) {
592 if (llvm::count_if(Range: map.getResults(), P: [dimPos](AffineExpr e) {
593 return e.isFunctionOfDim(position: dimPos);
594 }) > 1) {
595 convolvedDims.erase(V: dimPos);
596 unConvolvedDims.erase(V: dimPos);
597 // If a duplicate dim is marked as convolved, the pair of the duplicate
598 // dim must be removed from the map as well.
599 auto it = convolvedDimMapping.find(Val: dimPos);
600 if (it != convolvedDimMapping.end()) {
601 int64_t pairedDim = it->second;
602 convolvedDims.erase(V: pairedDim);
603 unConvolvedDims.erase(V: pairedDim);
604 strideAndDilationMapping.erase(Val: pairedDim);
605 convolvedDimMapping.erase(Val: dimPos);
606 convolvedDimMapping.erase(Val: pairedDim);
607 }
608 }
609 }
610 }
611
612 LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
613 unsigned position = dimExpr.getPosition();
614 if (unConvolvedDims.count(V: position) || convolvedDims.count(V: position)) {
615 return failure();
616 }
617 unConvolvedDims.insert(V: position);
618 return success();
619 }
620
621 LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); }
622
623 LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); }
624
625 LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
626 // In pre-order visit, top level op has to be an add op.
627 if (binaryExpr.getKind() != AffineExprKind::Add)
628 return failure();
629 auto lhsDimPos = getDimExprOrMulExprDimPos(expr: binaryExpr.getLHS());
630 auto rhsDimPos = getDimExprOrMulExprDimPos(expr: binaryExpr.getRHS());
631 if (failed(Result: lhsDimPos) || failed(Result: rhsDimPos))
632 return failure();
633 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
634 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
635 return success();
636 }
637
638 FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
639 if (auto dimExpr = dyn_cast<AffineDimExpr>(Val&: expr)) {
640 int64_t dim = dimExpr.getPosition();
641 if (convolvedDims.count(V: dim) || unConvolvedDims.count(V: dim))
642 return failure();
643 // Stride/dilation for this dim is implicitly 1.
644 strideAndDilationMapping[dim] =
645 getAffineConstantExpr(constant: 1, context: expr.getContext());
646 convolvedDims.insert(V: dim);
647 return dim;
648 }
649 if (auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(Val&: expr)) {
650 if (symbolMulExpr.getKind() != AffineExprKind::Mul)
651 return failure();
652 auto lhsExpr = symbolMulExpr.getLHS();
653 auto rhsExpr = symbolMulExpr.getRHS();
654 // Check for symbol expression.
655 AffineExpr mulExpr =
656 getAffineExprOfType<AffineSymbolExpr>(lhs: lhsExpr, rhs: rhsExpr);
657 // If there was no symbol expr, check for constant expression.
658 if (!mulExpr) {
659 mulExpr = getAffineExprOfType<AffineConstantExpr>(lhs: lhsExpr, rhs: rhsExpr);
660 }
661 auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhs: lhsExpr, rhs: rhsExpr);
662 if (!mulExpr || !dimExpr)
663 return failure();
664 int64_t dim = dimExpr.getPosition();
665 if (convolvedDims.count(V: dim) || unConvolvedDims.count(V: dim))
666 return failure();
667 strideAndDilationMapping[dim] = mulExpr;
668 convolvedDims.insert(V: dim);
669 return dim;
670 }
671 return failure();
672 }
673};
674} // namespace
675
676static llvm::SmallDenseSet<int64_t> getPreservedDims(AffineMap map) {
677 assert(map.isProjectedPermutation() &&
678 "expected map to have projected permutations");
679 llvm::SmallDenseSet<int64_t> preservedDims;
680 for (auto expr : map.getResults())
681 preservedDims.insert(V: cast<AffineDimExpr>(Val&: expr).getPosition());
682 return preservedDims;
683}
684
685static SmallVector<int64_t, 2>
686getConstantsFromExprList(const SmallVector<AffineExpr, 2> &exprs) {
687 SmallVector<int64_t, 2> vals;
688 for (auto e : exprs) {
689 auto constantExpr = dyn_cast<AffineConstantExpr>(Val&: e);
690 assert(constantExpr && "Found non-constant stride/dilation");
691 vals.push_back(Elt: constantExpr.getValue());
692 }
693 return vals;
694}
695
696/// Classifies dimensions in the `linalgOp` used by a convolution
697/// subcomputation, as captured by `inputExprWalker`. If
698/// `allowEmptyConvolvedDims` is not set this this will fail if there is not
699/// at least convolved dimension pair (output image + filter loop). Convolution
700/// dimensions are specified in sorted order, and strides match the order of
701/// the filter loop dimensions, while the dilations match the order of the
702/// output image dimensions.
703static FailureOr<ConvolutionDimensions>
704inferConvolutionDimsImpl(LinalgOp linalgOp,
705 ConvAccessExprWalker &inputExprWalker,
706 bool allowEmptyConvolvedDims) {
707 auto filterMap =
708 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
709 auto outputMap =
710 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
711 llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand(
712 filterMap, linalgOp.getIteratorTypesArray(), par);
713 llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand(
714 outputMap, linalgOp.getIteratorTypesArray(), par);
715
716 // unConvolvedDims & outputDims - filterDims are the batch iterators.
717 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
718 llvm::set_intersect(S1&: batch, S2: outputDims);
719 llvm::set_subtract(S1&: batch, S2: filterDims);
720
721 // convolvedDims & outputDims are the output image iterators.
722 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
723 llvm::set_intersect(S1&: oi, S2: outputDims);
724
725 // filterDims & outputDims - unConvolvedDims are the output channel iterators.
726 llvm::SmallDenseSet<int64_t> oc = filterDims;
727 llvm::set_intersect(S1&: oc, S2: outputDims);
728 llvm::set_subtract(S1&: oc, S2: inputExprWalker.unConvolvedDims);
729
730 // filterDims & outputDims & unConvolvedDims are the depth iterators.
731 llvm::SmallDenseSet<int64_t> depth = filterDims;
732 llvm::set_intersect(S1&: depth, S2: outputDims);
733 llvm::set_intersect(S1&: depth, S2: inputExprWalker.unConvolvedDims);
734
735 llvm::SmallDenseSet<int64_t> filterReducedDims =
736 findPermutationsIndexingOperand(filterMap,
737 linalgOp.getIteratorTypesArray(), red);
738
739 // convolvedDims & filterReducedDims are the filter loop iterators.
740 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
741 llvm::set_intersect(S1&: fl, S2: filterReducedDims);
742
743 // unConvolvedDims & filterReducedDims are the input channel iterators.
744 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
745 llvm::set_intersect(S1&: ic, S2: filterReducedDims);
746
747 if (oi.empty() && !allowEmptyConvolvedDims)
748 return failure();
749
750 // Return each set in sorted order.
751 ConvolutionDimensions dimensions{
752 .batch: SmallVector<unsigned, 2>(batch.begin(), batch.end()),
753 .outputImage: SmallVector<unsigned, 2>(oi.begin(), oi.end()),
754 .outputChannel: SmallVector<unsigned, 2>(oc.begin(), oc.end()),
755 .filterLoop: SmallVector<unsigned, 2>(fl.begin(), fl.end()),
756 .inputChannel: SmallVector<unsigned, 2>(ic.begin(), ic.end()),
757 .depth: SmallVector<unsigned, 2>(depth.begin(), depth.end()),
758 /*strides=*/SmallVector<int64_t, 2>{},
759 /*dilations=*/SmallVector<int64_t, 2>{}};
760 llvm::sort(Start: dimensions.batch.begin(), End: dimensions.batch.end());
761 llvm::sort(Start: dimensions.outputImage.begin(), End: dimensions.outputImage.end());
762 llvm::sort(Start: dimensions.outputChannel.begin(), End: dimensions.outputChannel.end());
763 llvm::sort(Start: dimensions.filterLoop.begin(), End: dimensions.filterLoop.end());
764 llvm::sort(Start: dimensions.inputChannel.begin(), End: dimensions.inputChannel.end());
765 llvm::sort(Start: dimensions.depth.begin(), End: dimensions.depth.end());
766
767 // Use the op carried strides/dilations attribute if present.
768 auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
769 if (!nativeStrides) {
770 SmallVector<AffineExpr, 2> strideExprs;
771 for (unsigned oiDim : dimensions.outputImage)
772 strideExprs.push_back(Elt: inputExprWalker.strideAndDilationMapping[oiDim]);
773 dimensions.strides = getConstantsFromExprList(exprs: strideExprs);
774 } else {
775 dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
776 }
777 auto nativeDilations =
778 linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
779 if (!nativeDilations) {
780 SmallVector<AffineExpr, 2> dilationExprs;
781 for (unsigned flDim : dimensions.filterLoop)
782 dilationExprs.push_back(Elt: inputExprWalker.strideAndDilationMapping[flDim]);
783 dimensions.dilations = getConstantsFromExprList(exprs: dilationExprs);
784 } else {
785 dimensions.dilations =
786 llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
787 }
788 return dimensions;
789}
790
791/// Find at least 1 parallel (output_image) and reduction (filter_loop)
792/// dimension candidates that form a convolution subcomputation within
793/// `linalgOp`. The LHS is assumed to be the convolution input while the
794/// RHS is assumed as the filter.
795/// These dimensions are such that:
796/// 1. Optional batch dimensions that appear in the input and filter.
797/// 2. The output_image dimension is involved in a cross-correlation along LHS
798/// (i.e. it is a permutation on RES and LHS and has an associated
799/// filter_loop in RHS).
800/// 3. Optional output_channel dimension is involved in an outer-product along
801/// RHS (i.e. it is a permutation on RES and RHS and does not appear in
802/// LHS).
803/// 4. Optional input_channel dimension appears as a permutation on LHS and
804/// RHS.
805/// 5. The filter_loop dimension appears as a permutation on the RHS and
806/// represents the shape of the kernel cross-correlated along a
807/// corresponding output_image dim.
808/// 6. The input_channel dimension appears as a permutation on LHS and RHS.
809/// 7. All dimensions appear only once in any given indexing map.
810/// This allows e.g. detecting that some convolution is embedded within
811/// `linalgOp` with some orthogonal heuristic.
812/// When multiple dimension occurrences exist that match any classification
813/// indices are returned in sorted order.
814/// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty.
815FailureOr<ConvolutionDimensions>
816mlir::linalg::inferConvolutionDims(LinalgOp linalgOp) {
817 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
818 return failure();
819
820 auto indexingMaps = linalgOp.getIndexingMapsArray();
821
822 // Check the input indexing map has the right form.
823 ConvAccessExprWalker inputExprWalker;
824 for (AffineExpr expr : indexingMaps[0].getResults())
825 (void)inputExprWalker.visit(expr);
826 inputExprWalker.clearMultiUseDims(map: indexingMaps[0]);
827
828 return inferConvolutionDimsImpl(linalgOp, inputExprWalker,
829 /*allowEmptyConvolvedDims=*/false);
830}
831
832namespace mlir::linalg::detail {
833enum class MatchConvolutionResult {
834 Success = 0,
835 NotLinalgOp,
836 WrongNumOperands,
837 WrongInputIndexingMap,
838 NotProjectedPermutations,
839 NonConvolutionLoop,
840 OutputDimsNotParallel,
841 NonOutputDimNotReduction,
842 EmptyConvolvedDims
843};
844} // namespace mlir::linalg::detail
845
846mlir::linalg::detail::MatchConvolutionResult
847mlir::linalg::detail::isConvolutionInterfaceImpl(
848 Operation *op, ConvolutionDimensions *dimensions,
849 bool allowEmptyConvolvedDims) {
850 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
851 if (!linalgOp)
852 return MatchConvolutionResult::NotLinalgOp;
853 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
854 return MatchConvolutionResult::WrongNumOperands;
855
856 auto indexingMaps = linalgOp.getIndexingMapsArray();
857
858 // Check the input indexing map has the right form.
859 ConvAccessExprWalker inputExprWalker;
860 if (llvm::any_of(indexingMaps[0].getResults(),
861 [&inputExprWalker](AffineExpr expr) {
862 return failed(Result: inputExprWalker.visit(expr));
863 })) {
864 return MatchConvolutionResult::WrongInputIndexingMap;
865 }
866
867 // Filter and output maps must be projected permutation.
868 if (!indexingMaps[1].isProjectedPermutation() ||
869 !indexingMaps.back().isProjectedPermutation())
870 return MatchConvolutionResult::NotProjectedPermutations;
871
872 auto iteratorTypes = linalgOp.getIteratorTypesArray();
873
874 llvm::SmallDenseSet<int64_t> outputDims =
875 getPreservedDims(indexingMaps.back());
876 llvm::SmallDenseSet<int64_t> filterDims = getPreservedDims(indexingMaps[1]);
877 // Make sure all loops are characterized as one of:
878 // - Batch loop : present in output, as non-convolved in input, not present in
879 // filter.
880 // - Output image dimension : present in output, convolved dims in input, not
881 // present in filter.
882 // - Output channel dimension : present in output, not present in input,
883 // present in filter.
884 // - Filter loop dimension : present in filter, convolved in input, not
885 // present in output.
886 // - Input channel dimension : unconvolved in input, not present in output,
887 // present in filter.
888 // - Depth multiplier : unconvolved in input, present in output, present in
889 // filter.
890 llvm::SmallDenseSet<int64_t> allLoopDims;
891 for (auto outputExpr : indexingMaps.back().getResults()) {
892 int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
893 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
894 !filterDims.count(outputDim)) {
895 // Batch dimension.
896 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
897 return MatchConvolutionResult::OutputDimsNotParallel;
898 allLoopDims.insert(outputDim);
899 continue;
900 }
901 if (inputExprWalker.convolvedDims.count(outputDim) &&
902 !filterDims.count(outputDim)) {
903 // Output image Loop dimension.
904 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
905 return MatchConvolutionResult::OutputDimsNotParallel;
906 allLoopDims.insert(outputDim);
907 continue;
908 }
909 if (!inputExprWalker.convolvedDims.count(outputDim) &&
910 !inputExprWalker.unConvolvedDims.count(outputDim) &&
911 filterDims.count(outputDim)) {
912 // Output channel dimension.
913 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
914 return MatchConvolutionResult::OutputDimsNotParallel;
915 allLoopDims.insert(outputDim);
916 continue;
917 }
918 if (inputExprWalker.unConvolvedDims.count(outputDim) &&
919 filterDims.count(outputDim)) {
920 // Depth multiplier.
921 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
922 return MatchConvolutionResult::OutputDimsNotParallel;
923 allLoopDims.insert(outputDim);
924 continue;
925 }
926 return MatchConvolutionResult::NonConvolutionLoop;
927 }
928 for (auto filterExpr : indexingMaps[1].getResults()) {
929 int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
930 if (outputDims.count(filterDim) &&
931 !inputExprWalker.unConvolvedDims.count(filterDim) &&
932 !inputExprWalker.convolvedDims.count(filterDim)) {
933 // Output channel dimension. This is already seen, continue;
934 continue;
935 }
936 if (inputExprWalker.convolvedDims.count(filterDim) &&
937 !outputDims.count(filterDim)) {
938 // Filter loop dimension.
939 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
940 return MatchConvolutionResult::NonOutputDimNotReduction;
941 if (allLoopDims.count(filterDim))
942 return MatchConvolutionResult::NonConvolutionLoop;
943 allLoopDims.insert(filterDim);
944 continue;
945 }
946 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
947 !outputDims.count(filterDim)) {
948 // Input channel dimension.
949 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
950 return MatchConvolutionResult::NonOutputDimNotReduction;
951 if (allLoopDims.count(filterDim))
952 return MatchConvolutionResult::NonConvolutionLoop;
953 allLoopDims.insert(filterDim);
954 continue;
955 }
956 if (inputExprWalker.unConvolvedDims.count(filterDim) &&
957 outputDims.count(filterDim)) {
958 // Depthwise loop. Already seen.
959 continue;
960 }
961 return MatchConvolutionResult::NonConvolutionLoop;
962 }
963 // All loops must be covered now.
964 if (allLoopDims.size() != linalgOp.getNumLoops())
965 return MatchConvolutionResult::NonConvolutionLoop;
966
967 if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
968 return MatchConvolutionResult::EmptyConvolvedDims;
969
970 if (dimensions) {
971 FailureOr<ConvolutionDimensions> res = inferConvolutionDimsImpl(
972 linalgOp, inputExprWalker, allowEmptyConvolvedDims);
973 assert(succeeded(res) && "unexpected failure to infer convolution dims");
974 *dimensions = *res;
975 }
976
977 return MatchConvolutionResult::Success;
978}
979
980StringRef
981mlir::linalg::detail::getMatchConvolutionMessage(MatchConvolutionResult res) {
982 switch (res) {
983 case MatchConvolutionResult::NotLinalgOp:
984 return "expected a LinalgOp";
985 case MatchConvolutionResult::WrongNumOperands:
986 return "expected op with 2 inputs and 1 output";
987 case MatchConvolutionResult::WrongInputIndexingMap:
988 return "unexpected input index map for convolutions";
989 case MatchConvolutionResult::NotProjectedPermutations:
990 return "expected output/filter indexing maps to be projected permutations";
991 case MatchConvolutionResult::NonConvolutionLoop:
992 return "unexpected loop dimension for convolution op";
993 case MatchConvolutionResult::OutputDimsNotParallel:
994 return "expected all iterators used to access outputs to be parallel";
995 case MatchConvolutionResult::NonOutputDimNotReduction:
996 return "expected all iterators not used to access outputs to be reduction";
997 case MatchConvolutionResult::EmptyConvolvedDims:
998 return "expected convolved dim to be non-empty";
999 case MatchConvolutionResult::Success:
1000 return "";
1001 }
1002 llvm_unreachable("unhandled MatchConvolutionResult case");
1003}
1004
1005bool mlir::linalg::isaConvolutionOpInterface(LinalgOp linalgOp,
1006 bool allowEmptyConvolvedDims) {
1007 return linalg::detail::isConvolutionInterfaceImpl(
1008 op: linalgOp.getOperation(), dimensions: nullptr, allowEmptyConvolvedDims) ==
1009 linalg::detail::MatchConvolutionResult::Success;
1010}
1011
1012LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
1013 MatchConvolutionResult res = isConvolutionInterfaceImpl(op);
1014 if (res != MatchConvolutionResult::Success)
1015 return op->emitError(message: getMatchConvolutionMessage(res));
1016 return success();
1017}
1018
1019//===----------------------------------------------------------------------===//
1020// FillOpInterface implementation
1021//===----------------------------------------------------------------------===//
1022
1023enum class MatchFillResult {
1024 Success = 0,
1025 NotLinalgOp,
1026 WrongNumOperands,
1027 NotScalarInput
1028};
1029
1030static MatchFillResult isFillInterfaceImpl(Operation *op) {
1031 auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
1032 if (!linalgOp)
1033 return MatchFillResult::NotLinalgOp;
1034 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
1035 return MatchFillResult::WrongNumOperands;
1036
1037 OpOperand *value = linalgOp.getDpsInputOperand(0);
1038 if (!linalgOp.isScalar(value))
1039 return MatchFillResult::NotScalarInput;
1040
1041 return MatchFillResult::Success;
1042}
1043
1044LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) {
1045 auto res = isFillInterfaceImpl(op);
1046 if (res == MatchFillResult::NotLinalgOp)
1047 return op->emitError(message: "expected a LinalgOp");
1048 if (res == MatchFillResult::WrongNumOperands)
1049 return op->emitError(message: "expected op with 1 input and 1 output");
1050 if (res == MatchFillResult::NotScalarInput)
1051 return op->emitError(message: "expected op with scalar input");
1052
1053 return success();
1054}
1055
1056//===----------------------------------------------------------------------===//
1057// StructuredOpInterface implementation
1058//===----------------------------------------------------------------------===//
1059
1060SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
1061 Location loc) {
1062 SmallVector<OpFoldResult> res;
1063 for (OpOperand &opOperand : getOperation()->getOpOperands()) {
1064 for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
1065 res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i));
1066 }
1067 return res;
1068}
1069
1070SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
1071 SmallVector<int64_t, 4> res;
1072 assert(!hasDynamicShape() && "expected operands to have static shapes");
1073 for (OpOperand &opOperand : getOperation()->getOpOperands())
1074 llvm::append_range(res, getShape(&opOperand));
1075 return res;
1076}
1077
1078SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
1079 AffineMap map = getLoopsToShapesMap();
1080 unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
1081 auto viewSizes = createFlatListOfOperandDims(b, loc);
1082 SmallVector<Range, 4> res(numDims);
1083 for (unsigned idx = 0; idx < numRes; ++idx) {
1084 auto result = map.getResult(idx);
1085 if (auto d = dyn_cast<AffineDimExpr>(result)) {
1086 if (res[d.getPosition()].offset)
1087 continue;
1088 res[d.getPosition()] =
1089 Range{b.getIndexAttr(0), viewSizes[idx], b.getIndexAttr(1)};
1090 }
1091 }
1092 return res;
1093}
1094
1095/// Visitor to check if any of the given set of positions from AffineDimExprs
1096/// are used within an AffineExpr.
1097struct HasAffineDimExprVisitor
1098 : public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
1099 HasAffineDimExprVisitor(llvm::SmallBitVector positions)
1100 : positions(std::move(positions)) {}
1101
1102 bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) {
1103 return visit(expr: binaryOpExpr.getLHS()) || visit(expr: binaryOpExpr.getRHS());
1104 }
1105
1106 bool visitDimExpr(AffineDimExpr dimExpr) {
1107 return positions.test(Idx: dimExpr.getPosition());
1108 }
1109
1110 bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
1111
1112 bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
1113
1114private:
1115 llvm::SmallBitVector positions;
1116};
1117
1118static std::pair<int64_t, int64_t>
1119getResultsPositionInLoopsToShapeMap(LinalgOp &op) {
1120 int64_t inputRankSum = 0;
1121 int64_t outputRankSum = 0;
1122 for (OpOperand *input : op.getDpsInputOperands())
1123 inputRankSum += op.getRank(input);
1124 for (OpOperand &output : op.getDpsInitsMutable())
1125 outputRankSum += op.getRank(&output);
1126 return {inputRankSum, inputRankSum + outputRankSum};
1127}
1128
1129LogicalResult
1130LinalgOp::reifyResultShapes(OpBuilder &b,
1131 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1132 // An example that helps understand the logic below.
1133 // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
1134 // We want to express the shape of dim 0 of O in terms of shape of the inputs.
1135 // This is achieved as follows.
1136 // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
1137 // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
1138 // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
1139 // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
1140 // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
1141 AffineMap loopsToShapesMap = getLoopsToShapesMap();
1142
1143 // Find the position in the above map that represents the shape of the
1144 // result:dim being inferred.
1145 auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(*this);
1146
1147 /// From loopsToShapesMap extract the submap that represents the shape of the
1148 /// (resultIdx, dim) needed.
1149 AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap(
1150 resultShapesSubMapPos.first,
1151 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1152 AffineMap resultShapesFromInputShapesMap =
1153 loopToResultsShapeMap.compose(getShapesToLoopsMap());
1154
1155 // Check that the result dim map does not contain the positions corresponding
1156 // to the outputs.
1157 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims());
1158 outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1159 HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
1160 Location loc = getOperation()->getLoc();
1161 IRRewriter rewriter(b);
1162 SmallVector<OpFoldResult> allResultDimValues =
1163 affine::makeComposedFoldedMultiResultAffineApply(
1164 rewriter, loc, resultShapesFromInputShapesMap,
1165 createFlatListOfOperandDims(b, loc));
1166 int64_t pos = 0;
1167 ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
1168 for (OpOperand &opOperand : getDpsInitsMutable()) {
1169 SmallVector<OpFoldResult> shapes;
1170 for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1171 auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1172 if (!shapedType.isDynamicDim(dim)) {
1173 // Static dim: Return IntegerAttr.
1174 shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim)));
1175 } else {
1176 // Dynamic dim: Return Value.
1177 OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos])
1178 ? createOrFoldDimOp(b, loc, opOperand.get(), dim)
1179 : allResultDimValues[pos];
1180 shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
1181 }
1182 pos++;
1183 }
1184 reifiedReturnShapes.emplace_back(std::move(shapes));
1185 }
1186 return success();
1187}
1188
1189/// Return the index in the indexingMaps vector that corresponds to this
1190/// `opOperand`.
1191int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
1192 auto operandNumber = opOperand->getOperandNumber();
1193 auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1194 if (!dpsIface.isDpsInput(opOperand))
1195 return operandNumber;
1196 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1197 assert(!dpsIface.isDpsInit(opOperand));
1198 // Account for potential inputs that are not DPS and may not appear in
1199 // `indexingMaps`.
1200 return cast<DestinationStyleOpInterface>(*this->getOperation())
1201 .getNumDpsInputs() +
1202 operandNumber - start;
1203}
1204
1205LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
1206 LinalgOp linalgOp = cast<LinalgOp>(op);
1207 // Mixed tensor/buffer operands are not allowed.
1208 if (!linalgOp.hasPureTensorSemantics() &&
1209 !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
1210 return op->emitOpError(message: "expected to have pure tensor or buffer semantics");
1211
1212 // Before checking indexing maps, we need to make sure the attributes
1213 // referenced by it are valid.
1214 if (linalgOp.hasDynamicIndexingMaps())
1215 if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1216 return failure();
1217
1218 // All input/output operands must be indexed.
1219 if (static_cast<int64_t>(linalgOp.getIndexingMapsArray().size()) !=
1220 linalgOp->getNumOperands())
1221 return op->emitOpError(message: "expected the number of indexing_map (")
1222 << linalgOp.getIndexingMapsArray().size()
1223 << ") to be equal to the number of input/output operands ("
1224 << linalgOp->getNumOperands() << ")";
1225
1226 // Set this flag if this op has user defined maps. This is required to guard
1227 // the below error condition which assume default indexing maps.
1228 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1229 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1230
1231 // Symbols disallowed.
1232 if (indexingMap.getNumSymbols() != 0)
1233 return op->emitOpError("unexpected symbols in indexing_map #")
1234 << opOperand.getOperandNumber();
1235
1236 // Domain must be consistent.
1237 unsigned numLoops = linalgOp.getNumLoops();
1238 if (indexingMap.getNumDims() != numLoops)
1239 return op->emitOpError("expected indexing_map #")
1240 << opOperand.getOperandNumber() << " to have " << numLoops
1241 << " dim(s) to match the number of loops";
1242
1243 int64_t rank = linalgOp.getRank(&opOperand);
1244
1245 if (indexingMap.getNumResults() != rank)
1246 return op->emitOpError("expected operand rank (")
1247 << rank << ") to match the result rank of indexing_map #"
1248 << opOperand.getOperandNumber() << " ("
1249 << indexingMap.getNumResults() << ")";
1250 }
1251 SmallVector<unsigned> redDims;
1252 linalgOp.getReductionDims(redDims);
1253
1254 if (!linalgOp.getShapesToLoopsMap())
1255 return op->emitOpError(message: "expected the shape-to-loops map to be non-null");
1256
1257 // Check if given shapes match to inferred shapes.
1258 SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
1259 SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
1260 // Verify only static cases since we can't get exact dimension sizes and
1261 // loop ranges for dynamic cases in this stage.
1262 if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
1263 for (int64_t &range : endLoopRangeValues)
1264 range -= 1;
1265 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1266 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1267 SmallVector<int64_t, 4> startIndices =
1268 indexingMap.compose(startLoopRangeValues);
1269 SmallVector<int64_t, 4> endIndices =
1270 indexingMap.compose(endLoopRangeValues);
1271 ArrayRef<int64_t> shape = linalgOp.getShape(&opOperand);
1272 for (auto dim : llvm::seq<int64_t>(0, shape.size())) {
1273 // Ignore dynamic dimension or the case that the dimension size is 0
1274 if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
1275 continue;
1276
1277 // The first index or last index should be the maximum or the minimum in
1278 // the inferred index ranges since the range is increasing or
1279 // decreasing. The size of dimensions of input/output operands and the
1280 // maximum value + 1 in the inferred range should be the same. But, for
1281 // now we check if the inferred ranges are in boundary of input/output
1282 // operands' size or not in case that Affine Expressions are complicated
1283 // such as d0 * 3
1284 // + d1 since it is not easy to handle the issues.
1285 // Found the case that this solution can't check, for example, (d0, d1)
1286 // -> (d1 - d0)
1287 int64_t inferredDimSize =
1288 std::max(startIndices[dim], endIndices[dim]) + 1;
1289 if (std::min(startIndices[dim], endIndices[dim]) < 0) {
1290 std::string mapStr;
1291 {
1292 llvm::raw_string_ostream os(mapStr);
1293 os << indexingMap;
1294 }
1295 return op->emitOpError(
1296 "unexpected result less than 0 at expression #")
1297 << dim << " in " << mapStr;
1298 }
1299 if (isa<AffineDimExpr>(indexingMap.getResult(dim))) {
1300 if (inferredDimSize != shape[dim]) {
1301 return op->emitOpError("inferred input/output operand #")
1302 << opOperand.getOperandNumber() << " has shape's dimension #"
1303 << dim << " to be " << inferredDimSize << ", but found "
1304 << shape[dim];
1305 }
1306 } else {
1307 if (inferredDimSize > shape[dim]) {
1308 return op->emitOpError("inferred input/output operand #")
1309 << opOperand.getOperandNumber() << " has shape's dimension #"
1310 << dim << " to be greater than or equal to "
1311 << inferredDimSize << ", but found " << shape[dim];
1312 }
1313 }
1314 }
1315 }
1316 }
1317
1318 // Check the region has exactly one block.
1319 if (linalgOp->getNumRegions() != 1 ||
1320 !llvm::hasSingleElement(linalgOp->getRegion(0)))
1321 return op->emitOpError(message: "expects to have 1 region with 1 block");
1322
1323 // Simplifying assumption: bbargs match 1-1 with shape operands elemental
1324 // types.
1325 // TODO: once ranked shape types are plugged in, we may want to drop the
1326 // corresponding bbargs, that can never be read from. This will be subject to
1327 // consistency discussions (i.e. what to do with output tensors whose bbarg is
1328 // not used).
1329 Block &block = linalgOp->getRegion(0).front();
1330
1331 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments())
1332 return op->emitOpError(message: "expected as many non-induction variable region "
1333 "arguments as the number of input/output operands");
1334
1335 for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1336 Type elementType = opOperand->get().getType();
1337 if (isa<MemRefType, RankedTensorType>(elementType))
1338 elementType = getElementTypeOrSelf(opOperand->get().getType());
1339 Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
1340 if (elementType != argType)
1341 return op->emitOpError("expected type of bb argument #")
1342 << opOperand->getOperandNumber() << " (" << argType << ")"
1343 << " to match element or self type of the corresponding operand ("
1344 << elementType << ")";
1345 }
1346
1347 return success();
1348}
1349

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp