1//===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Affine/IR/AffineOps.h"
10#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
11#include "mlir/Dialect/MemRef/IR/MemRef.h"
12#include "mlir/Dialect/UB/IR/UBOps.h"
13#include "mlir/IR/AffineExprVisitor.h"
14#include "mlir/IR/IRMapping.h"
15#include "mlir/IR/IntegerSet.h"
16#include "mlir/IR/Matchers.h"
17#include "mlir/IR/OpDefinition.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/Interfaces/ShapedOpInterfaces.h"
20#include "mlir/Interfaces/ValueBoundsOpInterface.h"
21#include "mlir/Support/MathExtras.h"
22#include "mlir/Transforms/InliningUtils.h"
23#include "llvm/ADT/ScopeExit.h"
24#include "llvm/ADT/SmallBitVector.h"
25#include "llvm/ADT/SmallVectorExtras.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/Debug.h"
28#include <numeric>
29#include <optional>
30
31using namespace mlir;
32using namespace mlir::affine;
33
34#define DEBUG_TYPE "affine-ops"
35
36#include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
37
38/// A utility function to check if a value is defined at the top level of
39/// `region` or is an argument of `region`. A value of index type defined at the
40/// top level of a `AffineScope` region is always a valid symbol for all
41/// uses in that region.
42bool mlir::affine::isTopLevelValue(Value value, Region *region) {
43 if (auto arg = llvm::dyn_cast<BlockArgument>(value))
44 return arg.getParentRegion() == region;
45 return value.getDefiningOp()->getParentRegion() == region;
46}
47
48/// Checks if `value` known to be a legal affine dimension or symbol in `src`
49/// region remains legal if the operation that uses it is inlined into `dest`
50/// with the given value mapping. `legalityCheck` is either `isValidDim` or
51/// `isValidSymbol`, depending on the value being required to remain a valid
52/// dimension or symbol.
53static bool
54remainsLegalAfterInline(Value value, Region *src, Region *dest,
55 const IRMapping &mapping,
56 function_ref<bool(Value, Region *)> legalityCheck) {
57 // If the value is a valid dimension for any other reason than being
58 // a top-level value, it will remain valid: constants get inlined
59 // with the function, transitive affine applies also get inlined and
60 // will be checked themselves, etc.
61 if (!isTopLevelValue(value, region: src))
62 return true;
63
64 // If it's a top-level value because it's a block operand, i.e. a
65 // function argument, check whether the value replacing it after
66 // inlining is a valid dimension in the new region.
67 if (llvm::isa<BlockArgument>(Val: value))
68 return legalityCheck(mapping.lookup(from: value), dest);
69
70 // If it's a top-level value because it's defined in the region,
71 // it can only be inlined if the defining op is a constant or a
72 // `dim`, which can appear anywhere and be valid, since the defining
73 // op won't be top-level anymore after inlining.
74 Attribute operandCst;
75 bool isDimLikeOp = isa<ShapedDimOpInterface>(value.getDefiningOp());
76 return matchPattern(op: value.getDefiningOp(), pattern: m_Constant(bind_value: &operandCst)) ||
77 isDimLikeOp;
78}
79
80/// Checks if all values known to be legal affine dimensions or symbols in `src`
81/// remain so if their respective users are inlined into `dest`.
82static bool
83remainsLegalAfterInline(ValueRange values, Region *src, Region *dest,
84 const IRMapping &mapping,
85 function_ref<bool(Value, Region *)> legalityCheck) {
86 return llvm::all_of(Range&: values, P: [&](Value v) {
87 return remainsLegalAfterInline(value: v, src, dest, mapping, legalityCheck);
88 });
89}
90
91/// Checks if an affine read or write operation remains legal after inlining
92/// from `src` to `dest`.
93template <typename OpTy>
94static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
95 const IRMapping &mapping) {
96 static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
97 AffineWriteOpInterface>::value,
98 "only ops with affine read/write interface are supported");
99
100 AffineMap map = op.getAffineMap();
101 ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
102 ValueRange symbolOperands =
103 op.getMapOperands().take_back(map.getNumSymbols());
104 if (!remainsLegalAfterInline(
105 values: dimOperands, src, dest, mapping,
106 legalityCheck: static_cast<bool (*)(Value, Region *)>(isValidDim)))
107 return false;
108 if (!remainsLegalAfterInline(
109 values: symbolOperands, src, dest, mapping,
110 legalityCheck: static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
111 return false;
112 return true;
113}
114
115/// Checks if an affine apply operation remains legal after inlining from `src`
116/// to `dest`.
117// Use "unused attribute" marker to silence clang-tidy warning stemming from
118// the inability to see through "llvm::TypeSwitch".
119template <>
120bool LLVM_ATTRIBUTE_UNUSED remainsLegalAfterInline(AffineApplyOp op,
121 Region *src, Region *dest,
122 const IRMapping &mapping) {
123 // If it's a valid dimension, we need to check that it remains so.
124 if (isValidDim(op.getResult(), src))
125 return remainsLegalAfterInline(
126 op.getMapOperands(), src, dest, mapping,
127 static_cast<bool (*)(Value, Region *)>(isValidDim));
128
129 // Otherwise it must be a valid symbol, check that it remains so.
130 return remainsLegalAfterInline(
131 op.getMapOperands(), src, dest, mapping,
132 static_cast<bool (*)(Value, Region *)>(isValidSymbol));
133}
134
135//===----------------------------------------------------------------------===//
136// AffineDialect Interfaces
137//===----------------------------------------------------------------------===//
138
139namespace {
140/// This class defines the interface for handling inlining with affine
141/// operations.
142struct AffineInlinerInterface : public DialectInlinerInterface {
143 using DialectInlinerInterface::DialectInlinerInterface;
144
145 //===--------------------------------------------------------------------===//
146 // Analysis Hooks
147 //===--------------------------------------------------------------------===//
148
149 /// Returns true if the given region 'src' can be inlined into the region
150 /// 'dest' that is attached to an operation registered to the current dialect.
151 /// 'wouldBeCloned' is set if the region is cloned into its new location
152 /// rather than moved, indicating there may be other users.
153 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
154 IRMapping &valueMapping) const final {
155 // We can inline into affine loops and conditionals if this doesn't break
156 // affine value categorization rules.
157 Operation *destOp = dest->getParentOp();
158 if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
159 return false;
160
161 // Multi-block regions cannot be inlined into affine constructs, all of
162 // which require single-block regions.
163 if (!llvm::hasSingleElement(C&: *src))
164 return false;
165
166 // Side-effecting operations that the affine dialect cannot understand
167 // should not be inlined.
168 Block &srcBlock = src->front();
169 for (Operation &op : srcBlock) {
170 // Ops with no side effects are fine,
171 if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
172 if (iface.hasNoEffect())
173 continue;
174 }
175
176 // Assuming the inlined region is valid, we only need to check if the
177 // inlining would change it.
178 bool remainsValid =
179 llvm::TypeSwitch<Operation *, bool>(&op)
180 .Case<AffineApplyOp, AffineReadOpInterface,
181 AffineWriteOpInterface>([&](auto op) {
182 return remainsLegalAfterInline(op, src, dest, valueMapping);
183 })
184 .Default([](Operation *) {
185 // Conservatively disallow inlining ops we cannot reason about.
186 return false;
187 });
188
189 if (!remainsValid)
190 return false;
191 }
192
193 return true;
194 }
195
196 /// Returns true if the given operation 'op', that is registered to this
197 /// dialect, can be inlined into the given region, false otherwise.
198 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
199 IRMapping &valueMapping) const final {
200 // Always allow inlining affine operations into a region that is marked as
201 // affine scope, or into affine loops and conditionals. There are some edge
202 // cases when inlining *into* affine structures, but that is handled in the
203 // other 'isLegalToInline' hook above.
204 Operation *parentOp = region->getParentOp();
205 return parentOp->hasTrait<OpTrait::AffineScope>() ||
206 isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
207 }
208
209 /// Affine regions should be analyzed recursively.
210 bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
211};
212} // namespace
213
214//===----------------------------------------------------------------------===//
215// AffineDialect
216//===----------------------------------------------------------------------===//
217
218void AffineDialect::initialize() {
219 addOperations<AffineDmaStartOp, AffineDmaWaitOp,
220#define GET_OP_LIST
221#include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
222 >();
223 addInterfaces<AffineInlinerInterface>();
224 declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
225 AffineMinOp>();
226}
227
228/// Materialize a single constant operation from a given attribute value with
229/// the desired resultant type.
230Operation *AffineDialect::materializeConstant(OpBuilder &builder,
231 Attribute value, Type type,
232 Location loc) {
233 if (auto poison = dyn_cast<ub::PoisonAttr>(value))
234 return builder.create<ub::PoisonOp>(loc, type, poison);
235 return arith::ConstantOp::materialize(builder, value, type, loc);
236}
237
238/// A utility function to check if a value is defined at the top level of an
239/// op with trait `AffineScope`. If the value is defined in an unlinked region,
240/// conservatively assume it is not top-level. A value of index type defined at
241/// the top level is always a valid symbol.
242bool mlir::affine::isTopLevelValue(Value value) {
243 if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
244 // The block owning the argument may be unlinked, e.g. when the surrounding
245 // region has not yet been attached to an Op, at which point the parent Op
246 // is null.
247 Operation *parentOp = arg.getOwner()->getParentOp();
248 return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
249 }
250 // The defining Op may live in an unlinked block so its parent Op may be null.
251 Operation *parentOp = value.getDefiningOp()->getParentOp();
252 return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
253}
254
255/// Returns the closest region enclosing `op` that is held by an operation with
256/// trait `AffineScope`; `nullptr` if there is no such region.
257Region *mlir::affine::getAffineScope(Operation *op) {
258 auto *curOp = op;
259 while (auto *parentOp = curOp->getParentOp()) {
260 if (parentOp->hasTrait<OpTrait::AffineScope>())
261 return curOp->getParentRegion();
262 curOp = parentOp;
263 }
264 return nullptr;
265}
266
267// A Value can be used as a dimension id iff it meets one of the following
268// conditions:
269// *) It is valid as a symbol.
270// *) It is an induction variable.
271// *) It is the result of affine apply operation with dimension id arguments.
272bool mlir::affine::isValidDim(Value value) {
273 // The value must be an index type.
274 if (!value.getType().isIndex())
275 return false;
276
277 if (auto *defOp = value.getDefiningOp())
278 return isValidDim(value, region: getAffineScope(op: defOp));
279
280 // This value has to be a block argument for an op that has the
281 // `AffineScope` trait or for an affine.for or affine.parallel.
282 auto *parentOp = llvm::cast<BlockArgument>(Val&: value).getOwner()->getParentOp();
283 return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
284 isa<AffineForOp, AffineParallelOp>(parentOp));
285}
286
287// Value can be used as a dimension id iff it meets one of the following
288// conditions:
289// *) It is valid as a symbol.
290// *) It is an induction variable.
291// *) It is the result of an affine apply operation with dimension id operands.
292bool mlir::affine::isValidDim(Value value, Region *region) {
293 // The value must be an index type.
294 if (!value.getType().isIndex())
295 return false;
296
297 // All valid symbols are okay.
298 if (isValidSymbol(value, region))
299 return true;
300
301 auto *op = value.getDefiningOp();
302 if (!op) {
303 // This value has to be a block argument for an affine.for or an
304 // affine.parallel.
305 auto *parentOp = llvm::cast<BlockArgument>(Val&: value).getOwner()->getParentOp();
306 return isa<AffineForOp, AffineParallelOp>(parentOp);
307 }
308
309 // Affine apply operation is ok if all of its operands are ok.
310 if (auto applyOp = dyn_cast<AffineApplyOp>(op))
311 return applyOp.isValidDim(region);
312 // The dim op is okay if its operand memref/tensor is defined at the top
313 // level.
314 if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
315 return isTopLevelValue(dimOp.getShapedValue());
316 return false;
317}
318
319/// Returns true if the 'index' dimension of the `memref` defined by
320/// `memrefDefOp` is a statically shaped one or defined using a valid symbol
321/// for `region`.
322template <typename AnyMemRefDefOp>
323static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
324 Region *region) {
325 MemRefType memRefType = memrefDefOp.getType();
326
327 // Dimension index is out of bounds.
328 if (index >= memRefType.getRank()) {
329 return false;
330 }
331
332 // Statically shaped.
333 if (!memRefType.isDynamicDim(index))
334 return true;
335 // Get the position of the dimension among dynamic dimensions;
336 unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
337 return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
338 region);
339}
340
341/// Returns true if the result of the dim op is a valid symbol for `region`.
342static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
343 // The dim op is okay if its source is defined at the top level.
344 if (isTopLevelValue(dimOp.getShapedValue()))
345 return true;
346
347 // Conservatively handle remaining BlockArguments as non-valid symbols.
348 // E.g. scf.for iterArgs.
349 if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
350 return false;
351
352 // The dim op is also okay if its operand memref is a view/subview whose
353 // corresponding size is a valid symbol.
354 std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension());
355
356 // Be conservative if we can't understand the dimension.
357 if (!index.has_value())
358 return false;
359
360 // Skip over all memref.cast ops (if any).
361 Operation *op = dimOp.getShapedValue().getDefiningOp();
362 while (auto castOp = dyn_cast<memref::CastOp>(op)) {
363 // Bail on unranked memrefs.
364 if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
365 return false;
366 op = castOp.getSource().getDefiningOp();
367 if (!op)
368 return false;
369 }
370
371 int64_t i = index.value();
372 return TypeSwitch<Operation *, bool>(op)
373 .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
374 [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
375 .Default([](Operation *) { return false; });
376}
377
378// A value can be used as a symbol (at all its use sites) iff it meets one of
379// the following conditions:
380// *) It is a constant.
381// *) Its defining op or block arg appearance is immediately enclosed by an op
382// with `AffineScope` trait.
383// *) It is the result of an affine.apply operation with symbol operands.
384// *) It is a result of the dim op on a memref whose corresponding size is a
385// valid symbol.
386bool mlir::affine::isValidSymbol(Value value) {
387 if (!value)
388 return false;
389
390 // The value must be an index type.
391 if (!value.getType().isIndex())
392 return false;
393
394 // Check that the value is a top level value.
395 if (isTopLevelValue(value))
396 return true;
397
398 if (auto *defOp = value.getDefiningOp())
399 return isValidSymbol(value, region: getAffineScope(op: defOp));
400
401 return false;
402}
403
404/// A value can be used as a symbol for `region` iff it meets one of the
405/// following conditions:
406/// *) It is a constant.
407/// *) It is the result of an affine apply operation with symbol arguments.
408/// *) It is a result of the dim op on a memref whose corresponding size is
409/// a valid symbol.
410/// *) It is defined at the top level of 'region' or is its argument.
411/// *) It dominates `region`'s parent op.
412/// If `region` is null, conservatively assume the symbol definition scope does
413/// not exist and only accept the values that would be symbols regardless of
414/// the surrounding region structure, i.e. the first three cases above.
415bool mlir::affine::isValidSymbol(Value value, Region *region) {
416 // The value must be an index type.
417 if (!value.getType().isIndex())
418 return false;
419
420 // A top-level value is a valid symbol.
421 if (region && ::isTopLevelValue(value, region))
422 return true;
423
424 auto *defOp = value.getDefiningOp();
425 if (!defOp) {
426 // A block argument that is not a top-level value is a valid symbol if it
427 // dominates region's parent op.
428 Operation *regionOp = region ? region->getParentOp() : nullptr;
429 if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
430 if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
431 return isValidSymbol(value, region: parentOpRegion);
432 return false;
433 }
434
435 // Constant operation is ok.
436 Attribute operandCst;
437 if (matchPattern(op: defOp, pattern: m_Constant(bind_value: &operandCst)))
438 return true;
439
440 // Affine apply operation is ok if all of its operands are ok.
441 if (auto applyOp = dyn_cast<AffineApplyOp>(defOp))
442 return applyOp.isValidSymbol(region);
443
444 // Dim op results could be valid symbols at any level.
445 if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
446 return isDimOpValidSymbol(dimOp, region);
447
448 // Check for values dominating `region`'s parent op.
449 Operation *regionOp = region ? region->getParentOp() : nullptr;
450 if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
451 if (auto *parentRegion = region->getParentOp()->getParentRegion())
452 return isValidSymbol(value, region: parentRegion);
453
454 return false;
455}
456
457// Returns true if 'value' is a valid index to an affine operation (e.g.
458// affine.load, affine.store, affine.dma_start, affine.dma_wait) where
459// `region` provides the polyhedral symbol scope. Returns false otherwise.
460static bool isValidAffineIndexOperand(Value value, Region *region) {
461 return isValidDim(value, region) || isValidSymbol(value, region);
462}
463
464/// Prints dimension and symbol list.
465static void printDimAndSymbolList(Operation::operand_iterator begin,
466 Operation::operand_iterator end,
467 unsigned numDims, OpAsmPrinter &printer) {
468 OperandRange operands(begin, end);
469 printer << '(' << operands.take_front(n: numDims) << ')';
470 if (operands.size() > numDims)
471 printer << '[' << operands.drop_front(n: numDims) << ']';
472}
473
474/// Parses dimension and symbol list and returns true if parsing failed.
475ParseResult mlir::affine::parseDimAndSymbolList(
476 OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
477 SmallVector<OpAsmParser::UnresolvedOperand, 8> opInfos;
478 if (parser.parseOperandList(result&: opInfos, delimiter: OpAsmParser::Delimiter::Paren))
479 return failure();
480 // Store number of dimensions for validation by caller.
481 numDims = opInfos.size();
482
483 // Parse the optional symbol operands.
484 auto indexTy = parser.getBuilder().getIndexType();
485 return failure(parser.parseOperandList(
486 result&: opInfos, delimiter: OpAsmParser::Delimiter::OptionalSquare) ||
487 parser.resolveOperands(opInfos, indexTy, operands));
488}
489
490/// Utility function to verify that a set of operands are valid dimension and
491/// symbol identifiers. The operands should be laid out such that the dimension
492/// operands are before the symbol operands. This function returns failure if
493/// there was an invalid operand. An operation is provided to emit any necessary
494/// errors.
495template <typename OpTy>
496static LogicalResult
497verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands,
498 unsigned numDims) {
499 unsigned opIt = 0;
500 for (auto operand : operands) {
501 if (opIt++ < numDims) {
502 if (!isValidDim(operand, getAffineScope(op)))
503 return op.emitOpError("operand cannot be used as a dimension id");
504 } else if (!isValidSymbol(operand, getAffineScope(op))) {
505 return op.emitOpError("operand cannot be used as a symbol");
506 }
507 }
508 return success();
509}
510
511//===----------------------------------------------------------------------===//
512// AffineApplyOp
513//===----------------------------------------------------------------------===//
514
515AffineValueMap AffineApplyOp::getAffineValueMap() {
516 return AffineValueMap(getAffineMap(), getOperands(), getResult());
517}
518
519ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
520 auto &builder = parser.getBuilder();
521 auto indexTy = builder.getIndexType();
522
523 AffineMapAttr mapAttr;
524 unsigned numDims;
525 if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
526 parseDimAndSymbolList(parser, result.operands, numDims) ||
527 parser.parseOptionalAttrDict(result.attributes))
528 return failure();
529 auto map = mapAttr.getValue();
530
531 if (map.getNumDims() != numDims ||
532 numDims + map.getNumSymbols() != result.operands.size()) {
533 return parser.emitError(parser.getNameLoc(),
534 "dimension or symbol index mismatch");
535 }
536
537 result.types.append(map.getNumResults(), indexTy);
538 return success();
539}
540
541void AffineApplyOp::print(OpAsmPrinter &p) {
542 p << " " << getMapAttr();
543 printDimAndSymbolList(operand_begin(), operand_end(),
544 getAffineMap().getNumDims(), p);
545 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map"});
546}
547
548LogicalResult AffineApplyOp::verify() {
549 // Check input and output dimensions match.
550 AffineMap affineMap = getMap();
551
552 // Verify that operand count matches affine map dimension and symbol count.
553 if (getNumOperands() != affineMap.getNumDims() + affineMap.getNumSymbols())
554 return emitOpError(
555 "operand count and affine map dimension and symbol count must match");
556
557 // Verify that the map only produces one result.
558 if (affineMap.getNumResults() != 1)
559 return emitOpError("mapping must produce one value");
560
561 return success();
562}
563
564// The result of the affine apply operation can be used as a dimension id if all
565// its operands are valid dimension ids.
566bool AffineApplyOp::isValidDim() {
567 return llvm::all_of(getOperands(),
568 [](Value op) { return affine::isValidDim(op); });
569}
570
571// The result of the affine apply operation can be used as a dimension id if all
572// its operands are valid dimension ids with the parent operation of `region`
573// defining the polyhedral scope for symbols.
574bool AffineApplyOp::isValidDim(Region *region) {
575 return llvm::all_of(getOperands(),
576 [&](Value op) { return ::isValidDim(op, region); });
577}
578
579// The result of the affine apply operation can be used as a symbol if all its
580// operands are symbols.
581bool AffineApplyOp::isValidSymbol() {
582 return llvm::all_of(getOperands(),
583 [](Value op) { return affine::isValidSymbol(op); });
584}
585
586// The result of the affine apply operation can be used as a symbol in `region`
587// if all its operands are symbols in `region`.
588bool AffineApplyOp::isValidSymbol(Region *region) {
589 return llvm::all_of(getOperands(), [&](Value operand) {
590 return affine::isValidSymbol(operand, region);
591 });
592}
593
594OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
595 auto map = getAffineMap();
596
597 // Fold dims and symbols to existing values.
598 auto expr = map.getResult(0);
599 if (auto dim = dyn_cast<AffineDimExpr>(expr))
600 return getOperand(dim.getPosition());
601 if (auto sym = dyn_cast<AffineSymbolExpr>(expr))
602 return getOperand(map.getNumDims() + sym.getPosition());
603
604 // Otherwise, default to folding the map.
605 SmallVector<Attribute, 1> result;
606 bool hasPoison = false;
607 auto foldResult =
608 map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
609 if (hasPoison)
610 return ub::PoisonAttr::get(getContext());
611 if (failed(foldResult))
612 return {};
613 return result[0];
614}
615
616/// Returns the largest known divisor of `e`. Exploits information from the
617/// values in `operands`.
618static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef<Value> operands) {
619 // This method isn't aware of `operands`.
620 int64_t div = e.getLargestKnownDivisor();
621
622 // We now make use of operands for the case `e` is a dim expression.
623 // TODO: More powerful simplification would have to modify
624 // getLargestKnownDivisor to take `operands` and exploit that information as
625 // well for dim/sym expressions, but in that case, getLargestKnownDivisor
626 // can't be part of the IR library but of the `Analysis` library. The IR
627 // library can only really depend on simple O(1) checks.
628 auto dimExpr = dyn_cast<AffineDimExpr>(Val&: e);
629 // If it's not a dim expr, `div` is the best we have.
630 if (!dimExpr)
631 return div;
632
633 // We simply exploit information from loop IVs.
634 // We don't need to use mlir::getLargestKnownDivisorOfValue since the other
635 // desired simplifications are expected to be part of other
636 // canonicalizations. Also, mlir::getLargestKnownDivisorOfValue is part of the
637 // LoopAnalysis library.
638 Value operand = operands[dimExpr.getPosition()];
639 int64_t operandDivisor = 1;
640 // TODO: With the right accessors, this can be extended to
641 // LoopLikeOpInterface.
642 if (AffineForOp forOp = getForInductionVarOwner(operand)) {
643 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
644 operandDivisor = forOp.getStepAsInt();
645 } else {
646 uint64_t lbLargestKnownDivisor =
647 forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
648 operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
649 }
650 }
651 return operandDivisor;
652}
653
654/// Check if `e` is known to be: 0 <= `e` < `k`. Handles the simple cases of `e`
655/// being an affine dim expression or a constant.
656static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef<Value> operands,
657 int64_t k) {
658 if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: e)) {
659 int64_t constVal = constExpr.getValue();
660 return constVal >= 0 && constVal < k;
661 }
662 auto dimExpr = dyn_cast<AffineDimExpr>(Val&: e);
663 if (!dimExpr)
664 return false;
665 Value operand = operands[dimExpr.getPosition()];
666 // TODO: With the right accessors, this can be extended to
667 // LoopLikeOpInterface.
668 if (AffineForOp forOp = getForInductionVarOwner(operand)) {
669 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
670 forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
671 return true;
672 }
673 }
674
675 // We don't consider other cases like `operand` being defined by a constant or
676 // an affine.apply op since such cases will already be handled by other
677 // patterns and propagation of loop IVs or constant would happen.
678 return false;
679}
680
681/// Check if expression `e` is of the form d*e_1 + e_2 where 0 <= e_2 < d.
682/// Set `div` to `d`, `quotientTimesDiv` to e_1 and `rem` to e_2 if the
683/// expression is in that form.
684static bool isQTimesDPlusR(AffineExpr e, ArrayRef<Value> operands, int64_t &div,
685 AffineExpr &quotientTimesDiv, AffineExpr &rem) {
686 auto bin = dyn_cast<AffineBinaryOpExpr>(Val&: e);
687 if (!bin || bin.getKind() != AffineExprKind::Add)
688 return false;
689
690 AffineExpr llhs = bin.getLHS();
691 AffineExpr rlhs = bin.getRHS();
692 div = getLargestKnownDivisor(e: llhs, operands);
693 if (isNonNegativeBoundedBy(e: rlhs, operands, k: div)) {
694 quotientTimesDiv = llhs;
695 rem = rlhs;
696 return true;
697 }
698 div = getLargestKnownDivisor(e: rlhs, operands);
699 if (isNonNegativeBoundedBy(e: llhs, operands, k: div)) {
700 quotientTimesDiv = rlhs;
701 rem = llhs;
702 return true;
703 }
704 return false;
705}
706
707/// Gets the constant lower bound on an `iv`.
708static std::optional<int64_t> getLowerBound(Value iv) {
709 AffineForOp forOp = getForInductionVarOwner(iv);
710 if (forOp && forOp.hasConstantLowerBound())
711 return forOp.getConstantLowerBound();
712 return std::nullopt;
713}
714
715/// Gets the constant upper bound on an affine.for `iv`.
716static std::optional<int64_t> getUpperBound(Value iv) {
717 AffineForOp forOp = getForInductionVarOwner(iv);
718 if (!forOp || !forOp.hasConstantUpperBound())
719 return std::nullopt;
720
721 // If its lower bound is also known, we can get a more precise bound
722 // whenever the step is not one.
723 if (forOp.hasConstantLowerBound()) {
724 return forOp.getConstantUpperBound() - 1 -
725 (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
726 forOp.getStepAsInt();
727 }
728 return forOp.getConstantUpperBound() - 1;
729}
730
731/// Determine a constant upper bound for `expr` if one exists while exploiting
732/// values in `operands`. Note that the upper bound is an inclusive one. `expr`
733/// is guaranteed to be less than or equal to it.
734static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
735 unsigned numSymbols,
736 ArrayRef<Value> operands) {
737 // Get the constant lower or upper bounds on the operands.
738 SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
739 constLowerBounds.reserve(N: operands.size());
740 constUpperBounds.reserve(N: operands.size());
741 for (Value operand : operands) {
742 constLowerBounds.push_back(Elt: getLowerBound(iv: operand));
743 constUpperBounds.push_back(Elt: getUpperBound(iv: operand));
744 }
745
746 if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: expr))
747 return constExpr.getValue();
748
749 return getBoundForAffineExpr(expr, numDims, numSymbols, constLowerBounds,
750 constUpperBounds,
751 /*isUpper=*/true);
752}
753
754/// Determine a constant lower bound for `expr` if one exists while exploiting
755/// values in `operands`. Note that the upper bound is an inclusive one. `expr`
756/// is guaranteed to be less than or equal to it.
757static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
758 unsigned numSymbols,
759 ArrayRef<Value> operands) {
760 // Get the constant lower or upper bounds on the operands.
761 SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
762 constLowerBounds.reserve(N: operands.size());
763 constUpperBounds.reserve(N: operands.size());
764 for (Value operand : operands) {
765 constLowerBounds.push_back(Elt: getLowerBound(iv: operand));
766 constUpperBounds.push_back(Elt: getUpperBound(iv: operand));
767 }
768
769 std::optional<int64_t> lowerBound;
770 if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: expr)) {
771 lowerBound = constExpr.getValue();
772 } else {
773 lowerBound = getBoundForAffineExpr(expr, numDims, numSymbols,
774 constLowerBounds, constUpperBounds,
775 /*isUpper=*/false);
776 }
777 return lowerBound;
778}
779
780/// Simplify `expr` while exploiting information from the values in `operands`.
781static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
782 unsigned numSymbols,
783 ArrayRef<Value> operands) {
784 // We do this only for certain floordiv/mod expressions.
785 auto binExpr = dyn_cast<AffineBinaryOpExpr>(Val&: expr);
786 if (!binExpr)
787 return;
788
789 // Simplify the child expressions first.
790 AffineExpr lhs = binExpr.getLHS();
791 AffineExpr rhs = binExpr.getRHS();
792 simplifyExprAndOperands(expr&: lhs, numDims, numSymbols, operands);
793 simplifyExprAndOperands(expr&: rhs, numDims, numSymbols, operands);
794 expr = getAffineBinaryOpExpr(kind: binExpr.getKind(), lhs, rhs);
795
796 binExpr = dyn_cast<AffineBinaryOpExpr>(Val&: expr);
797 if (!binExpr || (expr.getKind() != AffineExprKind::FloorDiv &&
798 expr.getKind() != AffineExprKind::CeilDiv &&
799 expr.getKind() != AffineExprKind::Mod)) {
800 return;
801 }
802
803 // The `lhs` and `rhs` may be different post construction of simplified expr.
804 lhs = binExpr.getLHS();
805 rhs = binExpr.getRHS();
806 auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhs);
807 if (!rhsConst)
808 return;
809
810 int64_t rhsConstVal = rhsConst.getValue();
811 // Undefined exprsessions aren't touched; IR can still be valid with them.
812 if (rhsConstVal <= 0)
813 return;
814
815 // Exploit constant lower/upper bounds to simplify a floordiv or mod.
816 MLIRContext *context = expr.getContext();
817 std::optional<int64_t> lhsLbConst =
818 getLowerBound(expr: lhs, numDims, numSymbols, operands);
819 std::optional<int64_t> lhsUbConst =
820 getUpperBound(expr: lhs, numDims, numSymbols, operands);
821 if (lhsLbConst && lhsUbConst) {
822 int64_t lhsLbConstVal = *lhsLbConst;
823 int64_t lhsUbConstVal = *lhsUbConst;
824 // lhs floordiv c is a single value lhs is bounded in a range `c` that has
825 // the same quotient.
826 if (binExpr.getKind() == AffineExprKind::FloorDiv &&
827 floorDiv(lhs: lhsLbConstVal, rhs: rhsConstVal) ==
828 floorDiv(lhs: lhsUbConstVal, rhs: rhsConstVal)) {
829 expr =
830 getAffineConstantExpr(constant: floorDiv(lhs: lhsLbConstVal, rhs: rhsConstVal), context);
831 return;
832 }
833 // lhs ceildiv c is a single value if the entire range has the same ceil
834 // quotient.
835 if (binExpr.getKind() == AffineExprKind::CeilDiv &&
836 ceilDiv(lhs: lhsLbConstVal, rhs: rhsConstVal) ==
837 ceilDiv(lhs: lhsUbConstVal, rhs: rhsConstVal)) {
838 expr =
839 getAffineConstantExpr(constant: ceilDiv(lhs: lhsLbConstVal, rhs: rhsConstVal), context);
840 return;
841 }
842 // lhs mod c is lhs if the entire range has quotient 0 w.r.t the rhs.
843 if (binExpr.getKind() == AffineExprKind::Mod && lhsLbConstVal >= 0 &&
844 lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
845 expr = lhs;
846 return;
847 }
848 }
849
850 // Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2)
851 // mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if
852 // `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c.
853 // And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c.
854 AffineExpr quotientTimesDiv, rem;
855 int64_t divisor;
856 if (isQTimesDPlusR(e: lhs, operands, div&: divisor, quotientTimesDiv, rem)) {
857 if (rhsConstVal % divisor == 0 &&
858 binExpr.getKind() == AffineExprKind::FloorDiv) {
859 expr = quotientTimesDiv.floorDiv(other: rhsConst);
860 } else if (divisor % rhsConstVal == 0 &&
861 binExpr.getKind() == AffineExprKind::Mod) {
862 expr = rem % rhsConst;
863 }
864 return;
865 }
866
867 // Handle the simple case when the LHS expression can be either upper
868 // bounded or is a known multiple of RHS constant.
869 // lhs floordiv c -> 0 if 0 <= lhs < c,
870 // lhs mod c -> 0 if lhs % c = 0.
871 if ((isNonNegativeBoundedBy(e: lhs, operands, k: rhsConstVal) &&
872 binExpr.getKind() == AffineExprKind::FloorDiv) ||
873 (getLargestKnownDivisor(e: lhs, operands) % rhsConstVal == 0 &&
874 binExpr.getKind() == AffineExprKind::Mod)) {
875 expr = getAffineConstantExpr(constant: 0, context: expr.getContext());
876 }
877}
878
879/// Simplify the expressions in `map` while making use of lower or upper bounds
880/// of its operands. If `isMax` is true, the map is to be treated as a max of
881/// its result expressions, and min otherwise. Eg: min (d0, d1) -> (8, 4 * d0 +
882/// d1) can be simplified to (8) if the operands are respectively lower bounded
883/// by 2 and 0 (the second expression can't be lower than 8).
884static void simplifyMinOrMaxExprWithOperands(AffineMap &map,
885 ArrayRef<Value> operands,
886 bool isMax) {
887 // Can't simplify.
888 if (operands.empty())
889 return;
890
891 // Get the upper or lower bound on an affine.for op IV using its range.
892 // Get the constant lower or upper bounds on the operands.
893 SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
894 constLowerBounds.reserve(N: operands.size());
895 constUpperBounds.reserve(N: operands.size());
896 for (Value operand : operands) {
897 constLowerBounds.push_back(Elt: getLowerBound(iv: operand));
898 constUpperBounds.push_back(Elt: getUpperBound(iv: operand));
899 }
900
901 // We will compute the lower and upper bounds on each of the expressions
902 // Then, we will check (depending on max or min) as to whether a specific
903 // bound is redundant by checking if its highest (in case of max) and its
904 // lowest (in the case of min) value is already lower than (or higher than)
905 // the lower bound (or upper bound in the case of min) of another bound.
906 SmallVector<std::optional<int64_t>, 4> lowerBounds, upperBounds;
907 lowerBounds.reserve(N: map.getNumResults());
908 upperBounds.reserve(N: map.getNumResults());
909 for (AffineExpr e : map.getResults()) {
910 if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: e)) {
911 lowerBounds.push_back(Elt: constExpr.getValue());
912 upperBounds.push_back(Elt: constExpr.getValue());
913 } else {
914 lowerBounds.push_back(
915 Elt: getBoundForAffineExpr(expr: e, numDims: map.getNumDims(), numSymbols: map.getNumSymbols(),
916 constLowerBounds, constUpperBounds,
917 /*isUpper=*/false));
918 upperBounds.push_back(
919 Elt: getBoundForAffineExpr(expr: e, numDims: map.getNumDims(), numSymbols: map.getNumSymbols(),
920 constLowerBounds, constUpperBounds,
921 /*isUpper=*/true));
922 }
923 }
924
925 // Collect expressions that are not redundant.
926 SmallVector<AffineExpr, 4> irredundantExprs;
927 for (auto exprEn : llvm::enumerate(First: map.getResults())) {
928 AffineExpr e = exprEn.value();
929 unsigned i = exprEn.index();
930 // Some expressions can be turned into constants.
931 if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
932 e = getAffineConstantExpr(constant: *lowerBounds[i], context: e.getContext());
933
934 // Check if the expression is redundant.
935 if (isMax) {
936 if (!upperBounds[i]) {
937 irredundantExprs.push_back(Elt: e);
938 continue;
939 }
940 // If there exists another expression such that its lower bound is greater
941 // than this expression's upper bound, it's redundant.
942 if (!llvm::any_of(Range: llvm::enumerate(First&: lowerBounds), P: [&](const auto &en) {
943 auto otherLowerBound = en.value();
944 unsigned pos = en.index();
945 if (pos == i || !otherLowerBound)
946 return false;
947 if (*otherLowerBound > *upperBounds[i])
948 return true;
949 if (*otherLowerBound < *upperBounds[i])
950 return false;
951 // Equality case. When both expressions are considered redundant, we
952 // don't want to get both of them. We keep the one that appears
953 // first.
954 if (upperBounds[pos] && lowerBounds[i] &&
955 lowerBounds[i] == upperBounds[i] &&
956 otherLowerBound == *upperBounds[pos] && i < pos)
957 return false;
958 return true;
959 }))
960 irredundantExprs.push_back(Elt: e);
961 } else {
962 if (!lowerBounds[i]) {
963 irredundantExprs.push_back(Elt: e);
964 continue;
965 }
966 // Likewise for the `min` case. Use the complement of the condition above.
967 if (!llvm::any_of(Range: llvm::enumerate(First&: upperBounds), P: [&](const auto &en) {
968 auto otherUpperBound = en.value();
969 unsigned pos = en.index();
970 if (pos == i || !otherUpperBound)
971 return false;
972 if (*otherUpperBound < *lowerBounds[i])
973 return true;
974 if (*otherUpperBound > *lowerBounds[i])
975 return false;
976 if (lowerBounds[pos] && upperBounds[i] &&
977 lowerBounds[i] == upperBounds[i] &&
978 otherUpperBound == lowerBounds[pos] && i < pos)
979 return false;
980 return true;
981 }))
982 irredundantExprs.push_back(Elt: e);
983 }
984 }
985
986 // Create the map without the redundant expressions.
987 map = AffineMap::get(dimCount: map.getNumDims(), symbolCount: map.getNumSymbols(), results: irredundantExprs,
988 context: map.getContext());
989}
990
991/// Simplify the map while exploiting information on the values in `operands`.
992// Use "unused attribute" marker to silence warning stemming from the inability
993// to see through the template expansion.
994static void LLVM_ATTRIBUTE_UNUSED
995simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) {
996 assert(map.getNumInputs() == operands.size() && "invalid operands for map");
997 SmallVector<AffineExpr> newResults;
998 newResults.reserve(N: map.getNumResults());
999 for (AffineExpr expr : map.getResults()) {
1000 simplifyExprAndOperands(expr, numDims: map.getNumDims(), numSymbols: map.getNumSymbols(),
1001 operands);
1002 newResults.push_back(Elt: expr);
1003 }
1004 map = AffineMap::get(dimCount: map.getNumDims(), symbolCount: map.getNumSymbols(), results: newResults,
1005 context: map.getContext());
1006}
1007
1008/// Replace all occurrences of AffineExpr at position `pos` in `map` by the
1009/// defining AffineApplyOp expression and operands.
1010/// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
1011/// When `dimOrSymbolPosition >= dims.size()`,
1012/// AffineSymbolExpr@[pos - dims.size()] is replaced.
1013/// Mutate `map`,`dims` and `syms` in place as follows:
1014/// 1. `dims` and `syms` are only appended to.
1015/// 2. `map` dim and symbols are gradually shifted to higher positions.
1016/// 3. Old `dim` and `sym` entries are replaced by nullptr
1017/// This avoids the need for any bookkeeping.
1018static LogicalResult replaceDimOrSym(AffineMap *map,
1019 unsigned dimOrSymbolPosition,
1020 SmallVectorImpl<Value> &dims,
1021 SmallVectorImpl<Value> &syms) {
1022 MLIRContext *ctx = map->getContext();
1023 bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1024 unsigned pos = isDimReplacement ? dimOrSymbolPosition
1025 : dimOrSymbolPosition - dims.size();
1026 Value &v = isDimReplacement ? dims[pos] : syms[pos];
1027 if (!v)
1028 return failure();
1029
1030 auto affineApply = v.getDefiningOp<AffineApplyOp>();
1031 if (!affineApply)
1032 return failure();
1033
1034 // At this point we will perform a replacement of `v`, set the entry in `dim`
1035 // or `sym` to nullptr immediately.
1036 v = nullptr;
1037
1038 // Compute the map, dims and symbols coming from the AffineApplyOp.
1039 AffineMap composeMap = affineApply.getAffineMap();
1040 assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
1041 SmallVector<Value> composeOperands(affineApply.getMapOperands().begin(),
1042 affineApply.getMapOperands().end());
1043 // Canonicalize the map to promote dims to symbols when possible. This is to
1044 // avoid generating invalid maps.
1045 canonicalizeMapAndOperands(map: &composeMap, operands: &composeOperands);
1046 AffineExpr replacementExpr =
1047 composeMap.shiftDims(shift: dims.size()).shiftSymbols(shift: syms.size()).getResult(idx: 0);
1048 ValueRange composeDims =
1049 ArrayRef<Value>(composeOperands).take_front(N: composeMap.getNumDims());
1050 ValueRange composeSyms =
1051 ArrayRef<Value>(composeOperands).take_back(N: composeMap.getNumSymbols());
1052 AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(position: pos, context: ctx)
1053 : getAffineSymbolExpr(position: pos, context: ctx);
1054
1055 // Append the dims and symbols where relevant and perform the replacement.
1056 dims.append(in_start: composeDims.begin(), in_end: composeDims.end());
1057 syms.append(in_start: composeSyms.begin(), in_end: composeSyms.end());
1058 *map = map->replace(expr: toReplace, replacement: replacementExpr, numResultDims: dims.size(), numResultSyms: syms.size());
1059
1060 return success();
1061}
1062
1063/// Iterate over `operands` and fold away all those produced by an AffineApplyOp
1064/// iteratively. Perform canonicalization of map and operands as well as
1065/// AffineMap simplification. `map` and `operands` are mutated in place.
1066static void composeAffineMapAndOperands(AffineMap *map,
1067 SmallVectorImpl<Value> *operands) {
1068 if (map->getNumResults() == 0) {
1069 canonicalizeMapAndOperands(map, operands);
1070 *map = simplifyAffineMap(map: *map);
1071 return;
1072 }
1073
1074 MLIRContext *ctx = map->getContext();
1075 SmallVector<Value, 4> dims(operands->begin(),
1076 operands->begin() + map->getNumDims());
1077 SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(),
1078 operands->end());
1079
1080 // Iterate over dims and symbols coming from AffineApplyOp and replace until
1081 // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims`
1082 // and `syms` can only increase by construction.
1083 // The implementation uses a `while` loop to support the case of symbols
1084 // that may be constructed from dims ;this may be overkill.
1085 while (true) {
1086 bool changed = false;
1087 for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1088 if ((changed |= succeeded(result: replaceDimOrSym(map, dimOrSymbolPosition: pos, dims, syms))))
1089 break;
1090 if (!changed)
1091 break;
1092 }
1093
1094 // Clear operands so we can fill them anew.
1095 operands->clear();
1096
1097 // At this point we may have introduced null operands, prune them out before
1098 // canonicalizing map and operands.
1099 unsigned nDims = 0, nSyms = 0;
1100 SmallVector<AffineExpr, 4> dimReplacements, symReplacements;
1101 dimReplacements.reserve(N: dims.size());
1102 symReplacements.reserve(N: syms.size());
1103 for (auto *container : {&dims, &syms}) {
1104 bool isDim = (container == &dims);
1105 auto &repls = isDim ? dimReplacements : symReplacements;
1106 for (const auto &en : llvm::enumerate(First&: *container)) {
1107 Value v = en.value();
1108 if (!v) {
1109 assert(isDim ? !map->isFunctionOfDim(en.index())
1110 : !map->isFunctionOfSymbol(en.index()) &&
1111 "map is function of unexpected expr@pos");
1112 repls.push_back(Elt: getAffineConstantExpr(constant: 0, context: ctx));
1113 continue;
1114 }
1115 repls.push_back(Elt: isDim ? getAffineDimExpr(position: nDims++, context: ctx)
1116 : getAffineSymbolExpr(position: nSyms++, context: ctx));
1117 operands->push_back(Elt: v);
1118 }
1119 }
1120 *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, numResultDims: nDims,
1121 numResultSyms: nSyms);
1122
1123 // Canonicalize and simplify before returning.
1124 canonicalizeMapAndOperands(map, operands);
1125 *map = simplifyAffineMap(map: *map);
1126}
1127
1128void mlir::affine::fullyComposeAffineMapAndOperands(
1129 AffineMap *map, SmallVectorImpl<Value> *operands) {
1130 while (llvm::any_of(Range&: *operands, P: [](Value v) {
1131 return isa_and_nonnull<AffineApplyOp>(Val: v.getDefiningOp());
1132 })) {
1133 composeAffineMapAndOperands(map, operands);
1134 }
1135}
1136
1137AffineApplyOp
1138mlir::affine::makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
1139 ArrayRef<OpFoldResult> operands) {
1140 SmallVector<Value> valueOperands;
1141 map = foldAttributesIntoMap(b, map, operands, remainingValues&: valueOperands);
1142 composeAffineMapAndOperands(map: &map, operands: &valueOperands);
1143 assert(map);
1144 return b.create<AffineApplyOp>(loc, map, valueOperands);
1145}
1146
1147AffineApplyOp
1148mlir::affine::makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e,
1149 ArrayRef<OpFoldResult> operands) {
1150 return makeComposedAffineApply(
1151 b, loc,
1152 AffineMap::inferFromExprList(exprsList: ArrayRef<AffineExpr>{e}, context: b.getContext())
1153 .front(),
1154 operands);
1155}
1156
1157/// Composes the given affine map with the given list of operands, pulling in
1158/// the maps from any affine.apply operations that supply the operands.
1159static void composeMultiResultAffineMap(AffineMap &map,
1160 SmallVectorImpl<Value> &operands) {
1161 // Compose and canonicalize each expression in the map individually because
1162 // composition only applies to single-result maps, collecting potentially
1163 // duplicate operands in a single list with shifted dimensions and symbols.
1164 SmallVector<Value> dims, symbols;
1165 SmallVector<AffineExpr> exprs;
1166 for (unsigned i : llvm::seq<unsigned>(Begin: 0, End: map.getNumResults())) {
1167 SmallVector<Value> submapOperands(operands.begin(), operands.end());
1168 AffineMap submap = map.getSubMap(resultPos: {i});
1169 fullyComposeAffineMapAndOperands(map: &submap, operands: &submapOperands);
1170 canonicalizeMapAndOperands(map: &submap, operands: &submapOperands);
1171 unsigned numNewDims = submap.getNumDims();
1172 submap = submap.shiftDims(shift: dims.size()).shiftSymbols(shift: symbols.size());
1173 llvm::append_range(C&: dims,
1174 R: ArrayRef<Value>(submapOperands).take_front(N: numNewDims));
1175 llvm::append_range(C&: symbols,
1176 R: ArrayRef<Value>(submapOperands).drop_front(N: numNewDims));
1177 exprs.push_back(Elt: submap.getResult(idx: 0));
1178 }
1179
1180 // Canonicalize the map created from composed expressions to deduplicate the
1181 // dimension and symbol operands.
1182 operands = llvm::to_vector(Range: llvm::concat<Value>(Ranges&: dims, Ranges&: symbols));
1183 map = AffineMap::get(dimCount: dims.size(), symbolCount: symbols.size(), results: exprs, context: map.getContext());
1184 canonicalizeMapAndOperands(map: &map, operands: &operands);
1185}
1186
1187OpFoldResult
1188mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
1189 AffineMap map,
1190 ArrayRef<OpFoldResult> operands) {
1191 assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
1192
1193 // Create new builder without a listener, so that no notification is
1194 // triggered if the op is folded.
1195 // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1196 // workaround is no longer needed.
1197 OpBuilder newBuilder(b.getContext());
1198 newBuilder.setInsertionPoint(block: b.getInsertionBlock(), insertPoint: b.getInsertionPoint());
1199
1200 // Create op.
1201 AffineApplyOp applyOp =
1202 makeComposedAffineApply(newBuilder, loc, map, operands);
1203
1204 // Get constant operands.
1205 SmallVector<Attribute> constOperands(applyOp->getNumOperands());
1206 for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1207 matchPattern(applyOp->getOperand(i), m_Constant(bind_value: &constOperands[i]));
1208
1209 // Try to fold the operation.
1210 SmallVector<OpFoldResult> foldResults;
1211 if (failed(applyOp->fold(constOperands, foldResults)) ||
1212 foldResults.empty()) {
1213 if (OpBuilder::Listener *listener = b.getListener())
1214 listener->notifyOperationInserted(op: applyOp, /*previous=*/{});
1215 return applyOp.getResult();
1216 }
1217
1218 applyOp->erase();
1219 assert(foldResults.size() == 1 && "expected 1 folded result");
1220 return foldResults.front();
1221}
1222
1223OpFoldResult
1224mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
1225 AffineExpr expr,
1226 ArrayRef<OpFoldResult> operands) {
1227 return makeComposedFoldedAffineApply(
1228 b, loc,
1229 map: AffineMap::inferFromExprList(exprsList: ArrayRef<AffineExpr>{expr}, context: b.getContext())
1230 .front(),
1231 operands);
1232}
1233
1234SmallVector<OpFoldResult>
1235mlir::affine::makeComposedFoldedMultiResultAffineApply(
1236 OpBuilder &b, Location loc, AffineMap map,
1237 ArrayRef<OpFoldResult> operands) {
1238 return llvm::map_to_vector(C: llvm::seq<unsigned>(Begin: 0, End: map.getNumResults()),
1239 F: [&](unsigned i) {
1240 return makeComposedFoldedAffineApply(
1241 b, loc, map: map.getSubMap(resultPos: {i}), operands);
1242 });
1243}
1244
1245template <typename OpTy>
1246static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map,
1247 ArrayRef<OpFoldResult> operands) {
1248 SmallVector<Value> valueOperands;
1249 map = foldAttributesIntoMap(b, map, operands, remainingValues&: valueOperands);
1250 composeMultiResultAffineMap(map, operands&: valueOperands);
1251 return b.create<OpTy>(loc, b.getIndexType(), map, valueOperands);
1252}
1253
1254AffineMinOp
1255mlir::affine::makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map,
1256 ArrayRef<OpFoldResult> operands) {
1257 return makeComposedMinMax<AffineMinOp>(b, loc, map, operands);
1258}
1259
1260template <typename OpTy>
1261static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc,
1262 AffineMap map,
1263 ArrayRef<OpFoldResult> operands) {
1264 // Create new builder without a listener, so that no notification is
1265 // triggered if the op is folded.
1266 // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1267 // workaround is no longer needed.
1268 OpBuilder newBuilder(b.getContext());
1269 newBuilder.setInsertionPoint(block: b.getInsertionBlock(), insertPoint: b.getInsertionPoint());
1270
1271 // Create op.
1272 auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1273
1274 // Get constant operands.
1275 SmallVector<Attribute> constOperands(minMaxOp->getNumOperands());
1276 for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1277 matchPattern(minMaxOp->getOperand(i), m_Constant(bind_value: &constOperands[i]));
1278
1279 // Try to fold the operation.
1280 SmallVector<OpFoldResult> foldResults;
1281 if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1282 foldResults.empty()) {
1283 if (OpBuilder::Listener *listener = b.getListener())
1284 listener->notifyOperationInserted(op: minMaxOp, /*previous=*/{});
1285 return minMaxOp.getResult();
1286 }
1287
1288 minMaxOp->erase();
1289 assert(foldResults.size() == 1 && "expected 1 folded result");
1290 return foldResults.front();
1291}
1292
1293OpFoldResult
1294mlir::affine::makeComposedFoldedAffineMin(OpBuilder &b, Location loc,
1295 AffineMap map,
1296 ArrayRef<OpFoldResult> operands) {
1297 return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
1298}
1299
1300OpFoldResult
1301mlir::affine::makeComposedFoldedAffineMax(OpBuilder &b, Location loc,
1302 AffineMap map,
1303 ArrayRef<OpFoldResult> operands) {
1304 return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
1305}
1306
1307// A symbol may appear as a dim in affine.apply operations. This function
1308// canonicalizes dims that are valid symbols into actual symbols.
1309template <class MapOrSet>
1310static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
1311 SmallVectorImpl<Value> *operands) {
1312 if (!mapOrSet || operands->empty())
1313 return;
1314
1315 assert(mapOrSet->getNumInputs() == operands->size() &&
1316 "map/set inputs must match number of operands");
1317
1318 auto *context = mapOrSet->getContext();
1319 SmallVector<Value, 8> resultOperands;
1320 resultOperands.reserve(N: operands->size());
1321 SmallVector<Value, 8> remappedSymbols;
1322 remappedSymbols.reserve(N: operands->size());
1323 unsigned nextDim = 0;
1324 unsigned nextSym = 0;
1325 unsigned oldNumSyms = mapOrSet->getNumSymbols();
1326 SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1327 for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1328 if (i < mapOrSet->getNumDims()) {
1329 if (isValidSymbol(value: (*operands)[i])) {
1330 // This is a valid symbol that appears as a dim, canonicalize it.
1331 dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
1332 remappedSymbols.push_back(Elt: (*operands)[i]);
1333 } else {
1334 dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1335 resultOperands.push_back(Elt: (*operands)[i]);
1336 }
1337 } else {
1338 resultOperands.push_back(Elt: (*operands)[i]);
1339 }
1340 }
1341
1342 resultOperands.append(in_start: remappedSymbols.begin(), in_end: remappedSymbols.end());
1343 *operands = resultOperands;
1344 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
1345 oldNumSyms + nextSym);
1346
1347 assert(mapOrSet->getNumInputs() == operands->size() &&
1348 "map/set inputs must match number of operands");
1349}
1350
1351// Works for either an affine map or an integer set.
1352template <class MapOrSet>
1353static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
1354 SmallVectorImpl<Value> *operands) {
1355 static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1356 "Argument must be either of AffineMap or IntegerSet type");
1357
1358 if (!mapOrSet || operands->empty())
1359 return;
1360
1361 assert(mapOrSet->getNumInputs() == operands->size() &&
1362 "map/set inputs must match number of operands");
1363
1364 canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1365
1366 // Check to see what dims are used.
1367 llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1368 llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1369 mapOrSet->walkExprs([&](AffineExpr expr) {
1370 if (auto dimExpr = dyn_cast<AffineDimExpr>(Val&: expr))
1371 usedDims[dimExpr.getPosition()] = true;
1372 else if (auto symExpr = dyn_cast<AffineSymbolExpr>(Val&: expr))
1373 usedSyms[symExpr.getPosition()] = true;
1374 });
1375
1376 auto *context = mapOrSet->getContext();
1377
1378 SmallVector<Value, 8> resultOperands;
1379 resultOperands.reserve(N: operands->size());
1380
1381 llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1382 SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1383 unsigned nextDim = 0;
1384 for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1385 if (usedDims[i]) {
1386 // Remap dim positions for duplicate operands.
1387 auto it = seenDims.find(Val: (*operands)[i]);
1388 if (it == seenDims.end()) {
1389 dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1390 resultOperands.push_back(Elt: (*operands)[i]);
1391 seenDims.insert(KV: std::make_pair(x&: (*operands)[i], y&: dimRemapping[i]));
1392 } else {
1393 dimRemapping[i] = it->second;
1394 }
1395 }
1396 }
1397 llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1398 SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
1399 unsigned nextSym = 0;
1400 for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1401 if (!usedSyms[i])
1402 continue;
1403 // Handle constant operands (only needed for symbolic operands since
1404 // constant operands in dimensional positions would have already been
1405 // promoted to symbolic positions above).
1406 IntegerAttr operandCst;
1407 if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
1408 m_Constant(&operandCst))) {
1409 symRemapping[i] =
1410 getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
1411 continue;
1412 }
1413 // Remap symbol positions for duplicate operands.
1414 auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1415 if (it == seenSymbols.end()) {
1416 symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
1417 resultOperands.push_back(Elt: (*operands)[i + mapOrSet->getNumDims()]);
1418 seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1419 symRemapping[i]));
1420 } else {
1421 symRemapping[i] = it->second;
1422 }
1423 }
1424 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1425 nextDim, nextSym);
1426 *operands = resultOperands;
1427}
1428
1429void mlir::affine::canonicalizeMapAndOperands(
1430 AffineMap *map, SmallVectorImpl<Value> *operands) {
1431 canonicalizeMapOrSetAndOperands<AffineMap>(mapOrSet: map, operands);
1432}
1433
1434void mlir::affine::canonicalizeSetAndOperands(
1435 IntegerSet *set, SmallVectorImpl<Value> *operands) {
1436 canonicalizeMapOrSetAndOperands<IntegerSet>(mapOrSet: set, operands);
1437}
1438
1439namespace {
1440/// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
1441/// maps that supply results into them.
1442///
1443template <typename AffineOpTy>
1444struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
1445 using OpRewritePattern<AffineOpTy>::OpRewritePattern;
1446
1447 /// Replace the affine op with another instance of it with the supplied
1448 /// map and mapOperands.
1449 void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
1450 AffineMap map, ArrayRef<Value> mapOperands) const;
1451
1452 LogicalResult matchAndRewrite(AffineOpTy affineOp,
1453 PatternRewriter &rewriter) const override {
1454 static_assert(
1455 llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1456 AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1457 AffineVectorStoreOp, AffineVectorLoadOp>::value,
1458 "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1459 "expected");
1460 auto map = affineOp.getAffineMap();
1461 AffineMap oldMap = map;
1462 auto oldOperands = affineOp.getMapOperands();
1463 SmallVector<Value, 8> resultOperands(oldOperands);
1464 composeAffineMapAndOperands(&map, &resultOperands);
1465 canonicalizeMapAndOperands(&map, &resultOperands);
1466 simplifyMapWithOperands(map, resultOperands);
1467 if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1468 resultOperands.begin()))
1469 return failure();
1470
1471 replaceAffineOp(rewriter, affineOp, map, mapOperands: resultOperands);
1472 return success();
1473 }
1474};
1475
1476// Specialize the template to account for the different build signatures for
1477// affine load, store, and apply ops.
1478template <>
1479void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1480 PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
1481 ArrayRef<Value> mapOperands) const {
1482 rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
1483 mapOperands);
1484}
1485template <>
1486void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1487 PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
1488 ArrayRef<Value> mapOperands) const {
1489 rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
1490 prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1491 prefetch.getLocalityHint(), prefetch.getIsDataCache());
1492}
1493template <>
1494void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1495 PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
1496 ArrayRef<Value> mapOperands) const {
1497 rewriter.replaceOpWithNewOp<AffineStoreOp>(
1498 store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1499}
1500template <>
1501void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1502 PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map,
1503 ArrayRef<Value> mapOperands) const {
1504 rewriter.replaceOpWithNewOp<AffineVectorLoadOp>(
1505 vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1506 mapOperands);
1507}
1508template <>
1509void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1510 PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map,
1511 ArrayRef<Value> mapOperands) const {
1512 rewriter.replaceOpWithNewOp<AffineVectorStoreOp>(
1513 vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1514 mapOperands);
1515}
1516
1517// Generic version for ops that don't have extra operands.
1518template <typename AffineOpTy>
1519void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1520 PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
1521 ArrayRef<Value> mapOperands) const {
1522 rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
1523}
1524} // namespace
1525
1526void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1527 MLIRContext *context) {
1528 results.add<SimplifyAffineOp<AffineApplyOp>>(context);
1529}
1530
1531//===----------------------------------------------------------------------===//
1532// AffineDmaStartOp
1533//===----------------------------------------------------------------------===//
1534
1535// TODO: Check that map operands are loop IVs or symbols.
1536void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
1537 Value srcMemRef, AffineMap srcMap,
1538 ValueRange srcIndices, Value destMemRef,
1539 AffineMap dstMap, ValueRange destIndices,
1540 Value tagMemRef, AffineMap tagMap,
1541 ValueRange tagIndices, Value numElements,
1542 Value stride, Value elementsPerStride) {
1543 result.addOperands(newOperands: srcMemRef);
1544 result.addAttribute(getSrcMapAttrStrName(), AffineMapAttr::get(srcMap));
1545 result.addOperands(newOperands: srcIndices);
1546 result.addOperands(newOperands: destMemRef);
1547 result.addAttribute(getDstMapAttrStrName(), AffineMapAttr::get(dstMap));
1548 result.addOperands(newOperands: destIndices);
1549 result.addOperands(newOperands: tagMemRef);
1550 result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1551 result.addOperands(newOperands: tagIndices);
1552 result.addOperands(newOperands: numElements);
1553 if (stride) {
1554 result.addOperands(newOperands: {stride, elementsPerStride});
1555 }
1556}
1557
1558void AffineDmaStartOp::print(OpAsmPrinter &p) {
1559 p << " " << getSrcMemRef() << '[';
1560 p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
1561 p << "], " << getDstMemRef() << '[';
1562 p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
1563 p << "], " << getTagMemRef() << '[';
1564 p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
1565 p << "], " << getNumElements();
1566 if (isStrided()) {
1567 p << ", " << getStride();
1568 p << ", " << getNumElementsPerStride();
1569 }
1570 p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
1571 << getTagMemRefType();
1572}
1573
1574// Parse AffineDmaStartOp.
1575// Ex:
1576// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
1577// %stride, %num_elt_per_stride
1578// : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
1579//
1580ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
1581 OperationState &result) {
1582 OpAsmParser::UnresolvedOperand srcMemRefInfo;
1583 AffineMapAttr srcMapAttr;
1584 SmallVector<OpAsmParser::UnresolvedOperand, 4> srcMapOperands;
1585 OpAsmParser::UnresolvedOperand dstMemRefInfo;
1586 AffineMapAttr dstMapAttr;
1587 SmallVector<OpAsmParser::UnresolvedOperand, 4> dstMapOperands;
1588 OpAsmParser::UnresolvedOperand tagMemRefInfo;
1589 AffineMapAttr tagMapAttr;
1590 SmallVector<OpAsmParser::UnresolvedOperand, 4> tagMapOperands;
1591 OpAsmParser::UnresolvedOperand numElementsInfo;
1592 SmallVector<OpAsmParser::UnresolvedOperand, 2> strideInfo;
1593
1594 SmallVector<Type, 3> types;
1595 auto indexType = parser.getBuilder().getIndexType();
1596
1597 // Parse and resolve the following list of operands:
1598 // *) dst memref followed by its affine maps operands (in square brackets).
1599 // *) src memref followed by its affine map operands (in square brackets).
1600 // *) tag memref followed by its affine map operands (in square brackets).
1601 // *) number of elements transferred by DMA operation.
1602 if (parser.parseOperand(result&: srcMemRefInfo) ||
1603 parser.parseAffineMapOfSSAIds(operands&: srcMapOperands, map&: srcMapAttr,
1604 attrName: getSrcMapAttrStrName(),
1605 attrs&: result.attributes) ||
1606 parser.parseComma() || parser.parseOperand(result&: dstMemRefInfo) ||
1607 parser.parseAffineMapOfSSAIds(operands&: dstMapOperands, map&: dstMapAttr,
1608 attrName: getDstMapAttrStrName(),
1609 attrs&: result.attributes) ||
1610 parser.parseComma() || parser.parseOperand(result&: tagMemRefInfo) ||
1611 parser.parseAffineMapOfSSAIds(operands&: tagMapOperands, map&: tagMapAttr,
1612 attrName: getTagMapAttrStrName(),
1613 attrs&: result.attributes) ||
1614 parser.parseComma() || parser.parseOperand(result&: numElementsInfo))
1615 return failure();
1616
1617 // Parse optional stride and elements per stride.
1618 if (parser.parseTrailingOperandList(result&: strideInfo))
1619 return failure();
1620
1621 if (!strideInfo.empty() && strideInfo.size() != 2) {
1622 return parser.emitError(loc: parser.getNameLoc(),
1623 message: "expected two stride related operands");
1624 }
1625 bool isStrided = strideInfo.size() == 2;
1626
1627 if (parser.parseColonTypeList(result&: types))
1628 return failure();
1629
1630 if (types.size() != 3)
1631 return parser.emitError(loc: parser.getNameLoc(), message: "expected three types");
1632
1633 if (parser.resolveOperand(operand: srcMemRefInfo, type: types[0], result&: result.operands) ||
1634 parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1635 parser.resolveOperand(operand: dstMemRefInfo, type: types[1], result&: result.operands) ||
1636 parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1637 parser.resolveOperand(operand: tagMemRefInfo, type: types[2], result&: result.operands) ||
1638 parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1639 parser.resolveOperand(operand: numElementsInfo, type: indexType, result&: result.operands))
1640 return failure();
1641
1642 if (isStrided) {
1643 if (parser.resolveOperands(strideInfo, indexType, result.operands))
1644 return failure();
1645 }
1646
1647 // Check that src/dst/tag operand counts match their map.numInputs.
1648 if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1649 dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1650 tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1651 return parser.emitError(loc: parser.getNameLoc(),
1652 message: "memref operand count not equal to map.numInputs");
1653 return success();
1654}
1655
1656LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
1657 if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).getType()))
1658 return emitOpError("expected DMA source to be of memref type");
1659 if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).getType()))
1660 return emitOpError("expected DMA destination to be of memref type");
1661 if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).getType()))
1662 return emitOpError("expected DMA tag to be of memref type");
1663
1664 unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1665 getDstMap().getNumInputs() +
1666 getTagMap().getNumInputs();
1667 if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1668 getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1669 return emitOpError("incorrect number of operands");
1670 }
1671
1672 Region *scope = getAffineScope(*this);
1673 for (auto idx : getSrcIndices()) {
1674 if (!idx.getType().isIndex())
1675 return emitOpError("src index to dma_start must have 'index' type");
1676 if (!isValidAffineIndexOperand(idx, scope))
1677 return emitOpError(
1678 "src index must be a valid dimension or symbol identifier");
1679 }
1680 for (auto idx : getDstIndices()) {
1681 if (!idx.getType().isIndex())
1682 return emitOpError("dst index to dma_start must have 'index' type");
1683 if (!isValidAffineIndexOperand(idx, scope))
1684 return emitOpError(
1685 "dst index must be a valid dimension or symbol identifier");
1686 }
1687 for (auto idx : getTagIndices()) {
1688 if (!idx.getType().isIndex())
1689 return emitOpError("tag index to dma_start must have 'index' type");
1690 if (!isValidAffineIndexOperand(idx, scope))
1691 return emitOpError(
1692 "tag index must be a valid dimension or symbol identifier");
1693 }
1694 return success();
1695}
1696
1697LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1698 SmallVectorImpl<OpFoldResult> &results) {
1699 /// dma_start(memrefcast) -> dma_start
1700 return memref::foldMemRefCast(*this);
1701}
1702
1703void AffineDmaStartOp::getEffects(
1704 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1705 &effects) {
1706 effects.emplace_back(Args: MemoryEffects::Read::get(), Args: getSrcMemRef(),
1707 Args: SideEffects::DefaultResource::get());
1708 effects.emplace_back(Args: MemoryEffects::Write::get(), Args: getDstMemRef(),
1709 Args: SideEffects::DefaultResource::get());
1710 effects.emplace_back(Args: MemoryEffects::Read::get(), Args: getTagMemRef(),
1711 Args: SideEffects::DefaultResource::get());
1712}
1713
1714//===----------------------------------------------------------------------===//
1715// AffineDmaWaitOp
1716//===----------------------------------------------------------------------===//
1717
1718// TODO: Check that map operands are loop IVs or symbols.
1719void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
1720 Value tagMemRef, AffineMap tagMap,
1721 ValueRange tagIndices, Value numElements) {
1722 result.addOperands(newOperands: tagMemRef);
1723 result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1724 result.addOperands(newOperands: tagIndices);
1725 result.addOperands(newOperands: numElements);
1726}
1727
1728void AffineDmaWaitOp::print(OpAsmPrinter &p) {
1729 p << " " << getTagMemRef() << '[';
1730 SmallVector<Value, 2> operands(getTagIndices());
1731 p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
1732 p << "], ";
1733 p.printOperand(value: getNumElements());
1734 p << " : " << getTagMemRef().getType();
1735}
1736
1737// Parse AffineDmaWaitOp.
1738// Eg:
1739// affine.dma_wait %tag[%index], %num_elements
1740// : memref<1 x i32, (d0) -> (d0), 4>
1741//
1742ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
1743 OperationState &result) {
1744 OpAsmParser::UnresolvedOperand tagMemRefInfo;
1745 AffineMapAttr tagMapAttr;
1746 SmallVector<OpAsmParser::UnresolvedOperand, 2> tagMapOperands;
1747 Type type;
1748 auto indexType = parser.getBuilder().getIndexType();
1749 OpAsmParser::UnresolvedOperand numElementsInfo;
1750
1751 // Parse tag memref, its map operands, and dma size.
1752 if (parser.parseOperand(result&: tagMemRefInfo) ||
1753 parser.parseAffineMapOfSSAIds(operands&: tagMapOperands, map&: tagMapAttr,
1754 attrName: getTagMapAttrStrName(),
1755 attrs&: result.attributes) ||
1756 parser.parseComma() || parser.parseOperand(result&: numElementsInfo) ||
1757 parser.parseColonType(result&: type) ||
1758 parser.resolveOperand(operand: tagMemRefInfo, type, result&: result.operands) ||
1759 parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1760 parser.resolveOperand(operand: numElementsInfo, type: indexType, result&: result.operands))
1761 return failure();
1762
1763 if (!llvm::isa<MemRefType>(Val: type))
1764 return parser.emitError(loc: parser.getNameLoc(),
1765 message: "expected tag to be of memref type");
1766
1767 if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1768 return parser.emitError(loc: parser.getNameLoc(),
1769 message: "tag memref operand count != to map.numInputs");
1770 return success();
1771}
1772
1773LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
1774 if (!llvm::isa<MemRefType>(getOperand(0).getType()))
1775 return emitOpError("expected DMA tag to be of memref type");
1776 Region *scope = getAffineScope(*this);
1777 for (auto idx : getTagIndices()) {
1778 if (!idx.getType().isIndex())
1779 return emitOpError("index to dma_wait must have 'index' type");
1780 if (!isValidAffineIndexOperand(idx, scope))
1781 return emitOpError(
1782 "index must be a valid dimension or symbol identifier");
1783 }
1784 return success();
1785}
1786
1787LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1788 SmallVectorImpl<OpFoldResult> &results) {
1789 /// dma_wait(memrefcast) -> dma_wait
1790 return memref::foldMemRefCast(*this);
1791}
1792
1793void AffineDmaWaitOp::getEffects(
1794 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1795 &effects) {
1796 effects.emplace_back(Args: MemoryEffects::Read::get(), Args: getTagMemRef(),
1797 Args: SideEffects::DefaultResource::get());
1798}
1799
1800//===----------------------------------------------------------------------===//
1801// AffineForOp
1802//===----------------------------------------------------------------------===//
1803
1804/// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
1805/// bodyBuilder are empty/null, we include default terminator op.
1806void AffineForOp::build(OpBuilder &builder, OperationState &result,
1807 ValueRange lbOperands, AffineMap lbMap,
1808 ValueRange ubOperands, AffineMap ubMap, int64_t step,
1809 ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1810 assert(((!lbMap && lbOperands.empty()) ||
1811 lbOperands.size() == lbMap.getNumInputs()) &&
1812 "lower bound operand count does not match the affine map");
1813 assert(((!ubMap && ubOperands.empty()) ||
1814 ubOperands.size() == ubMap.getNumInputs()) &&
1815 "upper bound operand count does not match the affine map");
1816 assert(step > 0 && "step has to be a positive integer constant");
1817
1818 OpBuilder::InsertionGuard guard(builder);
1819
1820 // Set variadic segment sizes.
1821 result.addAttribute(
1822 getOperandSegmentSizeAttr(),
1823 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lbOperands.size()),
1824 static_cast<int32_t>(ubOperands.size()),
1825 static_cast<int32_t>(iterArgs.size())}));
1826
1827 for (Value val : iterArgs)
1828 result.addTypes(val.getType());
1829
1830 // Add an attribute for the step.
1831 result.addAttribute(getStepAttrName(result.name),
1832 builder.getIntegerAttr(builder.getIndexType(), step));
1833
1834 // Add the lower bound.
1835 result.addAttribute(getLowerBoundMapAttrName(result.name),
1836 AffineMapAttr::get(lbMap));
1837 result.addOperands(lbOperands);
1838
1839 // Add the upper bound.
1840 result.addAttribute(getUpperBoundMapAttrName(result.name),
1841 AffineMapAttr::get(ubMap));
1842 result.addOperands(ubOperands);
1843
1844 result.addOperands(iterArgs);
1845 // Create a region and a block for the body. The argument of the region is
1846 // the loop induction variable.
1847 Region *bodyRegion = result.addRegion();
1848 Block *bodyBlock = builder.createBlock(bodyRegion);
1849 Value inductionVar =
1850 bodyBlock->addArgument(builder.getIndexType(), result.location);
1851 for (Value val : iterArgs)
1852 bodyBlock->addArgument(val.getType(), val.getLoc());
1853
1854 // Create the default terminator if the builder is not provided and if the
1855 // iteration arguments are not provided. Otherwise, leave this to the caller
1856 // because we don't know which values to return from the loop.
1857 if (iterArgs.empty() && !bodyBuilder) {
1858 ensureTerminator(*bodyRegion, builder, result.location);
1859 } else if (bodyBuilder) {
1860 OpBuilder::InsertionGuard guard(builder);
1861 builder.setInsertionPointToStart(bodyBlock);
1862 bodyBuilder(builder, result.location, inductionVar,
1863 bodyBlock->getArguments().drop_front());
1864 }
1865}
1866
1867void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
1868 int64_t ub, int64_t step, ValueRange iterArgs,
1869 BodyBuilderFn bodyBuilder) {
1870 auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
1871 auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
1872 return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1873 bodyBuilder);
1874}
1875
1876LogicalResult AffineForOp::verifyRegions() {
1877 // Check that the body defines as single block argument for the induction
1878 // variable.
1879 auto *body = getBody();
1880 if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1881 return emitOpError("expected body to have a single index argument for the "
1882 "induction variable");
1883
1884 // Verify that the bound operands are valid dimension/symbols.
1885 /// Lower bound.
1886 if (getLowerBoundMap().getNumInputs() > 0)
1887 if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundOperands(),
1888 getLowerBoundMap().getNumDims())))
1889 return failure();
1890 /// Upper bound.
1891 if (getUpperBoundMap().getNumInputs() > 0)
1892 if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundOperands(),
1893 getUpperBoundMap().getNumDims())))
1894 return failure();
1895
1896 unsigned opNumResults = getNumResults();
1897 if (opNumResults == 0)
1898 return success();
1899
1900 // If ForOp defines values, check that the number and types of the defined
1901 // values match ForOp initial iter operands and backedge basic block
1902 // arguments.
1903 if (getNumIterOperands() != opNumResults)
1904 return emitOpError(
1905 "mismatch between the number of loop-carried values and results");
1906 if (getNumRegionIterArgs() != opNumResults)
1907 return emitOpError(
1908 "mismatch between the number of basic block args and results");
1909
1910 return success();
1911}
1912
1913/// Parse a for operation loop bounds.
1914static ParseResult parseBound(bool isLower, OperationState &result,
1915 OpAsmParser &p) {
1916 // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
1917 // the map has multiple results.
1918 bool failedToParsedMinMax =
1919 failed(result: p.parseOptionalKeyword(keyword: isLower ? "max" : "min"));
1920
1921 auto &builder = p.getBuilder();
1922 auto boundAttrStrName =
1923 isLower ? AffineForOp::getLowerBoundMapAttrName(result.name)
1924 : AffineForOp::getUpperBoundMapAttrName(result.name);
1925
1926 // Parse ssa-id as identity map.
1927 SmallVector<OpAsmParser::UnresolvedOperand, 1> boundOpInfos;
1928 if (p.parseOperandList(result&: boundOpInfos))
1929 return failure();
1930
1931 if (!boundOpInfos.empty()) {
1932 // Check that only one operand was parsed.
1933 if (boundOpInfos.size() > 1)
1934 return p.emitError(loc: p.getNameLoc(),
1935 message: "expected only one loop bound operand");
1936
1937 // TODO: improve error message when SSA value is not of index type.
1938 // Currently it is 'use of value ... expects different type than prior uses'
1939 if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
1940 result.operands))
1941 return failure();
1942
1943 // Create an identity map using symbol id. This representation is optimized
1944 // for storage. Analysis passes may expand it into a multi-dimensional map
1945 // if desired.
1946 AffineMap map = builder.getSymbolIdentityMap();
1947 result.addAttribute(boundAttrStrName, AffineMapAttr::get(map));
1948 return success();
1949 }
1950
1951 // Get the attribute location.
1952 SMLoc attrLoc = p.getCurrentLocation();
1953
1954 Attribute boundAttr;
1955 if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrStrName,
1956 result.attributes))
1957 return failure();
1958
1959 // Parse full form - affine map followed by dim and symbol list.
1960 if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(boundAttr)) {
1961 unsigned currentNumOperands = result.operands.size();
1962 unsigned numDims;
1963 if (parseDimAndSymbolList(parser&: p, operands&: result.operands, numDims))
1964 return failure();
1965
1966 auto map = affineMapAttr.getValue();
1967 if (map.getNumDims() != numDims)
1968 return p.emitError(
1969 loc: p.getNameLoc(),
1970 message: "dim operand count and affine map dim count must match");
1971
1972 unsigned numDimAndSymbolOperands =
1973 result.operands.size() - currentNumOperands;
1974 if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
1975 return p.emitError(
1976 loc: p.getNameLoc(),
1977 message: "symbol operand count and affine map symbol count must match");
1978
1979 // If the map has multiple results, make sure that we parsed the min/max
1980 // prefix.
1981 if (map.getNumResults() > 1 && failedToParsedMinMax) {
1982 if (isLower) {
1983 return p.emitError(loc: attrLoc, message: "lower loop bound affine map with "
1984 "multiple results requires 'max' prefix");
1985 }
1986 return p.emitError(loc: attrLoc, message: "upper loop bound affine map with multiple "
1987 "results requires 'min' prefix");
1988 }
1989 return success();
1990 }
1991
1992 // Parse custom assembly form.
1993 if (auto integerAttr = llvm::dyn_cast<IntegerAttr>(boundAttr)) {
1994 result.attributes.pop_back();
1995 result.addAttribute(
1996 boundAttrStrName,
1997 AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
1998 return success();
1999 }
2000
2001 return p.emitError(
2002 loc: p.getNameLoc(),
2003 message: "expected valid affine map representation for loop bounds");
2004}
2005
2006ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
2007 auto &builder = parser.getBuilder();
2008 OpAsmParser::Argument inductionVariable;
2009 inductionVariable.type = builder.getIndexType();
2010 // Parse the induction variable followed by '='.
2011 if (parser.parseArgument(inductionVariable) || parser.parseEqual())
2012 return failure();
2013
2014 // Parse loop bounds.
2015 int64_t numOperands = result.operands.size();
2016 if (parseBound(/*isLower=*/true, result, parser))
2017 return failure();
2018 int64_t numLbOperands = result.operands.size() - numOperands;
2019 if (parser.parseKeyword("to", " between bounds"))
2020 return failure();
2021 numOperands = result.operands.size();
2022 if (parseBound(/*isLower=*/false, result, parser))
2023 return failure();
2024 int64_t numUbOperands = result.operands.size() - numOperands;
2025
2026 // Parse the optional loop step, we default to 1 if one is not present.
2027 if (parser.parseOptionalKeyword("step")) {
2028 result.addAttribute(
2029 getStepAttrName(result.name),
2030 builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
2031 } else {
2032 SMLoc stepLoc = parser.getCurrentLocation();
2033 IntegerAttr stepAttr;
2034 if (parser.parseAttribute(stepAttr, builder.getIndexType(),
2035 getStepAttrName(result.name).data(),
2036 result.attributes))
2037 return failure();
2038
2039 if (stepAttr.getValue().isNegative())
2040 return parser.emitError(
2041 stepLoc,
2042 "expected step to be representable as a positive signed integer");
2043 }
2044
2045 // Parse the optional initial iteration arguments.
2046 SmallVector<OpAsmParser::Argument, 4> regionArgs;
2047 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
2048
2049 // Induction variable.
2050 regionArgs.push_back(inductionVariable);
2051
2052 if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
2053 // Parse assignment list and results type list.
2054 if (parser.parseAssignmentList(regionArgs, operands) ||
2055 parser.parseArrowTypeList(result.types))
2056 return failure();
2057 // Resolve input operands.
2058 for (auto argOperandType :
2059 llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
2060 Type type = std::get<2>(argOperandType);
2061 std::get<0>(argOperandType).type = type;
2062 if (parser.resolveOperand(std::get<1>(argOperandType), type,
2063 result.operands))
2064 return failure();
2065 }
2066 }
2067
2068 result.addAttribute(
2069 getOperandSegmentSizeAttr(),
2070 builder.getDenseI32ArrayAttr({static_cast<int32_t>(numLbOperands),
2071 static_cast<int32_t>(numUbOperands),
2072 static_cast<int32_t>(operands.size())}));
2073
2074 // Parse the body region.
2075 Region *body = result.addRegion();
2076 if (regionArgs.size() != result.types.size() + 1)
2077 return parser.emitError(
2078 parser.getNameLoc(),
2079 "mismatch between the number of loop-carried values and results");
2080 if (parser.parseRegion(*body, regionArgs))
2081 return failure();
2082
2083 AffineForOp::ensureTerminator(*body, builder, result.location);
2084
2085 // Parse the optional attribute list.
2086 return parser.parseOptionalAttrDict(result.attributes);
2087}
2088
2089static void printBound(AffineMapAttr boundMap,
2090 Operation::operand_range boundOperands,
2091 const char *prefix, OpAsmPrinter &p) {
2092 AffineMap map = boundMap.getValue();
2093
2094 // Check if this bound should be printed using custom assembly form.
2095 // The decision to restrict printing custom assembly form to trivial cases
2096 // comes from the will to roundtrip MLIR binary -> text -> binary in a
2097 // lossless way.
2098 // Therefore, custom assembly form parsing and printing is only supported for
2099 // zero-operand constant maps and single symbol operand identity maps.
2100 if (map.getNumResults() == 1) {
2101 AffineExpr expr = map.getResult(idx: 0);
2102
2103 // Print constant bound.
2104 if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
2105 if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2106 p << constExpr.getValue();
2107 return;
2108 }
2109 }
2110
2111 // Print bound that consists of a single SSA symbol if the map is over a
2112 // single symbol.
2113 if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
2114 if (dyn_cast<AffineSymbolExpr>(Val&: expr)) {
2115 p.printOperand(value: *boundOperands.begin());
2116 return;
2117 }
2118 }
2119 } else {
2120 // Map has multiple results. Print 'min' or 'max' prefix.
2121 p << prefix << ' ';
2122 }
2123
2124 // Print the map and its operands.
2125 p << boundMap;
2126 printDimAndSymbolList(begin: boundOperands.begin(), end: boundOperands.end(),
2127 numDims: map.getNumDims(), printer&: p);
2128}
2129
2130unsigned AffineForOp::getNumIterOperands() {
2131 AffineMap lbMap = getLowerBoundMapAttr().getValue();
2132 AffineMap ubMap = getUpperBoundMapAttr().getValue();
2133
2134 return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
2135}
2136
2137std::optional<MutableArrayRef<OpOperand>>
2138AffineForOp::getYieldedValuesMutable() {
2139 return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2140}
2141
2142void AffineForOp::print(OpAsmPrinter &p) {
2143 p << ' ';
2144 p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{},
2145 /*omitType=*/true);
2146 p << " = ";
2147 printBound(getLowerBoundMapAttr(), getLowerBoundOperands(), "max", p);
2148 p << " to ";
2149 printBound(getUpperBoundMapAttr(), getUpperBoundOperands(), "min", p);
2150
2151 if (getStepAsInt() != 1)
2152 p << " step " << getStepAsInt();
2153
2154 bool printBlockTerminators = false;
2155 if (getNumIterOperands() > 0) {
2156 p << " iter_args(";
2157 auto regionArgs = getRegionIterArgs();
2158 auto operands = getInits();
2159
2160 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
2161 p << std::get<0>(it) << " = " << std::get<1>(it);
2162 });
2163 p << ") -> (" << getResultTypes() << ")";
2164 printBlockTerminators = true;
2165 }
2166
2167 p << ' ';
2168 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2169 printBlockTerminators);
2170 p.printOptionalAttrDict(
2171 (*this)->getAttrs(),
2172 /*elidedAttrs=*/{getLowerBoundMapAttrName(getOperation()->getName()),
2173 getUpperBoundMapAttrName(getOperation()->getName()),
2174 getStepAttrName(getOperation()->getName()),
2175 getOperandSegmentSizeAttr()});
2176}
2177
2178/// Fold the constant bounds of a loop.
2179static LogicalResult foldLoopBounds(AffineForOp forOp) {
2180 auto foldLowerOrUpperBound = [&forOp](bool lower) {
2181 // Check to see if each of the operands is the result of a constant. If
2182 // so, get the value. If not, ignore it.
2183 SmallVector<Attribute, 8> operandConstants;
2184 auto boundOperands =
2185 lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2186 for (auto operand : boundOperands) {
2187 Attribute operandCst;
2188 matchPattern(operand, m_Constant(&operandCst));
2189 operandConstants.push_back(operandCst);
2190 }
2191
2192 AffineMap boundMap =
2193 lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2194 assert(boundMap.getNumResults() >= 1 &&
2195 "bound maps should have at least one result");
2196 SmallVector<Attribute, 4> foldedResults;
2197 if (failed(result: boundMap.constantFold(operandConstants, results&: foldedResults)))
2198 return failure();
2199
2200 // Compute the max or min as applicable over the results.
2201 assert(!foldedResults.empty() && "bounds should have at least one result");
2202 auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2203 for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2204 auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2205 maxOrMin = lower ? llvm::APIntOps::smax(A: maxOrMin, B: foldedResult)
2206 : llvm::APIntOps::smin(A: maxOrMin, B: foldedResult);
2207 }
2208 lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2209 : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2210 return success();
2211 };
2212
2213 // Try to fold the lower bound.
2214 bool folded = false;
2215 if (!forOp.hasConstantLowerBound())
2216 folded |= succeeded(result: foldLowerOrUpperBound(/*lower=*/true));
2217
2218 // Try to fold the upper bound.
2219 if (!forOp.hasConstantUpperBound())
2220 folded |= succeeded(result: foldLowerOrUpperBound(/*lower=*/false));
2221 return success(isSuccess: folded);
2222}
2223
2224/// Canonicalize the bounds of the given loop.
2225static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
2226 SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
2227 SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
2228
2229 auto lbMap = forOp.getLowerBoundMap();
2230 auto ubMap = forOp.getUpperBoundMap();
2231 auto prevLbMap = lbMap;
2232 auto prevUbMap = ubMap;
2233
2234 composeAffineMapAndOperands(&lbMap, &lbOperands);
2235 canonicalizeMapAndOperands(&lbMap, &lbOperands);
2236 simplifyMinOrMaxExprWithOperands(lbMap, lbOperands, /*isMax=*/true);
2237 simplifyMinOrMaxExprWithOperands(ubMap, ubOperands, /*isMax=*/false);
2238 lbMap = removeDuplicateExprs(lbMap);
2239
2240 composeAffineMapAndOperands(&ubMap, &ubOperands);
2241 canonicalizeMapAndOperands(&ubMap, &ubOperands);
2242 ubMap = removeDuplicateExprs(ubMap);
2243
2244 // Any canonicalization change always leads to updated map(s).
2245 if (lbMap == prevLbMap && ubMap == prevUbMap)
2246 return failure();
2247
2248 if (lbMap != prevLbMap)
2249 forOp.setLowerBound(lbOperands, lbMap);
2250 if (ubMap != prevUbMap)
2251 forOp.setUpperBound(ubOperands, ubMap);
2252 return success();
2253}
2254
2255namespace {
2256/// Returns constant trip count in trivial cases.
2257static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2258 int64_t step = forOp.getStepAsInt();
2259 if (!forOp.hasConstantBounds() || step <= 0)
2260 return std::nullopt;
2261 int64_t lb = forOp.getConstantLowerBound();
2262 int64_t ub = forOp.getConstantUpperBound();
2263 return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2264}
2265
2266/// This is a pattern to fold trivially empty loop bodies.
2267/// TODO: This should be moved into the folding hook.
2268struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
2269 using OpRewritePattern<AffineForOp>::OpRewritePattern;
2270
2271 LogicalResult matchAndRewrite(AffineForOp forOp,
2272 PatternRewriter &rewriter) const override {
2273 // Check that the body only contains a yield.
2274 if (!llvm::hasSingleElement(*forOp.getBody()))
2275 return failure();
2276 if (forOp.getNumResults() == 0)
2277 return success();
2278 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2279 if (tripCount && *tripCount == 0) {
2280 // The initial values of the iteration arguments would be the op's
2281 // results.
2282 rewriter.replaceOp(forOp, forOp.getInits());
2283 return success();
2284 }
2285 SmallVector<Value, 4> replacements;
2286 auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2287 auto iterArgs = forOp.getRegionIterArgs();
2288 bool hasValDefinedOutsideLoop = false;
2289 bool iterArgsNotInOrder = false;
2290 for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2291 Value val = yieldOp.getOperand(i);
2292 auto *iterArgIt = llvm::find(iterArgs, val);
2293 if (iterArgIt == iterArgs.end()) {
2294 // `val` is defined outside of the loop.
2295 assert(forOp.isDefinedOutsideOfLoop(val) &&
2296 "must be defined outside of the loop");
2297 hasValDefinedOutsideLoop = true;
2298 replacements.push_back(Elt: val);
2299 } else {
2300 unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2301 if (pos != i)
2302 iterArgsNotInOrder = true;
2303 replacements.push_back(Elt: forOp.getInits()[pos]);
2304 }
2305 }
2306 // Bail out when the trip count is unknown and the loop returns any value
2307 // defined outside of the loop or any iterArg out of order.
2308 if (!tripCount.has_value() &&
2309 (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2310 return failure();
2311 // Bail out when the loop iterates more than once and it returns any iterArg
2312 // out of order.
2313 if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2314 return failure();
2315 rewriter.replaceOp(forOp, replacements);
2316 return success();
2317 }
2318};
2319} // namespace
2320
2321void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
2322 MLIRContext *context) {
2323 results.add<AffineForEmptyLoopFolder>(context);
2324}
2325
2326OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2327 assert((point.isParent() || point == getRegion()) && "invalid region point");
2328
2329 // The initial operands map to the loop arguments after the induction
2330 // variable or are forwarded to the results when the trip count is zero.
2331 return getInits();
2332}
2333
2334void AffineForOp::getSuccessorRegions(
2335 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2336 assert((point.isParent() || point == getRegion()) && "expected loop region");
2337 // The loop may typically branch back to its body or to the parent operation.
2338 // If the predecessor is the parent op and the trip count is known to be at
2339 // least one, branch into the body using the iterator arguments. And in cases
2340 // we know the trip count is zero, it can only branch back to its parent.
2341 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
2342 if (point.isParent() && tripCount.has_value()) {
2343 if (tripCount.value() > 0) {
2344 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2345 return;
2346 }
2347 if (tripCount.value() == 0) {
2348 regions.push_back(RegionSuccessor(getResults()));
2349 return;
2350 }
2351 }
2352
2353 // From the loop body, if the trip count is one, we can only branch back to
2354 // the parent.
2355 if (!point.isParent() && tripCount && *tripCount == 1) {
2356 regions.push_back(RegionSuccessor(getResults()));
2357 return;
2358 }
2359
2360 // In all other cases, the loop may branch back to itself or the parent
2361 // operation.
2362 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2363 regions.push_back(RegionSuccessor(getResults()));
2364}
2365
2366/// Returns true if the affine.for has zero iterations in trivial cases.
2367static bool hasTrivialZeroTripCount(AffineForOp op) {
2368 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(op);
2369 return tripCount && *tripCount == 0;
2370}
2371
2372LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2373 SmallVectorImpl<OpFoldResult> &results) {
2374 bool folded = succeeded(foldLoopBounds(*this));
2375 folded |= succeeded(canonicalizeLoopBounds(*this));
2376 if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) {
2377 // The initial values of the loop-carried variables (iter_args) are the
2378 // results of the op. But this must be avoided for an affine.for op that
2379 // does not return any results. Since ops that do not return results cannot
2380 // be folded away, we would enter an infinite loop of folds on the same
2381 // affine.for op.
2382 results.assign(getInits().begin(), getInits().end());
2383 folded = true;
2384 }
2385 return success(folded);
2386}
2387
2388AffineBound AffineForOp::getLowerBound() {
2389 return AffineBound(*this, getLowerBoundOperands(), getLowerBoundMap());
2390}
2391
2392AffineBound AffineForOp::getUpperBound() {
2393 return AffineBound(*this, getUpperBoundOperands(), getUpperBoundMap());
2394}
2395
2396void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
2397 assert(lbOperands.size() == map.getNumInputs());
2398 assert(map.getNumResults() >= 1 && "bound map has at least one result");
2399 getLowerBoundOperandsMutable().assign(lbOperands);
2400 setLowerBoundMap(map);
2401}
2402
2403void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
2404 assert(ubOperands.size() == map.getNumInputs());
2405 assert(map.getNumResults() >= 1 && "bound map has at least one result");
2406 getUpperBoundOperandsMutable().assign(ubOperands);
2407 setUpperBoundMap(map);
2408}
2409
2410bool AffineForOp::hasConstantLowerBound() {
2411 return getLowerBoundMap().isSingleConstant();
2412}
2413
2414bool AffineForOp::hasConstantUpperBound() {
2415 return getUpperBoundMap().isSingleConstant();
2416}
2417
2418int64_t AffineForOp::getConstantLowerBound() {
2419 return getLowerBoundMap().getSingleConstantResult();
2420}
2421
2422int64_t AffineForOp::getConstantUpperBound() {
2423 return getUpperBoundMap().getSingleConstantResult();
2424}
2425
2426void AffineForOp::setConstantLowerBound(int64_t value) {
2427 setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
2428}
2429
2430void AffineForOp::setConstantUpperBound(int64_t value) {
2431 setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
2432}
2433
2434AffineForOp::operand_range AffineForOp::getControlOperands() {
2435 return {operand_begin(), operand_begin() + getLowerBoundOperands().size() +
2436 getUpperBoundOperands().size()};
2437}
2438
2439bool AffineForOp::matchingBoundOperandList() {
2440 auto lbMap = getLowerBoundMap();
2441 auto ubMap = getUpperBoundMap();
2442 if (lbMap.getNumDims() != ubMap.getNumDims() ||
2443 lbMap.getNumSymbols() != ubMap.getNumSymbols())
2444 return false;
2445
2446 unsigned numOperands = lbMap.getNumInputs();
2447 for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
2448 // Compare Value 's.
2449 if (getOperand(i) != getOperand(numOperands + i))
2450 return false;
2451 }
2452 return true;
2453}
2454
2455SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
2456
2457std::optional<Value> AffineForOp::getSingleInductionVar() {
2458 return getInductionVar();
2459}
2460
2461std::optional<OpFoldResult> AffineForOp::getSingleLowerBound() {
2462 if (!hasConstantLowerBound())
2463 return std::nullopt;
2464 OpBuilder b(getContext());
2465 return OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()));
2466}
2467
2468std::optional<OpFoldResult> AffineForOp::getSingleStep() {
2469 OpBuilder b(getContext());
2470 return OpFoldResult(b.getI64IntegerAttr(getStepAsInt()));
2471}
2472
2473std::optional<OpFoldResult> AffineForOp::getSingleUpperBound() {
2474 if (!hasConstantUpperBound())
2475 return std::nullopt;
2476 OpBuilder b(getContext());
2477 return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()));
2478}
2479
2480FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2481 RewriterBase &rewriter, ValueRange newInitOperands,
2482 bool replaceInitOperandUsesInLoop,
2483 const NewYieldValuesFn &newYieldValuesFn) {
2484 // Create a new loop before the existing one, with the extra operands.
2485 OpBuilder::InsertionGuard g(rewriter);
2486 rewriter.setInsertionPoint(getOperation());
2487 auto inits = llvm::to_vector(getInits());
2488 inits.append(newInitOperands.begin(), newInitOperands.end());
2489 AffineForOp newLoop = rewriter.create<AffineForOp>(
2490 getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
2491 getUpperBoundOperands(), getUpperBoundMap(), getStepAsInt(), inits);
2492
2493 // Generate the new yield values and append them to the scf.yield operation.
2494 auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2495 ArrayRef<BlockArgument> newIterArgs =
2496 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2497 {
2498 OpBuilder::InsertionGuard g(rewriter);
2499 rewriter.setInsertionPoint(yieldOp);
2500 SmallVector<Value> newYieldedValues =
2501 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2502 assert(newInitOperands.size() == newYieldedValues.size() &&
2503 "expected as many new yield values as new iter operands");
2504 rewriter.modifyOpInPlace(yieldOp, [&]() {
2505 yieldOp.getOperandsMutable().append(newYieldedValues);
2506 });
2507 }
2508
2509 // Move the loop body to the new op.
2510 rewriter.mergeBlocks(getBody(), newLoop.getBody(),
2511 newLoop.getBody()->getArguments().take_front(
2512 getBody()->getNumArguments()));
2513
2514 if (replaceInitOperandUsesInLoop) {
2515 // Replace all uses of `newInitOperands` with the corresponding basic block
2516 // arguments.
2517 for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
2518 rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
2519 [&](OpOperand &use) {
2520 Operation *user = use.getOwner();
2521 return newLoop->isProperAncestor(user);
2522 });
2523 }
2524 }
2525
2526 // Replace the old loop.
2527 rewriter.replaceOp(getOperation(),
2528 newLoop->getResults().take_front(getNumResults()));
2529 return cast<LoopLikeOpInterface>(newLoop.getOperation());
2530}
2531
2532Speculation::Speculatability AffineForOp::getSpeculatability() {
2533 // `affine.for (I = Start; I < End; I += 1)` terminates for all values of
2534 // Start and End.
2535 //
2536 // For Step != 1, the loop may not terminate. We can add more smarts here if
2537 // needed.
2538 return getStepAsInt() == 1 ? Speculation::RecursivelySpeculatable
2539 : Speculation::NotSpeculatable;
2540}
2541
2542/// Returns true if the provided value is the induction variable of a
2543/// AffineForOp.
2544bool mlir::affine::isAffineForInductionVar(Value val) {
2545 return getForInductionVarOwner(val) != AffineForOp();
2546}
2547
2548bool mlir::affine::isAffineParallelInductionVar(Value val) {
2549 return getAffineParallelInductionVarOwner(val) != nullptr;
2550}
2551
2552bool mlir::affine::isAffineInductionVar(Value val) {
2553 return isAffineForInductionVar(val) || isAffineParallelInductionVar(val);
2554}
2555
2556AffineForOp mlir::affine::getForInductionVarOwner(Value val) {
2557 auto ivArg = llvm::dyn_cast<BlockArgument>(Val&: val);
2558 if (!ivArg || !ivArg.getOwner())
2559 return AffineForOp();
2560 auto *containingInst = ivArg.getOwner()->getParent()->getParentOp();
2561 if (auto forOp = dyn_cast<AffineForOp>(containingInst))
2562 // Check to make sure `val` is the induction variable, not an iter_arg.
2563 return forOp.getInductionVar() == val ? forOp : AffineForOp();
2564 return AffineForOp();
2565}
2566
2567AffineParallelOp mlir::affine::getAffineParallelInductionVarOwner(Value val) {
2568 auto ivArg = llvm::dyn_cast<BlockArgument>(Val&: val);
2569 if (!ivArg || !ivArg.getOwner())
2570 return nullptr;
2571 Operation *containingOp = ivArg.getOwner()->getParentOp();
2572 auto parallelOp = dyn_cast<AffineParallelOp>(containingOp);
2573 if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2574 return parallelOp;
2575 return nullptr;
2576}
2577
2578/// Extracts the induction variables from a list of AffineForOps and returns
2579/// them.
2580void mlir::affine::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
2581 SmallVectorImpl<Value> *ivs) {
2582 ivs->reserve(N: forInsts.size());
2583 for (auto forInst : forInsts)
2584 ivs->push_back(forInst.getInductionVar());
2585}
2586
2587void mlir::affine::extractInductionVars(ArrayRef<mlir::Operation *> affineOps,
2588 SmallVectorImpl<mlir::Value> &ivs) {
2589 ivs.reserve(N: affineOps.size());
2590 for (Operation *op : affineOps) {
2591 // Add constraints from forOp's bounds.
2592 if (auto forOp = dyn_cast<AffineForOp>(op))
2593 ivs.push_back(Elt: forOp.getInductionVar());
2594 else if (auto parallelOp = dyn_cast<AffineParallelOp>(op))
2595 for (size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2596 ivs.push_back(Elt: parallelOp.getBody()->getArgument(i));
2597 }
2598}
2599
2600/// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
2601/// operations.
2602template <typename BoundListTy, typename LoopCreatorTy>
2603static void buildAffineLoopNestImpl(
2604 OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
2605 ArrayRef<int64_t> steps,
2606 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
2607 LoopCreatorTy &&loopCreatorFn) {
2608 assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
2609 assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
2610
2611 // If there are no loops to be constructed, construct the body anyway.
2612 OpBuilder::InsertionGuard guard(builder);
2613 if (lbs.empty()) {
2614 if (bodyBuilderFn)
2615 bodyBuilderFn(builder, loc, ValueRange());
2616 return;
2617 }
2618
2619 // Create the loops iteratively and store the induction variables.
2620 SmallVector<Value, 4> ivs;
2621 ivs.reserve(N: lbs.size());
2622 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
2623 // Callback for creating the loop body, always creates the terminator.
2624 auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
2625 ValueRange iterArgs) {
2626 ivs.push_back(Elt: iv);
2627 // In the innermost loop, call the body builder.
2628 if (i == e - 1 && bodyBuilderFn) {
2629 OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
2630 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2631 }
2632 nestedBuilder.create<AffineYieldOp>(nestedLoc);
2633 };
2634
2635 // Delegate actual loop creation to the callback in order to dispatch
2636 // between constant- and variable-bound loops.
2637 auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2638 builder.setInsertionPointToStart(loop.getBody());
2639 }
2640}
2641
2642/// Creates an affine loop from the bounds known to be constants.
2643static AffineForOp
2644buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb,
2645 int64_t ub, int64_t step,
2646 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2647 return builder.create<AffineForOp>(loc, lb, ub, step,
2648 /*iterArgs=*/std::nullopt, bodyBuilderFn);
2649}
2650
2651/// Creates an affine loop from the bounds that may or may not be constants.
2652static AffineForOp
2653buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub,
2654 int64_t step,
2655 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2656 std::optional<int64_t> lbConst = getConstantIntValue(ofr: lb);
2657 std::optional<int64_t> ubConst = getConstantIntValue(ofr: ub);
2658 if (lbConst && ubConst)
2659 return buildAffineLoopFromConstants(builder, loc, lbConst.value(),
2660 ubConst.value(), step, bodyBuilderFn);
2661 return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
2662 builder.getDimIdentityMap(), step,
2663 /*iterArgs=*/std::nullopt, bodyBuilderFn);
2664}
2665
2666void mlir::affine::buildAffineLoopNest(
2667 OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
2668 ArrayRef<int64_t> ubs, ArrayRef<int64_t> steps,
2669 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2670 buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2671 buildAffineLoopFromConstants);
2672}
2673
2674void mlir::affine::buildAffineLoopNest(
2675 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
2676 ArrayRef<int64_t> steps,
2677 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2678 buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2679 buildAffineLoopFromValues);
2680}
2681
2682//===----------------------------------------------------------------------===//
2683// AffineIfOp
2684//===----------------------------------------------------------------------===//
2685
2686namespace {
2687/// Remove else blocks that have nothing other than a zero value yield.
2688struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
2689 using OpRewritePattern<AffineIfOp>::OpRewritePattern;
2690
2691 LogicalResult matchAndRewrite(AffineIfOp ifOp,
2692 PatternRewriter &rewriter) const override {
2693 if (ifOp.getElseRegion().empty() ||
2694 !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
2695 return failure();
2696
2697 rewriter.startOpModification(op: ifOp);
2698 rewriter.eraseBlock(block: ifOp.getElseBlock());
2699 rewriter.finalizeOpModification(op: ifOp);
2700 return success();
2701 }
2702};
2703
2704/// Removes affine.if cond if the condition is always true or false in certain
2705/// trivial cases. Promotes the then/else block in the parent operation block.
2706struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
2707 using OpRewritePattern<AffineIfOp>::OpRewritePattern;
2708
2709 LogicalResult matchAndRewrite(AffineIfOp op,
2710 PatternRewriter &rewriter) const override {
2711
2712 auto isTriviallyFalse = [](IntegerSet iSet) {
2713 return iSet.isEmptyIntegerSet();
2714 };
2715
2716 auto isTriviallyTrue = [](IntegerSet iSet) {
2717 return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2718 iSet.getConstraint(idx: 0) == 0);
2719 };
2720
2721 IntegerSet affineIfConditions = op.getIntegerSet();
2722 Block *blockToMove;
2723 if (isTriviallyFalse(affineIfConditions)) {
2724 // The absence, or equivalently, the emptiness of the else region need not
2725 // be checked when affine.if is returning results because if an affine.if
2726 // operation is returning results, it always has a non-empty else region.
2727 if (op.getNumResults() == 0 && !op.hasElse()) {
2728 // If the else region is absent, or equivalently, empty, remove the
2729 // affine.if operation (which is not returning any results).
2730 rewriter.eraseOp(op: op);
2731 return success();
2732 }
2733 blockToMove = op.getElseBlock();
2734 } else if (isTriviallyTrue(affineIfConditions)) {
2735 blockToMove = op.getThenBlock();
2736 } else {
2737 return failure();
2738 }
2739 Operation *blockToMoveTerminator = blockToMove->getTerminator();
2740 // Promote the "blockToMove" block to the parent operation block between the
2741 // prologue and epilogue of "op".
2742 rewriter.inlineBlockBefore(blockToMove, op);
2743 // Replace the "op" operation with the operands of the
2744 // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is
2745 // the affine.yield operation present in the "blockToMove" block. It has no
2746 // operands when affine.if is not returning results and therefore, in that
2747 // case, replaceOp just erases "op". When affine.if is not returning
2748 // results, the affine.yield operation can be omitted. It gets inserted
2749 // implicitly.
2750 rewriter.replaceOp(op, blockToMoveTerminator->getOperands());
2751 // Erase the "blockToMoveTerminator" operation since it is now in the parent
2752 // operation block, which already has its own terminator.
2753 rewriter.eraseOp(op: blockToMoveTerminator);
2754 return success();
2755 }
2756};
2757} // namespace
2758
2759/// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
2760/// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
2761void AffineIfOp::getSuccessorRegions(
2762 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2763 // If the predecessor is an AffineIfOp, then branching into both `then` and
2764 // `else` region is valid.
2765 if (point.isParent()) {
2766 regions.reserve(2);
2767 regions.push_back(
2768 RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));
2769 // If the "else" region is empty, branch bach into parent.
2770 if (getElseRegion().empty()) {
2771 regions.push_back(getResults());
2772 } else {
2773 regions.push_back(
2774 RegionSuccessor(&getElseRegion(), getElseRegion().getArguments()));
2775 }
2776 return;
2777 }
2778
2779 // If the predecessor is the `else`/`then` region, then branching into parent
2780 // op is valid.
2781 regions.push_back(RegionSuccessor(getResults()));
2782}
2783
2784LogicalResult AffineIfOp::verify() {
2785 // Verify that we have a condition attribute.
2786 // FIXME: This should be specified in the arguments list in ODS.
2787 auto conditionAttr =
2788 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2789 if (!conditionAttr)
2790 return emitOpError("requires an integer set attribute named 'condition'");
2791
2792 // Verify that there are enough operands for the condition.
2793 IntegerSet condition = conditionAttr.getValue();
2794 if (getNumOperands() != condition.getNumInputs())
2795 return emitOpError("operand count and condition integer set dimension and "
2796 "symbol count must match");
2797
2798 // Verify that the operands are valid dimension/symbols.
2799 if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(),
2800 condition.getNumDims())))
2801 return failure();
2802
2803 return success();
2804}
2805
2806ParseResult AffineIfOp::parse(OpAsmParser &parser, OperationState &result) {
2807 // Parse the condition attribute set.
2808 IntegerSetAttr conditionAttr;
2809 unsigned numDims;
2810 if (parser.parseAttribute(conditionAttr,
2811 AffineIfOp::getConditionAttrStrName(),
2812 result.attributes) ||
2813 parseDimAndSymbolList(parser, result.operands, numDims))
2814 return failure();
2815
2816 // Verify the condition operands.
2817 auto set = conditionAttr.getValue();
2818 if (set.getNumDims() != numDims)
2819 return parser.emitError(
2820 parser.getNameLoc(),
2821 "dim operand count and integer set dim count must match");
2822 if (numDims + set.getNumSymbols() != result.operands.size())
2823 return parser.emitError(
2824 parser.getNameLoc(),
2825 "symbol operand count and integer set symbol count must match");
2826
2827 if (parser.parseOptionalArrowTypeList(result.types))
2828 return failure();
2829
2830 // Create the regions for 'then' and 'else'. The latter must be created even
2831 // if it remains empty for the validity of the operation.
2832 result.regions.reserve(2);
2833 Region *thenRegion = result.addRegion();
2834 Region *elseRegion = result.addRegion();
2835
2836 // Parse the 'then' region.
2837 if (parser.parseRegion(*thenRegion, {}, {}))
2838 return failure();
2839 AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
2840 result.location);
2841
2842 // If we find an 'else' keyword then parse the 'else' region.
2843 if (!parser.parseOptionalKeyword("else")) {
2844 if (parser.parseRegion(*elseRegion, {}, {}))
2845 return failure();
2846 AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
2847 result.location);
2848 }
2849
2850 // Parse the optional attribute list.
2851 if (parser.parseOptionalAttrDict(result.attributes))
2852 return failure();
2853
2854 return success();
2855}
2856
2857void AffineIfOp::print(OpAsmPrinter &p) {
2858 auto conditionAttr =
2859 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2860 p << " " << conditionAttr;
2861 printDimAndSymbolList(operand_begin(), operand_end(),
2862 conditionAttr.getValue().getNumDims(), p);
2863 p.printOptionalArrowTypeList(getResultTypes());
2864 p << ' ';
2865 p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false,
2866 /*printBlockTerminators=*/getNumResults());
2867
2868 // Print the 'else' regions if it has any blocks.
2869 auto &elseRegion = this->getElseRegion();
2870 if (!elseRegion.empty()) {
2871 p << " else ";
2872 p.printRegion(elseRegion,
2873 /*printEntryBlockArgs=*/false,
2874 /*printBlockTerminators=*/getNumResults());
2875 }
2876
2877 // Print the attribute list.
2878 p.printOptionalAttrDict((*this)->getAttrs(),
2879 /*elidedAttrs=*/getConditionAttrStrName());
2880}
2881
2882IntegerSet AffineIfOp::getIntegerSet() {
2883 return (*this)
2884 ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
2885 .getValue();
2886}
2887
2888void AffineIfOp::setIntegerSet(IntegerSet newSet) {
2889 (*this)->setAttr(getConditionAttrStrName(), IntegerSetAttr::get(newSet));
2890}
2891
2892void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
2893 setIntegerSet(set);
2894 (*this)->setOperands(operands);
2895}
2896
2897void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2898 TypeRange resultTypes, IntegerSet set, ValueRange args,
2899 bool withElseRegion) {
2900 assert(resultTypes.empty() || withElseRegion);
2901 OpBuilder::InsertionGuard guard(builder);
2902
2903 result.addTypes(resultTypes);
2904 result.addOperands(args);
2905 result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set));
2906
2907 Region *thenRegion = result.addRegion();
2908 builder.createBlock(thenRegion);
2909 if (resultTypes.empty())
2910 AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
2911
2912 Region *elseRegion = result.addRegion();
2913 if (withElseRegion) {
2914 builder.createBlock(elseRegion);
2915 if (resultTypes.empty())
2916 AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
2917 }
2918}
2919
2920void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2921 IntegerSet set, ValueRange args, bool withElseRegion) {
2922 AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
2923 withElseRegion);
2924}
2925
2926/// Compose any affine.apply ops feeding into `operands` of the integer set
2927/// `set` by composing the maps of such affine.apply ops with the integer
2928/// set constraints.
2929static void composeSetAndOperands(IntegerSet &set,
2930 SmallVectorImpl<Value> &operands) {
2931 // We will simply reuse the API of the map composition by viewing the LHSs of
2932 // the equalities and inequalities of `set` as the affine exprs of an affine
2933 // map. Convert to equivalent map, compose, and convert back to set.
2934 auto map = AffineMap::get(dimCount: set.getNumDims(), symbolCount: set.getNumSymbols(),
2935 results: set.getConstraints(), context: set.getContext());
2936 // Check if any composition is possible.
2937 if (llvm::none_of(Range&: operands,
2938 P: [](Value v) { return v.getDefiningOp<AffineApplyOp>(); }))
2939 return;
2940
2941 composeAffineMapAndOperands(map: &map, operands: &operands);
2942 set = IntegerSet::get(dimCount: map.getNumDims(), symbolCount: map.getNumSymbols(), constraints: map.getResults(),
2943 eqFlags: set.getEqFlags());
2944}
2945
2946/// Canonicalize an affine if op's conditional (integer set + operands).
2947LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2948 auto set = getIntegerSet();
2949 SmallVector<Value, 4> operands(getOperands());
2950 composeSetAndOperands(set, operands);
2951 canonicalizeSetAndOperands(&set, &operands);
2952
2953 // Check if the canonicalization or composition led to any change.
2954 if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
2955 return failure();
2956
2957 setConditional(set, operands);
2958 return success();
2959}
2960
2961void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2962 MLIRContext *context) {
2963 results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
2964}
2965
2966//===----------------------------------------------------------------------===//
2967// AffineLoadOp
2968//===----------------------------------------------------------------------===//
2969
2970void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2971 AffineMap map, ValueRange operands) {
2972 assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
2973 result.addOperands(operands);
2974 if (map)
2975 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
2976 auto memrefType = llvm::cast<MemRefType>(operands[0].getType());
2977 result.types.push_back(memrefType.getElementType());
2978}
2979
2980void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2981 Value memref, AffineMap map, ValueRange mapOperands) {
2982 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2983 result.addOperands(memref);
2984 result.addOperands(mapOperands);
2985 auto memrefType = llvm::cast<MemRefType>(memref.getType());
2986 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
2987 result.types.push_back(memrefType.getElementType());
2988}
2989
2990void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2991 Value memref, ValueRange indices) {
2992 auto memrefType = llvm::cast<MemRefType>(memref.getType());
2993 int64_t rank = memrefType.getRank();
2994 // Create identity map for memrefs with at least one dimension or () -> ()
2995 // for zero-dimensional memrefs.
2996 auto map =
2997 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
2998 build(builder, result, memref, map, indices);
2999}
3000
3001ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
3002 auto &builder = parser.getBuilder();
3003 auto indexTy = builder.getIndexType();
3004
3005 MemRefType type;
3006 OpAsmParser::UnresolvedOperand memrefInfo;
3007 AffineMapAttr mapAttr;
3008 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
3009 return failure(
3010 parser.parseOperand(memrefInfo) ||
3011 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3012 AffineLoadOp::getMapAttrStrName(),
3013 result.attributes) ||
3014 parser.parseOptionalAttrDict(result.attributes) ||
3015 parser.parseColonType(type) ||
3016 parser.resolveOperand(memrefInfo, type, result.operands) ||
3017 parser.resolveOperands(mapOperands, indexTy, result.operands) ||
3018 parser.addTypeToList(type.getElementType(), result.types));
3019}
3020
3021void AffineLoadOp::print(OpAsmPrinter &p) {
3022 p << " " << getMemRef() << '[';
3023 if (AffineMapAttr mapAttr =
3024 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3025 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3026 p << ']';
3027 p.printOptionalAttrDict((*this)->getAttrs(),
3028 /*elidedAttrs=*/{getMapAttrStrName()});
3029 p << " : " << getMemRefType();
3030}
3031
3032/// Verify common indexing invariants of affine.load, affine.store,
3033/// affine.vector_load and affine.vector_store.
3034static LogicalResult
3035verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
3036 Operation::operand_range mapOperands,
3037 MemRefType memrefType, unsigned numIndexOperands) {
3038 AffineMap map = mapAttr.getValue();
3039 if (map.getNumResults() != memrefType.getRank())
3040 return op->emitOpError(message: "affine map num results must equal memref rank");
3041 if (map.getNumInputs() != numIndexOperands)
3042 return op->emitOpError(message: "expects as many subscripts as affine map inputs");
3043
3044 Region *scope = getAffineScope(op);
3045 for (auto idx : mapOperands) {
3046 if (!idx.getType().isIndex())
3047 return op->emitOpError(message: "index to load must have 'index' type");
3048 if (!isValidAffineIndexOperand(value: idx, region: scope))
3049 return op->emitOpError(
3050 message: "index must be a valid dimension or symbol identifier");
3051 }
3052
3053 return success();
3054}
3055
3056LogicalResult AffineLoadOp::verify() {
3057 auto memrefType = getMemRefType();
3058 if (getType() != memrefType.getElementType())
3059 return emitOpError("result type must match element type of memref");
3060
3061 if (failed(verifyMemoryOpIndexing(
3062 getOperation(),
3063 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3064 getMapOperands(), memrefType,
3065 /*numIndexOperands=*/getNumOperands() - 1)))
3066 return failure();
3067
3068 return success();
3069}
3070
3071void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3072 MLIRContext *context) {
3073 results.add<SimplifyAffineOp<AffineLoadOp>>(context);
3074}
3075
3076OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) {
3077 /// load(memrefcast) -> load
3078 if (succeeded(memref::foldMemRefCast(*this)))
3079 return getResult();
3080
3081 // Fold load from a global constant memref.
3082 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3083 if (!getGlobalOp)
3084 return {};
3085 // Get to the memref.global defining the symbol.
3086 auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
3087 if (!symbolTableOp)
3088 return {};
3089 auto global = dyn_cast_or_null<memref::GlobalOp>(
3090 SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr()));
3091 if (!global)
3092 return {};
3093
3094 // Check if the global memref is a constant.
3095 auto cstAttr =
3096 llvm::dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3097 if (!cstAttr)
3098 return {};
3099 // If it's a splat constant, we can fold irrespective of indices.
3100 if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(cstAttr))
3101 return splatAttr.getSplatValue<Attribute>();
3102 // Otherwise, we can fold only if we know the indices.
3103 if (!getAffineMap().isConstant())
3104 return {};
3105 auto indices = llvm::to_vector<4>(
3106 llvm::map_range(getAffineMap().getConstantResults(),
3107 [](int64_t v) -> uint64_t { return v; }));
3108 return cstAttr.getValues<Attribute>()[indices];
3109}
3110
3111//===----------------------------------------------------------------------===//
3112// AffineStoreOp
3113//===----------------------------------------------------------------------===//
3114
3115void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3116 Value valueToStore, Value memref, AffineMap map,
3117 ValueRange mapOperands) {
3118 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3119 result.addOperands(valueToStore);
3120 result.addOperands(memref);
3121 result.addOperands(mapOperands);
3122 result.getOrAddProperties<Properties>().map = AffineMapAttr::get(map);
3123}
3124
3125// Use identity map.
3126void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3127 Value valueToStore, Value memref,
3128 ValueRange indices) {
3129 auto memrefType = llvm::cast<MemRefType>(memref.getType());
3130 int64_t rank = memrefType.getRank();
3131 // Create identity map for memrefs with at least one dimension or () -> ()
3132 // for zero-dimensional memrefs.
3133 auto map =
3134 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3135 build(builder, result, valueToStore, memref, map, indices);
3136}
3137
3138ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
3139 auto indexTy = parser.getBuilder().getIndexType();
3140
3141 MemRefType type;
3142 OpAsmParser::UnresolvedOperand storeValueInfo;
3143 OpAsmParser::UnresolvedOperand memrefInfo;
3144 AffineMapAttr mapAttr;
3145 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
3146 return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
3147 parser.parseOperand(memrefInfo) ||
3148 parser.parseAffineMapOfSSAIds(
3149 mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3150 result.attributes) ||
3151 parser.parseOptionalAttrDict(result.attributes) ||
3152 parser.parseColonType(type) ||
3153 parser.resolveOperand(storeValueInfo, type.getElementType(),
3154 result.operands) ||
3155 parser.resolveOperand(memrefInfo, type, result.operands) ||
3156 parser.resolveOperands(mapOperands, indexTy, result.operands));
3157}
3158
3159void AffineStoreOp::print(OpAsmPrinter &p) {
3160 p << " " << getValueToStore();
3161 p << ", " << getMemRef() << '[';
3162 if (AffineMapAttr mapAttr =
3163 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3164 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3165 p << ']';
3166 p.printOptionalAttrDict((*this)->getAttrs(),
3167 /*elidedAttrs=*/{getMapAttrStrName()});
3168 p << " : " << getMemRefType();
3169}
3170
3171LogicalResult AffineStoreOp::verify() {
3172 // The value to store must have the same type as memref element type.
3173 auto memrefType = getMemRefType();
3174 if (getValueToStore().getType() != memrefType.getElementType())
3175 return emitOpError(
3176 "value to store must have the same type as memref element type");
3177
3178 if (failed(verifyMemoryOpIndexing(
3179 getOperation(),
3180 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3181 getMapOperands(), memrefType,
3182 /*numIndexOperands=*/getNumOperands() - 2)))
3183 return failure();
3184
3185 return success();
3186}
3187
3188void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3189 MLIRContext *context) {
3190 results.add<SimplifyAffineOp<AffineStoreOp>>(context);
3191}
3192
3193LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3194 SmallVectorImpl<OpFoldResult> &results) {
3195 /// store(memrefcast) -> store
3196 return memref::foldMemRefCast(*this, getValueToStore());
3197}
3198
3199//===----------------------------------------------------------------------===//
3200// AffineMinMaxOpBase
3201//===----------------------------------------------------------------------===//
3202
3203template <typename T>
3204static LogicalResult verifyAffineMinMaxOp(T op) {
3205 // Verify that operand count matches affine map dimension and symbol count.
3206 if (op.getNumOperands() !=
3207 op.getMap().getNumDims() + op.getMap().getNumSymbols())
3208 return op.emitOpError(
3209 "operand count and affine map dimension and symbol count must match");
3210 return success();
3211}
3212
3213template <typename T>
3214static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
3215 p << ' ' << op->getAttr(T::getMapAttrStrName());
3216 auto operands = op.getOperands();
3217 unsigned numDims = op.getMap().getNumDims();
3218 p << '(' << operands.take_front(numDims) << ')';
3219
3220 if (operands.size() != numDims)
3221 p << '[' << operands.drop_front(numDims) << ']';
3222 p.printOptionalAttrDict(attrs: op->getAttrs(),
3223 /*elidedAttrs=*/{T::getMapAttrStrName()});
3224}
3225
3226template <typename T>
3227static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
3228 OperationState &result) {
3229 auto &builder = parser.getBuilder();
3230 auto indexType = builder.getIndexType();
3231 SmallVector<OpAsmParser::UnresolvedOperand, 8> dimInfos;
3232 SmallVector<OpAsmParser::UnresolvedOperand, 8> symInfos;
3233 AffineMapAttr mapAttr;
3234 return failure(
3235 parser.parseAttribute(mapAttr, T::getMapAttrStrName(),
3236 result.attributes) ||
3237 parser.parseOperandList(result&: dimInfos, delimiter: OpAsmParser::Delimiter::Paren) ||
3238 parser.parseOperandList(result&: symInfos,
3239 delimiter: OpAsmParser::Delimiter::OptionalSquare) ||
3240 parser.parseOptionalAttrDict(result&: result.attributes) ||
3241 parser.resolveOperands(dimInfos, indexType, result.operands) ||
3242 parser.resolveOperands(symInfos, indexType, result.operands) ||
3243 parser.addTypeToList(type: indexType, result&: result.types));
3244}
3245
3246/// Fold an affine min or max operation with the given operands. The operand
3247/// list may contain nulls, which are interpreted as the operand not being a
3248/// constant.
3249template <typename T>
3250static OpFoldResult foldMinMaxOp(T op, ArrayRef<Attribute> operands) {
3251 static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3252 "expected affine min or max op");
3253
3254 // Fold the affine map.
3255 // TODO: Fold more cases:
3256 // min(some_affine, some_affine + constant, ...), etc.
3257 SmallVector<int64_t, 2> results;
3258 auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3259
3260 if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3261 return op.getOperand(0);
3262
3263 // If some of the map results are not constant, try changing the map in-place.
3264 if (results.empty()) {
3265 // If the map is the same, report that folding did not happen.
3266 if (foldedMap == op.getMap())
3267 return {};
3268 op->setAttr("map", AffineMapAttr::get(foldedMap));
3269 return op.getResult();
3270 }
3271
3272 // Otherwise, completely fold the op into a constant.
3273 auto resultIt = std::is_same<T, AffineMinOp>::value
3274 ? llvm::min_element(Range&: results)
3275 : llvm::max_element(Range&: results);
3276 if (resultIt == results.end())
3277 return {};
3278 return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
3279}
3280
3281/// Remove duplicated expressions in affine min/max ops.
3282template <typename T>
3283struct DeduplicateAffineMinMaxExpressions : public OpRewritePattern<T> {
3284 using OpRewritePattern<T>::OpRewritePattern;
3285
3286 LogicalResult matchAndRewrite(T affineOp,
3287 PatternRewriter &rewriter) const override {
3288 AffineMap oldMap = affineOp.getAffineMap();
3289
3290 SmallVector<AffineExpr, 4> newExprs;
3291 for (AffineExpr expr : oldMap.getResults()) {
3292 // This is a linear scan over newExprs, but it should be fine given that
3293 // we typically just have a few expressions per op.
3294 if (!llvm::is_contained(Range&: newExprs, Element: expr))
3295 newExprs.push_back(Elt: expr);
3296 }
3297
3298 if (newExprs.size() == oldMap.getNumResults())
3299 return failure();
3300
3301 auto newMap = AffineMap::get(dimCount: oldMap.getNumDims(), symbolCount: oldMap.getNumSymbols(),
3302 results: newExprs, context: rewriter.getContext());
3303 rewriter.replaceOpWithNewOp<T>(affineOp, newMap, affineOp.getMapOperands());
3304
3305 return success();
3306 }
3307};
3308
3309/// Merge an affine min/max op to its consumers if its consumer is also an
3310/// affine min/max op.
3311///
3312/// This pattern requires the producer affine min/max op is bound to a
3313/// dimension/symbol that is used as a standalone expression in the consumer
3314/// affine op's map.
3315///
3316/// For example, a pattern like the following:
3317///
3318/// %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1]
3319/// %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2]
3320///
3321/// Can be turned into:
3322///
3323/// %1 = affine.min affine_map<
3324/// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1]
3325template <typename T>
3326struct MergeAffineMinMaxOp : public OpRewritePattern<T> {
3327 using OpRewritePattern<T>::OpRewritePattern;
3328
3329 LogicalResult matchAndRewrite(T affineOp,
3330 PatternRewriter &rewriter) const override {
3331 AffineMap oldMap = affineOp.getAffineMap();
3332 ValueRange dimOperands =
3333 affineOp.getMapOperands().take_front(oldMap.getNumDims());
3334 ValueRange symOperands =
3335 affineOp.getMapOperands().take_back(oldMap.getNumSymbols());
3336
3337 auto newDimOperands = llvm::to_vector<8>(Range&: dimOperands);
3338 auto newSymOperands = llvm::to_vector<8>(Range&: symOperands);
3339 SmallVector<AffineExpr, 4> newExprs;
3340 SmallVector<T, 4> producerOps;
3341
3342 // Go over each expression to see whether it's a single dimension/symbol
3343 // with the corresponding operand which is the result of another affine
3344 // min/max op. If So it can be merged into this affine op.
3345 for (AffineExpr expr : oldMap.getResults()) {
3346 if (auto symExpr = dyn_cast<AffineSymbolExpr>(Val&: expr)) {
3347 Value symValue = symOperands[symExpr.getPosition()];
3348 if (auto producerOp = symValue.getDefiningOp<T>()) {
3349 producerOps.push_back(producerOp);
3350 continue;
3351 }
3352 } else if (auto dimExpr = dyn_cast<AffineDimExpr>(Val&: expr)) {
3353 Value dimValue = dimOperands[dimExpr.getPosition()];
3354 if (auto producerOp = dimValue.getDefiningOp<T>()) {
3355 producerOps.push_back(producerOp);
3356 continue;
3357 }
3358 }
3359 // For the above cases we will remove the expression by merging the
3360 // producer affine min/max's affine expressions. Otherwise we need to
3361 // keep the existing expression.
3362 newExprs.push_back(Elt: expr);
3363 }
3364
3365 if (producerOps.empty())
3366 return failure();
3367
3368 unsigned numUsedDims = oldMap.getNumDims();
3369 unsigned numUsedSyms = oldMap.getNumSymbols();
3370
3371 // Now go over all producer affine ops and merge their expressions.
3372 for (T producerOp : producerOps) {
3373 AffineMap producerMap = producerOp.getAffineMap();
3374 unsigned numProducerDims = producerMap.getNumDims();
3375 unsigned numProducerSyms = producerMap.getNumSymbols();
3376
3377 // Collect all dimension/symbol values.
3378 ValueRange dimValues =
3379 producerOp.getMapOperands().take_front(numProducerDims);
3380 ValueRange symValues =
3381 producerOp.getMapOperands().take_back(numProducerSyms);
3382 newDimOperands.append(in_start: dimValues.begin(), in_end: dimValues.end());
3383 newSymOperands.append(in_start: symValues.begin(), in_end: symValues.end());
3384
3385 // For expressions we need to shift to avoid overlap.
3386 for (AffineExpr expr : producerMap.getResults()) {
3387 newExprs.push_back(Elt: expr.shiftDims(numDims: numProducerDims, shift: numUsedDims)
3388 .shiftSymbols(numSymbols: numProducerSyms, shift: numUsedSyms));
3389 }
3390
3391 numUsedDims += numProducerDims;
3392 numUsedSyms += numProducerSyms;
3393 }
3394
3395 auto newMap = AffineMap::get(dimCount: numUsedDims, symbolCount: numUsedSyms, results: newExprs,
3396 context: rewriter.getContext());
3397 auto newOperands =
3398 llvm::to_vector<8>(Range: llvm::concat<Value>(Ranges&: newDimOperands, Ranges&: newSymOperands));
3399 rewriter.replaceOpWithNewOp<T>(affineOp, newMap, newOperands);
3400
3401 return success();
3402 }
3403};
3404
3405/// Canonicalize the result expression order of an affine map and return success
3406/// if the order changed.
3407///
3408/// The function flattens the map's affine expressions to coefficient arrays and
3409/// sorts them in lexicographic order. A coefficient array contains a multiplier
3410/// for every dimension/symbol and a constant term. The canonicalization fails
3411/// if a result expression is not pure or if the flattening requires local
3412/// variables that, unlike dimensions and symbols, have no global order.
3413static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map) {
3414 SmallVector<SmallVector<int64_t>> flattenedExprs;
3415 for (const AffineExpr &resultExpr : map.getResults()) {
3416 // Fail if the expression is not pure.
3417 if (!resultExpr.isPureAffine())
3418 return failure();
3419
3420 SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols());
3421 auto flattenResult = flattener.walkPostOrder(expr: resultExpr);
3422 if (failed(result: flattenResult))
3423 return failure();
3424
3425 // Fail if the flattened expression has local variables.
3426 if (flattener.operandExprStack.back().size() !=
3427 map.getNumDims() + map.getNumSymbols() + 1)
3428 return failure();
3429
3430 flattenedExprs.emplace_back(Args: flattener.operandExprStack.back().begin(),
3431 Args: flattener.operandExprStack.back().end());
3432 }
3433
3434 // Fail if sorting is not necessary.
3435 if (llvm::is_sorted(Range&: flattenedExprs))
3436 return failure();
3437
3438 // Reorder the result expressions according to their flattened form.
3439 SmallVector<unsigned> resultPermutation =
3440 llvm::to_vector(Range: llvm::seq<unsigned>(Begin: 0, End: map.getNumResults()));
3441 llvm::sort(C&: resultPermutation, Comp: [&](unsigned lhs, unsigned rhs) {
3442 return flattenedExprs[lhs] < flattenedExprs[rhs];
3443 });
3444 SmallVector<AffineExpr> newExprs;
3445 for (unsigned idx : resultPermutation)
3446 newExprs.push_back(Elt: map.getResult(idx));
3447
3448 map = AffineMap::get(dimCount: map.getNumDims(), symbolCount: map.getNumSymbols(), results: newExprs,
3449 context: map.getContext());
3450 return success();
3451}
3452
3453/// Canonicalize the affine map result expression order of an affine min/max
3454/// operation.
3455///
3456/// The pattern calls `canonicalizeMapExprAndTermOrder` to order the result
3457/// expressions and replaces the operation if the order changed.
3458///
3459/// For example, the following operation:
3460///
3461/// %0 = affine.min affine_map<(d0, d1) -> (d0 + d1, d1 + 16, 32)> (%i0, %i1)
3462///
3463/// Turns into:
3464///
3465/// %0 = affine.min affine_map<(d0, d1) -> (32, d1 + 16, d0 + d1)> (%i0, %i1)
3466template <typename T>
3467struct CanonicalizeAffineMinMaxOpExprAndTermOrder : public OpRewritePattern<T> {
3468 using OpRewritePattern<T>::OpRewritePattern;
3469
3470 LogicalResult matchAndRewrite(T affineOp,
3471 PatternRewriter &rewriter) const override {
3472 AffineMap map = affineOp.getAffineMap();
3473 if (failed(result: canonicalizeMapExprAndTermOrder(map)))
3474 return failure();
3475 rewriter.replaceOpWithNewOp<T>(affineOp, map, affineOp.getMapOperands());
3476 return success();
3477 }
3478};
3479
3480template <typename T>
3481struct CanonicalizeSingleResultAffineMinMaxOp : public OpRewritePattern<T> {
3482 using OpRewritePattern<T>::OpRewritePattern;
3483
3484 LogicalResult matchAndRewrite(T affineOp,
3485 PatternRewriter &rewriter) const override {
3486 if (affineOp.getMap().getNumResults() != 1)
3487 return failure();
3488 rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.getMap(),
3489 affineOp.getOperands());
3490 return success();
3491 }
3492};
3493
3494//===----------------------------------------------------------------------===//
3495// AffineMinOp
3496//===----------------------------------------------------------------------===//
3497//
3498// %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
3499//
3500
3501OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) {
3502 return foldMinMaxOp(*this, adaptor.getOperands());
3503}
3504
3505void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3506 MLIRContext *context) {
3507 patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMinOp>,
3508 DeduplicateAffineMinMaxExpressions<AffineMinOp>,
3509 MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>,
3510 CanonicalizeAffineMinMaxOpExprAndTermOrder<AffineMinOp>>(
3511 context);
3512}
3513
3514LogicalResult AffineMinOp::verify() { return verifyAffineMinMaxOp(*this); }
3515
3516ParseResult AffineMinOp::parse(OpAsmParser &parser, OperationState &result) {
3517 return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3518}
3519
3520void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3521
3522//===----------------------------------------------------------------------===//
3523// AffineMaxOp
3524//===----------------------------------------------------------------------===//
3525//
3526// %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
3527//
3528
3529OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) {
3530 return foldMinMaxOp(*this, adaptor.getOperands());
3531}
3532
3533void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3534 MLIRContext *context) {
3535 patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMaxOp>,
3536 DeduplicateAffineMinMaxExpressions<AffineMaxOp>,
3537 MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>,
3538 CanonicalizeAffineMinMaxOpExprAndTermOrder<AffineMaxOp>>(
3539 context);
3540}
3541
3542LogicalResult AffineMaxOp::verify() { return verifyAffineMinMaxOp(*this); }
3543
3544ParseResult AffineMaxOp::parse(OpAsmParser &parser, OperationState &result) {
3545 return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3546}
3547
3548void AffineMaxOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3549
3550//===----------------------------------------------------------------------===//
3551// AffinePrefetchOp
3552//===----------------------------------------------------------------------===//
3553
3554//
3555// affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
3556//
3557ParseResult AffinePrefetchOp::parse(OpAsmParser &parser,
3558 OperationState &result) {
3559 auto &builder = parser.getBuilder();
3560 auto indexTy = builder.getIndexType();
3561
3562 MemRefType type;
3563 OpAsmParser::UnresolvedOperand memrefInfo;
3564 IntegerAttr hintInfo;
3565 auto i32Type = parser.getBuilder().getIntegerType(32);
3566 StringRef readOrWrite, cacheType;
3567
3568 AffineMapAttr mapAttr;
3569 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
3570 if (parser.parseOperand(memrefInfo) ||
3571 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3572 AffinePrefetchOp::getMapAttrStrName(),
3573 result.attributes) ||
3574 parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
3575 parser.parseComma() || parser.parseKeyword("locality") ||
3576 parser.parseLess() ||
3577 parser.parseAttribute(hintInfo, i32Type,
3578 AffinePrefetchOp::getLocalityHintAttrStrName(),
3579 result.attributes) ||
3580 parser.parseGreater() || parser.parseComma() ||
3581 parser.parseKeyword(&cacheType) ||
3582 parser.parseOptionalAttrDict(result.attributes) ||
3583 parser.parseColonType(type) ||
3584 parser.resolveOperand(memrefInfo, type, result.operands) ||
3585 parser.resolveOperands(mapOperands, indexTy, result.operands))
3586 return failure();
3587
3588 if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
3589 return parser.emitError(parser.getNameLoc(),
3590 "rw specifier has to be 'read' or 'write'");
3591 result.addAttribute(
3592 AffinePrefetchOp::getIsWriteAttrStrName(),
3593 parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
3594
3595 if (!cacheType.equals("data") && !cacheType.equals("instr"))
3596 return parser.emitError(parser.getNameLoc(),
3597 "cache type has to be 'data' or 'instr'");
3598
3599 result.addAttribute(
3600 AffinePrefetchOp::getIsDataCacheAttrStrName(),
3601 parser.getBuilder().getBoolAttr(cacheType.equals("data")));
3602
3603 return success();
3604}
3605
3606void AffinePrefetchOp::print(OpAsmPrinter &p) {
3607 p << " " << getMemref() << '[';
3608 AffineMapAttr mapAttr =
3609 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3610 if (mapAttr)
3611 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3612 p << ']' << ", " << (getIsWrite() ? "write" : "read") << ", "
3613 << "locality<" << getLocalityHint() << ">, "
3614 << (getIsDataCache() ? "data" : "instr");
3615 p.printOptionalAttrDict(
3616 (*this)->getAttrs(),
3617 /*elidedAttrs=*/{getMapAttrStrName(), getLocalityHintAttrStrName(),
3618 getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3619 p << " : " << getMemRefType();
3620}
3621
3622LogicalResult AffinePrefetchOp::verify() {
3623 auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3624 if (mapAttr) {
3625 AffineMap map = mapAttr.getValue();
3626 if (map.getNumResults() != getMemRefType().getRank())
3627 return emitOpError("affine.prefetch affine map num results must equal"
3628 " memref rank");
3629 if (map.getNumInputs() + 1 != getNumOperands())
3630 return emitOpError("too few operands");
3631 } else {
3632 if (getNumOperands() != 1)
3633 return emitOpError("too few operands");
3634 }
3635
3636 Region *scope = getAffineScope(*this);
3637 for (auto idx : getMapOperands()) {
3638 if (!isValidAffineIndexOperand(idx, scope))
3639 return emitOpError(
3640 "index must be a valid dimension or symbol identifier");
3641 }
3642 return success();
3643}
3644
3645void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
3646 MLIRContext *context) {
3647 // prefetch(memrefcast) -> prefetch
3648 results.add<SimplifyAffineOp<AffinePrefetchOp>>(context);
3649}
3650
3651LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
3652 SmallVectorImpl<OpFoldResult> &results) {
3653 /// prefetch(memrefcast) -> prefetch
3654 return memref::foldMemRefCast(*this);
3655}
3656
3657//===----------------------------------------------------------------------===//
3658// AffineParallelOp
3659//===----------------------------------------------------------------------===//
3660
3661void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3662 TypeRange resultTypes,
3663 ArrayRef<arith::AtomicRMWKind> reductions,
3664 ArrayRef<int64_t> ranges) {
3665 SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
3666 auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
3667 return builder.getConstantAffineMap(value);
3668 }));
3669 SmallVector<int64_t> steps(ranges.size(), 1);
3670 build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs,
3671 /*ubArgs=*/{}, steps);
3672}
3673
3674void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3675 TypeRange resultTypes,
3676 ArrayRef<arith::AtomicRMWKind> reductions,
3677 ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
3678 ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
3679 ArrayRef<int64_t> steps) {
3680 assert(llvm::all_of(lbMaps,
3681 [lbMaps](AffineMap m) {
3682 return m.getNumDims() == lbMaps[0].getNumDims() &&
3683 m.getNumSymbols() == lbMaps[0].getNumSymbols();
3684 }) &&
3685 "expected all lower bounds maps to have the same number of dimensions "
3686 "and symbols");
3687 assert(llvm::all_of(ubMaps,
3688 [ubMaps](AffineMap m) {
3689 return m.getNumDims() == ubMaps[0].getNumDims() &&
3690 m.getNumSymbols() == ubMaps[0].getNumSymbols();
3691 }) &&
3692 "expected all upper bounds maps to have the same number of dimensions "
3693 "and symbols");
3694 assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3695 "expected lower bound maps to have as many inputs as lower bound "
3696 "operands");
3697 assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3698 "expected upper bound maps to have as many inputs as upper bound "
3699 "operands");
3700
3701 OpBuilder::InsertionGuard guard(builder);
3702 result.addTypes(resultTypes);
3703
3704 // Convert the reductions to integer attributes.
3705 SmallVector<Attribute, 4> reductionAttrs;
3706 for (arith::AtomicRMWKind reduction : reductions)
3707 reductionAttrs.push_back(
3708 builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
3709 result.addAttribute(getReductionsAttrStrName(),
3710 builder.getArrayAttr(reductionAttrs));
3711
3712 // Concatenates maps defined in the same input space (same dimensions and
3713 // symbols), assumes there is at least one map.
3714 auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps,
3715 SmallVectorImpl<int32_t> &groups) {
3716 if (maps.empty())
3717 return AffineMap::get(builder.getContext());
3718 SmallVector<AffineExpr> exprs;
3719 groups.reserve(groups.size() + maps.size());
3720 exprs.reserve(maps.size());
3721 for (AffineMap m : maps) {
3722 llvm::append_range(exprs, m.getResults());
3723 groups.push_back(m.getNumResults());
3724 }
3725 return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
3726 maps[0].getContext());
3727 };
3728
3729 // Set up the bounds.
3730 SmallVector<int32_t> lbGroups, ubGroups;
3731 AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3732 AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3733 result.addAttribute(getLowerBoundsMapAttrStrName(),
3734 AffineMapAttr::get(lbMap));
3735 result.addAttribute(getLowerBoundsGroupsAttrStrName(),
3736 builder.getI32TensorAttr(lbGroups));
3737 result.addAttribute(getUpperBoundsMapAttrStrName(),
3738 AffineMapAttr::get(ubMap));
3739 result.addAttribute(getUpperBoundsGroupsAttrStrName(),
3740 builder.getI32TensorAttr(ubGroups));
3741 result.addAttribute(getStepsAttrStrName(), builder.getI64ArrayAttr(steps));
3742 result.addOperands(lbArgs);
3743 result.addOperands(ubArgs);
3744
3745 // Create a region and a block for the body.
3746 auto *bodyRegion = result.addRegion();
3747 Block *body = builder.createBlock(bodyRegion);
3748
3749 // Add all the block arguments.
3750 for (unsigned i = 0, e = steps.size(); i < e; ++i)
3751 body->addArgument(IndexType::get(builder.getContext()), result.location);
3752 if (resultTypes.empty())
3753 ensureTerminator(*bodyRegion, builder, result.location);
3754}
3755
3756SmallVector<Region *> AffineParallelOp::getLoopRegions() {
3757 return {&getRegion()};
3758}
3759
3760unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
3761
3762AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
3763 return getOperands().take_front(getLowerBoundsMap().getNumInputs());
3764}
3765
3766AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
3767 return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
3768}
3769
3770AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
3771 auto values = getLowerBoundsGroups().getValues<int32_t>();
3772 unsigned start = 0;
3773 for (unsigned i = 0; i < pos; ++i)
3774 start += values[i];
3775 return getLowerBoundsMap().getSliceMap(start, values[pos]);
3776}
3777
3778AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
3779 auto values = getUpperBoundsGroups().getValues<int32_t>();
3780 unsigned start = 0;
3781 for (unsigned i = 0; i < pos; ++i)
3782 start += values[i];
3783 return getUpperBoundsMap().getSliceMap(start, values[pos]);
3784}
3785
3786AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
3787 return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
3788}
3789
3790AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
3791 return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
3792}
3793
3794std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
3795 if (hasMinMaxBounds())
3796 return std::nullopt;
3797
3798 // Try to convert all the ranges to constant expressions.
3799 SmallVector<int64_t, 8> out;
3800 AffineValueMap rangesValueMap;
3801 AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
3802 &rangesValueMap);
3803 out.reserve(rangesValueMap.getNumResults());
3804 for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
3805 auto expr = rangesValueMap.getResult(i);
3806 auto cst = dyn_cast<AffineConstantExpr>(expr);
3807 if (!cst)
3808 return std::nullopt;
3809 out.push_back(cst.getValue());
3810 }
3811 return out;
3812}
3813
3814Block *AffineParallelOp::getBody() { return &getRegion().front(); }
3815
3816OpBuilder AffineParallelOp::getBodyBuilder() {
3817 return OpBuilder(getBody(), std::prev(getBody()->end()));
3818}
3819
3820void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
3821 assert(lbOperands.size() == map.getNumInputs() &&
3822 "operands to map must match number of inputs");
3823
3824 auto ubOperands = getUpperBoundsOperands();
3825
3826 SmallVector<Value, 4> newOperands(lbOperands);
3827 newOperands.append(ubOperands.begin(), ubOperands.end());
3828 (*this)->setOperands(newOperands);
3829
3830 setLowerBoundsMapAttr(AffineMapAttr::get(map));
3831}
3832
3833void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
3834 assert(ubOperands.size() == map.getNumInputs() &&
3835 "operands to map must match number of inputs");
3836
3837 SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
3838 newOperands.append(ubOperands.begin(), ubOperands.end());
3839 (*this)->setOperands(newOperands);
3840
3841 setUpperBoundsMapAttr(AffineMapAttr::get(map));
3842}
3843
3844void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
3845 setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
3846}
3847
3848// check whether resultType match op or not in affine.parallel
3849static bool isResultTypeMatchAtomicRMWKind(Type resultType,
3850 arith::AtomicRMWKind op) {
3851 switch (op) {
3852 case arith::AtomicRMWKind::addf:
3853 return isa<FloatType>(Val: resultType);
3854 case arith::AtomicRMWKind::addi:
3855 return isa<IntegerType>(Val: resultType);
3856 case arith::AtomicRMWKind::assign:
3857 return true;
3858 case arith::AtomicRMWKind::mulf:
3859 return isa<FloatType>(Val: resultType);
3860 case arith::AtomicRMWKind::muli:
3861 return isa<IntegerType>(Val: resultType);
3862 case arith::AtomicRMWKind::maximumf:
3863 return isa<FloatType>(Val: resultType);
3864 case arith::AtomicRMWKind::minimumf:
3865 return isa<FloatType>(Val: resultType);
3866 case arith::AtomicRMWKind::maxs: {
3867 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3868 return intType && intType.isSigned();
3869 }
3870 case arith::AtomicRMWKind::mins: {
3871 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3872 return intType && intType.isSigned();
3873 }
3874 case arith::AtomicRMWKind::maxu: {
3875 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3876 return intType && intType.isUnsigned();
3877 }
3878 case arith::AtomicRMWKind::minu: {
3879 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3880 return intType && intType.isUnsigned();
3881 }
3882 case arith::AtomicRMWKind::ori:
3883 return isa<IntegerType>(Val: resultType);
3884 case arith::AtomicRMWKind::andi:
3885 return isa<IntegerType>(Val: resultType);
3886 default:
3887 return false;
3888 }
3889}
3890
3891LogicalResult AffineParallelOp::verify() {
3892 auto numDims = getNumDims();
3893 if (getLowerBoundsGroups().getNumElements() != numDims ||
3894 getUpperBoundsGroups().getNumElements() != numDims ||
3895 getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
3896 return emitOpError() << "the number of region arguments ("
3897 << getBody()->getNumArguments()
3898 << ") and the number of map groups for lower ("
3899 << getLowerBoundsGroups().getNumElements()
3900 << ") and upper bound ("
3901 << getUpperBoundsGroups().getNumElements()
3902 << "), and the number of steps (" << getSteps().size()
3903 << ") must all match";
3904 }
3905
3906 unsigned expectedNumLBResults = 0;
3907 for (APInt v : getLowerBoundsGroups())
3908 expectedNumLBResults += v.getZExtValue();
3909 if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
3910 return emitOpError() << "expected lower bounds map to have "
3911 << expectedNumLBResults << " results";
3912 unsigned expectedNumUBResults = 0;
3913 for (APInt v : getUpperBoundsGroups())
3914 expectedNumUBResults += v.getZExtValue();
3915 if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
3916 return emitOpError() << "expected upper bounds map to have "
3917 << expectedNumUBResults << " results";
3918
3919 if (getReductions().size() != getNumResults())
3920 return emitOpError("a reduction must be specified for each output");
3921
3922 // Verify reduction ops are all valid and each result type matches reduction
3923 // ops
3924 for (auto it : llvm::enumerate((getReductions()))) {
3925 Attribute attr = it.value();
3926 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
3927 if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
3928 return emitOpError("invalid reduction attribute");
3929 auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
3930 if (!isResultTypeMatchAtomicRMWKind(getResult(it.index()).getType(), kind))
3931 return emitOpError("result type cannot match reduction attribute");
3932 }
3933
3934 // Verify that the bound operands are valid dimension/symbols.
3935 /// Lower bounds.
3936 if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundsOperands(),
3937 getLowerBoundsMap().getNumDims())))
3938 return failure();
3939 /// Upper bounds.
3940 if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundsOperands(),
3941 getUpperBoundsMap().getNumDims())))
3942 return failure();
3943 return success();
3944}
3945
3946LogicalResult AffineValueMap::canonicalize() {
3947 SmallVector<Value, 4> newOperands{operands};
3948 auto newMap = getAffineMap();
3949 composeAffineMapAndOperands(map: &newMap, operands: &newOperands);
3950 if (newMap == getAffineMap() && newOperands == operands)
3951 return failure();
3952 reset(map: newMap, operands: newOperands);
3953 return success();
3954}
3955
3956/// Canonicalize the bounds of the given loop.
3957static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
3958 AffineValueMap lb = op.getLowerBoundsValueMap();
3959 bool lbCanonicalized = succeeded(result: lb.canonicalize());
3960
3961 AffineValueMap ub = op.getUpperBoundsValueMap();
3962 bool ubCanonicalized = succeeded(result: ub.canonicalize());
3963
3964 // Any canonicalization change always leads to updated map(s).
3965 if (!lbCanonicalized && !ubCanonicalized)
3966 return failure();
3967
3968 if (lbCanonicalized)
3969 op.setLowerBounds(lb.getOperands(), lb.getAffineMap());
3970 if (ubCanonicalized)
3971 op.setUpperBounds(ub.getOperands(), ub.getAffineMap());
3972
3973 return success();
3974}
3975
3976LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
3977 SmallVectorImpl<OpFoldResult> &results) {
3978 return canonicalizeLoopBounds(*this);
3979}
3980
3981/// Prints a lower(upper) bound of an affine parallel loop with max(min)
3982/// conditions in it. `mapAttr` is a flat list of affine expressions and `group`
3983/// identifies which of the those expressions form max/min groups. `operands`
3984/// are the SSA values of dimensions and symbols and `keyword` is either "min"
3985/// or "max".
3986static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr,
3987 DenseIntElementsAttr group, ValueRange operands,
3988 StringRef keyword) {
3989 AffineMap map = mapAttr.getValue();
3990 unsigned numDims = map.getNumDims();
3991 ValueRange dimOperands = operands.take_front(n: numDims);
3992 ValueRange symOperands = operands.drop_front(n: numDims);
3993 unsigned start = 0;
3994 for (llvm::APInt groupSize : group) {
3995 if (start != 0)
3996 p << ", ";
3997
3998 unsigned size = groupSize.getZExtValue();
3999 if (size == 1) {
4000 p.printAffineExprOfSSAIds(expr: map.getResult(idx: start), dimOperands, symOperands);
4001 ++start;
4002 } else {
4003 p << keyword << '(';
4004 AffineMap submap = map.getSliceMap(start, length: size);
4005 p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands);
4006 p << ')';
4007 start += size;
4008 }
4009 }
4010}
4011
4012void AffineParallelOp::print(OpAsmPrinter &p) {
4013 p << " (" << getBody()->getArguments() << ") = (";
4014 printMinMaxBound(p, getLowerBoundsMapAttr(), getLowerBoundsGroupsAttr(),
4015 getLowerBoundsOperands(), "max");
4016 p << ") to (";
4017 printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(),
4018 getUpperBoundsOperands(), "min");
4019 p << ')';
4020 SmallVector<int64_t, 8> steps = getSteps();
4021 bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
4022 if (!elideSteps) {
4023 p << " step (";
4024 llvm::interleaveComma(steps, p);
4025 p << ')';
4026 }
4027 if (getNumResults()) {
4028 p << " reduce (";
4029 llvm::interleaveComma(getReductions(), p, [&](auto &attr) {
4030 arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4031 llvm::cast<IntegerAttr>(attr).getInt());
4032 p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";
4033 });
4034 p << ") -> (" << getResultTypes() << ")";
4035 }
4036
4037 p << ' ';
4038 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
4039 /*printBlockTerminators=*/getNumResults());
4040 p.printOptionalAttrDict(
4041 (*this)->getAttrs(),
4042 /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrStrName(),
4043 AffineParallelOp::getLowerBoundsMapAttrStrName(),
4044 AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4045 AffineParallelOp::getUpperBoundsMapAttrStrName(),
4046 AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4047 AffineParallelOp::getStepsAttrStrName()});
4048}
4049
4050/// Given a list of lists of parsed operands, populates `uniqueOperands` with
4051/// unique operands. Also populates `replacements with affine expressions of
4052/// `kind` that can be used to update affine maps previously accepting a
4053/// `operands` to accept `uniqueOperands` instead.
4054static ParseResult deduplicateAndResolveOperands(
4055 OpAsmParser &parser,
4056 ArrayRef<SmallVector<OpAsmParser::UnresolvedOperand>> operands,
4057 SmallVectorImpl<Value> &uniqueOperands,
4058 SmallVectorImpl<AffineExpr> &replacements, AffineExprKind kind) {
4059 assert((kind == AffineExprKind::DimId || kind == AffineExprKind::SymbolId) &&
4060 "expected operands to be dim or symbol expression");
4061
4062 Type indexType = parser.getBuilder().getIndexType();
4063 for (const auto &list : operands) {
4064 SmallVector<Value> valueOperands;
4065 if (parser.resolveOperands(operands: list, type: indexType, result&: valueOperands))
4066 return failure();
4067 for (Value operand : valueOperands) {
4068 unsigned pos = std::distance(first: uniqueOperands.begin(),
4069 last: llvm::find(Range&: uniqueOperands, Val: operand));
4070 if (pos == uniqueOperands.size())
4071 uniqueOperands.push_back(Elt: operand);
4072 replacements.push_back(
4073 Elt: kind == AffineExprKind::DimId
4074 ? getAffineDimExpr(position: pos, context: parser.getContext())
4075 : getAffineSymbolExpr(position: pos, context: parser.getContext()));
4076 }
4077 }
4078 return success();
4079}
4080
4081namespace {
4082enum class MinMaxKind { Min, Max };
4083} // namespace
4084
4085/// Parses an affine map that can contain a min/max for groups of its results,
4086/// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates
4087/// `result` attributes with the map (flat list of expressions) and the grouping
4088/// (list of integers that specify how many expressions to put into each
4089/// min/max) attributes. Deduplicates repeated operands.
4090///
4091/// parallel-bound ::= `(` parallel-group-list `)`
4092/// parallel-group-list ::= parallel-group (`,` parallel-group-list)?
4093/// parallel-group ::= simple-group | min-max-group
4094/// simple-group ::= expr-of-ssa-ids
4095/// min-max-group ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)`
4096/// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)?
4097///
4098/// Examples:
4099/// (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6))
4100/// (%0, max(%1 - 2 * %2))
4101static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser,
4102 OperationState &result,
4103 MinMaxKind kind) {
4104 // Using `const` not `constexpr` below to workaround a MSVC optimizer bug,
4105 // see: https://reviews.llvm.org/D134227#3821753
4106 const llvm::StringLiteral tmpAttrStrName = "__pseudo_bound_map";
4107
4108 StringRef mapName = kind == MinMaxKind::Min
4109 ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4110 : AffineParallelOp::getLowerBoundsMapAttrStrName();
4111 StringRef groupsName =
4112 kind == MinMaxKind::Min
4113 ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4114 : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4115
4116 if (failed(result: parser.parseLParen()))
4117 return failure();
4118
4119 if (succeeded(result: parser.parseOptionalRParen())) {
4120 result.addAttribute(
4121 mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap()));
4122 result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr(values: {}));
4123 return success();
4124 }
4125
4126 SmallVector<AffineExpr> flatExprs;
4127 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> flatDimOperands;
4128 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> flatSymOperands;
4129 SmallVector<int32_t> numMapsPerGroup;
4130 SmallVector<OpAsmParser::UnresolvedOperand> mapOperands;
4131 auto parseOperands = [&]() {
4132 if (succeeded(result: parser.parseOptionalKeyword(
4133 keyword: kind == MinMaxKind::Min ? "min" : "max"))) {
4134 mapOperands.clear();
4135 AffineMapAttr map;
4136 if (failed(parser.parseAffineMapOfSSAIds(operands&: mapOperands, map&: map, attrName: tmpAttrStrName,
4137 attrs&: result.attributes,
4138 delimiter: OpAsmParser::Delimiter::Paren)))
4139 return failure();
4140 result.attributes.erase(name: tmpAttrStrName);
4141 llvm::append_range(flatExprs, map.getValue().getResults());
4142 auto operandsRef = llvm::ArrayRef(mapOperands);
4143 auto dimsRef = operandsRef.take_front(N: map.getValue().getNumDims());
4144 SmallVector<OpAsmParser::UnresolvedOperand> dims(dimsRef.begin(),
4145 dimsRef.end());
4146 auto symsRef = operandsRef.drop_front(N: map.getValue().getNumDims());
4147 SmallVector<OpAsmParser::UnresolvedOperand> syms(symsRef.begin(),
4148 symsRef.end());
4149 flatDimOperands.append(map.getValue().getNumResults(), dims);
4150 flatSymOperands.append(map.getValue().getNumResults(), syms);
4151 numMapsPerGroup.push_back(Elt: map.getValue().getNumResults());
4152 } else {
4153 if (failed(result: parser.parseAffineExprOfSSAIds(dimOperands&: flatDimOperands.emplace_back(),
4154 symbOperands&: flatSymOperands.emplace_back(),
4155 expr&: flatExprs.emplace_back())))
4156 return failure();
4157 numMapsPerGroup.push_back(Elt: 1);
4158 }
4159 return success();
4160 };
4161 if (parser.parseCommaSeparatedList(parseElementFn: parseOperands) || parser.parseRParen())
4162 return failure();
4163
4164 unsigned totalNumDims = 0;
4165 unsigned totalNumSyms = 0;
4166 for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4167 unsigned numDims = flatDimOperands[i].size();
4168 unsigned numSyms = flatSymOperands[i].size();
4169 flatExprs[i] = flatExprs[i]
4170 .shiftDims(numDims, shift: totalNumDims)
4171 .shiftSymbols(numSymbols: numSyms, shift: totalNumSyms);
4172 totalNumDims += numDims;
4173 totalNumSyms += numSyms;
4174 }
4175
4176 // Deduplicate map operands.
4177 SmallVector<Value> dimOperands, symOperands;
4178 SmallVector<AffineExpr> dimRplacements, symRepacements;
4179 if (deduplicateAndResolveOperands(parser, operands: flatDimOperands, uniqueOperands&: dimOperands,
4180 replacements&: dimRplacements, kind: AffineExprKind::DimId) ||
4181 deduplicateAndResolveOperands(parser, operands: flatSymOperands, uniqueOperands&: symOperands,
4182 replacements&: symRepacements, kind: AffineExprKind::SymbolId))
4183 return failure();
4184
4185 result.operands.append(in_start: dimOperands.begin(), in_end: dimOperands.end());
4186 result.operands.append(in_start: symOperands.begin(), in_end: symOperands.end());
4187
4188 Builder &builder = parser.getBuilder();
4189 auto flatMap = AffineMap::get(dimCount: totalNumDims, symbolCount: totalNumSyms, results: flatExprs,
4190 context: parser.getContext());
4191 flatMap = flatMap.replaceDimsAndSymbols(
4192 dimReplacements: dimRplacements, symReplacements: symRepacements, numResultDims: dimOperands.size(), numResultSyms: symOperands.size());
4193
4194 result.addAttribute(mapName, AffineMapAttr::get(flatMap));
4195 result.addAttribute(groupsName, builder.getI32TensorAttr(values: numMapsPerGroup));
4196 return success();
4197}
4198
4199//
4200// operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound
4201// `to` parallel-bound steps? region attr-dict?
4202// steps ::= `steps` `(` integer-literals `)`
4203//
4204ParseResult AffineParallelOp::parse(OpAsmParser &parser,
4205 OperationState &result) {
4206 auto &builder = parser.getBuilder();
4207 auto indexType = builder.getIndexType();
4208 SmallVector<OpAsmParser::Argument, 4> ivs;
4209 if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
4210 parser.parseEqual() ||
4211 parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
4212 parser.parseKeyword("to") ||
4213 parseAffineMapWithMinMax(parser, result, MinMaxKind::Min))
4214 return failure();
4215
4216 AffineMapAttr stepsMapAttr;
4217 NamedAttrList stepsAttrs;
4218 SmallVector<OpAsmParser::UnresolvedOperand, 4> stepsMapOperands;
4219 if (failed(parser.parseOptionalKeyword("step"))) {
4220 SmallVector<int64_t, 4> steps(ivs.size(), 1);
4221 result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4222 builder.getI64ArrayAttr(steps));
4223 } else {
4224 if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
4225 AffineParallelOp::getStepsAttrStrName(),
4226 stepsAttrs,
4227 OpAsmParser::Delimiter::Paren))
4228 return failure();
4229
4230 // Convert steps from an AffineMap into an I64ArrayAttr.
4231 SmallVector<int64_t, 4> steps;
4232 auto stepsMap = stepsMapAttr.getValue();
4233 for (const auto &result : stepsMap.getResults()) {
4234 auto constExpr = dyn_cast<AffineConstantExpr>(result);
4235 if (!constExpr)
4236 return parser.emitError(parser.getNameLoc(),
4237 "steps must be constant integers");
4238 steps.push_back(constExpr.getValue());
4239 }
4240 result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4241 builder.getI64ArrayAttr(steps));
4242 }
4243
4244 // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
4245 // quoted strings are a member of the enum AtomicRMWKind.
4246 SmallVector<Attribute, 4> reductions;
4247 if (succeeded(parser.parseOptionalKeyword("reduce"))) {
4248 if (parser.parseLParen())
4249 return failure();
4250 auto parseAttributes = [&]() -> ParseResult {
4251 // Parse a single quoted string via the attribute parsing, and then
4252 // verify it is a member of the enum and convert to it's integer
4253 // representation.
4254 StringAttr attrVal;
4255 NamedAttrList attrStorage;
4256 auto loc = parser.getCurrentLocation();
4257 if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
4258 attrStorage))
4259 return failure();
4260 std::optional<arith::AtomicRMWKind> reduction =
4261 arith::symbolizeAtomicRMWKind(attrVal.getValue());
4262 if (!reduction)
4263 return parser.emitError(loc, "invalid reduction value: ") << attrVal;
4264 reductions.push_back(
4265 builder.getI64IntegerAttr(static_cast<int64_t>(reduction.value())));
4266 // While we keep getting commas, keep parsing.
4267 return success();
4268 };
4269 if (parser.parseCommaSeparatedList(parseAttributes) || parser.parseRParen())
4270 return failure();
4271 }
4272 result.addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4273 builder.getArrayAttr(reductions));
4274
4275 // Parse return types of reductions (if any)
4276 if (parser.parseOptionalArrowTypeList(result.types))
4277 return failure();
4278
4279 // Now parse the body.
4280 Region *body = result.addRegion();
4281 for (auto &iv : ivs)
4282 iv.type = indexType;
4283 if (parser.parseRegion(*body, ivs) ||
4284 parser.parseOptionalAttrDict(result.attributes))
4285 return failure();
4286
4287 // Add a terminator if none was parsed.
4288 AffineParallelOp::ensureTerminator(*body, builder, result.location);
4289 return success();
4290}
4291
4292//===----------------------------------------------------------------------===//
4293// AffineYieldOp
4294//===----------------------------------------------------------------------===//
4295
4296LogicalResult AffineYieldOp::verify() {
4297 auto *parentOp = (*this)->getParentOp();
4298 auto results = parentOp->getResults();
4299 auto operands = getOperands();
4300
4301 if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4302 return emitOpError() << "only terminates affine.if/for/parallel regions";
4303 if (parentOp->getNumResults() != getNumOperands())
4304 return emitOpError() << "parent of yield must have same number of "
4305 "results as the yield operands";
4306 for (auto it : llvm::zip(results, operands)) {
4307 if (std::get<0>(it).getType() != std::get<1>(it).getType())
4308 return emitOpError() << "types mismatch between yield op and its parent";
4309 }
4310
4311 return success();
4312}
4313
4314//===----------------------------------------------------------------------===//
4315// AffineVectorLoadOp
4316//===----------------------------------------------------------------------===//
4317
4318void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4319 VectorType resultType, AffineMap map,
4320 ValueRange operands) {
4321 assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
4322 result.addOperands(operands);
4323 if (map)
4324 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4325 result.types.push_back(resultType);
4326}
4327
4328void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4329 VectorType resultType, Value memref,
4330 AffineMap map, ValueRange mapOperands) {
4331 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4332 result.addOperands(memref);
4333 result.addOperands(mapOperands);
4334 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4335 result.types.push_back(resultType);
4336}
4337
4338void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4339 VectorType resultType, Value memref,
4340 ValueRange indices) {
4341 auto memrefType = llvm::cast<MemRefType>(memref.getType());
4342 int64_t rank = memrefType.getRank();
4343 // Create identity map for memrefs with at least one dimension or () -> ()
4344 // for zero-dimensional memrefs.
4345 auto map =
4346 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4347 build(builder, result, resultType, memref, map, indices);
4348}
4349
4350void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4351 MLIRContext *context) {
4352 results.add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4353}
4354
4355ParseResult AffineVectorLoadOp::parse(OpAsmParser &parser,
4356 OperationState &result) {
4357 auto &builder = parser.getBuilder();
4358 auto indexTy = builder.getIndexType();
4359
4360 MemRefType memrefType;
4361 VectorType resultType;
4362 OpAsmParser::UnresolvedOperand memrefInfo;
4363 AffineMapAttr mapAttr;
4364 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
4365 return failure(
4366 parser.parseOperand(memrefInfo) ||
4367 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4368 AffineVectorLoadOp::getMapAttrStrName(),
4369 result.attributes) ||
4370 parser.parseOptionalAttrDict(result.attributes) ||
4371 parser.parseColonType(memrefType) || parser.parseComma() ||
4372 parser.parseType(resultType) ||
4373 parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4374 parser.resolveOperands(mapOperands, indexTy, result.operands) ||
4375 parser.addTypeToList(resultType, result.types));
4376}
4377
4378void AffineVectorLoadOp::print(OpAsmPrinter &p) {
4379 p << " " << getMemRef() << '[';
4380 if (AffineMapAttr mapAttr =
4381 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4382 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4383 p << ']';
4384 p.printOptionalAttrDict((*this)->getAttrs(),
4385 /*elidedAttrs=*/{getMapAttrStrName()});
4386 p << " : " << getMemRefType() << ", " << getType();
4387}
4388
4389/// Verify common invariants of affine.vector_load and affine.vector_store.
4390static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
4391 VectorType vectorType) {
4392 // Check that memref and vector element types match.
4393 if (memrefType.getElementType() != vectorType.getElementType())
4394 return op->emitOpError(
4395 message: "requires memref and vector types of the same elemental type");
4396 return success();
4397}
4398
4399LogicalResult AffineVectorLoadOp::verify() {
4400 MemRefType memrefType = getMemRefType();
4401 if (failed(verifyMemoryOpIndexing(
4402 getOperation(),
4403 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4404 getMapOperands(), memrefType,
4405 /*numIndexOperands=*/getNumOperands() - 1)))
4406 return failure();
4407
4408 if (failed(verifyVectorMemoryOp(getOperation(), memrefType, getVectorType())))
4409 return failure();
4410
4411 return success();
4412}
4413
4414//===----------------------------------------------------------------------===//
4415// AffineVectorStoreOp
4416//===----------------------------------------------------------------------===//
4417
4418void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4419 Value valueToStore, Value memref, AffineMap map,
4420 ValueRange mapOperands) {
4421 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4422 result.addOperands(valueToStore);
4423 result.addOperands(memref);
4424 result.addOperands(mapOperands);
4425 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4426}
4427
4428// Use identity map.
4429void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4430 Value valueToStore, Value memref,
4431 ValueRange indices) {
4432 auto memrefType = llvm::cast<MemRefType>(memref.getType());
4433 int64_t rank = memrefType.getRank();
4434 // Create identity map for memrefs with at least one dimension or () -> ()
4435 // for zero-dimensional memrefs.
4436 auto map =
4437 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4438 build(builder, result, valueToStore, memref, map, indices);
4439}
4440void AffineVectorStoreOp::getCanonicalizationPatterns(
4441 RewritePatternSet &results, MLIRContext *context) {
4442 results.add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4443}
4444
4445ParseResult AffineVectorStoreOp::parse(OpAsmParser &parser,
4446 OperationState &result) {
4447 auto indexTy = parser.getBuilder().getIndexType();
4448
4449 MemRefType memrefType;
4450 VectorType resultType;
4451 OpAsmParser::UnresolvedOperand storeValueInfo;
4452 OpAsmParser::UnresolvedOperand memrefInfo;
4453 AffineMapAttr mapAttr;
4454 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
4455 return failure(
4456 parser.parseOperand(storeValueInfo) || parser.parseComma() ||
4457 parser.parseOperand(memrefInfo) ||
4458 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4459 AffineVectorStoreOp::getMapAttrStrName(),
4460 result.attributes) ||
4461 parser.parseOptionalAttrDict(result.attributes) ||
4462 parser.parseColonType(memrefType) || parser.parseComma() ||
4463 parser.parseType(resultType) ||
4464 parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
4465 parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4466 parser.resolveOperands(mapOperands, indexTy, result.operands));
4467}
4468
4469void AffineVectorStoreOp::print(OpAsmPrinter &p) {
4470 p << " " << getValueToStore();
4471 p << ", " << getMemRef() << '[';
4472 if (AffineMapAttr mapAttr =
4473 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4474 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4475 p << ']';
4476 p.printOptionalAttrDict((*this)->getAttrs(),
4477 /*elidedAttrs=*/{getMapAttrStrName()});
4478 p << " : " << getMemRefType() << ", " << getValueToStore().getType();
4479}
4480
4481LogicalResult AffineVectorStoreOp::verify() {
4482 MemRefType memrefType = getMemRefType();
4483 if (failed(verifyMemoryOpIndexing(
4484 *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4485 getMapOperands(), memrefType,
4486 /*numIndexOperands=*/getNumOperands() - 2)))
4487 return failure();
4488
4489 if (failed(verifyVectorMemoryOp(*this, memrefType, getVectorType())))
4490 return failure();
4491
4492 return success();
4493}
4494
4495//===----------------------------------------------------------------------===//
4496// DelinearizeIndexOp
4497//===----------------------------------------------------------------------===//
4498
4499LogicalResult AffineDelinearizeIndexOp::inferReturnTypes(
4500 MLIRContext *context, std::optional<::mlir::Location> location,
4501 ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
4502 RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
4503 AffineDelinearizeIndexOpAdaptor adaptor(operands, attributes, properties,
4504 regions);
4505 inferredReturnTypes.assign(adaptor.getBasis().size(),
4506 IndexType::get(context));
4507 return success();
4508}
4509
4510void AffineDelinearizeIndexOp::build(OpBuilder &builder, OperationState &result,
4511 Value linearIndex,
4512 ArrayRef<OpFoldResult> basis) {
4513 result.addTypes(SmallVector<Type>(basis.size(), builder.getIndexType()));
4514 result.addOperands(linearIndex);
4515 SmallVector<Value> basisValues =
4516 llvm::map_to_vector(basis, [&](OpFoldResult ofr) -> Value {
4517 std::optional<int64_t> staticDim = getConstantIntValue(ofr);
4518 if (staticDim.has_value())
4519 return builder.create<arith::ConstantIndexOp>(result.location,
4520 *staticDim);
4521 return llvm::dyn_cast_if_present<Value>(ofr);
4522 });
4523 result.addOperands(basisValues);
4524}
4525
4526LogicalResult AffineDelinearizeIndexOp::verify() {
4527 if (getBasis().empty())
4528 return emitOpError("basis should not be empty");
4529 if (getNumResults() != getBasis().size())
4530 return emitOpError("should return an index for each basis element");
4531 return success();
4532}
4533
4534//===----------------------------------------------------------------------===//
4535// TableGen'd op method definitions
4536//===----------------------------------------------------------------------===//
4537
4538#define GET_OP_CLASSES
4539#include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
4540

source code of mlir/lib/Dialect/Affine/IR/AffineOps.cpp