1 | //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
10 | #include "mlir/Dialect/Affine/IR/AffineValueMap.h" |
11 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
12 | #include "mlir/Dialect/UB/IR/UBOps.h" |
13 | #include "mlir/IR/AffineExprVisitor.h" |
14 | #include "mlir/IR/IRMapping.h" |
15 | #include "mlir/IR/IntegerSet.h" |
16 | #include "mlir/IR/Matchers.h" |
17 | #include "mlir/IR/OpDefinition.h" |
18 | #include "mlir/IR/PatternMatch.h" |
19 | #include "mlir/Interfaces/ShapedOpInterfaces.h" |
20 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
21 | #include "mlir/Support/MathExtras.h" |
22 | #include "mlir/Transforms/InliningUtils.h" |
23 | #include "llvm/ADT/ScopeExit.h" |
24 | #include "llvm/ADT/SmallBitVector.h" |
25 | #include "llvm/ADT/SmallVectorExtras.h" |
26 | #include "llvm/ADT/TypeSwitch.h" |
27 | #include "llvm/Support/Debug.h" |
28 | #include <numeric> |
29 | #include <optional> |
30 | |
31 | using namespace mlir; |
32 | using namespace mlir::affine; |
33 | |
34 | #define DEBUG_TYPE "affine-ops" |
35 | |
36 | #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc" |
37 | |
38 | /// A utility function to check if a value is defined at the top level of |
39 | /// `region` or is an argument of `region`. A value of index type defined at the |
40 | /// top level of a `AffineScope` region is always a valid symbol for all |
41 | /// uses in that region. |
42 | bool mlir::affine::isTopLevelValue(Value value, Region *region) { |
43 | if (auto arg = llvm::dyn_cast<BlockArgument>(value)) |
44 | return arg.getParentRegion() == region; |
45 | return value.getDefiningOp()->getParentRegion() == region; |
46 | } |
47 | |
48 | /// Checks if `value` known to be a legal affine dimension or symbol in `src` |
49 | /// region remains legal if the operation that uses it is inlined into `dest` |
50 | /// with the given value mapping. `legalityCheck` is either `isValidDim` or |
51 | /// `isValidSymbol`, depending on the value being required to remain a valid |
52 | /// dimension or symbol. |
53 | static bool |
54 | remainsLegalAfterInline(Value value, Region *src, Region *dest, |
55 | const IRMapping &mapping, |
56 | function_ref<bool(Value, Region *)> legalityCheck) { |
57 | // If the value is a valid dimension for any other reason than being |
58 | // a top-level value, it will remain valid: constants get inlined |
59 | // with the function, transitive affine applies also get inlined and |
60 | // will be checked themselves, etc. |
61 | if (!isTopLevelValue(value, region: src)) |
62 | return true; |
63 | |
64 | // If it's a top-level value because it's a block operand, i.e. a |
65 | // function argument, check whether the value replacing it after |
66 | // inlining is a valid dimension in the new region. |
67 | if (llvm::isa<BlockArgument>(Val: value)) |
68 | return legalityCheck(mapping.lookup(from: value), dest); |
69 | |
70 | // If it's a top-level value because it's defined in the region, |
71 | // it can only be inlined if the defining op is a constant or a |
72 | // `dim`, which can appear anywhere and be valid, since the defining |
73 | // op won't be top-level anymore after inlining. |
74 | Attribute operandCst; |
75 | bool isDimLikeOp = isa<ShapedDimOpInterface>(value.getDefiningOp()); |
76 | return matchPattern(op: value.getDefiningOp(), pattern: m_Constant(bind_value: &operandCst)) || |
77 | isDimLikeOp; |
78 | } |
79 | |
80 | /// Checks if all values known to be legal affine dimensions or symbols in `src` |
81 | /// remain so if their respective users are inlined into `dest`. |
82 | static bool |
83 | remainsLegalAfterInline(ValueRange values, Region *src, Region *dest, |
84 | const IRMapping &mapping, |
85 | function_ref<bool(Value, Region *)> legalityCheck) { |
86 | return llvm::all_of(Range&: values, P: [&](Value v) { |
87 | return remainsLegalAfterInline(value: v, src, dest, mapping, legalityCheck); |
88 | }); |
89 | } |
90 | |
91 | /// Checks if an affine read or write operation remains legal after inlining |
92 | /// from `src` to `dest`. |
93 | template <typename OpTy> |
94 | static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest, |
95 | const IRMapping &mapping) { |
96 | static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface, |
97 | AffineWriteOpInterface>::value, |
98 | "only ops with affine read/write interface are supported" ); |
99 | |
100 | AffineMap map = op.getAffineMap(); |
101 | ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims()); |
102 | ValueRange symbolOperands = |
103 | op.getMapOperands().take_back(map.getNumSymbols()); |
104 | if (!remainsLegalAfterInline( |
105 | values: dimOperands, src, dest, mapping, |
106 | legalityCheck: static_cast<bool (*)(Value, Region *)>(isValidDim))) |
107 | return false; |
108 | if (!remainsLegalAfterInline( |
109 | values: symbolOperands, src, dest, mapping, |
110 | legalityCheck: static_cast<bool (*)(Value, Region *)>(isValidSymbol))) |
111 | return false; |
112 | return true; |
113 | } |
114 | |
115 | /// Checks if an affine apply operation remains legal after inlining from `src` |
116 | /// to `dest`. |
117 | // Use "unused attribute" marker to silence clang-tidy warning stemming from |
118 | // the inability to see through "llvm::TypeSwitch". |
119 | template <> |
120 | bool LLVM_ATTRIBUTE_UNUSED remainsLegalAfterInline(AffineApplyOp op, |
121 | Region *src, Region *dest, |
122 | const IRMapping &mapping) { |
123 | // If it's a valid dimension, we need to check that it remains so. |
124 | if (isValidDim(op.getResult(), src)) |
125 | return remainsLegalAfterInline( |
126 | op.getMapOperands(), src, dest, mapping, |
127 | static_cast<bool (*)(Value, Region *)>(isValidDim)); |
128 | |
129 | // Otherwise it must be a valid symbol, check that it remains so. |
130 | return remainsLegalAfterInline( |
131 | op.getMapOperands(), src, dest, mapping, |
132 | static_cast<bool (*)(Value, Region *)>(isValidSymbol)); |
133 | } |
134 | |
135 | //===----------------------------------------------------------------------===// |
136 | // AffineDialect Interfaces |
137 | //===----------------------------------------------------------------------===// |
138 | |
139 | namespace { |
140 | /// This class defines the interface for handling inlining with affine |
141 | /// operations. |
142 | struct AffineInlinerInterface : public DialectInlinerInterface { |
143 | using DialectInlinerInterface::DialectInlinerInterface; |
144 | |
145 | //===--------------------------------------------------------------------===// |
146 | // Analysis Hooks |
147 | //===--------------------------------------------------------------------===// |
148 | |
149 | /// Returns true if the given region 'src' can be inlined into the region |
150 | /// 'dest' that is attached to an operation registered to the current dialect. |
151 | /// 'wouldBeCloned' is set if the region is cloned into its new location |
152 | /// rather than moved, indicating there may be other users. |
153 | bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, |
154 | IRMapping &valueMapping) const final { |
155 | // We can inline into affine loops and conditionals if this doesn't break |
156 | // affine value categorization rules. |
157 | Operation *destOp = dest->getParentOp(); |
158 | if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp)) |
159 | return false; |
160 | |
161 | // Multi-block regions cannot be inlined into affine constructs, all of |
162 | // which require single-block regions. |
163 | if (!llvm::hasSingleElement(C&: *src)) |
164 | return false; |
165 | |
166 | // Side-effecting operations that the affine dialect cannot understand |
167 | // should not be inlined. |
168 | Block &srcBlock = src->front(); |
169 | for (Operation &op : srcBlock) { |
170 | // Ops with no side effects are fine, |
171 | if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) { |
172 | if (iface.hasNoEffect()) |
173 | continue; |
174 | } |
175 | |
176 | // Assuming the inlined region is valid, we only need to check if the |
177 | // inlining would change it. |
178 | bool remainsValid = |
179 | llvm::TypeSwitch<Operation *, bool>(&op) |
180 | .Case<AffineApplyOp, AffineReadOpInterface, |
181 | AffineWriteOpInterface>([&](auto op) { |
182 | return remainsLegalAfterInline(op, src, dest, valueMapping); |
183 | }) |
184 | .Default([](Operation *) { |
185 | // Conservatively disallow inlining ops we cannot reason about. |
186 | return false; |
187 | }); |
188 | |
189 | if (!remainsValid) |
190 | return false; |
191 | } |
192 | |
193 | return true; |
194 | } |
195 | |
196 | /// Returns true if the given operation 'op', that is registered to this |
197 | /// dialect, can be inlined into the given region, false otherwise. |
198 | bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned, |
199 | IRMapping &valueMapping) const final { |
200 | // Always allow inlining affine operations into a region that is marked as |
201 | // affine scope, or into affine loops and conditionals. There are some edge |
202 | // cases when inlining *into* affine structures, but that is handled in the |
203 | // other 'isLegalToInline' hook above. |
204 | Operation *parentOp = region->getParentOp(); |
205 | return parentOp->hasTrait<OpTrait::AffineScope>() || |
206 | isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp); |
207 | } |
208 | |
209 | /// Affine regions should be analyzed recursively. |
210 | bool shouldAnalyzeRecursively(Operation *op) const final { return true; } |
211 | }; |
212 | } // namespace |
213 | |
214 | //===----------------------------------------------------------------------===// |
215 | // AffineDialect |
216 | //===----------------------------------------------------------------------===// |
217 | |
218 | void AffineDialect::initialize() { |
219 | addOperations<AffineDmaStartOp, AffineDmaWaitOp, |
220 | #define GET_OP_LIST |
221 | #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc" |
222 | >(); |
223 | addInterfaces<AffineInlinerInterface>(); |
224 | declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp, |
225 | AffineMinOp>(); |
226 | } |
227 | |
228 | /// Materialize a single constant operation from a given attribute value with |
229 | /// the desired resultant type. |
230 | Operation *AffineDialect::materializeConstant(OpBuilder &builder, |
231 | Attribute value, Type type, |
232 | Location loc) { |
233 | if (auto poison = dyn_cast<ub::PoisonAttr>(value)) |
234 | return builder.create<ub::PoisonOp>(loc, type, poison); |
235 | return arith::ConstantOp::materialize(builder, value, type, loc); |
236 | } |
237 | |
238 | /// A utility function to check if a value is defined at the top level of an |
239 | /// op with trait `AffineScope`. If the value is defined in an unlinked region, |
240 | /// conservatively assume it is not top-level. A value of index type defined at |
241 | /// the top level is always a valid symbol. |
242 | bool mlir::affine::isTopLevelValue(Value value) { |
243 | if (auto arg = llvm::dyn_cast<BlockArgument>(value)) { |
244 | // The block owning the argument may be unlinked, e.g. when the surrounding |
245 | // region has not yet been attached to an Op, at which point the parent Op |
246 | // is null. |
247 | Operation *parentOp = arg.getOwner()->getParentOp(); |
248 | return parentOp && parentOp->hasTrait<OpTrait::AffineScope>(); |
249 | } |
250 | // The defining Op may live in an unlinked block so its parent Op may be null. |
251 | Operation *parentOp = value.getDefiningOp()->getParentOp(); |
252 | return parentOp && parentOp->hasTrait<OpTrait::AffineScope>(); |
253 | } |
254 | |
255 | /// Returns the closest region enclosing `op` that is held by an operation with |
256 | /// trait `AffineScope`; `nullptr` if there is no such region. |
257 | Region *mlir::affine::getAffineScope(Operation *op) { |
258 | auto *curOp = op; |
259 | while (auto *parentOp = curOp->getParentOp()) { |
260 | if (parentOp->hasTrait<OpTrait::AffineScope>()) |
261 | return curOp->getParentRegion(); |
262 | curOp = parentOp; |
263 | } |
264 | return nullptr; |
265 | } |
266 | |
267 | // A Value can be used as a dimension id iff it meets one of the following |
268 | // conditions: |
269 | // *) It is valid as a symbol. |
270 | // *) It is an induction variable. |
271 | // *) It is the result of affine apply operation with dimension id arguments. |
272 | bool mlir::affine::isValidDim(Value value) { |
273 | // The value must be an index type. |
274 | if (!value.getType().isIndex()) |
275 | return false; |
276 | |
277 | if (auto *defOp = value.getDefiningOp()) |
278 | return isValidDim(value, region: getAffineScope(op: defOp)); |
279 | |
280 | // This value has to be a block argument for an op that has the |
281 | // `AffineScope` trait or for an affine.for or affine.parallel. |
282 | auto *parentOp = llvm::cast<BlockArgument>(Val&: value).getOwner()->getParentOp(); |
283 | return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() || |
284 | isa<AffineForOp, AffineParallelOp>(parentOp)); |
285 | } |
286 | |
287 | // Value can be used as a dimension id iff it meets one of the following |
288 | // conditions: |
289 | // *) It is valid as a symbol. |
290 | // *) It is an induction variable. |
291 | // *) It is the result of an affine apply operation with dimension id operands. |
292 | bool mlir::affine::isValidDim(Value value, Region *region) { |
293 | // The value must be an index type. |
294 | if (!value.getType().isIndex()) |
295 | return false; |
296 | |
297 | // All valid symbols are okay. |
298 | if (isValidSymbol(value, region)) |
299 | return true; |
300 | |
301 | auto *op = value.getDefiningOp(); |
302 | if (!op) { |
303 | // This value has to be a block argument for an affine.for or an |
304 | // affine.parallel. |
305 | auto *parentOp = llvm::cast<BlockArgument>(Val&: value).getOwner()->getParentOp(); |
306 | return isa<AffineForOp, AffineParallelOp>(parentOp); |
307 | } |
308 | |
309 | // Affine apply operation is ok if all of its operands are ok. |
310 | if (auto applyOp = dyn_cast<AffineApplyOp>(op)) |
311 | return applyOp.isValidDim(region); |
312 | // The dim op is okay if its operand memref/tensor is defined at the top |
313 | // level. |
314 | if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op)) |
315 | return isTopLevelValue(dimOp.getShapedValue()); |
316 | return false; |
317 | } |
318 | |
319 | /// Returns true if the 'index' dimension of the `memref` defined by |
320 | /// `memrefDefOp` is a statically shaped one or defined using a valid symbol |
321 | /// for `region`. |
322 | template <typename AnyMemRefDefOp> |
323 | static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, |
324 | Region *region) { |
325 | MemRefType memRefType = memrefDefOp.getType(); |
326 | |
327 | // Dimension index is out of bounds. |
328 | if (index >= memRefType.getRank()) { |
329 | return false; |
330 | } |
331 | |
332 | // Statically shaped. |
333 | if (!memRefType.isDynamicDim(index)) |
334 | return true; |
335 | // Get the position of the dimension among dynamic dimensions; |
336 | unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index); |
337 | return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos), |
338 | region); |
339 | } |
340 | |
341 | /// Returns true if the result of the dim op is a valid symbol for `region`. |
342 | static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) { |
343 | // The dim op is okay if its source is defined at the top level. |
344 | if (isTopLevelValue(dimOp.getShapedValue())) |
345 | return true; |
346 | |
347 | // Conservatively handle remaining BlockArguments as non-valid symbols. |
348 | // E.g. scf.for iterArgs. |
349 | if (llvm::isa<BlockArgument>(dimOp.getShapedValue())) |
350 | return false; |
351 | |
352 | // The dim op is also okay if its operand memref is a view/subview whose |
353 | // corresponding size is a valid symbol. |
354 | std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension()); |
355 | |
356 | // Be conservative if we can't understand the dimension. |
357 | if (!index.has_value()) |
358 | return false; |
359 | |
360 | // Skip over all memref.cast ops (if any). |
361 | Operation *op = dimOp.getShapedValue().getDefiningOp(); |
362 | while (auto castOp = dyn_cast<memref::CastOp>(op)) { |
363 | // Bail on unranked memrefs. |
364 | if (isa<UnrankedMemRefType>(castOp.getSource().getType())) |
365 | return false; |
366 | op = castOp.getSource().getDefiningOp(); |
367 | if (!op) |
368 | return false; |
369 | } |
370 | |
371 | int64_t i = index.value(); |
372 | return TypeSwitch<Operation *, bool>(op) |
373 | .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>( |
374 | [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); }) |
375 | .Default([](Operation *) { return false; }); |
376 | } |
377 | |
378 | // A value can be used as a symbol (at all its use sites) iff it meets one of |
379 | // the following conditions: |
380 | // *) It is a constant. |
381 | // *) Its defining op or block arg appearance is immediately enclosed by an op |
382 | // with `AffineScope` trait. |
383 | // *) It is the result of an affine.apply operation with symbol operands. |
384 | // *) It is a result of the dim op on a memref whose corresponding size is a |
385 | // valid symbol. |
386 | bool mlir::affine::isValidSymbol(Value value) { |
387 | if (!value) |
388 | return false; |
389 | |
390 | // The value must be an index type. |
391 | if (!value.getType().isIndex()) |
392 | return false; |
393 | |
394 | // Check that the value is a top level value. |
395 | if (isTopLevelValue(value)) |
396 | return true; |
397 | |
398 | if (auto *defOp = value.getDefiningOp()) |
399 | return isValidSymbol(value, region: getAffineScope(op: defOp)); |
400 | |
401 | return false; |
402 | } |
403 | |
404 | /// A value can be used as a symbol for `region` iff it meets one of the |
405 | /// following conditions: |
406 | /// *) It is a constant. |
407 | /// *) It is the result of an affine apply operation with symbol arguments. |
408 | /// *) It is a result of the dim op on a memref whose corresponding size is |
409 | /// a valid symbol. |
410 | /// *) It is defined at the top level of 'region' or is its argument. |
411 | /// *) It dominates `region`'s parent op. |
412 | /// If `region` is null, conservatively assume the symbol definition scope does |
413 | /// not exist and only accept the values that would be symbols regardless of |
414 | /// the surrounding region structure, i.e. the first three cases above. |
415 | bool mlir::affine::isValidSymbol(Value value, Region *region) { |
416 | // The value must be an index type. |
417 | if (!value.getType().isIndex()) |
418 | return false; |
419 | |
420 | // A top-level value is a valid symbol. |
421 | if (region && ::isTopLevelValue(value, region)) |
422 | return true; |
423 | |
424 | auto *defOp = value.getDefiningOp(); |
425 | if (!defOp) { |
426 | // A block argument that is not a top-level value is a valid symbol if it |
427 | // dominates region's parent op. |
428 | Operation *regionOp = region ? region->getParentOp() : nullptr; |
429 | if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>()) |
430 | if (auto *parentOpRegion = region->getParentOp()->getParentRegion()) |
431 | return isValidSymbol(value, region: parentOpRegion); |
432 | return false; |
433 | } |
434 | |
435 | // Constant operation is ok. |
436 | Attribute operandCst; |
437 | if (matchPattern(op: defOp, pattern: m_Constant(bind_value: &operandCst))) |
438 | return true; |
439 | |
440 | // Affine apply operation is ok if all of its operands are ok. |
441 | if (auto applyOp = dyn_cast<AffineApplyOp>(defOp)) |
442 | return applyOp.isValidSymbol(region); |
443 | |
444 | // Dim op results could be valid symbols at any level. |
445 | if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp)) |
446 | return isDimOpValidSymbol(dimOp, region); |
447 | |
448 | // Check for values dominating `region`'s parent op. |
449 | Operation *regionOp = region ? region->getParentOp() : nullptr; |
450 | if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>()) |
451 | if (auto *parentRegion = region->getParentOp()->getParentRegion()) |
452 | return isValidSymbol(value, region: parentRegion); |
453 | |
454 | return false; |
455 | } |
456 | |
457 | // Returns true if 'value' is a valid index to an affine operation (e.g. |
458 | // affine.load, affine.store, affine.dma_start, affine.dma_wait) where |
459 | // `region` provides the polyhedral symbol scope. Returns false otherwise. |
460 | static bool isValidAffineIndexOperand(Value value, Region *region) { |
461 | return isValidDim(value, region) || isValidSymbol(value, region); |
462 | } |
463 | |
464 | /// Prints dimension and symbol list. |
465 | static void printDimAndSymbolList(Operation::operand_iterator begin, |
466 | Operation::operand_iterator end, |
467 | unsigned numDims, OpAsmPrinter &printer) { |
468 | OperandRange operands(begin, end); |
469 | printer << '(' << operands.take_front(n: numDims) << ')'; |
470 | if (operands.size() > numDims) |
471 | printer << '[' << operands.drop_front(n: numDims) << ']'; |
472 | } |
473 | |
474 | /// Parses dimension and symbol list and returns true if parsing failed. |
475 | ParseResult mlir::affine::parseDimAndSymbolList( |
476 | OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) { |
477 | SmallVector<OpAsmParser::UnresolvedOperand, 8> opInfos; |
478 | if (parser.parseOperandList(result&: opInfos, delimiter: OpAsmParser::Delimiter::Paren)) |
479 | return failure(); |
480 | // Store number of dimensions for validation by caller. |
481 | numDims = opInfos.size(); |
482 | |
483 | // Parse the optional symbol operands. |
484 | auto indexTy = parser.getBuilder().getIndexType(); |
485 | return failure(parser.parseOperandList( |
486 | result&: opInfos, delimiter: OpAsmParser::Delimiter::OptionalSquare) || |
487 | parser.resolveOperands(opInfos, indexTy, operands)); |
488 | } |
489 | |
490 | /// Utility function to verify that a set of operands are valid dimension and |
491 | /// symbol identifiers. The operands should be laid out such that the dimension |
492 | /// operands are before the symbol operands. This function returns failure if |
493 | /// there was an invalid operand. An operation is provided to emit any necessary |
494 | /// errors. |
495 | template <typename OpTy> |
496 | static LogicalResult |
497 | verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, |
498 | unsigned numDims) { |
499 | unsigned opIt = 0; |
500 | for (auto operand : operands) { |
501 | if (opIt++ < numDims) { |
502 | if (!isValidDim(operand, getAffineScope(op))) |
503 | return op.emitOpError("operand cannot be used as a dimension id" ); |
504 | } else if (!isValidSymbol(operand, getAffineScope(op))) { |
505 | return op.emitOpError("operand cannot be used as a symbol" ); |
506 | } |
507 | } |
508 | return success(); |
509 | } |
510 | |
511 | //===----------------------------------------------------------------------===// |
512 | // AffineApplyOp |
513 | //===----------------------------------------------------------------------===// |
514 | |
515 | AffineValueMap AffineApplyOp::getAffineValueMap() { |
516 | return AffineValueMap(getAffineMap(), getOperands(), getResult()); |
517 | } |
518 | |
519 | ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) { |
520 | auto &builder = parser.getBuilder(); |
521 | auto indexTy = builder.getIndexType(); |
522 | |
523 | AffineMapAttr mapAttr; |
524 | unsigned numDims; |
525 | if (parser.parseAttribute(mapAttr, "map" , result.attributes) || |
526 | parseDimAndSymbolList(parser, result.operands, numDims) || |
527 | parser.parseOptionalAttrDict(result.attributes)) |
528 | return failure(); |
529 | auto map = mapAttr.getValue(); |
530 | |
531 | if (map.getNumDims() != numDims || |
532 | numDims + map.getNumSymbols() != result.operands.size()) { |
533 | return parser.emitError(parser.getNameLoc(), |
534 | "dimension or symbol index mismatch" ); |
535 | } |
536 | |
537 | result.types.append(map.getNumResults(), indexTy); |
538 | return success(); |
539 | } |
540 | |
541 | void AffineApplyOp::print(OpAsmPrinter &p) { |
542 | p << " " << getMapAttr(); |
543 | printDimAndSymbolList(operand_begin(), operand_end(), |
544 | getAffineMap().getNumDims(), p); |
545 | p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map" }); |
546 | } |
547 | |
548 | LogicalResult AffineApplyOp::verify() { |
549 | // Check input and output dimensions match. |
550 | AffineMap affineMap = getMap(); |
551 | |
552 | // Verify that operand count matches affine map dimension and symbol count. |
553 | if (getNumOperands() != affineMap.getNumDims() + affineMap.getNumSymbols()) |
554 | return emitOpError( |
555 | "operand count and affine map dimension and symbol count must match" ); |
556 | |
557 | // Verify that the map only produces one result. |
558 | if (affineMap.getNumResults() != 1) |
559 | return emitOpError("mapping must produce one value" ); |
560 | |
561 | return success(); |
562 | } |
563 | |
564 | // The result of the affine apply operation can be used as a dimension id if all |
565 | // its operands are valid dimension ids. |
566 | bool AffineApplyOp::isValidDim() { |
567 | return llvm::all_of(getOperands(), |
568 | [](Value op) { return affine::isValidDim(op); }); |
569 | } |
570 | |
571 | // The result of the affine apply operation can be used as a dimension id if all |
572 | // its operands are valid dimension ids with the parent operation of `region` |
573 | // defining the polyhedral scope for symbols. |
574 | bool AffineApplyOp::isValidDim(Region *region) { |
575 | return llvm::all_of(getOperands(), |
576 | [&](Value op) { return ::isValidDim(op, region); }); |
577 | } |
578 | |
579 | // The result of the affine apply operation can be used as a symbol if all its |
580 | // operands are symbols. |
581 | bool AffineApplyOp::isValidSymbol() { |
582 | return llvm::all_of(getOperands(), |
583 | [](Value op) { return affine::isValidSymbol(op); }); |
584 | } |
585 | |
586 | // The result of the affine apply operation can be used as a symbol in `region` |
587 | // if all its operands are symbols in `region`. |
588 | bool AffineApplyOp::isValidSymbol(Region *region) { |
589 | return llvm::all_of(getOperands(), [&](Value operand) { |
590 | return affine::isValidSymbol(operand, region); |
591 | }); |
592 | } |
593 | |
594 | OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) { |
595 | auto map = getAffineMap(); |
596 | |
597 | // Fold dims and symbols to existing values. |
598 | auto expr = map.getResult(0); |
599 | if (auto dim = dyn_cast<AffineDimExpr>(expr)) |
600 | return getOperand(dim.getPosition()); |
601 | if (auto sym = dyn_cast<AffineSymbolExpr>(expr)) |
602 | return getOperand(map.getNumDims() + sym.getPosition()); |
603 | |
604 | // Otherwise, default to folding the map. |
605 | SmallVector<Attribute, 1> result; |
606 | bool hasPoison = false; |
607 | auto foldResult = |
608 | map.constantFold(adaptor.getMapOperands(), result, &hasPoison); |
609 | if (hasPoison) |
610 | return ub::PoisonAttr::get(getContext()); |
611 | if (failed(foldResult)) |
612 | return {}; |
613 | return result[0]; |
614 | } |
615 | |
616 | /// Returns the largest known divisor of `e`. Exploits information from the |
617 | /// values in `operands`. |
618 | static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef<Value> operands) { |
619 | // This method isn't aware of `operands`. |
620 | int64_t div = e.getLargestKnownDivisor(); |
621 | |
622 | // We now make use of operands for the case `e` is a dim expression. |
623 | // TODO: More powerful simplification would have to modify |
624 | // getLargestKnownDivisor to take `operands` and exploit that information as |
625 | // well for dim/sym expressions, but in that case, getLargestKnownDivisor |
626 | // can't be part of the IR library but of the `Analysis` library. The IR |
627 | // library can only really depend on simple O(1) checks. |
628 | auto dimExpr = dyn_cast<AffineDimExpr>(Val&: e); |
629 | // If it's not a dim expr, `div` is the best we have. |
630 | if (!dimExpr) |
631 | return div; |
632 | |
633 | // We simply exploit information from loop IVs. |
634 | // We don't need to use mlir::getLargestKnownDivisorOfValue since the other |
635 | // desired simplifications are expected to be part of other |
636 | // canonicalizations. Also, mlir::getLargestKnownDivisorOfValue is part of the |
637 | // LoopAnalysis library. |
638 | Value operand = operands[dimExpr.getPosition()]; |
639 | int64_t operandDivisor = 1; |
640 | // TODO: With the right accessors, this can be extended to |
641 | // LoopLikeOpInterface. |
642 | if (AffineForOp forOp = getForInductionVarOwner(operand)) { |
643 | if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) { |
644 | operandDivisor = forOp.getStepAsInt(); |
645 | } else { |
646 | uint64_t lbLargestKnownDivisor = |
647 | forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs(); |
648 | operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt()); |
649 | } |
650 | } |
651 | return operandDivisor; |
652 | } |
653 | |
654 | /// Check if `e` is known to be: 0 <= `e` < `k`. Handles the simple cases of `e` |
655 | /// being an affine dim expression or a constant. |
656 | static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef<Value> operands, |
657 | int64_t k) { |
658 | if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: e)) { |
659 | int64_t constVal = constExpr.getValue(); |
660 | return constVal >= 0 && constVal < k; |
661 | } |
662 | auto dimExpr = dyn_cast<AffineDimExpr>(Val&: e); |
663 | if (!dimExpr) |
664 | return false; |
665 | Value operand = operands[dimExpr.getPosition()]; |
666 | // TODO: With the right accessors, this can be extended to |
667 | // LoopLikeOpInterface. |
668 | if (AffineForOp forOp = getForInductionVarOwner(operand)) { |
669 | if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 && |
670 | forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) { |
671 | return true; |
672 | } |
673 | } |
674 | |
675 | // We don't consider other cases like `operand` being defined by a constant or |
676 | // an affine.apply op since such cases will already be handled by other |
677 | // patterns and propagation of loop IVs or constant would happen. |
678 | return false; |
679 | } |
680 | |
681 | /// Check if expression `e` is of the form d*e_1 + e_2 where 0 <= e_2 < d. |
682 | /// Set `div` to `d`, `quotientTimesDiv` to e_1 and `rem` to e_2 if the |
683 | /// expression is in that form. |
684 | static bool isQTimesDPlusR(AffineExpr e, ArrayRef<Value> operands, int64_t &div, |
685 | AffineExpr "ientTimesDiv, AffineExpr &rem) { |
686 | auto bin = dyn_cast<AffineBinaryOpExpr>(Val&: e); |
687 | if (!bin || bin.getKind() != AffineExprKind::Add) |
688 | return false; |
689 | |
690 | AffineExpr llhs = bin.getLHS(); |
691 | AffineExpr rlhs = bin.getRHS(); |
692 | div = getLargestKnownDivisor(e: llhs, operands); |
693 | if (isNonNegativeBoundedBy(e: rlhs, operands, k: div)) { |
694 | quotientTimesDiv = llhs; |
695 | rem = rlhs; |
696 | return true; |
697 | } |
698 | div = getLargestKnownDivisor(e: rlhs, operands); |
699 | if (isNonNegativeBoundedBy(e: llhs, operands, k: div)) { |
700 | quotientTimesDiv = rlhs; |
701 | rem = llhs; |
702 | return true; |
703 | } |
704 | return false; |
705 | } |
706 | |
707 | /// Gets the constant lower bound on an `iv`. |
708 | static std::optional<int64_t> getLowerBound(Value iv) { |
709 | AffineForOp forOp = getForInductionVarOwner(iv); |
710 | if (forOp && forOp.hasConstantLowerBound()) |
711 | return forOp.getConstantLowerBound(); |
712 | return std::nullopt; |
713 | } |
714 | |
715 | /// Gets the constant upper bound on an affine.for `iv`. |
716 | static std::optional<int64_t> getUpperBound(Value iv) { |
717 | AffineForOp forOp = getForInductionVarOwner(iv); |
718 | if (!forOp || !forOp.hasConstantUpperBound()) |
719 | return std::nullopt; |
720 | |
721 | // If its lower bound is also known, we can get a more precise bound |
722 | // whenever the step is not one. |
723 | if (forOp.hasConstantLowerBound()) { |
724 | return forOp.getConstantUpperBound() - 1 - |
725 | (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) % |
726 | forOp.getStepAsInt(); |
727 | } |
728 | return forOp.getConstantUpperBound() - 1; |
729 | } |
730 | |
731 | /// Determine a constant upper bound for `expr` if one exists while exploiting |
732 | /// values in `operands`. Note that the upper bound is an inclusive one. `expr` |
733 | /// is guaranteed to be less than or equal to it. |
734 | static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims, |
735 | unsigned numSymbols, |
736 | ArrayRef<Value> operands) { |
737 | // Get the constant lower or upper bounds on the operands. |
738 | SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds; |
739 | constLowerBounds.reserve(N: operands.size()); |
740 | constUpperBounds.reserve(N: operands.size()); |
741 | for (Value operand : operands) { |
742 | constLowerBounds.push_back(Elt: getLowerBound(iv: operand)); |
743 | constUpperBounds.push_back(Elt: getUpperBound(iv: operand)); |
744 | } |
745 | |
746 | if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: expr)) |
747 | return constExpr.getValue(); |
748 | |
749 | return getBoundForAffineExpr(expr, numDims, numSymbols, constLowerBounds, |
750 | constUpperBounds, |
751 | /*isUpper=*/true); |
752 | } |
753 | |
754 | /// Determine a constant lower bound for `expr` if one exists while exploiting |
755 | /// values in `operands`. Note that the upper bound is an inclusive one. `expr` |
756 | /// is guaranteed to be less than or equal to it. |
757 | static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims, |
758 | unsigned numSymbols, |
759 | ArrayRef<Value> operands) { |
760 | // Get the constant lower or upper bounds on the operands. |
761 | SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds; |
762 | constLowerBounds.reserve(N: operands.size()); |
763 | constUpperBounds.reserve(N: operands.size()); |
764 | for (Value operand : operands) { |
765 | constLowerBounds.push_back(Elt: getLowerBound(iv: operand)); |
766 | constUpperBounds.push_back(Elt: getUpperBound(iv: operand)); |
767 | } |
768 | |
769 | std::optional<int64_t> lowerBound; |
770 | if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: expr)) { |
771 | lowerBound = constExpr.getValue(); |
772 | } else { |
773 | lowerBound = getBoundForAffineExpr(expr, numDims, numSymbols, |
774 | constLowerBounds, constUpperBounds, |
775 | /*isUpper=*/false); |
776 | } |
777 | return lowerBound; |
778 | } |
779 | |
780 | /// Simplify `expr` while exploiting information from the values in `operands`. |
781 | static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, |
782 | unsigned numSymbols, |
783 | ArrayRef<Value> operands) { |
784 | // We do this only for certain floordiv/mod expressions. |
785 | auto binExpr = dyn_cast<AffineBinaryOpExpr>(Val&: expr); |
786 | if (!binExpr) |
787 | return; |
788 | |
789 | // Simplify the child expressions first. |
790 | AffineExpr lhs = binExpr.getLHS(); |
791 | AffineExpr rhs = binExpr.getRHS(); |
792 | simplifyExprAndOperands(expr&: lhs, numDims, numSymbols, operands); |
793 | simplifyExprAndOperands(expr&: rhs, numDims, numSymbols, operands); |
794 | expr = getAffineBinaryOpExpr(kind: binExpr.getKind(), lhs, rhs); |
795 | |
796 | binExpr = dyn_cast<AffineBinaryOpExpr>(Val&: expr); |
797 | if (!binExpr || (expr.getKind() != AffineExprKind::FloorDiv && |
798 | expr.getKind() != AffineExprKind::CeilDiv && |
799 | expr.getKind() != AffineExprKind::Mod)) { |
800 | return; |
801 | } |
802 | |
803 | // The `lhs` and `rhs` may be different post construction of simplified expr. |
804 | lhs = binExpr.getLHS(); |
805 | rhs = binExpr.getRHS(); |
806 | auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhs); |
807 | if (!rhsConst) |
808 | return; |
809 | |
810 | int64_t rhsConstVal = rhsConst.getValue(); |
811 | // Undefined exprsessions aren't touched; IR can still be valid with them. |
812 | if (rhsConstVal <= 0) |
813 | return; |
814 | |
815 | // Exploit constant lower/upper bounds to simplify a floordiv or mod. |
816 | MLIRContext *context = expr.getContext(); |
817 | std::optional<int64_t> lhsLbConst = |
818 | getLowerBound(expr: lhs, numDims, numSymbols, operands); |
819 | std::optional<int64_t> lhsUbConst = |
820 | getUpperBound(expr: lhs, numDims, numSymbols, operands); |
821 | if (lhsLbConst && lhsUbConst) { |
822 | int64_t lhsLbConstVal = *lhsLbConst; |
823 | int64_t lhsUbConstVal = *lhsUbConst; |
824 | // lhs floordiv c is a single value lhs is bounded in a range `c` that has |
825 | // the same quotient. |
826 | if (binExpr.getKind() == AffineExprKind::FloorDiv && |
827 | floorDiv(lhs: lhsLbConstVal, rhs: rhsConstVal) == |
828 | floorDiv(lhs: lhsUbConstVal, rhs: rhsConstVal)) { |
829 | expr = |
830 | getAffineConstantExpr(constant: floorDiv(lhs: lhsLbConstVal, rhs: rhsConstVal), context); |
831 | return; |
832 | } |
833 | // lhs ceildiv c is a single value if the entire range has the same ceil |
834 | // quotient. |
835 | if (binExpr.getKind() == AffineExprKind::CeilDiv && |
836 | ceilDiv(lhs: lhsLbConstVal, rhs: rhsConstVal) == |
837 | ceilDiv(lhs: lhsUbConstVal, rhs: rhsConstVal)) { |
838 | expr = |
839 | getAffineConstantExpr(constant: ceilDiv(lhs: lhsLbConstVal, rhs: rhsConstVal), context); |
840 | return; |
841 | } |
842 | // lhs mod c is lhs if the entire range has quotient 0 w.r.t the rhs. |
843 | if (binExpr.getKind() == AffineExprKind::Mod && lhsLbConstVal >= 0 && |
844 | lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) { |
845 | expr = lhs; |
846 | return; |
847 | } |
848 | } |
849 | |
850 | // Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2) |
851 | // mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if |
852 | // `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c. |
853 | // And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c. |
854 | AffineExpr quotientTimesDiv, rem; |
855 | int64_t divisor; |
856 | if (isQTimesDPlusR(e: lhs, operands, div&: divisor, quotientTimesDiv, rem)) { |
857 | if (rhsConstVal % divisor == 0 && |
858 | binExpr.getKind() == AffineExprKind::FloorDiv) { |
859 | expr = quotientTimesDiv.floorDiv(other: rhsConst); |
860 | } else if (divisor % rhsConstVal == 0 && |
861 | binExpr.getKind() == AffineExprKind::Mod) { |
862 | expr = rem % rhsConst; |
863 | } |
864 | return; |
865 | } |
866 | |
867 | // Handle the simple case when the LHS expression can be either upper |
868 | // bounded or is a known multiple of RHS constant. |
869 | // lhs floordiv c -> 0 if 0 <= lhs < c, |
870 | // lhs mod c -> 0 if lhs % c = 0. |
871 | if ((isNonNegativeBoundedBy(e: lhs, operands, k: rhsConstVal) && |
872 | binExpr.getKind() == AffineExprKind::FloorDiv) || |
873 | (getLargestKnownDivisor(e: lhs, operands) % rhsConstVal == 0 && |
874 | binExpr.getKind() == AffineExprKind::Mod)) { |
875 | expr = getAffineConstantExpr(constant: 0, context: expr.getContext()); |
876 | } |
877 | } |
878 | |
879 | /// Simplify the expressions in `map` while making use of lower or upper bounds |
880 | /// of its operands. If `isMax` is true, the map is to be treated as a max of |
881 | /// its result expressions, and min otherwise. Eg: min (d0, d1) -> (8, 4 * d0 + |
882 | /// d1) can be simplified to (8) if the operands are respectively lower bounded |
883 | /// by 2 and 0 (the second expression can't be lower than 8). |
884 | static void simplifyMinOrMaxExprWithOperands(AffineMap &map, |
885 | ArrayRef<Value> operands, |
886 | bool isMax) { |
887 | // Can't simplify. |
888 | if (operands.empty()) |
889 | return; |
890 | |
891 | // Get the upper or lower bound on an affine.for op IV using its range. |
892 | // Get the constant lower or upper bounds on the operands. |
893 | SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds; |
894 | constLowerBounds.reserve(N: operands.size()); |
895 | constUpperBounds.reserve(N: operands.size()); |
896 | for (Value operand : operands) { |
897 | constLowerBounds.push_back(Elt: getLowerBound(iv: operand)); |
898 | constUpperBounds.push_back(Elt: getUpperBound(iv: operand)); |
899 | } |
900 | |
901 | // We will compute the lower and upper bounds on each of the expressions |
902 | // Then, we will check (depending on max or min) as to whether a specific |
903 | // bound is redundant by checking if its highest (in case of max) and its |
904 | // lowest (in the case of min) value is already lower than (or higher than) |
905 | // the lower bound (or upper bound in the case of min) of another bound. |
906 | SmallVector<std::optional<int64_t>, 4> lowerBounds, upperBounds; |
907 | lowerBounds.reserve(N: map.getNumResults()); |
908 | upperBounds.reserve(N: map.getNumResults()); |
909 | for (AffineExpr e : map.getResults()) { |
910 | if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: e)) { |
911 | lowerBounds.push_back(Elt: constExpr.getValue()); |
912 | upperBounds.push_back(Elt: constExpr.getValue()); |
913 | } else { |
914 | lowerBounds.push_back( |
915 | Elt: getBoundForAffineExpr(expr: e, numDims: map.getNumDims(), numSymbols: map.getNumSymbols(), |
916 | constLowerBounds, constUpperBounds, |
917 | /*isUpper=*/false)); |
918 | upperBounds.push_back( |
919 | Elt: getBoundForAffineExpr(expr: e, numDims: map.getNumDims(), numSymbols: map.getNumSymbols(), |
920 | constLowerBounds, constUpperBounds, |
921 | /*isUpper=*/true)); |
922 | } |
923 | } |
924 | |
925 | // Collect expressions that are not redundant. |
926 | SmallVector<AffineExpr, 4> irredundantExprs; |
927 | for (auto exprEn : llvm::enumerate(First: map.getResults())) { |
928 | AffineExpr e = exprEn.value(); |
929 | unsigned i = exprEn.index(); |
930 | // Some expressions can be turned into constants. |
931 | if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i]) |
932 | e = getAffineConstantExpr(constant: *lowerBounds[i], context: e.getContext()); |
933 | |
934 | // Check if the expression is redundant. |
935 | if (isMax) { |
936 | if (!upperBounds[i]) { |
937 | irredundantExprs.push_back(Elt: e); |
938 | continue; |
939 | } |
940 | // If there exists another expression such that its lower bound is greater |
941 | // than this expression's upper bound, it's redundant. |
942 | if (!llvm::any_of(Range: llvm::enumerate(First&: lowerBounds), P: [&](const auto &en) { |
943 | auto otherLowerBound = en.value(); |
944 | unsigned pos = en.index(); |
945 | if (pos == i || !otherLowerBound) |
946 | return false; |
947 | if (*otherLowerBound > *upperBounds[i]) |
948 | return true; |
949 | if (*otherLowerBound < *upperBounds[i]) |
950 | return false; |
951 | // Equality case. When both expressions are considered redundant, we |
952 | // don't want to get both of them. We keep the one that appears |
953 | // first. |
954 | if (upperBounds[pos] && lowerBounds[i] && |
955 | lowerBounds[i] == upperBounds[i] && |
956 | otherLowerBound == *upperBounds[pos] && i < pos) |
957 | return false; |
958 | return true; |
959 | })) |
960 | irredundantExprs.push_back(Elt: e); |
961 | } else { |
962 | if (!lowerBounds[i]) { |
963 | irredundantExprs.push_back(Elt: e); |
964 | continue; |
965 | } |
966 | // Likewise for the `min` case. Use the complement of the condition above. |
967 | if (!llvm::any_of(Range: llvm::enumerate(First&: upperBounds), P: [&](const auto &en) { |
968 | auto otherUpperBound = en.value(); |
969 | unsigned pos = en.index(); |
970 | if (pos == i || !otherUpperBound) |
971 | return false; |
972 | if (*otherUpperBound < *lowerBounds[i]) |
973 | return true; |
974 | if (*otherUpperBound > *lowerBounds[i]) |
975 | return false; |
976 | if (lowerBounds[pos] && upperBounds[i] && |
977 | lowerBounds[i] == upperBounds[i] && |
978 | otherUpperBound == lowerBounds[pos] && i < pos) |
979 | return false; |
980 | return true; |
981 | })) |
982 | irredundantExprs.push_back(Elt: e); |
983 | } |
984 | } |
985 | |
986 | // Create the map without the redundant expressions. |
987 | map = AffineMap::get(dimCount: map.getNumDims(), symbolCount: map.getNumSymbols(), results: irredundantExprs, |
988 | context: map.getContext()); |
989 | } |
990 | |
991 | /// Simplify the map while exploiting information on the values in `operands`. |
992 | // Use "unused attribute" marker to silence warning stemming from the inability |
993 | // to see through the template expansion. |
994 | static void LLVM_ATTRIBUTE_UNUSED |
995 | simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) { |
996 | assert(map.getNumInputs() == operands.size() && "invalid operands for map" ); |
997 | SmallVector<AffineExpr> newResults; |
998 | newResults.reserve(N: map.getNumResults()); |
999 | for (AffineExpr expr : map.getResults()) { |
1000 | simplifyExprAndOperands(expr, numDims: map.getNumDims(), numSymbols: map.getNumSymbols(), |
1001 | operands); |
1002 | newResults.push_back(Elt: expr); |
1003 | } |
1004 | map = AffineMap::get(dimCount: map.getNumDims(), symbolCount: map.getNumSymbols(), results: newResults, |
1005 | context: map.getContext()); |
1006 | } |
1007 | |
1008 | /// Replace all occurrences of AffineExpr at position `pos` in `map` by the |
1009 | /// defining AffineApplyOp expression and operands. |
1010 | /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced. |
1011 | /// When `dimOrSymbolPosition >= dims.size()`, |
1012 | /// AffineSymbolExpr@[pos - dims.size()] is replaced. |
1013 | /// Mutate `map`,`dims` and `syms` in place as follows: |
1014 | /// 1. `dims` and `syms` are only appended to. |
1015 | /// 2. `map` dim and symbols are gradually shifted to higher positions. |
1016 | /// 3. Old `dim` and `sym` entries are replaced by nullptr |
1017 | /// This avoids the need for any bookkeeping. |
1018 | static LogicalResult replaceDimOrSym(AffineMap *map, |
1019 | unsigned dimOrSymbolPosition, |
1020 | SmallVectorImpl<Value> &dims, |
1021 | SmallVectorImpl<Value> &syms) { |
1022 | MLIRContext *ctx = map->getContext(); |
1023 | bool isDimReplacement = (dimOrSymbolPosition < dims.size()); |
1024 | unsigned pos = isDimReplacement ? dimOrSymbolPosition |
1025 | : dimOrSymbolPosition - dims.size(); |
1026 | Value &v = isDimReplacement ? dims[pos] : syms[pos]; |
1027 | if (!v) |
1028 | return failure(); |
1029 | |
1030 | auto affineApply = v.getDefiningOp<AffineApplyOp>(); |
1031 | if (!affineApply) |
1032 | return failure(); |
1033 | |
1034 | // At this point we will perform a replacement of `v`, set the entry in `dim` |
1035 | // or `sym` to nullptr immediately. |
1036 | v = nullptr; |
1037 | |
1038 | // Compute the map, dims and symbols coming from the AffineApplyOp. |
1039 | AffineMap composeMap = affineApply.getAffineMap(); |
1040 | assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results" ); |
1041 | SmallVector<Value> composeOperands(affineApply.getMapOperands().begin(), |
1042 | affineApply.getMapOperands().end()); |
1043 | // Canonicalize the map to promote dims to symbols when possible. This is to |
1044 | // avoid generating invalid maps. |
1045 | canonicalizeMapAndOperands(map: &composeMap, operands: &composeOperands); |
1046 | AffineExpr replacementExpr = |
1047 | composeMap.shiftDims(shift: dims.size()).shiftSymbols(shift: syms.size()).getResult(idx: 0); |
1048 | ValueRange composeDims = |
1049 | ArrayRef<Value>(composeOperands).take_front(N: composeMap.getNumDims()); |
1050 | ValueRange composeSyms = |
1051 | ArrayRef<Value>(composeOperands).take_back(N: composeMap.getNumSymbols()); |
1052 | AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(position: pos, context: ctx) |
1053 | : getAffineSymbolExpr(position: pos, context: ctx); |
1054 | |
1055 | // Append the dims and symbols where relevant and perform the replacement. |
1056 | dims.append(in_start: composeDims.begin(), in_end: composeDims.end()); |
1057 | syms.append(in_start: composeSyms.begin(), in_end: composeSyms.end()); |
1058 | *map = map->replace(expr: toReplace, replacement: replacementExpr, numResultDims: dims.size(), numResultSyms: syms.size()); |
1059 | |
1060 | return success(); |
1061 | } |
1062 | |
1063 | /// Iterate over `operands` and fold away all those produced by an AffineApplyOp |
1064 | /// iteratively. Perform canonicalization of map and operands as well as |
1065 | /// AffineMap simplification. `map` and `operands` are mutated in place. |
1066 | static void composeAffineMapAndOperands(AffineMap *map, |
1067 | SmallVectorImpl<Value> *operands) { |
1068 | if (map->getNumResults() == 0) { |
1069 | canonicalizeMapAndOperands(map, operands); |
1070 | *map = simplifyAffineMap(map: *map); |
1071 | return; |
1072 | } |
1073 | |
1074 | MLIRContext *ctx = map->getContext(); |
1075 | SmallVector<Value, 4> dims(operands->begin(), |
1076 | operands->begin() + map->getNumDims()); |
1077 | SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(), |
1078 | operands->end()); |
1079 | |
1080 | // Iterate over dims and symbols coming from AffineApplyOp and replace until |
1081 | // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims` |
1082 | // and `syms` can only increase by construction. |
1083 | // The implementation uses a `while` loop to support the case of symbols |
1084 | // that may be constructed from dims ;this may be overkill. |
1085 | while (true) { |
1086 | bool changed = false; |
1087 | for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos) |
1088 | if ((changed |= succeeded(result: replaceDimOrSym(map, dimOrSymbolPosition: pos, dims, syms)))) |
1089 | break; |
1090 | if (!changed) |
1091 | break; |
1092 | } |
1093 | |
1094 | // Clear operands so we can fill them anew. |
1095 | operands->clear(); |
1096 | |
1097 | // At this point we may have introduced null operands, prune them out before |
1098 | // canonicalizing map and operands. |
1099 | unsigned nDims = 0, nSyms = 0; |
1100 | SmallVector<AffineExpr, 4> dimReplacements, symReplacements; |
1101 | dimReplacements.reserve(N: dims.size()); |
1102 | symReplacements.reserve(N: syms.size()); |
1103 | for (auto *container : {&dims, &syms}) { |
1104 | bool isDim = (container == &dims); |
1105 | auto &repls = isDim ? dimReplacements : symReplacements; |
1106 | for (const auto &en : llvm::enumerate(First&: *container)) { |
1107 | Value v = en.value(); |
1108 | if (!v) { |
1109 | assert(isDim ? !map->isFunctionOfDim(en.index()) |
1110 | : !map->isFunctionOfSymbol(en.index()) && |
1111 | "map is function of unexpected expr@pos" ); |
1112 | repls.push_back(Elt: getAffineConstantExpr(constant: 0, context: ctx)); |
1113 | continue; |
1114 | } |
1115 | repls.push_back(Elt: isDim ? getAffineDimExpr(position: nDims++, context: ctx) |
1116 | : getAffineSymbolExpr(position: nSyms++, context: ctx)); |
1117 | operands->push_back(Elt: v); |
1118 | } |
1119 | } |
1120 | *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, numResultDims: nDims, |
1121 | numResultSyms: nSyms); |
1122 | |
1123 | // Canonicalize and simplify before returning. |
1124 | canonicalizeMapAndOperands(map, operands); |
1125 | *map = simplifyAffineMap(map: *map); |
1126 | } |
1127 | |
1128 | void mlir::affine::fullyComposeAffineMapAndOperands( |
1129 | AffineMap *map, SmallVectorImpl<Value> *operands) { |
1130 | while (llvm::any_of(Range&: *operands, P: [](Value v) { |
1131 | return isa_and_nonnull<AffineApplyOp>(Val: v.getDefiningOp()); |
1132 | })) { |
1133 | composeAffineMapAndOperands(map, operands); |
1134 | } |
1135 | } |
1136 | |
1137 | AffineApplyOp |
1138 | mlir::affine::makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, |
1139 | ArrayRef<OpFoldResult> operands) { |
1140 | SmallVector<Value> valueOperands; |
1141 | map = foldAttributesIntoMap(b, map, operands, remainingValues&: valueOperands); |
1142 | composeAffineMapAndOperands(map: &map, operands: &valueOperands); |
1143 | assert(map); |
1144 | return b.create<AffineApplyOp>(loc, map, valueOperands); |
1145 | } |
1146 | |
1147 | AffineApplyOp |
1148 | mlir::affine::makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e, |
1149 | ArrayRef<OpFoldResult> operands) { |
1150 | return makeComposedAffineApply( |
1151 | b, loc, |
1152 | AffineMap::inferFromExprList(exprsList: ArrayRef<AffineExpr>{e}, context: b.getContext()) |
1153 | .front(), |
1154 | operands); |
1155 | } |
1156 | |
1157 | /// Composes the given affine map with the given list of operands, pulling in |
1158 | /// the maps from any affine.apply operations that supply the operands. |
1159 | static void composeMultiResultAffineMap(AffineMap &map, |
1160 | SmallVectorImpl<Value> &operands) { |
1161 | // Compose and canonicalize each expression in the map individually because |
1162 | // composition only applies to single-result maps, collecting potentially |
1163 | // duplicate operands in a single list with shifted dimensions and symbols. |
1164 | SmallVector<Value> dims, symbols; |
1165 | SmallVector<AffineExpr> exprs; |
1166 | for (unsigned i : llvm::seq<unsigned>(Begin: 0, End: map.getNumResults())) { |
1167 | SmallVector<Value> submapOperands(operands.begin(), operands.end()); |
1168 | AffineMap submap = map.getSubMap(resultPos: {i}); |
1169 | fullyComposeAffineMapAndOperands(map: &submap, operands: &submapOperands); |
1170 | canonicalizeMapAndOperands(map: &submap, operands: &submapOperands); |
1171 | unsigned numNewDims = submap.getNumDims(); |
1172 | submap = submap.shiftDims(shift: dims.size()).shiftSymbols(shift: symbols.size()); |
1173 | llvm::append_range(C&: dims, |
1174 | R: ArrayRef<Value>(submapOperands).take_front(N: numNewDims)); |
1175 | llvm::append_range(C&: symbols, |
1176 | R: ArrayRef<Value>(submapOperands).drop_front(N: numNewDims)); |
1177 | exprs.push_back(Elt: submap.getResult(idx: 0)); |
1178 | } |
1179 | |
1180 | // Canonicalize the map created from composed expressions to deduplicate the |
1181 | // dimension and symbol operands. |
1182 | operands = llvm::to_vector(Range: llvm::concat<Value>(Ranges&: dims, Ranges&: symbols)); |
1183 | map = AffineMap::get(dimCount: dims.size(), symbolCount: symbols.size(), results: exprs, context: map.getContext()); |
1184 | canonicalizeMapAndOperands(map: &map, operands: &operands); |
1185 | } |
1186 | |
1187 | OpFoldResult |
1188 | mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc, |
1189 | AffineMap map, |
1190 | ArrayRef<OpFoldResult> operands) { |
1191 | assert(map.getNumResults() == 1 && "building affine.apply with !=1 result" ); |
1192 | |
1193 | // Create new builder without a listener, so that no notification is |
1194 | // triggered if the op is folded. |
1195 | // TODO: OpBuilder::createOrFold should return OpFoldResults, then this |
1196 | // workaround is no longer needed. |
1197 | OpBuilder newBuilder(b.getContext()); |
1198 | newBuilder.setInsertionPoint(block: b.getInsertionBlock(), insertPoint: b.getInsertionPoint()); |
1199 | |
1200 | // Create op. |
1201 | AffineApplyOp applyOp = |
1202 | makeComposedAffineApply(newBuilder, loc, map, operands); |
1203 | |
1204 | // Get constant operands. |
1205 | SmallVector<Attribute> constOperands(applyOp->getNumOperands()); |
1206 | for (unsigned i = 0, e = constOperands.size(); i != e; ++i) |
1207 | matchPattern(applyOp->getOperand(i), m_Constant(bind_value: &constOperands[i])); |
1208 | |
1209 | // Try to fold the operation. |
1210 | SmallVector<OpFoldResult> foldResults; |
1211 | if (failed(applyOp->fold(constOperands, foldResults)) || |
1212 | foldResults.empty()) { |
1213 | if (OpBuilder::Listener *listener = b.getListener()) |
1214 | listener->notifyOperationInserted(op: applyOp, /*previous=*/{}); |
1215 | return applyOp.getResult(); |
1216 | } |
1217 | |
1218 | applyOp->erase(); |
1219 | assert(foldResults.size() == 1 && "expected 1 folded result" ); |
1220 | return foldResults.front(); |
1221 | } |
1222 | |
1223 | OpFoldResult |
1224 | mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc, |
1225 | AffineExpr expr, |
1226 | ArrayRef<OpFoldResult> operands) { |
1227 | return makeComposedFoldedAffineApply( |
1228 | b, loc, |
1229 | map: AffineMap::inferFromExprList(exprsList: ArrayRef<AffineExpr>{expr}, context: b.getContext()) |
1230 | .front(), |
1231 | operands); |
1232 | } |
1233 | |
1234 | SmallVector<OpFoldResult> |
1235 | mlir::affine::makeComposedFoldedMultiResultAffineApply( |
1236 | OpBuilder &b, Location loc, AffineMap map, |
1237 | ArrayRef<OpFoldResult> operands) { |
1238 | return llvm::map_to_vector(C: llvm::seq<unsigned>(Begin: 0, End: map.getNumResults()), |
1239 | F: [&](unsigned i) { |
1240 | return makeComposedFoldedAffineApply( |
1241 | b, loc, map: map.getSubMap(resultPos: {i}), operands); |
1242 | }); |
1243 | } |
1244 | |
1245 | template <typename OpTy> |
1246 | static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, |
1247 | ArrayRef<OpFoldResult> operands) { |
1248 | SmallVector<Value> valueOperands; |
1249 | map = foldAttributesIntoMap(b, map, operands, remainingValues&: valueOperands); |
1250 | composeMultiResultAffineMap(map, operands&: valueOperands); |
1251 | return b.create<OpTy>(loc, b.getIndexType(), map, valueOperands); |
1252 | } |
1253 | |
1254 | AffineMinOp |
1255 | mlir::affine::makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, |
1256 | ArrayRef<OpFoldResult> operands) { |
1257 | return makeComposedMinMax<AffineMinOp>(b, loc, map, operands); |
1258 | } |
1259 | |
1260 | template <typename OpTy> |
1261 | static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, |
1262 | AffineMap map, |
1263 | ArrayRef<OpFoldResult> operands) { |
1264 | // Create new builder without a listener, so that no notification is |
1265 | // triggered if the op is folded. |
1266 | // TODO: OpBuilder::createOrFold should return OpFoldResults, then this |
1267 | // workaround is no longer needed. |
1268 | OpBuilder newBuilder(b.getContext()); |
1269 | newBuilder.setInsertionPoint(block: b.getInsertionBlock(), insertPoint: b.getInsertionPoint()); |
1270 | |
1271 | // Create op. |
1272 | auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands); |
1273 | |
1274 | // Get constant operands. |
1275 | SmallVector<Attribute> constOperands(minMaxOp->getNumOperands()); |
1276 | for (unsigned i = 0, e = constOperands.size(); i != e; ++i) |
1277 | matchPattern(minMaxOp->getOperand(i), m_Constant(bind_value: &constOperands[i])); |
1278 | |
1279 | // Try to fold the operation. |
1280 | SmallVector<OpFoldResult> foldResults; |
1281 | if (failed(minMaxOp->fold(constOperands, foldResults)) || |
1282 | foldResults.empty()) { |
1283 | if (OpBuilder::Listener *listener = b.getListener()) |
1284 | listener->notifyOperationInserted(op: minMaxOp, /*previous=*/{}); |
1285 | return minMaxOp.getResult(); |
1286 | } |
1287 | |
1288 | minMaxOp->erase(); |
1289 | assert(foldResults.size() == 1 && "expected 1 folded result" ); |
1290 | return foldResults.front(); |
1291 | } |
1292 | |
1293 | OpFoldResult |
1294 | mlir::affine::makeComposedFoldedAffineMin(OpBuilder &b, Location loc, |
1295 | AffineMap map, |
1296 | ArrayRef<OpFoldResult> operands) { |
1297 | return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands); |
1298 | } |
1299 | |
1300 | OpFoldResult |
1301 | mlir::affine::makeComposedFoldedAffineMax(OpBuilder &b, Location loc, |
1302 | AffineMap map, |
1303 | ArrayRef<OpFoldResult> operands) { |
1304 | return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands); |
1305 | } |
1306 | |
1307 | // A symbol may appear as a dim in affine.apply operations. This function |
1308 | // canonicalizes dims that are valid symbols into actual symbols. |
1309 | template <class MapOrSet> |
1310 | static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, |
1311 | SmallVectorImpl<Value> *operands) { |
1312 | if (!mapOrSet || operands->empty()) |
1313 | return; |
1314 | |
1315 | assert(mapOrSet->getNumInputs() == operands->size() && |
1316 | "map/set inputs must match number of operands" ); |
1317 | |
1318 | auto *context = mapOrSet->getContext(); |
1319 | SmallVector<Value, 8> resultOperands; |
1320 | resultOperands.reserve(N: operands->size()); |
1321 | SmallVector<Value, 8> remappedSymbols; |
1322 | remappedSymbols.reserve(N: operands->size()); |
1323 | unsigned nextDim = 0; |
1324 | unsigned nextSym = 0; |
1325 | unsigned oldNumSyms = mapOrSet->getNumSymbols(); |
1326 | SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims()); |
1327 | for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) { |
1328 | if (i < mapOrSet->getNumDims()) { |
1329 | if (isValidSymbol(value: (*operands)[i])) { |
1330 | // This is a valid symbol that appears as a dim, canonicalize it. |
1331 | dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context); |
1332 | remappedSymbols.push_back(Elt: (*operands)[i]); |
1333 | } else { |
1334 | dimRemapping[i] = getAffineDimExpr(nextDim++, context); |
1335 | resultOperands.push_back(Elt: (*operands)[i]); |
1336 | } |
1337 | } else { |
1338 | resultOperands.push_back(Elt: (*operands)[i]); |
1339 | } |
1340 | } |
1341 | |
1342 | resultOperands.append(in_start: remappedSymbols.begin(), in_end: remappedSymbols.end()); |
1343 | *operands = resultOperands; |
1344 | *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim, |
1345 | oldNumSyms + nextSym); |
1346 | |
1347 | assert(mapOrSet->getNumInputs() == operands->size() && |
1348 | "map/set inputs must match number of operands" ); |
1349 | } |
1350 | |
1351 | // Works for either an affine map or an integer set. |
1352 | template <class MapOrSet> |
1353 | static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, |
1354 | SmallVectorImpl<Value> *operands) { |
1355 | static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value, |
1356 | "Argument must be either of AffineMap or IntegerSet type" ); |
1357 | |
1358 | if (!mapOrSet || operands->empty()) |
1359 | return; |
1360 | |
1361 | assert(mapOrSet->getNumInputs() == operands->size() && |
1362 | "map/set inputs must match number of operands" ); |
1363 | |
1364 | canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands); |
1365 | |
1366 | // Check to see what dims are used. |
1367 | llvm::SmallBitVector usedDims(mapOrSet->getNumDims()); |
1368 | llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols()); |
1369 | mapOrSet->walkExprs([&](AffineExpr expr) { |
1370 | if (auto dimExpr = dyn_cast<AffineDimExpr>(Val&: expr)) |
1371 | usedDims[dimExpr.getPosition()] = true; |
1372 | else if (auto symExpr = dyn_cast<AffineSymbolExpr>(Val&: expr)) |
1373 | usedSyms[symExpr.getPosition()] = true; |
1374 | }); |
1375 | |
1376 | auto *context = mapOrSet->getContext(); |
1377 | |
1378 | SmallVector<Value, 8> resultOperands; |
1379 | resultOperands.reserve(N: operands->size()); |
1380 | |
1381 | llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims; |
1382 | SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims()); |
1383 | unsigned nextDim = 0; |
1384 | for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) { |
1385 | if (usedDims[i]) { |
1386 | // Remap dim positions for duplicate operands. |
1387 | auto it = seenDims.find(Val: (*operands)[i]); |
1388 | if (it == seenDims.end()) { |
1389 | dimRemapping[i] = getAffineDimExpr(nextDim++, context); |
1390 | resultOperands.push_back(Elt: (*operands)[i]); |
1391 | seenDims.insert(KV: std::make_pair(x&: (*operands)[i], y&: dimRemapping[i])); |
1392 | } else { |
1393 | dimRemapping[i] = it->second; |
1394 | } |
1395 | } |
1396 | } |
1397 | llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols; |
1398 | SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols()); |
1399 | unsigned nextSym = 0; |
1400 | for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) { |
1401 | if (!usedSyms[i]) |
1402 | continue; |
1403 | // Handle constant operands (only needed for symbolic operands since |
1404 | // constant operands in dimensional positions would have already been |
1405 | // promoted to symbolic positions above). |
1406 | IntegerAttr operandCst; |
1407 | if (matchPattern((*operands)[i + mapOrSet->getNumDims()], |
1408 | m_Constant(&operandCst))) { |
1409 | symRemapping[i] = |
1410 | getAffineConstantExpr(operandCst.getValue().getSExtValue(), context); |
1411 | continue; |
1412 | } |
1413 | // Remap symbol positions for duplicate operands. |
1414 | auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]); |
1415 | if (it == seenSymbols.end()) { |
1416 | symRemapping[i] = getAffineSymbolExpr(nextSym++, context); |
1417 | resultOperands.push_back(Elt: (*operands)[i + mapOrSet->getNumDims()]); |
1418 | seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()], |
1419 | symRemapping[i])); |
1420 | } else { |
1421 | symRemapping[i] = it->second; |
1422 | } |
1423 | } |
1424 | *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping, |
1425 | nextDim, nextSym); |
1426 | *operands = resultOperands; |
1427 | } |
1428 | |
1429 | void mlir::affine::canonicalizeMapAndOperands( |
1430 | AffineMap *map, SmallVectorImpl<Value> *operands) { |
1431 | canonicalizeMapOrSetAndOperands<AffineMap>(mapOrSet: map, operands); |
1432 | } |
1433 | |
1434 | void mlir::affine::canonicalizeSetAndOperands( |
1435 | IntegerSet *set, SmallVectorImpl<Value> *operands) { |
1436 | canonicalizeMapOrSetAndOperands<IntegerSet>(mapOrSet: set, operands); |
1437 | } |
1438 | |
1439 | namespace { |
1440 | /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing |
1441 | /// maps that supply results into them. |
1442 | /// |
1443 | template <typename AffineOpTy> |
1444 | struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> { |
1445 | using OpRewritePattern<AffineOpTy>::OpRewritePattern; |
1446 | |
1447 | /// Replace the affine op with another instance of it with the supplied |
1448 | /// map and mapOperands. |
1449 | void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp, |
1450 | AffineMap map, ArrayRef<Value> mapOperands) const; |
1451 | |
1452 | LogicalResult matchAndRewrite(AffineOpTy affineOp, |
1453 | PatternRewriter &rewriter) const override { |
1454 | static_assert( |
1455 | llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp, |
1456 | AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp, |
1457 | AffineVectorStoreOp, AffineVectorLoadOp>::value, |
1458 | "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op " |
1459 | "expected" ); |
1460 | auto map = affineOp.getAffineMap(); |
1461 | AffineMap oldMap = map; |
1462 | auto oldOperands = affineOp.getMapOperands(); |
1463 | SmallVector<Value, 8> resultOperands(oldOperands); |
1464 | composeAffineMapAndOperands(&map, &resultOperands); |
1465 | canonicalizeMapAndOperands(&map, &resultOperands); |
1466 | simplifyMapWithOperands(map, resultOperands); |
1467 | if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(), |
1468 | resultOperands.begin())) |
1469 | return failure(); |
1470 | |
1471 | replaceAffineOp(rewriter, affineOp, map, mapOperands: resultOperands); |
1472 | return success(); |
1473 | } |
1474 | }; |
1475 | |
1476 | // Specialize the template to account for the different build signatures for |
1477 | // affine load, store, and apply ops. |
1478 | template <> |
1479 | void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp( |
1480 | PatternRewriter &rewriter, AffineLoadOp load, AffineMap map, |
1481 | ArrayRef<Value> mapOperands) const { |
1482 | rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map, |
1483 | mapOperands); |
1484 | } |
1485 | template <> |
1486 | void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp( |
1487 | PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map, |
1488 | ArrayRef<Value> mapOperands) const { |
1489 | rewriter.replaceOpWithNewOp<AffinePrefetchOp>( |
1490 | prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(), |
1491 | prefetch.getLocalityHint(), prefetch.getIsDataCache()); |
1492 | } |
1493 | template <> |
1494 | void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp( |
1495 | PatternRewriter &rewriter, AffineStoreOp store, AffineMap map, |
1496 | ArrayRef<Value> mapOperands) const { |
1497 | rewriter.replaceOpWithNewOp<AffineStoreOp>( |
1498 | store, store.getValueToStore(), store.getMemRef(), map, mapOperands); |
1499 | } |
1500 | template <> |
1501 | void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp( |
1502 | PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map, |
1503 | ArrayRef<Value> mapOperands) const { |
1504 | rewriter.replaceOpWithNewOp<AffineVectorLoadOp>( |
1505 | vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map, |
1506 | mapOperands); |
1507 | } |
1508 | template <> |
1509 | void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp( |
1510 | PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map, |
1511 | ArrayRef<Value> mapOperands) const { |
1512 | rewriter.replaceOpWithNewOp<AffineVectorStoreOp>( |
1513 | vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map, |
1514 | mapOperands); |
1515 | } |
1516 | |
1517 | // Generic version for ops that don't have extra operands. |
1518 | template <typename AffineOpTy> |
1519 | void SimplifyAffineOp<AffineOpTy>::replaceAffineOp( |
1520 | PatternRewriter &rewriter, AffineOpTy op, AffineMap map, |
1521 | ArrayRef<Value> mapOperands) const { |
1522 | rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands); |
1523 | } |
1524 | } // namespace |
1525 | |
1526 | void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results, |
1527 | MLIRContext *context) { |
1528 | results.add<SimplifyAffineOp<AffineApplyOp>>(context); |
1529 | } |
1530 | |
1531 | //===----------------------------------------------------------------------===// |
1532 | // AffineDmaStartOp |
1533 | //===----------------------------------------------------------------------===// |
1534 | |
1535 | // TODO: Check that map operands are loop IVs or symbols. |
1536 | void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result, |
1537 | Value srcMemRef, AffineMap srcMap, |
1538 | ValueRange srcIndices, Value destMemRef, |
1539 | AffineMap dstMap, ValueRange destIndices, |
1540 | Value tagMemRef, AffineMap tagMap, |
1541 | ValueRange tagIndices, Value numElements, |
1542 | Value stride, Value elementsPerStride) { |
1543 | result.addOperands(newOperands: srcMemRef); |
1544 | result.addAttribute(getSrcMapAttrStrName(), AffineMapAttr::get(srcMap)); |
1545 | result.addOperands(newOperands: srcIndices); |
1546 | result.addOperands(newOperands: destMemRef); |
1547 | result.addAttribute(getDstMapAttrStrName(), AffineMapAttr::get(dstMap)); |
1548 | result.addOperands(newOperands: destIndices); |
1549 | result.addOperands(newOperands: tagMemRef); |
1550 | result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap)); |
1551 | result.addOperands(newOperands: tagIndices); |
1552 | result.addOperands(newOperands: numElements); |
1553 | if (stride) { |
1554 | result.addOperands(newOperands: {stride, elementsPerStride}); |
1555 | } |
1556 | } |
1557 | |
1558 | void AffineDmaStartOp::print(OpAsmPrinter &p) { |
1559 | p << " " << getSrcMemRef() << '['; |
1560 | p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices()); |
1561 | p << "], " << getDstMemRef() << '['; |
1562 | p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices()); |
1563 | p << "], " << getTagMemRef() << '['; |
1564 | p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices()); |
1565 | p << "], " << getNumElements(); |
1566 | if (isStrided()) { |
1567 | p << ", " << getStride(); |
1568 | p << ", " << getNumElementsPerStride(); |
1569 | } |
1570 | p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", " |
1571 | << getTagMemRefType(); |
1572 | } |
1573 | |
1574 | // Parse AffineDmaStartOp. |
1575 | // Ex: |
1576 | // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size, |
1577 | // %stride, %num_elt_per_stride |
1578 | // : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32> |
1579 | // |
1580 | ParseResult AffineDmaStartOp::parse(OpAsmParser &parser, |
1581 | OperationState &result) { |
1582 | OpAsmParser::UnresolvedOperand srcMemRefInfo; |
1583 | AffineMapAttr srcMapAttr; |
1584 | SmallVector<OpAsmParser::UnresolvedOperand, 4> srcMapOperands; |
1585 | OpAsmParser::UnresolvedOperand dstMemRefInfo; |
1586 | AffineMapAttr dstMapAttr; |
1587 | SmallVector<OpAsmParser::UnresolvedOperand, 4> dstMapOperands; |
1588 | OpAsmParser::UnresolvedOperand tagMemRefInfo; |
1589 | AffineMapAttr tagMapAttr; |
1590 | SmallVector<OpAsmParser::UnresolvedOperand, 4> tagMapOperands; |
1591 | OpAsmParser::UnresolvedOperand numElementsInfo; |
1592 | SmallVector<OpAsmParser::UnresolvedOperand, 2> strideInfo; |
1593 | |
1594 | SmallVector<Type, 3> types; |
1595 | auto indexType = parser.getBuilder().getIndexType(); |
1596 | |
1597 | // Parse and resolve the following list of operands: |
1598 | // *) dst memref followed by its affine maps operands (in square brackets). |
1599 | // *) src memref followed by its affine map operands (in square brackets). |
1600 | // *) tag memref followed by its affine map operands (in square brackets). |
1601 | // *) number of elements transferred by DMA operation. |
1602 | if (parser.parseOperand(result&: srcMemRefInfo) || |
1603 | parser.parseAffineMapOfSSAIds(operands&: srcMapOperands, map&: srcMapAttr, |
1604 | attrName: getSrcMapAttrStrName(), |
1605 | attrs&: result.attributes) || |
1606 | parser.parseComma() || parser.parseOperand(result&: dstMemRefInfo) || |
1607 | parser.parseAffineMapOfSSAIds(operands&: dstMapOperands, map&: dstMapAttr, |
1608 | attrName: getDstMapAttrStrName(), |
1609 | attrs&: result.attributes) || |
1610 | parser.parseComma() || parser.parseOperand(result&: tagMemRefInfo) || |
1611 | parser.parseAffineMapOfSSAIds(operands&: tagMapOperands, map&: tagMapAttr, |
1612 | attrName: getTagMapAttrStrName(), |
1613 | attrs&: result.attributes) || |
1614 | parser.parseComma() || parser.parseOperand(result&: numElementsInfo)) |
1615 | return failure(); |
1616 | |
1617 | // Parse optional stride and elements per stride. |
1618 | if (parser.parseTrailingOperandList(result&: strideInfo)) |
1619 | return failure(); |
1620 | |
1621 | if (!strideInfo.empty() && strideInfo.size() != 2) { |
1622 | return parser.emitError(loc: parser.getNameLoc(), |
1623 | message: "expected two stride related operands" ); |
1624 | } |
1625 | bool isStrided = strideInfo.size() == 2; |
1626 | |
1627 | if (parser.parseColonTypeList(result&: types)) |
1628 | return failure(); |
1629 | |
1630 | if (types.size() != 3) |
1631 | return parser.emitError(loc: parser.getNameLoc(), message: "expected three types" ); |
1632 | |
1633 | if (parser.resolveOperand(operand: srcMemRefInfo, type: types[0], result&: result.operands) || |
1634 | parser.resolveOperands(srcMapOperands, indexType, result.operands) || |
1635 | parser.resolveOperand(operand: dstMemRefInfo, type: types[1], result&: result.operands) || |
1636 | parser.resolveOperands(dstMapOperands, indexType, result.operands) || |
1637 | parser.resolveOperand(operand: tagMemRefInfo, type: types[2], result&: result.operands) || |
1638 | parser.resolveOperands(tagMapOperands, indexType, result.operands) || |
1639 | parser.resolveOperand(operand: numElementsInfo, type: indexType, result&: result.operands)) |
1640 | return failure(); |
1641 | |
1642 | if (isStrided) { |
1643 | if (parser.resolveOperands(strideInfo, indexType, result.operands)) |
1644 | return failure(); |
1645 | } |
1646 | |
1647 | // Check that src/dst/tag operand counts match their map.numInputs. |
1648 | if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() || |
1649 | dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() || |
1650 | tagMapOperands.size() != tagMapAttr.getValue().getNumInputs()) |
1651 | return parser.emitError(loc: parser.getNameLoc(), |
1652 | message: "memref operand count not equal to map.numInputs" ); |
1653 | return success(); |
1654 | } |
1655 | |
1656 | LogicalResult AffineDmaStartOp::verifyInvariantsImpl() { |
1657 | if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).getType())) |
1658 | return emitOpError("expected DMA source to be of memref type" ); |
1659 | if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).getType())) |
1660 | return emitOpError("expected DMA destination to be of memref type" ); |
1661 | if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).getType())) |
1662 | return emitOpError("expected DMA tag to be of memref type" ); |
1663 | |
1664 | unsigned numInputsAllMaps = getSrcMap().getNumInputs() + |
1665 | getDstMap().getNumInputs() + |
1666 | getTagMap().getNumInputs(); |
1667 | if (getNumOperands() != numInputsAllMaps + 3 + 1 && |
1668 | getNumOperands() != numInputsAllMaps + 3 + 1 + 2) { |
1669 | return emitOpError("incorrect number of operands" ); |
1670 | } |
1671 | |
1672 | Region *scope = getAffineScope(*this); |
1673 | for (auto idx : getSrcIndices()) { |
1674 | if (!idx.getType().isIndex()) |
1675 | return emitOpError("src index to dma_start must have 'index' type" ); |
1676 | if (!isValidAffineIndexOperand(idx, scope)) |
1677 | return emitOpError( |
1678 | "src index must be a valid dimension or symbol identifier" ); |
1679 | } |
1680 | for (auto idx : getDstIndices()) { |
1681 | if (!idx.getType().isIndex()) |
1682 | return emitOpError("dst index to dma_start must have 'index' type" ); |
1683 | if (!isValidAffineIndexOperand(idx, scope)) |
1684 | return emitOpError( |
1685 | "dst index must be a valid dimension or symbol identifier" ); |
1686 | } |
1687 | for (auto idx : getTagIndices()) { |
1688 | if (!idx.getType().isIndex()) |
1689 | return emitOpError("tag index to dma_start must have 'index' type" ); |
1690 | if (!isValidAffineIndexOperand(idx, scope)) |
1691 | return emitOpError( |
1692 | "tag index must be a valid dimension or symbol identifier" ); |
1693 | } |
1694 | return success(); |
1695 | } |
1696 | |
1697 | LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands, |
1698 | SmallVectorImpl<OpFoldResult> &results) { |
1699 | /// dma_start(memrefcast) -> dma_start |
1700 | return memref::foldMemRefCast(*this); |
1701 | } |
1702 | |
1703 | void AffineDmaStartOp::getEffects( |
1704 | SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> |
1705 | &effects) { |
1706 | effects.emplace_back(Args: MemoryEffects::Read::get(), Args: getSrcMemRef(), |
1707 | Args: SideEffects::DefaultResource::get()); |
1708 | effects.emplace_back(Args: MemoryEffects::Write::get(), Args: getDstMemRef(), |
1709 | Args: SideEffects::DefaultResource::get()); |
1710 | effects.emplace_back(Args: MemoryEffects::Read::get(), Args: getTagMemRef(), |
1711 | Args: SideEffects::DefaultResource::get()); |
1712 | } |
1713 | |
1714 | //===----------------------------------------------------------------------===// |
1715 | // AffineDmaWaitOp |
1716 | //===----------------------------------------------------------------------===// |
1717 | |
1718 | // TODO: Check that map operands are loop IVs or symbols. |
1719 | void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result, |
1720 | Value tagMemRef, AffineMap tagMap, |
1721 | ValueRange tagIndices, Value numElements) { |
1722 | result.addOperands(newOperands: tagMemRef); |
1723 | result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap)); |
1724 | result.addOperands(newOperands: tagIndices); |
1725 | result.addOperands(newOperands: numElements); |
1726 | } |
1727 | |
1728 | void AffineDmaWaitOp::print(OpAsmPrinter &p) { |
1729 | p << " " << getTagMemRef() << '['; |
1730 | SmallVector<Value, 2> operands(getTagIndices()); |
1731 | p.printAffineMapOfSSAIds(getTagMapAttr(), operands); |
1732 | p << "], " ; |
1733 | p.printOperand(value: getNumElements()); |
1734 | p << " : " << getTagMemRef().getType(); |
1735 | } |
1736 | |
1737 | // Parse AffineDmaWaitOp. |
1738 | // Eg: |
1739 | // affine.dma_wait %tag[%index], %num_elements |
1740 | // : memref<1 x i32, (d0) -> (d0), 4> |
1741 | // |
1742 | ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser, |
1743 | OperationState &result) { |
1744 | OpAsmParser::UnresolvedOperand tagMemRefInfo; |
1745 | AffineMapAttr tagMapAttr; |
1746 | SmallVector<OpAsmParser::UnresolvedOperand, 2> tagMapOperands; |
1747 | Type type; |
1748 | auto indexType = parser.getBuilder().getIndexType(); |
1749 | OpAsmParser::UnresolvedOperand numElementsInfo; |
1750 | |
1751 | // Parse tag memref, its map operands, and dma size. |
1752 | if (parser.parseOperand(result&: tagMemRefInfo) || |
1753 | parser.parseAffineMapOfSSAIds(operands&: tagMapOperands, map&: tagMapAttr, |
1754 | attrName: getTagMapAttrStrName(), |
1755 | attrs&: result.attributes) || |
1756 | parser.parseComma() || parser.parseOperand(result&: numElementsInfo) || |
1757 | parser.parseColonType(result&: type) || |
1758 | parser.resolveOperand(operand: tagMemRefInfo, type, result&: result.operands) || |
1759 | parser.resolveOperands(tagMapOperands, indexType, result.operands) || |
1760 | parser.resolveOperand(operand: numElementsInfo, type: indexType, result&: result.operands)) |
1761 | return failure(); |
1762 | |
1763 | if (!llvm::isa<MemRefType>(Val: type)) |
1764 | return parser.emitError(loc: parser.getNameLoc(), |
1765 | message: "expected tag to be of memref type" ); |
1766 | |
1767 | if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs()) |
1768 | return parser.emitError(loc: parser.getNameLoc(), |
1769 | message: "tag memref operand count != to map.numInputs" ); |
1770 | return success(); |
1771 | } |
1772 | |
1773 | LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() { |
1774 | if (!llvm::isa<MemRefType>(getOperand(0).getType())) |
1775 | return emitOpError("expected DMA tag to be of memref type" ); |
1776 | Region *scope = getAffineScope(*this); |
1777 | for (auto idx : getTagIndices()) { |
1778 | if (!idx.getType().isIndex()) |
1779 | return emitOpError("index to dma_wait must have 'index' type" ); |
1780 | if (!isValidAffineIndexOperand(idx, scope)) |
1781 | return emitOpError( |
1782 | "index must be a valid dimension or symbol identifier" ); |
1783 | } |
1784 | return success(); |
1785 | } |
1786 | |
1787 | LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands, |
1788 | SmallVectorImpl<OpFoldResult> &results) { |
1789 | /// dma_wait(memrefcast) -> dma_wait |
1790 | return memref::foldMemRefCast(*this); |
1791 | } |
1792 | |
1793 | void AffineDmaWaitOp::getEffects( |
1794 | SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> |
1795 | &effects) { |
1796 | effects.emplace_back(Args: MemoryEffects::Read::get(), Args: getTagMemRef(), |
1797 | Args: SideEffects::DefaultResource::get()); |
1798 | } |
1799 | |
1800 | //===----------------------------------------------------------------------===// |
1801 | // AffineForOp |
1802 | //===----------------------------------------------------------------------===// |
1803 | |
1804 | /// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and |
1805 | /// bodyBuilder are empty/null, we include default terminator op. |
1806 | void AffineForOp::build(OpBuilder &builder, OperationState &result, |
1807 | ValueRange lbOperands, AffineMap lbMap, |
1808 | ValueRange ubOperands, AffineMap ubMap, int64_t step, |
1809 | ValueRange iterArgs, BodyBuilderFn bodyBuilder) { |
1810 | assert(((!lbMap && lbOperands.empty()) || |
1811 | lbOperands.size() == lbMap.getNumInputs()) && |
1812 | "lower bound operand count does not match the affine map" ); |
1813 | assert(((!ubMap && ubOperands.empty()) || |
1814 | ubOperands.size() == ubMap.getNumInputs()) && |
1815 | "upper bound operand count does not match the affine map" ); |
1816 | assert(step > 0 && "step has to be a positive integer constant" ); |
1817 | |
1818 | OpBuilder::InsertionGuard guard(builder); |
1819 | |
1820 | // Set variadic segment sizes. |
1821 | result.addAttribute( |
1822 | getOperandSegmentSizeAttr(), |
1823 | builder.getDenseI32ArrayAttr({static_cast<int32_t>(lbOperands.size()), |
1824 | static_cast<int32_t>(ubOperands.size()), |
1825 | static_cast<int32_t>(iterArgs.size())})); |
1826 | |
1827 | for (Value val : iterArgs) |
1828 | result.addTypes(val.getType()); |
1829 | |
1830 | // Add an attribute for the step. |
1831 | result.addAttribute(getStepAttrName(result.name), |
1832 | builder.getIntegerAttr(builder.getIndexType(), step)); |
1833 | |
1834 | // Add the lower bound. |
1835 | result.addAttribute(getLowerBoundMapAttrName(result.name), |
1836 | AffineMapAttr::get(lbMap)); |
1837 | result.addOperands(lbOperands); |
1838 | |
1839 | // Add the upper bound. |
1840 | result.addAttribute(getUpperBoundMapAttrName(result.name), |
1841 | AffineMapAttr::get(ubMap)); |
1842 | result.addOperands(ubOperands); |
1843 | |
1844 | result.addOperands(iterArgs); |
1845 | // Create a region and a block for the body. The argument of the region is |
1846 | // the loop induction variable. |
1847 | Region *bodyRegion = result.addRegion(); |
1848 | Block *bodyBlock = builder.createBlock(bodyRegion); |
1849 | Value inductionVar = |
1850 | bodyBlock->addArgument(builder.getIndexType(), result.location); |
1851 | for (Value val : iterArgs) |
1852 | bodyBlock->addArgument(val.getType(), val.getLoc()); |
1853 | |
1854 | // Create the default terminator if the builder is not provided and if the |
1855 | // iteration arguments are not provided. Otherwise, leave this to the caller |
1856 | // because we don't know which values to return from the loop. |
1857 | if (iterArgs.empty() && !bodyBuilder) { |
1858 | ensureTerminator(*bodyRegion, builder, result.location); |
1859 | } else if (bodyBuilder) { |
1860 | OpBuilder::InsertionGuard guard(builder); |
1861 | builder.setInsertionPointToStart(bodyBlock); |
1862 | bodyBuilder(builder, result.location, inductionVar, |
1863 | bodyBlock->getArguments().drop_front()); |
1864 | } |
1865 | } |
1866 | |
1867 | void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb, |
1868 | int64_t ub, int64_t step, ValueRange iterArgs, |
1869 | BodyBuilderFn bodyBuilder) { |
1870 | auto lbMap = AffineMap::getConstantMap(lb, builder.getContext()); |
1871 | auto ubMap = AffineMap::getConstantMap(ub, builder.getContext()); |
1872 | return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs, |
1873 | bodyBuilder); |
1874 | } |
1875 | |
1876 | LogicalResult AffineForOp::verifyRegions() { |
1877 | // Check that the body defines as single block argument for the induction |
1878 | // variable. |
1879 | auto *body = getBody(); |
1880 | if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex()) |
1881 | return emitOpError("expected body to have a single index argument for the " |
1882 | "induction variable" ); |
1883 | |
1884 | // Verify that the bound operands are valid dimension/symbols. |
1885 | /// Lower bound. |
1886 | if (getLowerBoundMap().getNumInputs() > 0) |
1887 | if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundOperands(), |
1888 | getLowerBoundMap().getNumDims()))) |
1889 | return failure(); |
1890 | /// Upper bound. |
1891 | if (getUpperBoundMap().getNumInputs() > 0) |
1892 | if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundOperands(), |
1893 | getUpperBoundMap().getNumDims()))) |
1894 | return failure(); |
1895 | |
1896 | unsigned opNumResults = getNumResults(); |
1897 | if (opNumResults == 0) |
1898 | return success(); |
1899 | |
1900 | // If ForOp defines values, check that the number and types of the defined |
1901 | // values match ForOp initial iter operands and backedge basic block |
1902 | // arguments. |
1903 | if (getNumIterOperands() != opNumResults) |
1904 | return emitOpError( |
1905 | "mismatch between the number of loop-carried values and results" ); |
1906 | if (getNumRegionIterArgs() != opNumResults) |
1907 | return emitOpError( |
1908 | "mismatch between the number of basic block args and results" ); |
1909 | |
1910 | return success(); |
1911 | } |
1912 | |
1913 | /// Parse a for operation loop bounds. |
1914 | static ParseResult parseBound(bool isLower, OperationState &result, |
1915 | OpAsmParser &p) { |
1916 | // 'min' / 'max' prefixes are generally syntactic sugar, but are required if |
1917 | // the map has multiple results. |
1918 | bool failedToParsedMinMax = |
1919 | failed(result: p.parseOptionalKeyword(keyword: isLower ? "max" : "min" )); |
1920 | |
1921 | auto &builder = p.getBuilder(); |
1922 | auto boundAttrStrName = |
1923 | isLower ? AffineForOp::getLowerBoundMapAttrName(result.name) |
1924 | : AffineForOp::getUpperBoundMapAttrName(result.name); |
1925 | |
1926 | // Parse ssa-id as identity map. |
1927 | SmallVector<OpAsmParser::UnresolvedOperand, 1> boundOpInfos; |
1928 | if (p.parseOperandList(result&: boundOpInfos)) |
1929 | return failure(); |
1930 | |
1931 | if (!boundOpInfos.empty()) { |
1932 | // Check that only one operand was parsed. |
1933 | if (boundOpInfos.size() > 1) |
1934 | return p.emitError(loc: p.getNameLoc(), |
1935 | message: "expected only one loop bound operand" ); |
1936 | |
1937 | // TODO: improve error message when SSA value is not of index type. |
1938 | // Currently it is 'use of value ... expects different type than prior uses' |
1939 | if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(), |
1940 | result.operands)) |
1941 | return failure(); |
1942 | |
1943 | // Create an identity map using symbol id. This representation is optimized |
1944 | // for storage. Analysis passes may expand it into a multi-dimensional map |
1945 | // if desired. |
1946 | AffineMap map = builder.getSymbolIdentityMap(); |
1947 | result.addAttribute(boundAttrStrName, AffineMapAttr::get(map)); |
1948 | return success(); |
1949 | } |
1950 | |
1951 | // Get the attribute location. |
1952 | SMLoc attrLoc = p.getCurrentLocation(); |
1953 | |
1954 | Attribute boundAttr; |
1955 | if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrStrName, |
1956 | result.attributes)) |
1957 | return failure(); |
1958 | |
1959 | // Parse full form - affine map followed by dim and symbol list. |
1960 | if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(boundAttr)) { |
1961 | unsigned currentNumOperands = result.operands.size(); |
1962 | unsigned numDims; |
1963 | if (parseDimAndSymbolList(parser&: p, operands&: result.operands, numDims)) |
1964 | return failure(); |
1965 | |
1966 | auto map = affineMapAttr.getValue(); |
1967 | if (map.getNumDims() != numDims) |
1968 | return p.emitError( |
1969 | loc: p.getNameLoc(), |
1970 | message: "dim operand count and affine map dim count must match" ); |
1971 | |
1972 | unsigned numDimAndSymbolOperands = |
1973 | result.operands.size() - currentNumOperands; |
1974 | if (numDims + map.getNumSymbols() != numDimAndSymbolOperands) |
1975 | return p.emitError( |
1976 | loc: p.getNameLoc(), |
1977 | message: "symbol operand count and affine map symbol count must match" ); |
1978 | |
1979 | // If the map has multiple results, make sure that we parsed the min/max |
1980 | // prefix. |
1981 | if (map.getNumResults() > 1 && failedToParsedMinMax) { |
1982 | if (isLower) { |
1983 | return p.emitError(loc: attrLoc, message: "lower loop bound affine map with " |
1984 | "multiple results requires 'max' prefix" ); |
1985 | } |
1986 | return p.emitError(loc: attrLoc, message: "upper loop bound affine map with multiple " |
1987 | "results requires 'min' prefix" ); |
1988 | } |
1989 | return success(); |
1990 | } |
1991 | |
1992 | // Parse custom assembly form. |
1993 | if (auto integerAttr = llvm::dyn_cast<IntegerAttr>(boundAttr)) { |
1994 | result.attributes.pop_back(); |
1995 | result.addAttribute( |
1996 | boundAttrStrName, |
1997 | AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt()))); |
1998 | return success(); |
1999 | } |
2000 | |
2001 | return p.emitError( |
2002 | loc: p.getNameLoc(), |
2003 | message: "expected valid affine map representation for loop bounds" ); |
2004 | } |
2005 | |
2006 | ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) { |
2007 | auto &builder = parser.getBuilder(); |
2008 | OpAsmParser::Argument inductionVariable; |
2009 | inductionVariable.type = builder.getIndexType(); |
2010 | // Parse the induction variable followed by '='. |
2011 | if (parser.parseArgument(inductionVariable) || parser.parseEqual()) |
2012 | return failure(); |
2013 | |
2014 | // Parse loop bounds. |
2015 | int64_t numOperands = result.operands.size(); |
2016 | if (parseBound(/*isLower=*/true, result, parser)) |
2017 | return failure(); |
2018 | int64_t numLbOperands = result.operands.size() - numOperands; |
2019 | if (parser.parseKeyword("to" , " between bounds" )) |
2020 | return failure(); |
2021 | numOperands = result.operands.size(); |
2022 | if (parseBound(/*isLower=*/false, result, parser)) |
2023 | return failure(); |
2024 | int64_t numUbOperands = result.operands.size() - numOperands; |
2025 | |
2026 | // Parse the optional loop step, we default to 1 if one is not present. |
2027 | if (parser.parseOptionalKeyword("step" )) { |
2028 | result.addAttribute( |
2029 | getStepAttrName(result.name), |
2030 | builder.getIntegerAttr(builder.getIndexType(), /*value=*/1)); |
2031 | } else { |
2032 | SMLoc stepLoc = parser.getCurrentLocation(); |
2033 | IntegerAttr stepAttr; |
2034 | if (parser.parseAttribute(stepAttr, builder.getIndexType(), |
2035 | getStepAttrName(result.name).data(), |
2036 | result.attributes)) |
2037 | return failure(); |
2038 | |
2039 | if (stepAttr.getValue().isNegative()) |
2040 | return parser.emitError( |
2041 | stepLoc, |
2042 | "expected step to be representable as a positive signed integer" ); |
2043 | } |
2044 | |
2045 | // Parse the optional initial iteration arguments. |
2046 | SmallVector<OpAsmParser::Argument, 4> regionArgs; |
2047 | SmallVector<OpAsmParser::UnresolvedOperand, 4> operands; |
2048 | |
2049 | // Induction variable. |
2050 | regionArgs.push_back(inductionVariable); |
2051 | |
2052 | if (succeeded(parser.parseOptionalKeyword("iter_args" ))) { |
2053 | // Parse assignment list and results type list. |
2054 | if (parser.parseAssignmentList(regionArgs, operands) || |
2055 | parser.parseArrowTypeList(result.types)) |
2056 | return failure(); |
2057 | // Resolve input operands. |
2058 | for (auto argOperandType : |
2059 | llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) { |
2060 | Type type = std::get<2>(argOperandType); |
2061 | std::get<0>(argOperandType).type = type; |
2062 | if (parser.resolveOperand(std::get<1>(argOperandType), type, |
2063 | result.operands)) |
2064 | return failure(); |
2065 | } |
2066 | } |
2067 | |
2068 | result.addAttribute( |
2069 | getOperandSegmentSizeAttr(), |
2070 | builder.getDenseI32ArrayAttr({static_cast<int32_t>(numLbOperands), |
2071 | static_cast<int32_t>(numUbOperands), |
2072 | static_cast<int32_t>(operands.size())})); |
2073 | |
2074 | // Parse the body region. |
2075 | Region *body = result.addRegion(); |
2076 | if (regionArgs.size() != result.types.size() + 1) |
2077 | return parser.emitError( |
2078 | parser.getNameLoc(), |
2079 | "mismatch between the number of loop-carried values and results" ); |
2080 | if (parser.parseRegion(*body, regionArgs)) |
2081 | return failure(); |
2082 | |
2083 | AffineForOp::ensureTerminator(*body, builder, result.location); |
2084 | |
2085 | // Parse the optional attribute list. |
2086 | return parser.parseOptionalAttrDict(result.attributes); |
2087 | } |
2088 | |
2089 | static void printBound(AffineMapAttr boundMap, |
2090 | Operation::operand_range boundOperands, |
2091 | const char *prefix, OpAsmPrinter &p) { |
2092 | AffineMap map = boundMap.getValue(); |
2093 | |
2094 | // Check if this bound should be printed using custom assembly form. |
2095 | // The decision to restrict printing custom assembly form to trivial cases |
2096 | // comes from the will to roundtrip MLIR binary -> text -> binary in a |
2097 | // lossless way. |
2098 | // Therefore, custom assembly form parsing and printing is only supported for |
2099 | // zero-operand constant maps and single symbol operand identity maps. |
2100 | if (map.getNumResults() == 1) { |
2101 | AffineExpr expr = map.getResult(idx: 0); |
2102 | |
2103 | // Print constant bound. |
2104 | if (map.getNumDims() == 0 && map.getNumSymbols() == 0) { |
2105 | if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) { |
2106 | p << constExpr.getValue(); |
2107 | return; |
2108 | } |
2109 | } |
2110 | |
2111 | // Print bound that consists of a single SSA symbol if the map is over a |
2112 | // single symbol. |
2113 | if (map.getNumDims() == 0 && map.getNumSymbols() == 1) { |
2114 | if (dyn_cast<AffineSymbolExpr>(Val&: expr)) { |
2115 | p.printOperand(value: *boundOperands.begin()); |
2116 | return; |
2117 | } |
2118 | } |
2119 | } else { |
2120 | // Map has multiple results. Print 'min' or 'max' prefix. |
2121 | p << prefix << ' '; |
2122 | } |
2123 | |
2124 | // Print the map and its operands. |
2125 | p << boundMap; |
2126 | printDimAndSymbolList(begin: boundOperands.begin(), end: boundOperands.end(), |
2127 | numDims: map.getNumDims(), printer&: p); |
2128 | } |
2129 | |
2130 | unsigned AffineForOp::getNumIterOperands() { |
2131 | AffineMap lbMap = getLowerBoundMapAttr().getValue(); |
2132 | AffineMap ubMap = getUpperBoundMapAttr().getValue(); |
2133 | |
2134 | return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs(); |
2135 | } |
2136 | |
2137 | std::optional<MutableArrayRef<OpOperand>> |
2138 | AffineForOp::getYieldedValuesMutable() { |
2139 | return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable(); |
2140 | } |
2141 | |
2142 | void AffineForOp::print(OpAsmPrinter &p) { |
2143 | p << ' '; |
2144 | p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{}, |
2145 | /*omitType=*/true); |
2146 | p << " = " ; |
2147 | printBound(getLowerBoundMapAttr(), getLowerBoundOperands(), "max" , p); |
2148 | p << " to " ; |
2149 | printBound(getUpperBoundMapAttr(), getUpperBoundOperands(), "min" , p); |
2150 | |
2151 | if (getStepAsInt() != 1) |
2152 | p << " step " << getStepAsInt(); |
2153 | |
2154 | bool printBlockTerminators = false; |
2155 | if (getNumIterOperands() > 0) { |
2156 | p << " iter_args(" ; |
2157 | auto regionArgs = getRegionIterArgs(); |
2158 | auto operands = getInits(); |
2159 | |
2160 | llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) { |
2161 | p << std::get<0>(it) << " = " << std::get<1>(it); |
2162 | }); |
2163 | p << ") -> (" << getResultTypes() << ")" ; |
2164 | printBlockTerminators = true; |
2165 | } |
2166 | |
2167 | p << ' '; |
2168 | p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, |
2169 | printBlockTerminators); |
2170 | p.printOptionalAttrDict( |
2171 | (*this)->getAttrs(), |
2172 | /*elidedAttrs=*/{getLowerBoundMapAttrName(getOperation()->getName()), |
2173 | getUpperBoundMapAttrName(getOperation()->getName()), |
2174 | getStepAttrName(getOperation()->getName()), |
2175 | getOperandSegmentSizeAttr()}); |
2176 | } |
2177 | |
2178 | /// Fold the constant bounds of a loop. |
2179 | static LogicalResult foldLoopBounds(AffineForOp forOp) { |
2180 | auto foldLowerOrUpperBound = [&forOp](bool lower) { |
2181 | // Check to see if each of the operands is the result of a constant. If |
2182 | // so, get the value. If not, ignore it. |
2183 | SmallVector<Attribute, 8> operandConstants; |
2184 | auto boundOperands = |
2185 | lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands(); |
2186 | for (auto operand : boundOperands) { |
2187 | Attribute operandCst; |
2188 | matchPattern(operand, m_Constant(&operandCst)); |
2189 | operandConstants.push_back(operandCst); |
2190 | } |
2191 | |
2192 | AffineMap boundMap = |
2193 | lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap(); |
2194 | assert(boundMap.getNumResults() >= 1 && |
2195 | "bound maps should have at least one result" ); |
2196 | SmallVector<Attribute, 4> foldedResults; |
2197 | if (failed(result: boundMap.constantFold(operandConstants, results&: foldedResults))) |
2198 | return failure(); |
2199 | |
2200 | // Compute the max or min as applicable over the results. |
2201 | assert(!foldedResults.empty() && "bounds should have at least one result" ); |
2202 | auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue(); |
2203 | for (unsigned i = 1, e = foldedResults.size(); i < e; i++) { |
2204 | auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue(); |
2205 | maxOrMin = lower ? llvm::APIntOps::smax(A: maxOrMin, B: foldedResult) |
2206 | : llvm::APIntOps::smin(A: maxOrMin, B: foldedResult); |
2207 | } |
2208 | lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue()) |
2209 | : forOp.setConstantUpperBound(maxOrMin.getSExtValue()); |
2210 | return success(); |
2211 | }; |
2212 | |
2213 | // Try to fold the lower bound. |
2214 | bool folded = false; |
2215 | if (!forOp.hasConstantLowerBound()) |
2216 | folded |= succeeded(result: foldLowerOrUpperBound(/*lower=*/true)); |
2217 | |
2218 | // Try to fold the upper bound. |
2219 | if (!forOp.hasConstantUpperBound()) |
2220 | folded |= succeeded(result: foldLowerOrUpperBound(/*lower=*/false)); |
2221 | return success(isSuccess: folded); |
2222 | } |
2223 | |
2224 | /// Canonicalize the bounds of the given loop. |
2225 | static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) { |
2226 | SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands()); |
2227 | SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands()); |
2228 | |
2229 | auto lbMap = forOp.getLowerBoundMap(); |
2230 | auto ubMap = forOp.getUpperBoundMap(); |
2231 | auto prevLbMap = lbMap; |
2232 | auto prevUbMap = ubMap; |
2233 | |
2234 | composeAffineMapAndOperands(&lbMap, &lbOperands); |
2235 | canonicalizeMapAndOperands(&lbMap, &lbOperands); |
2236 | simplifyMinOrMaxExprWithOperands(lbMap, lbOperands, /*isMax=*/true); |
2237 | simplifyMinOrMaxExprWithOperands(ubMap, ubOperands, /*isMax=*/false); |
2238 | lbMap = removeDuplicateExprs(lbMap); |
2239 | |
2240 | composeAffineMapAndOperands(&ubMap, &ubOperands); |
2241 | canonicalizeMapAndOperands(&ubMap, &ubOperands); |
2242 | ubMap = removeDuplicateExprs(ubMap); |
2243 | |
2244 | // Any canonicalization change always leads to updated map(s). |
2245 | if (lbMap == prevLbMap && ubMap == prevUbMap) |
2246 | return failure(); |
2247 | |
2248 | if (lbMap != prevLbMap) |
2249 | forOp.setLowerBound(lbOperands, lbMap); |
2250 | if (ubMap != prevUbMap) |
2251 | forOp.setUpperBound(ubOperands, ubMap); |
2252 | return success(); |
2253 | } |
2254 | |
2255 | namespace { |
2256 | /// Returns constant trip count in trivial cases. |
2257 | static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) { |
2258 | int64_t step = forOp.getStepAsInt(); |
2259 | if (!forOp.hasConstantBounds() || step <= 0) |
2260 | return std::nullopt; |
2261 | int64_t lb = forOp.getConstantLowerBound(); |
2262 | int64_t ub = forOp.getConstantUpperBound(); |
2263 | return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step; |
2264 | } |
2265 | |
2266 | /// This is a pattern to fold trivially empty loop bodies. |
2267 | /// TODO: This should be moved into the folding hook. |
2268 | struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> { |
2269 | using OpRewritePattern<AffineForOp>::OpRewritePattern; |
2270 | |
2271 | LogicalResult matchAndRewrite(AffineForOp forOp, |
2272 | PatternRewriter &rewriter) const override { |
2273 | // Check that the body only contains a yield. |
2274 | if (!llvm::hasSingleElement(*forOp.getBody())) |
2275 | return failure(); |
2276 | if (forOp.getNumResults() == 0) |
2277 | return success(); |
2278 | std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp); |
2279 | if (tripCount && *tripCount == 0) { |
2280 | // The initial values of the iteration arguments would be the op's |
2281 | // results. |
2282 | rewriter.replaceOp(forOp, forOp.getInits()); |
2283 | return success(); |
2284 | } |
2285 | SmallVector<Value, 4> replacements; |
2286 | auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator()); |
2287 | auto iterArgs = forOp.getRegionIterArgs(); |
2288 | bool hasValDefinedOutsideLoop = false; |
2289 | bool iterArgsNotInOrder = false; |
2290 | for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) { |
2291 | Value val = yieldOp.getOperand(i); |
2292 | auto *iterArgIt = llvm::find(iterArgs, val); |
2293 | if (iterArgIt == iterArgs.end()) { |
2294 | // `val` is defined outside of the loop. |
2295 | assert(forOp.isDefinedOutsideOfLoop(val) && |
2296 | "must be defined outside of the loop" ); |
2297 | hasValDefinedOutsideLoop = true; |
2298 | replacements.push_back(Elt: val); |
2299 | } else { |
2300 | unsigned pos = std::distance(iterArgs.begin(), iterArgIt); |
2301 | if (pos != i) |
2302 | iterArgsNotInOrder = true; |
2303 | replacements.push_back(Elt: forOp.getInits()[pos]); |
2304 | } |
2305 | } |
2306 | // Bail out when the trip count is unknown and the loop returns any value |
2307 | // defined outside of the loop or any iterArg out of order. |
2308 | if (!tripCount.has_value() && |
2309 | (hasValDefinedOutsideLoop || iterArgsNotInOrder)) |
2310 | return failure(); |
2311 | // Bail out when the loop iterates more than once and it returns any iterArg |
2312 | // out of order. |
2313 | if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder) |
2314 | return failure(); |
2315 | rewriter.replaceOp(forOp, replacements); |
2316 | return success(); |
2317 | } |
2318 | }; |
2319 | } // namespace |
2320 | |
2321 | void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results, |
2322 | MLIRContext *context) { |
2323 | results.add<AffineForEmptyLoopFolder>(context); |
2324 | } |
2325 | |
2326 | OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) { |
2327 | assert((point.isParent() || point == getRegion()) && "invalid region point" ); |
2328 | |
2329 | // The initial operands map to the loop arguments after the induction |
2330 | // variable or are forwarded to the results when the trip count is zero. |
2331 | return getInits(); |
2332 | } |
2333 | |
2334 | void AffineForOp::getSuccessorRegions( |
2335 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
2336 | assert((point.isParent() || point == getRegion()) && "expected loop region" ); |
2337 | // The loop may typically branch back to its body or to the parent operation. |
2338 | // If the predecessor is the parent op and the trip count is known to be at |
2339 | // least one, branch into the body using the iterator arguments. And in cases |
2340 | // we know the trip count is zero, it can only branch back to its parent. |
2341 | std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this); |
2342 | if (point.isParent() && tripCount.has_value()) { |
2343 | if (tripCount.value() > 0) { |
2344 | regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); |
2345 | return; |
2346 | } |
2347 | if (tripCount.value() == 0) { |
2348 | regions.push_back(RegionSuccessor(getResults())); |
2349 | return; |
2350 | } |
2351 | } |
2352 | |
2353 | // From the loop body, if the trip count is one, we can only branch back to |
2354 | // the parent. |
2355 | if (!point.isParent() && tripCount && *tripCount == 1) { |
2356 | regions.push_back(RegionSuccessor(getResults())); |
2357 | return; |
2358 | } |
2359 | |
2360 | // In all other cases, the loop may branch back to itself or the parent |
2361 | // operation. |
2362 | regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); |
2363 | regions.push_back(RegionSuccessor(getResults())); |
2364 | } |
2365 | |
2366 | /// Returns true if the affine.for has zero iterations in trivial cases. |
2367 | static bool hasTrivialZeroTripCount(AffineForOp op) { |
2368 | std::optional<uint64_t> tripCount = getTrivialConstantTripCount(op); |
2369 | return tripCount && *tripCount == 0; |
2370 | } |
2371 | |
2372 | LogicalResult AffineForOp::fold(FoldAdaptor adaptor, |
2373 | SmallVectorImpl<OpFoldResult> &results) { |
2374 | bool folded = succeeded(foldLoopBounds(*this)); |
2375 | folded |= succeeded(canonicalizeLoopBounds(*this)); |
2376 | if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) { |
2377 | // The initial values of the loop-carried variables (iter_args) are the |
2378 | // results of the op. But this must be avoided for an affine.for op that |
2379 | // does not return any results. Since ops that do not return results cannot |
2380 | // be folded away, we would enter an infinite loop of folds on the same |
2381 | // affine.for op. |
2382 | results.assign(getInits().begin(), getInits().end()); |
2383 | folded = true; |
2384 | } |
2385 | return success(folded); |
2386 | } |
2387 | |
2388 | AffineBound AffineForOp::getLowerBound() { |
2389 | return AffineBound(*this, getLowerBoundOperands(), getLowerBoundMap()); |
2390 | } |
2391 | |
2392 | AffineBound AffineForOp::getUpperBound() { |
2393 | return AffineBound(*this, getUpperBoundOperands(), getUpperBoundMap()); |
2394 | } |
2395 | |
2396 | void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) { |
2397 | assert(lbOperands.size() == map.getNumInputs()); |
2398 | assert(map.getNumResults() >= 1 && "bound map has at least one result" ); |
2399 | getLowerBoundOperandsMutable().assign(lbOperands); |
2400 | setLowerBoundMap(map); |
2401 | } |
2402 | |
2403 | void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) { |
2404 | assert(ubOperands.size() == map.getNumInputs()); |
2405 | assert(map.getNumResults() >= 1 && "bound map has at least one result" ); |
2406 | getUpperBoundOperandsMutable().assign(ubOperands); |
2407 | setUpperBoundMap(map); |
2408 | } |
2409 | |
2410 | bool AffineForOp::hasConstantLowerBound() { |
2411 | return getLowerBoundMap().isSingleConstant(); |
2412 | } |
2413 | |
2414 | bool AffineForOp::hasConstantUpperBound() { |
2415 | return getUpperBoundMap().isSingleConstant(); |
2416 | } |
2417 | |
2418 | int64_t AffineForOp::getConstantLowerBound() { |
2419 | return getLowerBoundMap().getSingleConstantResult(); |
2420 | } |
2421 | |
2422 | int64_t AffineForOp::getConstantUpperBound() { |
2423 | return getUpperBoundMap().getSingleConstantResult(); |
2424 | } |
2425 | |
2426 | void AffineForOp::setConstantLowerBound(int64_t value) { |
2427 | setLowerBound({}, AffineMap::getConstantMap(value, getContext())); |
2428 | } |
2429 | |
2430 | void AffineForOp::setConstantUpperBound(int64_t value) { |
2431 | setUpperBound({}, AffineMap::getConstantMap(value, getContext())); |
2432 | } |
2433 | |
2434 | AffineForOp::operand_range AffineForOp::getControlOperands() { |
2435 | return {operand_begin(), operand_begin() + getLowerBoundOperands().size() + |
2436 | getUpperBoundOperands().size()}; |
2437 | } |
2438 | |
2439 | bool AffineForOp::matchingBoundOperandList() { |
2440 | auto lbMap = getLowerBoundMap(); |
2441 | auto ubMap = getUpperBoundMap(); |
2442 | if (lbMap.getNumDims() != ubMap.getNumDims() || |
2443 | lbMap.getNumSymbols() != ubMap.getNumSymbols()) |
2444 | return false; |
2445 | |
2446 | unsigned numOperands = lbMap.getNumInputs(); |
2447 | for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) { |
2448 | // Compare Value 's. |
2449 | if (getOperand(i) != getOperand(numOperands + i)) |
2450 | return false; |
2451 | } |
2452 | return true; |
2453 | } |
2454 | |
2455 | SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; } |
2456 | |
2457 | std::optional<Value> AffineForOp::getSingleInductionVar() { |
2458 | return getInductionVar(); |
2459 | } |
2460 | |
2461 | std::optional<OpFoldResult> AffineForOp::getSingleLowerBound() { |
2462 | if (!hasConstantLowerBound()) |
2463 | return std::nullopt; |
2464 | OpBuilder b(getContext()); |
2465 | return OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound())); |
2466 | } |
2467 | |
2468 | std::optional<OpFoldResult> AffineForOp::getSingleStep() { |
2469 | OpBuilder b(getContext()); |
2470 | return OpFoldResult(b.getI64IntegerAttr(getStepAsInt())); |
2471 | } |
2472 | |
2473 | std::optional<OpFoldResult> AffineForOp::getSingleUpperBound() { |
2474 | if (!hasConstantUpperBound()) |
2475 | return std::nullopt; |
2476 | OpBuilder b(getContext()); |
2477 | return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound())); |
2478 | } |
2479 | |
2480 | FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields( |
2481 | RewriterBase &rewriter, ValueRange newInitOperands, |
2482 | bool replaceInitOperandUsesInLoop, |
2483 | const NewYieldValuesFn &newYieldValuesFn) { |
2484 | // Create a new loop before the existing one, with the extra operands. |
2485 | OpBuilder::InsertionGuard g(rewriter); |
2486 | rewriter.setInsertionPoint(getOperation()); |
2487 | auto inits = llvm::to_vector(getInits()); |
2488 | inits.append(newInitOperands.begin(), newInitOperands.end()); |
2489 | AffineForOp newLoop = rewriter.create<AffineForOp>( |
2490 | getLoc(), getLowerBoundOperands(), getLowerBoundMap(), |
2491 | getUpperBoundOperands(), getUpperBoundMap(), getStepAsInt(), inits); |
2492 | |
2493 | // Generate the new yield values and append them to the scf.yield operation. |
2494 | auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator()); |
2495 | ArrayRef<BlockArgument> newIterArgs = |
2496 | newLoop.getBody()->getArguments().take_back(newInitOperands.size()); |
2497 | { |
2498 | OpBuilder::InsertionGuard g(rewriter); |
2499 | rewriter.setInsertionPoint(yieldOp); |
2500 | SmallVector<Value> newYieldedValues = |
2501 | newYieldValuesFn(rewriter, getLoc(), newIterArgs); |
2502 | assert(newInitOperands.size() == newYieldedValues.size() && |
2503 | "expected as many new yield values as new iter operands" ); |
2504 | rewriter.modifyOpInPlace(yieldOp, [&]() { |
2505 | yieldOp.getOperandsMutable().append(newYieldedValues); |
2506 | }); |
2507 | } |
2508 | |
2509 | // Move the loop body to the new op. |
2510 | rewriter.mergeBlocks(getBody(), newLoop.getBody(), |
2511 | newLoop.getBody()->getArguments().take_front( |
2512 | getBody()->getNumArguments())); |
2513 | |
2514 | if (replaceInitOperandUsesInLoop) { |
2515 | // Replace all uses of `newInitOperands` with the corresponding basic block |
2516 | // arguments. |
2517 | for (auto it : llvm::zip(newInitOperands, newIterArgs)) { |
2518 | rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it), |
2519 | [&](OpOperand &use) { |
2520 | Operation *user = use.getOwner(); |
2521 | return newLoop->isProperAncestor(user); |
2522 | }); |
2523 | } |
2524 | } |
2525 | |
2526 | // Replace the old loop. |
2527 | rewriter.replaceOp(getOperation(), |
2528 | newLoop->getResults().take_front(getNumResults())); |
2529 | return cast<LoopLikeOpInterface>(newLoop.getOperation()); |
2530 | } |
2531 | |
2532 | Speculation::Speculatability AffineForOp::getSpeculatability() { |
2533 | // `affine.for (I = Start; I < End; I += 1)` terminates for all values of |
2534 | // Start and End. |
2535 | // |
2536 | // For Step != 1, the loop may not terminate. We can add more smarts here if |
2537 | // needed. |
2538 | return getStepAsInt() == 1 ? Speculation::RecursivelySpeculatable |
2539 | : Speculation::NotSpeculatable; |
2540 | } |
2541 | |
2542 | /// Returns true if the provided value is the induction variable of a |
2543 | /// AffineForOp. |
2544 | bool mlir::affine::isAffineForInductionVar(Value val) { |
2545 | return getForInductionVarOwner(val) != AffineForOp(); |
2546 | } |
2547 | |
2548 | bool mlir::affine::isAffineParallelInductionVar(Value val) { |
2549 | return getAffineParallelInductionVarOwner(val) != nullptr; |
2550 | } |
2551 | |
2552 | bool mlir::affine::isAffineInductionVar(Value val) { |
2553 | return isAffineForInductionVar(val) || isAffineParallelInductionVar(val); |
2554 | } |
2555 | |
2556 | AffineForOp mlir::affine::getForInductionVarOwner(Value val) { |
2557 | auto ivArg = llvm::dyn_cast<BlockArgument>(Val&: val); |
2558 | if (!ivArg || !ivArg.getOwner()) |
2559 | return AffineForOp(); |
2560 | auto *containingInst = ivArg.getOwner()->getParent()->getParentOp(); |
2561 | if (auto forOp = dyn_cast<AffineForOp>(containingInst)) |
2562 | // Check to make sure `val` is the induction variable, not an iter_arg. |
2563 | return forOp.getInductionVar() == val ? forOp : AffineForOp(); |
2564 | return AffineForOp(); |
2565 | } |
2566 | |
2567 | AffineParallelOp mlir::affine::getAffineParallelInductionVarOwner(Value val) { |
2568 | auto ivArg = llvm::dyn_cast<BlockArgument>(Val&: val); |
2569 | if (!ivArg || !ivArg.getOwner()) |
2570 | return nullptr; |
2571 | Operation *containingOp = ivArg.getOwner()->getParentOp(); |
2572 | auto parallelOp = dyn_cast<AffineParallelOp>(containingOp); |
2573 | if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val)) |
2574 | return parallelOp; |
2575 | return nullptr; |
2576 | } |
2577 | |
2578 | /// Extracts the induction variables from a list of AffineForOps and returns |
2579 | /// them. |
2580 | void mlir::affine::(ArrayRef<AffineForOp> forInsts, |
2581 | SmallVectorImpl<Value> *ivs) { |
2582 | ivs->reserve(N: forInsts.size()); |
2583 | for (auto forInst : forInsts) |
2584 | ivs->push_back(forInst.getInductionVar()); |
2585 | } |
2586 | |
2587 | void mlir::affine::(ArrayRef<mlir::Operation *> affineOps, |
2588 | SmallVectorImpl<mlir::Value> &ivs) { |
2589 | ivs.reserve(N: affineOps.size()); |
2590 | for (Operation *op : affineOps) { |
2591 | // Add constraints from forOp's bounds. |
2592 | if (auto forOp = dyn_cast<AffineForOp>(op)) |
2593 | ivs.push_back(Elt: forOp.getInductionVar()); |
2594 | else if (auto parallelOp = dyn_cast<AffineParallelOp>(op)) |
2595 | for (size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++) |
2596 | ivs.push_back(Elt: parallelOp.getBody()->getArgument(i)); |
2597 | } |
2598 | } |
2599 | |
2600 | /// Builds an affine loop nest, using "loopCreatorFn" to create individual loop |
2601 | /// operations. |
2602 | template <typename BoundListTy, typename LoopCreatorTy> |
2603 | static void buildAffineLoopNestImpl( |
2604 | OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, |
2605 | ArrayRef<int64_t> steps, |
2606 | function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, |
2607 | LoopCreatorTy &&loopCreatorFn) { |
2608 | assert(lbs.size() == ubs.size() && "Mismatch in number of arguments" ); |
2609 | assert(lbs.size() == steps.size() && "Mismatch in number of arguments" ); |
2610 | |
2611 | // If there are no loops to be constructed, construct the body anyway. |
2612 | OpBuilder::InsertionGuard guard(builder); |
2613 | if (lbs.empty()) { |
2614 | if (bodyBuilderFn) |
2615 | bodyBuilderFn(builder, loc, ValueRange()); |
2616 | return; |
2617 | } |
2618 | |
2619 | // Create the loops iteratively and store the induction variables. |
2620 | SmallVector<Value, 4> ivs; |
2621 | ivs.reserve(N: lbs.size()); |
2622 | for (unsigned i = 0, e = lbs.size(); i < e; ++i) { |
2623 | // Callback for creating the loop body, always creates the terminator. |
2624 | auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, |
2625 | ValueRange iterArgs) { |
2626 | ivs.push_back(Elt: iv); |
2627 | // In the innermost loop, call the body builder. |
2628 | if (i == e - 1 && bodyBuilderFn) { |
2629 | OpBuilder::InsertionGuard nestedGuard(nestedBuilder); |
2630 | bodyBuilderFn(nestedBuilder, nestedLoc, ivs); |
2631 | } |
2632 | nestedBuilder.create<AffineYieldOp>(nestedLoc); |
2633 | }; |
2634 | |
2635 | // Delegate actual loop creation to the callback in order to dispatch |
2636 | // between constant- and variable-bound loops. |
2637 | auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody); |
2638 | builder.setInsertionPointToStart(loop.getBody()); |
2639 | } |
2640 | } |
2641 | |
2642 | /// Creates an affine loop from the bounds known to be constants. |
2643 | static AffineForOp |
2644 | buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, |
2645 | int64_t ub, int64_t step, |
2646 | AffineForOp::BodyBuilderFn bodyBuilderFn) { |
2647 | return builder.create<AffineForOp>(loc, lb, ub, step, |
2648 | /*iterArgs=*/std::nullopt, bodyBuilderFn); |
2649 | } |
2650 | |
2651 | /// Creates an affine loop from the bounds that may or may not be constants. |
2652 | static AffineForOp |
2653 | buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, |
2654 | int64_t step, |
2655 | AffineForOp::BodyBuilderFn bodyBuilderFn) { |
2656 | std::optional<int64_t> lbConst = getConstantIntValue(ofr: lb); |
2657 | std::optional<int64_t> ubConst = getConstantIntValue(ofr: ub); |
2658 | if (lbConst && ubConst) |
2659 | return buildAffineLoopFromConstants(builder, loc, lbConst.value(), |
2660 | ubConst.value(), step, bodyBuilderFn); |
2661 | return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub, |
2662 | builder.getDimIdentityMap(), step, |
2663 | /*iterArgs=*/std::nullopt, bodyBuilderFn); |
2664 | } |
2665 | |
2666 | void mlir::affine::buildAffineLoopNest( |
2667 | OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs, |
2668 | ArrayRef<int64_t> ubs, ArrayRef<int64_t> steps, |
2669 | function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) { |
2670 | buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn, |
2671 | buildAffineLoopFromConstants); |
2672 | } |
2673 | |
2674 | void mlir::affine::buildAffineLoopNest( |
2675 | OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, |
2676 | ArrayRef<int64_t> steps, |
2677 | function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) { |
2678 | buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn, |
2679 | buildAffineLoopFromValues); |
2680 | } |
2681 | |
2682 | //===----------------------------------------------------------------------===// |
2683 | // AffineIfOp |
2684 | //===----------------------------------------------------------------------===// |
2685 | |
2686 | namespace { |
2687 | /// Remove else blocks that have nothing other than a zero value yield. |
2688 | struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> { |
2689 | using OpRewritePattern<AffineIfOp>::OpRewritePattern; |
2690 | |
2691 | LogicalResult matchAndRewrite(AffineIfOp ifOp, |
2692 | PatternRewriter &rewriter) const override { |
2693 | if (ifOp.getElseRegion().empty() || |
2694 | !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults()) |
2695 | return failure(); |
2696 | |
2697 | rewriter.startOpModification(op: ifOp); |
2698 | rewriter.eraseBlock(block: ifOp.getElseBlock()); |
2699 | rewriter.finalizeOpModification(op: ifOp); |
2700 | return success(); |
2701 | } |
2702 | }; |
2703 | |
2704 | /// Removes affine.if cond if the condition is always true or false in certain |
2705 | /// trivial cases. Promotes the then/else block in the parent operation block. |
2706 | struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> { |
2707 | using OpRewritePattern<AffineIfOp>::OpRewritePattern; |
2708 | |
2709 | LogicalResult matchAndRewrite(AffineIfOp op, |
2710 | PatternRewriter &rewriter) const override { |
2711 | |
2712 | auto isTriviallyFalse = [](IntegerSet iSet) { |
2713 | return iSet.isEmptyIntegerSet(); |
2714 | }; |
2715 | |
2716 | auto isTriviallyTrue = [](IntegerSet iSet) { |
2717 | return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 && |
2718 | iSet.getConstraint(idx: 0) == 0); |
2719 | }; |
2720 | |
2721 | IntegerSet affineIfConditions = op.getIntegerSet(); |
2722 | Block *blockToMove; |
2723 | if (isTriviallyFalse(affineIfConditions)) { |
2724 | // The absence, or equivalently, the emptiness of the else region need not |
2725 | // be checked when affine.if is returning results because if an affine.if |
2726 | // operation is returning results, it always has a non-empty else region. |
2727 | if (op.getNumResults() == 0 && !op.hasElse()) { |
2728 | // If the else region is absent, or equivalently, empty, remove the |
2729 | // affine.if operation (which is not returning any results). |
2730 | rewriter.eraseOp(op: op); |
2731 | return success(); |
2732 | } |
2733 | blockToMove = op.getElseBlock(); |
2734 | } else if (isTriviallyTrue(affineIfConditions)) { |
2735 | blockToMove = op.getThenBlock(); |
2736 | } else { |
2737 | return failure(); |
2738 | } |
2739 | Operation *blockToMoveTerminator = blockToMove->getTerminator(); |
2740 | // Promote the "blockToMove" block to the parent operation block between the |
2741 | // prologue and epilogue of "op". |
2742 | rewriter.inlineBlockBefore(blockToMove, op); |
2743 | // Replace the "op" operation with the operands of the |
2744 | // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is |
2745 | // the affine.yield operation present in the "blockToMove" block. It has no |
2746 | // operands when affine.if is not returning results and therefore, in that |
2747 | // case, replaceOp just erases "op". When affine.if is not returning |
2748 | // results, the affine.yield operation can be omitted. It gets inserted |
2749 | // implicitly. |
2750 | rewriter.replaceOp(op, blockToMoveTerminator->getOperands()); |
2751 | // Erase the "blockToMoveTerminator" operation since it is now in the parent |
2752 | // operation block, which already has its own terminator. |
2753 | rewriter.eraseOp(op: blockToMoveTerminator); |
2754 | return success(); |
2755 | } |
2756 | }; |
2757 | } // namespace |
2758 | |
2759 | /// AffineIfOp has two regions -- `then` and `else`. The flow of data should be |
2760 | /// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp |
2761 | void AffineIfOp::getSuccessorRegions( |
2762 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
2763 | // If the predecessor is an AffineIfOp, then branching into both `then` and |
2764 | // `else` region is valid. |
2765 | if (point.isParent()) { |
2766 | regions.reserve(2); |
2767 | regions.push_back( |
2768 | RegionSuccessor(&getThenRegion(), getThenRegion().getArguments())); |
2769 | // If the "else" region is empty, branch bach into parent. |
2770 | if (getElseRegion().empty()) { |
2771 | regions.push_back(getResults()); |
2772 | } else { |
2773 | regions.push_back( |
2774 | RegionSuccessor(&getElseRegion(), getElseRegion().getArguments())); |
2775 | } |
2776 | return; |
2777 | } |
2778 | |
2779 | // If the predecessor is the `else`/`then` region, then branching into parent |
2780 | // op is valid. |
2781 | regions.push_back(RegionSuccessor(getResults())); |
2782 | } |
2783 | |
2784 | LogicalResult AffineIfOp::verify() { |
2785 | // Verify that we have a condition attribute. |
2786 | // FIXME: This should be specified in the arguments list in ODS. |
2787 | auto conditionAttr = |
2788 | (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName()); |
2789 | if (!conditionAttr) |
2790 | return emitOpError("requires an integer set attribute named 'condition'" ); |
2791 | |
2792 | // Verify that there are enough operands for the condition. |
2793 | IntegerSet condition = conditionAttr.getValue(); |
2794 | if (getNumOperands() != condition.getNumInputs()) |
2795 | return emitOpError("operand count and condition integer set dimension and " |
2796 | "symbol count must match" ); |
2797 | |
2798 | // Verify that the operands are valid dimension/symbols. |
2799 | if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(), |
2800 | condition.getNumDims()))) |
2801 | return failure(); |
2802 | |
2803 | return success(); |
2804 | } |
2805 | |
2806 | ParseResult AffineIfOp::parse(OpAsmParser &parser, OperationState &result) { |
2807 | // Parse the condition attribute set. |
2808 | IntegerSetAttr conditionAttr; |
2809 | unsigned numDims; |
2810 | if (parser.parseAttribute(conditionAttr, |
2811 | AffineIfOp::getConditionAttrStrName(), |
2812 | result.attributes) || |
2813 | parseDimAndSymbolList(parser, result.operands, numDims)) |
2814 | return failure(); |
2815 | |
2816 | // Verify the condition operands. |
2817 | auto set = conditionAttr.getValue(); |
2818 | if (set.getNumDims() != numDims) |
2819 | return parser.emitError( |
2820 | parser.getNameLoc(), |
2821 | "dim operand count and integer set dim count must match" ); |
2822 | if (numDims + set.getNumSymbols() != result.operands.size()) |
2823 | return parser.emitError( |
2824 | parser.getNameLoc(), |
2825 | "symbol operand count and integer set symbol count must match" ); |
2826 | |
2827 | if (parser.parseOptionalArrowTypeList(result.types)) |
2828 | return failure(); |
2829 | |
2830 | // Create the regions for 'then' and 'else'. The latter must be created even |
2831 | // if it remains empty for the validity of the operation. |
2832 | result.regions.reserve(2); |
2833 | Region *thenRegion = result.addRegion(); |
2834 | Region *elseRegion = result.addRegion(); |
2835 | |
2836 | // Parse the 'then' region. |
2837 | if (parser.parseRegion(*thenRegion, {}, {})) |
2838 | return failure(); |
2839 | AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(), |
2840 | result.location); |
2841 | |
2842 | // If we find an 'else' keyword then parse the 'else' region. |
2843 | if (!parser.parseOptionalKeyword("else" )) { |
2844 | if (parser.parseRegion(*elseRegion, {}, {})) |
2845 | return failure(); |
2846 | AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(), |
2847 | result.location); |
2848 | } |
2849 | |
2850 | // Parse the optional attribute list. |
2851 | if (parser.parseOptionalAttrDict(result.attributes)) |
2852 | return failure(); |
2853 | |
2854 | return success(); |
2855 | } |
2856 | |
2857 | void AffineIfOp::print(OpAsmPrinter &p) { |
2858 | auto conditionAttr = |
2859 | (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName()); |
2860 | p << " " << conditionAttr; |
2861 | printDimAndSymbolList(operand_begin(), operand_end(), |
2862 | conditionAttr.getValue().getNumDims(), p); |
2863 | p.printOptionalArrowTypeList(getResultTypes()); |
2864 | p << ' '; |
2865 | p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false, |
2866 | /*printBlockTerminators=*/getNumResults()); |
2867 | |
2868 | // Print the 'else' regions if it has any blocks. |
2869 | auto &elseRegion = this->getElseRegion(); |
2870 | if (!elseRegion.empty()) { |
2871 | p << " else " ; |
2872 | p.printRegion(elseRegion, |
2873 | /*printEntryBlockArgs=*/false, |
2874 | /*printBlockTerminators=*/getNumResults()); |
2875 | } |
2876 | |
2877 | // Print the attribute list. |
2878 | p.printOptionalAttrDict((*this)->getAttrs(), |
2879 | /*elidedAttrs=*/getConditionAttrStrName()); |
2880 | } |
2881 | |
2882 | IntegerSet AffineIfOp::getIntegerSet() { |
2883 | return (*this) |
2884 | ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName()) |
2885 | .getValue(); |
2886 | } |
2887 | |
2888 | void AffineIfOp::setIntegerSet(IntegerSet newSet) { |
2889 | (*this)->setAttr(getConditionAttrStrName(), IntegerSetAttr::get(newSet)); |
2890 | } |
2891 | |
2892 | void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) { |
2893 | setIntegerSet(set); |
2894 | (*this)->setOperands(operands); |
2895 | } |
2896 | |
2897 | void AffineIfOp::build(OpBuilder &builder, OperationState &result, |
2898 | TypeRange resultTypes, IntegerSet set, ValueRange args, |
2899 | bool withElseRegion) { |
2900 | assert(resultTypes.empty() || withElseRegion); |
2901 | OpBuilder::InsertionGuard guard(builder); |
2902 | |
2903 | result.addTypes(resultTypes); |
2904 | result.addOperands(args); |
2905 | result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set)); |
2906 | |
2907 | Region *thenRegion = result.addRegion(); |
2908 | builder.createBlock(thenRegion); |
2909 | if (resultTypes.empty()) |
2910 | AffineIfOp::ensureTerminator(*thenRegion, builder, result.location); |
2911 | |
2912 | Region *elseRegion = result.addRegion(); |
2913 | if (withElseRegion) { |
2914 | builder.createBlock(elseRegion); |
2915 | if (resultTypes.empty()) |
2916 | AffineIfOp::ensureTerminator(*elseRegion, builder, result.location); |
2917 | } |
2918 | } |
2919 | |
2920 | void AffineIfOp::build(OpBuilder &builder, OperationState &result, |
2921 | IntegerSet set, ValueRange args, bool withElseRegion) { |
2922 | AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args, |
2923 | withElseRegion); |
2924 | } |
2925 | |
2926 | /// Compose any affine.apply ops feeding into `operands` of the integer set |
2927 | /// `set` by composing the maps of such affine.apply ops with the integer |
2928 | /// set constraints. |
2929 | static void composeSetAndOperands(IntegerSet &set, |
2930 | SmallVectorImpl<Value> &operands) { |
2931 | // We will simply reuse the API of the map composition by viewing the LHSs of |
2932 | // the equalities and inequalities of `set` as the affine exprs of an affine |
2933 | // map. Convert to equivalent map, compose, and convert back to set. |
2934 | auto map = AffineMap::get(dimCount: set.getNumDims(), symbolCount: set.getNumSymbols(), |
2935 | results: set.getConstraints(), context: set.getContext()); |
2936 | // Check if any composition is possible. |
2937 | if (llvm::none_of(Range&: operands, |
2938 | P: [](Value v) { return v.getDefiningOp<AffineApplyOp>(); })) |
2939 | return; |
2940 | |
2941 | composeAffineMapAndOperands(map: &map, operands: &operands); |
2942 | set = IntegerSet::get(dimCount: map.getNumDims(), symbolCount: map.getNumSymbols(), constraints: map.getResults(), |
2943 | eqFlags: set.getEqFlags()); |
2944 | } |
2945 | |
2946 | /// Canonicalize an affine if op's conditional (integer set + operands). |
2947 | LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) { |
2948 | auto set = getIntegerSet(); |
2949 | SmallVector<Value, 4> operands(getOperands()); |
2950 | composeSetAndOperands(set, operands); |
2951 | canonicalizeSetAndOperands(&set, &operands); |
2952 | |
2953 | // Check if the canonicalization or composition led to any change. |
2954 | if (getIntegerSet() == set && llvm::equal(operands, getOperands())) |
2955 | return failure(); |
2956 | |
2957 | setConditional(set, operands); |
2958 | return success(); |
2959 | } |
2960 | |
2961 | void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results, |
2962 | MLIRContext *context) { |
2963 | results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context); |
2964 | } |
2965 | |
2966 | //===----------------------------------------------------------------------===// |
2967 | // AffineLoadOp |
2968 | //===----------------------------------------------------------------------===// |
2969 | |
2970 | void AffineLoadOp::build(OpBuilder &builder, OperationState &result, |
2971 | AffineMap map, ValueRange operands) { |
2972 | assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands" ); |
2973 | result.addOperands(operands); |
2974 | if (map) |
2975 | result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map)); |
2976 | auto memrefType = llvm::cast<MemRefType>(operands[0].getType()); |
2977 | result.types.push_back(memrefType.getElementType()); |
2978 | } |
2979 | |
2980 | void AffineLoadOp::build(OpBuilder &builder, OperationState &result, |
2981 | Value memref, AffineMap map, ValueRange mapOperands) { |
2982 | assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info" ); |
2983 | result.addOperands(memref); |
2984 | result.addOperands(mapOperands); |
2985 | auto memrefType = llvm::cast<MemRefType>(memref.getType()); |
2986 | result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map)); |
2987 | result.types.push_back(memrefType.getElementType()); |
2988 | } |
2989 | |
2990 | void AffineLoadOp::build(OpBuilder &builder, OperationState &result, |
2991 | Value memref, ValueRange indices) { |
2992 | auto memrefType = llvm::cast<MemRefType>(memref.getType()); |
2993 | int64_t rank = memrefType.getRank(); |
2994 | // Create identity map for memrefs with at least one dimension or () -> () |
2995 | // for zero-dimensional memrefs. |
2996 | auto map = |
2997 | rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap(); |
2998 | build(builder, result, memref, map, indices); |
2999 | } |
3000 | |
3001 | ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) { |
3002 | auto &builder = parser.getBuilder(); |
3003 | auto indexTy = builder.getIndexType(); |
3004 | |
3005 | MemRefType type; |
3006 | OpAsmParser::UnresolvedOperand memrefInfo; |
3007 | AffineMapAttr mapAttr; |
3008 | SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands; |
3009 | return failure( |
3010 | parser.parseOperand(memrefInfo) || |
3011 | parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, |
3012 | AffineLoadOp::getMapAttrStrName(), |
3013 | result.attributes) || |
3014 | parser.parseOptionalAttrDict(result.attributes) || |
3015 | parser.parseColonType(type) || |
3016 | parser.resolveOperand(memrefInfo, type, result.operands) || |
3017 | parser.resolveOperands(mapOperands, indexTy, result.operands) || |
3018 | parser.addTypeToList(type.getElementType(), result.types)); |
3019 | } |
3020 | |
3021 | void AffineLoadOp::print(OpAsmPrinter &p) { |
3022 | p << " " << getMemRef() << '['; |
3023 | if (AffineMapAttr mapAttr = |
3024 | (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName())) |
3025 | p.printAffineMapOfSSAIds(mapAttr, getMapOperands()); |
3026 | p << ']'; |
3027 | p.printOptionalAttrDict((*this)->getAttrs(), |
3028 | /*elidedAttrs=*/{getMapAttrStrName()}); |
3029 | p << " : " << getMemRefType(); |
3030 | } |
3031 | |
3032 | /// Verify common indexing invariants of affine.load, affine.store, |
3033 | /// affine.vector_load and affine.vector_store. |
3034 | static LogicalResult |
3035 | verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr, |
3036 | Operation::operand_range mapOperands, |
3037 | MemRefType memrefType, unsigned numIndexOperands) { |
3038 | AffineMap map = mapAttr.getValue(); |
3039 | if (map.getNumResults() != memrefType.getRank()) |
3040 | return op->emitOpError(message: "affine map num results must equal memref rank" ); |
3041 | if (map.getNumInputs() != numIndexOperands) |
3042 | return op->emitOpError(message: "expects as many subscripts as affine map inputs" ); |
3043 | |
3044 | Region *scope = getAffineScope(op); |
3045 | for (auto idx : mapOperands) { |
3046 | if (!idx.getType().isIndex()) |
3047 | return op->emitOpError(message: "index to load must have 'index' type" ); |
3048 | if (!isValidAffineIndexOperand(value: idx, region: scope)) |
3049 | return op->emitOpError( |
3050 | message: "index must be a valid dimension or symbol identifier" ); |
3051 | } |
3052 | |
3053 | return success(); |
3054 | } |
3055 | |
3056 | LogicalResult AffineLoadOp::verify() { |
3057 | auto memrefType = getMemRefType(); |
3058 | if (getType() != memrefType.getElementType()) |
3059 | return emitOpError("result type must match element type of memref" ); |
3060 | |
3061 | if (failed(verifyMemoryOpIndexing( |
3062 | getOperation(), |
3063 | (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()), |
3064 | getMapOperands(), memrefType, |
3065 | /*numIndexOperands=*/getNumOperands() - 1))) |
3066 | return failure(); |
3067 | |
3068 | return success(); |
3069 | } |
3070 | |
3071 | void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, |
3072 | MLIRContext *context) { |
3073 | results.add<SimplifyAffineOp<AffineLoadOp>>(context); |
3074 | } |
3075 | |
3076 | OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) { |
3077 | /// load(memrefcast) -> load |
3078 | if (succeeded(memref::foldMemRefCast(*this))) |
3079 | return getResult(); |
3080 | |
3081 | // Fold load from a global constant memref. |
3082 | auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>(); |
3083 | if (!getGlobalOp) |
3084 | return {}; |
3085 | // Get to the memref.global defining the symbol. |
3086 | auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>(); |
3087 | if (!symbolTableOp) |
3088 | return {}; |
3089 | auto global = dyn_cast_or_null<memref::GlobalOp>( |
3090 | SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr())); |
3091 | if (!global) |
3092 | return {}; |
3093 | |
3094 | // Check if the global memref is a constant. |
3095 | auto cstAttr = |
3096 | llvm::dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue()); |
3097 | if (!cstAttr) |
3098 | return {}; |
3099 | // If it's a splat constant, we can fold irrespective of indices. |
3100 | if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(cstAttr)) |
3101 | return splatAttr.getSplatValue<Attribute>(); |
3102 | // Otherwise, we can fold only if we know the indices. |
3103 | if (!getAffineMap().isConstant()) |
3104 | return {}; |
3105 | auto indices = llvm::to_vector<4>( |
3106 | llvm::map_range(getAffineMap().getConstantResults(), |
3107 | [](int64_t v) -> uint64_t { return v; })); |
3108 | return cstAttr.getValues<Attribute>()[indices]; |
3109 | } |
3110 | |
3111 | //===----------------------------------------------------------------------===// |
3112 | // AffineStoreOp |
3113 | //===----------------------------------------------------------------------===// |
3114 | |
3115 | void AffineStoreOp::build(OpBuilder &builder, OperationState &result, |
3116 | Value valueToStore, Value memref, AffineMap map, |
3117 | ValueRange mapOperands) { |
3118 | assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info" ); |
3119 | result.addOperands(valueToStore); |
3120 | result.addOperands(memref); |
3121 | result.addOperands(mapOperands); |
3122 | result.getOrAddProperties<Properties>().map = AffineMapAttr::get(map); |
3123 | } |
3124 | |
3125 | // Use identity map. |
3126 | void AffineStoreOp::build(OpBuilder &builder, OperationState &result, |
3127 | Value valueToStore, Value memref, |
3128 | ValueRange indices) { |
3129 | auto memrefType = llvm::cast<MemRefType>(memref.getType()); |
3130 | int64_t rank = memrefType.getRank(); |
3131 | // Create identity map for memrefs with at least one dimension or () -> () |
3132 | // for zero-dimensional memrefs. |
3133 | auto map = |
3134 | rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap(); |
3135 | build(builder, result, valueToStore, memref, map, indices); |
3136 | } |
3137 | |
3138 | ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) { |
3139 | auto indexTy = parser.getBuilder().getIndexType(); |
3140 | |
3141 | MemRefType type; |
3142 | OpAsmParser::UnresolvedOperand storeValueInfo; |
3143 | OpAsmParser::UnresolvedOperand memrefInfo; |
3144 | AffineMapAttr mapAttr; |
3145 | SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands; |
3146 | return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() || |
3147 | parser.parseOperand(memrefInfo) || |
3148 | parser.parseAffineMapOfSSAIds( |
3149 | mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(), |
3150 | result.attributes) || |
3151 | parser.parseOptionalAttrDict(result.attributes) || |
3152 | parser.parseColonType(type) || |
3153 | parser.resolveOperand(storeValueInfo, type.getElementType(), |
3154 | result.operands) || |
3155 | parser.resolveOperand(memrefInfo, type, result.operands) || |
3156 | parser.resolveOperands(mapOperands, indexTy, result.operands)); |
3157 | } |
3158 | |
3159 | void AffineStoreOp::print(OpAsmPrinter &p) { |
3160 | p << " " << getValueToStore(); |
3161 | p << ", " << getMemRef() << '['; |
3162 | if (AffineMapAttr mapAttr = |
3163 | (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName())) |
3164 | p.printAffineMapOfSSAIds(mapAttr, getMapOperands()); |
3165 | p << ']'; |
3166 | p.printOptionalAttrDict((*this)->getAttrs(), |
3167 | /*elidedAttrs=*/{getMapAttrStrName()}); |
3168 | p << " : " << getMemRefType(); |
3169 | } |
3170 | |
3171 | LogicalResult AffineStoreOp::verify() { |
3172 | // The value to store must have the same type as memref element type. |
3173 | auto memrefType = getMemRefType(); |
3174 | if (getValueToStore().getType() != memrefType.getElementType()) |
3175 | return emitOpError( |
3176 | "value to store must have the same type as memref element type" ); |
3177 | |
3178 | if (failed(verifyMemoryOpIndexing( |
3179 | getOperation(), |
3180 | (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()), |
3181 | getMapOperands(), memrefType, |
3182 | /*numIndexOperands=*/getNumOperands() - 2))) |
3183 | return failure(); |
3184 | |
3185 | return success(); |
3186 | } |
3187 | |
3188 | void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results, |
3189 | MLIRContext *context) { |
3190 | results.add<SimplifyAffineOp<AffineStoreOp>>(context); |
3191 | } |
3192 | |
3193 | LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor, |
3194 | SmallVectorImpl<OpFoldResult> &results) { |
3195 | /// store(memrefcast) -> store |
3196 | return memref::foldMemRefCast(*this, getValueToStore()); |
3197 | } |
3198 | |
3199 | //===----------------------------------------------------------------------===// |
3200 | // AffineMinMaxOpBase |
3201 | //===----------------------------------------------------------------------===// |
3202 | |
3203 | template <typename T> |
3204 | static LogicalResult verifyAffineMinMaxOp(T op) { |
3205 | // Verify that operand count matches affine map dimension and symbol count. |
3206 | if (op.getNumOperands() != |
3207 | op.getMap().getNumDims() + op.getMap().getNumSymbols()) |
3208 | return op.emitOpError( |
3209 | "operand count and affine map dimension and symbol count must match" ); |
3210 | return success(); |
3211 | } |
3212 | |
3213 | template <typename T> |
3214 | static void printAffineMinMaxOp(OpAsmPrinter &p, T op) { |
3215 | p << ' ' << op->getAttr(T::getMapAttrStrName()); |
3216 | auto operands = op.getOperands(); |
3217 | unsigned numDims = op.getMap().getNumDims(); |
3218 | p << '(' << operands.take_front(numDims) << ')'; |
3219 | |
3220 | if (operands.size() != numDims) |
3221 | p << '[' << operands.drop_front(numDims) << ']'; |
3222 | p.printOptionalAttrDict(attrs: op->getAttrs(), |
3223 | /*elidedAttrs=*/{T::getMapAttrStrName()}); |
3224 | } |
3225 | |
3226 | template <typename T> |
3227 | static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, |
3228 | OperationState &result) { |
3229 | auto &builder = parser.getBuilder(); |
3230 | auto indexType = builder.getIndexType(); |
3231 | SmallVector<OpAsmParser::UnresolvedOperand, 8> dimInfos; |
3232 | SmallVector<OpAsmParser::UnresolvedOperand, 8> symInfos; |
3233 | AffineMapAttr mapAttr; |
3234 | return failure( |
3235 | parser.parseAttribute(mapAttr, T::getMapAttrStrName(), |
3236 | result.attributes) || |
3237 | parser.parseOperandList(result&: dimInfos, delimiter: OpAsmParser::Delimiter::Paren) || |
3238 | parser.parseOperandList(result&: symInfos, |
3239 | delimiter: OpAsmParser::Delimiter::OptionalSquare) || |
3240 | parser.parseOptionalAttrDict(result&: result.attributes) || |
3241 | parser.resolveOperands(dimInfos, indexType, result.operands) || |
3242 | parser.resolveOperands(symInfos, indexType, result.operands) || |
3243 | parser.addTypeToList(type: indexType, result&: result.types)); |
3244 | } |
3245 | |
3246 | /// Fold an affine min or max operation with the given operands. The operand |
3247 | /// list may contain nulls, which are interpreted as the operand not being a |
3248 | /// constant. |
3249 | template <typename T> |
3250 | static OpFoldResult foldMinMaxOp(T op, ArrayRef<Attribute> operands) { |
3251 | static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value, |
3252 | "expected affine min or max op" ); |
3253 | |
3254 | // Fold the affine map. |
3255 | // TODO: Fold more cases: |
3256 | // min(some_affine, some_affine + constant, ...), etc. |
3257 | SmallVector<int64_t, 2> results; |
3258 | auto foldedMap = op.getMap().partialConstantFold(operands, &results); |
3259 | |
3260 | if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity()) |
3261 | return op.getOperand(0); |
3262 | |
3263 | // If some of the map results are not constant, try changing the map in-place. |
3264 | if (results.empty()) { |
3265 | // If the map is the same, report that folding did not happen. |
3266 | if (foldedMap == op.getMap()) |
3267 | return {}; |
3268 | op->setAttr("map" , AffineMapAttr::get(foldedMap)); |
3269 | return op.getResult(); |
3270 | } |
3271 | |
3272 | // Otherwise, completely fold the op into a constant. |
3273 | auto resultIt = std::is_same<T, AffineMinOp>::value |
3274 | ? llvm::min_element(Range&: results) |
3275 | : llvm::max_element(Range&: results); |
3276 | if (resultIt == results.end()) |
3277 | return {}; |
3278 | return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt); |
3279 | } |
3280 | |
3281 | /// Remove duplicated expressions in affine min/max ops. |
3282 | template <typename T> |
3283 | struct DeduplicateAffineMinMaxExpressions : public OpRewritePattern<T> { |
3284 | using OpRewritePattern<T>::OpRewritePattern; |
3285 | |
3286 | LogicalResult matchAndRewrite(T affineOp, |
3287 | PatternRewriter &rewriter) const override { |
3288 | AffineMap oldMap = affineOp.getAffineMap(); |
3289 | |
3290 | SmallVector<AffineExpr, 4> newExprs; |
3291 | for (AffineExpr expr : oldMap.getResults()) { |
3292 | // This is a linear scan over newExprs, but it should be fine given that |
3293 | // we typically just have a few expressions per op. |
3294 | if (!llvm::is_contained(Range&: newExprs, Element: expr)) |
3295 | newExprs.push_back(Elt: expr); |
3296 | } |
3297 | |
3298 | if (newExprs.size() == oldMap.getNumResults()) |
3299 | return failure(); |
3300 | |
3301 | auto newMap = AffineMap::get(dimCount: oldMap.getNumDims(), symbolCount: oldMap.getNumSymbols(), |
3302 | results: newExprs, context: rewriter.getContext()); |
3303 | rewriter.replaceOpWithNewOp<T>(affineOp, newMap, affineOp.getMapOperands()); |
3304 | |
3305 | return success(); |
3306 | } |
3307 | }; |
3308 | |
3309 | /// Merge an affine min/max op to its consumers if its consumer is also an |
3310 | /// affine min/max op. |
3311 | /// |
3312 | /// This pattern requires the producer affine min/max op is bound to a |
3313 | /// dimension/symbol that is used as a standalone expression in the consumer |
3314 | /// affine op's map. |
3315 | /// |
3316 | /// For example, a pattern like the following: |
3317 | /// |
3318 | /// %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1] |
3319 | /// %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2] |
3320 | /// |
3321 | /// Can be turned into: |
3322 | /// |
3323 | /// %1 = affine.min affine_map< |
3324 | /// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1] |
3325 | template <typename T> |
3326 | struct MergeAffineMinMaxOp : public OpRewritePattern<T> { |
3327 | using OpRewritePattern<T>::OpRewritePattern; |
3328 | |
3329 | LogicalResult matchAndRewrite(T affineOp, |
3330 | PatternRewriter &rewriter) const override { |
3331 | AffineMap oldMap = affineOp.getAffineMap(); |
3332 | ValueRange dimOperands = |
3333 | affineOp.getMapOperands().take_front(oldMap.getNumDims()); |
3334 | ValueRange symOperands = |
3335 | affineOp.getMapOperands().take_back(oldMap.getNumSymbols()); |
3336 | |
3337 | auto newDimOperands = llvm::to_vector<8>(Range&: dimOperands); |
3338 | auto newSymOperands = llvm::to_vector<8>(Range&: symOperands); |
3339 | SmallVector<AffineExpr, 4> newExprs; |
3340 | SmallVector<T, 4> producerOps; |
3341 | |
3342 | // Go over each expression to see whether it's a single dimension/symbol |
3343 | // with the corresponding operand which is the result of another affine |
3344 | // min/max op. If So it can be merged into this affine op. |
3345 | for (AffineExpr expr : oldMap.getResults()) { |
3346 | if (auto symExpr = dyn_cast<AffineSymbolExpr>(Val&: expr)) { |
3347 | Value symValue = symOperands[symExpr.getPosition()]; |
3348 | if (auto producerOp = symValue.getDefiningOp<T>()) { |
3349 | producerOps.push_back(producerOp); |
3350 | continue; |
3351 | } |
3352 | } else if (auto dimExpr = dyn_cast<AffineDimExpr>(Val&: expr)) { |
3353 | Value dimValue = dimOperands[dimExpr.getPosition()]; |
3354 | if (auto producerOp = dimValue.getDefiningOp<T>()) { |
3355 | producerOps.push_back(producerOp); |
3356 | continue; |
3357 | } |
3358 | } |
3359 | // For the above cases we will remove the expression by merging the |
3360 | // producer affine min/max's affine expressions. Otherwise we need to |
3361 | // keep the existing expression. |
3362 | newExprs.push_back(Elt: expr); |
3363 | } |
3364 | |
3365 | if (producerOps.empty()) |
3366 | return failure(); |
3367 | |
3368 | unsigned numUsedDims = oldMap.getNumDims(); |
3369 | unsigned numUsedSyms = oldMap.getNumSymbols(); |
3370 | |
3371 | // Now go over all producer affine ops and merge their expressions. |
3372 | for (T producerOp : producerOps) { |
3373 | AffineMap producerMap = producerOp.getAffineMap(); |
3374 | unsigned numProducerDims = producerMap.getNumDims(); |
3375 | unsigned numProducerSyms = producerMap.getNumSymbols(); |
3376 | |
3377 | // Collect all dimension/symbol values. |
3378 | ValueRange dimValues = |
3379 | producerOp.getMapOperands().take_front(numProducerDims); |
3380 | ValueRange symValues = |
3381 | producerOp.getMapOperands().take_back(numProducerSyms); |
3382 | newDimOperands.append(in_start: dimValues.begin(), in_end: dimValues.end()); |
3383 | newSymOperands.append(in_start: symValues.begin(), in_end: symValues.end()); |
3384 | |
3385 | // For expressions we need to shift to avoid overlap. |
3386 | for (AffineExpr expr : producerMap.getResults()) { |
3387 | newExprs.push_back(Elt: expr.shiftDims(numDims: numProducerDims, shift: numUsedDims) |
3388 | .shiftSymbols(numSymbols: numProducerSyms, shift: numUsedSyms)); |
3389 | } |
3390 | |
3391 | numUsedDims += numProducerDims; |
3392 | numUsedSyms += numProducerSyms; |
3393 | } |
3394 | |
3395 | auto newMap = AffineMap::get(dimCount: numUsedDims, symbolCount: numUsedSyms, results: newExprs, |
3396 | context: rewriter.getContext()); |
3397 | auto newOperands = |
3398 | llvm::to_vector<8>(Range: llvm::concat<Value>(Ranges&: newDimOperands, Ranges&: newSymOperands)); |
3399 | rewriter.replaceOpWithNewOp<T>(affineOp, newMap, newOperands); |
3400 | |
3401 | return success(); |
3402 | } |
3403 | }; |
3404 | |
3405 | /// Canonicalize the result expression order of an affine map and return success |
3406 | /// if the order changed. |
3407 | /// |
3408 | /// The function flattens the map's affine expressions to coefficient arrays and |
3409 | /// sorts them in lexicographic order. A coefficient array contains a multiplier |
3410 | /// for every dimension/symbol and a constant term. The canonicalization fails |
3411 | /// if a result expression is not pure or if the flattening requires local |
3412 | /// variables that, unlike dimensions and symbols, have no global order. |
3413 | static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map) { |
3414 | SmallVector<SmallVector<int64_t>> flattenedExprs; |
3415 | for (const AffineExpr &resultExpr : map.getResults()) { |
3416 | // Fail if the expression is not pure. |
3417 | if (!resultExpr.isPureAffine()) |
3418 | return failure(); |
3419 | |
3420 | SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols()); |
3421 | auto flattenResult = flattener.walkPostOrder(expr: resultExpr); |
3422 | if (failed(result: flattenResult)) |
3423 | return failure(); |
3424 | |
3425 | // Fail if the flattened expression has local variables. |
3426 | if (flattener.operandExprStack.back().size() != |
3427 | map.getNumDims() + map.getNumSymbols() + 1) |
3428 | return failure(); |
3429 | |
3430 | flattenedExprs.emplace_back(Args: flattener.operandExprStack.back().begin(), |
3431 | Args: flattener.operandExprStack.back().end()); |
3432 | } |
3433 | |
3434 | // Fail if sorting is not necessary. |
3435 | if (llvm::is_sorted(Range&: flattenedExprs)) |
3436 | return failure(); |
3437 | |
3438 | // Reorder the result expressions according to their flattened form. |
3439 | SmallVector<unsigned> resultPermutation = |
3440 | llvm::to_vector(Range: llvm::seq<unsigned>(Begin: 0, End: map.getNumResults())); |
3441 | llvm::sort(C&: resultPermutation, Comp: [&](unsigned lhs, unsigned rhs) { |
3442 | return flattenedExprs[lhs] < flattenedExprs[rhs]; |
3443 | }); |
3444 | SmallVector<AffineExpr> newExprs; |
3445 | for (unsigned idx : resultPermutation) |
3446 | newExprs.push_back(Elt: map.getResult(idx)); |
3447 | |
3448 | map = AffineMap::get(dimCount: map.getNumDims(), symbolCount: map.getNumSymbols(), results: newExprs, |
3449 | context: map.getContext()); |
3450 | return success(); |
3451 | } |
3452 | |
3453 | /// Canonicalize the affine map result expression order of an affine min/max |
3454 | /// operation. |
3455 | /// |
3456 | /// The pattern calls `canonicalizeMapExprAndTermOrder` to order the result |
3457 | /// expressions and replaces the operation if the order changed. |
3458 | /// |
3459 | /// For example, the following operation: |
3460 | /// |
3461 | /// %0 = affine.min affine_map<(d0, d1) -> (d0 + d1, d1 + 16, 32)> (%i0, %i1) |
3462 | /// |
3463 | /// Turns into: |
3464 | /// |
3465 | /// %0 = affine.min affine_map<(d0, d1) -> (32, d1 + 16, d0 + d1)> (%i0, %i1) |
3466 | template <typename T> |
3467 | struct CanonicalizeAffineMinMaxOpExprAndTermOrder : public OpRewritePattern<T> { |
3468 | using OpRewritePattern<T>::OpRewritePattern; |
3469 | |
3470 | LogicalResult matchAndRewrite(T affineOp, |
3471 | PatternRewriter &rewriter) const override { |
3472 | AffineMap map = affineOp.getAffineMap(); |
3473 | if (failed(result: canonicalizeMapExprAndTermOrder(map))) |
3474 | return failure(); |
3475 | rewriter.replaceOpWithNewOp<T>(affineOp, map, affineOp.getMapOperands()); |
3476 | return success(); |
3477 | } |
3478 | }; |
3479 | |
3480 | template <typename T> |
3481 | struct CanonicalizeSingleResultAffineMinMaxOp : public OpRewritePattern<T> { |
3482 | using OpRewritePattern<T>::OpRewritePattern; |
3483 | |
3484 | LogicalResult matchAndRewrite(T affineOp, |
3485 | PatternRewriter &rewriter) const override { |
3486 | if (affineOp.getMap().getNumResults() != 1) |
3487 | return failure(); |
3488 | rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.getMap(), |
3489 | affineOp.getOperands()); |
3490 | return success(); |
3491 | } |
3492 | }; |
3493 | |
3494 | //===----------------------------------------------------------------------===// |
3495 | // AffineMinOp |
3496 | //===----------------------------------------------------------------------===// |
3497 | // |
3498 | // %0 = affine.min (d0) -> (1000, d0 + 512) (%i0) |
3499 | // |
3500 | |
3501 | OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) { |
3502 | return foldMinMaxOp(*this, adaptor.getOperands()); |
3503 | } |
3504 | |
3505 | void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
3506 | MLIRContext *context) { |
3507 | patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMinOp>, |
3508 | DeduplicateAffineMinMaxExpressions<AffineMinOp>, |
3509 | MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>, |
3510 | CanonicalizeAffineMinMaxOpExprAndTermOrder<AffineMinOp>>( |
3511 | context); |
3512 | } |
3513 | |
3514 | LogicalResult AffineMinOp::verify() { return verifyAffineMinMaxOp(*this); } |
3515 | |
3516 | ParseResult AffineMinOp::parse(OpAsmParser &parser, OperationState &result) { |
3517 | return parseAffineMinMaxOp<AffineMinOp>(parser, result); |
3518 | } |
3519 | |
3520 | void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); } |
3521 | |
3522 | //===----------------------------------------------------------------------===// |
3523 | // AffineMaxOp |
3524 | //===----------------------------------------------------------------------===// |
3525 | // |
3526 | // %0 = affine.max (d0) -> (1000, d0 + 512) (%i0) |
3527 | // |
3528 | |
3529 | OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) { |
3530 | return foldMinMaxOp(*this, adaptor.getOperands()); |
3531 | } |
3532 | |
3533 | void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
3534 | MLIRContext *context) { |
3535 | patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMaxOp>, |
3536 | DeduplicateAffineMinMaxExpressions<AffineMaxOp>, |
3537 | MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>, |
3538 | CanonicalizeAffineMinMaxOpExprAndTermOrder<AffineMaxOp>>( |
3539 | context); |
3540 | } |
3541 | |
3542 | LogicalResult AffineMaxOp::verify() { return verifyAffineMinMaxOp(*this); } |
3543 | |
3544 | ParseResult AffineMaxOp::parse(OpAsmParser &parser, OperationState &result) { |
3545 | return parseAffineMinMaxOp<AffineMaxOp>(parser, result); |
3546 | } |
3547 | |
3548 | void AffineMaxOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); } |
3549 | |
3550 | //===----------------------------------------------------------------------===// |
3551 | // AffinePrefetchOp |
3552 | //===----------------------------------------------------------------------===// |
3553 | |
3554 | // |
3555 | // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32> |
3556 | // |
3557 | ParseResult AffinePrefetchOp::parse(OpAsmParser &parser, |
3558 | OperationState &result) { |
3559 | auto &builder = parser.getBuilder(); |
3560 | auto indexTy = builder.getIndexType(); |
3561 | |
3562 | MemRefType type; |
3563 | OpAsmParser::UnresolvedOperand memrefInfo; |
3564 | IntegerAttr hintInfo; |
3565 | auto i32Type = parser.getBuilder().getIntegerType(32); |
3566 | StringRef readOrWrite, cacheType; |
3567 | |
3568 | AffineMapAttr mapAttr; |
3569 | SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands; |
3570 | if (parser.parseOperand(memrefInfo) || |
3571 | parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, |
3572 | AffinePrefetchOp::getMapAttrStrName(), |
3573 | result.attributes) || |
3574 | parser.parseComma() || parser.parseKeyword(&readOrWrite) || |
3575 | parser.parseComma() || parser.parseKeyword("locality" ) || |
3576 | parser.parseLess() || |
3577 | parser.parseAttribute(hintInfo, i32Type, |
3578 | AffinePrefetchOp::getLocalityHintAttrStrName(), |
3579 | result.attributes) || |
3580 | parser.parseGreater() || parser.parseComma() || |
3581 | parser.parseKeyword(&cacheType) || |
3582 | parser.parseOptionalAttrDict(result.attributes) || |
3583 | parser.parseColonType(type) || |
3584 | parser.resolveOperand(memrefInfo, type, result.operands) || |
3585 | parser.resolveOperands(mapOperands, indexTy, result.operands)) |
3586 | return failure(); |
3587 | |
3588 | if (!readOrWrite.equals("read" ) && !readOrWrite.equals("write" )) |
3589 | return parser.emitError(parser.getNameLoc(), |
3590 | "rw specifier has to be 'read' or 'write'" ); |
3591 | result.addAttribute( |
3592 | AffinePrefetchOp::getIsWriteAttrStrName(), |
3593 | parser.getBuilder().getBoolAttr(readOrWrite.equals("write" ))); |
3594 | |
3595 | if (!cacheType.equals("data" ) && !cacheType.equals("instr" )) |
3596 | return parser.emitError(parser.getNameLoc(), |
3597 | "cache type has to be 'data' or 'instr'" ); |
3598 | |
3599 | result.addAttribute( |
3600 | AffinePrefetchOp::getIsDataCacheAttrStrName(), |
3601 | parser.getBuilder().getBoolAttr(cacheType.equals("data" ))); |
3602 | |
3603 | return success(); |
3604 | } |
3605 | |
3606 | void AffinePrefetchOp::print(OpAsmPrinter &p) { |
3607 | p << " " << getMemref() << '['; |
3608 | AffineMapAttr mapAttr = |
3609 | (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()); |
3610 | if (mapAttr) |
3611 | p.printAffineMapOfSSAIds(mapAttr, getMapOperands()); |
3612 | p << ']' << ", " << (getIsWrite() ? "write" : "read" ) << ", " |
3613 | << "locality<" << getLocalityHint() << ">, " |
3614 | << (getIsDataCache() ? "data" : "instr" ); |
3615 | p.printOptionalAttrDict( |
3616 | (*this)->getAttrs(), |
3617 | /*elidedAttrs=*/{getMapAttrStrName(), getLocalityHintAttrStrName(), |
3618 | getIsDataCacheAttrStrName(), getIsWriteAttrStrName()}); |
3619 | p << " : " << getMemRefType(); |
3620 | } |
3621 | |
3622 | LogicalResult AffinePrefetchOp::verify() { |
3623 | auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()); |
3624 | if (mapAttr) { |
3625 | AffineMap map = mapAttr.getValue(); |
3626 | if (map.getNumResults() != getMemRefType().getRank()) |
3627 | return emitOpError("affine.prefetch affine map num results must equal" |
3628 | " memref rank" ); |
3629 | if (map.getNumInputs() + 1 != getNumOperands()) |
3630 | return emitOpError("too few operands" ); |
3631 | } else { |
3632 | if (getNumOperands() != 1) |
3633 | return emitOpError("too few operands" ); |
3634 | } |
3635 | |
3636 | Region *scope = getAffineScope(*this); |
3637 | for (auto idx : getMapOperands()) { |
3638 | if (!isValidAffineIndexOperand(idx, scope)) |
3639 | return emitOpError( |
3640 | "index must be a valid dimension or symbol identifier" ); |
3641 | } |
3642 | return success(); |
3643 | } |
3644 | |
3645 | void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results, |
3646 | MLIRContext *context) { |
3647 | // prefetch(memrefcast) -> prefetch |
3648 | results.add<SimplifyAffineOp<AffinePrefetchOp>>(context); |
3649 | } |
3650 | |
3651 | LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor, |
3652 | SmallVectorImpl<OpFoldResult> &results) { |
3653 | /// prefetch(memrefcast) -> prefetch |
3654 | return memref::foldMemRefCast(*this); |
3655 | } |
3656 | |
3657 | //===----------------------------------------------------------------------===// |
3658 | // AffineParallelOp |
3659 | //===----------------------------------------------------------------------===// |
3660 | |
3661 | void AffineParallelOp::build(OpBuilder &builder, OperationState &result, |
3662 | TypeRange resultTypes, |
3663 | ArrayRef<arith::AtomicRMWKind> reductions, |
3664 | ArrayRef<int64_t> ranges) { |
3665 | SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0)); |
3666 | auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) { |
3667 | return builder.getConstantAffineMap(value); |
3668 | })); |
3669 | SmallVector<int64_t> steps(ranges.size(), 1); |
3670 | build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs, |
3671 | /*ubArgs=*/{}, steps); |
3672 | } |
3673 | |
3674 | void AffineParallelOp::build(OpBuilder &builder, OperationState &result, |
3675 | TypeRange resultTypes, |
3676 | ArrayRef<arith::AtomicRMWKind> reductions, |
3677 | ArrayRef<AffineMap> lbMaps, ValueRange lbArgs, |
3678 | ArrayRef<AffineMap> ubMaps, ValueRange ubArgs, |
3679 | ArrayRef<int64_t> steps) { |
3680 | assert(llvm::all_of(lbMaps, |
3681 | [lbMaps](AffineMap m) { |
3682 | return m.getNumDims() == lbMaps[0].getNumDims() && |
3683 | m.getNumSymbols() == lbMaps[0].getNumSymbols(); |
3684 | }) && |
3685 | "expected all lower bounds maps to have the same number of dimensions " |
3686 | "and symbols" ); |
3687 | assert(llvm::all_of(ubMaps, |
3688 | [ubMaps](AffineMap m) { |
3689 | return m.getNumDims() == ubMaps[0].getNumDims() && |
3690 | m.getNumSymbols() == ubMaps[0].getNumSymbols(); |
3691 | }) && |
3692 | "expected all upper bounds maps to have the same number of dimensions " |
3693 | "and symbols" ); |
3694 | assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) && |
3695 | "expected lower bound maps to have as many inputs as lower bound " |
3696 | "operands" ); |
3697 | assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) && |
3698 | "expected upper bound maps to have as many inputs as upper bound " |
3699 | "operands" ); |
3700 | |
3701 | OpBuilder::InsertionGuard guard(builder); |
3702 | result.addTypes(resultTypes); |
3703 | |
3704 | // Convert the reductions to integer attributes. |
3705 | SmallVector<Attribute, 4> reductionAttrs; |
3706 | for (arith::AtomicRMWKind reduction : reductions) |
3707 | reductionAttrs.push_back( |
3708 | builder.getI64IntegerAttr(static_cast<int64_t>(reduction))); |
3709 | result.addAttribute(getReductionsAttrStrName(), |
3710 | builder.getArrayAttr(reductionAttrs)); |
3711 | |
3712 | // Concatenates maps defined in the same input space (same dimensions and |
3713 | // symbols), assumes there is at least one map. |
3714 | auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps, |
3715 | SmallVectorImpl<int32_t> &groups) { |
3716 | if (maps.empty()) |
3717 | return AffineMap::get(builder.getContext()); |
3718 | SmallVector<AffineExpr> exprs; |
3719 | groups.reserve(groups.size() + maps.size()); |
3720 | exprs.reserve(maps.size()); |
3721 | for (AffineMap m : maps) { |
3722 | llvm::append_range(exprs, m.getResults()); |
3723 | groups.push_back(m.getNumResults()); |
3724 | } |
3725 | return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs, |
3726 | maps[0].getContext()); |
3727 | }; |
3728 | |
3729 | // Set up the bounds. |
3730 | SmallVector<int32_t> lbGroups, ubGroups; |
3731 | AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups); |
3732 | AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups); |
3733 | result.addAttribute(getLowerBoundsMapAttrStrName(), |
3734 | AffineMapAttr::get(lbMap)); |
3735 | result.addAttribute(getLowerBoundsGroupsAttrStrName(), |
3736 | builder.getI32TensorAttr(lbGroups)); |
3737 | result.addAttribute(getUpperBoundsMapAttrStrName(), |
3738 | AffineMapAttr::get(ubMap)); |
3739 | result.addAttribute(getUpperBoundsGroupsAttrStrName(), |
3740 | builder.getI32TensorAttr(ubGroups)); |
3741 | result.addAttribute(getStepsAttrStrName(), builder.getI64ArrayAttr(steps)); |
3742 | result.addOperands(lbArgs); |
3743 | result.addOperands(ubArgs); |
3744 | |
3745 | // Create a region and a block for the body. |
3746 | auto *bodyRegion = result.addRegion(); |
3747 | Block *body = builder.createBlock(bodyRegion); |
3748 | |
3749 | // Add all the block arguments. |
3750 | for (unsigned i = 0, e = steps.size(); i < e; ++i) |
3751 | body->addArgument(IndexType::get(builder.getContext()), result.location); |
3752 | if (resultTypes.empty()) |
3753 | ensureTerminator(*bodyRegion, builder, result.location); |
3754 | } |
3755 | |
3756 | SmallVector<Region *> AffineParallelOp::getLoopRegions() { |
3757 | return {&getRegion()}; |
3758 | } |
3759 | |
3760 | unsigned AffineParallelOp::getNumDims() { return getSteps().size(); } |
3761 | |
3762 | AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() { |
3763 | return getOperands().take_front(getLowerBoundsMap().getNumInputs()); |
3764 | } |
3765 | |
3766 | AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() { |
3767 | return getOperands().drop_front(getLowerBoundsMap().getNumInputs()); |
3768 | } |
3769 | |
3770 | AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) { |
3771 | auto values = getLowerBoundsGroups().getValues<int32_t>(); |
3772 | unsigned start = 0; |
3773 | for (unsigned i = 0; i < pos; ++i) |
3774 | start += values[i]; |
3775 | return getLowerBoundsMap().getSliceMap(start, values[pos]); |
3776 | } |
3777 | |
3778 | AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) { |
3779 | auto values = getUpperBoundsGroups().getValues<int32_t>(); |
3780 | unsigned start = 0; |
3781 | for (unsigned i = 0; i < pos; ++i) |
3782 | start += values[i]; |
3783 | return getUpperBoundsMap().getSliceMap(start, values[pos]); |
3784 | } |
3785 | |
3786 | AffineValueMap AffineParallelOp::getLowerBoundsValueMap() { |
3787 | return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands()); |
3788 | } |
3789 | |
3790 | AffineValueMap AffineParallelOp::getUpperBoundsValueMap() { |
3791 | return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands()); |
3792 | } |
3793 | |
3794 | std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() { |
3795 | if (hasMinMaxBounds()) |
3796 | return std::nullopt; |
3797 | |
3798 | // Try to convert all the ranges to constant expressions. |
3799 | SmallVector<int64_t, 8> out; |
3800 | AffineValueMap rangesValueMap; |
3801 | AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(), |
3802 | &rangesValueMap); |
3803 | out.reserve(rangesValueMap.getNumResults()); |
3804 | for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) { |
3805 | auto expr = rangesValueMap.getResult(i); |
3806 | auto cst = dyn_cast<AffineConstantExpr>(expr); |
3807 | if (!cst) |
3808 | return std::nullopt; |
3809 | out.push_back(cst.getValue()); |
3810 | } |
3811 | return out; |
3812 | } |
3813 | |
3814 | Block *AffineParallelOp::getBody() { return &getRegion().front(); } |
3815 | |
3816 | OpBuilder AffineParallelOp::getBodyBuilder() { |
3817 | return OpBuilder(getBody(), std::prev(getBody()->end())); |
3818 | } |
3819 | |
3820 | void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) { |
3821 | assert(lbOperands.size() == map.getNumInputs() && |
3822 | "operands to map must match number of inputs" ); |
3823 | |
3824 | auto ubOperands = getUpperBoundsOperands(); |
3825 | |
3826 | SmallVector<Value, 4> newOperands(lbOperands); |
3827 | newOperands.append(ubOperands.begin(), ubOperands.end()); |
3828 | (*this)->setOperands(newOperands); |
3829 | |
3830 | setLowerBoundsMapAttr(AffineMapAttr::get(map)); |
3831 | } |
3832 | |
3833 | void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) { |
3834 | assert(ubOperands.size() == map.getNumInputs() && |
3835 | "operands to map must match number of inputs" ); |
3836 | |
3837 | SmallVector<Value, 4> newOperands(getLowerBoundsOperands()); |
3838 | newOperands.append(ubOperands.begin(), ubOperands.end()); |
3839 | (*this)->setOperands(newOperands); |
3840 | |
3841 | setUpperBoundsMapAttr(AffineMapAttr::get(map)); |
3842 | } |
3843 | |
3844 | void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) { |
3845 | setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps)); |
3846 | } |
3847 | |
3848 | // check whether resultType match op or not in affine.parallel |
3849 | static bool isResultTypeMatchAtomicRMWKind(Type resultType, |
3850 | arith::AtomicRMWKind op) { |
3851 | switch (op) { |
3852 | case arith::AtomicRMWKind::addf: |
3853 | return isa<FloatType>(Val: resultType); |
3854 | case arith::AtomicRMWKind::addi: |
3855 | return isa<IntegerType>(Val: resultType); |
3856 | case arith::AtomicRMWKind::assign: |
3857 | return true; |
3858 | case arith::AtomicRMWKind::mulf: |
3859 | return isa<FloatType>(Val: resultType); |
3860 | case arith::AtomicRMWKind::muli: |
3861 | return isa<IntegerType>(Val: resultType); |
3862 | case arith::AtomicRMWKind::maximumf: |
3863 | return isa<FloatType>(Val: resultType); |
3864 | case arith::AtomicRMWKind::minimumf: |
3865 | return isa<FloatType>(Val: resultType); |
3866 | case arith::AtomicRMWKind::maxs: { |
3867 | auto intType = llvm::dyn_cast<IntegerType>(resultType); |
3868 | return intType && intType.isSigned(); |
3869 | } |
3870 | case arith::AtomicRMWKind::mins: { |
3871 | auto intType = llvm::dyn_cast<IntegerType>(resultType); |
3872 | return intType && intType.isSigned(); |
3873 | } |
3874 | case arith::AtomicRMWKind::maxu: { |
3875 | auto intType = llvm::dyn_cast<IntegerType>(resultType); |
3876 | return intType && intType.isUnsigned(); |
3877 | } |
3878 | case arith::AtomicRMWKind::minu: { |
3879 | auto intType = llvm::dyn_cast<IntegerType>(resultType); |
3880 | return intType && intType.isUnsigned(); |
3881 | } |
3882 | case arith::AtomicRMWKind::ori: |
3883 | return isa<IntegerType>(Val: resultType); |
3884 | case arith::AtomicRMWKind::andi: |
3885 | return isa<IntegerType>(Val: resultType); |
3886 | default: |
3887 | return false; |
3888 | } |
3889 | } |
3890 | |
3891 | LogicalResult AffineParallelOp::verify() { |
3892 | auto numDims = getNumDims(); |
3893 | if (getLowerBoundsGroups().getNumElements() != numDims || |
3894 | getUpperBoundsGroups().getNumElements() != numDims || |
3895 | getSteps().size() != numDims || getBody()->getNumArguments() != numDims) { |
3896 | return emitOpError() << "the number of region arguments (" |
3897 | << getBody()->getNumArguments() |
3898 | << ") and the number of map groups for lower (" |
3899 | << getLowerBoundsGroups().getNumElements() |
3900 | << ") and upper bound (" |
3901 | << getUpperBoundsGroups().getNumElements() |
3902 | << "), and the number of steps (" << getSteps().size() |
3903 | << ") must all match" ; |
3904 | } |
3905 | |
3906 | unsigned expectedNumLBResults = 0; |
3907 | for (APInt v : getLowerBoundsGroups()) |
3908 | expectedNumLBResults += v.getZExtValue(); |
3909 | if (expectedNumLBResults != getLowerBoundsMap().getNumResults()) |
3910 | return emitOpError() << "expected lower bounds map to have " |
3911 | << expectedNumLBResults << " results" ; |
3912 | unsigned expectedNumUBResults = 0; |
3913 | for (APInt v : getUpperBoundsGroups()) |
3914 | expectedNumUBResults += v.getZExtValue(); |
3915 | if (expectedNumUBResults != getUpperBoundsMap().getNumResults()) |
3916 | return emitOpError() << "expected upper bounds map to have " |
3917 | << expectedNumUBResults << " results" ; |
3918 | |
3919 | if (getReductions().size() != getNumResults()) |
3920 | return emitOpError("a reduction must be specified for each output" ); |
3921 | |
3922 | // Verify reduction ops are all valid and each result type matches reduction |
3923 | // ops |
3924 | for (auto it : llvm::enumerate((getReductions()))) { |
3925 | Attribute attr = it.value(); |
3926 | auto intAttr = llvm::dyn_cast<IntegerAttr>(attr); |
3927 | if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt())) |
3928 | return emitOpError("invalid reduction attribute" ); |
3929 | auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value(); |
3930 | if (!isResultTypeMatchAtomicRMWKind(getResult(it.index()).getType(), kind)) |
3931 | return emitOpError("result type cannot match reduction attribute" ); |
3932 | } |
3933 | |
3934 | // Verify that the bound operands are valid dimension/symbols. |
3935 | /// Lower bounds. |
3936 | if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundsOperands(), |
3937 | getLowerBoundsMap().getNumDims()))) |
3938 | return failure(); |
3939 | /// Upper bounds. |
3940 | if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundsOperands(), |
3941 | getUpperBoundsMap().getNumDims()))) |
3942 | return failure(); |
3943 | return success(); |
3944 | } |
3945 | |
3946 | LogicalResult AffineValueMap::canonicalize() { |
3947 | SmallVector<Value, 4> newOperands{operands}; |
3948 | auto newMap = getAffineMap(); |
3949 | composeAffineMapAndOperands(map: &newMap, operands: &newOperands); |
3950 | if (newMap == getAffineMap() && newOperands == operands) |
3951 | return failure(); |
3952 | reset(map: newMap, operands: newOperands); |
3953 | return success(); |
3954 | } |
3955 | |
3956 | /// Canonicalize the bounds of the given loop. |
3957 | static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) { |
3958 | AffineValueMap lb = op.getLowerBoundsValueMap(); |
3959 | bool lbCanonicalized = succeeded(result: lb.canonicalize()); |
3960 | |
3961 | AffineValueMap ub = op.getUpperBoundsValueMap(); |
3962 | bool ubCanonicalized = succeeded(result: ub.canonicalize()); |
3963 | |
3964 | // Any canonicalization change always leads to updated map(s). |
3965 | if (!lbCanonicalized && !ubCanonicalized) |
3966 | return failure(); |
3967 | |
3968 | if (lbCanonicalized) |
3969 | op.setLowerBounds(lb.getOperands(), lb.getAffineMap()); |
3970 | if (ubCanonicalized) |
3971 | op.setUpperBounds(ub.getOperands(), ub.getAffineMap()); |
3972 | |
3973 | return success(); |
3974 | } |
3975 | |
3976 | LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor, |
3977 | SmallVectorImpl<OpFoldResult> &results) { |
3978 | return canonicalizeLoopBounds(*this); |
3979 | } |
3980 | |
3981 | /// Prints a lower(upper) bound of an affine parallel loop with max(min) |
3982 | /// conditions in it. `mapAttr` is a flat list of affine expressions and `group` |
3983 | /// identifies which of the those expressions form max/min groups. `operands` |
3984 | /// are the SSA values of dimensions and symbols and `keyword` is either "min" |
3985 | /// or "max". |
3986 | static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, |
3987 | DenseIntElementsAttr group, ValueRange operands, |
3988 | StringRef keyword) { |
3989 | AffineMap map = mapAttr.getValue(); |
3990 | unsigned numDims = map.getNumDims(); |
3991 | ValueRange dimOperands = operands.take_front(n: numDims); |
3992 | ValueRange symOperands = operands.drop_front(n: numDims); |
3993 | unsigned start = 0; |
3994 | for (llvm::APInt groupSize : group) { |
3995 | if (start != 0) |
3996 | p << ", " ; |
3997 | |
3998 | unsigned size = groupSize.getZExtValue(); |
3999 | if (size == 1) { |
4000 | p.printAffineExprOfSSAIds(expr: map.getResult(idx: start), dimOperands, symOperands); |
4001 | ++start; |
4002 | } else { |
4003 | p << keyword << '('; |
4004 | AffineMap submap = map.getSliceMap(start, length: size); |
4005 | p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands); |
4006 | p << ')'; |
4007 | start += size; |
4008 | } |
4009 | } |
4010 | } |
4011 | |
4012 | void AffineParallelOp::print(OpAsmPrinter &p) { |
4013 | p << " (" << getBody()->getArguments() << ") = (" ; |
4014 | printMinMaxBound(p, getLowerBoundsMapAttr(), getLowerBoundsGroupsAttr(), |
4015 | getLowerBoundsOperands(), "max" ); |
4016 | p << ") to (" ; |
4017 | printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(), |
4018 | getUpperBoundsOperands(), "min" ); |
4019 | p << ')'; |
4020 | SmallVector<int64_t, 8> steps = getSteps(); |
4021 | bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; }); |
4022 | if (!elideSteps) { |
4023 | p << " step (" ; |
4024 | llvm::interleaveComma(steps, p); |
4025 | p << ')'; |
4026 | } |
4027 | if (getNumResults()) { |
4028 | p << " reduce (" ; |
4029 | llvm::interleaveComma(getReductions(), p, [&](auto &attr) { |
4030 | arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind( |
4031 | llvm::cast<IntegerAttr>(attr).getInt()); |
4032 | p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"" ; |
4033 | }); |
4034 | p << ") -> (" << getResultTypes() << ")" ; |
4035 | } |
4036 | |
4037 | p << ' '; |
4038 | p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, |
4039 | /*printBlockTerminators=*/getNumResults()); |
4040 | p.printOptionalAttrDict( |
4041 | (*this)->getAttrs(), |
4042 | /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrStrName(), |
4043 | AffineParallelOp::getLowerBoundsMapAttrStrName(), |
4044 | AffineParallelOp::getLowerBoundsGroupsAttrStrName(), |
4045 | AffineParallelOp::getUpperBoundsMapAttrStrName(), |
4046 | AffineParallelOp::getUpperBoundsGroupsAttrStrName(), |
4047 | AffineParallelOp::getStepsAttrStrName()}); |
4048 | } |
4049 | |
4050 | /// Given a list of lists of parsed operands, populates `uniqueOperands` with |
4051 | /// unique operands. Also populates `replacements with affine expressions of |
4052 | /// `kind` that can be used to update affine maps previously accepting a |
4053 | /// `operands` to accept `uniqueOperands` instead. |
4054 | static ParseResult deduplicateAndResolveOperands( |
4055 | OpAsmParser &parser, |
4056 | ArrayRef<SmallVector<OpAsmParser::UnresolvedOperand>> operands, |
4057 | SmallVectorImpl<Value> &uniqueOperands, |
4058 | SmallVectorImpl<AffineExpr> &replacements, AffineExprKind kind) { |
4059 | assert((kind == AffineExprKind::DimId || kind == AffineExprKind::SymbolId) && |
4060 | "expected operands to be dim or symbol expression" ); |
4061 | |
4062 | Type indexType = parser.getBuilder().getIndexType(); |
4063 | for (const auto &list : operands) { |
4064 | SmallVector<Value> valueOperands; |
4065 | if (parser.resolveOperands(operands: list, type: indexType, result&: valueOperands)) |
4066 | return failure(); |
4067 | for (Value operand : valueOperands) { |
4068 | unsigned pos = std::distance(first: uniqueOperands.begin(), |
4069 | last: llvm::find(Range&: uniqueOperands, Val: operand)); |
4070 | if (pos == uniqueOperands.size()) |
4071 | uniqueOperands.push_back(Elt: operand); |
4072 | replacements.push_back( |
4073 | Elt: kind == AffineExprKind::DimId |
4074 | ? getAffineDimExpr(position: pos, context: parser.getContext()) |
4075 | : getAffineSymbolExpr(position: pos, context: parser.getContext())); |
4076 | } |
4077 | } |
4078 | return success(); |
4079 | } |
4080 | |
4081 | namespace { |
4082 | enum class MinMaxKind { Min, Max }; |
4083 | } // namespace |
4084 | |
4085 | /// Parses an affine map that can contain a min/max for groups of its results, |
4086 | /// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates |
4087 | /// `result` attributes with the map (flat list of expressions) and the grouping |
4088 | /// (list of integers that specify how many expressions to put into each |
4089 | /// min/max) attributes. Deduplicates repeated operands. |
4090 | /// |
4091 | /// parallel-bound ::= `(` parallel-group-list `)` |
4092 | /// parallel-group-list ::= parallel-group (`,` parallel-group-list)? |
4093 | /// parallel-group ::= simple-group | min-max-group |
4094 | /// simple-group ::= expr-of-ssa-ids |
4095 | /// min-max-group ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)` |
4096 | /// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)? |
4097 | /// |
4098 | /// Examples: |
4099 | /// (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6)) |
4100 | /// (%0, max(%1 - 2 * %2)) |
4101 | static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, |
4102 | OperationState &result, |
4103 | MinMaxKind kind) { |
4104 | // Using `const` not `constexpr` below to workaround a MSVC optimizer bug, |
4105 | // see: https://reviews.llvm.org/D134227#3821753 |
4106 | const llvm::StringLiteral tmpAttrStrName = "__pseudo_bound_map" ; |
4107 | |
4108 | StringRef mapName = kind == MinMaxKind::Min |
4109 | ? AffineParallelOp::getUpperBoundsMapAttrStrName() |
4110 | : AffineParallelOp::getLowerBoundsMapAttrStrName(); |
4111 | StringRef groupsName = |
4112 | kind == MinMaxKind::Min |
4113 | ? AffineParallelOp::getUpperBoundsGroupsAttrStrName() |
4114 | : AffineParallelOp::getLowerBoundsGroupsAttrStrName(); |
4115 | |
4116 | if (failed(result: parser.parseLParen())) |
4117 | return failure(); |
4118 | |
4119 | if (succeeded(result: parser.parseOptionalRParen())) { |
4120 | result.addAttribute( |
4121 | mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap())); |
4122 | result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr(values: {})); |
4123 | return success(); |
4124 | } |
4125 | |
4126 | SmallVector<AffineExpr> flatExprs; |
4127 | SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> flatDimOperands; |
4128 | SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> flatSymOperands; |
4129 | SmallVector<int32_t> numMapsPerGroup; |
4130 | SmallVector<OpAsmParser::UnresolvedOperand> mapOperands; |
4131 | auto parseOperands = [&]() { |
4132 | if (succeeded(result: parser.parseOptionalKeyword( |
4133 | keyword: kind == MinMaxKind::Min ? "min" : "max" ))) { |
4134 | mapOperands.clear(); |
4135 | AffineMapAttr map; |
4136 | if (failed(parser.parseAffineMapOfSSAIds(operands&: mapOperands, map&: map, attrName: tmpAttrStrName, |
4137 | attrs&: result.attributes, |
4138 | delimiter: OpAsmParser::Delimiter::Paren))) |
4139 | return failure(); |
4140 | result.attributes.erase(name: tmpAttrStrName); |
4141 | llvm::append_range(flatExprs, map.getValue().getResults()); |
4142 | auto operandsRef = llvm::ArrayRef(mapOperands); |
4143 | auto dimsRef = operandsRef.take_front(N: map.getValue().getNumDims()); |
4144 | SmallVector<OpAsmParser::UnresolvedOperand> dims(dimsRef.begin(), |
4145 | dimsRef.end()); |
4146 | auto symsRef = operandsRef.drop_front(N: map.getValue().getNumDims()); |
4147 | SmallVector<OpAsmParser::UnresolvedOperand> syms(symsRef.begin(), |
4148 | symsRef.end()); |
4149 | flatDimOperands.append(map.getValue().getNumResults(), dims); |
4150 | flatSymOperands.append(map.getValue().getNumResults(), syms); |
4151 | numMapsPerGroup.push_back(Elt: map.getValue().getNumResults()); |
4152 | } else { |
4153 | if (failed(result: parser.parseAffineExprOfSSAIds(dimOperands&: flatDimOperands.emplace_back(), |
4154 | symbOperands&: flatSymOperands.emplace_back(), |
4155 | expr&: flatExprs.emplace_back()))) |
4156 | return failure(); |
4157 | numMapsPerGroup.push_back(Elt: 1); |
4158 | } |
4159 | return success(); |
4160 | }; |
4161 | if (parser.parseCommaSeparatedList(parseElementFn: parseOperands) || parser.parseRParen()) |
4162 | return failure(); |
4163 | |
4164 | unsigned totalNumDims = 0; |
4165 | unsigned totalNumSyms = 0; |
4166 | for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) { |
4167 | unsigned numDims = flatDimOperands[i].size(); |
4168 | unsigned numSyms = flatSymOperands[i].size(); |
4169 | flatExprs[i] = flatExprs[i] |
4170 | .shiftDims(numDims, shift: totalNumDims) |
4171 | .shiftSymbols(numSymbols: numSyms, shift: totalNumSyms); |
4172 | totalNumDims += numDims; |
4173 | totalNumSyms += numSyms; |
4174 | } |
4175 | |
4176 | // Deduplicate map operands. |
4177 | SmallVector<Value> dimOperands, symOperands; |
4178 | SmallVector<AffineExpr> dimRplacements, symRepacements; |
4179 | if (deduplicateAndResolveOperands(parser, operands: flatDimOperands, uniqueOperands&: dimOperands, |
4180 | replacements&: dimRplacements, kind: AffineExprKind::DimId) || |
4181 | deduplicateAndResolveOperands(parser, operands: flatSymOperands, uniqueOperands&: symOperands, |
4182 | replacements&: symRepacements, kind: AffineExprKind::SymbolId)) |
4183 | return failure(); |
4184 | |
4185 | result.operands.append(in_start: dimOperands.begin(), in_end: dimOperands.end()); |
4186 | result.operands.append(in_start: symOperands.begin(), in_end: symOperands.end()); |
4187 | |
4188 | Builder &builder = parser.getBuilder(); |
4189 | auto flatMap = AffineMap::get(dimCount: totalNumDims, symbolCount: totalNumSyms, results: flatExprs, |
4190 | context: parser.getContext()); |
4191 | flatMap = flatMap.replaceDimsAndSymbols( |
4192 | dimReplacements: dimRplacements, symReplacements: symRepacements, numResultDims: dimOperands.size(), numResultSyms: symOperands.size()); |
4193 | |
4194 | result.addAttribute(mapName, AffineMapAttr::get(flatMap)); |
4195 | result.addAttribute(groupsName, builder.getI32TensorAttr(values: numMapsPerGroup)); |
4196 | return success(); |
4197 | } |
4198 | |
4199 | // |
4200 | // operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound |
4201 | // `to` parallel-bound steps? region attr-dict? |
4202 | // steps ::= `steps` `(` integer-literals `)` |
4203 | // |
4204 | ParseResult AffineParallelOp::parse(OpAsmParser &parser, |
4205 | OperationState &result) { |
4206 | auto &builder = parser.getBuilder(); |
4207 | auto indexType = builder.getIndexType(); |
4208 | SmallVector<OpAsmParser::Argument, 4> ivs; |
4209 | if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) || |
4210 | parser.parseEqual() || |
4211 | parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) || |
4212 | parser.parseKeyword("to" ) || |
4213 | parseAffineMapWithMinMax(parser, result, MinMaxKind::Min)) |
4214 | return failure(); |
4215 | |
4216 | AffineMapAttr stepsMapAttr; |
4217 | NamedAttrList stepsAttrs; |
4218 | SmallVector<OpAsmParser::UnresolvedOperand, 4> stepsMapOperands; |
4219 | if (failed(parser.parseOptionalKeyword("step" ))) { |
4220 | SmallVector<int64_t, 4> steps(ivs.size(), 1); |
4221 | result.addAttribute(AffineParallelOp::getStepsAttrStrName(), |
4222 | builder.getI64ArrayAttr(steps)); |
4223 | } else { |
4224 | if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr, |
4225 | AffineParallelOp::getStepsAttrStrName(), |
4226 | stepsAttrs, |
4227 | OpAsmParser::Delimiter::Paren)) |
4228 | return failure(); |
4229 | |
4230 | // Convert steps from an AffineMap into an I64ArrayAttr. |
4231 | SmallVector<int64_t, 4> steps; |
4232 | auto stepsMap = stepsMapAttr.getValue(); |
4233 | for (const auto &result : stepsMap.getResults()) { |
4234 | auto constExpr = dyn_cast<AffineConstantExpr>(result); |
4235 | if (!constExpr) |
4236 | return parser.emitError(parser.getNameLoc(), |
4237 | "steps must be constant integers" ); |
4238 | steps.push_back(constExpr.getValue()); |
4239 | } |
4240 | result.addAttribute(AffineParallelOp::getStepsAttrStrName(), |
4241 | builder.getI64ArrayAttr(steps)); |
4242 | } |
4243 | |
4244 | // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the |
4245 | // quoted strings are a member of the enum AtomicRMWKind. |
4246 | SmallVector<Attribute, 4> reductions; |
4247 | if (succeeded(parser.parseOptionalKeyword("reduce" ))) { |
4248 | if (parser.parseLParen()) |
4249 | return failure(); |
4250 | auto parseAttributes = [&]() -> ParseResult { |
4251 | // Parse a single quoted string via the attribute parsing, and then |
4252 | // verify it is a member of the enum and convert to it's integer |
4253 | // representation. |
4254 | StringAttr attrVal; |
4255 | NamedAttrList attrStorage; |
4256 | auto loc = parser.getCurrentLocation(); |
4257 | if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce" , |
4258 | attrStorage)) |
4259 | return failure(); |
4260 | std::optional<arith::AtomicRMWKind> reduction = |
4261 | arith::symbolizeAtomicRMWKind(attrVal.getValue()); |
4262 | if (!reduction) |
4263 | return parser.emitError(loc, "invalid reduction value: " ) << attrVal; |
4264 | reductions.push_back( |
4265 | builder.getI64IntegerAttr(static_cast<int64_t>(reduction.value()))); |
4266 | // While we keep getting commas, keep parsing. |
4267 | return success(); |
4268 | }; |
4269 | if (parser.parseCommaSeparatedList(parseAttributes) || parser.parseRParen()) |
4270 | return failure(); |
4271 | } |
4272 | result.addAttribute(AffineParallelOp::getReductionsAttrStrName(), |
4273 | builder.getArrayAttr(reductions)); |
4274 | |
4275 | // Parse return types of reductions (if any) |
4276 | if (parser.parseOptionalArrowTypeList(result.types)) |
4277 | return failure(); |
4278 | |
4279 | // Now parse the body. |
4280 | Region *body = result.addRegion(); |
4281 | for (auto &iv : ivs) |
4282 | iv.type = indexType; |
4283 | if (parser.parseRegion(*body, ivs) || |
4284 | parser.parseOptionalAttrDict(result.attributes)) |
4285 | return failure(); |
4286 | |
4287 | // Add a terminator if none was parsed. |
4288 | AffineParallelOp::ensureTerminator(*body, builder, result.location); |
4289 | return success(); |
4290 | } |
4291 | |
4292 | //===----------------------------------------------------------------------===// |
4293 | // AffineYieldOp |
4294 | //===----------------------------------------------------------------------===// |
4295 | |
4296 | LogicalResult AffineYieldOp::verify() { |
4297 | auto *parentOp = (*this)->getParentOp(); |
4298 | auto results = parentOp->getResults(); |
4299 | auto operands = getOperands(); |
4300 | |
4301 | if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp)) |
4302 | return emitOpError() << "only terminates affine.if/for/parallel regions" ; |
4303 | if (parentOp->getNumResults() != getNumOperands()) |
4304 | return emitOpError() << "parent of yield must have same number of " |
4305 | "results as the yield operands" ; |
4306 | for (auto it : llvm::zip(results, operands)) { |
4307 | if (std::get<0>(it).getType() != std::get<1>(it).getType()) |
4308 | return emitOpError() << "types mismatch between yield op and its parent" ; |
4309 | } |
4310 | |
4311 | return success(); |
4312 | } |
4313 | |
4314 | //===----------------------------------------------------------------------===// |
4315 | // AffineVectorLoadOp |
4316 | //===----------------------------------------------------------------------===// |
4317 | |
4318 | void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result, |
4319 | VectorType resultType, AffineMap map, |
4320 | ValueRange operands) { |
4321 | assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands" ); |
4322 | result.addOperands(operands); |
4323 | if (map) |
4324 | result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map)); |
4325 | result.types.push_back(resultType); |
4326 | } |
4327 | |
4328 | void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result, |
4329 | VectorType resultType, Value memref, |
4330 | AffineMap map, ValueRange mapOperands) { |
4331 | assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info" ); |
4332 | result.addOperands(memref); |
4333 | result.addOperands(mapOperands); |
4334 | result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map)); |
4335 | result.types.push_back(resultType); |
4336 | } |
4337 | |
4338 | void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result, |
4339 | VectorType resultType, Value memref, |
4340 | ValueRange indices) { |
4341 | auto memrefType = llvm::cast<MemRefType>(memref.getType()); |
4342 | int64_t rank = memrefType.getRank(); |
4343 | // Create identity map for memrefs with at least one dimension or () -> () |
4344 | // for zero-dimensional memrefs. |
4345 | auto map = |
4346 | rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap(); |
4347 | build(builder, result, resultType, memref, map, indices); |
4348 | } |
4349 | |
4350 | void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, |
4351 | MLIRContext *context) { |
4352 | results.add<SimplifyAffineOp<AffineVectorLoadOp>>(context); |
4353 | } |
4354 | |
4355 | ParseResult AffineVectorLoadOp::parse(OpAsmParser &parser, |
4356 | OperationState &result) { |
4357 | auto &builder = parser.getBuilder(); |
4358 | auto indexTy = builder.getIndexType(); |
4359 | |
4360 | MemRefType memrefType; |
4361 | VectorType resultType; |
4362 | OpAsmParser::UnresolvedOperand memrefInfo; |
4363 | AffineMapAttr mapAttr; |
4364 | SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands; |
4365 | return failure( |
4366 | parser.parseOperand(memrefInfo) || |
4367 | parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, |
4368 | AffineVectorLoadOp::getMapAttrStrName(), |
4369 | result.attributes) || |
4370 | parser.parseOptionalAttrDict(result.attributes) || |
4371 | parser.parseColonType(memrefType) || parser.parseComma() || |
4372 | parser.parseType(resultType) || |
4373 | parser.resolveOperand(memrefInfo, memrefType, result.operands) || |
4374 | parser.resolveOperands(mapOperands, indexTy, result.operands) || |
4375 | parser.addTypeToList(resultType, result.types)); |
4376 | } |
4377 | |
4378 | void AffineVectorLoadOp::print(OpAsmPrinter &p) { |
4379 | p << " " << getMemRef() << '['; |
4380 | if (AffineMapAttr mapAttr = |
4381 | (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName())) |
4382 | p.printAffineMapOfSSAIds(mapAttr, getMapOperands()); |
4383 | p << ']'; |
4384 | p.printOptionalAttrDict((*this)->getAttrs(), |
4385 | /*elidedAttrs=*/{getMapAttrStrName()}); |
4386 | p << " : " << getMemRefType() << ", " << getType(); |
4387 | } |
4388 | |
4389 | /// Verify common invariants of affine.vector_load and affine.vector_store. |
4390 | static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, |
4391 | VectorType vectorType) { |
4392 | // Check that memref and vector element types match. |
4393 | if (memrefType.getElementType() != vectorType.getElementType()) |
4394 | return op->emitOpError( |
4395 | message: "requires memref and vector types of the same elemental type" ); |
4396 | return success(); |
4397 | } |
4398 | |
4399 | LogicalResult AffineVectorLoadOp::verify() { |
4400 | MemRefType memrefType = getMemRefType(); |
4401 | if (failed(verifyMemoryOpIndexing( |
4402 | getOperation(), |
4403 | (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()), |
4404 | getMapOperands(), memrefType, |
4405 | /*numIndexOperands=*/getNumOperands() - 1))) |
4406 | return failure(); |
4407 | |
4408 | if (failed(verifyVectorMemoryOp(getOperation(), memrefType, getVectorType()))) |
4409 | return failure(); |
4410 | |
4411 | return success(); |
4412 | } |
4413 | |
4414 | //===----------------------------------------------------------------------===// |
4415 | // AffineVectorStoreOp |
4416 | //===----------------------------------------------------------------------===// |
4417 | |
4418 | void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result, |
4419 | Value valueToStore, Value memref, AffineMap map, |
4420 | ValueRange mapOperands) { |
4421 | assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info" ); |
4422 | result.addOperands(valueToStore); |
4423 | result.addOperands(memref); |
4424 | result.addOperands(mapOperands); |
4425 | result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map)); |
4426 | } |
4427 | |
4428 | // Use identity map. |
4429 | void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result, |
4430 | Value valueToStore, Value memref, |
4431 | ValueRange indices) { |
4432 | auto memrefType = llvm::cast<MemRefType>(memref.getType()); |
4433 | int64_t rank = memrefType.getRank(); |
4434 | // Create identity map for memrefs with at least one dimension or () -> () |
4435 | // for zero-dimensional memrefs. |
4436 | auto map = |
4437 | rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap(); |
4438 | build(builder, result, valueToStore, memref, map, indices); |
4439 | } |
4440 | void AffineVectorStoreOp::getCanonicalizationPatterns( |
4441 | RewritePatternSet &results, MLIRContext *context) { |
4442 | results.add<SimplifyAffineOp<AffineVectorStoreOp>>(context); |
4443 | } |
4444 | |
4445 | ParseResult AffineVectorStoreOp::parse(OpAsmParser &parser, |
4446 | OperationState &result) { |
4447 | auto indexTy = parser.getBuilder().getIndexType(); |
4448 | |
4449 | MemRefType memrefType; |
4450 | VectorType resultType; |
4451 | OpAsmParser::UnresolvedOperand storeValueInfo; |
4452 | OpAsmParser::UnresolvedOperand memrefInfo; |
4453 | AffineMapAttr mapAttr; |
4454 | SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands; |
4455 | return failure( |
4456 | parser.parseOperand(storeValueInfo) || parser.parseComma() || |
4457 | parser.parseOperand(memrefInfo) || |
4458 | parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, |
4459 | AffineVectorStoreOp::getMapAttrStrName(), |
4460 | result.attributes) || |
4461 | parser.parseOptionalAttrDict(result.attributes) || |
4462 | parser.parseColonType(memrefType) || parser.parseComma() || |
4463 | parser.parseType(resultType) || |
4464 | parser.resolveOperand(storeValueInfo, resultType, result.operands) || |
4465 | parser.resolveOperand(memrefInfo, memrefType, result.operands) || |
4466 | parser.resolveOperands(mapOperands, indexTy, result.operands)); |
4467 | } |
4468 | |
4469 | void AffineVectorStoreOp::print(OpAsmPrinter &p) { |
4470 | p << " " << getValueToStore(); |
4471 | p << ", " << getMemRef() << '['; |
4472 | if (AffineMapAttr mapAttr = |
4473 | (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName())) |
4474 | p.printAffineMapOfSSAIds(mapAttr, getMapOperands()); |
4475 | p << ']'; |
4476 | p.printOptionalAttrDict((*this)->getAttrs(), |
4477 | /*elidedAttrs=*/{getMapAttrStrName()}); |
4478 | p << " : " << getMemRefType() << ", " << getValueToStore().getType(); |
4479 | } |
4480 | |
4481 | LogicalResult AffineVectorStoreOp::verify() { |
4482 | MemRefType memrefType = getMemRefType(); |
4483 | if (failed(verifyMemoryOpIndexing( |
4484 | *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()), |
4485 | getMapOperands(), memrefType, |
4486 | /*numIndexOperands=*/getNumOperands() - 2))) |
4487 | return failure(); |
4488 | |
4489 | if (failed(verifyVectorMemoryOp(*this, memrefType, getVectorType()))) |
4490 | return failure(); |
4491 | |
4492 | return success(); |
4493 | } |
4494 | |
4495 | //===----------------------------------------------------------------------===// |
4496 | // DelinearizeIndexOp |
4497 | //===----------------------------------------------------------------------===// |
4498 | |
4499 | LogicalResult AffineDelinearizeIndexOp::inferReturnTypes( |
4500 | MLIRContext *context, std::optional<::mlir::Location> location, |
4501 | ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, |
4502 | RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) { |
4503 | AffineDelinearizeIndexOpAdaptor adaptor(operands, attributes, properties, |
4504 | regions); |
4505 | inferredReturnTypes.assign(adaptor.getBasis().size(), |
4506 | IndexType::get(context)); |
4507 | return success(); |
4508 | } |
4509 | |
4510 | void AffineDelinearizeIndexOp::build(OpBuilder &builder, OperationState &result, |
4511 | Value linearIndex, |
4512 | ArrayRef<OpFoldResult> basis) { |
4513 | result.addTypes(SmallVector<Type>(basis.size(), builder.getIndexType())); |
4514 | result.addOperands(linearIndex); |
4515 | SmallVector<Value> basisValues = |
4516 | llvm::map_to_vector(basis, [&](OpFoldResult ofr) -> Value { |
4517 | std::optional<int64_t> staticDim = getConstantIntValue(ofr); |
4518 | if (staticDim.has_value()) |
4519 | return builder.create<arith::ConstantIndexOp>(result.location, |
4520 | *staticDim); |
4521 | return llvm::dyn_cast_if_present<Value>(ofr); |
4522 | }); |
4523 | result.addOperands(basisValues); |
4524 | } |
4525 | |
4526 | LogicalResult AffineDelinearizeIndexOp::verify() { |
4527 | if (getBasis().empty()) |
4528 | return emitOpError("basis should not be empty" ); |
4529 | if (getNumResults() != getBasis().size()) |
4530 | return emitOpError("should return an index for each basis element" ); |
4531 | return success(); |
4532 | } |
4533 | |
4534 | //===----------------------------------------------------------------------===// |
4535 | // TableGen'd op method definitions |
4536 | //===----------------------------------------------------------------------===// |
4537 | |
4538 | #define GET_OP_CLASSES |
4539 | #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc" |
4540 | |