1 | //===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" |
10 | |
11 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
12 | #include "mlir/Dialect/Arith/IR/Arith.h" |
13 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
14 | #include "mlir/Dialect/Complex/IR/Complex.h" |
15 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
16 | #include "mlir/IR/AffineExpr.h" |
17 | #include "mlir/IR/AffineExprVisitor.h" |
18 | #include "mlir/IR/AffineMap.h" |
19 | #include "mlir/IR/BuiltinTypeInterfaces.h" |
20 | #include "mlir/IR/MLIRContext.h" |
21 | #include "mlir/IR/TypeUtilities.h" |
22 | #include "llvm/ADT/STLExtras.h" |
23 | #include "llvm/ADT/SetOperations.h" |
24 | #include "llvm/ADT/SmallBitVector.h" |
25 | #include "llvm/ADT/SmallVector.h" |
26 | #include "llvm/Support/Casting.h" |
27 | #include "llvm/Support/raw_ostream.h" |
28 | #include <algorithm> |
29 | #include <numeric> |
30 | #include <optional> |
31 | |
32 | using namespace mlir; |
33 | using namespace mlir::linalg; |
34 | |
35 | /// Include the definitions of the copy operation interface. |
36 | #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc" |
37 | |
38 | //===----------------------------------------------------------------------===// |
39 | // Interface utility functions |
40 | //===----------------------------------------------------------------------===// |
41 | |
42 | bool linalg::detail::canOpOperandsBeDroppedImpl( |
43 | linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) { |
44 | SmallVector<AffineMap> indexingMaps; |
45 | for (auto &opOperand : linalgOp->getOpOperands()) { |
46 | if (llvm::is_contained(droppedOperands, &opOperand)) |
47 | continue; |
48 | indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand)); |
49 | } |
50 | if (indexingMaps.empty()) { |
51 | // If there are no indexing maps, the operand can only be dropped |
52 | // if the op has no loops. |
53 | return linalgOp.getNumLoops() == 0; |
54 | } |
55 | return inversePermutation(concatAffineMaps( |
56 | indexingMaps, linalgOp.getContext())) != AffineMap(); |
57 | } |
58 | |
59 | //===----------------------------------------------------------------------===// |
60 | // CopyOpInterface implementation |
61 | //===----------------------------------------------------------------------===// |
62 | |
63 | bool linalg::isaCopyOpInterface(LinalgOp op) { |
64 | // Check all loops are parallel and linalgOp is single input and output. |
65 | if (!op.isAllParallelLoops() || !op.isSingleInputOutput()) |
66 | return false; |
67 | |
68 | auto mapRange = op.getIndexingMapsArray(); |
69 | if (mapRange.size() != 2 || !mapRange.front().isIdentity() || |
70 | !mapRange.back().isIdentity()) { |
71 | return false; |
72 | } |
73 | // Region. |
74 | return llvm::hasSingleElement(op.getBlock()->getOperations()); |
75 | } |
76 | |
77 | //===----------------------------------------------------------------------===// |
78 | // FillOpInterface implementation |
79 | //===----------------------------------------------------------------------===// |
80 | std::optional<Value> linalg::isaFillOpInterface(GenericOp op) { |
81 | // Structural. |
82 | if (!op.isAllParallelLoops() || !op.isSingleInputOutput() || |
83 | !op.isSingleYieldOp()) |
84 | return std::nullopt; |
85 | |
86 | // Input should be referenced and init should not. |
87 | if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) || |
88 | op.payloadUsesValueFromOperand(op.getDpsInitOperand(0))) |
89 | return std::nullopt; |
90 | |
91 | OpOperand *value = op.getDpsInputOperand(0); |
92 | if (!op.isScalar(value)) |
93 | return std::nullopt; |
94 | return value->get(); |
95 | } |
96 | |
97 | //===----------------------------------------------------------------------===// |
98 | // BroadcastOpInterface implementation |
99 | //===----------------------------------------------------------------------===// |
100 | std::optional<SmallVector<int64_t>> |
101 | linalg::isaBroadcastOpInterface(GenericOp op) { |
102 | // Structural. |
103 | if (!op.isAllParallelLoops() || !op.isSingleInputOutput() || |
104 | !op.isSingleYieldOp()) |
105 | return std::nullopt; |
106 | |
107 | auto srcTy = op.getDpsInputOperand(0)->get().getType(); |
108 | auto dstTy = op.getDpsInitOperand(0)->get().getType(); |
109 | if (!isa<MemRefType, RankedTensorType>(srcTy) || |
110 | !isa<MemRefType, RankedTensorType>(dstTy)) |
111 | return std::nullopt; |
112 | |
113 | // Check output is identity map. Broadcast could additionally be |
114 | // employing permutation of indices and that would be expressible |
115 | // in linalg.generic but is not expressible for named broadcast op. |
116 | auto dstMap = op.getIndexingMapsArray()[1]; |
117 | if (!dstMap.isIdentity()) |
118 | return std::nullopt; |
119 | |
120 | SmallVector<int64_t> position; |
121 | auto srcMap = op.getIndexingMapsArray()[0]; |
122 | |
123 | if (srcMap.getResults().size() >= dstMap.getResults().size()) |
124 | return std::nullopt; |
125 | |
126 | // Check input map is monotonically increasing DimIds. |
127 | for (unsigned i = 0; i < srcMap.getNumResults(); ++i) { |
128 | auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]); |
129 | if (!expr) |
130 | return std::nullopt; |
131 | int64_t pos = expr.getPosition(); |
132 | if (i > 0 && pos <= position[i - 1]) |
133 | return std::nullopt; |
134 | position.push_back(Elt: expr.getPosition()); |
135 | } |
136 | |
137 | SmallVector<int64_t> broadcastedDims; |
138 | auto numDims = srcMap.getNumDims(); |
139 | // This is quadratic but number of items is generally small. |
140 | for (auto dim : llvm::seq<int64_t>(0, numDims)) { |
141 | if (!llvm::is_contained(position, dim)) |
142 | broadcastedDims.push_back(dim); |
143 | } |
144 | return broadcastedDims; |
145 | } |
146 | |
147 | //===----------------------------------------------------------------------===// |
148 | // TransposeOpInterface implementation |
149 | //===----------------------------------------------------------------------===// |
150 | std::optional<SmallVector<int64_t>> |
151 | linalg::isaTransposeOpInterface(GenericOp op) { |
152 | // To specialize as a transpose op, the genericOp must be |
153 | // all parallel loops, single input, single output, and its body |
154 | // should be just a yield op, yielding input as output as is (no compute). |
155 | if (!op.isAllParallelLoops() || !op.isSingleInputOutput() || |
156 | !op.isSingleYieldOp()) |
157 | return std::nullopt; |
158 | |
159 | auto mapRange = op.getIndexingMapsArray(); |
160 | if (mapRange.size() != 2) |
161 | return std::nullopt; |
162 | |
163 | auto mapOfInput = mapRange.front(); |
164 | auto mapOfResult = mapRange.back(); |
165 | |
166 | // linalg.transpose permutes the dimensions of input using this |
167 | // rule: dim(result, i) = dim(input, permutation[i]) |
168 | if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation()) |
169 | return std::nullopt; |
170 | |
171 | SmallVector<int64_t> permutation(mapOfInput.getNumDims()); |
172 | for (unsigned i = 0; i < mapOfInput.getNumDims(); ++i) { |
173 | auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]); |
174 | permutation[expr.getPosition()] = i; |
175 | } |
176 | return permutation; |
177 | } |
178 | |
179 | //===----------------------------------------------------------------------===// |
180 | // Elementwise Single Unary/Binary-OpInterface implementation |
181 | //===----------------------------------------------------------------------===// |
182 | static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, |
183 | unsigned arity) { |
184 | // Check all loops are parallel. |
185 | if (!op.isAllParallelLoops() || op.getNumLoops() < 1) |
186 | return false; |
187 | |
188 | // Check there are arity-inputs, 1-output and all are identity-maps. |
189 | if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 || |
190 | !llvm::all_of(op.getIndexingMapsArray(), |
191 | [](AffineMap map) { return map.isIdentity(); })) |
192 | return false; |
193 | |
194 | // Init should not be referenced for elementwise operations. |
195 | if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0))) |
196 | return false; |
197 | |
198 | // A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such |
199 | // as resulting from producer-consumer fusion. Here, we restrict to two ops in |
200 | // the body, where the first is the elementwise single op and the second a |
201 | // yield. |
202 | Block *body = op.getBody(); |
203 | if (body->getOperations().size() != 2) |
204 | return false; |
205 | |
206 | Operation *oper = &body->front(); |
207 | if (oper->getNumOperands() != arity || oper->getNumResults() != 1) |
208 | return false; |
209 | |
210 | auto yieldOp = dyn_cast<linalg::YieldOp>(body->back()); |
211 | if (!yieldOp || yieldOp.getNumOperands() != 1 || |
212 | yieldOp->getOperand(0).getDefiningOp() != oper) |
213 | return false; |
214 | return true; |
215 | } |
216 | |
217 | bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op) { |
218 | // All basic elemwise checks. |
219 | if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 1)) |
220 | return false; |
221 | |
222 | // Check input is actully used. |
223 | if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0))) |
224 | return false; |
225 | return true; |
226 | } |
227 | |
228 | bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op) { |
229 | if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 2)) |
230 | return false; |
231 | |
232 | // Check both inputs are used (elementwise). |
233 | OpOperand *inputOpOperand0 = op.getDpsInputOperand(0); |
234 | OpOperand *inputOpOperand1 = op.getDpsInputOperand(1); |
235 | if (!op.payloadUsesValueFromOperand(inputOpOperand0) || |
236 | !op.payloadUsesValueFromOperand(inputOpOperand1)) |
237 | return false; |
238 | return true; |
239 | } |
240 | |
241 | //===----------------------------------------------------------------------===// |
242 | // ContractionOpInterface implementation |
243 | //===----------------------------------------------------------------------===// |
244 | |
245 | /// If the value is defined by a chain of unary side effect-free, go up the |
246 | /// use-def chain until the first value that isn't defined by such an op. |
247 | // TODO: relax to multi-operands with constants, which are technically unary ops |
248 | // as needed (e.g. add5). |
249 | static Value getSourceSkipUnary(Value value) { |
250 | Operation *op = value.getDefiningOp(); |
251 | while (op && op->getNumOperands() == 1) { |
252 | auto iface = dyn_cast<MemoryEffectOpInterface>(op); |
253 | if (!iface || !iface.hasNoEffect()) |
254 | break; |
255 | value = op->getOperand(idx: 0); |
256 | op = value.getDefiningOp(); |
257 | } |
258 | return value; |
259 | } |
260 | |
261 | bool mlir::linalg::detail::isContractionBody( |
262 | Block &block, function_ref<bool(Operation *, Operation *)> isaPair, |
263 | llvm::raw_ostream &errs) { |
264 | if (block.empty() || !block.back().mightHaveTrait<OpTrait::IsTerminator>()) { |
265 | errs << "no terminator in the block"; |
266 | return false; |
267 | } |
268 | |
269 | if (block.getNumArguments() != 3) { |
270 | errs << "expected block with 3 arguments"; |
271 | return false; |
272 | } |
273 | |
274 | Operation *terminator = block.getTerminator(); |
275 | if (terminator->getNumOperands() != 1) { |
276 | errs << "expected terminator with 1 operand"; |
277 | return false; |
278 | } |
279 | |
280 | Value yielded = getSourceSkipUnary(value: terminator->getOperand(idx: 0)); |
281 | Operation *reductionOp = yielded.getDefiningOp(); |
282 | if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) { |
283 | errs << "expected reduction op to be binary"; |
284 | return false; |
285 | } |
286 | |
287 | Value reductionLHS = getSourceSkipUnary(value: reductionOp->getOperand(idx: 0)); |
288 | Value reductionRHS = getSourceSkipUnary(value: reductionOp->getOperand(idx: 1)); |
289 | |
290 | if (reductionLHS != block.getArgument(i: 2) && |
291 | reductionRHS != block.getArgument(i: 2)) { |
292 | errs << "expected reduction to take block argument #2 as one of the " |
293 | "operands (modulo unary casts)"; |
294 | return false; |
295 | } |
296 | |
297 | Value contributed = getSourceSkipUnary( |
298 | value: isa<BlockArgument>(Val: reductionLHS) ? reductionRHS : reductionLHS); |
299 | Operation *elementwiseOp = contributed.getDefiningOp(); |
300 | if (!elementwiseOp || elementwiseOp->getNumResults() != 1 || |
301 | elementwiseOp->getNumOperands() != 2) { |
302 | errs << "expected elementwise op to be binary"; |
303 | return false; |
304 | } |
305 | |
306 | if (!isaPair(elementwiseOp, reductionOp)) { |
307 | errs << "expected reduction/elementwise op kind not satisfied"; |
308 | return false; |
309 | } |
310 | |
311 | Value elementwiseLHS = getSourceSkipUnary(value: elementwiseOp->getOperand(idx: 0)); |
312 | Value elementwiseRHS = getSourceSkipUnary(value: elementwiseOp->getOperand(idx: 1)); |
313 | if ((elementwiseLHS == block.getArgument(i: 0) && |
314 | elementwiseRHS == block.getArgument(i: 1)) || |
315 | (elementwiseLHS == block.getArgument(i: 1) && |
316 | elementwiseRHS == block.getArgument(i: 0))) { |
317 | return true; |
318 | } |
319 | |
320 | errs << "expected elementwise op to apply to block arguments (modulo unary " |
321 | "casts)"; |
322 | return false; |
323 | } |
324 | |
325 | /// Returns true if the two operations are of the kinds specified by a pair of |
326 | /// consecutive template arguments. |
327 | template <typename AddOpTy, typename MulOpTy, typename... Args> |
328 | static bool isPairTemplateImpl(Operation *add, Operation *mul) { |
329 | static_assert(sizeof...(Args) % 2 == 0, |
330 | "expected an even number of template arguments"); |
331 | if (isa<AddOpTy>(add) && isa<MulOpTy>(mul)) |
332 | return true; |
333 | |
334 | if constexpr (sizeof...(Args) > 0) |
335 | return isPairTemplateImpl<Args...>(add, mul); |
336 | else |
337 | return false; |
338 | } |
339 | |
340 | /// Returns true if the block is a body of a contraction with the kinds of |
341 | /// operations given pairwise by template arguments. |
342 | template <typename... Args> |
343 | static bool isContractionBody(Block &block) { |
344 | return linalg::detail::isContractionBody(block, isaPair: &isPairTemplateImpl<Args...>); |
345 | } |
346 | |
347 | /// Given an `indexingMap` and its corresponding `iterators`, returns |
348 | /// the positions of the iterators of type `iter` that are indexed by |
349 | /// the `indexingMap` as a permutation. This is useful to infer various |
350 | /// subcomputations on a `LinalgOp`. This is performed by looking up |
351 | /// each result in the `indexingMap` and determining whether: |
352 | /// - It is a single AffineDimExpr. |
353 | /// - It is the only result involving this AffineDimExpr. |
354 | static llvm::SmallDenseSet<int64_t> |
355 | findPermutationsIndexingOperand(AffineMap indexingMap, |
356 | ArrayRef<utils::IteratorType> iterators, |
357 | utils::IteratorType iter) { |
358 | assert(iterators.size() == indexingMap.getNumDims()); |
359 | llvm::SmallDenseSet<int64_t> res; |
360 | for (AffineExpr e : indexingMap.getResults()) { |
361 | if (auto d = dyn_cast<AffineDimExpr>(Val&: e)) { |
362 | if (iterators[d.getPosition()] == iter && |
363 | llvm::count_if(Range: indexingMap.getResults(), P: [d](AffineExpr e) { |
364 | return e.isFunctionOfDim(position: d.getPosition()); |
365 | }) == 1) |
366 | res.insert(V: d.getPosition()); |
367 | } |
368 | } |
369 | return res; |
370 | } |
371 | |
372 | namespace { |
373 | auto par = utils::IteratorType::parallel; |
374 | auto red = utils::IteratorType::reduction; |
375 | } // namespace |
376 | |
377 | /// Infer the iterator types from the init affine map. This looks at which dims |
378 | /// are present in the map results, and returns an iterator types array with |
379 | /// parallel types for dims that are present, and reduction types for dims that |
380 | /// are not present. |
381 | static FailureOr<SmallVector<utils::IteratorType>> |
382 | inferIteratorsFromOutMap(AffineMap map) { |
383 | if (!map.isProjectedPermutation()) |
384 | return failure(); |
385 | SmallVector<utils::IteratorType> iterators(map.getNumDims(), red); |
386 | for (auto expr : map.getResults()) |
387 | if (auto dim = dyn_cast<AffineDimExpr>(Val&: expr)) |
388 | iterators[dim.getPosition()] = par; |
389 | return iterators; |
390 | } |
391 | |
392 | /// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form |
393 | /// a matmul subcomputation within `linalgOp`. These dimensions are such that: |
394 | /// 1. The m dimension is involved in an outer-product along LHS |
395 | /// (i.e. it is a permutation on RES and LHS and does not appear in RHS). |
396 | /// 2. The n dimension is involved in an outer-product along RHS |
397 | /// (i.e. it is a permutation on RES and RHS and does not appear in LHS). |
398 | /// 3. The k dimension appears as a permutation on LHS and RHS. |
399 | /// 4. m, n and k appear only once in any given indexing. |
400 | /// 5. Optional batch dimensions that appear in all operands are captured. |
401 | /// This allows e.g. detecting that some contraction is embedded within |
402 | /// `linalgOp` with some orthogonal heuristic. |
403 | static FailureOr<ContractionDimensions> |
404 | inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps, |
405 | ArrayRef<utils::IteratorType> iterators) { |
406 | llvm::SmallDenseSet<int64_t> a = |
407 | findPermutationsIndexingOperand(indexingMaps[0], iterators, par); |
408 | llvm::SmallDenseSet<int64_t> b = |
409 | findPermutationsIndexingOperand(indexingMaps[1], iterators, par); |
410 | llvm::SmallDenseSet<int64_t> c = |
411 | findPermutationsIndexingOperand(indexingMaps[2], iterators, par); |
412 | |
413 | // A & C - B are the iterators involved in an outer-product along A (the LHS). |
414 | llvm::SmallDenseSet<int64_t> ac = a; |
415 | llvm::set_intersect(S1&: ac, S2: c); |
416 | llvm::set_subtract(S1&: ac, S2: b); |
417 | // B & C - A are the iterators involved in an outer-product along B (the RHS). |
418 | llvm::SmallDenseSet<int64_t> bc = b; |
419 | llvm::set_intersect(S1&: bc, S2: c); |
420 | llvm::set_subtract(S1&: bc, S2: a); |
421 | // A & B & C are the "batch" dimensions. |
422 | llvm::SmallDenseSet<int64_t> batches = a; |
423 | llvm::set_intersect(S1&: batches, S2: b); |
424 | llvm::set_intersect(S1&: batches, S2: c); |
425 | |
426 | // A & B red are the reduction dimensions. |
427 | llvm::SmallDenseSet<int64_t> ra = |
428 | findPermutationsIndexingOperand(indexingMaps[0], iterators, red); |
429 | llvm::SmallDenseSet<int64_t> rb = |
430 | findPermutationsIndexingOperand(indexingMaps[1], iterators, red); |
431 | llvm::set_intersect(S1&: ra, S2: rb); |
432 | |
433 | // Return each set in sorted order. |
434 | ContractionDimensions dimensions{ |
435 | .batch: SmallVector<unsigned, 2>(batches.begin(), batches.end()), |
436 | .m: SmallVector<unsigned, 2>(ac.begin(), ac.end()), |
437 | .n: SmallVector<unsigned, 2>(bc.begin(), bc.end()), |
438 | .k: SmallVector<unsigned, 2>(ra.begin(), ra.end())}; |
439 | llvm::sort(Start: dimensions.batch.begin(), End: dimensions.batch.end()); |
440 | llvm::sort(Start: dimensions.m.begin(), End: dimensions.m.end()); |
441 | llvm::sort(Start: dimensions.n.begin(), End: dimensions.n.end()); |
442 | llvm::sort(Start: dimensions.k.begin(), End: dimensions.k.end()); |
443 | return dimensions; |
444 | } |
445 | |
446 | FailureOr<ContractionDimensions> |
447 | mlir::linalg::inferContractionDims(LinalgOp linalgOp) { |
448 | if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2) |
449 | return failure(); |
450 | return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(), |
451 | linalgOp.getIteratorTypesArray()); |
452 | } |
453 | |
454 | FailureOr<ContractionDimensions> |
455 | mlir::linalg::inferContractionDims(ArrayRef<AffineMap> indexingMaps) { |
456 | if (indexingMaps.size() != 3) |
457 | return failure(); |
458 | auto iterators = inferIteratorsFromOutMap(indexingMaps[2]); |
459 | if (failed(iterators)) |
460 | return failure(); |
461 | return inferContractionDimsImpl(indexingMaps, iterators.value()); |
462 | } |
463 | |
464 | namespace mlir::linalg::detail { |
465 | enum class MatchContractionResult { |
466 | Success = 0, |
467 | NotLinalgOp, |
468 | WrongNumOperands, |
469 | NoReduction, |
470 | NotProjectedPermutations, |
471 | NotAddMul |
472 | }; |
473 | } // namespace mlir::linalg::detail |
474 | |
475 | mlir::linalg::detail::MatchContractionResult |
476 | mlir::linalg::detail::isContractionInterfaceImpl( |
477 | Operation *op, mlir::linalg::ContractionDimensions *dimensions) { |
478 | auto linalgOp = dyn_cast<linalg::LinalgOp>(op); |
479 | if (!linalgOp) |
480 | return MatchContractionResult::NotLinalgOp; |
481 | if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1) |
482 | return MatchContractionResult::WrongNumOperands; |
483 | auto mapRange = linalgOp.getIndexingMapsArray(); |
484 | if (linalgOp.getNumReductionLoops() == 0) |
485 | return MatchContractionResult::NoReduction; |
486 | if (llvm::any_of(mapRange, |
487 | [](AffineMap m) { return !m.isProjectedPermutation(); })) |
488 | return MatchContractionResult::NotProjectedPermutations; |
489 | // TODO: more fields than add/mul. |
490 | // clang-format off |
491 | if (!::isContractionBody< |
492 | arith::MulFOp, arith::AddFOp, |
493 | arith::MulIOp, arith::AddIOp, |
494 | complex::MulOp, complex::AddOp, |
495 | arith::AndIOp, arith::OrIOp>( |
496 | *linalgOp.getBlock())) { |
497 | return MatchContractionResult::NotAddMul; |
498 | } |
499 | // clang-format on |
500 | |
501 | if (dimensions) { |
502 | FailureOr<ContractionDimensions> res = inferContractionDims(linalgOp); |
503 | assert(succeeded(res) && "unexpected failure to infer contraction dims"); |
504 | *dimensions = *res; |
505 | } |
506 | return MatchContractionResult::Success; |
507 | } |
508 | |
509 | StringRef |
510 | mlir::linalg::detail::getMatchContractionMessage(MatchContractionResult res) { |
511 | switch (res) { |
512 | case MatchContractionResult::NotLinalgOp: |
513 | return "expected a LinalgOp"; |
514 | case MatchContractionResult::WrongNumOperands: |
515 | return "expected op with 2 inputs and 1 output"; |
516 | case MatchContractionResult::NoReduction: |
517 | return "expected at least 1 reduction"; |
518 | case MatchContractionResult::NotProjectedPermutations: |
519 | return "expected indexing maps to be projected permutations"; |
520 | case MatchContractionResult::NotAddMul: |
521 | return "expected add/mul op in the body"; |
522 | case MatchContractionResult::Success: |
523 | return ""; |
524 | } |
525 | llvm_unreachable("unhandled MatchContractionResult case"); |
526 | } |
527 | |
528 | bool mlir::linalg::isaContractionOpInterface(LinalgOp linalgOp) { |
529 | if (!linalgOp) |
530 | return false; |
531 | Operation *op = linalgOp.getOperation(); |
532 | return isa<ContractionOpInterface>(op) || |
533 | (mlir::linalg::detail::isContractionInterfaceImpl(op) == |
534 | mlir::linalg::detail::MatchContractionResult::Success); |
535 | } |
536 | |
537 | /// Verify that a LinalgOp `op` is a contraction. |
538 | /// A Linalg contraction is defined in general terms: |
539 | /// 1. Has 2 input and 1 output shapes. |
540 | /// 2. Has at least one reduction dimension. |
541 | /// 3. Has only projected permutation indexing maps. |
542 | /// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field |
543 | /// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary |
544 | /// operations that may change the type (e.g. for mixed-precision). |
545 | /// As a consequence, when vectorization of such an op occurs, the only special |
546 | /// behavior is that the (unique) MulOpType is vectorized into a |
547 | /// `vector.contract`. All other ops are handled in a generic fashion. |
548 | /// In the future, we may wish to allow more input arguments and elementwise and |
549 | /// constant operations that do not involve the reduction dimension(s). |
550 | LogicalResult mlir::linalg::detail::verifyContractionInterface(Operation *op) { |
551 | auto res = isContractionInterfaceImpl(op); |
552 | if (res != MatchContractionResult::Success) |
553 | return op->emitError(message: getMatchContractionMessage(res)); |
554 | return success(); |
555 | } |
556 | |
557 | //===----------------------------------------------------------------------===// |
558 | // ConvolutionOpInterface implementation |
559 | //===----------------------------------------------------------------------===// |
560 | |
561 | /// Of the given two expressions returns one that is of type T (`lhs` gets |
562 | /// preference over `rhs`) |
563 | template <typename T> |
564 | static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs) { |
565 | return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) : nullptr); |
566 | } |
567 | |
568 | namespace { |
569 | /// Walk the indexing expressions for input of a convolution operation to verify |
570 | /// its of the right form, either |
571 | /// - AffineDimExpr |
572 | /// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))? |
573 | /// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)* |
574 | /// |
575 | /// classifies the AffineDimExpr as convolved dimensions or unconvolved |
576 | /// dimensions and verifies each dimension occurs only once. |
577 | struct ConvAccessExprWalker |
578 | : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> { |
579 | // Stores dimensions used in expressions of the above form. |
580 | llvm::SmallDenseSet<int64_t> convolvedDims; |
581 | // Stores the dual mapping between LHS and RHS of convolution exprs. |
582 | llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping; |
583 | // Stores single use dimensions used by an AffineDimExpr. |
584 | llvm::SmallDenseSet<int64_t> unConvolvedDims; |
585 | // Stores a mapping from convolved dims to their coefficient. |
586 | llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping; |
587 | |
588 | // Removes dims with multiple uses in the source input map from dimension |
589 | // sets tracked by this walker. |
590 | void clearMultiUseDims(AffineMap map) { |
591 | for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) { |
592 | if (llvm::count_if(Range: map.getResults(), P: [dimPos](AffineExpr e) { |
593 | return e.isFunctionOfDim(position: dimPos); |
594 | }) > 1) { |
595 | convolvedDims.erase(V: dimPos); |
596 | unConvolvedDims.erase(V: dimPos); |
597 | // If a duplicate dim is marked as convolved, the pair of the duplicate |
598 | // dim must be removed from the map as well. |
599 | auto it = convolvedDimMapping.find(Val: dimPos); |
600 | if (it != convolvedDimMapping.end()) { |
601 | int64_t pairedDim = it->second; |
602 | convolvedDims.erase(V: pairedDim); |
603 | unConvolvedDims.erase(V: pairedDim); |
604 | strideAndDilationMapping.erase(Val: pairedDim); |
605 | convolvedDimMapping.erase(Val: dimPos); |
606 | convolvedDimMapping.erase(Val: pairedDim); |
607 | } |
608 | } |
609 | } |
610 | } |
611 | |
612 | LogicalResult visitDimExpr(AffineDimExpr dimExpr) { |
613 | unsigned position = dimExpr.getPosition(); |
614 | if (unConvolvedDims.count(V: position) || convolvedDims.count(V: position)) { |
615 | return failure(); |
616 | } |
617 | unConvolvedDims.insert(V: position); |
618 | return success(); |
619 | } |
620 | |
621 | LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); } |
622 | |
623 | LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); } |
624 | |
625 | LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) { |
626 | // In pre-order visit, top level op has to be an add op. |
627 | if (binaryExpr.getKind() != AffineExprKind::Add) |
628 | return failure(); |
629 | auto lhsDimPos = getDimExprOrMulExprDimPos(expr: binaryExpr.getLHS()); |
630 | auto rhsDimPos = getDimExprOrMulExprDimPos(expr: binaryExpr.getRHS()); |
631 | if (failed(Result: lhsDimPos) || failed(Result: rhsDimPos)) |
632 | return failure(); |
633 | convolvedDimMapping[*lhsDimPos] = *rhsDimPos; |
634 | convolvedDimMapping[*rhsDimPos] = *lhsDimPos; |
635 | return success(); |
636 | } |
637 | |
638 | FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) { |
639 | if (auto dimExpr = dyn_cast<AffineDimExpr>(Val&: expr)) { |
640 | int64_t dim = dimExpr.getPosition(); |
641 | if (convolvedDims.count(V: dim) || unConvolvedDims.count(V: dim)) |
642 | return failure(); |
643 | // Stride/dilation for this dim is implicitly 1. |
644 | strideAndDilationMapping[dim] = |
645 | getAffineConstantExpr(constant: 1, context: expr.getContext()); |
646 | convolvedDims.insert(V: dim); |
647 | return dim; |
648 | } |
649 | if (auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(Val&: expr)) { |
650 | if (symbolMulExpr.getKind() != AffineExprKind::Mul) |
651 | return failure(); |
652 | auto lhsExpr = symbolMulExpr.getLHS(); |
653 | auto rhsExpr = symbolMulExpr.getRHS(); |
654 | // Check for symbol expression. |
655 | AffineExpr mulExpr = |
656 | getAffineExprOfType<AffineSymbolExpr>(lhs: lhsExpr, rhs: rhsExpr); |
657 | // If there was no symbol expr, check for constant expression. |
658 | if (!mulExpr) { |
659 | mulExpr = getAffineExprOfType<AffineConstantExpr>(lhs: lhsExpr, rhs: rhsExpr); |
660 | } |
661 | auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhs: lhsExpr, rhs: rhsExpr); |
662 | if (!mulExpr || !dimExpr) |
663 | return failure(); |
664 | int64_t dim = dimExpr.getPosition(); |
665 | if (convolvedDims.count(V: dim) || unConvolvedDims.count(V: dim)) |
666 | return failure(); |
667 | strideAndDilationMapping[dim] = mulExpr; |
668 | convolvedDims.insert(V: dim); |
669 | return dim; |
670 | } |
671 | return failure(); |
672 | } |
673 | }; |
674 | } // namespace |
675 | |
676 | static llvm::SmallDenseSet<int64_t> getPreservedDims(AffineMap map) { |
677 | assert(map.isProjectedPermutation() && |
678 | "expected map to have projected permutations"); |
679 | llvm::SmallDenseSet<int64_t> preservedDims; |
680 | for (auto expr : map.getResults()) |
681 | preservedDims.insert(V: cast<AffineDimExpr>(Val&: expr).getPosition()); |
682 | return preservedDims; |
683 | } |
684 | |
685 | static SmallVector<int64_t, 2> |
686 | getConstantsFromExprList(const SmallVector<AffineExpr, 2> &exprs) { |
687 | SmallVector<int64_t, 2> vals; |
688 | for (auto e : exprs) { |
689 | auto constantExpr = dyn_cast<AffineConstantExpr>(Val&: e); |
690 | assert(constantExpr && "Found non-constant stride/dilation"); |
691 | vals.push_back(Elt: constantExpr.getValue()); |
692 | } |
693 | return vals; |
694 | } |
695 | |
696 | /// Classifies dimensions in the `linalgOp` used by a convolution |
697 | /// subcomputation, as captured by `inputExprWalker`. If |
698 | /// `allowEmptyConvolvedDims` is not set this this will fail if there is not |
699 | /// at least convolved dimension pair (output image + filter loop). Convolution |
700 | /// dimensions are specified in sorted order, and strides match the order of |
701 | /// the filter loop dimensions, while the dilations match the order of the |
702 | /// output image dimensions. |
703 | static FailureOr<ConvolutionDimensions> |
704 | inferConvolutionDimsImpl(LinalgOp linalgOp, |
705 | ConvAccessExprWalker &inputExprWalker, |
706 | bool allowEmptyConvolvedDims) { |
707 | auto filterMap = |
708 | linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1)); |
709 | auto outputMap = |
710 | linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0)); |
711 | llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand( |
712 | filterMap, linalgOp.getIteratorTypesArray(), par); |
713 | llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand( |
714 | outputMap, linalgOp.getIteratorTypesArray(), par); |
715 | |
716 | // unConvolvedDims & outputDims - filterDims are the batch iterators. |
717 | llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims; |
718 | llvm::set_intersect(S1&: batch, S2: outputDims); |
719 | llvm::set_subtract(S1&: batch, S2: filterDims); |
720 | |
721 | // convolvedDims & outputDims are the output image iterators. |
722 | llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims; |
723 | llvm::set_intersect(S1&: oi, S2: outputDims); |
724 | |
725 | // filterDims & outputDims - unConvolvedDims are the output channel iterators. |
726 | llvm::SmallDenseSet<int64_t> oc = filterDims; |
727 | llvm::set_intersect(S1&: oc, S2: outputDims); |
728 | llvm::set_subtract(S1&: oc, S2: inputExprWalker.unConvolvedDims); |
729 | |
730 | // filterDims & outputDims & unConvolvedDims are the depth iterators. |
731 | llvm::SmallDenseSet<int64_t> depth = filterDims; |
732 | llvm::set_intersect(S1&: depth, S2: outputDims); |
733 | llvm::set_intersect(S1&: depth, S2: inputExprWalker.unConvolvedDims); |
734 | |
735 | llvm::SmallDenseSet<int64_t> filterReducedDims = |
736 | findPermutationsIndexingOperand(filterMap, |
737 | linalgOp.getIteratorTypesArray(), red); |
738 | |
739 | // convolvedDims & filterReducedDims are the filter loop iterators. |
740 | llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims; |
741 | llvm::set_intersect(S1&: fl, S2: filterReducedDims); |
742 | |
743 | // unConvolvedDims & filterReducedDims are the input channel iterators. |
744 | llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims; |
745 | llvm::set_intersect(S1&: ic, S2: filterReducedDims); |
746 | |
747 | if (oi.empty() && !allowEmptyConvolvedDims) |
748 | return failure(); |
749 | |
750 | // Return each set in sorted order. |
751 | ConvolutionDimensions dimensions{ |
752 | .batch: SmallVector<unsigned, 2>(batch.begin(), batch.end()), |
753 | .outputImage: SmallVector<unsigned, 2>(oi.begin(), oi.end()), |
754 | .outputChannel: SmallVector<unsigned, 2>(oc.begin(), oc.end()), |
755 | .filterLoop: SmallVector<unsigned, 2>(fl.begin(), fl.end()), |
756 | .inputChannel: SmallVector<unsigned, 2>(ic.begin(), ic.end()), |
757 | .depth: SmallVector<unsigned, 2>(depth.begin(), depth.end()), |
758 | /*strides=*/SmallVector<int64_t, 2>{}, |
759 | /*dilations=*/SmallVector<int64_t, 2>{}}; |
760 | llvm::sort(Start: dimensions.batch.begin(), End: dimensions.batch.end()); |
761 | llvm::sort(Start: dimensions.outputImage.begin(), End: dimensions.outputImage.end()); |
762 | llvm::sort(Start: dimensions.outputChannel.begin(), End: dimensions.outputChannel.end()); |
763 | llvm::sort(Start: dimensions.filterLoop.begin(), End: dimensions.filterLoop.end()); |
764 | llvm::sort(Start: dimensions.inputChannel.begin(), End: dimensions.inputChannel.end()); |
765 | llvm::sort(Start: dimensions.depth.begin(), End: dimensions.depth.end()); |
766 | |
767 | // Use the op carried strides/dilations attribute if present. |
768 | auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides"); |
769 | if (!nativeStrides) { |
770 | SmallVector<AffineExpr, 2> strideExprs; |
771 | for (unsigned oiDim : dimensions.outputImage) |
772 | strideExprs.push_back(Elt: inputExprWalker.strideAndDilationMapping[oiDim]); |
773 | dimensions.strides = getConstantsFromExprList(exprs: strideExprs); |
774 | } else { |
775 | dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>()); |
776 | } |
777 | auto nativeDilations = |
778 | linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations"); |
779 | if (!nativeDilations) { |
780 | SmallVector<AffineExpr, 2> dilationExprs; |
781 | for (unsigned flDim : dimensions.filterLoop) |
782 | dilationExprs.push_back(Elt: inputExprWalker.strideAndDilationMapping[flDim]); |
783 | dimensions.dilations = getConstantsFromExprList(exprs: dilationExprs); |
784 | } else { |
785 | dimensions.dilations = |
786 | llvm::to_vector<2>(nativeDilations.getValues<int64_t>()); |
787 | } |
788 | return dimensions; |
789 | } |
790 | |
791 | /// Find at least 1 parallel (output_image) and reduction (filter_loop) |
792 | /// dimension candidates that form a convolution subcomputation within |
793 | /// `linalgOp`. The LHS is assumed to be the convolution input while the |
794 | /// RHS is assumed as the filter. |
795 | /// These dimensions are such that: |
796 | /// 1. Optional batch dimensions that appear in the input and filter. |
797 | /// 2. The output_image dimension is involved in a cross-correlation along LHS |
798 | /// (i.e. it is a permutation on RES and LHS and has an associated |
799 | /// filter_loop in RHS). |
800 | /// 3. Optional output_channel dimension is involved in an outer-product along |
801 | /// RHS (i.e. it is a permutation on RES and RHS and does not appear in |
802 | /// LHS). |
803 | /// 4. Optional input_channel dimension appears as a permutation on LHS and |
804 | /// RHS. |
805 | /// 5. The filter_loop dimension appears as a permutation on the RHS and |
806 | /// represents the shape of the kernel cross-correlated along a |
807 | /// corresponding output_image dim. |
808 | /// 6. The input_channel dimension appears as a permutation on LHS and RHS. |
809 | /// 7. All dimensions appear only once in any given indexing map. |
810 | /// This allows e.g. detecting that some convolution is embedded within |
811 | /// `linalgOp` with some orthogonal heuristic. |
812 | /// When multiple dimension occurrences exist that match any classification |
813 | /// indices are returned in sorted order. |
814 | /// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty. |
815 | FailureOr<ConvolutionDimensions> |
816 | mlir::linalg::inferConvolutionDims(LinalgOp linalgOp) { |
817 | if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2) |
818 | return failure(); |
819 | |
820 | auto indexingMaps = linalgOp.getIndexingMapsArray(); |
821 | |
822 | // Check the input indexing map has the right form. |
823 | ConvAccessExprWalker inputExprWalker; |
824 | for (AffineExpr expr : indexingMaps[0].getResults()) |
825 | (void)inputExprWalker.visit(expr); |
826 | inputExprWalker.clearMultiUseDims(map: indexingMaps[0]); |
827 | |
828 | return inferConvolutionDimsImpl(linalgOp, inputExprWalker, |
829 | /*allowEmptyConvolvedDims=*/false); |
830 | } |
831 | |
832 | namespace mlir::linalg::detail { |
833 | enum class MatchConvolutionResult { |
834 | Success = 0, |
835 | NotLinalgOp, |
836 | WrongNumOperands, |
837 | WrongInputIndexingMap, |
838 | NotProjectedPermutations, |
839 | NonConvolutionLoop, |
840 | OutputDimsNotParallel, |
841 | NonOutputDimNotReduction, |
842 | EmptyConvolvedDims |
843 | }; |
844 | } // namespace mlir::linalg::detail |
845 | |
846 | mlir::linalg::detail::MatchConvolutionResult |
847 | mlir::linalg::detail::isConvolutionInterfaceImpl( |
848 | Operation *op, ConvolutionDimensions *dimensions, |
849 | bool allowEmptyConvolvedDims) { |
850 | auto linalgOp = dyn_cast<linalg::LinalgOp>(op); |
851 | if (!linalgOp) |
852 | return MatchConvolutionResult::NotLinalgOp; |
853 | if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1) |
854 | return MatchConvolutionResult::WrongNumOperands; |
855 | |
856 | auto indexingMaps = linalgOp.getIndexingMapsArray(); |
857 | |
858 | // Check the input indexing map has the right form. |
859 | ConvAccessExprWalker inputExprWalker; |
860 | if (llvm::any_of(indexingMaps[0].getResults(), |
861 | [&inputExprWalker](AffineExpr expr) { |
862 | return failed(Result: inputExprWalker.visit(expr)); |
863 | })) { |
864 | return MatchConvolutionResult::WrongInputIndexingMap; |
865 | } |
866 | |
867 | // Filter and output maps must be projected permutation. |
868 | if (!indexingMaps[1].isProjectedPermutation() || |
869 | !indexingMaps.back().isProjectedPermutation()) |
870 | return MatchConvolutionResult::NotProjectedPermutations; |
871 | |
872 | auto iteratorTypes = linalgOp.getIteratorTypesArray(); |
873 | |
874 | llvm::SmallDenseSet<int64_t> outputDims = |
875 | getPreservedDims(indexingMaps.back()); |
876 | llvm::SmallDenseSet<int64_t> filterDims = getPreservedDims(indexingMaps[1]); |
877 | // Make sure all loops are characterized as one of: |
878 | // - Batch loop : present in output, as non-convolved in input, not present in |
879 | // filter. |
880 | // - Output image dimension : present in output, convolved dims in input, not |
881 | // present in filter. |
882 | // - Output channel dimension : present in output, not present in input, |
883 | // present in filter. |
884 | // - Filter loop dimension : present in filter, convolved in input, not |
885 | // present in output. |
886 | // - Input channel dimension : unconvolved in input, not present in output, |
887 | // present in filter. |
888 | // - Depth multiplier : unconvolved in input, present in output, present in |
889 | // filter. |
890 | llvm::SmallDenseSet<int64_t> allLoopDims; |
891 | for (auto outputExpr : indexingMaps.back().getResults()) { |
892 | int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition(); |
893 | if (inputExprWalker.unConvolvedDims.count(outputDim) && |
894 | !filterDims.count(outputDim)) { |
895 | // Batch dimension. |
896 | if (iteratorTypes[outputDim] != utils::IteratorType::parallel) |
897 | return MatchConvolutionResult::OutputDimsNotParallel; |
898 | allLoopDims.insert(outputDim); |
899 | continue; |
900 | } |
901 | if (inputExprWalker.convolvedDims.count(outputDim) && |
902 | !filterDims.count(outputDim)) { |
903 | // Output image Loop dimension. |
904 | if (iteratorTypes[outputDim] != utils::IteratorType::parallel) |
905 | return MatchConvolutionResult::OutputDimsNotParallel; |
906 | allLoopDims.insert(outputDim); |
907 | continue; |
908 | } |
909 | if (!inputExprWalker.convolvedDims.count(outputDim) && |
910 | !inputExprWalker.unConvolvedDims.count(outputDim) && |
911 | filterDims.count(outputDim)) { |
912 | // Output channel dimension. |
913 | if (iteratorTypes[outputDim] != utils::IteratorType::parallel) |
914 | return MatchConvolutionResult::OutputDimsNotParallel; |
915 | allLoopDims.insert(outputDim); |
916 | continue; |
917 | } |
918 | if (inputExprWalker.unConvolvedDims.count(outputDim) && |
919 | filterDims.count(outputDim)) { |
920 | // Depth multiplier. |
921 | if (iteratorTypes[outputDim] != utils::IteratorType::parallel) |
922 | return MatchConvolutionResult::OutputDimsNotParallel; |
923 | allLoopDims.insert(outputDim); |
924 | continue; |
925 | } |
926 | return MatchConvolutionResult::NonConvolutionLoop; |
927 | } |
928 | for (auto filterExpr : indexingMaps[1].getResults()) { |
929 | int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition(); |
930 | if (outputDims.count(filterDim) && |
931 | !inputExprWalker.unConvolvedDims.count(filterDim) && |
932 | !inputExprWalker.convolvedDims.count(filterDim)) { |
933 | // Output channel dimension. This is already seen, continue; |
934 | continue; |
935 | } |
936 | if (inputExprWalker.convolvedDims.count(filterDim) && |
937 | !outputDims.count(filterDim)) { |
938 | // Filter loop dimension. |
939 | if (iteratorTypes[filterDim] != utils::IteratorType::reduction) |
940 | return MatchConvolutionResult::NonOutputDimNotReduction; |
941 | if (allLoopDims.count(filterDim)) |
942 | return MatchConvolutionResult::NonConvolutionLoop; |
943 | allLoopDims.insert(filterDim); |
944 | continue; |
945 | } |
946 | if (inputExprWalker.unConvolvedDims.count(filterDim) && |
947 | !outputDims.count(filterDim)) { |
948 | // Input channel dimension. |
949 | if (iteratorTypes[filterDim] != utils::IteratorType::reduction) |
950 | return MatchConvolutionResult::NonOutputDimNotReduction; |
951 | if (allLoopDims.count(filterDim)) |
952 | return MatchConvolutionResult::NonConvolutionLoop; |
953 | allLoopDims.insert(filterDim); |
954 | continue; |
955 | } |
956 | if (inputExprWalker.unConvolvedDims.count(filterDim) && |
957 | outputDims.count(filterDim)) { |
958 | // Depthwise loop. Already seen. |
959 | continue; |
960 | } |
961 | return MatchConvolutionResult::NonConvolutionLoop; |
962 | } |
963 | // All loops must be covered now. |
964 | if (allLoopDims.size() != linalgOp.getNumLoops()) |
965 | return MatchConvolutionResult::NonConvolutionLoop; |
966 | |
967 | if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty()) |
968 | return MatchConvolutionResult::EmptyConvolvedDims; |
969 | |
970 | if (dimensions) { |
971 | FailureOr<ConvolutionDimensions> res = inferConvolutionDimsImpl( |
972 | linalgOp, inputExprWalker, allowEmptyConvolvedDims); |
973 | assert(succeeded(res) && "unexpected failure to infer convolution dims"); |
974 | *dimensions = *res; |
975 | } |
976 | |
977 | return MatchConvolutionResult::Success; |
978 | } |
979 | |
980 | StringRef |
981 | mlir::linalg::detail::getMatchConvolutionMessage(MatchConvolutionResult res) { |
982 | switch (res) { |
983 | case MatchConvolutionResult::NotLinalgOp: |
984 | return "expected a LinalgOp"; |
985 | case MatchConvolutionResult::WrongNumOperands: |
986 | return "expected op with 2 inputs and 1 output"; |
987 | case MatchConvolutionResult::WrongInputIndexingMap: |
988 | return "unexpected input index map for convolutions"; |
989 | case MatchConvolutionResult::NotProjectedPermutations: |
990 | return "expected output/filter indexing maps to be projected permutations"; |
991 | case MatchConvolutionResult::NonConvolutionLoop: |
992 | return "unexpected loop dimension for convolution op"; |
993 | case MatchConvolutionResult::OutputDimsNotParallel: |
994 | return "expected all iterators used to access outputs to be parallel"; |
995 | case MatchConvolutionResult::NonOutputDimNotReduction: |
996 | return "expected all iterators not used to access outputs to be reduction"; |
997 | case MatchConvolutionResult::EmptyConvolvedDims: |
998 | return "expected convolved dim to be non-empty"; |
999 | case MatchConvolutionResult::Success: |
1000 | return ""; |
1001 | } |
1002 | llvm_unreachable("unhandled MatchConvolutionResult case"); |
1003 | } |
1004 | |
1005 | bool mlir::linalg::isaConvolutionOpInterface(LinalgOp linalgOp, |
1006 | bool allowEmptyConvolvedDims) { |
1007 | return linalg::detail::isConvolutionInterfaceImpl( |
1008 | op: linalgOp.getOperation(), dimensions: nullptr, allowEmptyConvolvedDims) == |
1009 | linalg::detail::MatchConvolutionResult::Success; |
1010 | } |
1011 | |
1012 | LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) { |
1013 | MatchConvolutionResult res = isConvolutionInterfaceImpl(op); |
1014 | if (res != MatchConvolutionResult::Success) |
1015 | return op->emitError(message: getMatchConvolutionMessage(res)); |
1016 | return success(); |
1017 | } |
1018 | |
1019 | //===----------------------------------------------------------------------===// |
1020 | // FillOpInterface implementation |
1021 | //===----------------------------------------------------------------------===// |
1022 | |
1023 | enum class MatchFillResult { |
1024 | Success = 0, |
1025 | NotLinalgOp, |
1026 | WrongNumOperands, |
1027 | NotScalarInput |
1028 | }; |
1029 | |
1030 | static MatchFillResult isFillInterfaceImpl(Operation *op) { |
1031 | auto linalgOp = dyn_cast<linalg::LinalgOp>(op); |
1032 | if (!linalgOp) |
1033 | return MatchFillResult::NotLinalgOp; |
1034 | if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1) |
1035 | return MatchFillResult::WrongNumOperands; |
1036 | |
1037 | OpOperand *value = linalgOp.getDpsInputOperand(0); |
1038 | if (!linalgOp.isScalar(value)) |
1039 | return MatchFillResult::NotScalarInput; |
1040 | |
1041 | return MatchFillResult::Success; |
1042 | } |
1043 | |
1044 | LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) { |
1045 | auto res = isFillInterfaceImpl(op); |
1046 | if (res == MatchFillResult::NotLinalgOp) |
1047 | return op->emitError(message: "expected a LinalgOp"); |
1048 | if (res == MatchFillResult::WrongNumOperands) |
1049 | return op->emitError(message: "expected op with 1 input and 1 output"); |
1050 | if (res == MatchFillResult::NotScalarInput) |
1051 | return op->emitError(message: "expected op with scalar input"); |
1052 | |
1053 | return success(); |
1054 | } |
1055 | |
1056 | //===----------------------------------------------------------------------===// |
1057 | // StructuredOpInterface implementation |
1058 | //===----------------------------------------------------------------------===// |
1059 | |
1060 | SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b, |
1061 | Location loc) { |
1062 | SmallVector<OpFoldResult> res; |
1063 | for (OpOperand &opOperand : getOperation()->getOpOperands()) { |
1064 | for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i) |
1065 | res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i)); |
1066 | } |
1067 | return res; |
1068 | } |
1069 | |
1070 | SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() { |
1071 | SmallVector<int64_t, 4> res; |
1072 | assert(!hasDynamicShape() && "expected operands to have static shapes"); |
1073 | for (OpOperand &opOperand : getOperation()->getOpOperands()) |
1074 | llvm::append_range(res, getShape(&opOperand)); |
1075 | return res; |
1076 | } |
1077 | |
1078 | SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) { |
1079 | AffineMap map = getLoopsToShapesMap(); |
1080 | unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); |
1081 | auto viewSizes = createFlatListOfOperandDims(b, loc); |
1082 | SmallVector<Range, 4> res(numDims); |
1083 | for (unsigned idx = 0; idx < numRes; ++idx) { |
1084 | auto result = map.getResult(idx); |
1085 | if (auto d = dyn_cast<AffineDimExpr>(result)) { |
1086 | if (res[d.getPosition()].offset) |
1087 | continue; |
1088 | res[d.getPosition()] = |
1089 | Range{b.getIndexAttr(0), viewSizes[idx], b.getIndexAttr(1)}; |
1090 | } |
1091 | } |
1092 | return res; |
1093 | } |
1094 | |
1095 | /// Visitor to check if any of the given set of positions from AffineDimExprs |
1096 | /// are used within an AffineExpr. |
1097 | struct HasAffineDimExprVisitor |
1098 | : public AffineExprVisitor<HasAffineDimExprVisitor, bool> { |
1099 | HasAffineDimExprVisitor(llvm::SmallBitVector positions) |
1100 | : positions(std::move(positions)) {} |
1101 | |
1102 | bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) { |
1103 | return visit(expr: binaryOpExpr.getLHS()) || visit(expr: binaryOpExpr.getRHS()); |
1104 | } |
1105 | |
1106 | bool visitDimExpr(AffineDimExpr dimExpr) { |
1107 | return positions.test(Idx: dimExpr.getPosition()); |
1108 | } |
1109 | |
1110 | bool visitConstantExpr(AffineConstantExpr constExpr) { return false; } |
1111 | |
1112 | bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; } |
1113 | |
1114 | private: |
1115 | llvm::SmallBitVector positions; |
1116 | }; |
1117 | |
1118 | static std::pair<int64_t, int64_t> |
1119 | getResultsPositionInLoopsToShapeMap(LinalgOp &op) { |
1120 | int64_t inputRankSum = 0; |
1121 | int64_t outputRankSum = 0; |
1122 | for (OpOperand *input : op.getDpsInputOperands()) |
1123 | inputRankSum += op.getRank(input); |
1124 | for (OpOperand &output : op.getDpsInitsMutable()) |
1125 | outputRankSum += op.getRank(&output); |
1126 | return {inputRankSum, inputRankSum + outputRankSum}; |
1127 | } |
1128 | |
1129 | LogicalResult |
1130 | LinalgOp::reifyResultShapes(OpBuilder &b, |
1131 | ReifiedRankedShapedTypeDims &reifiedReturnShapes) { |
1132 | // An example that helps understand the logic below. |
1133 | // Consider the following expression O(i+j, j) += A(i,k) * B(k, j) |
1134 | // We want to express the shape of dim 0 of O in terms of shape of the inputs. |
1135 | // This is achieved as follows. |
1136 | // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1) |
1137 | // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1) |
1138 | // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2) |
1139 | // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap) |
1140 | // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1) |
1141 | AffineMap loopsToShapesMap = getLoopsToShapesMap(); |
1142 | |
1143 | // Find the position in the above map that represents the shape of the |
1144 | // result:dim being inferred. |
1145 | auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(*this); |
1146 | |
1147 | /// From loopsToShapesMap extract the submap that represents the shape of the |
1148 | /// (resultIdx, dim) needed. |
1149 | AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap( |
1150 | resultShapesSubMapPos.first, |
1151 | resultShapesSubMapPos.second - resultShapesSubMapPos.first); |
1152 | AffineMap resultShapesFromInputShapesMap = |
1153 | loopToResultsShapeMap.compose(getShapesToLoopsMap()); |
1154 | |
1155 | // Check that the result dim map does not contain the positions corresponding |
1156 | // to the outputs. |
1157 | llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims()); |
1158 | outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second); |
1159 | HasAffineDimExprVisitor checkDimExpr(std::move(outputDims)); |
1160 | Location loc = getOperation()->getLoc(); |
1161 | IRRewriter rewriter(b); |
1162 | SmallVector<OpFoldResult> allResultDimValues = |
1163 | affine::makeComposedFoldedMultiResultAffineApply( |
1164 | rewriter, loc, resultShapesFromInputShapesMap, |
1165 | createFlatListOfOperandDims(b, loc)); |
1166 | int64_t pos = 0; |
1167 | ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults(); |
1168 | for (OpOperand &opOperand : getDpsInitsMutable()) { |
1169 | SmallVector<OpFoldResult> shapes; |
1170 | for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) { |
1171 | auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType()); |
1172 | if (!shapedType.isDynamicDim(dim)) { |
1173 | // Static dim: Return IntegerAttr. |
1174 | shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim))); |
1175 | } else { |
1176 | // Dynamic dim: Return Value. |
1177 | OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos]) |
1178 | ? createOrFoldDimOp(b, loc, opOperand.get(), dim) |
1179 | : allResultDimValues[pos]; |
1180 | shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr)); |
1181 | } |
1182 | pos++; |
1183 | } |
1184 | reifiedReturnShapes.emplace_back(std::move(shapes)); |
1185 | } |
1186 | return success(); |
1187 | } |
1188 | |
1189 | /// Return the index in the indexingMaps vector that corresponds to this |
1190 | /// `opOperand`. |
1191 | int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) { |
1192 | auto operandNumber = opOperand->getOperandNumber(); |
1193 | auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation()); |
1194 | if (!dpsIface.isDpsInput(opOperand)) |
1195 | return operandNumber; |
1196 | unsigned start = dpsIface.getDpsInits().getBeginOperandIndex(); |
1197 | assert(!dpsIface.isDpsInit(opOperand)); |
1198 | // Account for potential inputs that are not DPS and may not appear in |
1199 | // `indexingMaps`. |
1200 | return cast<DestinationStyleOpInterface>(*this->getOperation()) |
1201 | .getNumDpsInputs() + |
1202 | operandNumber - start; |
1203 | } |
1204 | |
1205 | LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { |
1206 | LinalgOp linalgOp = cast<LinalgOp>(op); |
1207 | // Mixed tensor/buffer operands are not allowed. |
1208 | if (!linalgOp.hasPureTensorSemantics() && |
1209 | !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0) |
1210 | return op->emitOpError(message: "expected to have pure tensor or buffer semantics"); |
1211 | |
1212 | // Before checking indexing maps, we need to make sure the attributes |
1213 | // referenced by it are valid. |
1214 | if (linalgOp.hasDynamicIndexingMaps()) |
1215 | if (failed(linalgOp.verifyIndexingMapRequiredAttributes())) |
1216 | return failure(); |
1217 | |
1218 | // All input/output operands must be indexed. |
1219 | if (static_cast<int64_t>(linalgOp.getIndexingMapsArray().size()) != |
1220 | linalgOp->getNumOperands()) |
1221 | return op->emitOpError(message: "expected the number of indexing_map (") |
1222 | << linalgOp.getIndexingMapsArray().size() |
1223 | << ") to be equal to the number of input/output operands (" |
1224 | << linalgOp->getNumOperands() << ")"; |
1225 | |
1226 | // Set this flag if this op has user defined maps. This is required to guard |
1227 | // the below error condition which assume default indexing maps. |
1228 | for (OpOperand &opOperand : linalgOp->getOpOperands()) { |
1229 | AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand); |
1230 | |
1231 | // Symbols disallowed. |
1232 | if (indexingMap.getNumSymbols() != 0) |
1233 | return op->emitOpError("unexpected symbols in indexing_map #") |
1234 | << opOperand.getOperandNumber(); |
1235 | |
1236 | // Domain must be consistent. |
1237 | unsigned numLoops = linalgOp.getNumLoops(); |
1238 | if (indexingMap.getNumDims() != numLoops) |
1239 | return op->emitOpError("expected indexing_map #") |
1240 | << opOperand.getOperandNumber() << " to have "<< numLoops |
1241 | << " dim(s) to match the number of loops"; |
1242 | |
1243 | int64_t rank = linalgOp.getRank(&opOperand); |
1244 | |
1245 | if (indexingMap.getNumResults() != rank) |
1246 | return op->emitOpError("expected operand rank (") |
1247 | << rank << ") to match the result rank of indexing_map #" |
1248 | << opOperand.getOperandNumber() << " (" |
1249 | << indexingMap.getNumResults() << ")"; |
1250 | } |
1251 | SmallVector<unsigned> redDims; |
1252 | linalgOp.getReductionDims(redDims); |
1253 | |
1254 | if (!linalgOp.getShapesToLoopsMap()) |
1255 | return op->emitOpError(message: "expected the shape-to-loops map to be non-null"); |
1256 | |
1257 | // Check if given shapes match to inferred shapes. |
1258 | SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges(); |
1259 | SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0); |
1260 | // Verify only static cases since we can't get exact dimension sizes and |
1261 | // loop ranges for dynamic cases in this stage. |
1262 | if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) { |
1263 | for (int64_t &range : endLoopRangeValues) |
1264 | range -= 1; |
1265 | for (OpOperand &opOperand : linalgOp->getOpOperands()) { |
1266 | AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand); |
1267 | SmallVector<int64_t, 4> startIndices = |
1268 | indexingMap.compose(startLoopRangeValues); |
1269 | SmallVector<int64_t, 4> endIndices = |
1270 | indexingMap.compose(endLoopRangeValues); |
1271 | ArrayRef<int64_t> shape = linalgOp.getShape(&opOperand); |
1272 | for (auto dim : llvm::seq<int64_t>(0, shape.size())) { |
1273 | // Ignore dynamic dimension or the case that the dimension size is 0 |
1274 | if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0) |
1275 | continue; |
1276 | |
1277 | // The first index or last index should be the maximum or the minimum in |
1278 | // the inferred index ranges since the range is increasing or |
1279 | // decreasing. The size of dimensions of input/output operands and the |
1280 | // maximum value + 1 in the inferred range should be the same. But, for |
1281 | // now we check if the inferred ranges are in boundary of input/output |
1282 | // operands' size or not in case that Affine Expressions are complicated |
1283 | // such as d0 * 3 |
1284 | // + d1 since it is not easy to handle the issues. |
1285 | // Found the case that this solution can't check, for example, (d0, d1) |
1286 | // -> (d1 - d0) |
1287 | int64_t inferredDimSize = |
1288 | std::max(startIndices[dim], endIndices[dim]) + 1; |
1289 | if (std::min(startIndices[dim], endIndices[dim]) < 0) { |
1290 | std::string mapStr; |
1291 | { |
1292 | llvm::raw_string_ostream os(mapStr); |
1293 | os << indexingMap; |
1294 | } |
1295 | return op->emitOpError( |
1296 | "unexpected result less than 0 at expression #") |
1297 | << dim << " in "<< mapStr; |
1298 | } |
1299 | if (isa<AffineDimExpr>(indexingMap.getResult(dim))) { |
1300 | if (inferredDimSize != shape[dim]) { |
1301 | return op->emitOpError("inferred input/output operand #") |
1302 | << opOperand.getOperandNumber() << " has shape's dimension #" |
1303 | << dim << " to be "<< inferredDimSize << ", but found " |
1304 | << shape[dim]; |
1305 | } |
1306 | } else { |
1307 | if (inferredDimSize > shape[dim]) { |
1308 | return op->emitOpError("inferred input/output operand #") |
1309 | << opOperand.getOperandNumber() << " has shape's dimension #" |
1310 | << dim << " to be greater than or equal to " |
1311 | << inferredDimSize << ", but found "<< shape[dim]; |
1312 | } |
1313 | } |
1314 | } |
1315 | } |
1316 | } |
1317 | |
1318 | // Check the region has exactly one block. |
1319 | if (linalgOp->getNumRegions() != 1 || |
1320 | !llvm::hasSingleElement(linalgOp->getRegion(0))) |
1321 | return op->emitOpError(message: "expects to have 1 region with 1 block"); |
1322 | |
1323 | // Simplifying assumption: bbargs match 1-1 with shape operands elemental |
1324 | // types. |
1325 | // TODO: once ranked shape types are plugged in, we may want to drop the |
1326 | // corresponding bbargs, that can never be read from. This will be subject to |
1327 | // consistency discussions (i.e. what to do with output tensors whose bbarg is |
1328 | // not used). |
1329 | Block &block = linalgOp->getRegion(0).front(); |
1330 | |
1331 | if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments()) |
1332 | return op->emitOpError(message: "expected as many non-induction variable region " |
1333 | "arguments as the number of input/output operands"); |
1334 | |
1335 | for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) { |
1336 | Type elementType = opOperand->get().getType(); |
1337 | if (isa<MemRefType, RankedTensorType>(elementType)) |
1338 | elementType = getElementTypeOrSelf(opOperand->get().getType()); |
1339 | Type argType = block.getArgument(opOperand->getOperandNumber()).getType(); |
1340 | if (elementType != argType) |
1341 | return op->emitOpError("expected type of bb argument #") |
1342 | << opOperand->getOperandNumber() << " ("<< argType << ")" |
1343 | << " to match element or self type of the corresponding operand (" |
1344 | << elementType << ")"; |
1345 | } |
1346 | |
1347 | return success(); |
1348 | } |
1349 |
Definitions
- canOpOperandsBeDroppedImpl
- isaCopyOpInterface
- isaFillOpInterface
- isaBroadcastOpInterface
- isaTransposeOpInterface
- isaElemwiseSingleUnaryOrBinaryOpInterface
- isaElemwiseSingleUnaryOpInterface
- isaElemwiseSingleBinaryOpInterface
- getSourceSkipUnary
- isContractionBody
- isPairTemplateImpl
- isContractionBody
- findPermutationsIndexingOperand
- par
- red
- inferIteratorsFromOutMap
- inferContractionDimsImpl
- inferContractionDims
- inferContractionDims
- MatchContractionResult
- isContractionInterfaceImpl
- getMatchContractionMessage
- isaContractionOpInterface
- verifyContractionInterface
- getAffineExprOfType
- ConvAccessExprWalker
- clearMultiUseDims
- visitDimExpr
- visitSymbolExpr
- visitConstantExpr
- visitAffineBinaryOpExpr
- getDimExprOrMulExprDimPos
- getPreservedDims
- getConstantsFromExprList
- inferConvolutionDimsImpl
- inferConvolutionDims
- MatchConvolutionResult
- isConvolutionInterfaceImpl
- getMatchConvolutionMessage
- isaConvolutionOpInterface
- verifyConvolutionInterface
- MatchFillResult
- isFillInterfaceImpl
- verifyFillInterface
- HasAffineDimExprVisitor
- HasAffineDimExprVisitor
- visitAffineBinaryOpExpr
- visitDimExpr
- visitConstantExpr
- visitSymbolExpr
- getResultsPositionInLoopsToShapeMap
Improve your Profiling and Debugging skills
Find out more