1 | //===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" |
10 | |
11 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
12 | #include "mlir/Dialect/Arith/IR/Arith.h" |
13 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
14 | #include "mlir/Dialect/Complex/IR/Complex.h" |
15 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
17 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
18 | #include "mlir/IR/AffineExprVisitor.h" |
19 | #include "mlir/IR/AffineMap.h" |
20 | #include "mlir/IR/TypeUtilities.h" |
21 | #include "llvm/ADT/SetOperations.h" |
22 | #include "llvm/ADT/SmallBitVector.h" |
23 | #include "llvm/ADT/SmallVector.h" |
24 | #include <algorithm> |
25 | |
26 | using namespace mlir; |
27 | using namespace mlir::linalg; |
28 | |
29 | /// Include the definitions of the copy operation interface. |
30 | #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc" |
31 | |
32 | //===----------------------------------------------------------------------===// |
33 | // Interface utility functions |
34 | //===----------------------------------------------------------------------===// |
35 | |
36 | bool linalg::detail::canOpOperandsBeDroppedImpl( |
37 | linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) { |
38 | SmallVector<AffineMap> indexingMaps; |
39 | for (auto &opOperand : linalgOp->getOpOperands()) { |
40 | if (llvm::is_contained(droppedOperands, &opOperand)) |
41 | continue; |
42 | indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand)); |
43 | } |
44 | if (indexingMaps.empty()) { |
45 | // If there are no indexing maps, the operand can only be dropped |
46 | // if the op has no loops. |
47 | return linalgOp.getNumLoops() == 0; |
48 | } |
49 | return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap(); |
50 | } |
51 | |
52 | //===----------------------------------------------------------------------===// |
53 | // CopyOpInterface implementation |
54 | //===----------------------------------------------------------------------===// |
55 | |
56 | bool linalg::isaCopyOpInterface(LinalgOp linalgOp) { |
57 | // Structural. |
58 | if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops()) |
59 | return false; |
60 | |
61 | // Operands and maps. |
62 | if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1) |
63 | return false; |
64 | auto mapRange = linalgOp.getIndexingMapsArray(); |
65 | if (mapRange.size() != 2 || !mapRange.front().isIdentity() || |
66 | !mapRange.back().isIdentity()) { |
67 | return false; |
68 | } |
69 | // Region. |
70 | return llvm::hasSingleElement(linalgOp.getBlock()->getOperations()); |
71 | } |
72 | |
73 | //===----------------------------------------------------------------------===// |
74 | // ContractionOpInterface implementation |
75 | //===----------------------------------------------------------------------===// |
76 | |
77 | /// If the value is defined by a chain of unary side effect-free, go up the |
78 | /// use-def chain until the first value that isn't defined by such an op. |
79 | // TODO: relax to multi-operands with constants, which are technically unary ops |
80 | // as needed (e.g. add5). |
81 | static Value getSourceSkipUnary(Value value) { |
82 | Operation *op = value.getDefiningOp(); |
83 | while (op && op->getNumOperands() == 1) { |
84 | auto iface = dyn_cast<MemoryEffectOpInterface>(op); |
85 | if (!iface || !iface.hasNoEffect()) |
86 | break; |
87 | value = op->getOperand(idx: 0); |
88 | op = value.getDefiningOp(); |
89 | } |
90 | return value; |
91 | } |
92 | |
93 | bool mlir::linalg::detail::isContractionBody( |
94 | Block &block, function_ref<bool(Operation *, Operation *)> isaPair, |
95 | llvm::raw_ostream &errs) { |
96 | if (block.empty() || !block.back().mightHaveTrait<OpTrait::IsTerminator>()) { |
97 | errs << "no terminator in the block" ; |
98 | return false; |
99 | } |
100 | |
101 | if (block.getNumArguments() != 3) { |
102 | errs << "expected block with 3 arguments" ; |
103 | return false; |
104 | } |
105 | |
106 | Operation *terminator = block.getTerminator(); |
107 | if (terminator->getNumOperands() != 1) { |
108 | errs << "expected terminator with 1 operand" ; |
109 | return false; |
110 | } |
111 | |
112 | Value yielded = getSourceSkipUnary(value: terminator->getOperand(idx: 0)); |
113 | Operation *reductionOp = yielded.getDefiningOp(); |
114 | if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) { |
115 | errs << "expected reduction op to be binary" ; |
116 | return false; |
117 | } |
118 | |
119 | Value reductionLHS = getSourceSkipUnary(value: reductionOp->getOperand(idx: 0)); |
120 | Value reductionRHS = getSourceSkipUnary(value: reductionOp->getOperand(idx: 1)); |
121 | |
122 | if (reductionLHS != block.getArgument(i: 2) && |
123 | reductionRHS != block.getArgument(i: 2)) { |
124 | errs << "expected reduction to take block argument #2 as one of the " |
125 | "operands (modulo unary casts)" ; |
126 | return false; |
127 | } |
128 | |
129 | Value contributed = getSourceSkipUnary( |
130 | value: isa<BlockArgument>(Val: reductionLHS) ? reductionRHS : reductionLHS); |
131 | Operation *elementwiseOp = contributed.getDefiningOp(); |
132 | if (elementwiseOp->getNumResults() != 1 || |
133 | elementwiseOp->getNumOperands() != 2) { |
134 | errs << "expected elementwise op to be binary" ; |
135 | return false; |
136 | } |
137 | |
138 | if (!isaPair(elementwiseOp, reductionOp)) { |
139 | errs << "expected reduction/elementwise op kind not satisfied" ; |
140 | return false; |
141 | } |
142 | |
143 | Value elementwiseLHS = getSourceSkipUnary(value: elementwiseOp->getOperand(idx: 0)); |
144 | Value elementwiseRHS = getSourceSkipUnary(value: elementwiseOp->getOperand(idx: 1)); |
145 | if ((elementwiseLHS == block.getArgument(i: 0) && |
146 | elementwiseRHS == block.getArgument(i: 1)) || |
147 | (elementwiseLHS == block.getArgument(i: 1) && |
148 | elementwiseRHS == block.getArgument(i: 0))) { |
149 | return true; |
150 | } |
151 | |
152 | errs << "expected elementwise op to apply to block arguments (modulo unary " |
153 | "casts)" ; |
154 | return false; |
155 | } |
156 | |
157 | /// Returns true if the two operations are of the kinds specified by a pair of |
158 | /// consecutive template arguments. |
159 | template <typename AddOpTy, typename MulOpTy, typename... Args> |
160 | static bool isPairTemplateImpl(Operation *add, Operation *mul) { |
161 | static_assert(sizeof...(Args) % 2 == 0, |
162 | "expected an even number of template arguments" ); |
163 | if (isa<AddOpTy>(add) && isa<MulOpTy>(mul)) |
164 | return true; |
165 | |
166 | if constexpr (sizeof...(Args) > 0) |
167 | return isPairTemplateImpl<Args...>(add, mul); |
168 | else |
169 | return false; |
170 | } |
171 | |
172 | /// Returns true if the block is a body of a contraction with the kinds of |
173 | /// operations given pairwise by template arguments. |
174 | template <typename... Args> |
175 | static bool isContractionBody(Block &block) { |
176 | return linalg::detail::isContractionBody(block, isaPair: &isPairTemplateImpl<Args...>); |
177 | } |
178 | |
179 | /// Given an `indexingMap` and its corresponding `iterators`, returns |
180 | /// the positions of the iterators of type `iter` that are indexed by |
181 | /// the `indexingMap` as a permutation. This is useful to infer various |
182 | /// subcomputations on a `LinalgOp`. This is performed by looking up |
183 | /// each result in the `indexingMap` and determining whether: |
184 | /// - It is a single AffineDimExpr. |
185 | /// - It is the only result involving this AffineDimExpr. |
186 | static llvm::SmallDenseSet<int64_t> |
187 | findPermutationsIndexingOperand(AffineMap indexingMap, |
188 | ArrayRef<utils::IteratorType> iterators, |
189 | utils::IteratorType iter) { |
190 | assert(iterators.size() == indexingMap.getNumDims()); |
191 | llvm::SmallDenseSet<int64_t> res; |
192 | for (AffineExpr e : indexingMap.getResults()) { |
193 | if (auto d = dyn_cast<AffineDimExpr>(Val&: e)) { |
194 | if (iterators[d.getPosition()] == iter && |
195 | llvm::count_if(Range: indexingMap.getResults(), P: [d](AffineExpr e) { |
196 | return e.isFunctionOfDim(position: d.getPosition()); |
197 | }) == 1) |
198 | res.insert(V: d.getPosition()); |
199 | } |
200 | } |
201 | return res; |
202 | } |
203 | |
204 | namespace { |
205 | auto par = utils::IteratorType::parallel; |
206 | auto red = utils::IteratorType::reduction; |
207 | } // namespace |
208 | |
209 | /// Infer the iterator types from the init affine map. This looks at which dims |
210 | /// are present in the map results, and returns an iterator types array with |
211 | /// parallel types for dims that are present, and reduction types for dims that |
212 | /// are not present. |
213 | static FailureOr<SmallVector<utils::IteratorType>> |
214 | inferIteratorsFromOutMap(AffineMap map) { |
215 | if (!map.isProjectedPermutation()) |
216 | return failure(); |
217 | SmallVector<utils::IteratorType> iterators(map.getNumDims(), red); |
218 | for (auto expr : map.getResults()) |
219 | if (auto dim = dyn_cast<AffineDimExpr>(Val&: expr)) |
220 | iterators[dim.getPosition()] = par; |
221 | return iterators; |
222 | } |
223 | |
224 | /// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form |
225 | /// a matmul subcomputation within `linalgOp`. These dimensions are such that: |
226 | /// 1. The m dimension is involved in an outer-product along LHS |
227 | /// (i.e. it is a permutation on RES and LHS and does not appear in RHS). |
228 | /// 2. The n dimension is involved in an outer-product along RHS |
229 | /// (i.e. it is a permutation on RES and RHS and does not appear in LHS). |
230 | /// 3. The k dimension appears as a permutation on LHS and RHS. |
231 | /// 4. m, n and k appear only once in any given indexing. |
232 | /// 5. Optional batch dimensions that appear in all operands are captured. |
233 | /// This allows e.g. detecting that some contraction is embedded within |
234 | /// `linalgOp` with some orthogonal heuristic. |
235 | static FailureOr<ContractionDimensions> |
236 | inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps, |
237 | ArrayRef<utils::IteratorType> iterators) { |
238 | llvm::SmallDenseSet<int64_t> a = |
239 | findPermutationsIndexingOperand(indexingMaps[0], iterators, par); |
240 | llvm::SmallDenseSet<int64_t> b = |
241 | findPermutationsIndexingOperand(indexingMaps[1], iterators, par); |
242 | llvm::SmallDenseSet<int64_t> c = |
243 | findPermutationsIndexingOperand(indexingMaps[2], iterators, par); |
244 | |
245 | // A & C - B are the iterators involved in an outer-product along A (the LHS). |
246 | llvm::SmallDenseSet<int64_t> ac = a; |
247 | llvm::set_intersect(S1&: ac, S2: c); |
248 | llvm::set_subtract(S1&: ac, S2: b); |
249 | // B & C - A are the iterators involved in an outer-product along B (the RHS). |
250 | llvm::SmallDenseSet<int64_t> bc = b; |
251 | llvm::set_intersect(S1&: bc, S2: c); |
252 | llvm::set_subtract(S1&: bc, S2: a); |
253 | // A & B & C are the "batch" dimensions. |
254 | llvm::SmallDenseSet<int64_t> batches = a; |
255 | llvm::set_intersect(S1&: batches, S2: b); |
256 | llvm::set_intersect(S1&: batches, S2: c); |
257 | |
258 | // A & B red are the reduction dimensions. |
259 | llvm::SmallDenseSet<int64_t> ra = |
260 | findPermutationsIndexingOperand(indexingMaps[0], iterators, red); |
261 | llvm::SmallDenseSet<int64_t> rb = |
262 | findPermutationsIndexingOperand(indexingMaps[1], iterators, red); |
263 | llvm::set_intersect(S1&: ra, S2: rb); |
264 | |
265 | // Return each set in sorted order. |
266 | ContractionDimensions dimensions{ |
267 | .batch: SmallVector<unsigned, 2>(batches.begin(), batches.end()), |
268 | .m: SmallVector<unsigned, 2>(ac.begin(), ac.end()), |
269 | .n: SmallVector<unsigned, 2>(bc.begin(), bc.end()), |
270 | .k: SmallVector<unsigned, 2>(ra.begin(), ra.end())}; |
271 | llvm::sort(Start: dimensions.batch.begin(), End: dimensions.batch.end()); |
272 | llvm::sort(Start: dimensions.m.begin(), End: dimensions.m.end()); |
273 | llvm::sort(Start: dimensions.n.begin(), End: dimensions.n.end()); |
274 | llvm::sort(Start: dimensions.k.begin(), End: dimensions.k.end()); |
275 | return dimensions; |
276 | } |
277 | |
278 | FailureOr<ContractionDimensions> |
279 | mlir::linalg::inferContractionDims(LinalgOp linalgOp) { |
280 | if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2) |
281 | return failure(); |
282 | return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(), |
283 | linalgOp.getIteratorTypesArray()); |
284 | } |
285 | |
286 | FailureOr<ContractionDimensions> |
287 | mlir::linalg::inferContractionDims(ArrayRef<AffineMap> indexingMaps) { |
288 | if (indexingMaps.size() != 3) |
289 | return failure(); |
290 | auto iterators = inferIteratorsFromOutMap(indexingMaps[2]); |
291 | if (failed(iterators)) |
292 | return failure(); |
293 | return inferContractionDimsImpl(indexingMaps, iterators.value()); |
294 | } |
295 | |
296 | namespace mlir::linalg::detail { |
297 | enum class MatchContractionResult { |
298 | Success = 0, |
299 | NotLinalgOp, |
300 | WrongNumOperands, |
301 | NoReduction, |
302 | NotProjectedPermutations, |
303 | NotAddMul |
304 | }; |
305 | } // namespace mlir::linalg::detail |
306 | |
307 | mlir::linalg::detail::MatchContractionResult |
308 | mlir::linalg::detail::isContractionInterfaceImpl( |
309 | Operation *op, mlir::linalg::ContractionDimensions *dimensions) { |
310 | auto linalgOp = dyn_cast<linalg::LinalgOp>(op); |
311 | if (!linalgOp) |
312 | return MatchContractionResult::NotLinalgOp; |
313 | if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1) |
314 | return MatchContractionResult::WrongNumOperands; |
315 | auto mapRange = linalgOp.getIndexingMapsArray(); |
316 | if (linalgOp.getNumReductionLoops() == 0) |
317 | return MatchContractionResult::NoReduction; |
318 | if (llvm::any_of(mapRange, |
319 | [](AffineMap m) { return !m.isProjectedPermutation(); })) |
320 | return MatchContractionResult::NotProjectedPermutations; |
321 | // TODO: more fields than add/mul. |
322 | // clang-format off |
323 | if (!::isContractionBody< |
324 | arith::MulFOp, arith::AddFOp, |
325 | arith::MulIOp, arith::AddIOp, |
326 | complex::MulOp, complex::AddOp, |
327 | arith::AndIOp, arith::OrIOp>( |
328 | *linalgOp.getBlock())) { |
329 | return MatchContractionResult::NotAddMul; |
330 | } |
331 | // clang-format on |
332 | |
333 | if (dimensions) { |
334 | FailureOr<ContractionDimensions> res = inferContractionDims(linalgOp); |
335 | assert(succeeded(res) && "unexpected failure to infer contraction dims" ); |
336 | *dimensions = *res; |
337 | } |
338 | return MatchContractionResult::Success; |
339 | } |
340 | |
341 | StringRef |
342 | mlir::linalg::detail::getMatchContractionMessage(MatchContractionResult res) { |
343 | switch (res) { |
344 | case MatchContractionResult::NotLinalgOp: |
345 | return "expected a LinalgOp" ; |
346 | case MatchContractionResult::WrongNumOperands: |
347 | return "expected op with 2 inputs and 1 output" ; |
348 | case MatchContractionResult::NoReduction: |
349 | return "expected at least 1 reduction" ; |
350 | case MatchContractionResult::NotProjectedPermutations: |
351 | return "expected indexing maps to be projected permutations" ; |
352 | case MatchContractionResult::NotAddMul: |
353 | return "expected add/mul op in the body" ; |
354 | case MatchContractionResult::Success: |
355 | return "" ; |
356 | } |
357 | llvm_unreachable("unhandled MatchContractionResult case" ); |
358 | } |
359 | |
360 | bool mlir::linalg::isaContractionOpInterface(LinalgOp linalgOp) { |
361 | if (!linalgOp) |
362 | return false; |
363 | Operation *op = linalgOp.getOperation(); |
364 | return isa<ContractionOpInterface>(op) || |
365 | (mlir::linalg::detail::isContractionInterfaceImpl(op) == |
366 | mlir::linalg::detail::MatchContractionResult::Success); |
367 | } |
368 | |
369 | /// Verify that a LinalgOp `op` is a contraction. |
370 | /// A Linalg contraction is defined in general terms: |
371 | /// 1. Has 2 input and 1 output shapes. |
372 | /// 2. Has at least one reduction dimension. |
373 | /// 3. Has only projected permutation indexing maps. |
374 | /// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field |
375 | /// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary |
376 | /// operations that may change the type (e.g. for mixed-precision). |
377 | /// As a consequence, when vectorization of such an op occurs, the only special |
378 | /// behavior is that the (unique) MulOpType is vectorized into a |
379 | /// `vector.contract`. All other ops are handled in a generic fashion. |
380 | /// In the future, we may wish to allow more input arguments and elementwise and |
381 | /// constant operations that do not involve the reduction dimension(s). |
382 | LogicalResult mlir::linalg::detail::verifyContractionInterface(Operation *op) { |
383 | auto res = isContractionInterfaceImpl(op); |
384 | if (res != MatchContractionResult::Success) |
385 | return op->emitError(message: getMatchContractionMessage(res)); |
386 | return success(); |
387 | } |
388 | |
389 | //===----------------------------------------------------------------------===// |
390 | // ConvolutionOpInterface implementation |
391 | //===----------------------------------------------------------------------===// |
392 | |
393 | /// Of the given two expressions returns one that is of type T (`lhs` gets |
394 | /// preference over `rhs`) |
395 | template <typename T> |
396 | static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs) { |
397 | return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) : nullptr); |
398 | } |
399 | |
400 | namespace { |
401 | /// Walk the indexing expressions for input of a convolution operation to verify |
402 | /// its of the right form, either |
403 | /// - AffineDimExpr |
404 | /// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))? |
405 | /// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)* |
406 | /// |
407 | /// classifies the AffineDimExpr as convolved dimensions or unconvolved |
408 | /// dimensions and verifies each dimension occurs only once. |
409 | struct ConvAccessExprWalker |
410 | : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> { |
411 | // Stores dimensions used in expressions of the above form. |
412 | llvm::SmallDenseSet<int64_t> convolvedDims; |
413 | // Stores the dual mapping between LHS and RHS of convolution exprs. |
414 | llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping; |
415 | // Stores single use dimensions used by an AffineDimExpr. |
416 | llvm::SmallDenseSet<int64_t> unConvolvedDims; |
417 | // Stores a mapping from convolved dims to their coefficient. |
418 | llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping; |
419 | |
420 | // Removes dims with multiple uses in the source input map from dimension |
421 | // sets tracked by this walker. |
422 | void clearMultiUseDims(AffineMap map) { |
423 | for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) { |
424 | if (llvm::count_if(Range: map.getResults(), P: [dimPos](AffineExpr e) { |
425 | return e.isFunctionOfDim(position: dimPos); |
426 | }) > 1) { |
427 | convolvedDims.erase(V: dimPos); |
428 | unConvolvedDims.erase(V: dimPos); |
429 | // If a duplicate dim is marked as convolved, the pair of the duplicate |
430 | // dim must be removed from the map as well. |
431 | if (convolvedDimMapping.contains(Val: dimPos)) { |
432 | int64_t pairedDim = convolvedDimMapping[dimPos]; |
433 | convolvedDims.erase(V: pairedDim); |
434 | unConvolvedDims.erase(V: pairedDim); |
435 | strideAndDilationMapping.erase(Val: pairedDim); |
436 | convolvedDimMapping.erase(Val: dimPos); |
437 | convolvedDimMapping.erase(Val: pairedDim); |
438 | } |
439 | } |
440 | } |
441 | } |
442 | |
443 | LogicalResult visitDimExpr(AffineDimExpr dimExpr) { |
444 | unsigned position = dimExpr.getPosition(); |
445 | if (unConvolvedDims.count(V: position) || convolvedDims.count(V: position)) { |
446 | return failure(); |
447 | } |
448 | unConvolvedDims.insert(V: position); |
449 | return success(); |
450 | } |
451 | |
452 | LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); } |
453 | |
454 | LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); } |
455 | |
456 | LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) { |
457 | // In pre-order visit, top level op has to be an add op. |
458 | if (binaryExpr.getKind() != AffineExprKind::Add) |
459 | return failure(); |
460 | auto lhsDimPos = getDimExprOrMulExprDimPos(expr: binaryExpr.getLHS()); |
461 | auto rhsDimPos = getDimExprOrMulExprDimPos(expr: binaryExpr.getRHS()); |
462 | if (failed(result: lhsDimPos) || failed(result: rhsDimPos)) |
463 | return failure(); |
464 | convolvedDimMapping[*lhsDimPos] = *rhsDimPos; |
465 | convolvedDimMapping[*rhsDimPos] = *lhsDimPos; |
466 | return success(); |
467 | } |
468 | |
469 | FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) { |
470 | if (auto dimExpr = dyn_cast<AffineDimExpr>(Val&: expr)) { |
471 | int64_t dim = dimExpr.getPosition(); |
472 | if (convolvedDims.count(V: dim) || unConvolvedDims.count(V: dim)) |
473 | return failure(); |
474 | // Stride/dilation for this dim is implicitly 1. |
475 | strideAndDilationMapping[dim] = |
476 | getAffineConstantExpr(constant: 1, context: expr.getContext()); |
477 | convolvedDims.insert(V: dim); |
478 | return dim; |
479 | } |
480 | if (auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(Val&: expr)) { |
481 | if (symbolMulExpr.getKind() != AffineExprKind::Mul) |
482 | return failure(); |
483 | auto lhsExpr = symbolMulExpr.getLHS(); |
484 | auto rhsExpr = symbolMulExpr.getRHS(); |
485 | // Check for symbol expression. |
486 | AffineExpr mulExpr = |
487 | getAffineExprOfType<AffineSymbolExpr>(lhs: lhsExpr, rhs: rhsExpr); |
488 | // If there was no symbol expr, check for constant expression. |
489 | if (!mulExpr) { |
490 | mulExpr = getAffineExprOfType<AffineConstantExpr>(lhs: lhsExpr, rhs: rhsExpr); |
491 | } |
492 | auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhs: lhsExpr, rhs: rhsExpr); |
493 | if (!mulExpr || !dimExpr) |
494 | return failure(); |
495 | int64_t dim = dimExpr.getPosition(); |
496 | if (convolvedDims.count(V: dim) || unConvolvedDims.count(V: dim)) |
497 | return failure(); |
498 | strideAndDilationMapping[dim] = mulExpr; |
499 | convolvedDims.insert(V: dim); |
500 | return dim; |
501 | } |
502 | return failure(); |
503 | } |
504 | }; |
505 | } // namespace |
506 | |
507 | static llvm::SmallDenseSet<int64_t> getPreservedDims(AffineMap map) { |
508 | assert(map.isProjectedPermutation() && |
509 | "expected map to have projected permutations" ); |
510 | llvm::SmallDenseSet<int64_t> preservedDims; |
511 | for (auto expr : map.getResults()) |
512 | preservedDims.insert(V: cast<AffineDimExpr>(Val&: expr).getPosition()); |
513 | return preservedDims; |
514 | } |
515 | |
516 | static SmallVector<int64_t, 2> |
517 | getConstantsFromExprList(const SmallVector<AffineExpr, 2> &exprs) { |
518 | SmallVector<int64_t, 2> vals; |
519 | for (auto e : exprs) { |
520 | auto constantExpr = dyn_cast<AffineConstantExpr>(Val&: e); |
521 | assert(constantExpr && "Found non-constant stride/dilation" ); |
522 | vals.push_back(Elt: constantExpr.getValue()); |
523 | } |
524 | return vals; |
525 | } |
526 | |
527 | /// Classifies dimensions in the `linalgOp` used by a convolution |
528 | /// subcomputation, as captured by `inputExprWalker`. If |
529 | /// `allowEmptyConvolvedDims` is not set this this will fail if there is not |
530 | /// at least convolved dimension pair (output image + filter loop). Convolution |
531 | /// dimensions are specified in sorted order, and strides match the order of |
532 | /// the filter loop dimensions, while the dilations match the order of the |
533 | /// output image dimensions. |
534 | static FailureOr<ConvolutionDimensions> |
535 | inferConvolutionDimsImpl(LinalgOp linalgOp, |
536 | ConvAccessExprWalker &inputExprWalker, |
537 | bool allowEmptyConvolvedDims) { |
538 | auto filterMap = |
539 | linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1)); |
540 | auto outputMap = |
541 | linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0)); |
542 | llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand( |
543 | filterMap, linalgOp.getIteratorTypesArray(), par); |
544 | llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand( |
545 | outputMap, linalgOp.getIteratorTypesArray(), par); |
546 | |
547 | // unConvolvedDims & outputDims - filterDims are the batch iterators. |
548 | llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims; |
549 | llvm::set_intersect(S1&: batch, S2: outputDims); |
550 | llvm::set_subtract(S1&: batch, S2: filterDims); |
551 | |
552 | // convolvedDims & outputDims are the output image iterators. |
553 | llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims; |
554 | llvm::set_intersect(S1&: oi, S2: outputDims); |
555 | |
556 | // filterDims & outputDims - unConvolvedDims are the output channel iterators. |
557 | llvm::SmallDenseSet<int64_t> oc = filterDims; |
558 | llvm::set_intersect(S1&: oc, S2: outputDims); |
559 | llvm::set_subtract(S1&: oc, S2: inputExprWalker.unConvolvedDims); |
560 | |
561 | // filterDims & outputDims & unConvolvedDims are the depth iterators. |
562 | llvm::SmallDenseSet<int64_t> depth = filterDims; |
563 | llvm::set_intersect(S1&: depth, S2: outputDims); |
564 | llvm::set_intersect(S1&: depth, S2: inputExprWalker.unConvolvedDims); |
565 | |
566 | llvm::SmallDenseSet<int64_t> filterReducedDims = |
567 | findPermutationsIndexingOperand(filterMap, |
568 | linalgOp.getIteratorTypesArray(), red); |
569 | |
570 | // convolvedDims & filterReducedDims are the filter loop iterators. |
571 | llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims; |
572 | llvm::set_intersect(S1&: fl, S2: filterReducedDims); |
573 | |
574 | // unConvolvedDims & filterReducedDims are the input channel iterators. |
575 | llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims; |
576 | llvm::set_intersect(S1&: ic, S2: filterReducedDims); |
577 | |
578 | if (oi.empty() && !allowEmptyConvolvedDims) |
579 | return failure(); |
580 | |
581 | // Return each set in sorted order. |
582 | ConvolutionDimensions dimensions{ |
583 | .batch: SmallVector<unsigned, 2>(batch.begin(), batch.end()), |
584 | .outputImage: SmallVector<unsigned, 2>(oi.begin(), oi.end()), |
585 | .outputChannel: SmallVector<unsigned, 2>(oc.begin(), oc.end()), |
586 | .filterLoop: SmallVector<unsigned, 2>(fl.begin(), fl.end()), |
587 | .inputChannel: SmallVector<unsigned, 2>(ic.begin(), ic.end()), |
588 | .depth: SmallVector<unsigned, 2>(depth.begin(), depth.end()), |
589 | /*strides=*/SmallVector<int64_t, 2>{}, |
590 | /*dilations=*/SmallVector<int64_t, 2>{}}; |
591 | llvm::sort(Start: dimensions.batch.begin(), End: dimensions.batch.end()); |
592 | llvm::sort(Start: dimensions.outputImage.begin(), End: dimensions.outputImage.end()); |
593 | llvm::sort(Start: dimensions.outputChannel.begin(), End: dimensions.outputChannel.end()); |
594 | llvm::sort(Start: dimensions.filterLoop.begin(), End: dimensions.filterLoop.end()); |
595 | llvm::sort(Start: dimensions.inputChannel.begin(), End: dimensions.inputChannel.end()); |
596 | llvm::sort(Start: dimensions.depth.begin(), End: dimensions.depth.end()); |
597 | |
598 | // Use the op carried strides/dilations attribute if present. |
599 | auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides" ); |
600 | if (!nativeStrides) { |
601 | SmallVector<AffineExpr, 2> strideExprs; |
602 | for (unsigned oiDim : dimensions.outputImage) |
603 | strideExprs.push_back(Elt: inputExprWalker.strideAndDilationMapping[oiDim]); |
604 | dimensions.strides = getConstantsFromExprList(exprs: strideExprs); |
605 | } else { |
606 | dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>()); |
607 | } |
608 | auto nativeDilations = |
609 | linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations" ); |
610 | if (!nativeDilations) { |
611 | SmallVector<AffineExpr, 2> dilationExprs; |
612 | for (unsigned flDim : dimensions.filterLoop) |
613 | dilationExprs.push_back(Elt: inputExprWalker.strideAndDilationMapping[flDim]); |
614 | dimensions.dilations = getConstantsFromExprList(exprs: dilationExprs); |
615 | } else { |
616 | dimensions.dilations = |
617 | llvm::to_vector<2>(nativeDilations.getValues<int64_t>()); |
618 | } |
619 | return dimensions; |
620 | } |
621 | |
622 | /// Find at least 1 parallel (output_image) and reduction (filter_loop) |
623 | /// dimension candidates that form a convolution subcomputation within |
624 | /// `linalgOp`. The LHS is assumed to be the convolution input while the |
625 | /// RHS is assumed as the filter. |
626 | /// These dimensions are such that: |
627 | /// 1. Optional batch dimensions that appear in the input and filter. |
628 | /// 2. The output_image dimension is involved in a cross-correlation along LHS |
629 | /// (i.e. it is a permutation on RES and LHS and has an associated |
630 | /// filter_loop in RHS). |
631 | /// 3. Optional output_channel dimension is involved in an outer-product along |
632 | /// RHS (i.e. it is a permutation on RES and RHS and does not appear in |
633 | /// LHS). |
634 | /// 4. Optional input_channel dimension appears as a permutation on LHS and |
635 | /// RHS. |
636 | /// 5. The filter_loop dimension appears as a permutation on the RHS and |
637 | /// represents the shape of the kernel cross-correlated along a |
638 | /// corresponding output_image dim. |
639 | /// 6. The input_channel dimension appears as a permutation on LHS and RHS. |
640 | /// 7. All dimensions appear only once in any given indexing map. |
641 | /// This allows e.g. detecting that some convolution is embedded within |
642 | /// `linalgOp` with some orthogonal heuristic. |
643 | /// When multiple dimension occurrences exist that match any classification |
644 | /// indices are returned in sorted order. |
645 | /// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty. |
646 | FailureOr<ConvolutionDimensions> |
647 | mlir::linalg::inferConvolutionDims(LinalgOp linalgOp) { |
648 | if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2) |
649 | return failure(); |
650 | |
651 | auto indexingMaps = linalgOp.getIndexingMapsArray(); |
652 | |
653 | // Check the input indexing map has the right form. |
654 | ConvAccessExprWalker inputExprWalker; |
655 | for (AffineExpr expr : indexingMaps[0].getResults()) |
656 | (void)inputExprWalker.visit(expr); |
657 | inputExprWalker.clearMultiUseDims(map: indexingMaps[0]); |
658 | |
659 | return inferConvolutionDimsImpl(linalgOp, inputExprWalker, |
660 | /*allowEmptyConvolvedDims=*/false); |
661 | } |
662 | |
663 | namespace mlir::linalg::detail { |
664 | enum class MatchConvolutionResult { |
665 | Success = 0, |
666 | NotLinalgOp, |
667 | WrongNumOperands, |
668 | WrongInputIndexingMap, |
669 | NotProjectedPermutations, |
670 | NonConvolutionLoop, |
671 | OutputDimsNotParallel, |
672 | NonOutputDimNotReduction |
673 | }; |
674 | } // namespace mlir::linalg::detail |
675 | |
676 | mlir::linalg::detail::MatchConvolutionResult |
677 | mlir::linalg::detail::isConvolutionInterfaceImpl( |
678 | Operation *op, ConvolutionDimensions *dimensions) { |
679 | auto linalgOp = dyn_cast<linalg::LinalgOp>(op); |
680 | if (!linalgOp) |
681 | return MatchConvolutionResult::NotLinalgOp; |
682 | if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1) |
683 | return MatchConvolutionResult::WrongNumOperands; |
684 | |
685 | auto indexingMaps = linalgOp.getIndexingMapsArray(); |
686 | |
687 | // Check the input indexing map has the right form. |
688 | ConvAccessExprWalker inputExprWalker; |
689 | if (llvm::any_of(indexingMaps[0].getResults(), |
690 | [&inputExprWalker](AffineExpr expr) { |
691 | return failed(result: inputExprWalker.visit(expr)); |
692 | })) { |
693 | return MatchConvolutionResult::WrongInputIndexingMap; |
694 | } |
695 | |
696 | // Filter and output maps must be projected permutation. |
697 | if (!indexingMaps[1].isProjectedPermutation() || |
698 | !indexingMaps.back().isProjectedPermutation()) |
699 | return MatchConvolutionResult::NotProjectedPermutations; |
700 | |
701 | auto iteratorTypes = linalgOp.getIteratorTypesArray(); |
702 | |
703 | llvm::SmallDenseSet<int64_t> outputDims = |
704 | getPreservedDims(indexingMaps.back()); |
705 | llvm::SmallDenseSet<int64_t> filterDims = getPreservedDims(indexingMaps[1]); |
706 | // Make sure all loops are characterized as one of: |
707 | // - Batch loop : present in output, as non-convolved in input, not present in |
708 | // filter. |
709 | // - Output image dimension : present in output, convolved dims in input, not |
710 | // present in filter. |
711 | // - Output channel dimension : present in output, not present in input, |
712 | // present in filter. |
713 | // - Filter loop dimension : present in filter, convolved in input, not |
714 | // present in output. |
715 | // - Input channel dimension : unconvolved in input, not present in output, |
716 | // present in filter. |
717 | // - Depth multiplier : unconvolved in input, present in output, present in |
718 | // filter. |
719 | llvm::SmallDenseSet<int64_t> allLoopDims; |
720 | for (auto outputExpr : indexingMaps.back().getResults()) { |
721 | int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition(); |
722 | if (inputExprWalker.unConvolvedDims.count(outputDim) && |
723 | !filterDims.count(outputDim)) { |
724 | // Batch dimension. |
725 | if (iteratorTypes[outputDim] != utils::IteratorType::parallel) |
726 | return MatchConvolutionResult::OutputDimsNotParallel; |
727 | allLoopDims.insert(outputDim); |
728 | continue; |
729 | } |
730 | if (inputExprWalker.convolvedDims.count(outputDim) && |
731 | !filterDims.count(outputDim)) { |
732 | // Output image Loop dimension. |
733 | if (iteratorTypes[outputDim] != utils::IteratorType::parallel) |
734 | return MatchConvolutionResult::OutputDimsNotParallel; |
735 | allLoopDims.insert(outputDim); |
736 | continue; |
737 | } |
738 | if (!inputExprWalker.convolvedDims.count(outputDim) && |
739 | !inputExprWalker.unConvolvedDims.count(outputDim) && |
740 | filterDims.count(outputDim)) { |
741 | // Output channel dimension. |
742 | if (iteratorTypes[outputDim] != utils::IteratorType::parallel) |
743 | return MatchConvolutionResult::OutputDimsNotParallel; |
744 | allLoopDims.insert(outputDim); |
745 | continue; |
746 | } |
747 | if (inputExprWalker.unConvolvedDims.count(outputDim) && |
748 | filterDims.count(outputDim)) { |
749 | // Depth multiplier. |
750 | if (iteratorTypes[outputDim] != utils::IteratorType::parallel) |
751 | return MatchConvolutionResult::OutputDimsNotParallel; |
752 | allLoopDims.insert(outputDim); |
753 | continue; |
754 | } |
755 | return MatchConvolutionResult::NonConvolutionLoop; |
756 | } |
757 | for (auto filterExpr : indexingMaps[1].getResults()) { |
758 | int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition(); |
759 | if (outputDims.count(filterDim) && |
760 | !inputExprWalker.unConvolvedDims.count(filterDim) && |
761 | !inputExprWalker.convolvedDims.count(filterDim)) { |
762 | // Output channel dimension. This is already seen, continue; |
763 | continue; |
764 | } |
765 | if (inputExprWalker.convolvedDims.count(filterDim) && |
766 | !outputDims.count(filterDim)) { |
767 | // Filter loop dimension. |
768 | if (iteratorTypes[filterDim] != utils::IteratorType::reduction) |
769 | return MatchConvolutionResult::NonOutputDimNotReduction; |
770 | if (allLoopDims.count(filterDim)) |
771 | return MatchConvolutionResult::NonConvolutionLoop; |
772 | allLoopDims.insert(filterDim); |
773 | continue; |
774 | } |
775 | if (inputExprWalker.unConvolvedDims.count(filterDim) && |
776 | !outputDims.count(filterDim)) { |
777 | // Input channel dimension. |
778 | if (iteratorTypes[filterDim] != utils::IteratorType::reduction) |
779 | return MatchConvolutionResult::NonOutputDimNotReduction; |
780 | if (allLoopDims.count(filterDim)) |
781 | return MatchConvolutionResult::NonConvolutionLoop; |
782 | allLoopDims.insert(filterDim); |
783 | continue; |
784 | } |
785 | if (inputExprWalker.unConvolvedDims.count(filterDim) && |
786 | outputDims.count(filterDim)) { |
787 | // Depthwise loop. Already seen. |
788 | continue; |
789 | } |
790 | return MatchConvolutionResult::NonConvolutionLoop; |
791 | } |
792 | // All loops must be covered now. |
793 | if (allLoopDims.size() != linalgOp.getNumLoops()) |
794 | return MatchConvolutionResult::NonConvolutionLoop; |
795 | |
796 | if (dimensions) { |
797 | FailureOr<ConvolutionDimensions> res = |
798 | inferConvolutionDimsImpl(linalgOp, inputExprWalker, |
799 | /*allowEmptyConvolvedDims=*/true); |
800 | assert(succeeded(res) && "unexpected failure to infer convolution dims" ); |
801 | *dimensions = *res; |
802 | } |
803 | |
804 | return MatchConvolutionResult::Success; |
805 | } |
806 | |
807 | StringRef |
808 | mlir::linalg::detail::getMatchConvolutionMessage(MatchConvolutionResult res) { |
809 | switch (res) { |
810 | case MatchConvolutionResult::NotLinalgOp: |
811 | return "expected a LinalgOp" ; |
812 | case MatchConvolutionResult::WrongNumOperands: |
813 | return "expected op with 2 inputs and 1 output" ; |
814 | case MatchConvolutionResult::WrongInputIndexingMap: |
815 | return "unexpected input index map for convolutions" ; |
816 | case MatchConvolutionResult::NotProjectedPermutations: |
817 | return "expected output/filter indexing maps to be projected permutations" ; |
818 | case MatchConvolutionResult::NonConvolutionLoop: |
819 | return "unexpected loop dimension for convolution op" ; |
820 | case MatchConvolutionResult::OutputDimsNotParallel: |
821 | return "expected all iterators used to access outputs to be parallel" ; |
822 | case MatchConvolutionResult::NonOutputDimNotReduction: |
823 | return "expected all iterators not used to access outputs to be reduction" ; |
824 | case MatchConvolutionResult::Success: |
825 | return "" ; |
826 | } |
827 | llvm_unreachable("unhandled MatchConvolutionResult case" ); |
828 | } |
829 | |
830 | bool mlir::linalg::isaConvolutionOpInterface(LinalgOp linalgOp) { |
831 | return linalg::detail::isConvolutionInterfaceImpl(op: linalgOp.getOperation()) == |
832 | linalg::detail::MatchConvolutionResult::Success; |
833 | } |
834 | |
835 | LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) { |
836 | MatchConvolutionResult res = isConvolutionInterfaceImpl(op); |
837 | if (res != MatchConvolutionResult::Success) |
838 | return op->emitError(message: getMatchConvolutionMessage(res)); |
839 | return success(); |
840 | } |
841 | |
842 | //===----------------------------------------------------------------------===// |
843 | // FillOpInterface implementation |
844 | //===----------------------------------------------------------------------===// |
845 | |
846 | enum class MatchFillResult { |
847 | Success = 0, |
848 | NotLinalgOp, |
849 | WrongNumOperands, |
850 | NotScalarInput |
851 | }; |
852 | |
853 | static MatchFillResult isFillInterfaceImpl(Operation *op) { |
854 | auto linalgOp = dyn_cast<linalg::LinalgOp>(op); |
855 | if (!linalgOp) |
856 | return MatchFillResult::NotLinalgOp; |
857 | if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1) |
858 | return MatchFillResult::WrongNumOperands; |
859 | |
860 | OpOperand *value = linalgOp.getDpsInputOperand(0); |
861 | if (!linalgOp.isScalar(value)) |
862 | return MatchFillResult::NotScalarInput; |
863 | |
864 | return MatchFillResult::Success; |
865 | } |
866 | |
867 | LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) { |
868 | auto res = isFillInterfaceImpl(op); |
869 | if (res == MatchFillResult::NotLinalgOp) |
870 | return op->emitError(message: "expected a LinalgOp" ); |
871 | if (res == MatchFillResult::WrongNumOperands) |
872 | return op->emitError(message: "expected op with 1 input and 1 output" ); |
873 | if (res == MatchFillResult::NotScalarInput) |
874 | return op->emitError(message: "expected op with scalar input" ); |
875 | |
876 | return success(); |
877 | } |
878 | |
879 | //===----------------------------------------------------------------------===// |
880 | // StructuredOpInterface implementation |
881 | //===----------------------------------------------------------------------===// |
882 | |
883 | SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b, |
884 | Location loc) { |
885 | SmallVector<OpFoldResult> res; |
886 | for (OpOperand &opOperand : getOperation()->getOpOperands()) { |
887 | for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i) |
888 | res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i)); |
889 | } |
890 | return res; |
891 | } |
892 | |
893 | SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() { |
894 | SmallVector<int64_t, 4> res; |
895 | assert(!hasDynamicShape() && "expected operands to have static shapes" ); |
896 | for (OpOperand &opOperand : getOperation()->getOpOperands()) |
897 | llvm::append_range(res, getShape(&opOperand)); |
898 | return res; |
899 | } |
900 | |
901 | SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) { |
902 | AffineMap map = getLoopsToShapesMap(); |
903 | unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); |
904 | auto viewSizes = createFlatListOfOperandDims(b, loc); |
905 | SmallVector<Range, 4> res(numDims); |
906 | for (unsigned idx = 0; idx < numRes; ++idx) { |
907 | auto result = map.getResult(idx); |
908 | if (auto d = dyn_cast<AffineDimExpr>(result)) { |
909 | if (res[d.getPosition()].offset) |
910 | continue; |
911 | res[d.getPosition()] = |
912 | Range{b.getIndexAttr(0), viewSizes[idx], b.getIndexAttr(1)}; |
913 | } |
914 | } |
915 | return res; |
916 | } |
917 | |
918 | SmallVector<int64_t, 4> LinalgOp::computeStaticLoopSizes() { |
919 | AffineMap map = getLoopsToShapesMap(); |
920 | unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); |
921 | SmallVector<int64_t, 4> allShapeSizes = createFlatListOfOperandStaticDims(); |
922 | SmallVector<int64_t, 4> res(numDims, 0); |
923 | for (unsigned idx = 0; idx < numRes; ++idx) { |
924 | auto result = map.getResult(idx); |
925 | if (auto d = dyn_cast<AffineDimExpr>(result)) |
926 | res[d.getPosition()] = allShapeSizes[idx]; |
927 | } |
928 | return res; |
929 | } |
930 | |
931 | /// Visitor to check if any of the given set of positions from AffineDimExprs |
932 | /// are used within an AffineExpr. |
933 | struct HasAffineDimExprVisitor |
934 | : public AffineExprVisitor<HasAffineDimExprVisitor, bool> { |
935 | HasAffineDimExprVisitor(llvm::SmallBitVector positions) |
936 | : positions(std::move(positions)) {} |
937 | |
938 | bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) { |
939 | return visit(expr: binaryOpExpr.getLHS()) || visit(expr: binaryOpExpr.getRHS()); |
940 | } |
941 | |
942 | bool visitDimExpr(AffineDimExpr dimExpr) { |
943 | return positions.test(Idx: dimExpr.getPosition()); |
944 | } |
945 | |
946 | bool visitConstantExpr(AffineConstantExpr constExpr) { return false; } |
947 | |
948 | bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; } |
949 | |
950 | private: |
951 | llvm::SmallBitVector positions; |
952 | }; |
953 | |
954 | static std::pair<int64_t, int64_t> |
955 | getResultsPositionInLoopsToShapeMap(LinalgOp &op) { |
956 | int64_t inputRankSum = 0; |
957 | int64_t outputRankSum = 0; |
958 | for (OpOperand *input : op.getDpsInputOperands()) |
959 | inputRankSum += op.getRank(input); |
960 | for (OpOperand &output : op.getDpsInitsMutable()) |
961 | outputRankSum += op.getRank(&output); |
962 | return {inputRankSum, inputRankSum + outputRankSum}; |
963 | } |
964 | |
965 | LogicalResult |
966 | LinalgOp::reifyResultShapes(OpBuilder &b, |
967 | ReifiedRankedShapedTypeDims &reifiedReturnShapes) { |
968 | // An example that helps understand the logic below. |
969 | // Consider the following expression O(i+j, j) += A(i,k) * B(k, j) |
970 | // We want to express the shape of dim 0 of O in terms of shape of the inputs. |
971 | // This is achieved as follows. |
972 | // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1) |
973 | // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1) |
974 | // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2) |
975 | // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap) |
976 | // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1) |
977 | AffineMap loopsToShapesMap = getLoopsToShapesMap(); |
978 | |
979 | // Find the position in the above map that represents the shape of the |
980 | // result:dim being inferred. |
981 | auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(*this); |
982 | |
983 | /// From loopsToShapesMap extract the submap that represents the shape of the |
984 | /// (resultIdx, dim) needed. |
985 | AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap( |
986 | resultShapesSubMapPos.first, |
987 | resultShapesSubMapPos.second - resultShapesSubMapPos.first); |
988 | AffineMap resultShapesFromInputShapesMap = |
989 | loopToResultsShapeMap.compose(getShapesToLoopsMap()); |
990 | |
991 | // Check that the result dim map does not contain the positions corresponding |
992 | // to the outputs. |
993 | llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims()); |
994 | outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second); |
995 | HasAffineDimExprVisitor checkDimExpr(std::move(outputDims)); |
996 | Location loc = getOperation()->getLoc(); |
997 | IRRewriter rewriter(b); |
998 | SmallVector<OpFoldResult> allResultDimValues = |
999 | affine::makeComposedFoldedMultiResultAffineApply( |
1000 | rewriter, loc, resultShapesFromInputShapesMap, |
1001 | createFlatListOfOperandDims(b, loc)); |
1002 | int64_t pos = 0; |
1003 | ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults(); |
1004 | for (OpOperand &opOperand : getDpsInitsMutable()) { |
1005 | SmallVector<OpFoldResult> shapes; |
1006 | for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) { |
1007 | auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType()); |
1008 | if (!shapedType.isDynamicDim(dim)) { |
1009 | // Static dim: Return IntegerAttr. |
1010 | shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim))); |
1011 | } else { |
1012 | // Dynamic dim: Return Value. |
1013 | OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos]) |
1014 | ? createOrFoldDimOp(b, loc, opOperand.get(), dim) |
1015 | : allResultDimValues[pos]; |
1016 | shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr)); |
1017 | } |
1018 | pos++; |
1019 | } |
1020 | reifiedReturnShapes.emplace_back(std::move(shapes)); |
1021 | } |
1022 | return success(); |
1023 | } |
1024 | |
1025 | /// Return the index in the indexingMaps vector that corresponds to this |
1026 | /// `opOperand`. |
1027 | int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) { |
1028 | auto operandNumber = opOperand->getOperandNumber(); |
1029 | auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation()); |
1030 | if (!dpsIface.isDpsInput(opOperand)) |
1031 | return operandNumber; |
1032 | unsigned start = dpsIface.getDpsInits().getBeginOperandIndex(); |
1033 | assert(!dpsIface.isDpsInit(opOperand)); |
1034 | // Account for potential inputs that are not DPS and may not appear in |
1035 | // `indexingMaps`. |
1036 | return cast<DestinationStyleOpInterface>(*this->getOperation()) |
1037 | .getNumDpsInputs() + |
1038 | operandNumber - start; |
1039 | } |
1040 | |
1041 | LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { |
1042 | LinalgOp linalgOp = cast<LinalgOp>(op); |
1043 | |
1044 | // Mixed tensor/buffer operands are not allowed. |
1045 | if (!linalgOp.hasPureTensorSemantics() && |
1046 | !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0) |
1047 | return op->emitOpError(message: "expected to have pure tensor or buffer semantics" ); |
1048 | |
1049 | // Before checking indexing maps, we need to make sure the attributes |
1050 | // referenced by it are valid. |
1051 | if (linalgOp.hasDynamicIndexingMaps()) |
1052 | if (failed(linalgOp.verifyIndexingMapRequiredAttributes())) |
1053 | return failure(); |
1054 | |
1055 | // All input/output operands must be indexed. |
1056 | if (static_cast<int64_t>(linalgOp.getIndexingMapsArray().size()) != |
1057 | linalgOp->getNumOperands()) |
1058 | return op->emitOpError(message: "expected the number of indexing_map (" ) |
1059 | << linalgOp.getIndexingMapsArray().size() |
1060 | << ") to be equal to the number of input/output operands (" |
1061 | << linalgOp->getNumOperands() << ")" ; |
1062 | |
1063 | for (OpOperand &opOperand : linalgOp->getOpOperands()) { |
1064 | AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand); |
1065 | |
1066 | // Symbols disallowed. |
1067 | if (indexingMap.getNumSymbols() != 0) |
1068 | return op->emitOpError("unexpected symbols in indexing_map #" ) |
1069 | << opOperand.getOperandNumber(); |
1070 | |
1071 | // Domain must be consistent. |
1072 | unsigned numLoops = linalgOp.getNumLoops(); |
1073 | if (indexingMap.getNumDims() != numLoops) |
1074 | return op->emitOpError("expected indexing_map #" ) |
1075 | << opOperand.getOperandNumber() << " to have " << numLoops |
1076 | << " dim(s) to match the number of loops" ; |
1077 | |
1078 | int64_t rank = linalgOp.getRank(&opOperand); |
1079 | if (indexingMap.getNumResults() != rank) |
1080 | return op->emitOpError("expected operand rank (" ) |
1081 | << rank << ") to match the result rank of indexing_map #" |
1082 | << opOperand.getOperandNumber() << " (" |
1083 | << indexingMap.getNumResults() << ")" ; |
1084 | } |
1085 | |
1086 | SmallVector<unsigned> redDims; |
1087 | linalgOp.getReductionDims(redDims); |
1088 | |
1089 | if (!linalgOp.getShapesToLoopsMap()) |
1090 | return op->emitOpError(message: "expected the shape-to-loops map to be non-null" ); |
1091 | |
1092 | // Check if given shapes match to inferred shapes. |
1093 | SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges(); |
1094 | SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0); |
1095 | |
1096 | // Verify only static cases since we can't get exact dimension sizes and loop |
1097 | // ranges for dynamic cases in this stage. |
1098 | if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) { |
1099 | for (int64_t &range : endLoopRangeValues) |
1100 | range -= 1; |
1101 | for (OpOperand &opOperand : linalgOp->getOpOperands()) { |
1102 | AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand); |
1103 | SmallVector<int64_t, 4> startIndices = |
1104 | indexingMap.compose(startLoopRangeValues); |
1105 | SmallVector<int64_t, 4> endIndices = |
1106 | indexingMap.compose(endLoopRangeValues); |
1107 | ArrayRef<int64_t> shape = linalgOp.getShape(&opOperand); |
1108 | for (auto dim : llvm::seq<int64_t>(0, shape.size())) { |
1109 | // Ignore dynamic dimension or the case that the dimension size is 0 |
1110 | if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0) |
1111 | continue; |
1112 | |
1113 | // The first index or last index should be the maximum or the minimum in |
1114 | // the inferred index ranges since the range is increasing or |
1115 | // decreasing. The size of dimensions of input/output operands and the |
1116 | // maximum value + 1 in the inferred range should be the same. But, for |
1117 | // now we check if the inferred ranges are in boundary of input/output |
1118 | // operands' size or not in case that Affine Expressions are complicated |
1119 | // such as d0 * 3 |
1120 | // + d1 since it is not easy to handle the issues. |
1121 | // Found the case that this solution can't check, for example, (d0, d1) |
1122 | // -> (d1 - d0) |
1123 | int64_t inferredDimSize = |
1124 | std::max(startIndices[dim], endIndices[dim]) + 1; |
1125 | if (std::min(startIndices[dim], endIndices[dim]) < 0) { |
1126 | std::string mapStr; |
1127 | { |
1128 | llvm::raw_string_ostream os(mapStr); |
1129 | os << indexingMap; |
1130 | } |
1131 | return op->emitOpError( |
1132 | "unexpected result less than 0 at expression #" ) |
1133 | << dim << " in " << mapStr; |
1134 | } |
1135 | if (dyn_cast<AffineDimExpr>(indexingMap.getResult(dim))) { |
1136 | if (inferredDimSize != shape[dim]) { |
1137 | return op->emitOpError("inferred input/output operand #" ) |
1138 | << opOperand.getOperandNumber() << " has shape's dimension #" |
1139 | << dim << " to be " << inferredDimSize << ", but found " |
1140 | << shape[dim]; |
1141 | } |
1142 | } else { |
1143 | if (inferredDimSize > shape[dim]) { |
1144 | return op->emitOpError("inferred input/output operand #" ) |
1145 | << opOperand.getOperandNumber() << " has shape's dimension #" |
1146 | << dim << " to be greater than or equal to " |
1147 | << inferredDimSize << ", but found " << shape[dim]; |
1148 | } |
1149 | } |
1150 | } |
1151 | } |
1152 | } |
1153 | |
1154 | // Check the region has exactly one block. |
1155 | if (linalgOp->getNumRegions() != 1 || |
1156 | !llvm::hasSingleElement(linalgOp->getRegion(0))) |
1157 | return op->emitOpError(message: "expects to have 1 region with 1 block" ); |
1158 | |
1159 | // Simplifying assumption: bbargs match 1-1 with shape operands elemental |
1160 | // types. |
1161 | // TODO: once ranked shape types are plugged in, we may want to drop the |
1162 | // corresponding bbargs, that can never be read from. This will be subject to |
1163 | // consistency discussions (i.e. what to do with output tensors whose bbarg is |
1164 | // not used). |
1165 | Block &block = linalgOp->getRegion(0).front(); |
1166 | |
1167 | if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments()) |
1168 | return op->emitOpError(message: "expected as many non-induction variable region " |
1169 | "arguments as the number of input/output operands" ); |
1170 | |
1171 | for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) { |
1172 | Type elementType = opOperand->get().getType(); |
1173 | if (isa<MemRefType, RankedTensorType>(elementType)) |
1174 | elementType = getElementTypeOrSelf(opOperand->get().getType()); |
1175 | Type argType = block.getArgument(opOperand->getOperandNumber()).getType(); |
1176 | if (elementType != argType) |
1177 | return op->emitOpError("expected type of bb argument #" ) |
1178 | << opOperand->getOperandNumber() << " (" << argType << ")" |
1179 | << " to match element or self type of the corresponding operand (" |
1180 | << elementType << ")" ; |
1181 | } |
1182 | |
1183 | return success(); |
1184 | } |
1185 | |