1//===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
10
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/Arith/Utils/Utils.h"
14#include "mlir/Dialect/Complex/IR/Complex.h"
15#include "mlir/Dialect/Linalg/IR/Linalg.h"
16#include "mlir/IR/AffineExpr.h"
17#include "mlir/IR/AffineExprVisitor.h"
18#include "mlir/IR/AffineMap.h"
19#include "mlir/IR/BuiltinTypeInterfaces.h"
20#include "mlir/IR/MLIRContext.h"
21#include "mlir/IR/TypeUtilities.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/SetOperations.h"
24#include "llvm/ADT/SmallBitVector.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/Support/Casting.h"
27#include "llvm/Support/raw_ostream.h"
28#include <optional>
29
30using namespace mlir;
31using namespace mlir::linalg;
32
33/// Include the definitions of the copy operation interface.
34#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
35
36//===----------------------------------------------------------------------===//
37// Interface utility functions
38//===----------------------------------------------------------------------===//
39
40bool linalg::detail::canOpOperandsBeDroppedImpl(
41 linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
42 SmallVector<AffineMap> indexingMaps;
43 for (auto &opOperand : linalgOp->getOpOperands()) {
44 if (llvm::is_contained(Range&: droppedOperands, Element: &opOperand))
45 continue;
46 indexingMaps.push_back(Elt: linalgOp.getMatchingIndexingMap(opOperand: &opOperand));
47 }
48 if (indexingMaps.empty()) {
49 // If there are no indexing maps, the operand can only be dropped
50 // if the op has no loops.
51 return linalgOp.getNumLoops() == 0;
52 }
53 return inversePermutation(map: concatAffineMaps(
54 maps: indexingMaps, context: linalgOp.getContext())) != AffineMap();
55}
56
57//===----------------------------------------------------------------------===//
58// CopyOpInterface implementation
59//===----------------------------------------------------------------------===//
60
61bool linalg::isaCopyOpInterface(LinalgOp op) {
62 // Check all loops are parallel and linalgOp is single input and output.
63 if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
64 return false;
65
66 auto mapRange = op.getIndexingMapsArray();
67 if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
68 !mapRange.back().isIdentity()) {
69 return false;
70 }
71 // Region.
72 return llvm::hasSingleElement(C&: op.getBlock()->getOperations());
73}
74
75//===----------------------------------------------------------------------===//
76// FillOpInterface implementation
77//===----------------------------------------------------------------------===//
78/// Detects if a linalg.generic operation represents a fill with an inlined
79/// constant. If so, returns the constant value. Otherwise, returns
80/// std::nullopt.
81static std::optional<Value> isaInlinedFillOp(GenericOp op) {
82 if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
83 op.getNumDpsInputs() != 0)
84 return std::nullopt;
85
86 // Init should not be referenced.
87 if (op.payloadUsesValueFromOperand(opOperand: op.getDpsInitOperand(i: 0)))
88 return std::nullopt;
89
90 Block *body = op.getBody();
91 if (body->getOperations().size() != 1)
92 return std::nullopt;
93
94 auto yieldOp = dyn_cast<linalg::YieldOp>(Val&: body->back());
95 if (!yieldOp || yieldOp.getNumOperands() != 1)
96 return std::nullopt;
97
98 Value yieldOperand = yieldOp->getOperand(idx: 0);
99 if (!yieldOperand.getDefiningOp<arith::ConstantOp>() &&
100 !yieldOperand.getDefiningOp<complex::ConstantOp>())
101 return std::nullopt;
102
103 return yieldOperand;
104}
105
106/// Detects if a linalg.generic operation represents an external scalar input.
107/// If so, returns the constant value. Otherwise, returns std::nullopt.
108static std::optional<Value> isaExternalFillOp(GenericOp op) {
109 // Structural.
110 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
111 !op.isSingleYieldOp())
112 return std::nullopt;
113
114 // Input should be referenced and init should not.
115 if (!op.payloadUsesValueFromOperand(opOperand: op.getDpsInputOperand(i: 0)) ||
116 op.payloadUsesValueFromOperand(opOperand: op.getDpsInitOperand(i: 0)))
117 return std::nullopt;
118
119 OpOperand *value = op.getDpsInputOperand(i: 0);
120 if (!op.isScalar(opOperand: value))
121 return std::nullopt;
122 return value->get();
123}
124
125std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
126 if (auto fillVal = isaInlinedFillOp(op))
127 return fillVal;
128 return isaExternalFillOp(op);
129}
130
131//===----------------------------------------------------------------------===//
132// BroadcastOpInterface implementation
133//===----------------------------------------------------------------------===//
134std::optional<SmallVector<int64_t>>
135linalg::isaBroadcastOpInterface(GenericOp op) {
136 // Structural.
137 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
138 !op.isSingleYieldOp())
139 return std::nullopt;
140
141 auto srcTy = op.getDpsInputOperand(i: 0)->get().getType();
142 auto dstTy = op.getDpsInitOperand(i: 0)->get().getType();
143 if (!isa<MemRefType, RankedTensorType>(Val: srcTy) ||
144 !isa<MemRefType, RankedTensorType>(Val: dstTy))
145 return std::nullopt;
146
147 // Check output is identity map. Broadcast could additionally be
148 // employing permutation of indices and that would be expressible
149 // in linalg.generic but is not expressible for named broadcast op.
150 auto dstMap = op.getIndexingMapsArray()[1];
151 if (!dstMap.isIdentity())
152 return std::nullopt;
153
154 SmallVector<int64_t> position;
155 auto srcMap = op.getIndexingMapsArray()[0];
156
157 if (srcMap.getResults().size() >= dstMap.getResults().size())
158 return std::nullopt;
159
160 // Check input map is monotonically increasing DimIds.
161 for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
162 auto expr = llvm::dyn_cast<AffineDimExpr>(Val: srcMap.getResults()[i]);
163 if (!expr)
164 return std::nullopt;
165 int64_t pos = expr.getPosition();
166 if (i > 0 && pos <= position[i - 1])
167 return std::nullopt;
168 position.push_back(Elt: expr.getPosition());
169 }
170
171 SmallVector<int64_t> broadcastedDims;
172 auto numDims = srcMap.getNumDims();
173 // This is quadratic but number of items is generally small.
174 for (auto dim : llvm::seq<int64_t>(Begin: 0, End: numDims)) {
175 if (!llvm::is_contained(Range&: position, Element: dim))
176 broadcastedDims.push_back(Elt: dim);
177 }
178 return broadcastedDims;
179}
180
181//===----------------------------------------------------------------------===//
182// TransposeOpInterface implementation
183//===----------------------------------------------------------------------===//
184std::optional<SmallVector<int64_t>>
185linalg::isaTransposeOpInterface(GenericOp op) {
186 // To specialize as a transpose op, the genericOp must be
187 // all parallel loops, single input, single output, and its body
188 // should be just a yield op, yielding input as output as is (no compute).
189 if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
190 !op.isSingleYieldOp())
191 return std::nullopt;
192
193 auto mapRange = op.getIndexingMapsArray();
194 if (mapRange.size() != 2)
195 return std::nullopt;
196
197 auto mapOfInput = mapRange.front();
198 auto mapOfResult = mapRange.back();
199
200 // linalg.transpose permutes the dimensions of input using this
201 // rule: dim(result, i) = dim(input, permutation[i])
202 if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
203 return std::nullopt;
204
205 SmallVector<int64_t> permutation(mapOfInput.getNumDims());
206 for (unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
207 auto expr = llvm::cast<AffineDimExpr>(Val: mapOfInput.getResults()[i]);
208 permutation[expr.getPosition()] = i;
209 }
210 return permutation;
211}
212
213//===----------------------------------------------------------------------===//
214// Elementwise Single Unary/Binary-OpInterface implementation
215//===----------------------------------------------------------------------===//
216static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op,
217 unsigned arity) {
218 // Check all loops are parallel.
219 if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
220 return false;
221
222 // Check there are arity-inputs, 1-output and all are identity-maps.
223 if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
224 !llvm::all_of(Range: op.getIndexingMapsArray(),
225 P: [](AffineMap map) { return map.isIdentity(); }))
226 return false;
227
228 // Init should not be referenced for elementwise operations.
229 if (op.payloadUsesValueFromOperand(opOperand: op.getDpsInitOperand(i: 0)))
230 return false;
231
232 // A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such
233 // as resulting from producer-consumer fusion. Here, we restrict to two ops in
234 // the body, where the first is the elementwise single op and the second a
235 // yield.
236 Block *body = op.getBody();
237 if (body->getOperations().size() != 2)
238 return false;
239
240 Operation *oper = &body->front();
241 if (oper->getNumOperands() != arity || oper->getNumResults() != 1)
242 return false;
243
244 auto yieldOp = dyn_cast<linalg::YieldOp>(Val&: body->back());
245 if (!yieldOp || yieldOp.getNumOperands() != 1 ||
246 yieldOp->getOperand(idx: 0).getDefiningOp() != oper)
247 return false;
248 return true;
249}
250
251bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op) {
252 // All basic elemwise checks.
253 if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, arity: 1))
254 return false;
255
256 // Check input is actully used.
257 if (!op.payloadUsesValueFromOperand(opOperand: op.getDpsInputOperand(i: 0)))
258 return false;
259 return true;
260}
261
262bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op) {
263 if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, arity: 2))
264 return false;
265
266 // Check both inputs are used (elementwise).
267 OpOperand *inputOpOperand0 = op.getDpsInputOperand(i: 0);
268 OpOperand *inputOpOperand1 = op.getDpsInputOperand(i: 1);
269 if (!op.payloadUsesValueFromOperand(opOperand: inputOpOperand0) ||
270 !op.payloadUsesValueFromOperand(opOperand: inputOpOperand1))
271 return false;
272 return true;
273}
274
275//===----------------------------------------------------------------------===//
276// ContractionOpInterface implementation
277//===----------------------------------------------------------------------===//
278
279/// If the value is defined by a chain of unary side effect-free, go up the
280/// use-def chain until the first value that isn't defined by such an op.
281// TODO: relax to multi-operands with constants, which are technically unary ops
282// as needed (e.g. add5).
283static Value getSourceSkipUnary(Value value) {
284 Operation *op = value.getDefiningOp();
285 while (op && op->getNumOperands() == 1) {
286 auto iface = dyn_cast<MemoryEffectOpInterface>(Val: op);
287 if (!iface || !iface.hasNoEffect())
288 break;
289 value = op->getOperand(idx: 0);
290 op = value.getDefiningOp();
291 }
292 return value;
293}
294
295bool mlir::linalg::detail::isContractionBody(
296 Block &block, function_ref<bool(Operation *, Operation *)> isaPair,
297 llvm::raw_ostream &errs) {
298 if (block.empty() || !block.back().mightHaveTrait<OpTrait::IsTerminator>()) {
299 errs << "no terminator in the block";
300 return false;
301 }
302
303 if (block.getNumArguments() != 3) {
304 errs << "expected block with 3 arguments";
305 return false;
306 }
307
308 Operation *terminator = block.getTerminator();
309 if (terminator->getNumOperands() != 1) {
310 errs << "expected terminator with 1 operand";
311 return false;
312 }
313
314 Value yielded = getSourceSkipUnary(value: terminator->getOperand(idx: 0));
315 Operation *reductionOp = yielded.getDefiningOp();
316 if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) {
317 errs << "expected reduction op to be binary";
318 return false;
319 }
320
321 Value reductionLHS = getSourceSkipUnary(value: reductionOp->getOperand(idx: 0));
322 Value reductionRHS = getSourceSkipUnary(value: reductionOp->getOperand(idx: 1));
323
324 if (reductionLHS != block.getArgument(i: 2) &&
325 reductionRHS != block.getArgument(i: 2)) {
326 errs << "expected reduction to take block argument #2 as one of the "
327 "operands (modulo unary casts)";
328 return false;
329 }
330
331 Value contributed = getSourceSkipUnary(
332 value: isa<BlockArgument>(Val: reductionLHS) ? reductionRHS : reductionLHS);
333 Operation *elementwiseOp = contributed.getDefiningOp();
334 if (!elementwiseOp || elementwiseOp->getNumResults() != 1 ||
335 elementwiseOp->getNumOperands() != 2) {
336 errs << "expected elementwise op to be binary";
337 return false;
338 }
339
340 if (!isaPair(elementwiseOp, reductionOp)) {
341 errs << "expected reduction/elementwise op kind not satisfied";
342 return false;
343 }
344
345 Value elementwiseLHS = getSourceSkipUnary(value: elementwiseOp->getOperand(idx: 0));
346 Value elementwiseRHS = getSourceSkipUnary(value: elementwiseOp->getOperand(idx: 1));
347 if ((elementwiseLHS == block.getArgument(i: 0) &&
348 elementwiseRHS == block.getArgument(i: 1)) ||
349 (elementwiseLHS == block.getArgument(i: 1) &&
350 elementwiseRHS == block.getArgument(i: 0))) {
351 return true;
352 }
353
354 errs << "expected elementwise op to apply to block arguments (modulo unary "
355 "casts)";
356 return false;
357}
358
359/// Returns true if the two operations are of the kinds specified by a pair of
360/// consecutive template arguments.
361template <typename AddOpTy, typename MulOpTy, typename... Args>
362static bool isPairTemplateImpl(Operation *add, Operation *mul) {
363 static_assert(sizeof...(Args) % 2 == 0,
364 "expected an even number of template arguments");
365 if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
366 return true;
367
368 if constexpr (sizeof...(Args) > 0)
369 return isPairTemplateImpl<Args...>(add, mul);
370 else
371 return false;
372}
373
374/// Returns true if the block is a body of a contraction with the kinds of
375/// operations given pairwise by template arguments.
376template <typename... Args>
377static bool isContractionBody(Block &block) {
378 return linalg::detail::isContractionBody(block, isaPair: &isPairTemplateImpl<Args...>);
379}
380
381/// Given an `indexingMap` and its corresponding `iterators`, returns
382/// the positions of the iterators of type `iter` that are indexed by
383/// the `indexingMap` as a permutation. This is useful to infer various
384/// subcomputations on a `LinalgOp`. This is performed by looking up
385/// each result in the `indexingMap` and determining whether:
386/// - It is a single AffineDimExpr.
387/// - It is the only result involving this AffineDimExpr.
388static llvm::SmallDenseSet<int64_t>
389findPermutationsIndexingOperand(AffineMap indexingMap,
390 ArrayRef<utils::IteratorType> iterators,
391 utils::IteratorType iter) {
392 assert(iterators.size() == indexingMap.getNumDims());
393 llvm::SmallDenseSet<int64_t> res;
394 for (AffineExpr e : indexingMap.getResults()) {
395 if (auto d = dyn_cast<AffineDimExpr>(Val&: e)) {
396 if (iterators[d.getPosition()] == iter &&
397 llvm::count_if(Range: indexingMap.getResults(), P: [d](AffineExpr e) {
398 return e.isFunctionOfDim(position: d.getPosition());
399 }) == 1)
400 res.insert(V: d.getPosition());
401 }
402 }
403 return res;
404}
405
406namespace {
407auto par = utils::IteratorType::parallel;
408auto red = utils::IteratorType::reduction;
409} // namespace
410
411/// Infer the iterator types from the init affine map. This looks at which dims
412/// are present in the map results, and returns an iterator types array with
413/// parallel types for dims that are present, and reduction types for dims that
414/// are not present.
415static FailureOr<SmallVector<utils::IteratorType>>
416inferIteratorsFromOutMap(AffineMap map) {
417 if (!map.isProjectedPermutation())
418 return failure();
419 SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
420 for (auto expr : map.getResults())
421 if (auto dim = dyn_cast<AffineDimExpr>(Val&: expr))
422 iterators[dim.getPosition()] = par;
423 return iterators;
424}
425
426/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
427/// a matmul subcomputation within `linalgOp`. These dimensions are such that:
428/// 1. The m dimension is involved in an outer-product along LHS
429/// (i.e. it is a permutation on RES and LHS and does not appear in RHS).
430/// 2. The n dimension is involved in an outer-product along RHS
431/// (i.e. it is a permutation on RES and RHS and does not appear in LHS).
432/// 3. The k dimension appears as a permutation on LHS and RHS.
433/// 4. m, n and k appear only once in any given indexing.
434/// 5. Optional batch dimensions that appear in all operands are captured.
435/// This allows e.g. detecting that some contraction is embedded within
436/// `linalgOp` with some orthogonal heuristic.
437static FailureOr<ContractionDimensions>
438inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps,
439 ArrayRef<utils::IteratorType> iterators) {
440 llvm::SmallDenseSet<int64_t> a =
441 findPermutationsIndexingOperand(indexingMap: indexingMaps[0], iterators, iter: par);
442 llvm::SmallDenseSet<int64_t> b =
443 findPermutationsIndexingOperand(indexingMap: indexingMaps[1], iterators, iter: par);
444 llvm::SmallDenseSet<int64_t> c =
445 findPermutationsIndexingOperand(indexingMap: indexingMaps[2], iterators, iter: par);
446
447 // A & C - B are the iterators involved in an outer-product along A (the LHS).
448 llvm::SmallDenseSet<int64_t> ac = a;
449 llvm::set_intersect(S1&: ac, S2: c);
450 llvm::set_subtract(S1&: ac, S2: b);
451 // B & C - A are the iterators involved in an outer-product along B (the RHS).
452 llvm::SmallDenseSet<int64_t> bc = b;
453 llvm::set_intersect(S1&: bc, S2: c);
454 llvm::set_subtract(S1&: bc, S2: a);
455 // A & B & C are the "batch" dimensions.
456 llvm::SmallDenseSet<int64_t> batches = a;
457 llvm::set_intersect(S1&: batches, S2: b);
458 llvm::set_intersect(S1&: batches, S2: c);
459
460 // A & B red are the reduction dimensions.
461 llvm::SmallDenseSet<int64_t> ra =
462 findPermutationsIndexingOperand(indexingMap: indexingMaps[0], iterators, iter: red);
463 llvm::SmallDenseSet<int64_t> rb =
464 findPermutationsIndexingOperand(indexingMap: indexingMaps[1], iterators, iter: red);
465 llvm::set_intersect(S1&: ra, S2: rb);
466
467 // Return each set in sorted order.
468 ContractionDimensions dimensions{
469 .batch: SmallVector<unsigned, 2>(batches.begin(), batches.end()),
470 .m: SmallVector<unsigned, 2>(ac.begin(), ac.end()),
471 .n: SmallVector<unsigned, 2>(bc.begin(), bc.end()),
472 .k: SmallVector<unsigned, 2>(ra.begin(), ra.end())};
473 llvm::sort(Start: dimensions.batch.begin(), End: dimensions.batch.end());
474 llvm::sort(Start: dimensions.m.begin(), End: dimensions.m.end());
475 llvm::sort(Start: dimensions.n.begin(), End: dimensions.n.end());
476 llvm::sort(Start: dimensions.k.begin(), End: dimensions.k.end());
477 return dimensions;
478}
479
480FailureOr<ContractionDimensions>
481mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
482 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
483 return failure();
484 return inferContractionDimsImpl(indexingMaps: linalgOp.getIndexingMapsArray(),
485 iterators: linalgOp.getIteratorTypesArray());
486}
487
488FailureOr<ContractionDimensions>
489mlir::linalg::inferContractionDims(ArrayRef<AffineMap> indexingMaps) {
490 if (indexingMaps.size() != 3)
491 return failure();
492 auto iterators = inferIteratorsFromOutMap(map: indexingMaps[2]);
493 if (failed(Result: iterators))
494 return failure();
495 return inferContractionDimsImpl(indexingMaps, iterators: iterators.value());
496}
497
498namespace mlir::linalg::detail {
499enum class MatchContractionResult {
500 Success = 0,
501 NotLinalgOp,
502 WrongNumOperands,
503 NoReduction,
504 NotProjectedPermutations,
505 NotAddMul
506};
507} // namespace mlir::linalg::detail
508
509mlir::linalg::detail::MatchContractionResult
510mlir::linalg::detail::isContractionInterfaceImpl(
511 Operation *op, mlir::linalg::ContractionDimensions *dimensions) {
512 auto linalgOp = dyn_cast<linalg::LinalgOp>(Val: op);
513 if (!linalgOp)
514 return MatchContractionResult::NotLinalgOp;
515 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
516 return MatchContractionResult::WrongNumOperands;
517 auto mapRange = linalgOp.getIndexingMapsArray();
518 if (linalgOp.getNumReductionLoops() == 0)
519 return MatchContractionResult::NoReduction;
520 if (llvm::any_of(Range&: mapRange,
521 P: [](AffineMap m) { return !m.isProjectedPermutation(); }))
522 return MatchContractionResult::NotProjectedPermutations;
523 // TODO: more fields than add/mul.
524 // clang-format off
525 if (!::isContractionBody<
526 arith::MulFOp, arith::AddFOp,
527 arith::MulIOp, arith::AddIOp,
528 complex::MulOp, complex::AddOp,
529 arith::AndIOp, arith::OrIOp>(
530 block&: *linalgOp.getBlock())) {
531 return MatchContractionResult::NotAddMul;
532 }
533 // clang-format on
534
535 if (dimensions) {
536 FailureOr<ContractionDimensions> res = inferContractionDims(linalgOp);
537 assert(succeeded(res) && "unexpected failure to infer contraction dims");
538 *dimensions = *res;
539 }
540 return MatchContractionResult::Success;
541}
542
543StringRef
544mlir::linalg::detail::getMatchContractionMessage(MatchContractionResult res) {
545 switch (res) {
546 case MatchContractionResult::NotLinalgOp:
547 return "expected a LinalgOp";
548 case MatchContractionResult::WrongNumOperands:
549 return "expected op with 2 inputs and 1 output";
550 case MatchContractionResult::NoReduction:
551 return "expected at least 1 reduction";
552 case MatchContractionResult::NotProjectedPermutations:
553 return "expected indexing maps to be projected permutations";
554 case MatchContractionResult::NotAddMul:
555 return "expected add/mul op in the body";
556 case MatchContractionResult::Success:
557 return "";
558 }
559 llvm_unreachable("unhandled MatchContractionResult case");
560}
561
562bool mlir::linalg::isaContractionOpInterface(LinalgOp linalgOp) {
563 if (!linalgOp)
564 return false;
565 Operation *op = linalgOp.getOperation();
566 return isa<ContractionOpInterface>(Val: op) ||
567 (mlir::linalg::detail::isContractionInterfaceImpl(op) ==
568 mlir::linalg::detail::MatchContractionResult::Success);
569}
570
571/// Verify that a LinalgOp `op` is a contraction.
572/// A Linalg contraction is defined in general terms:
573/// 1. Has 2 input and 1 output shapes.
574/// 2. Has at least one reduction dimension.
575/// 3. Has only projected permutation indexing maps.
576/// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
577/// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
578/// operations that may change the type (e.g. for mixed-precision).
579/// As a consequence, when vectorization of such an op occurs, the only special
580/// behavior is that the (unique) MulOpType is vectorized into a
581/// `vector.contract`. All other ops are handled in a generic fashion.
582/// In the future, we may wish to allow more input arguments and elementwise and
583/// constant operations that do not involve the reduction dimension(s).
584LogicalResult mlir::linalg::detail::verifyContractionInterface(Operation *op) {
585 auto res = isContractionInterfaceImpl(op);
586 if (res != MatchContractionResult::Success)
587 return op->emitError(message: getMatchContractionMessage(res));
588 return success();
589}
590
591//===----------------------------------------------------------------------===//
592// ConvolutionOpInterface implementation
593//===----------------------------------------------------------------------===//
594
595/// Of the given two expressions returns one that is of type T (`lhs` gets
596/// preference over `rhs`)
597template <typename T>
598static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs) {
599 return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) : nullptr);
600}
601
602namespace {
603/// Walk the indexing expressions for input of a convolution operation to verify
604/// its of the right form, either
605/// - AffineDimExpr
606/// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?
607/// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)*
608///
609/// classifies the AffineDimExpr as convolved dimensions or unconvolved
610/// dimensions and verifies each dimension occurs only once.
611struct ConvAccessExprWalker
612 : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {
613 // Stores dimensions used in expressions of the above form.
614 llvm::SmallDenseSet<int64_t> convolvedDims;
615 // Stores the dual mapping between LHS and RHS of convolution exprs.
616 llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
617 // Stores single use dimensions used by an AffineDimExpr.
618 llvm::SmallDenseSet<int64_t> unConvolvedDims;
619 // Stores a mapping from convolved dims to their coefficient.
620 llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
621
622 // Removes dims with multiple uses in the source input map from dimension
623 // sets tracked by this walker.
624 void clearMultiUseDims(AffineMap map) {
625 for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) {
626 if (llvm::count_if(Range: map.getResults(), P: [dimPos](AffineExpr e) {
627 return e.isFunctionOfDim(position: dimPos);
628 }) > 1) {
629 convolvedDims.erase(V: dimPos);
630 unConvolvedDims.erase(V: dimPos);
631 // If a duplicate dim is marked as convolved, the pair of the duplicate
632 // dim must be removed from the map as well.
633 auto it = convolvedDimMapping.find(Val: dimPos);
634 if (it != convolvedDimMapping.end()) {
635 int64_t pairedDim = it->second;
636 convolvedDims.erase(V: pairedDim);
637 unConvolvedDims.erase(V: pairedDim);
638 strideAndDilationMapping.erase(Val: pairedDim);
639 convolvedDimMapping.erase(Val: dimPos);
640 convolvedDimMapping.erase(Val: pairedDim);
641 }
642 }
643 }
644 }
645
646 LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
647 unsigned position = dimExpr.getPosition();
648 if (unConvolvedDims.count(V: position) || convolvedDims.count(V: position)) {
649 return failure();
650 }
651 unConvolvedDims.insert(V: position);
652 return success();
653 }
654
655 LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); }
656
657 LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); }
658
659 LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
660 // In pre-order visit, top level op has to be an add op.
661 if (binaryExpr.getKind() != AffineExprKind::Add)
662 return failure();
663 auto lhsDimPos = getDimExprOrMulExprDimPos(expr: binaryExpr.getLHS());
664 auto rhsDimPos = getDimExprOrMulExprDimPos(expr: binaryExpr.getRHS());
665 if (failed(Result: lhsDimPos) || failed(Result: rhsDimPos))
666 return failure();
667 convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
668 convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
669 return success();
670 }
671
672 FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
673 if (auto dimExpr = dyn_cast<AffineDimExpr>(Val&: expr)) {
674 int64_t dim = dimExpr.getPosition();
675 if (convolvedDims.count(V: dim) || unConvolvedDims.count(V: dim))
676 return failure();
677 // Stride/dilation for this dim is implicitly 1.
678 strideAndDilationMapping[dim] =
679 getAffineConstantExpr(constant: 1, context: expr.getContext());
680 convolvedDims.insert(V: dim);
681 return dim;
682 }
683 if (auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(Val&: expr)) {
684 if (symbolMulExpr.getKind() != AffineExprKind::Mul)
685 return failure();
686 auto lhsExpr = symbolMulExpr.getLHS();
687 auto rhsExpr = symbolMulExpr.getRHS();
688 // Check for symbol expression.
689 AffineExpr mulExpr =
690 getAffineExprOfType<AffineSymbolExpr>(lhs: lhsExpr, rhs: rhsExpr);
691 // If there was no symbol expr, check for constant expression.
692 if (!mulExpr) {
693 mulExpr = getAffineExprOfType<AffineConstantExpr>(lhs: lhsExpr, rhs: rhsExpr);
694 }
695 auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhs: lhsExpr, rhs: rhsExpr);
696 if (!mulExpr || !dimExpr)
697 return failure();
698 int64_t dim = dimExpr.getPosition();
699 if (convolvedDims.count(V: dim) || unConvolvedDims.count(V: dim))
700 return failure();
701 strideAndDilationMapping[dim] = mulExpr;
702 convolvedDims.insert(V: dim);
703 return dim;
704 }
705 return failure();
706 }
707};
708} // namespace
709
710static llvm::SmallDenseSet<int64_t> getPreservedDims(AffineMap map) {
711 assert(map.isProjectedPermutation() &&
712 "expected map to have projected permutations");
713 llvm::SmallDenseSet<int64_t> preservedDims;
714 for (auto expr : map.getResults())
715 preservedDims.insert(V: cast<AffineDimExpr>(Val&: expr).getPosition());
716 return preservedDims;
717}
718
719static SmallVector<int64_t, 2>
720getConstantsFromExprList(const SmallVector<AffineExpr, 2> &exprs) {
721 SmallVector<int64_t, 2> vals;
722 for (auto e : exprs) {
723 auto constantExpr = dyn_cast<AffineConstantExpr>(Val&: e);
724 assert(constantExpr && "Found non-constant stride/dilation");
725 vals.push_back(Elt: constantExpr.getValue());
726 }
727 return vals;
728}
729
730/// Classifies dimensions in the `linalgOp` used by a convolution
731/// subcomputation, as captured by `inputExprWalker`. If
732/// `allowEmptyConvolvedDims` is not set this this will fail if there is not
733/// at least convolved dimension pair (output image + filter loop). Convolution
734/// dimensions are specified in sorted order, and strides match the order of
735/// the filter loop dimensions, while the dilations match the order of the
736/// output image dimensions.
737static FailureOr<ConvolutionDimensions>
738inferConvolutionDimsImpl(LinalgOp linalgOp,
739 ConvAccessExprWalker &inputExprWalker,
740 bool allowEmptyConvolvedDims) {
741 auto filterMap =
742 linalgOp.getMatchingIndexingMap(opOperand: linalgOp.getDpsInputOperand(i: 1));
743 auto outputMap =
744 linalgOp.getMatchingIndexingMap(opOperand: linalgOp.getDpsInitOperand(i: 0));
745 llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand(
746 indexingMap: filterMap, iterators: linalgOp.getIteratorTypesArray(), iter: par);
747 llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand(
748 indexingMap: outputMap, iterators: linalgOp.getIteratorTypesArray(), iter: par);
749
750 // unConvolvedDims & outputDims - filterDims are the batch iterators.
751 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
752 llvm::set_intersect(S1&: batch, S2: outputDims);
753 llvm::set_subtract(S1&: batch, S2: filterDims);
754
755 // convolvedDims & outputDims are the output image iterators.
756 llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
757 llvm::set_intersect(S1&: oi, S2: outputDims);
758
759 // filterDims & outputDims - unConvolvedDims are the output channel iterators.
760 llvm::SmallDenseSet<int64_t> oc = filterDims;
761 llvm::set_intersect(S1&: oc, S2: outputDims);
762 llvm::set_subtract(S1&: oc, S2: inputExprWalker.unConvolvedDims);
763
764 // filterDims & outputDims & unConvolvedDims are the depth iterators.
765 llvm::SmallDenseSet<int64_t> depth = filterDims;
766 llvm::set_intersect(S1&: depth, S2: outputDims);
767 llvm::set_intersect(S1&: depth, S2: inputExprWalker.unConvolvedDims);
768
769 llvm::SmallDenseSet<int64_t> filterReducedDims =
770 findPermutationsIndexingOperand(indexingMap: filterMap,
771 iterators: linalgOp.getIteratorTypesArray(), iter: red);
772
773 // convolvedDims & filterReducedDims are the filter loop iterators.
774 llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
775 llvm::set_intersect(S1&: fl, S2: filterReducedDims);
776
777 // unConvolvedDims & filterReducedDims are the input channel iterators.
778 llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
779 llvm::set_intersect(S1&: ic, S2: filterReducedDims);
780
781 if (oi.empty() && !allowEmptyConvolvedDims)
782 return failure();
783
784 // Return each set in sorted order.
785 ConvolutionDimensions dimensions{
786 .batch: SmallVector<unsigned, 2>(batch.begin(), batch.end()),
787 .outputImage: SmallVector<unsigned, 2>(oi.begin(), oi.end()),
788 .outputChannel: SmallVector<unsigned, 2>(oc.begin(), oc.end()),
789 .filterLoop: SmallVector<unsigned, 2>(fl.begin(), fl.end()),
790 .inputChannel: SmallVector<unsigned, 2>(ic.begin(), ic.end()),
791 .depth: SmallVector<unsigned, 2>(depth.begin(), depth.end()),
792 /*strides=*/SmallVector<int64_t, 2>{},
793 /*dilations=*/SmallVector<int64_t, 2>{}};
794 llvm::sort(Start: dimensions.batch.begin(), End: dimensions.batch.end());
795 llvm::sort(Start: dimensions.outputImage.begin(), End: dimensions.outputImage.end());
796 llvm::sort(Start: dimensions.outputChannel.begin(), End: dimensions.outputChannel.end());
797 llvm::sort(Start: dimensions.filterLoop.begin(), End: dimensions.filterLoop.end());
798 llvm::sort(Start: dimensions.inputChannel.begin(), End: dimensions.inputChannel.end());
799 llvm::sort(Start: dimensions.depth.begin(), End: dimensions.depth.end());
800
801 // Use the op carried strides/dilations attribute if present.
802 auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>(name: "strides");
803 if (!nativeStrides) {
804 SmallVector<AffineExpr, 2> strideExprs;
805 for (unsigned oiDim : dimensions.outputImage)
806 strideExprs.push_back(Elt: inputExprWalker.strideAndDilationMapping[oiDim]);
807 dimensions.strides = getConstantsFromExprList(exprs: strideExprs);
808 } else {
809 dimensions.strides = llvm::to_vector<2>(Range: nativeStrides.getValues<int64_t>());
810 }
811 auto nativeDilations =
812 linalgOp->getAttrOfType<DenseIntElementsAttr>(name: "dilations");
813 if (!nativeDilations) {
814 SmallVector<AffineExpr, 2> dilationExprs;
815 for (unsigned flDim : dimensions.filterLoop)
816 dilationExprs.push_back(Elt: inputExprWalker.strideAndDilationMapping[flDim]);
817 dimensions.dilations = getConstantsFromExprList(exprs: dilationExprs);
818 } else {
819 dimensions.dilations =
820 llvm::to_vector<2>(Range: nativeDilations.getValues<int64_t>());
821 }
822 return dimensions;
823}
824
825/// Find at least 1 parallel (output_image) and reduction (filter_loop)
826/// dimension candidates that form a convolution subcomputation within
827/// `linalgOp`. The LHS is assumed to be the convolution input while the
828/// RHS is assumed as the filter.
829/// These dimensions are such that:
830/// 1. Optional batch dimensions that appear in the input and filter.
831/// 2. The output_image dimension is involved in a cross-correlation along LHS
832/// (i.e. it is a permutation on RES and LHS and has an associated
833/// filter_loop in RHS).
834/// 3. Optional output_channel dimension is involved in an outer-product along
835/// RHS (i.e. it is a permutation on RES and RHS and does not appear in
836/// LHS).
837/// 4. Optional input_channel dimension appears as a permutation on LHS and
838/// RHS.
839/// 5. The filter_loop dimension appears as a permutation on the RHS and
840/// represents the shape of the kernel cross-correlated along a
841/// corresponding output_image dim.
842/// 6. The input_channel dimension appears as a permutation on LHS and RHS.
843/// 7. All dimensions appear only once in any given indexing map.
844/// This allows e.g. detecting that some convolution is embedded within
845/// `linalgOp` with some orthogonal heuristic.
846/// When multiple dimension occurrences exist that match any classification
847/// indices are returned in sorted order.
848/// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty.
849FailureOr<ConvolutionDimensions>
850mlir::linalg::inferConvolutionDims(LinalgOp linalgOp) {
851 if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
852 return failure();
853
854 auto indexingMaps = linalgOp.getIndexingMapsArray();
855
856 // Check the input indexing map has the right form.
857 ConvAccessExprWalker inputExprWalker;
858 for (AffineExpr expr : indexingMaps[0].getResults())
859 (void)inputExprWalker.visit(expr);
860 inputExprWalker.clearMultiUseDims(map: indexingMaps[0]);
861
862 return inferConvolutionDimsImpl(linalgOp, inputExprWalker,
863 /*allowEmptyConvolvedDims=*/false);
864}
865
866namespace mlir::linalg::detail {
867enum class MatchConvolutionResult {
868 Success = 0,
869 NotLinalgOp,
870 WrongNumOperands,
871 WrongInputIndexingMap,
872 NotProjectedPermutations,
873 NonConvolutionLoop,
874 OutputDimsNotParallel,
875 NonOutputDimNotReduction,
876 EmptyConvolvedDims
877};
878} // namespace mlir::linalg::detail
879
880mlir::linalg::detail::MatchConvolutionResult
881mlir::linalg::detail::isConvolutionInterfaceImpl(
882 Operation *op, ConvolutionDimensions *dimensions,
883 bool allowEmptyConvolvedDims) {
884 auto linalgOp = dyn_cast<linalg::LinalgOp>(Val: op);
885 if (!linalgOp)
886 return MatchConvolutionResult::NotLinalgOp;
887 if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
888 return MatchConvolutionResult::WrongNumOperands;
889
890 auto indexingMaps = linalgOp.getIndexingMapsArray();
891
892 // Check the input indexing map has the right form.
893 ConvAccessExprWalker inputExprWalker;
894 if (llvm::any_of(Range: indexingMaps[0].getResults(),
895 P: [&inputExprWalker](AffineExpr expr) {
896 return failed(Result: inputExprWalker.visit(expr));
897 })) {
898 return MatchConvolutionResult::WrongInputIndexingMap;
899 }
900
901 // Filter and output maps must be projected permutation.
902 if (!indexingMaps[1].isProjectedPermutation() ||
903 !indexingMaps.back().isProjectedPermutation())
904 return MatchConvolutionResult::NotProjectedPermutations;
905
906 auto iteratorTypes = linalgOp.getIteratorTypesArray();
907
908 llvm::SmallDenseSet<int64_t> outputDims =
909 getPreservedDims(map: indexingMaps.back());
910 llvm::SmallDenseSet<int64_t> filterDims = getPreservedDims(map: indexingMaps[1]);
911 // Make sure all loops are characterized as one of:
912 // - Batch loop : present in output, as non-convolved in input, not present in
913 // filter.
914 // - Output image dimension : present in output, convolved dims in input, not
915 // present in filter.
916 // - Output channel dimension : present in output, not present in input,
917 // present in filter.
918 // - Filter loop dimension : present in filter, convolved in input, not
919 // present in output.
920 // - Input channel dimension : unconvolved in input, not present in output,
921 // present in filter.
922 // - Depth multiplier : unconvolved in input, present in output, present in
923 // filter.
924 llvm::SmallDenseSet<int64_t> allLoopDims;
925 for (auto outputExpr : indexingMaps.back().getResults()) {
926 int64_t outputDim = cast<AffineDimExpr>(Val&: outputExpr).getPosition();
927 if (inputExprWalker.unConvolvedDims.count(V: outputDim) &&
928 !filterDims.count(V: outputDim)) {
929 // Batch dimension.
930 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
931 return MatchConvolutionResult::OutputDimsNotParallel;
932 allLoopDims.insert(V: outputDim);
933 continue;
934 }
935 if (inputExprWalker.convolvedDims.count(V: outputDim) &&
936 !filterDims.count(V: outputDim)) {
937 // Output image Loop dimension.
938 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
939 return MatchConvolutionResult::OutputDimsNotParallel;
940 allLoopDims.insert(V: outputDim);
941 continue;
942 }
943 if (!inputExprWalker.convolvedDims.count(V: outputDim) &&
944 !inputExprWalker.unConvolvedDims.count(V: outputDim) &&
945 filterDims.count(V: outputDim)) {
946 // Output channel dimension.
947 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
948 return MatchConvolutionResult::OutputDimsNotParallel;
949 allLoopDims.insert(V: outputDim);
950 continue;
951 }
952 if (inputExprWalker.unConvolvedDims.count(V: outputDim) &&
953 filterDims.count(V: outputDim)) {
954 // Depth multiplier.
955 if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
956 return MatchConvolutionResult::OutputDimsNotParallel;
957 allLoopDims.insert(V: outputDim);
958 continue;
959 }
960 return MatchConvolutionResult::NonConvolutionLoop;
961 }
962 for (auto filterExpr : indexingMaps[1].getResults()) {
963 int64_t filterDim = cast<AffineDimExpr>(Val&: filterExpr).getPosition();
964 if (outputDims.count(V: filterDim) &&
965 !inputExprWalker.unConvolvedDims.count(V: filterDim) &&
966 !inputExprWalker.convolvedDims.count(V: filterDim)) {
967 // Output channel dimension. This is already seen, continue;
968 continue;
969 }
970 if (inputExprWalker.convolvedDims.count(V: filterDim) &&
971 !outputDims.count(V: filterDim)) {
972 // Filter loop dimension.
973 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
974 return MatchConvolutionResult::NonOutputDimNotReduction;
975 if (allLoopDims.count(V: filterDim))
976 return MatchConvolutionResult::NonConvolutionLoop;
977 allLoopDims.insert(V: filterDim);
978 continue;
979 }
980 if (inputExprWalker.unConvolvedDims.count(V: filterDim) &&
981 !outputDims.count(V: filterDim)) {
982 // Input channel dimension.
983 if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
984 return MatchConvolutionResult::NonOutputDimNotReduction;
985 if (allLoopDims.count(V: filterDim))
986 return MatchConvolutionResult::NonConvolutionLoop;
987 allLoopDims.insert(V: filterDim);
988 continue;
989 }
990 if (inputExprWalker.unConvolvedDims.count(V: filterDim) &&
991 outputDims.count(V: filterDim)) {
992 // Depthwise loop. Already seen.
993 continue;
994 }
995 return MatchConvolutionResult::NonConvolutionLoop;
996 }
997 // All loops must be covered now.
998 if (allLoopDims.size() != linalgOp.getNumLoops())
999 return MatchConvolutionResult::NonConvolutionLoop;
1000
1001 if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
1002 return MatchConvolutionResult::EmptyConvolvedDims;
1003
1004 if (dimensions) {
1005 FailureOr<ConvolutionDimensions> res = inferConvolutionDimsImpl(
1006 linalgOp, inputExprWalker, allowEmptyConvolvedDims);
1007 assert(succeeded(res) && "unexpected failure to infer convolution dims");
1008 *dimensions = *res;
1009 }
1010
1011 return MatchConvolutionResult::Success;
1012}
1013
1014StringRef
1015mlir::linalg::detail::getMatchConvolutionMessage(MatchConvolutionResult res) {
1016 switch (res) {
1017 case MatchConvolutionResult::NotLinalgOp:
1018 return "expected a LinalgOp";
1019 case MatchConvolutionResult::WrongNumOperands:
1020 return "expected op with 2 inputs and 1 output";
1021 case MatchConvolutionResult::WrongInputIndexingMap:
1022 return "unexpected input index map for convolutions";
1023 case MatchConvolutionResult::NotProjectedPermutations:
1024 return "expected output/filter indexing maps to be projected permutations";
1025 case MatchConvolutionResult::NonConvolutionLoop:
1026 return "unexpected loop dimension for convolution op";
1027 case MatchConvolutionResult::OutputDimsNotParallel:
1028 return "expected all iterators used to access outputs to be parallel";
1029 case MatchConvolutionResult::NonOutputDimNotReduction:
1030 return "expected all iterators not used to access outputs to be reduction";
1031 case MatchConvolutionResult::EmptyConvolvedDims:
1032 return "expected convolved dim to be non-empty";
1033 case MatchConvolutionResult::Success:
1034 return "";
1035 }
1036 llvm_unreachable("unhandled MatchConvolutionResult case");
1037}
1038
1039bool mlir::linalg::isaConvolutionOpInterface(LinalgOp linalgOp,
1040 bool allowEmptyConvolvedDims) {
1041 return linalg::detail::isConvolutionInterfaceImpl(
1042 op: linalgOp.getOperation(), dimensions: nullptr, allowEmptyConvolvedDims) ==
1043 linalg::detail::MatchConvolutionResult::Success;
1044}
1045
1046LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
1047 MatchConvolutionResult res = isConvolutionInterfaceImpl(op);
1048 if (res != MatchConvolutionResult::Success)
1049 return op->emitError(message: getMatchConvolutionMessage(res));
1050 return success();
1051}
1052
1053//===----------------------------------------------------------------------===//
1054// FillOpInterface implementation
1055//===----------------------------------------------------------------------===//
1056
1057enum class MatchFillResult {
1058 Success = 0,
1059 NotLinalgOp,
1060 WrongNumOperands,
1061 NotScalarInput
1062};
1063
1064static MatchFillResult isFillInterfaceImpl(Operation *op) {
1065 auto linalgOp = dyn_cast<linalg::LinalgOp>(Val: op);
1066 if (!linalgOp)
1067 return MatchFillResult::NotLinalgOp;
1068 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
1069 return MatchFillResult::WrongNumOperands;
1070
1071 OpOperand *value = linalgOp.getDpsInputOperand(i: 0);
1072 if (!linalgOp.isScalar(opOperand: value))
1073 return MatchFillResult::NotScalarInput;
1074
1075 return MatchFillResult::Success;
1076}
1077
1078LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) {
1079 auto res = isFillInterfaceImpl(op);
1080 if (res == MatchFillResult::NotLinalgOp)
1081 return op->emitError(message: "expected a LinalgOp");
1082 if (res == MatchFillResult::WrongNumOperands)
1083 return op->emitError(message: "expected op with 1 input and 1 output");
1084 if (res == MatchFillResult::NotScalarInput)
1085 return op->emitError(message: "expected op with scalar input");
1086
1087 return success();
1088}
1089
1090//===----------------------------------------------------------------------===//
1091// StructuredOpInterface implementation
1092//===----------------------------------------------------------------------===//
1093
1094SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
1095 Location loc) {
1096 SmallVector<OpFoldResult> res;
1097 for (OpOperand &opOperand : getOperation()->getOpOperands()) {
1098 for (int64_t i = 0, e = getRank(opOperand: &opOperand); i < e; ++i)
1099 res.push_back(Elt: createFoldedDimOp(b, loc, val: opOperand.get(), dim: i));
1100 }
1101 return res;
1102}
1103
1104SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
1105 SmallVector<int64_t, 4> res;
1106 assert(!hasDynamicShape() && "expected operands to have static shapes");
1107 for (OpOperand &opOperand : getOperation()->getOpOperands())
1108 llvm::append_range(C&: res, R: getShape(opOperand: &opOperand));
1109 return res;
1110}
1111
1112SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
1113 AffineMap map = getLoopsToShapesMap();
1114 unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
1115 auto viewSizes = createFlatListOfOperandDims(b, loc);
1116 SmallVector<Range, 4> res(numDims);
1117 for (unsigned idx = 0; idx < numRes; ++idx) {
1118 auto result = map.getResult(idx);
1119 if (auto d = dyn_cast<AffineDimExpr>(Val&: result)) {
1120 if (res[d.getPosition()].offset)
1121 continue;
1122 res[d.getPosition()] =
1123 Range{.offset: b.getIndexAttr(value: 0), .size: viewSizes[idx], .stride: b.getIndexAttr(value: 1)};
1124 }
1125 }
1126 return res;
1127}
1128
1129/// Visitor to check if any of the given set of positions from AffineDimExprs
1130/// are used within an AffineExpr.
1131struct HasAffineDimExprVisitor
1132 : public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
1133 HasAffineDimExprVisitor(llvm::SmallBitVector positions)
1134 : positions(std::move(positions)) {}
1135
1136 bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) {
1137 return visit(expr: binaryOpExpr.getLHS()) || visit(expr: binaryOpExpr.getRHS());
1138 }
1139
1140 bool visitDimExpr(AffineDimExpr dimExpr) {
1141 return positions.test(Idx: dimExpr.getPosition());
1142 }
1143
1144 bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
1145
1146 bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
1147
1148private:
1149 llvm::SmallBitVector positions;
1150};
1151
1152static std::pair<int64_t, int64_t>
1153getResultsPositionInLoopsToShapeMap(LinalgOp &op) {
1154 int64_t inputRankSum = 0;
1155 int64_t outputRankSum = 0;
1156 for (OpOperand *input : op.getDpsInputOperands())
1157 inputRankSum += op.getRank(opOperand: input);
1158 for (OpOperand &output : op.getDpsInitsMutable())
1159 outputRankSum += op.getRank(opOperand: &output);
1160 return {inputRankSum, inputRankSum + outputRankSum};
1161}
1162
1163LogicalResult
1164LinalgOp::reifyResultShapes(OpBuilder &b,
1165 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1166 // An example that helps understand the logic below.
1167 // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
1168 // We want to express the shape of dim 0 of O in terms of shape of the inputs.
1169 // This is achieved as follows.
1170 // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
1171 // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
1172 // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
1173 // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
1174 // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
1175 AffineMap loopsToShapesMap = getLoopsToShapesMap();
1176
1177 // Find the position in the above map that represents the shape of the
1178 // result:dim being inferred.
1179 auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(op&: *this);
1180
1181 /// From loopsToShapesMap extract the submap that represents the shape of the
1182 /// (resultIdx, dim) needed.
1183 AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap(
1184 start: resultShapesSubMapPos.first,
1185 length: resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1186 AffineMap resultShapesFromInputShapesMap =
1187 loopToResultsShapeMap.compose(map: getShapesToLoopsMap());
1188
1189 // Check that the result dim map does not contain the positions corresponding
1190 // to the outputs.
1191 llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims());
1192 outputDims.set(I: resultShapesSubMapPos.first, E: resultShapesSubMapPos.second);
1193 HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
1194 Location loc = getOperation()->getLoc();
1195 IRRewriter rewriter(b);
1196 SmallVector<OpFoldResult> allResultDimValues =
1197 affine::makeComposedFoldedMultiResultAffineApply(
1198 b&: rewriter, loc, map: resultShapesFromInputShapesMap,
1199 operands: createFlatListOfOperandDims(b, loc));
1200 int64_t pos = 0;
1201 ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
1202 for (OpOperand &opOperand : getDpsInitsMutable()) {
1203 SmallVector<OpFoldResult> shapes;
1204 for (int64_t dim : llvm::seq<int64_t>(Begin: 0, End: getRank(opOperand: &opOperand))) {
1205 auto shapedType = llvm::cast<ShapedType>(Val: opOperand.get().getType());
1206 if (!shapedType.isDynamicDim(idx: dim)) {
1207 // Static dim: Return IntegerAttr.
1208 shapes.push_back(Elt: b.getIndexAttr(value: shapedType.getDimSize(idx: dim)));
1209 } else {
1210 // Dynamic dim: Return Value.
1211 OpFoldResult ofr = checkDimExpr.visit(expr: shapeExprs[pos])
1212 ? createOrFoldDimOp(b, loc, val: opOperand.get(), dim)
1213 : allResultDimValues[pos];
1214 shapes.push_back(Elt: getValueOrCreateConstantIndexOp(b, loc, ofr));
1215 }
1216 pos++;
1217 }
1218 reifiedReturnShapes.emplace_back(Args: std::move(shapes));
1219 }
1220 return success();
1221}
1222
1223/// Return the index in the indexingMaps vector that corresponds to this
1224/// `opOperand`.
1225int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
1226 auto operandNumber = opOperand->getOperandNumber();
1227 auto dpsIface = cast<DestinationStyleOpInterface>(Val&: *this->getOperation());
1228 if (!dpsIface.isDpsInput(opOperand))
1229 return operandNumber;
1230 unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1231 assert(!dpsIface.isDpsInit(opOperand));
1232 // Account for potential inputs that are not DPS and may not appear in
1233 // `indexingMaps`.
1234 return cast<DestinationStyleOpInterface>(Val&: *this->getOperation())
1235 .getNumDpsInputs() +
1236 operandNumber - start;
1237}
1238
1239LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
1240 LinalgOp linalgOp = cast<LinalgOp>(Val: op);
1241 // Mixed tensor/buffer operands are not allowed.
1242 if (!linalgOp.hasPureTensorSemantics() &&
1243 !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
1244 return op->emitOpError(message: "expected to have pure tensor or buffer semantics");
1245
1246 // Before checking indexing maps, we need to make sure the attributes
1247 // referenced by it are valid.
1248 if (linalgOp.hasDynamicIndexingMaps())
1249 if (failed(Result: linalgOp.verifyIndexingMapRequiredAttributes()))
1250 return failure();
1251
1252 // Delayed calling of IndexingMapOpInterface::verifyImpl.
1253 if (failed(Result: cast<IndexingMapOpInterface>(Val: op).verifyImpl()))
1254 return failure();
1255
1256 // Set this flag if this op has user defined maps. This is required to guard
1257 // the below error condition which assume default indexing maps.
1258 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1259 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand: &opOperand);
1260 // Domain must be consistent.
1261 unsigned numLoops = linalgOp.getNumLoops();
1262 if (indexingMap.getNumDims() != numLoops)
1263 return op->emitOpError(message: "expected indexing_map #")
1264 << opOperand.getOperandNumber() << " to have " << numLoops
1265 << " dim(s) to match the number of loops";
1266 }
1267 SmallVector<unsigned> redDims;
1268 linalgOp.getReductionDims(res&: redDims);
1269
1270 if (!linalgOp.getShapesToLoopsMap())
1271 return op->emitOpError(message: "expected the shape-to-loops map to be non-null");
1272
1273 // Check the region has exactly one block.
1274 if (linalgOp->getNumRegions() != 1 ||
1275 !llvm::hasSingleElement(C&: linalgOp->getRegion(index: 0)))
1276 return op->emitOpError(message: "expects to have 1 region with 1 block");
1277
1278 // Simplifying assumption: bbargs match 1-1 with shape operands elemental
1279 // types.
1280 // TODO: once ranked shape types are plugged in, we may want to drop the
1281 // corresponding bbargs, that can never be read from. This will be subject to
1282 // consistency discussions (i.e. what to do with output tensors whose bbarg is
1283 // not used).
1284 Block &block = linalgOp->getRegion(index: 0).front();
1285
1286 if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments())
1287 return op->emitOpError(message: "expected as many non-induction variable region "
1288 "arguments as the number of input/output operands");
1289
1290 for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1291 Type elementType = opOperand->get().getType();
1292 if (isa<MemRefType, RankedTensorType>(Val: elementType))
1293 elementType = getElementTypeOrSelf(type: opOperand->get().getType());
1294 Type argType = block.getArgument(i: opOperand->getOperandNumber()).getType();
1295 if (elementType != argType)
1296 return op->emitOpError(message: "expected type of bb argument #")
1297 << opOperand->getOperandNumber() << " (" << argType << ")"
1298 << " to match element or self type of the corresponding operand ("
1299 << elementType << ")";
1300 }
1301
1302 return success();
1303}
1304

source code of mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp