1//===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Affine/IR/AffineOps.h"
10#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
11#include "mlir/Dialect/MemRef/IR/MemRef.h"
12#include "mlir/Dialect/UB/IR/UBOps.h"
13#include "mlir/Dialect/Utils/StaticValueUtils.h"
14#include "mlir/IR/AffineExpr.h"
15#include "mlir/IR/AffineExprVisitor.h"
16#include "mlir/IR/IRMapping.h"
17#include "mlir/IR/IntegerSet.h"
18#include "mlir/IR/Matchers.h"
19#include "mlir/IR/OpDefinition.h"
20#include "mlir/IR/PatternMatch.h"
21#include "mlir/IR/Value.h"
22#include "mlir/Interfaces/ShapedOpInterfaces.h"
23#include "mlir/Interfaces/ValueBoundsOpInterface.h"
24#include "mlir/Transforms/InliningUtils.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallBitVector.h"
27#include "llvm/ADT/SmallVectorExtras.h"
28#include "llvm/ADT/TypeSwitch.h"
29#include "llvm/Support/Debug.h"
30#include "llvm/Support/LogicalResult.h"
31#include "llvm/Support/MathExtras.h"
32#include <numeric>
33#include <optional>
34
35using namespace mlir;
36using namespace mlir::affine;
37
38using llvm::divideCeilSigned;
39using llvm::divideFloorSigned;
40using llvm::mod;
41
42#define DEBUG_TYPE "affine-ops"
43#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
44
45#include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
46
47/// A utility function to check if a value is defined at the top level of
48/// `region` or is an argument of `region`. A value of index type defined at the
49/// top level of a `AffineScope` region is always a valid symbol for all
50/// uses in that region.
51bool mlir::affine::isTopLevelValue(Value value, Region *region) {
52 if (auto arg = llvm::dyn_cast<BlockArgument>(Val&: value))
53 return arg.getParentRegion() == region;
54 return value.getDefiningOp()->getParentRegion() == region;
55}
56
57/// Checks if `value` known to be a legal affine dimension or symbol in `src`
58/// region remains legal if the operation that uses it is inlined into `dest`
59/// with the given value mapping. `legalityCheck` is either `isValidDim` or
60/// `isValidSymbol`, depending on the value being required to remain a valid
61/// dimension or symbol.
62static bool
63remainsLegalAfterInline(Value value, Region *src, Region *dest,
64 const IRMapping &mapping,
65 function_ref<bool(Value, Region *)> legalityCheck) {
66 // If the value is a valid dimension for any other reason than being
67 // a top-level value, it will remain valid: constants get inlined
68 // with the function, transitive affine applies also get inlined and
69 // will be checked themselves, etc.
70 if (!isTopLevelValue(value, region: src))
71 return true;
72
73 // If it's a top-level value because it's a block operand, i.e. a
74 // function argument, check whether the value replacing it after
75 // inlining is a valid dimension in the new region.
76 if (llvm::isa<BlockArgument>(Val: value))
77 return legalityCheck(mapping.lookup(from: value), dest);
78
79 // If it's a top-level value because it's defined in the region,
80 // it can only be inlined if the defining op is a constant or a
81 // `dim`, which can appear anywhere and be valid, since the defining
82 // op won't be top-level anymore after inlining.
83 Attribute operandCst;
84 bool isDimLikeOp = isa<ShapedDimOpInterface>(Val: value.getDefiningOp());
85 return matchPattern(op: value.getDefiningOp(), pattern: m_Constant(bind_value: &operandCst)) ||
86 isDimLikeOp;
87}
88
89/// Checks if all values known to be legal affine dimensions or symbols in `src`
90/// remain so if their respective users are inlined into `dest`.
91static bool
92remainsLegalAfterInline(ValueRange values, Region *src, Region *dest,
93 const IRMapping &mapping,
94 function_ref<bool(Value, Region *)> legalityCheck) {
95 return llvm::all_of(Range&: values, P: [&](Value v) {
96 return remainsLegalAfterInline(value: v, src, dest, mapping, legalityCheck);
97 });
98}
99
100/// Checks if an affine read or write operation remains legal after inlining
101/// from `src` to `dest`.
102template <typename OpTy>
103static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
104 const IRMapping &mapping) {
105 static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
106 AffineWriteOpInterface>::value,
107 "only ops with affine read/write interface are supported");
108
109 AffineMap map = op.getAffineMap();
110 ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
111 ValueRange symbolOperands =
112 op.getMapOperands().take_back(map.getNumSymbols());
113 if (!remainsLegalAfterInline(
114 values: dimOperands, src, dest, mapping,
115 legalityCheck: static_cast<bool (*)(Value, Region *)>(isValidDim)))
116 return false;
117 if (!remainsLegalAfterInline(
118 values: symbolOperands, src, dest, mapping,
119 legalityCheck: static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
120 return false;
121 return true;
122}
123
124/// Checks if an affine apply operation remains legal after inlining from `src`
125/// to `dest`.
126// Use "unused attribute" marker to silence clang-tidy warning stemming from
127// the inability to see through "llvm::TypeSwitch".
128template <>
129bool LLVM_ATTRIBUTE_UNUSED remainsLegalAfterInline(AffineApplyOp op,
130 Region *src, Region *dest,
131 const IRMapping &mapping) {
132 // If it's a valid dimension, we need to check that it remains so.
133 if (isValidDim(value: op.getResult(), region: src))
134 return remainsLegalAfterInline(
135 values: op.getMapOperands(), src, dest, mapping,
136 legalityCheck: static_cast<bool (*)(Value, Region *)>(isValidDim));
137
138 // Otherwise it must be a valid symbol, check that it remains so.
139 return remainsLegalAfterInline(
140 values: op.getMapOperands(), src, dest, mapping,
141 legalityCheck: static_cast<bool (*)(Value, Region *)>(isValidSymbol));
142}
143
144//===----------------------------------------------------------------------===//
145// AffineDialect Interfaces
146//===----------------------------------------------------------------------===//
147
148namespace {
149/// This class defines the interface for handling inlining with affine
150/// operations.
151struct AffineInlinerInterface : public DialectInlinerInterface {
152 using DialectInlinerInterface::DialectInlinerInterface;
153
154 //===--------------------------------------------------------------------===//
155 // Analysis Hooks
156 //===--------------------------------------------------------------------===//
157
158 /// Returns true if the given region 'src' can be inlined into the region
159 /// 'dest' that is attached to an operation registered to the current dialect.
160 /// 'wouldBeCloned' is set if the region is cloned into its new location
161 /// rather than moved, indicating there may be other users.
162 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
163 IRMapping &valueMapping) const final {
164 // We can inline into affine loops and conditionals if this doesn't break
165 // affine value categorization rules.
166 Operation *destOp = dest->getParentOp();
167 if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(Val: destOp))
168 return false;
169
170 // Multi-block regions cannot be inlined into affine constructs, all of
171 // which require single-block regions.
172 if (!llvm::hasSingleElement(C&: *src))
173 return false;
174
175 // Side-effecting operations that the affine dialect cannot understand
176 // should not be inlined.
177 Block &srcBlock = src->front();
178 for (Operation &op : srcBlock) {
179 // Ops with no side effects are fine,
180 if (auto iface = dyn_cast<MemoryEffectOpInterface>(Val&: op)) {
181 if (iface.hasNoEffect())
182 continue;
183 }
184
185 // Assuming the inlined region is valid, we only need to check if the
186 // inlining would change it.
187 bool remainsValid =
188 llvm::TypeSwitch<Operation *, bool>(&op)
189 .Case<AffineApplyOp, AffineReadOpInterface,
190 AffineWriteOpInterface>(caseFn: [&](auto op) {
191 return remainsLegalAfterInline(op, src, dest, valueMapping);
192 })
193 .Default(defaultFn: [](Operation *) {
194 // Conservatively disallow inlining ops we cannot reason about.
195 return false;
196 });
197
198 if (!remainsValid)
199 return false;
200 }
201
202 return true;
203 }
204
205 /// Returns true if the given operation 'op', that is registered to this
206 /// dialect, can be inlined into the given region, false otherwise.
207 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
208 IRMapping &valueMapping) const final {
209 // Always allow inlining affine operations into a region that is marked as
210 // affine scope, or into affine loops and conditionals. There are some edge
211 // cases when inlining *into* affine structures, but that is handled in the
212 // other 'isLegalToInline' hook above.
213 Operation *parentOp = region->getParentOp();
214 return parentOp->hasTrait<OpTrait::AffineScope>() ||
215 isa<AffineForOp, AffineParallelOp, AffineIfOp>(Val: parentOp);
216 }
217
218 /// Affine regions should be analyzed recursively.
219 bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
220};
221} // namespace
222
223//===----------------------------------------------------------------------===//
224// AffineDialect
225//===----------------------------------------------------------------------===//
226
227void AffineDialect::initialize() {
228 addOperations<AffineDmaStartOp, AffineDmaWaitOp,
229#define GET_OP_LIST
230#include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
231 >();
232 addInterfaces<AffineInlinerInterface>();
233 declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
234 AffineMinOp>();
235}
236
237/// Materialize a single constant operation from a given attribute value with
238/// the desired resultant type.
239Operation *AffineDialect::materializeConstant(OpBuilder &builder,
240 Attribute value, Type type,
241 Location loc) {
242 if (auto poison = dyn_cast<ub::PoisonAttr>(Val&: value))
243 return builder.create<ub::PoisonOp>(location: loc, args&: type, args&: poison);
244 return arith::ConstantOp::materialize(builder, value, type, loc);
245}
246
247/// A utility function to check if a value is defined at the top level of an
248/// op with trait `AffineScope`. If the value is defined in an unlinked region,
249/// conservatively assume it is not top-level. A value of index type defined at
250/// the top level is always a valid symbol.
251bool mlir::affine::isTopLevelValue(Value value) {
252 if (auto arg = llvm::dyn_cast<BlockArgument>(Val&: value)) {
253 // The block owning the argument may be unlinked, e.g. when the surrounding
254 // region has not yet been attached to an Op, at which point the parent Op
255 // is null.
256 Operation *parentOp = arg.getOwner()->getParentOp();
257 return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
258 }
259 // The defining Op may live in an unlinked block so its parent Op may be null.
260 Operation *parentOp = value.getDefiningOp()->getParentOp();
261 return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
262}
263
264/// Returns the closest region enclosing `op` that is held by an operation with
265/// trait `AffineScope`; `nullptr` if there is no such region.
266Region *mlir::affine::getAffineScope(Operation *op) {
267 auto *curOp = op;
268 while (auto *parentOp = curOp->getParentOp()) {
269 if (parentOp->hasTrait<OpTrait::AffineScope>())
270 return curOp->getParentRegion();
271 curOp = parentOp;
272 }
273 return nullptr;
274}
275
276Region *mlir::affine::getAffineAnalysisScope(Operation *op) {
277 Operation *curOp = op;
278 while (auto *parentOp = curOp->getParentOp()) {
279 if (!isa<AffineForOp, AffineIfOp, AffineParallelOp>(Val: parentOp))
280 return curOp->getParentRegion();
281 curOp = parentOp;
282 }
283 return nullptr;
284}
285
286// A Value can be used as a dimension id iff it meets one of the following
287// conditions:
288// *) It is valid as a symbol.
289// *) It is an induction variable.
290// *) It is the result of affine apply operation with dimension id arguments.
291bool mlir::affine::isValidDim(Value value) {
292 // The value must be an index type.
293 if (!value.getType().isIndex())
294 return false;
295
296 if (auto *defOp = value.getDefiningOp())
297 return isValidDim(value, region: getAffineScope(op: defOp));
298
299 // This value has to be a block argument for an op that has the
300 // `AffineScope` trait or an induction var of an affine.for or
301 // affine.parallel.
302 if (isAffineInductionVar(val: value))
303 return true;
304 auto *parentOp = llvm::cast<BlockArgument>(Val&: value).getOwner()->getParentOp();
305 return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
306}
307
308// Value can be used as a dimension id iff it meets one of the following
309// conditions:
310// *) It is valid as a symbol.
311// *) It is an induction variable.
312// *) It is the result of an affine apply operation with dimension id operands.
313// *) It is the result of a more specialized index transformation (ex.
314// delinearize_index or linearize_index) with dimension id operands.
315bool mlir::affine::isValidDim(Value value, Region *region) {
316 // The value must be an index type.
317 if (!value.getType().isIndex())
318 return false;
319
320 // All valid symbols are okay.
321 if (isValidSymbol(value, region))
322 return true;
323
324 auto *op = value.getDefiningOp();
325 if (!op) {
326 // This value has to be an induction var for an affine.for or an
327 // affine.parallel.
328 return isAffineInductionVar(val: value);
329 }
330
331 // Affine apply operation is ok if all of its operands are ok.
332 if (auto applyOp = dyn_cast<AffineApplyOp>(Val: op))
333 return applyOp.isValidDim(region);
334 // delinearize_index and linearize_index are special forms of apply
335 // and so are valid dimensions if all their arguments are valid dimensions.
336 if (isa<AffineDelinearizeIndexOp, AffineLinearizeIndexOp>(Val: op))
337 return llvm::all_of(Range: op->getOperands(),
338 P: [&](Value arg) { return ::isValidDim(value: arg, region); });
339 // The dim op is okay if its operand memref/tensor is defined at the top
340 // level.
341 if (auto dimOp = dyn_cast<ShapedDimOpInterface>(Val: op))
342 return isTopLevelValue(value: dimOp.getShapedValue());
343 return false;
344}
345
346/// Returns true if the 'index' dimension of the `memref` defined by
347/// `memrefDefOp` is a statically shaped one or defined using a valid symbol
348/// for `region`.
349template <typename AnyMemRefDefOp>
350static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
351 Region *region) {
352 MemRefType memRefType = memrefDefOp.getType();
353
354 // Dimension index is out of bounds.
355 if (index >= memRefType.getRank()) {
356 return false;
357 }
358
359 // Statically shaped.
360 if (!memRefType.isDynamicDim(idx: index))
361 return true;
362 // Get the position of the dimension among dynamic dimensions;
363 unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
364 return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
365 region);
366}
367
368/// Returns true if the result of the dim op is a valid symbol for `region`.
369static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
370 // The dim op is okay if its source is defined at the top level.
371 if (isTopLevelValue(value: dimOp.getShapedValue()))
372 return true;
373
374 // Conservatively handle remaining BlockArguments as non-valid symbols.
375 // E.g. scf.for iterArgs.
376 if (llvm::isa<BlockArgument>(Val: dimOp.getShapedValue()))
377 return false;
378
379 // The dim op is also okay if its operand memref is a view/subview whose
380 // corresponding size is a valid symbol.
381 std::optional<int64_t> index = getConstantIntValue(ofr: dimOp.getDimension());
382
383 // Be conservative if we can't understand the dimension.
384 if (!index.has_value())
385 return false;
386
387 // Skip over all memref.cast ops (if any).
388 Operation *op = dimOp.getShapedValue().getDefiningOp();
389 while (auto castOp = dyn_cast<memref::CastOp>(Val: op)) {
390 // Bail on unranked memrefs.
391 if (isa<UnrankedMemRefType>(Val: castOp.getSource().getType()))
392 return false;
393 op = castOp.getSource().getDefiningOp();
394 if (!op)
395 return false;
396 }
397
398 int64_t i = index.value();
399 return TypeSwitch<Operation *, bool>(op)
400 .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
401 caseFn: [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
402 .Default(defaultFn: [](Operation *) { return false; });
403}
404
405// A value can be used as a symbol (at all its use sites) iff it meets one of
406// the following conditions:
407// *) It is a constant.
408// *) Its defining op or block arg appearance is immediately enclosed by an op
409// with `AffineScope` trait.
410// *) It is the result of an affine.apply operation with symbol operands.
411// *) It is a result of the dim op on a memref whose corresponding size is a
412// valid symbol.
413bool mlir::affine::isValidSymbol(Value value) {
414 if (!value)
415 return false;
416
417 // The value must be an index type.
418 if (!value.getType().isIndex())
419 return false;
420
421 // Check that the value is a top level value.
422 if (isTopLevelValue(value))
423 return true;
424
425 if (auto *defOp = value.getDefiningOp())
426 return isValidSymbol(value, region: getAffineScope(op: defOp));
427
428 return false;
429}
430
431/// A value can be used as a symbol for `region` iff it meets one of the
432/// following conditions:
433/// *) It is a constant.
434/// *) It is a result of a `Pure` operation whose operands are valid symbolic
435/// *) identifiers.
436/// *) It is a result of the dim op on a memref whose corresponding size is
437/// a valid symbol.
438/// *) It is defined at the top level of 'region' or is its argument.
439/// *) It dominates `region`'s parent op.
440/// If `region` is null, conservatively assume the symbol definition scope does
441/// not exist and only accept the values that would be symbols regardless of
442/// the surrounding region structure, i.e. the first three cases above.
443bool mlir::affine::isValidSymbol(Value value, Region *region) {
444 // The value must be an index type.
445 if (!value.getType().isIndex())
446 return false;
447
448 // A top-level value is a valid symbol.
449 if (region && ::isTopLevelValue(value, region))
450 return true;
451
452 auto *defOp = value.getDefiningOp();
453 if (!defOp) {
454 // A block argument that is not a top-level value is a valid symbol if it
455 // dominates region's parent op.
456 Operation *regionOp = region ? region->getParentOp() : nullptr;
457 if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
458 if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
459 return isValidSymbol(value, region: parentOpRegion);
460 return false;
461 }
462
463 // Constant operation is ok.
464 Attribute operandCst;
465 if (matchPattern(op: defOp, pattern: m_Constant(bind_value: &operandCst)))
466 return true;
467
468 // `Pure` operation that whose operands are valid symbolic identifiers.
469 if (isPure(op: defOp) && llvm::all_of(Range: defOp->getOperands(), P: [&](Value operand) {
470 return affine::isValidSymbol(value: operand, region);
471 })) {
472 return true;
473 }
474
475 // Dim op results could be valid symbols at any level.
476 if (auto dimOp = dyn_cast<ShapedDimOpInterface>(Val: defOp))
477 return isDimOpValidSymbol(dimOp, region);
478
479 // Check for values dominating `region`'s parent op.
480 Operation *regionOp = region ? region->getParentOp() : nullptr;
481 if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
482 if (auto *parentRegion = region->getParentOp()->getParentRegion())
483 return isValidSymbol(value, region: parentRegion);
484
485 return false;
486}
487
488// Returns true if 'value' is a valid index to an affine operation (e.g.
489// affine.load, affine.store, affine.dma_start, affine.dma_wait) where
490// `region` provides the polyhedral symbol scope. Returns false otherwise.
491static bool isValidAffineIndexOperand(Value value, Region *region) {
492 return isValidDim(value, region) || isValidSymbol(value, region);
493}
494
495/// Prints dimension and symbol list.
496static void printDimAndSymbolList(Operation::operand_iterator begin,
497 Operation::operand_iterator end,
498 unsigned numDims, OpAsmPrinter &printer) {
499 OperandRange operands(begin, end);
500 printer << '(' << operands.take_front(n: numDims) << ')';
501 if (operands.size() > numDims)
502 printer << '[' << operands.drop_front(n: numDims) << ']';
503}
504
505/// Parses dimension and symbol list and returns true if parsing failed.
506ParseResult mlir::affine::parseDimAndSymbolList(
507 OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
508 SmallVector<OpAsmParser::UnresolvedOperand, 8> opInfos;
509 if (parser.parseOperandList(result&: opInfos, delimiter: OpAsmParser::Delimiter::Paren))
510 return failure();
511 // Store number of dimensions for validation by caller.
512 numDims = opInfos.size();
513
514 // Parse the optional symbol operands.
515 auto indexTy = parser.getBuilder().getIndexType();
516 return failure(IsFailure: parser.parseOperandList(
517 result&: opInfos, delimiter: OpAsmParser::Delimiter::OptionalSquare) ||
518 parser.resolveOperands(operands&: opInfos, type: indexTy, result&: operands));
519}
520
521/// Utility function to verify that a set of operands are valid dimension and
522/// symbol identifiers. The operands should be laid out such that the dimension
523/// operands are before the symbol operands. This function returns failure if
524/// there was an invalid operand. An operation is provided to emit any necessary
525/// errors.
526template <typename OpTy>
527static LogicalResult
528verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands,
529 unsigned numDims) {
530 unsigned opIt = 0;
531 for (auto operand : operands) {
532 if (opIt++ < numDims) {
533 if (!isValidDim(operand, getAffineScope(op)))
534 return op.emitOpError("operand cannot be used as a dimension id");
535 } else if (!isValidSymbol(operand, getAffineScope(op))) {
536 return op.emitOpError("operand cannot be used as a symbol");
537 }
538 }
539 return success();
540}
541
542//===----------------------------------------------------------------------===//
543// AffineApplyOp
544//===----------------------------------------------------------------------===//
545
546AffineValueMap AffineApplyOp::getAffineValueMap() {
547 return AffineValueMap(getAffineMap(), getOperands(), getResult());
548}
549
550ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
551 auto &builder = parser.getBuilder();
552 auto indexTy = builder.getIndexType();
553
554 AffineMapAttr mapAttr;
555 unsigned numDims;
556 if (parser.parseAttribute(result&: mapAttr, attrName: "map", attrs&: result.attributes) ||
557 parseDimAndSymbolList(parser, operands&: result.operands, numDims) ||
558 parser.parseOptionalAttrDict(result&: result.attributes))
559 return failure();
560 auto map = mapAttr.getValue();
561
562 if (map.getNumDims() != numDims ||
563 numDims + map.getNumSymbols() != result.operands.size()) {
564 return parser.emitError(loc: parser.getNameLoc(),
565 message: "dimension or symbol index mismatch");
566 }
567
568 result.types.append(NumInputs: map.getNumResults(), Elt: indexTy);
569 return success();
570}
571
572void AffineApplyOp::print(OpAsmPrinter &p) {
573 p << " " << getMapAttr();
574 printDimAndSymbolList(begin: operand_begin(), end: operand_end(),
575 numDims: getAffineMap().getNumDims(), printer&: p);
576 p.printOptionalAttrDict(attrs: (*this)->getAttrs(), /*elidedAttrs=*/{"map"});
577}
578
579LogicalResult AffineApplyOp::verify() {
580 // Check input and output dimensions match.
581 AffineMap affineMap = getMap();
582
583 // Verify that operand count matches affine map dimension and symbol count.
584 if (getNumOperands() != affineMap.getNumDims() + affineMap.getNumSymbols())
585 return emitOpError(
586 message: "operand count and affine map dimension and symbol count must match");
587
588 // Verify that the map only produces one result.
589 if (affineMap.getNumResults() != 1)
590 return emitOpError(message: "mapping must produce one value");
591
592 // Do not allow valid dims to be used in symbol positions. We do allow
593 // affine.apply to use operands for values that may neither qualify as affine
594 // dims or affine symbols due to usage outside of affine ops, analyses, etc.
595 Region *region = getAffineScope(op: *this);
596 for (Value operand : getMapOperands().drop_front(n: affineMap.getNumDims())) {
597 if (::isValidDim(value: operand, region) && !::isValidSymbol(value: operand, region))
598 return emitError(message: "dimensional operand cannot be used as a symbol");
599 }
600
601 return success();
602}
603
604// The result of the affine apply operation can be used as a dimension id if all
605// its operands are valid dimension ids.
606bool AffineApplyOp::isValidDim() {
607 return llvm::all_of(Range: getOperands(),
608 P: [](Value op) { return affine::isValidDim(value: op); });
609}
610
611// The result of the affine apply operation can be used as a dimension id if all
612// its operands are valid dimension ids with the parent operation of `region`
613// defining the polyhedral scope for symbols.
614bool AffineApplyOp::isValidDim(Region *region) {
615 return llvm::all_of(Range: getOperands(),
616 P: [&](Value op) { return ::isValidDim(value: op, region); });
617}
618
619// The result of the affine apply operation can be used as a symbol if all its
620// operands are symbols.
621bool AffineApplyOp::isValidSymbol() {
622 return llvm::all_of(Range: getOperands(),
623 P: [](Value op) { return affine::isValidSymbol(value: op); });
624}
625
626// The result of the affine apply operation can be used as a symbol in `region`
627// if all its operands are symbols in `region`.
628bool AffineApplyOp::isValidSymbol(Region *region) {
629 return llvm::all_of(Range: getOperands(), P: [&](Value operand) {
630 return affine::isValidSymbol(value: operand, region);
631 });
632}
633
634OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
635 auto map = getAffineMap();
636
637 // Fold dims and symbols to existing values.
638 auto expr = map.getResult(idx: 0);
639 if (auto dim = dyn_cast<AffineDimExpr>(Val&: expr))
640 return getOperand(i: dim.getPosition());
641 if (auto sym = dyn_cast<AffineSymbolExpr>(Val&: expr))
642 return getOperand(i: map.getNumDims() + sym.getPosition());
643
644 // Otherwise, default to folding the map.
645 SmallVector<Attribute, 1> result;
646 bool hasPoison = false;
647 auto foldResult =
648 map.constantFold(operandConstants: adaptor.getMapOperands(), results&: result, hasPoison: &hasPoison);
649 if (hasPoison)
650 return ub::PoisonAttr::get(ctx: getContext());
651 if (failed(Result: foldResult))
652 return {};
653 return result[0];
654}
655
656/// Returns the largest known divisor of `e`. Exploits information from the
657/// values in `operands`.
658static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef<Value> operands) {
659 // This method isn't aware of `operands`.
660 int64_t div = e.getLargestKnownDivisor();
661
662 // We now make use of operands for the case `e` is a dim expression.
663 // TODO: More powerful simplification would have to modify
664 // getLargestKnownDivisor to take `operands` and exploit that information as
665 // well for dim/sym expressions, but in that case, getLargestKnownDivisor
666 // can't be part of the IR library but of the `Analysis` library. The IR
667 // library can only really depend on simple O(1) checks.
668 auto dimExpr = dyn_cast<AffineDimExpr>(Val&: e);
669 // If it's not a dim expr, `div` is the best we have.
670 if (!dimExpr)
671 return div;
672
673 // We simply exploit information from loop IVs.
674 // We don't need to use mlir::getLargestKnownDivisorOfValue since the other
675 // desired simplifications are expected to be part of other
676 // canonicalizations. Also, mlir::getLargestKnownDivisorOfValue is part of the
677 // LoopAnalysis library.
678 Value operand = operands[dimExpr.getPosition()];
679 int64_t operandDivisor = 1;
680 // TODO: With the right accessors, this can be extended to
681 // LoopLikeOpInterface.
682 if (AffineForOp forOp = getForInductionVarOwner(val: operand)) {
683 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
684 operandDivisor = forOp.getStepAsInt();
685 } else {
686 uint64_t lbLargestKnownDivisor =
687 forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
688 operandDivisor = std::gcd(m: lbLargestKnownDivisor, n: forOp.getStepAsInt());
689 }
690 }
691 return operandDivisor;
692}
693
694/// Check if `e` is known to be: 0 <= `e` < `k`. Handles the simple cases of `e`
695/// being an affine dim expression or a constant.
696static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef<Value> operands,
697 int64_t k) {
698 if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: e)) {
699 int64_t constVal = constExpr.getValue();
700 return constVal >= 0 && constVal < k;
701 }
702 auto dimExpr = dyn_cast<AffineDimExpr>(Val&: e);
703 if (!dimExpr)
704 return false;
705 Value operand = operands[dimExpr.getPosition()];
706 // TODO: With the right accessors, this can be extended to
707 // LoopLikeOpInterface.
708 if (AffineForOp forOp = getForInductionVarOwner(val: operand)) {
709 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
710 forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
711 return true;
712 }
713 }
714
715 // We don't consider other cases like `operand` being defined by a constant or
716 // an affine.apply op since such cases will already be handled by other
717 // patterns and propagation of loop IVs or constant would happen.
718 return false;
719}
720
721/// Check if expression `e` is of the form d*e_1 + e_2 where 0 <= e_2 < d.
722/// Set `div` to `d`, `quotientTimesDiv` to e_1 and `rem` to e_2 if the
723/// expression is in that form.
724static bool isQTimesDPlusR(AffineExpr e, ArrayRef<Value> operands, int64_t &div,
725 AffineExpr &quotientTimesDiv, AffineExpr &rem) {
726 auto bin = dyn_cast<AffineBinaryOpExpr>(Val&: e);
727 if (!bin || bin.getKind() != AffineExprKind::Add)
728 return false;
729
730 AffineExpr llhs = bin.getLHS();
731 AffineExpr rlhs = bin.getRHS();
732 div = getLargestKnownDivisor(e: llhs, operands);
733 if (isNonNegativeBoundedBy(e: rlhs, operands, k: div)) {
734 quotientTimesDiv = llhs;
735 rem = rlhs;
736 return true;
737 }
738 div = getLargestKnownDivisor(e: rlhs, operands);
739 if (isNonNegativeBoundedBy(e: llhs, operands, k: div)) {
740 quotientTimesDiv = rlhs;
741 rem = llhs;
742 return true;
743 }
744 return false;
745}
746
747/// Gets the constant lower bound on an `iv`.
748static std::optional<int64_t> getLowerBound(Value iv) {
749 AffineForOp forOp = getForInductionVarOwner(val: iv);
750 if (forOp && forOp.hasConstantLowerBound())
751 return forOp.getConstantLowerBound();
752 return std::nullopt;
753}
754
755/// Gets the constant upper bound on an affine.for `iv`.
756static std::optional<int64_t> getUpperBound(Value iv) {
757 AffineForOp forOp = getForInductionVarOwner(val: iv);
758 if (!forOp || !forOp.hasConstantUpperBound())
759 return std::nullopt;
760
761 // If its lower bound is also known, we can get a more precise bound
762 // whenever the step is not one.
763 if (forOp.hasConstantLowerBound()) {
764 return forOp.getConstantUpperBound() - 1 -
765 (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
766 forOp.getStepAsInt();
767 }
768 return forOp.getConstantUpperBound() - 1;
769}
770
771/// Determine a constant upper bound for `expr` if one exists while exploiting
772/// values in `operands`. Note that the upper bound is an inclusive one. `expr`
773/// is guaranteed to be less than or equal to it.
774static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
775 unsigned numSymbols,
776 ArrayRef<Value> operands) {
777 // Get the constant lower or upper bounds on the operands.
778 SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
779 constLowerBounds.reserve(N: operands.size());
780 constUpperBounds.reserve(N: operands.size());
781 for (Value operand : operands) {
782 constLowerBounds.push_back(Elt: getLowerBound(iv: operand));
783 constUpperBounds.push_back(Elt: getUpperBound(iv: operand));
784 }
785
786 if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: expr))
787 return constExpr.getValue();
788
789 return getBoundForAffineExpr(expr, numDims, numSymbols, constLowerBounds,
790 constUpperBounds,
791 /*isUpper=*/true);
792}
793
794/// Determine a constant lower bound for `expr` if one exists while exploiting
795/// values in `operands`. Note that the upper bound is an inclusive one. `expr`
796/// is guaranteed to be less than or equal to it.
797static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
798 unsigned numSymbols,
799 ArrayRef<Value> operands) {
800 // Get the constant lower or upper bounds on the operands.
801 SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
802 constLowerBounds.reserve(N: operands.size());
803 constUpperBounds.reserve(N: operands.size());
804 for (Value operand : operands) {
805 constLowerBounds.push_back(Elt: getLowerBound(iv: operand));
806 constUpperBounds.push_back(Elt: getUpperBound(iv: operand));
807 }
808
809 std::optional<int64_t> lowerBound;
810 if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: expr)) {
811 lowerBound = constExpr.getValue();
812 } else {
813 lowerBound = getBoundForAffineExpr(expr, numDims, numSymbols,
814 constLowerBounds, constUpperBounds,
815 /*isUpper=*/false);
816 }
817 return lowerBound;
818}
819
820/// Simplify `expr` while exploiting information from the values in `operands`.
821static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
822 unsigned numSymbols,
823 ArrayRef<Value> operands) {
824 // We do this only for certain floordiv/mod expressions.
825 auto binExpr = dyn_cast<AffineBinaryOpExpr>(Val&: expr);
826 if (!binExpr)
827 return;
828
829 // Simplify the child expressions first.
830 AffineExpr lhs = binExpr.getLHS();
831 AffineExpr rhs = binExpr.getRHS();
832 simplifyExprAndOperands(expr&: lhs, numDims, numSymbols, operands);
833 simplifyExprAndOperands(expr&: rhs, numDims, numSymbols, operands);
834 expr = getAffineBinaryOpExpr(kind: binExpr.getKind(), lhs, rhs);
835
836 binExpr = dyn_cast<AffineBinaryOpExpr>(Val&: expr);
837 if (!binExpr || (expr.getKind() != AffineExprKind::FloorDiv &&
838 expr.getKind() != AffineExprKind::CeilDiv &&
839 expr.getKind() != AffineExprKind::Mod)) {
840 return;
841 }
842
843 // The `lhs` and `rhs` may be different post construction of simplified expr.
844 lhs = binExpr.getLHS();
845 rhs = binExpr.getRHS();
846 auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhs);
847 if (!rhsConst)
848 return;
849
850 int64_t rhsConstVal = rhsConst.getValue();
851 // Undefined exprsessions aren't touched; IR can still be valid with them.
852 if (rhsConstVal <= 0)
853 return;
854
855 // Exploit constant lower/upper bounds to simplify a floordiv or mod.
856 MLIRContext *context = expr.getContext();
857 std::optional<int64_t> lhsLbConst =
858 getLowerBound(expr: lhs, numDims, numSymbols, operands);
859 std::optional<int64_t> lhsUbConst =
860 getUpperBound(expr: lhs, numDims, numSymbols, operands);
861 if (lhsLbConst && lhsUbConst) {
862 int64_t lhsLbConstVal = *lhsLbConst;
863 int64_t lhsUbConstVal = *lhsUbConst;
864 // lhs floordiv c is a single value lhs is bounded in a range `c` that has
865 // the same quotient.
866 if (binExpr.getKind() == AffineExprKind::FloorDiv &&
867 divideFloorSigned(Numerator: lhsLbConstVal, Denominator: rhsConstVal) ==
868 divideFloorSigned(Numerator: lhsUbConstVal, Denominator: rhsConstVal)) {
869 expr = getAffineConstantExpr(
870 constant: divideFloorSigned(Numerator: lhsLbConstVal, Denominator: rhsConstVal), context);
871 return;
872 }
873 // lhs ceildiv c is a single value if the entire range has the same ceil
874 // quotient.
875 if (binExpr.getKind() == AffineExprKind::CeilDiv &&
876 divideCeilSigned(Numerator: lhsLbConstVal, Denominator: rhsConstVal) ==
877 divideCeilSigned(Numerator: lhsUbConstVal, Denominator: rhsConstVal)) {
878 expr = getAffineConstantExpr(constant: divideCeilSigned(Numerator: lhsLbConstVal, Denominator: rhsConstVal),
879 context);
880 return;
881 }
882 // lhs mod c is lhs if the entire range has quotient 0 w.r.t the rhs.
883 if (binExpr.getKind() == AffineExprKind::Mod && lhsLbConstVal >= 0 &&
884 lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
885 expr = lhs;
886 return;
887 }
888 }
889
890 // Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2)
891 // mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if
892 // `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c.
893 // And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c.
894 AffineExpr quotientTimesDiv, rem;
895 int64_t divisor;
896 if (isQTimesDPlusR(e: lhs, operands, div&: divisor, quotientTimesDiv, rem)) {
897 if (rhsConstVal % divisor == 0 &&
898 binExpr.getKind() == AffineExprKind::FloorDiv) {
899 expr = quotientTimesDiv.floorDiv(other: rhsConst);
900 } else if (divisor % rhsConstVal == 0 &&
901 binExpr.getKind() == AffineExprKind::Mod) {
902 expr = rem % rhsConst;
903 }
904 return;
905 }
906
907 // Handle the simple case when the LHS expression can be either upper
908 // bounded or is a known multiple of RHS constant.
909 // lhs floordiv c -> 0 if 0 <= lhs < c,
910 // lhs mod c -> 0 if lhs % c = 0.
911 if ((isNonNegativeBoundedBy(e: lhs, operands, k: rhsConstVal) &&
912 binExpr.getKind() == AffineExprKind::FloorDiv) ||
913 (getLargestKnownDivisor(e: lhs, operands) % rhsConstVal == 0 &&
914 binExpr.getKind() == AffineExprKind::Mod)) {
915 expr = getAffineConstantExpr(constant: 0, context: expr.getContext());
916 }
917}
918
919/// Simplify the expressions in `map` while making use of lower or upper bounds
920/// of its operands. If `isMax` is true, the map is to be treated as a max of
921/// its result expressions, and min otherwise. Eg: min (d0, d1) -> (8, 4 * d0 +
922/// d1) can be simplified to (8) if the operands are respectively lower bounded
923/// by 2 and 0 (the second expression can't be lower than 8).
924static void simplifyMinOrMaxExprWithOperands(AffineMap &map,
925 ArrayRef<Value> operands,
926 bool isMax) {
927 // Can't simplify.
928 if (operands.empty())
929 return;
930
931 // Get the upper or lower bound on an affine.for op IV using its range.
932 // Get the constant lower or upper bounds on the operands.
933 SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
934 constLowerBounds.reserve(N: operands.size());
935 constUpperBounds.reserve(N: operands.size());
936 for (Value operand : operands) {
937 constLowerBounds.push_back(Elt: getLowerBound(iv: operand));
938 constUpperBounds.push_back(Elt: getUpperBound(iv: operand));
939 }
940
941 // We will compute the lower and upper bounds on each of the expressions
942 // Then, we will check (depending on max or min) as to whether a specific
943 // bound is redundant by checking if its highest (in case of max) and its
944 // lowest (in the case of min) value is already lower than (or higher than)
945 // the lower bound (or upper bound in the case of min) of another bound.
946 SmallVector<std::optional<int64_t>, 4> lowerBounds, upperBounds;
947 lowerBounds.reserve(N: map.getNumResults());
948 upperBounds.reserve(N: map.getNumResults());
949 for (AffineExpr e : map.getResults()) {
950 if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: e)) {
951 lowerBounds.push_back(Elt: constExpr.getValue());
952 upperBounds.push_back(Elt: constExpr.getValue());
953 } else {
954 lowerBounds.push_back(
955 Elt: getBoundForAffineExpr(expr: e, numDims: map.getNumDims(), numSymbols: map.getNumSymbols(),
956 constLowerBounds, constUpperBounds,
957 /*isUpper=*/false));
958 upperBounds.push_back(
959 Elt: getBoundForAffineExpr(expr: e, numDims: map.getNumDims(), numSymbols: map.getNumSymbols(),
960 constLowerBounds, constUpperBounds,
961 /*isUpper=*/true));
962 }
963 }
964
965 // Collect expressions that are not redundant.
966 SmallVector<AffineExpr, 4> irredundantExprs;
967 for (auto exprEn : llvm::enumerate(First: map.getResults())) {
968 AffineExpr e = exprEn.value();
969 unsigned i = exprEn.index();
970 // Some expressions can be turned into constants.
971 if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
972 e = getAffineConstantExpr(constant: *lowerBounds[i], context: e.getContext());
973
974 // Check if the expression is redundant.
975 if (isMax) {
976 if (!upperBounds[i]) {
977 irredundantExprs.push_back(Elt: e);
978 continue;
979 }
980 // If there exists another expression such that its lower bound is greater
981 // than this expression's upper bound, it's redundant.
982 if (!llvm::any_of(Range: llvm::enumerate(First&: lowerBounds), P: [&](const auto &en) {
983 auto otherLowerBound = en.value();
984 unsigned pos = en.index();
985 if (pos == i || !otherLowerBound)
986 return false;
987 if (*otherLowerBound > *upperBounds[i])
988 return true;
989 if (*otherLowerBound < *upperBounds[i])
990 return false;
991 // Equality case. When both expressions are considered redundant, we
992 // don't want to get both of them. We keep the one that appears
993 // first.
994 if (upperBounds[pos] && lowerBounds[i] &&
995 lowerBounds[i] == upperBounds[i] &&
996 otherLowerBound == *upperBounds[pos] && i < pos)
997 return false;
998 return true;
999 }))
1000 irredundantExprs.push_back(Elt: e);
1001 } else {
1002 if (!lowerBounds[i]) {
1003 irredundantExprs.push_back(Elt: e);
1004 continue;
1005 }
1006 // Likewise for the `min` case. Use the complement of the condition above.
1007 if (!llvm::any_of(Range: llvm::enumerate(First&: upperBounds), P: [&](const auto &en) {
1008 auto otherUpperBound = en.value();
1009 unsigned pos = en.index();
1010 if (pos == i || !otherUpperBound)
1011 return false;
1012 if (*otherUpperBound < *lowerBounds[i])
1013 return true;
1014 if (*otherUpperBound > *lowerBounds[i])
1015 return false;
1016 if (lowerBounds[pos] && upperBounds[i] &&
1017 lowerBounds[i] == upperBounds[i] &&
1018 otherUpperBound == lowerBounds[pos] && i < pos)
1019 return false;
1020 return true;
1021 }))
1022 irredundantExprs.push_back(Elt: e);
1023 }
1024 }
1025
1026 // Create the map without the redundant expressions.
1027 map = AffineMap::get(dimCount: map.getNumDims(), symbolCount: map.getNumSymbols(), results: irredundantExprs,
1028 context: map.getContext());
1029}
1030
1031/// Simplify the map while exploiting information on the values in `operands`.
1032// Use "unused attribute" marker to silence warning stemming from the inability
1033// to see through the template expansion.
1034static void LLVM_ATTRIBUTE_UNUSED
1035simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) {
1036 assert(map.getNumInputs() == operands.size() && "invalid operands for map");
1037 SmallVector<AffineExpr> newResults;
1038 newResults.reserve(N: map.getNumResults());
1039 for (AffineExpr expr : map.getResults()) {
1040 simplifyExprAndOperands(expr, numDims: map.getNumDims(), numSymbols: map.getNumSymbols(),
1041 operands);
1042 newResults.push_back(Elt: expr);
1043 }
1044 map = AffineMap::get(dimCount: map.getNumDims(), symbolCount: map.getNumSymbols(), results: newResults,
1045 context: map.getContext());
1046}
1047
1048/// Assuming `dimOrSym` is a quantity in the apply op map `map` and defined by
1049/// `minOp = affine_min(x_1, ..., x_n)`. This function checks that:
1050/// `0 < affine_min(x_1, ..., x_n)` and proceeds with replacing the patterns:
1051/// ```
1052/// dimOrSym.ceildiv(x_k)
1053/// (dimOrSym + x_k - 1).floordiv(x_k)
1054/// ```
1055/// by `1` for all `k` in `1, ..., n`. This is possible because `x / x_k <= 1`.
1056///
1057///
1058/// Warning: ValueBoundsConstraintSet::computeConstantBound is needed to check
1059/// `minOp` is positive.
1060static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp,
1061 AffineExpr dimOrSym,
1062 AffineMap *map,
1063 ValueRange dims,
1064 ValueRange syms) {
1065 AffineMap affineMinMap = minOp.getAffineMap();
1066
1067 LLVM_DEBUG({
1068 DBGS() << "replaceAffineMinBoundingBoxExpression: `" << minOp << "`\n";
1069 });
1070
1071 // Check the value is positive.
1072 for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) {
1073 // Compare each expression in the minimum against 0.
1074 if (!ValueBoundsConstraintSet::compare(
1075 lhs: getAsIndexOpFoldResult(ctx: minOp.getContext(), val: 0),
1076 cmp: ValueBoundsConstraintSet::ComparisonOperator::LT,
1077 rhs: ValueBoundsConstraintSet::Variable(affineMinMap.getSliceMap(start: i, length: 1),
1078 minOp.getOperands())))
1079 return failure();
1080 }
1081
1082 /// Convert affine symbols and dimensions in minOp to symbols or dimensions in
1083 /// the apply op affine map.
1084 DenseMap<AffineExpr, AffineExpr> dimSymConversionTable;
1085 SmallVector<unsigned> unmappedDims, unmappedSyms;
1086 for (auto [i, dim] : llvm::enumerate(First: minOp.getDimOperands())) {
1087 auto it = llvm::find(Range&: dims, Val: dim);
1088 if (it == dims.end()) {
1089 unmappedDims.push_back(Elt: i);
1090 continue;
1091 }
1092 dimSymConversionTable[getAffineDimExpr(position: i, context: minOp.getContext())] =
1093 getAffineDimExpr(position: it.getIndex(), context: minOp.getContext());
1094 }
1095 for (auto [i, sym] : llvm::enumerate(First: minOp.getSymbolOperands())) {
1096 auto it = llvm::find(Range&: syms, Val: sym);
1097 if (it == syms.end()) {
1098 unmappedSyms.push_back(Elt: i);
1099 continue;
1100 }
1101 dimSymConversionTable[getAffineSymbolExpr(position: i, context: minOp.getContext())] =
1102 getAffineSymbolExpr(position: it.getIndex(), context: minOp.getContext());
1103 }
1104
1105 // Create the replacement map.
1106 DenseMap<AffineExpr, AffineExpr> repl;
1107 AffineExpr c1 = getAffineConstantExpr(constant: 1, context: minOp.getContext());
1108 for (AffineExpr expr : affineMinMap.getResults()) {
1109 // If we cannot express the result in terms of the apply map symbols and
1110 // sims then continue.
1111 if (llvm::any_of(Range&: unmappedDims,
1112 P: [&](unsigned i) { return expr.isFunctionOfDim(position: i); }) ||
1113 llvm::any_of(Range&: unmappedSyms,
1114 P: [&](unsigned i) { return expr.isFunctionOfSymbol(position: i); }))
1115 continue;
1116
1117 AffineExpr convertedExpr = expr.replace(map: dimSymConversionTable);
1118
1119 // dimOrSym.ceilDiv(expr) -> 1
1120 repl[dimOrSym.ceilDiv(other: convertedExpr)] = c1;
1121 // (dimOrSym + expr - 1).floorDiv(expr) -> 1
1122 repl[(dimOrSym + convertedExpr - 1).floorDiv(other: convertedExpr)] = c1;
1123 }
1124 AffineMap initialMap = *map;
1125 *map = initialMap.replace(map: repl, numResultDims: initialMap.getNumDims(),
1126 numResultSyms: initialMap.getNumSymbols());
1127 return success(IsSuccess: *map != initialMap);
1128}
1129
1130/// Replace all occurrences of AffineExpr at position `pos` in `map` by the
1131/// defining AffineApplyOp expression and operands.
1132/// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
1133/// When `dimOrSymbolPosition >= dims.size()`,
1134/// AffineSymbolExpr@[pos - dims.size()] is replaced.
1135/// Mutate `map`,`dims` and `syms` in place as follows:
1136/// 1. `dims` and `syms` are only appended to.
1137/// 2. `map` dim and symbols are gradually shifted to higher positions.
1138/// 3. Old `dim` and `sym` entries are replaced by nullptr
1139/// This avoids the need for any bookkeeping.
1140/// If `replaceAffineMin` is set to true, additionally triggers more expensive
1141/// replacements involving affine_min operations.
1142static LogicalResult replaceDimOrSym(AffineMap *map,
1143 unsigned dimOrSymbolPosition,
1144 SmallVectorImpl<Value> &dims,
1145 SmallVectorImpl<Value> &syms,
1146 bool replaceAffineMin) {
1147 MLIRContext *ctx = map->getContext();
1148 bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1149 unsigned pos = isDimReplacement ? dimOrSymbolPosition
1150 : dimOrSymbolPosition - dims.size();
1151 Value &v = isDimReplacement ? dims[pos] : syms[pos];
1152 if (!v)
1153 return failure();
1154
1155 if (auto minOp = v.getDefiningOp<AffineMinOp>(); minOp && replaceAffineMin) {
1156 AffineExpr dimOrSym = isDimReplacement ? getAffineDimExpr(position: pos, context: ctx)
1157 : getAffineSymbolExpr(position: pos, context: ctx);
1158 return replaceAffineMinBoundingBoxExpression(minOp, dimOrSym, map, dims,
1159 syms);
1160 }
1161
1162 auto affineApply = v.getDefiningOp<AffineApplyOp>();
1163 if (!affineApply)
1164 return failure();
1165
1166 // At this point we will perform a replacement of `v`, set the entry in `dim`
1167 // or `sym` to nullptr immediately.
1168 v = nullptr;
1169
1170 // Compute the map, dims and symbols coming from the AffineApplyOp.
1171 AffineMap composeMap = affineApply.getAffineMap();
1172 assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
1173 SmallVector<Value> composeOperands(affineApply.getMapOperands().begin(),
1174 affineApply.getMapOperands().end());
1175 // Canonicalize the map to promote dims to symbols when possible. This is to
1176 // avoid generating invalid maps.
1177 canonicalizeMapAndOperands(map: &composeMap, operands: &composeOperands);
1178 AffineExpr replacementExpr =
1179 composeMap.shiftDims(shift: dims.size()).shiftSymbols(shift: syms.size()).getResult(idx: 0);
1180 ValueRange composeDims =
1181 ArrayRef<Value>(composeOperands).take_front(N: composeMap.getNumDims());
1182 ValueRange composeSyms =
1183 ArrayRef<Value>(composeOperands).take_back(N: composeMap.getNumSymbols());
1184 AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(position: pos, context: ctx)
1185 : getAffineSymbolExpr(position: pos, context: ctx);
1186
1187 // Append the dims and symbols where relevant and perform the replacement.
1188 dims.append(in_start: composeDims.begin(), in_end: composeDims.end());
1189 syms.append(in_start: composeSyms.begin(), in_end: composeSyms.end());
1190 *map = map->replace(expr: toReplace, replacement: replacementExpr, numResultDims: dims.size(), numResultSyms: syms.size());
1191
1192 return success();
1193}
1194
1195/// Iterate over `operands` and fold away all those produced by an AffineApplyOp
1196/// iteratively. Perform canonicalization of map and operands as well as
1197/// AffineMap simplification. `map` and `operands` are mutated in place.
1198static void composeAffineMapAndOperands(AffineMap *map,
1199 SmallVectorImpl<Value> *operands,
1200 bool composeAffineMin = false) {
1201 if (map->getNumResults() == 0) {
1202 canonicalizeMapAndOperands(map, operands);
1203 *map = simplifyAffineMap(map: *map);
1204 return;
1205 }
1206
1207 MLIRContext *ctx = map->getContext();
1208 SmallVector<Value, 4> dims(operands->begin(),
1209 operands->begin() + map->getNumDims());
1210 SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(),
1211 operands->end());
1212
1213 // Iterate over dims and symbols coming from AffineApplyOp and replace until
1214 // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims`
1215 // and `syms` can only increase by construction.
1216 // The implementation uses a `while` loop to support the case of symbols
1217 // that may be constructed from dims ;this may be overkill.
1218 while (true) {
1219 bool changed = false;
1220 for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1221 if ((changed |=
1222 succeeded(Result: replaceDimOrSym(map, dimOrSymbolPosition: pos, dims, syms, replaceAffineMin: composeAffineMin))))
1223 break;
1224 if (!changed)
1225 break;
1226 }
1227
1228 // Clear operands so we can fill them anew.
1229 operands->clear();
1230
1231 // At this point we may have introduced null operands, prune them out before
1232 // canonicalizing map and operands.
1233 unsigned nDims = 0, nSyms = 0;
1234 SmallVector<AffineExpr, 4> dimReplacements, symReplacements;
1235 dimReplacements.reserve(N: dims.size());
1236 symReplacements.reserve(N: syms.size());
1237 for (auto *container : {&dims, &syms}) {
1238 bool isDim = (container == &dims);
1239 auto &repls = isDim ? dimReplacements : symReplacements;
1240 for (const auto &en : llvm::enumerate(First&: *container)) {
1241 Value v = en.value();
1242 if (!v) {
1243 assert(isDim ? !map->isFunctionOfDim(en.index())
1244 : !map->isFunctionOfSymbol(en.index()) &&
1245 "map is function of unexpected expr@pos");
1246 repls.push_back(Elt: getAffineConstantExpr(constant: 0, context: ctx));
1247 continue;
1248 }
1249 repls.push_back(Elt: isDim ? getAffineDimExpr(position: nDims++, context: ctx)
1250 : getAffineSymbolExpr(position: nSyms++, context: ctx));
1251 operands->push_back(Elt: v);
1252 }
1253 }
1254 *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, numResultDims: nDims,
1255 numResultSyms: nSyms);
1256
1257 // Canonicalize and simplify before returning.
1258 canonicalizeMapAndOperands(map, operands);
1259 *map = simplifyAffineMap(map: *map);
1260}
1261
1262void mlir::affine::fullyComposeAffineMapAndOperands(
1263 AffineMap *map, SmallVectorImpl<Value> *operands, bool composeAffineMin) {
1264 while (llvm::any_of(Range&: *operands, P: [](Value v) {
1265 return isa_and_nonnull<AffineApplyOp>(Val: v.getDefiningOp());
1266 })) {
1267 composeAffineMapAndOperands(map, operands, composeAffineMin);
1268 }
1269 // Additional trailing step for AffineMinOps in case no chains of AffineApply.
1270 if (composeAffineMin && llvm::any_of(Range&: *operands, P: [](Value v) {
1271 return isa_and_nonnull<AffineMinOp>(Val: v.getDefiningOp());
1272 })) {
1273 composeAffineMapAndOperands(map, operands, composeAffineMin);
1274 }
1275}
1276
1277AffineApplyOp
1278mlir::affine::makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
1279 ArrayRef<OpFoldResult> operands,
1280 bool composeAffineMin) {
1281 SmallVector<Value> valueOperands;
1282 map = foldAttributesIntoMap(b, map, operands, remainingValues&: valueOperands);
1283 composeAffineMapAndOperands(map: &map, operands: &valueOperands, composeAffineMin);
1284 assert(map);
1285 return b.create<AffineApplyOp>(location: loc, args&: map, args&: valueOperands);
1286}
1287
1288AffineApplyOp
1289mlir::affine::makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e,
1290 ArrayRef<OpFoldResult> operands,
1291 bool composeAffineMin) {
1292 return makeComposedAffineApply(
1293 b, loc,
1294 map: AffineMap::inferFromExprList(exprsList: ArrayRef<AffineExpr>{e}, context: b.getContext())
1295 .front(),
1296 operands, composeAffineMin);
1297}
1298
1299/// Composes the given affine map with the given list of operands, pulling in
1300/// the maps from any affine.apply operations that supply the operands.
1301static void composeMultiResultAffineMap(AffineMap &map,
1302 SmallVectorImpl<Value> &operands,
1303 bool composeAffineMin = false) {
1304 // Compose and canonicalize each expression in the map individually because
1305 // composition only applies to single-result maps, collecting potentially
1306 // duplicate operands in a single list with shifted dimensions and symbols.
1307 SmallVector<Value> dims, symbols;
1308 SmallVector<AffineExpr> exprs;
1309 for (unsigned i : llvm::seq<unsigned>(Begin: 0, End: map.getNumResults())) {
1310 SmallVector<Value> submapOperands(operands.begin(), operands.end());
1311 AffineMap submap = map.getSubMap(resultPos: {i});
1312 fullyComposeAffineMapAndOperands(map: &submap, operands: &submapOperands,
1313 composeAffineMin);
1314 canonicalizeMapAndOperands(map: &submap, operands: &submapOperands);
1315 unsigned numNewDims = submap.getNumDims();
1316 submap = submap.shiftDims(shift: dims.size()).shiftSymbols(shift: symbols.size());
1317 llvm::append_range(C&: dims,
1318 R: ArrayRef<Value>(submapOperands).take_front(N: numNewDims));
1319 llvm::append_range(C&: symbols,
1320 R: ArrayRef<Value>(submapOperands).drop_front(N: numNewDims));
1321 exprs.push_back(Elt: submap.getResult(idx: 0));
1322 }
1323
1324 // Canonicalize the map created from composed expressions to deduplicate the
1325 // dimension and symbol operands.
1326 operands = llvm::to_vector(Range: llvm::concat<Value>(Ranges&: dims, Ranges&: symbols));
1327 map = AffineMap::get(dimCount: dims.size(), symbolCount: symbols.size(), results: exprs, context: map.getContext());
1328 canonicalizeMapAndOperands(map: &map, operands: &operands);
1329}
1330
1331OpFoldResult mlir::affine::makeComposedFoldedAffineApply(
1332 OpBuilder &b, Location loc, AffineMap map, ArrayRef<OpFoldResult> operands,
1333 bool composeAffineMin) {
1334 assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
1335
1336 // Create new builder without a listener, so that no notification is
1337 // triggered if the op is folded.
1338 // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1339 // workaround is no longer needed.
1340 OpBuilder newBuilder(b.getContext());
1341 newBuilder.setInsertionPoint(block: b.getInsertionBlock(), insertPoint: b.getInsertionPoint());
1342
1343 // Create op.
1344 AffineApplyOp applyOp =
1345 makeComposedAffineApply(b&: newBuilder, loc, map, operands, composeAffineMin);
1346
1347 // Get constant operands.
1348 SmallVector<Attribute> constOperands(applyOp->getNumOperands());
1349 for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1350 matchPattern(value: applyOp->getOperand(idx: i), pattern: m_Constant(bind_value: &constOperands[i]));
1351
1352 // Try to fold the operation.
1353 SmallVector<OpFoldResult> foldResults;
1354 if (failed(Result: applyOp->fold(operands: constOperands, results&: foldResults)) ||
1355 foldResults.empty()) {
1356 if (OpBuilder::Listener *listener = b.getListener())
1357 listener->notifyOperationInserted(op: applyOp, /*previous=*/{});
1358 return applyOp.getResult();
1359 }
1360
1361 applyOp->erase();
1362 return llvm::getSingleElement(C&: foldResults);
1363}
1364
1365OpFoldResult mlir::affine::makeComposedFoldedAffineApply(
1366 OpBuilder &b, Location loc, AffineExpr expr,
1367 ArrayRef<OpFoldResult> operands, bool composeAffineMin) {
1368 return makeComposedFoldedAffineApply(
1369 b, loc,
1370 map: AffineMap::inferFromExprList(exprsList: ArrayRef<AffineExpr>{expr}, context: b.getContext())
1371 .front(),
1372 operands, composeAffineMin);
1373}
1374
1375SmallVector<OpFoldResult>
1376mlir::affine::makeComposedFoldedMultiResultAffineApply(
1377 OpBuilder &b, Location loc, AffineMap map, ArrayRef<OpFoldResult> operands,
1378 bool composeAffineMin) {
1379 return llvm::map_to_vector(
1380 C: llvm::seq<unsigned>(Begin: 0, End: map.getNumResults()), F: [&](unsigned i) {
1381 return makeComposedFoldedAffineApply(b, loc, map: map.getSubMap(resultPos: {i}),
1382 operands, composeAffineMin);
1383 });
1384}
1385
1386template <typename OpTy>
1387static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map,
1388 ArrayRef<OpFoldResult> operands) {
1389 SmallVector<Value> valueOperands;
1390 map = foldAttributesIntoMap(b, map, operands, remainingValues&: valueOperands);
1391 composeMultiResultAffineMap(map, operands&: valueOperands);
1392 return b.create<OpTy>(loc, b.getIndexType(), map, valueOperands);
1393}
1394
1395AffineMinOp
1396mlir::affine::makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map,
1397 ArrayRef<OpFoldResult> operands) {
1398 return makeComposedMinMax<AffineMinOp>(b, loc, map, operands);
1399}
1400
1401template <typename OpTy>
1402static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc,
1403 AffineMap map,
1404 ArrayRef<OpFoldResult> operands) {
1405 // Create new builder without a listener, so that no notification is
1406 // triggered if the op is folded.
1407 // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1408 // workaround is no longer needed.
1409 OpBuilder newBuilder(b.getContext());
1410 newBuilder.setInsertionPoint(block: b.getInsertionBlock(), insertPoint: b.getInsertionPoint());
1411
1412 // Create op.
1413 auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1414
1415 // Get constant operands.
1416 SmallVector<Attribute> constOperands(minMaxOp->getNumOperands());
1417 for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1418 matchPattern(minMaxOp->getOperand(i), m_Constant(bind_value: &constOperands[i]));
1419
1420 // Try to fold the operation.
1421 SmallVector<OpFoldResult> foldResults;
1422 if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1423 foldResults.empty()) {
1424 if (OpBuilder::Listener *listener = b.getListener())
1425 listener->notifyOperationInserted(op: minMaxOp, /*previous=*/{});
1426 return minMaxOp.getResult();
1427 }
1428
1429 minMaxOp->erase();
1430 return llvm::getSingleElement(C&: foldResults);
1431}
1432
1433OpFoldResult
1434mlir::affine::makeComposedFoldedAffineMin(OpBuilder &b, Location loc,
1435 AffineMap map,
1436 ArrayRef<OpFoldResult> operands) {
1437 return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
1438}
1439
1440OpFoldResult
1441mlir::affine::makeComposedFoldedAffineMax(OpBuilder &b, Location loc,
1442 AffineMap map,
1443 ArrayRef<OpFoldResult> operands) {
1444 return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
1445}
1446
1447// A symbol may appear as a dim in affine.apply operations. This function
1448// canonicalizes dims that are valid symbols into actual symbols.
1449template <class MapOrSet>
1450static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
1451 SmallVectorImpl<Value> *operands) {
1452 if (!mapOrSet || operands->empty())
1453 return;
1454
1455 assert(mapOrSet->getNumInputs() == operands->size() &&
1456 "map/set inputs must match number of operands");
1457
1458 auto *context = mapOrSet->getContext();
1459 SmallVector<Value, 8> resultOperands;
1460 resultOperands.reserve(N: operands->size());
1461 SmallVector<Value, 8> remappedSymbols;
1462 remappedSymbols.reserve(N: operands->size());
1463 unsigned nextDim = 0;
1464 unsigned nextSym = 0;
1465 unsigned oldNumSyms = mapOrSet->getNumSymbols();
1466 SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1467 for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1468 if (i < mapOrSet->getNumDims()) {
1469 if (isValidSymbol(value: (*operands)[i])) {
1470 // This is a valid symbol that appears as a dim, canonicalize it.
1471 dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
1472 remappedSymbols.push_back(Elt: (*operands)[i]);
1473 } else {
1474 dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1475 resultOperands.push_back(Elt: (*operands)[i]);
1476 }
1477 } else {
1478 resultOperands.push_back(Elt: (*operands)[i]);
1479 }
1480 }
1481
1482 resultOperands.append(in_start: remappedSymbols.begin(), in_end: remappedSymbols.end());
1483 *operands = resultOperands;
1484 *mapOrSet = mapOrSet->replaceDimsAndSymbols(
1485 dimRemapping, /*symReplacements=*/{}, nextDim, oldNumSyms + nextSym);
1486
1487 assert(mapOrSet->getNumInputs() == operands->size() &&
1488 "map/set inputs must match number of operands");
1489}
1490
1491/// A valid affine dimension may appear as a symbol in affine.apply operations.
1492/// Given an application of `operands` to an affine map or integer set
1493/// `mapOrSet`, this function canonicalizes symbols of `mapOrSet` that are valid
1494/// dims, but not valid symbols into actual dims. Without such a legalization,
1495/// the affine.apply will be invalid. This method is the exact inverse of
1496/// canonicalizePromotedSymbols.
1497template <class MapOrSet>
1498static void legalizeDemotedDims(MapOrSet &mapOrSet,
1499 SmallVectorImpl<Value> &operands) {
1500 if (!mapOrSet || operands.empty())
1501 return;
1502
1503 unsigned numOperands = operands.size();
1504
1505 assert(mapOrSet.getNumInputs() == numOperands &&
1506 "map/set inputs must match number of operands");
1507
1508 auto *context = mapOrSet.getContext();
1509 SmallVector<Value, 8> resultOperands;
1510 resultOperands.reserve(N: numOperands);
1511 SmallVector<Value, 8> remappedDims;
1512 remappedDims.reserve(N: numOperands);
1513 SmallVector<Value, 8> symOperands;
1514 symOperands.reserve(N: mapOrSet.getNumSymbols());
1515 unsigned nextSym = 0;
1516 unsigned nextDim = 0;
1517 unsigned oldNumDims = mapOrSet.getNumDims();
1518 SmallVector<AffineExpr, 8> symRemapping(mapOrSet.getNumSymbols());
1519 resultOperands.assign(in_start: operands.begin(), in_end: operands.begin() + oldNumDims);
1520 for (unsigned i = oldNumDims, e = mapOrSet.getNumInputs(); i != e; ++i) {
1521 if (operands[i] && isValidDim(value: operands[i]) && !isValidSymbol(value: operands[i])) {
1522 // This is a valid dim that appears as a symbol, legalize it.
1523 symRemapping[i - oldNumDims] =
1524 getAffineDimExpr(oldNumDims + nextDim++, context);
1525 remappedDims.push_back(Elt: operands[i]);
1526 } else {
1527 symRemapping[i - oldNumDims] = getAffineSymbolExpr(nextSym++, context);
1528 symOperands.push_back(Elt: operands[i]);
1529 }
1530 }
1531
1532 append_range(C&: resultOperands, R&: remappedDims);
1533 append_range(C&: resultOperands, R&: symOperands);
1534 operands = resultOperands;
1535 mapOrSet = mapOrSet.replaceDimsAndSymbols(
1536 /*dimReplacements=*/{}, symRemapping, oldNumDims + nextDim, nextSym);
1537
1538 assert(mapOrSet.getNumInputs() == operands.size() &&
1539 "map/set inputs must match number of operands");
1540}
1541
1542// Works for either an affine map or an integer set.
1543template <class MapOrSet>
1544static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
1545 SmallVectorImpl<Value> *operands) {
1546 static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1547 "Argument must be either of AffineMap or IntegerSet type");
1548
1549 if (!mapOrSet || operands->empty())
1550 return;
1551
1552 assert(mapOrSet->getNumInputs() == operands->size() &&
1553 "map/set inputs must match number of operands");
1554
1555 canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1556 legalizeDemotedDims<MapOrSet>(*mapOrSet, *operands);
1557
1558 // Check to see what dims are used.
1559 llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1560 llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1561 mapOrSet->walkExprs([&](AffineExpr expr) {
1562 if (auto dimExpr = dyn_cast<AffineDimExpr>(Val&: expr))
1563 usedDims[dimExpr.getPosition()] = true;
1564 else if (auto symExpr = dyn_cast<AffineSymbolExpr>(Val&: expr))
1565 usedSyms[symExpr.getPosition()] = true;
1566 });
1567
1568 auto *context = mapOrSet->getContext();
1569
1570 SmallVector<Value, 8> resultOperands;
1571 resultOperands.reserve(N: operands->size());
1572
1573 llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1574 SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1575 unsigned nextDim = 0;
1576 for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1577 if (usedDims[i]) {
1578 // Remap dim positions for duplicate operands.
1579 auto it = seenDims.find(Val: (*operands)[i]);
1580 if (it == seenDims.end()) {
1581 dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1582 resultOperands.push_back(Elt: (*operands)[i]);
1583 seenDims.insert(KV: std::make_pair(x&: (*operands)[i], y&: dimRemapping[i]));
1584 } else {
1585 dimRemapping[i] = it->second;
1586 }
1587 }
1588 }
1589 llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1590 SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
1591 unsigned nextSym = 0;
1592 for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1593 if (!usedSyms[i])
1594 continue;
1595 // Handle constant operands (only needed for symbolic operands since
1596 // constant operands in dimensional positions would have already been
1597 // promoted to symbolic positions above).
1598 IntegerAttr operandCst;
1599 if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
1600 m_Constant(bind_value: &operandCst))) {
1601 symRemapping[i] =
1602 getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
1603 continue;
1604 }
1605 // Remap symbol positions for duplicate operands.
1606 auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1607 if (it == seenSymbols.end()) {
1608 symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
1609 resultOperands.push_back(Elt: (*operands)[i + mapOrSet->getNumDims()]);
1610 seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1611 symRemapping[i]));
1612 } else {
1613 symRemapping[i] = it->second;
1614 }
1615 }
1616 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1617 nextDim, nextSym);
1618 *operands = resultOperands;
1619}
1620
1621void mlir::affine::canonicalizeMapAndOperands(
1622 AffineMap *map, SmallVectorImpl<Value> *operands) {
1623 canonicalizeMapOrSetAndOperands<AffineMap>(mapOrSet: map, operands);
1624}
1625
1626void mlir::affine::canonicalizeSetAndOperands(
1627 IntegerSet *set, SmallVectorImpl<Value> *operands) {
1628 canonicalizeMapOrSetAndOperands<IntegerSet>(mapOrSet: set, operands);
1629}
1630
1631namespace {
1632/// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
1633/// maps that supply results into them.
1634///
1635template <typename AffineOpTy>
1636struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
1637 using OpRewritePattern<AffineOpTy>::OpRewritePattern;
1638
1639 /// Replace the affine op with another instance of it with the supplied
1640 /// map and mapOperands.
1641 void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
1642 AffineMap map, ArrayRef<Value> mapOperands) const;
1643
1644 LogicalResult matchAndRewrite(AffineOpTy affineOp,
1645 PatternRewriter &rewriter) const override {
1646 static_assert(
1647 llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1648 AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1649 AffineVectorStoreOp, AffineVectorLoadOp>::value,
1650 "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1651 "expected");
1652 auto map = affineOp.getAffineMap();
1653 AffineMap oldMap = map;
1654 auto oldOperands = affineOp.getMapOperands();
1655 SmallVector<Value, 8> resultOperands(oldOperands);
1656 composeAffineMapAndOperands(&map, &resultOperands);
1657 canonicalizeMapAndOperands(&map, &resultOperands);
1658 simplifyMapWithOperands(map, resultOperands);
1659 if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1660 resultOperands.begin()))
1661 return failure();
1662
1663 replaceAffineOp(rewriter, affineOp, map, mapOperands: resultOperands);
1664 return success();
1665 }
1666};
1667
1668// Specialize the template to account for the different build signatures for
1669// affine load, store, and apply ops.
1670template <>
1671void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1672 PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
1673 ArrayRef<Value> mapOperands) const {
1674 rewriter.replaceOpWithNewOp<AffineLoadOp>(op: load, args: load.getMemRef(), args&: map,
1675 args&: mapOperands);
1676}
1677template <>
1678void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1679 PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
1680 ArrayRef<Value> mapOperands) const {
1681 rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
1682 op: prefetch, args: prefetch.getMemref(), args&: map, args&: mapOperands, args: prefetch.getIsWrite(),
1683 args: prefetch.getLocalityHint(), args: prefetch.getIsDataCache());
1684}
1685template <>
1686void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1687 PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
1688 ArrayRef<Value> mapOperands) const {
1689 rewriter.replaceOpWithNewOp<AffineStoreOp>(
1690 op: store, args: store.getValueToStore(), args: store.getMemRef(), args&: map, args&: mapOperands);
1691}
1692template <>
1693void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1694 PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map,
1695 ArrayRef<Value> mapOperands) const {
1696 rewriter.replaceOpWithNewOp<AffineVectorLoadOp>(
1697 op: vectorload, args: vectorload.getVectorType(), args: vectorload.getMemRef(), args&: map,
1698 args&: mapOperands);
1699}
1700template <>
1701void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1702 PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map,
1703 ArrayRef<Value> mapOperands) const {
1704 rewriter.replaceOpWithNewOp<AffineVectorStoreOp>(
1705 op: vectorstore, args: vectorstore.getValueToStore(), args: vectorstore.getMemRef(), args&: map,
1706 args&: mapOperands);
1707}
1708
1709// Generic version for ops that don't have extra operands.
1710template <typename AffineOpTy>
1711void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1712 PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
1713 ArrayRef<Value> mapOperands) const {
1714 rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
1715}
1716} // namespace
1717
1718void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1719 MLIRContext *context) {
1720 results.add<SimplifyAffineOp<AffineApplyOp>>(arg&: context);
1721}
1722
1723//===----------------------------------------------------------------------===//
1724// AffineDmaStartOp
1725//===----------------------------------------------------------------------===//
1726
1727// TODO: Check that map operands are loop IVs or symbols.
1728void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
1729 Value srcMemRef, AffineMap srcMap,
1730 ValueRange srcIndices, Value destMemRef,
1731 AffineMap dstMap, ValueRange destIndices,
1732 Value tagMemRef, AffineMap tagMap,
1733 ValueRange tagIndices, Value numElements,
1734 Value stride, Value elementsPerStride) {
1735 result.addOperands(newOperands: srcMemRef);
1736 result.addAttribute(name: getSrcMapAttrStrName(), attr: AffineMapAttr::get(value: srcMap));
1737 result.addOperands(newOperands: srcIndices);
1738 result.addOperands(newOperands: destMemRef);
1739 result.addAttribute(name: getDstMapAttrStrName(), attr: AffineMapAttr::get(value: dstMap));
1740 result.addOperands(newOperands: destIndices);
1741 result.addOperands(newOperands: tagMemRef);
1742 result.addAttribute(name: getTagMapAttrStrName(), attr: AffineMapAttr::get(value: tagMap));
1743 result.addOperands(newOperands: tagIndices);
1744 result.addOperands(newOperands: numElements);
1745 if (stride) {
1746 result.addOperands(newOperands: {stride, elementsPerStride});
1747 }
1748}
1749
1750void AffineDmaStartOp::print(OpAsmPrinter &p) {
1751 p << " " << getSrcMemRef() << '[';
1752 p.printAffineMapOfSSAIds(mapAttr: getSrcMapAttr(), operands: getSrcIndices());
1753 p << "], " << getDstMemRef() << '[';
1754 p.printAffineMapOfSSAIds(mapAttr: getDstMapAttr(), operands: getDstIndices());
1755 p << "], " << getTagMemRef() << '[';
1756 p.printAffineMapOfSSAIds(mapAttr: getTagMapAttr(), operands: getTagIndices());
1757 p << "], " << getNumElements();
1758 if (isStrided()) {
1759 p << ", " << getStride();
1760 p << ", " << getNumElementsPerStride();
1761 }
1762 p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
1763 << getTagMemRefType();
1764}
1765
1766// Parse AffineDmaStartOp.
1767// Ex:
1768// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
1769// %stride, %num_elt_per_stride
1770// : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
1771//
1772ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
1773 OperationState &result) {
1774 OpAsmParser::UnresolvedOperand srcMemRefInfo;
1775 AffineMapAttr srcMapAttr;
1776 SmallVector<OpAsmParser::UnresolvedOperand, 4> srcMapOperands;
1777 OpAsmParser::UnresolvedOperand dstMemRefInfo;
1778 AffineMapAttr dstMapAttr;
1779 SmallVector<OpAsmParser::UnresolvedOperand, 4> dstMapOperands;
1780 OpAsmParser::UnresolvedOperand tagMemRefInfo;
1781 AffineMapAttr tagMapAttr;
1782 SmallVector<OpAsmParser::UnresolvedOperand, 4> tagMapOperands;
1783 OpAsmParser::UnresolvedOperand numElementsInfo;
1784 SmallVector<OpAsmParser::UnresolvedOperand, 2> strideInfo;
1785
1786 SmallVector<Type, 3> types;
1787 auto indexType = parser.getBuilder().getIndexType();
1788
1789 // Parse and resolve the following list of operands:
1790 // *) dst memref followed by its affine maps operands (in square brackets).
1791 // *) src memref followed by its affine map operands (in square brackets).
1792 // *) tag memref followed by its affine map operands (in square brackets).
1793 // *) number of elements transferred by DMA operation.
1794 if (parser.parseOperand(result&: srcMemRefInfo) ||
1795 parser.parseAffineMapOfSSAIds(operands&: srcMapOperands, map&: srcMapAttr,
1796 attrName: getSrcMapAttrStrName(),
1797 attrs&: result.attributes) ||
1798 parser.parseComma() || parser.parseOperand(result&: dstMemRefInfo) ||
1799 parser.parseAffineMapOfSSAIds(operands&: dstMapOperands, map&: dstMapAttr,
1800 attrName: getDstMapAttrStrName(),
1801 attrs&: result.attributes) ||
1802 parser.parseComma() || parser.parseOperand(result&: tagMemRefInfo) ||
1803 parser.parseAffineMapOfSSAIds(operands&: tagMapOperands, map&: tagMapAttr,
1804 attrName: getTagMapAttrStrName(),
1805 attrs&: result.attributes) ||
1806 parser.parseComma() || parser.parseOperand(result&: numElementsInfo))
1807 return failure();
1808
1809 // Parse optional stride and elements per stride.
1810 if (parser.parseTrailingOperandList(result&: strideInfo))
1811 return failure();
1812
1813 if (!strideInfo.empty() && strideInfo.size() != 2) {
1814 return parser.emitError(loc: parser.getNameLoc(),
1815 message: "expected two stride related operands");
1816 }
1817 bool isStrided = strideInfo.size() == 2;
1818
1819 if (parser.parseColonTypeList(result&: types))
1820 return failure();
1821
1822 if (types.size() != 3)
1823 return parser.emitError(loc: parser.getNameLoc(), message: "expected three types");
1824
1825 if (parser.resolveOperand(operand: srcMemRefInfo, type: types[0], result&: result.operands) ||
1826 parser.resolveOperands(operands&: srcMapOperands, type: indexType, result&: result.operands) ||
1827 parser.resolveOperand(operand: dstMemRefInfo, type: types[1], result&: result.operands) ||
1828 parser.resolveOperands(operands&: dstMapOperands, type: indexType, result&: result.operands) ||
1829 parser.resolveOperand(operand: tagMemRefInfo, type: types[2], result&: result.operands) ||
1830 parser.resolveOperands(operands&: tagMapOperands, type: indexType, result&: result.operands) ||
1831 parser.resolveOperand(operand: numElementsInfo, type: indexType, result&: result.operands))
1832 return failure();
1833
1834 if (isStrided) {
1835 if (parser.resolveOperands(operands&: strideInfo, type: indexType, result&: result.operands))
1836 return failure();
1837 }
1838
1839 // Check that src/dst/tag operand counts match their map.numInputs.
1840 if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1841 dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1842 tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1843 return parser.emitError(loc: parser.getNameLoc(),
1844 message: "memref operand count not equal to map.numInputs");
1845 return success();
1846}
1847
1848LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
1849 if (!llvm::isa<MemRefType>(Val: getOperand(i: getSrcMemRefOperandIndex()).getType()))
1850 return emitOpError(message: "expected DMA source to be of memref type");
1851 if (!llvm::isa<MemRefType>(Val: getOperand(i: getDstMemRefOperandIndex()).getType()))
1852 return emitOpError(message: "expected DMA destination to be of memref type");
1853 if (!llvm::isa<MemRefType>(Val: getOperand(i: getTagMemRefOperandIndex()).getType()))
1854 return emitOpError(message: "expected DMA tag to be of memref type");
1855
1856 unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1857 getDstMap().getNumInputs() +
1858 getTagMap().getNumInputs();
1859 if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1860 getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1861 return emitOpError(message: "incorrect number of operands");
1862 }
1863
1864 Region *scope = getAffineScope(op: *this);
1865 for (auto idx : getSrcIndices()) {
1866 if (!idx.getType().isIndex())
1867 return emitOpError(message: "src index to dma_start must have 'index' type");
1868 if (!isValidAffineIndexOperand(value: idx, region: scope))
1869 return emitOpError(
1870 message: "src index must be a valid dimension or symbol identifier");
1871 }
1872 for (auto idx : getDstIndices()) {
1873 if (!idx.getType().isIndex())
1874 return emitOpError(message: "dst index to dma_start must have 'index' type");
1875 if (!isValidAffineIndexOperand(value: idx, region: scope))
1876 return emitOpError(
1877 message: "dst index must be a valid dimension or symbol identifier");
1878 }
1879 for (auto idx : getTagIndices()) {
1880 if (!idx.getType().isIndex())
1881 return emitOpError(message: "tag index to dma_start must have 'index' type");
1882 if (!isValidAffineIndexOperand(value: idx, region: scope))
1883 return emitOpError(
1884 message: "tag index must be a valid dimension or symbol identifier");
1885 }
1886 return success();
1887}
1888
1889LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1890 SmallVectorImpl<OpFoldResult> &results) {
1891 /// dma_start(memrefcast) -> dma_start
1892 return memref::foldMemRefCast(op: *this);
1893}
1894
1895void AffineDmaStartOp::getEffects(
1896 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1897 &effects) {
1898 effects.emplace_back(Args: MemoryEffects::Read::get(), Args: &getSrcMemRefMutable(),
1899 Args: SideEffects::DefaultResource::get());
1900 effects.emplace_back(Args: MemoryEffects::Write::get(), Args: &getDstMemRefMutable(),
1901 Args: SideEffects::DefaultResource::get());
1902 effects.emplace_back(Args: MemoryEffects::Read::get(), Args: &getTagMemRefMutable(),
1903 Args: SideEffects::DefaultResource::get());
1904}
1905
1906//===----------------------------------------------------------------------===//
1907// AffineDmaWaitOp
1908//===----------------------------------------------------------------------===//
1909
1910// TODO: Check that map operands are loop IVs or symbols.
1911void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
1912 Value tagMemRef, AffineMap tagMap,
1913 ValueRange tagIndices, Value numElements) {
1914 result.addOperands(newOperands: tagMemRef);
1915 result.addAttribute(name: getTagMapAttrStrName(), attr: AffineMapAttr::get(value: tagMap));
1916 result.addOperands(newOperands: tagIndices);
1917 result.addOperands(newOperands: numElements);
1918}
1919
1920void AffineDmaWaitOp::print(OpAsmPrinter &p) {
1921 p << " " << getTagMemRef() << '[';
1922 SmallVector<Value, 2> operands(getTagIndices());
1923 p.printAffineMapOfSSAIds(mapAttr: getTagMapAttr(), operands);
1924 p << "], ";
1925 p.printOperand(value: getNumElements());
1926 p << " : " << getTagMemRef().getType();
1927}
1928
1929// Parse AffineDmaWaitOp.
1930// Eg:
1931// affine.dma_wait %tag[%index], %num_elements
1932// : memref<1 x i32, (d0) -> (d0), 4>
1933//
1934ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
1935 OperationState &result) {
1936 OpAsmParser::UnresolvedOperand tagMemRefInfo;
1937 AffineMapAttr tagMapAttr;
1938 SmallVector<OpAsmParser::UnresolvedOperand, 2> tagMapOperands;
1939 Type type;
1940 auto indexType = parser.getBuilder().getIndexType();
1941 OpAsmParser::UnresolvedOperand numElementsInfo;
1942
1943 // Parse tag memref, its map operands, and dma size.
1944 if (parser.parseOperand(result&: tagMemRefInfo) ||
1945 parser.parseAffineMapOfSSAIds(operands&: tagMapOperands, map&: tagMapAttr,
1946 attrName: getTagMapAttrStrName(),
1947 attrs&: result.attributes) ||
1948 parser.parseComma() || parser.parseOperand(result&: numElementsInfo) ||
1949 parser.parseColonType(result&: type) ||
1950 parser.resolveOperand(operand: tagMemRefInfo, type, result&: result.operands) ||
1951 parser.resolveOperands(operands&: tagMapOperands, type: indexType, result&: result.operands) ||
1952 parser.resolveOperand(operand: numElementsInfo, type: indexType, result&: result.operands))
1953 return failure();
1954
1955 if (!llvm::isa<MemRefType>(Val: type))
1956 return parser.emitError(loc: parser.getNameLoc(),
1957 message: "expected tag to be of memref type");
1958
1959 if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1960 return parser.emitError(loc: parser.getNameLoc(),
1961 message: "tag memref operand count != to map.numInputs");
1962 return success();
1963}
1964
1965LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
1966 if (!llvm::isa<MemRefType>(Val: getOperand(i: 0).getType()))
1967 return emitOpError(message: "expected DMA tag to be of memref type");
1968 Region *scope = getAffineScope(op: *this);
1969 for (auto idx : getTagIndices()) {
1970 if (!idx.getType().isIndex())
1971 return emitOpError(message: "index to dma_wait must have 'index' type");
1972 if (!isValidAffineIndexOperand(value: idx, region: scope))
1973 return emitOpError(
1974 message: "index must be a valid dimension or symbol identifier");
1975 }
1976 return success();
1977}
1978
1979LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1980 SmallVectorImpl<OpFoldResult> &results) {
1981 /// dma_wait(memrefcast) -> dma_wait
1982 return memref::foldMemRefCast(op: *this);
1983}
1984
1985void AffineDmaWaitOp::getEffects(
1986 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1987 &effects) {
1988 effects.emplace_back(Args: MemoryEffects::Read::get(), Args: &getTagMemRefMutable(),
1989 Args: SideEffects::DefaultResource::get());
1990}
1991
1992//===----------------------------------------------------------------------===//
1993// AffineForOp
1994//===----------------------------------------------------------------------===//
1995
1996/// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
1997/// bodyBuilder are empty/null, we include default terminator op.
1998void AffineForOp::build(OpBuilder &builder, OperationState &result,
1999 ValueRange lbOperands, AffineMap lbMap,
2000 ValueRange ubOperands, AffineMap ubMap, int64_t step,
2001 ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
2002 assert(((!lbMap && lbOperands.empty()) ||
2003 lbOperands.size() == lbMap.getNumInputs()) &&
2004 "lower bound operand count does not match the affine map");
2005 assert(((!ubMap && ubOperands.empty()) ||
2006 ubOperands.size() == ubMap.getNumInputs()) &&
2007 "upper bound operand count does not match the affine map");
2008 assert(step > 0 && "step has to be a positive integer constant");
2009
2010 OpBuilder::InsertionGuard guard(builder);
2011
2012 // Set variadic segment sizes.
2013 result.addAttribute(
2014 name: getOperandSegmentSizeAttr(),
2015 attr: builder.getDenseI32ArrayAttr(values: {static_cast<int32_t>(lbOperands.size()),
2016 static_cast<int32_t>(ubOperands.size()),
2017 static_cast<int32_t>(iterArgs.size())}));
2018
2019 for (Value val : iterArgs)
2020 result.addTypes(newTypes: val.getType());
2021
2022 // Add an attribute for the step.
2023 result.addAttribute(name: getStepAttrName(name: result.name),
2024 attr: builder.getIntegerAttr(type: builder.getIndexType(), value: step));
2025
2026 // Add the lower bound.
2027 result.addAttribute(name: getLowerBoundMapAttrName(name: result.name),
2028 attr: AffineMapAttr::get(value: lbMap));
2029 result.addOperands(newOperands: lbOperands);
2030
2031 // Add the upper bound.
2032 result.addAttribute(name: getUpperBoundMapAttrName(name: result.name),
2033 attr: AffineMapAttr::get(value: ubMap));
2034 result.addOperands(newOperands: ubOperands);
2035
2036 result.addOperands(newOperands: iterArgs);
2037 // Create a region and a block for the body. The argument of the region is
2038 // the loop induction variable.
2039 Region *bodyRegion = result.addRegion();
2040 Block *bodyBlock = builder.createBlock(parent: bodyRegion);
2041 Value inductionVar =
2042 bodyBlock->addArgument(type: builder.getIndexType(), loc: result.location);
2043 for (Value val : iterArgs)
2044 bodyBlock->addArgument(type: val.getType(), loc: val.getLoc());
2045
2046 // Create the default terminator if the builder is not provided and if the
2047 // iteration arguments are not provided. Otherwise, leave this to the caller
2048 // because we don't know which values to return from the loop.
2049 if (iterArgs.empty() && !bodyBuilder) {
2050 ensureTerminator(region&: *bodyRegion, builder, loc: result.location);
2051 } else if (bodyBuilder) {
2052 OpBuilder::InsertionGuard guard(builder);
2053 builder.setInsertionPointToStart(bodyBlock);
2054 bodyBuilder(builder, result.location, inductionVar,
2055 bodyBlock->getArguments().drop_front());
2056 }
2057}
2058
2059void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
2060 int64_t ub, int64_t step, ValueRange iterArgs,
2061 BodyBuilderFn bodyBuilder) {
2062 auto lbMap = AffineMap::getConstantMap(val: lb, context: builder.getContext());
2063 auto ubMap = AffineMap::getConstantMap(val: ub, context: builder.getContext());
2064 return build(builder, result, lbOperands: {}, lbMap, ubOperands: {}, ubMap, step, iterArgs,
2065 bodyBuilder);
2066}
2067
2068LogicalResult AffineForOp::verifyRegions() {
2069 // Check that the body defines as single block argument for the induction
2070 // variable.
2071 auto *body = getBody();
2072 if (body->getNumArguments() == 0 || !body->getArgument(i: 0).getType().isIndex())
2073 return emitOpError(message: "expected body to have a single index argument for the "
2074 "induction variable");
2075
2076 // Verify that the bound operands are valid dimension/symbols.
2077 /// Lower bound.
2078 if (getLowerBoundMap().getNumInputs() > 0)
2079 if (failed(Result: verifyDimAndSymbolIdentifiers(op&: *this, operands: getLowerBoundOperands(),
2080 numDims: getLowerBoundMap().getNumDims())))
2081 return failure();
2082 /// Upper bound.
2083 if (getUpperBoundMap().getNumInputs() > 0)
2084 if (failed(Result: verifyDimAndSymbolIdentifiers(op&: *this, operands: getUpperBoundOperands(),
2085 numDims: getUpperBoundMap().getNumDims())))
2086 return failure();
2087 if (getLowerBoundMap().getNumResults() < 1)
2088 return emitOpError(message: "expected lower bound map to have at least one result");
2089 if (getUpperBoundMap().getNumResults() < 1)
2090 return emitOpError(message: "expected upper bound map to have at least one result");
2091
2092 unsigned opNumResults = getNumResults();
2093 if (opNumResults == 0)
2094 return success();
2095
2096 // If ForOp defines values, check that the number and types of the defined
2097 // values match ForOp initial iter operands and backedge basic block
2098 // arguments.
2099 if (getNumIterOperands() != opNumResults)
2100 return emitOpError(
2101 message: "mismatch between the number of loop-carried values and results");
2102 if (getNumRegionIterArgs() != opNumResults)
2103 return emitOpError(
2104 message: "mismatch between the number of basic block args and results");
2105
2106 return success();
2107}
2108
2109/// Parse a for operation loop bounds.
2110static ParseResult parseBound(bool isLower, OperationState &result,
2111 OpAsmParser &p) {
2112 // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
2113 // the map has multiple results.
2114 bool failedToParsedMinMax =
2115 failed(Result: p.parseOptionalKeyword(keyword: isLower ? "max" : "min"));
2116
2117 auto &builder = p.getBuilder();
2118 auto boundAttrStrName =
2119 isLower ? AffineForOp::getLowerBoundMapAttrName(name: result.name)
2120 : AffineForOp::getUpperBoundMapAttrName(name: result.name);
2121
2122 // Parse ssa-id as identity map.
2123 SmallVector<OpAsmParser::UnresolvedOperand, 1> boundOpInfos;
2124 if (p.parseOperandList(result&: boundOpInfos))
2125 return failure();
2126
2127 if (!boundOpInfos.empty()) {
2128 // Check that only one operand was parsed.
2129 if (boundOpInfos.size() > 1)
2130 return p.emitError(loc: p.getNameLoc(),
2131 message: "expected only one loop bound operand");
2132
2133 // TODO: improve error message when SSA value is not of index type.
2134 // Currently it is 'use of value ... expects different type than prior uses'
2135 if (p.resolveOperand(operand: boundOpInfos.front(), type: builder.getIndexType(),
2136 result&: result.operands))
2137 return failure();
2138
2139 // Create an identity map using symbol id. This representation is optimized
2140 // for storage. Analysis passes may expand it into a multi-dimensional map
2141 // if desired.
2142 AffineMap map = builder.getSymbolIdentityMap();
2143 result.addAttribute(name: boundAttrStrName, attr: AffineMapAttr::get(value: map));
2144 return success();
2145 }
2146
2147 // Get the attribute location.
2148 SMLoc attrLoc = p.getCurrentLocation();
2149
2150 Attribute boundAttr;
2151 if (p.parseAttribute(result&: boundAttr, type: builder.getIndexType(), attrName: boundAttrStrName,
2152 attrs&: result.attributes))
2153 return failure();
2154
2155 // Parse full form - affine map followed by dim and symbol list.
2156 if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(Val&: boundAttr)) {
2157 unsigned currentNumOperands = result.operands.size();
2158 unsigned numDims;
2159 if (parseDimAndSymbolList(parser&: p, operands&: result.operands, numDims))
2160 return failure();
2161
2162 auto map = affineMapAttr.getValue();
2163 if (map.getNumDims() != numDims)
2164 return p.emitError(
2165 loc: p.getNameLoc(),
2166 message: "dim operand count and affine map dim count must match");
2167
2168 unsigned numDimAndSymbolOperands =
2169 result.operands.size() - currentNumOperands;
2170 if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
2171 return p.emitError(
2172 loc: p.getNameLoc(),
2173 message: "symbol operand count and affine map symbol count must match");
2174
2175 // If the map has multiple results, make sure that we parsed the min/max
2176 // prefix.
2177 if (map.getNumResults() > 1 && failedToParsedMinMax) {
2178 if (isLower) {
2179 return p.emitError(loc: attrLoc, message: "lower loop bound affine map with "
2180 "multiple results requires 'max' prefix");
2181 }
2182 return p.emitError(loc: attrLoc, message: "upper loop bound affine map with multiple "
2183 "results requires 'min' prefix");
2184 }
2185 return success();
2186 }
2187
2188 // Parse custom assembly form.
2189 if (auto integerAttr = llvm::dyn_cast<IntegerAttr>(Val&: boundAttr)) {
2190 result.attributes.pop_back();
2191 result.addAttribute(
2192 name: boundAttrStrName,
2193 attr: AffineMapAttr::get(value: builder.getConstantAffineMap(val: integerAttr.getInt())));
2194 return success();
2195 }
2196
2197 return p.emitError(
2198 loc: p.getNameLoc(),
2199 message: "expected valid affine map representation for loop bounds");
2200}
2201
2202ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
2203 auto &builder = parser.getBuilder();
2204 OpAsmParser::Argument inductionVariable;
2205 inductionVariable.type = builder.getIndexType();
2206 // Parse the induction variable followed by '='.
2207 if (parser.parseArgument(result&: inductionVariable) || parser.parseEqual())
2208 return failure();
2209
2210 // Parse loop bounds.
2211 int64_t numOperands = result.operands.size();
2212 if (parseBound(/*isLower=*/true, result, p&: parser))
2213 return failure();
2214 int64_t numLbOperands = result.operands.size() - numOperands;
2215 if (parser.parseKeyword(keyword: "to", msg: " between bounds"))
2216 return failure();
2217 numOperands = result.operands.size();
2218 if (parseBound(/*isLower=*/false, result, p&: parser))
2219 return failure();
2220 int64_t numUbOperands = result.operands.size() - numOperands;
2221
2222 // Parse the optional loop step, we default to 1 if one is not present.
2223 if (parser.parseOptionalKeyword(keyword: "step")) {
2224 result.addAttribute(
2225 name: getStepAttrName(name: result.name),
2226 attr: builder.getIntegerAttr(type: builder.getIndexType(), /*value=*/1));
2227 } else {
2228 SMLoc stepLoc = parser.getCurrentLocation();
2229 IntegerAttr stepAttr;
2230 if (parser.parseAttribute(result&: stepAttr, type: builder.getIndexType(),
2231 attrName: getStepAttrName(name: result.name).data(),
2232 attrs&: result.attributes))
2233 return failure();
2234
2235 if (stepAttr.getValue().isNegative())
2236 return parser.emitError(
2237 loc: stepLoc,
2238 message: "expected step to be representable as a positive signed integer");
2239 }
2240
2241 // Parse the optional initial iteration arguments.
2242 SmallVector<OpAsmParser::Argument, 4> regionArgs;
2243 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
2244
2245 // Induction variable.
2246 regionArgs.push_back(Elt: inductionVariable);
2247
2248 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "iter_args"))) {
2249 // Parse assignment list and results type list.
2250 if (parser.parseAssignmentList(lhs&: regionArgs, rhs&: operands) ||
2251 parser.parseArrowTypeList(result&: result.types))
2252 return failure();
2253 // Resolve input operands.
2254 for (auto argOperandType :
2255 llvm::zip(t: llvm::drop_begin(RangeOrContainer&: regionArgs), u&: operands, args&: result.types)) {
2256 Type type = std::get<2>(t&: argOperandType);
2257 std::get<0>(t&: argOperandType).type = type;
2258 if (parser.resolveOperand(operand: std::get<1>(t&: argOperandType), type,
2259 result&: result.operands))
2260 return failure();
2261 }
2262 }
2263
2264 result.addAttribute(
2265 name: getOperandSegmentSizeAttr(),
2266 attr: builder.getDenseI32ArrayAttr(values: {static_cast<int32_t>(numLbOperands),
2267 static_cast<int32_t>(numUbOperands),
2268 static_cast<int32_t>(operands.size())}));
2269
2270 // Parse the body region.
2271 Region *body = result.addRegion();
2272 if (regionArgs.size() != result.types.size() + 1)
2273 return parser.emitError(
2274 loc: parser.getNameLoc(),
2275 message: "mismatch between the number of loop-carried values and results");
2276 if (parser.parseRegion(region&: *body, arguments: regionArgs))
2277 return failure();
2278
2279 AffineForOp::ensureTerminator(region&: *body, builder, loc: result.location);
2280
2281 // Parse the optional attribute list.
2282 return parser.parseOptionalAttrDict(result&: result.attributes);
2283}
2284
2285static void printBound(AffineMapAttr boundMap,
2286 Operation::operand_range boundOperands,
2287 const char *prefix, OpAsmPrinter &p) {
2288 AffineMap map = boundMap.getValue();
2289
2290 // Check if this bound should be printed using custom assembly form.
2291 // The decision to restrict printing custom assembly form to trivial cases
2292 // comes from the will to roundtrip MLIR binary -> text -> binary in a
2293 // lossless way.
2294 // Therefore, custom assembly form parsing and printing is only supported for
2295 // zero-operand constant maps and single symbol operand identity maps.
2296 if (map.getNumResults() == 1) {
2297 AffineExpr expr = map.getResult(idx: 0);
2298
2299 // Print constant bound.
2300 if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
2301 if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: expr)) {
2302 p << constExpr.getValue();
2303 return;
2304 }
2305 }
2306
2307 // Print bound that consists of a single SSA symbol if the map is over a
2308 // single symbol.
2309 if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
2310 if (isa<AffineSymbolExpr>(Val: expr)) {
2311 p.printOperand(value: *boundOperands.begin());
2312 return;
2313 }
2314 }
2315 } else {
2316 // Map has multiple results. Print 'min' or 'max' prefix.
2317 p << prefix << ' ';
2318 }
2319
2320 // Print the map and its operands.
2321 p << boundMap;
2322 printDimAndSymbolList(begin: boundOperands.begin(), end: boundOperands.end(),
2323 numDims: map.getNumDims(), printer&: p);
2324}
2325
2326unsigned AffineForOp::getNumIterOperands() {
2327 AffineMap lbMap = getLowerBoundMapAttr().getValue();
2328 AffineMap ubMap = getUpperBoundMapAttr().getValue();
2329
2330 return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
2331}
2332
2333std::optional<MutableArrayRef<OpOperand>>
2334AffineForOp::getYieldedValuesMutable() {
2335 return cast<AffineYieldOp>(Val: getBody()->getTerminator()).getOperandsMutable();
2336}
2337
2338void AffineForOp::print(OpAsmPrinter &p) {
2339 p << ' ';
2340 p.printRegionArgument(arg: getBody()->getArgument(i: 0), /*argAttrs=*/{},
2341 /*omitType=*/true);
2342 p << " = ";
2343 printBound(boundMap: getLowerBoundMapAttr(), boundOperands: getLowerBoundOperands(), prefix: "max", p);
2344 p << " to ";
2345 printBound(boundMap: getUpperBoundMapAttr(), boundOperands: getUpperBoundOperands(), prefix: "min", p);
2346
2347 if (getStepAsInt() != 1)
2348 p << " step " << getStepAsInt();
2349
2350 bool printBlockTerminators = false;
2351 if (getNumIterOperands() > 0) {
2352 p << " iter_args(";
2353 auto regionArgs = getRegionIterArgs();
2354 auto operands = getInits();
2355
2356 llvm::interleaveComma(c: llvm::zip(t&: regionArgs, u&: operands), os&: p, each_fn: [&](auto it) {
2357 p << std::get<0>(it) << " = " << std::get<1>(it);
2358 });
2359 p << ") -> (" << getResultTypes() << ")";
2360 printBlockTerminators = true;
2361 }
2362
2363 p << ' ';
2364 p.printRegion(blocks&: getRegion(), /*printEntryBlockArgs=*/false,
2365 printBlockTerminators);
2366 p.printOptionalAttrDict(
2367 attrs: (*this)->getAttrs(),
2368 /*elidedAttrs=*/{getLowerBoundMapAttrName(name: getOperation()->getName()),
2369 getUpperBoundMapAttrName(name: getOperation()->getName()),
2370 getStepAttrName(name: getOperation()->getName()),
2371 getOperandSegmentSizeAttr()});
2372}
2373
2374/// Fold the constant bounds of a loop.
2375static LogicalResult foldLoopBounds(AffineForOp forOp) {
2376 auto foldLowerOrUpperBound = [&forOp](bool lower) {
2377 // Check to see if each of the operands is the result of a constant. If
2378 // so, get the value. If not, ignore it.
2379 SmallVector<Attribute, 8> operandConstants;
2380 auto boundOperands =
2381 lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2382 for (auto operand : boundOperands) {
2383 Attribute operandCst;
2384 matchPattern(value: operand, pattern: m_Constant(bind_value: &operandCst));
2385 operandConstants.push_back(Elt: operandCst);
2386 }
2387
2388 AffineMap boundMap =
2389 lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2390 assert(boundMap.getNumResults() >= 1 &&
2391 "bound maps should have at least one result");
2392 SmallVector<Attribute, 4> foldedResults;
2393 if (failed(Result: boundMap.constantFold(operandConstants, results&: foldedResults)))
2394 return failure();
2395
2396 // Compute the max or min as applicable over the results.
2397 assert(!foldedResults.empty() && "bounds should have at least one result");
2398 auto maxOrMin = llvm::cast<IntegerAttr>(Val&: foldedResults[0]).getValue();
2399 for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2400 auto foldedResult = llvm::cast<IntegerAttr>(Val&: foldedResults[i]).getValue();
2401 maxOrMin = lower ? llvm::APIntOps::smax(A: maxOrMin, B: foldedResult)
2402 : llvm::APIntOps::smin(A: maxOrMin, B: foldedResult);
2403 }
2404 lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2405 : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2406 return success();
2407 };
2408
2409 // Try to fold the lower bound.
2410 bool folded = false;
2411 if (!forOp.hasConstantLowerBound())
2412 folded |= succeeded(Result: foldLowerOrUpperBound(/*lower=*/true));
2413
2414 // Try to fold the upper bound.
2415 if (!forOp.hasConstantUpperBound())
2416 folded |= succeeded(Result: foldLowerOrUpperBound(/*lower=*/false));
2417 return success(IsSuccess: folded);
2418}
2419
2420/// Canonicalize the bounds of the given loop.
2421static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
2422 SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
2423 SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
2424
2425 auto lbMap = forOp.getLowerBoundMap();
2426 auto ubMap = forOp.getUpperBoundMap();
2427 auto prevLbMap = lbMap;
2428 auto prevUbMap = ubMap;
2429
2430 composeAffineMapAndOperands(map: &lbMap, operands: &lbOperands);
2431 canonicalizeMapAndOperands(map: &lbMap, operands: &lbOperands);
2432 simplifyMinOrMaxExprWithOperands(map&: lbMap, operands: lbOperands, /*isMax=*/true);
2433 simplifyMinOrMaxExprWithOperands(map&: ubMap, operands: ubOperands, /*isMax=*/false);
2434 lbMap = removeDuplicateExprs(map: lbMap);
2435
2436 composeAffineMapAndOperands(map: &ubMap, operands: &ubOperands);
2437 canonicalizeMapAndOperands(map: &ubMap, operands: &ubOperands);
2438 ubMap = removeDuplicateExprs(map: ubMap);
2439
2440 // Any canonicalization change always leads to updated map(s).
2441 if (lbMap == prevLbMap && ubMap == prevUbMap)
2442 return failure();
2443
2444 if (lbMap != prevLbMap)
2445 forOp.setLowerBound(operands: lbOperands, map: lbMap);
2446 if (ubMap != prevUbMap)
2447 forOp.setUpperBound(operands: ubOperands, map: ubMap);
2448 return success();
2449}
2450
2451namespace {
2452/// Returns constant trip count in trivial cases.
2453static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2454 int64_t step = forOp.getStepAsInt();
2455 if (!forOp.hasConstantBounds() || step <= 0)
2456 return std::nullopt;
2457 int64_t lb = forOp.getConstantLowerBound();
2458 int64_t ub = forOp.getConstantUpperBound();
2459 return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2460}
2461
2462/// This is a pattern to fold trivially empty loop bodies.
2463/// TODO: This should be moved into the folding hook.
2464struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
2465 using OpRewritePattern<AffineForOp>::OpRewritePattern;
2466
2467 LogicalResult matchAndRewrite(AffineForOp forOp,
2468 PatternRewriter &rewriter) const override {
2469 // Check that the body only contains a yield.
2470 if (!llvm::hasSingleElement(C&: *forOp.getBody()))
2471 return failure();
2472 if (forOp.getNumResults() == 0)
2473 return success();
2474 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2475 if (tripCount == 0) {
2476 // The initial values of the iteration arguments would be the op's
2477 // results.
2478 rewriter.replaceOp(op: forOp, newValues: forOp.getInits());
2479 return success();
2480 }
2481 SmallVector<Value, 4> replacements;
2482 auto yieldOp = cast<AffineYieldOp>(Val: forOp.getBody()->getTerminator());
2483 auto iterArgs = forOp.getRegionIterArgs();
2484 bool hasValDefinedOutsideLoop = false;
2485 bool iterArgsNotInOrder = false;
2486 for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2487 Value val = yieldOp.getOperand(i);
2488 auto *iterArgIt = llvm::find(Range&: iterArgs, Val: val);
2489 // TODO: It should be possible to perform a replacement by computing the
2490 // last value of the IV based on the bounds and the step.
2491 if (val == forOp.getInductionVar())
2492 return failure();
2493 if (iterArgIt == iterArgs.end()) {
2494 // `val` is defined outside of the loop.
2495 assert(forOp.isDefinedOutsideOfLoop(val) &&
2496 "must be defined outside of the loop");
2497 hasValDefinedOutsideLoop = true;
2498 replacements.push_back(Elt: val);
2499 } else {
2500 unsigned pos = std::distance(first: iterArgs.begin(), last: iterArgIt);
2501 if (pos != i)
2502 iterArgsNotInOrder = true;
2503 replacements.push_back(Elt: forOp.getInits()[pos]);
2504 }
2505 }
2506 // Bail out when the trip count is unknown and the loop returns any value
2507 // defined outside of the loop or any iterArg out of order.
2508 if (!tripCount.has_value() &&
2509 (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2510 return failure();
2511 // Bail out when the loop iterates more than once and it returns any iterArg
2512 // out of order.
2513 if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2514 return failure();
2515 rewriter.replaceOp(op: forOp, newValues: replacements);
2516 return success();
2517 }
2518};
2519} // namespace
2520
2521void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
2522 MLIRContext *context) {
2523 results.add<AffineForEmptyLoopFolder>(arg&: context);
2524}
2525
2526OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2527 assert((point.isParent() || point == getRegion()) && "invalid region point");
2528
2529 // The initial operands map to the loop arguments after the induction
2530 // variable or are forwarded to the results when the trip count is zero.
2531 return getInits();
2532}
2533
2534void AffineForOp::getSuccessorRegions(
2535 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2536 assert((point.isParent() || point == getRegion()) && "expected loop region");
2537 // The loop may typically branch back to its body or to the parent operation.
2538 // If the predecessor is the parent op and the trip count is known to be at
2539 // least one, branch into the body using the iterator arguments. And in cases
2540 // we know the trip count is zero, it can only branch back to its parent.
2541 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp: *this);
2542 if (point.isParent() && tripCount.has_value()) {
2543 if (tripCount.value() > 0) {
2544 regions.push_back(Elt: RegionSuccessor(&getRegion(), getRegionIterArgs()));
2545 return;
2546 }
2547 if (tripCount.value() == 0) {
2548 regions.push_back(Elt: RegionSuccessor(getResults()));
2549 return;
2550 }
2551 }
2552
2553 // From the loop body, if the trip count is one, we can only branch back to
2554 // the parent.
2555 if (!point.isParent() && tripCount == 1) {
2556 regions.push_back(Elt: RegionSuccessor(getResults()));
2557 return;
2558 }
2559
2560 // In all other cases, the loop may branch back to itself or the parent
2561 // operation.
2562 regions.push_back(Elt: RegionSuccessor(&getRegion(), getRegionIterArgs()));
2563 regions.push_back(Elt: RegionSuccessor(getResults()));
2564}
2565
2566/// Returns true if the affine.for has zero iterations in trivial cases.
2567static bool hasTrivialZeroTripCount(AffineForOp op) {
2568 return getTrivialConstantTripCount(forOp: op) == 0;
2569}
2570
2571LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2572 SmallVectorImpl<OpFoldResult> &results) {
2573 bool folded = succeeded(Result: foldLoopBounds(forOp: *this));
2574 folded |= succeeded(Result: canonicalizeLoopBounds(forOp: *this));
2575 if (hasTrivialZeroTripCount(op: *this) && getNumResults() != 0) {
2576 // The initial values of the loop-carried variables (iter_args) are the
2577 // results of the op. But this must be avoided for an affine.for op that
2578 // does not return any results. Since ops that do not return results cannot
2579 // be folded away, we would enter an infinite loop of folds on the same
2580 // affine.for op.
2581 results.assign(in_start: getInits().begin(), in_end: getInits().end());
2582 folded = true;
2583 }
2584 return success(IsSuccess: folded);
2585}
2586
2587AffineBound AffineForOp::getLowerBound() {
2588 return AffineBound(*this, getLowerBoundOperands(), getLowerBoundMap());
2589}
2590
2591AffineBound AffineForOp::getUpperBound() {
2592 return AffineBound(*this, getUpperBoundOperands(), getUpperBoundMap());
2593}
2594
2595void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
2596 assert(lbOperands.size() == map.getNumInputs());
2597 assert(map.getNumResults() >= 1 && "bound map has at least one result");
2598 getLowerBoundOperandsMutable().assign(values: lbOperands);
2599 setLowerBoundMap(map);
2600}
2601
2602void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
2603 assert(ubOperands.size() == map.getNumInputs());
2604 assert(map.getNumResults() >= 1 && "bound map has at least one result");
2605 getUpperBoundOperandsMutable().assign(values: ubOperands);
2606 setUpperBoundMap(map);
2607}
2608
2609bool AffineForOp::hasConstantLowerBound() {
2610 return getLowerBoundMap().isSingleConstant();
2611}
2612
2613bool AffineForOp::hasConstantUpperBound() {
2614 return getUpperBoundMap().isSingleConstant();
2615}
2616
2617int64_t AffineForOp::getConstantLowerBound() {
2618 return getLowerBoundMap().getSingleConstantResult();
2619}
2620
2621int64_t AffineForOp::getConstantUpperBound() {
2622 return getUpperBoundMap().getSingleConstantResult();
2623}
2624
2625void AffineForOp::setConstantLowerBound(int64_t value) {
2626 setLowerBound(lbOperands: {}, map: AffineMap::getConstantMap(val: value, context: getContext()));
2627}
2628
2629void AffineForOp::setConstantUpperBound(int64_t value) {
2630 setUpperBound(ubOperands: {}, map: AffineMap::getConstantMap(val: value, context: getContext()));
2631}
2632
2633AffineForOp::operand_range AffineForOp::getControlOperands() {
2634 return {operand_begin(), operand_begin() + getLowerBoundOperands().size() +
2635 getUpperBoundOperands().size()};
2636}
2637
2638bool AffineForOp::matchingBoundOperandList() {
2639 auto lbMap = getLowerBoundMap();
2640 auto ubMap = getUpperBoundMap();
2641 if (lbMap.getNumDims() != ubMap.getNumDims() ||
2642 lbMap.getNumSymbols() != ubMap.getNumSymbols())
2643 return false;
2644
2645 unsigned numOperands = lbMap.getNumInputs();
2646 for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
2647 // Compare Value 's.
2648 if (getOperand(i) != getOperand(i: numOperands + i))
2649 return false;
2650 }
2651 return true;
2652}
2653
2654SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
2655
2656std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
2657 return SmallVector<Value>{getInductionVar()};
2658}
2659
2660std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
2661 if (!hasConstantLowerBound())
2662 return std::nullopt;
2663 OpBuilder b(getContext());
2664 return SmallVector<OpFoldResult>{
2665 OpFoldResult(b.getI64IntegerAttr(value: getConstantLowerBound()))};
2666}
2667
2668std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
2669 OpBuilder b(getContext());
2670 return SmallVector<OpFoldResult>{
2671 OpFoldResult(b.getI64IntegerAttr(value: getStepAsInt()))};
2672}
2673
2674std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
2675 if (!hasConstantUpperBound())
2676 return {};
2677 OpBuilder b(getContext());
2678 return SmallVector<OpFoldResult>{
2679 OpFoldResult(b.getI64IntegerAttr(value: getConstantUpperBound()))};
2680}
2681
2682FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2683 RewriterBase &rewriter, ValueRange newInitOperands,
2684 bool replaceInitOperandUsesInLoop,
2685 const NewYieldValuesFn &newYieldValuesFn) {
2686 // Create a new loop before the existing one, with the extra operands.
2687 OpBuilder::InsertionGuard g(rewriter);
2688 rewriter.setInsertionPoint(getOperation());
2689 auto inits = llvm::to_vector(Range: getInits());
2690 inits.append(in_start: newInitOperands.begin(), in_end: newInitOperands.end());
2691 AffineForOp newLoop = rewriter.create<AffineForOp>(
2692 location: getLoc(), args: getLowerBoundOperands(), args: getLowerBoundMap(),
2693 args: getUpperBoundOperands(), args: getUpperBoundMap(), args: getStepAsInt(), args&: inits);
2694
2695 // Generate the new yield values and append them to the scf.yield operation.
2696 auto yieldOp = cast<AffineYieldOp>(Val: getBody()->getTerminator());
2697 ArrayRef<BlockArgument> newIterArgs =
2698 newLoop.getBody()->getArguments().take_back(N: newInitOperands.size());
2699 {
2700 OpBuilder::InsertionGuard g(rewriter);
2701 rewriter.setInsertionPoint(yieldOp);
2702 SmallVector<Value> newYieldedValues =
2703 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2704 assert(newInitOperands.size() == newYieldedValues.size() &&
2705 "expected as many new yield values as new iter operands");
2706 rewriter.modifyOpInPlace(root: yieldOp, callable: [&]() {
2707 yieldOp.getOperandsMutable().append(values: newYieldedValues);
2708 });
2709 }
2710
2711 // Move the loop body to the new op.
2712 rewriter.mergeBlocks(source: getBody(), dest: newLoop.getBody(),
2713 argValues: newLoop.getBody()->getArguments().take_front(
2714 N: getBody()->getNumArguments()));
2715
2716 if (replaceInitOperandUsesInLoop) {
2717 // Replace all uses of `newInitOperands` with the corresponding basic block
2718 // arguments.
2719 for (auto it : llvm::zip(t&: newInitOperands, u&: newIterArgs)) {
2720 rewriter.replaceUsesWithIf(from: std::get<0>(t&: it), to: std::get<1>(t&: it),
2721 functor: [&](OpOperand &use) {
2722 Operation *user = use.getOwner();
2723 return newLoop->isProperAncestor(other: user);
2724 });
2725 }
2726 }
2727
2728 // Replace the old loop.
2729 rewriter.replaceOp(op: getOperation(),
2730 newValues: newLoop->getResults().take_front(n: getNumResults()));
2731 return cast<LoopLikeOpInterface>(Val: newLoop.getOperation());
2732}
2733
2734Speculation::Speculatability AffineForOp::getSpeculatability() {
2735 // `affine.for (I = Start; I < End; I += 1)` terminates for all values of
2736 // Start and End.
2737 //
2738 // For Step != 1, the loop may not terminate. We can add more smarts here if
2739 // needed.
2740 return getStepAsInt() == 1 ? Speculation::RecursivelySpeculatable
2741 : Speculation::NotSpeculatable;
2742}
2743
2744/// Returns true if the provided value is the induction variable of a
2745/// AffineForOp.
2746bool mlir::affine::isAffineForInductionVar(Value val) {
2747 return getForInductionVarOwner(val) != AffineForOp();
2748}
2749
2750bool mlir::affine::isAffineParallelInductionVar(Value val) {
2751 return getAffineParallelInductionVarOwner(val) != nullptr;
2752}
2753
2754bool mlir::affine::isAffineInductionVar(Value val) {
2755 return isAffineForInductionVar(val) || isAffineParallelInductionVar(val);
2756}
2757
2758AffineForOp mlir::affine::getForInductionVarOwner(Value val) {
2759 auto ivArg = llvm::dyn_cast<BlockArgument>(Val&: val);
2760 if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())
2761 return AffineForOp();
2762 if (auto forOp =
2763 ivArg.getOwner()->getParent()->getParentOfType<AffineForOp>())
2764 // Check to make sure `val` is the induction variable, not an iter_arg.
2765 return forOp.getInductionVar() == val ? forOp : AffineForOp();
2766 return AffineForOp();
2767}
2768
2769AffineParallelOp mlir::affine::getAffineParallelInductionVarOwner(Value val) {
2770 auto ivArg = llvm::dyn_cast<BlockArgument>(Val&: val);
2771 if (!ivArg || !ivArg.getOwner())
2772 return nullptr;
2773 Operation *containingOp = ivArg.getOwner()->getParentOp();
2774 auto parallelOp = dyn_cast_if_present<AffineParallelOp>(Val: containingOp);
2775 if (parallelOp && llvm::is_contained(Range: parallelOp.getIVs(), Element: val))
2776 return parallelOp;
2777 return nullptr;
2778}
2779
2780/// Extracts the induction variables from a list of AffineForOps and returns
2781/// them.
2782void mlir::affine::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
2783 SmallVectorImpl<Value> *ivs) {
2784 ivs->reserve(N: forInsts.size());
2785 for (auto forInst : forInsts)
2786 ivs->push_back(Elt: forInst.getInductionVar());
2787}
2788
2789void mlir::affine::extractInductionVars(ArrayRef<mlir::Operation *> affineOps,
2790 SmallVectorImpl<mlir::Value> &ivs) {
2791 ivs.reserve(N: affineOps.size());
2792 for (Operation *op : affineOps) {
2793 // Add constraints from forOp's bounds.
2794 if (auto forOp = dyn_cast<AffineForOp>(Val: op))
2795 ivs.push_back(Elt: forOp.getInductionVar());
2796 else if (auto parallelOp = dyn_cast<AffineParallelOp>(Val: op))
2797 for (size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2798 ivs.push_back(Elt: parallelOp.getBody()->getArgument(i));
2799 }
2800}
2801
2802/// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
2803/// operations.
2804template <typename BoundListTy, typename LoopCreatorTy>
2805static void buildAffineLoopNestImpl(
2806 OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
2807 ArrayRef<int64_t> steps,
2808 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
2809 LoopCreatorTy &&loopCreatorFn) {
2810 assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
2811 assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
2812
2813 // If there are no loops to be constructed, construct the body anyway.
2814 OpBuilder::InsertionGuard guard(builder);
2815 if (lbs.empty()) {
2816 if (bodyBuilderFn)
2817 bodyBuilderFn(builder, loc, ValueRange());
2818 return;
2819 }
2820
2821 // Create the loops iteratively and store the induction variables.
2822 SmallVector<Value, 4> ivs;
2823 ivs.reserve(N: lbs.size());
2824 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
2825 // Callback for creating the loop body, always creates the terminator.
2826 auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
2827 ValueRange iterArgs) {
2828 ivs.push_back(Elt: iv);
2829 // In the innermost loop, call the body builder.
2830 if (i == e - 1 && bodyBuilderFn) {
2831 OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
2832 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2833 }
2834 nestedBuilder.create<AffineYieldOp>(location: nestedLoc);
2835 };
2836
2837 // Delegate actual loop creation to the callback in order to dispatch
2838 // between constant- and variable-bound loops.
2839 auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2840 builder.setInsertionPointToStart(loop.getBody());
2841 }
2842}
2843
2844/// Creates an affine loop from the bounds known to be constants.
2845static AffineForOp
2846buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb,
2847 int64_t ub, int64_t step,
2848 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2849 return builder.create<AffineForOp>(location: loc, args&: lb, args&: ub, args&: step,
2850 /*iterArgs=*/args: ValueRange(), args&: bodyBuilderFn);
2851}
2852
2853/// Creates an affine loop from the bounds that may or may not be constants.
2854static AffineForOp
2855buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub,
2856 int64_t step,
2857 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2858 std::optional<int64_t> lbConst = getConstantIntValue(ofr: lb);
2859 std::optional<int64_t> ubConst = getConstantIntValue(ofr: ub);
2860 if (lbConst && ubConst)
2861 return buildAffineLoopFromConstants(builder, loc, lb: lbConst.value(),
2862 ub: ubConst.value(), step, bodyBuilderFn);
2863 return builder.create<AffineForOp>(location: loc, args&: lb, args: builder.getDimIdentityMap(), args&: ub,
2864 args: builder.getDimIdentityMap(), args&: step,
2865 /*iterArgs=*/args: ValueRange(), args&: bodyBuilderFn);
2866}
2867
2868void mlir::affine::buildAffineLoopNest(
2869 OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
2870 ArrayRef<int64_t> ubs, ArrayRef<int64_t> steps,
2871 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2872 buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2873 loopCreatorFn&: buildAffineLoopFromConstants);
2874}
2875
2876void mlir::affine::buildAffineLoopNest(
2877 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
2878 ArrayRef<int64_t> steps,
2879 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2880 buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2881 loopCreatorFn&: buildAffineLoopFromValues);
2882}
2883
2884//===----------------------------------------------------------------------===//
2885// AffineIfOp
2886//===----------------------------------------------------------------------===//
2887
2888namespace {
2889/// Remove else blocks that have nothing other than a zero value yield.
2890struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
2891 using OpRewritePattern<AffineIfOp>::OpRewritePattern;
2892
2893 LogicalResult matchAndRewrite(AffineIfOp ifOp,
2894 PatternRewriter &rewriter) const override {
2895 if (ifOp.getElseRegion().empty() ||
2896 !llvm::hasSingleElement(C&: *ifOp.getElseBlock()) || ifOp.getNumResults())
2897 return failure();
2898
2899 rewriter.startOpModification(op: ifOp);
2900 rewriter.eraseBlock(block: ifOp.getElseBlock());
2901 rewriter.finalizeOpModification(op: ifOp);
2902 return success();
2903 }
2904};
2905
2906/// Removes affine.if cond if the condition is always true or false in certain
2907/// trivial cases. Promotes the then/else block in the parent operation block.
2908struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
2909 using OpRewritePattern<AffineIfOp>::OpRewritePattern;
2910
2911 LogicalResult matchAndRewrite(AffineIfOp op,
2912 PatternRewriter &rewriter) const override {
2913
2914 auto isTriviallyFalse = [](IntegerSet iSet) {
2915 return iSet.isEmptyIntegerSet();
2916 };
2917
2918 auto isTriviallyTrue = [](IntegerSet iSet) {
2919 return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2920 iSet.getConstraint(idx: 0) == 0);
2921 };
2922
2923 IntegerSet affineIfConditions = op.getIntegerSet();
2924 Block *blockToMove;
2925 if (isTriviallyFalse(affineIfConditions)) {
2926 // The absence, or equivalently, the emptiness of the else region need not
2927 // be checked when affine.if is returning results because if an affine.if
2928 // operation is returning results, it always has a non-empty else region.
2929 if (op.getNumResults() == 0 && !op.hasElse()) {
2930 // If the else region is absent, or equivalently, empty, remove the
2931 // affine.if operation (which is not returning any results).
2932 rewriter.eraseOp(op);
2933 return success();
2934 }
2935 blockToMove = op.getElseBlock();
2936 } else if (isTriviallyTrue(affineIfConditions)) {
2937 blockToMove = op.getThenBlock();
2938 } else {
2939 return failure();
2940 }
2941 Operation *blockToMoveTerminator = blockToMove->getTerminator();
2942 // Promote the "blockToMove" block to the parent operation block between the
2943 // prologue and epilogue of "op".
2944 rewriter.inlineBlockBefore(source: blockToMove, op);
2945 // Replace the "op" operation with the operands of the
2946 // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is
2947 // the affine.yield operation present in the "blockToMove" block. It has no
2948 // operands when affine.if is not returning results and therefore, in that
2949 // case, replaceOp just erases "op". When affine.if is not returning
2950 // results, the affine.yield operation can be omitted. It gets inserted
2951 // implicitly.
2952 rewriter.replaceOp(op, newValues: blockToMoveTerminator->getOperands());
2953 // Erase the "blockToMoveTerminator" operation since it is now in the parent
2954 // operation block, which already has its own terminator.
2955 rewriter.eraseOp(op: blockToMoveTerminator);
2956 return success();
2957 }
2958};
2959} // namespace
2960
2961/// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
2962/// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
2963void AffineIfOp::getSuccessorRegions(
2964 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2965 // If the predecessor is an AffineIfOp, then branching into both `then` and
2966 // `else` region is valid.
2967 if (point.isParent()) {
2968 regions.reserve(N: 2);
2969 regions.push_back(
2970 Elt: RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));
2971 // If the "else" region is empty, branch bach into parent.
2972 if (getElseRegion().empty()) {
2973 regions.push_back(Elt: getResults());
2974 } else {
2975 regions.push_back(
2976 Elt: RegionSuccessor(&getElseRegion(), getElseRegion().getArguments()));
2977 }
2978 return;
2979 }
2980
2981 // If the predecessor is the `else`/`then` region, then branching into parent
2982 // op is valid.
2983 regions.push_back(Elt: RegionSuccessor(getResults()));
2984}
2985
2986LogicalResult AffineIfOp::verify() {
2987 // Verify that we have a condition attribute.
2988 // FIXME: This should be specified in the arguments list in ODS.
2989 auto conditionAttr =
2990 (*this)->getAttrOfType<IntegerSetAttr>(name: getConditionAttrStrName());
2991 if (!conditionAttr)
2992 return emitOpError(message: "requires an integer set attribute named 'condition'");
2993
2994 // Verify that there are enough operands for the condition.
2995 IntegerSet condition = conditionAttr.getValue();
2996 if (getNumOperands() != condition.getNumInputs())
2997 return emitOpError(message: "operand count and condition integer set dimension and "
2998 "symbol count must match");
2999
3000 // Verify that the operands are valid dimension/symbols.
3001 if (failed(Result: verifyDimAndSymbolIdentifiers(op&: *this, operands: getOperands(),
3002 numDims: condition.getNumDims())))
3003 return failure();
3004
3005 return success();
3006}
3007
3008ParseResult AffineIfOp::parse(OpAsmParser &parser, OperationState &result) {
3009 // Parse the condition attribute set.
3010 IntegerSetAttr conditionAttr;
3011 unsigned numDims;
3012 if (parser.parseAttribute(result&: conditionAttr,
3013 attrName: AffineIfOp::getConditionAttrStrName(),
3014 attrs&: result.attributes) ||
3015 parseDimAndSymbolList(parser, operands&: result.operands, numDims))
3016 return failure();
3017
3018 // Verify the condition operands.
3019 auto set = conditionAttr.getValue();
3020 if (set.getNumDims() != numDims)
3021 return parser.emitError(
3022 loc: parser.getNameLoc(),
3023 message: "dim operand count and integer set dim count must match");
3024 if (numDims + set.getNumSymbols() != result.operands.size())
3025 return parser.emitError(
3026 loc: parser.getNameLoc(),
3027 message: "symbol operand count and integer set symbol count must match");
3028
3029 if (parser.parseOptionalArrowTypeList(result&: result.types))
3030 return failure();
3031
3032 // Create the regions for 'then' and 'else'. The latter must be created even
3033 // if it remains empty for the validity of the operation.
3034 result.regions.reserve(N: 2);
3035 Region *thenRegion = result.addRegion();
3036 Region *elseRegion = result.addRegion();
3037
3038 // Parse the 'then' region.
3039 if (parser.parseRegion(region&: *thenRegion, arguments: {}, enableNameShadowing: {}))
3040 return failure();
3041 AffineIfOp::ensureTerminator(region&: *thenRegion, builder&: parser.getBuilder(),
3042 loc: result.location);
3043
3044 // If we find an 'else' keyword then parse the 'else' region.
3045 if (!parser.parseOptionalKeyword(keyword: "else")) {
3046 if (parser.parseRegion(region&: *elseRegion, arguments: {}, enableNameShadowing: {}))
3047 return failure();
3048 AffineIfOp::ensureTerminator(region&: *elseRegion, builder&: parser.getBuilder(),
3049 loc: result.location);
3050 }
3051
3052 // Parse the optional attribute list.
3053 if (parser.parseOptionalAttrDict(result&: result.attributes))
3054 return failure();
3055
3056 return success();
3057}
3058
3059void AffineIfOp::print(OpAsmPrinter &p) {
3060 auto conditionAttr =
3061 (*this)->getAttrOfType<IntegerSetAttr>(name: getConditionAttrStrName());
3062 p << " " << conditionAttr;
3063 printDimAndSymbolList(begin: operand_begin(), end: operand_end(),
3064 numDims: conditionAttr.getValue().getNumDims(), printer&: p);
3065 p.printOptionalArrowTypeList(types: getResultTypes());
3066 p << ' ';
3067 p.printRegion(blocks&: getThenRegion(), /*printEntryBlockArgs=*/false,
3068 /*printBlockTerminators=*/getNumResults());
3069
3070 // Print the 'else' regions if it has any blocks.
3071 auto &elseRegion = this->getElseRegion();
3072 if (!elseRegion.empty()) {
3073 p << " else ";
3074 p.printRegion(blocks&: elseRegion,
3075 /*printEntryBlockArgs=*/false,
3076 /*printBlockTerminators=*/getNumResults());
3077 }
3078
3079 // Print the attribute list.
3080 p.printOptionalAttrDict(attrs: (*this)->getAttrs(),
3081 /*elidedAttrs=*/getConditionAttrStrName());
3082}
3083
3084IntegerSet AffineIfOp::getIntegerSet() {
3085 return (*this)
3086 ->getAttrOfType<IntegerSetAttr>(name: getConditionAttrStrName())
3087 .getValue();
3088}
3089
3090void AffineIfOp::setIntegerSet(IntegerSet newSet) {
3091 (*this)->setAttr(name: getConditionAttrStrName(), value: IntegerSetAttr::get(value: newSet));
3092}
3093
3094void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
3095 setIntegerSet(set);
3096 (*this)->setOperands(operands);
3097}
3098
3099void AffineIfOp::build(OpBuilder &builder, OperationState &result,
3100 TypeRange resultTypes, IntegerSet set, ValueRange args,
3101 bool withElseRegion) {
3102 assert(resultTypes.empty() || withElseRegion);
3103 OpBuilder::InsertionGuard guard(builder);
3104
3105 result.addTypes(newTypes&: resultTypes);
3106 result.addOperands(newOperands: args);
3107 result.addAttribute(name: getConditionAttrStrName(), attr: IntegerSetAttr::get(value: set));
3108
3109 Region *thenRegion = result.addRegion();
3110 builder.createBlock(parent: thenRegion);
3111 if (resultTypes.empty())
3112 AffineIfOp::ensureTerminator(region&: *thenRegion, builder, loc: result.location);
3113
3114 Region *elseRegion = result.addRegion();
3115 if (withElseRegion) {
3116 builder.createBlock(parent: elseRegion);
3117 if (resultTypes.empty())
3118 AffineIfOp::ensureTerminator(region&: *elseRegion, builder, loc: result.location);
3119 }
3120}
3121
3122void AffineIfOp::build(OpBuilder &builder, OperationState &result,
3123 IntegerSet set, ValueRange args, bool withElseRegion) {
3124 AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
3125 withElseRegion);
3126}
3127
3128/// Compose any affine.apply ops feeding into `operands` of the integer set
3129/// `set` by composing the maps of such affine.apply ops with the integer
3130/// set constraints.
3131static void composeSetAndOperands(IntegerSet &set,
3132 SmallVectorImpl<Value> &operands,
3133 bool composeAffineMin = false) {
3134 // We will simply reuse the API of the map composition by viewing the LHSs of
3135 // the equalities and inequalities of `set` as the affine exprs of an affine
3136 // map. Convert to equivalent map, compose, and convert back to set.
3137 auto map = AffineMap::get(dimCount: set.getNumDims(), symbolCount: set.getNumSymbols(),
3138 results: set.getConstraints(), context: set.getContext());
3139 // Check if any composition is possible.
3140 if (llvm::none_of(Range&: operands,
3141 P: [](Value v) { return v.getDefiningOp<AffineApplyOp>(); }))
3142 return;
3143
3144 composeAffineMapAndOperands(map: &map, operands: &operands, composeAffineMin);
3145 set = IntegerSet::get(dimCount: map.getNumDims(), symbolCount: map.getNumSymbols(), constraints: map.getResults(),
3146 eqFlags: set.getEqFlags());
3147}
3148
3149/// Canonicalize an affine if op's conditional (integer set + operands).
3150LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3151 auto set = getIntegerSet();
3152 SmallVector<Value, 4> operands(getOperands());
3153 composeSetAndOperands(set, operands);
3154 canonicalizeSetAndOperands(set: &set, operands: &operands);
3155
3156 // Check if the canonicalization or composition led to any change.
3157 if (getIntegerSet() == set && llvm::equal(LRange&: operands, RRange: getOperands()))
3158 return failure();
3159
3160 setConditional(set, operands);
3161 return success();
3162}
3163
3164void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
3165 MLIRContext *context) {
3166 results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(arg&: context);
3167}
3168
3169//===----------------------------------------------------------------------===//
3170// AffineLoadOp
3171//===----------------------------------------------------------------------===//
3172
3173void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3174 AffineMap map, ValueRange operands) {
3175 assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
3176 result.addOperands(newOperands: operands);
3177 if (map)
3178 result.addAttribute(name: getMapAttrStrName(), attr: AffineMapAttr::get(value: map));
3179 auto memrefType = llvm::cast<MemRefType>(Val: operands[0].getType());
3180 result.types.push_back(Elt: memrefType.getElementType());
3181}
3182
3183void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3184 Value memref, AffineMap map, ValueRange mapOperands) {
3185 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3186 result.addOperands(newOperands: memref);
3187 result.addOperands(newOperands: mapOperands);
3188 auto memrefType = llvm::cast<MemRefType>(Val: memref.getType());
3189 result.addAttribute(name: getMapAttrStrName(), attr: AffineMapAttr::get(value: map));
3190 result.types.push_back(Elt: memrefType.getElementType());
3191}
3192
3193void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3194 Value memref, ValueRange indices) {
3195 auto memrefType = llvm::cast<MemRefType>(Val: memref.getType());
3196 int64_t rank = memrefType.getRank();
3197 // Create identity map for memrefs with at least one dimension or () -> ()
3198 // for zero-dimensional memrefs.
3199 auto map =
3200 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3201 build(builder, result, memref, map, mapOperands: indices);
3202}
3203
3204ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
3205 auto &builder = parser.getBuilder();
3206 auto indexTy = builder.getIndexType();
3207
3208 MemRefType type;
3209 OpAsmParser::UnresolvedOperand memrefInfo;
3210 AffineMapAttr mapAttr;
3211 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
3212 return failure(
3213 IsFailure: parser.parseOperand(result&: memrefInfo) ||
3214 parser.parseAffineMapOfSSAIds(operands&: mapOperands, map&: mapAttr,
3215 attrName: AffineLoadOp::getMapAttrStrName(),
3216 attrs&: result.attributes) ||
3217 parser.parseOptionalAttrDict(result&: result.attributes) ||
3218 parser.parseColonType(result&: type) ||
3219 parser.resolveOperand(operand: memrefInfo, type, result&: result.operands) ||
3220 parser.resolveOperands(operands&: mapOperands, type: indexTy, result&: result.operands) ||
3221 parser.addTypeToList(type: type.getElementType(), result&: result.types));
3222}
3223
3224void AffineLoadOp::print(OpAsmPrinter &p) {
3225 p << " " << getMemRef() << '[';
3226 if (AffineMapAttr mapAttr =
3227 (*this)->getAttrOfType<AffineMapAttr>(name: getMapAttrStrName()))
3228 p.printAffineMapOfSSAIds(mapAttr, operands: getMapOperands());
3229 p << ']';
3230 p.printOptionalAttrDict(attrs: (*this)->getAttrs(),
3231 /*elidedAttrs=*/{getMapAttrStrName()});
3232 p << " : " << getMemRefType();
3233}
3234
3235/// Verify common indexing invariants of affine.load, affine.store,
3236/// affine.vector_load and affine.vector_store.
3237template <typename AffineMemOpTy>
3238static LogicalResult
3239verifyMemoryOpIndexing(AffineMemOpTy op, AffineMapAttr mapAttr,
3240 Operation::operand_range mapOperands,
3241 MemRefType memrefType, unsigned numIndexOperands) {
3242 AffineMap map = mapAttr.getValue();
3243 if (map.getNumResults() != memrefType.getRank())
3244 return op->emitOpError("affine map num results must equal memref rank");
3245 if (map.getNumInputs() != numIndexOperands)
3246 return op->emitOpError("expects as many subscripts as affine map inputs");
3247
3248 for (auto idx : mapOperands) {
3249 if (!idx.getType().isIndex())
3250 return op->emitOpError("index to load must have 'index' type");
3251 }
3252 if (failed(verifyDimAndSymbolIdentifiers(op, mapOperands, map.getNumDims())))
3253 return failure();
3254
3255 return success();
3256}
3257
3258LogicalResult AffineLoadOp::verify() {
3259 auto memrefType = getMemRefType();
3260 if (getType() != memrefType.getElementType())
3261 return emitOpError(message: "result type must match element type of memref");
3262
3263 if (failed(Result: verifyMemoryOpIndexing(
3264 op: *this, mapAttr: (*this)->getAttrOfType<AffineMapAttr>(name: getMapAttrStrName()),
3265 mapOperands: getMapOperands(), memrefType,
3266 /*numIndexOperands=*/getNumOperands() - 1)))
3267 return failure();
3268
3269 return success();
3270}
3271
3272void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3273 MLIRContext *context) {
3274 results.add<SimplifyAffineOp<AffineLoadOp>>(arg&: context);
3275}
3276
3277OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) {
3278 /// load(memrefcast) -> load
3279 if (succeeded(Result: memref::foldMemRefCast(op: *this)))
3280 return getResult();
3281
3282 // Fold load from a global constant memref.
3283 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3284 if (!getGlobalOp)
3285 return {};
3286 // Get to the memref.global defining the symbol.
3287 auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
3288 if (!symbolTableOp)
3289 return {};
3290 auto global = dyn_cast_or_null<memref::GlobalOp>(
3291 Val: SymbolTable::lookupSymbolIn(op: symbolTableOp, symbol: getGlobalOp.getNameAttr()));
3292 if (!global)
3293 return {};
3294
3295 // Check if the global memref is a constant.
3296 auto cstAttr =
3297 llvm::dyn_cast_or_null<DenseElementsAttr>(Val: global.getConstantInitValue());
3298 if (!cstAttr)
3299 return {};
3300 // If it's a splat constant, we can fold irrespective of indices.
3301 if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(Val&: cstAttr))
3302 return splatAttr.getSplatValue<Attribute>();
3303 // Otherwise, we can fold only if we know the indices.
3304 if (!getAffineMap().isConstant())
3305 return {};
3306 auto indices = llvm::to_vector<4>(
3307 Range: llvm::map_range(C: getAffineMap().getConstantResults(),
3308 F: [](int64_t v) -> uint64_t { return v; }));
3309 return cstAttr.getValues<Attribute>()[indices];
3310}
3311
3312//===----------------------------------------------------------------------===//
3313// AffineStoreOp
3314//===----------------------------------------------------------------------===//
3315
3316void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3317 Value valueToStore, Value memref, AffineMap map,
3318 ValueRange mapOperands) {
3319 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3320 result.addOperands(newOperands: valueToStore);
3321 result.addOperands(newOperands: memref);
3322 result.addOperands(newOperands: mapOperands);
3323 result.getOrAddProperties<Properties>().map = AffineMapAttr::get(value: map);
3324}
3325
3326// Use identity map.
3327void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3328 Value valueToStore, Value memref,
3329 ValueRange indices) {
3330 auto memrefType = llvm::cast<MemRefType>(Val: memref.getType());
3331 int64_t rank = memrefType.getRank();
3332 // Create identity map for memrefs with at least one dimension or () -> ()
3333 // for zero-dimensional memrefs.
3334 auto map =
3335 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3336 build(builder, result, valueToStore, memref, map, mapOperands: indices);
3337}
3338
3339ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
3340 auto indexTy = parser.getBuilder().getIndexType();
3341
3342 MemRefType type;
3343 OpAsmParser::UnresolvedOperand storeValueInfo;
3344 OpAsmParser::UnresolvedOperand memrefInfo;
3345 AffineMapAttr mapAttr;
3346 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
3347 return failure(IsFailure: parser.parseOperand(result&: storeValueInfo) || parser.parseComma() ||
3348 parser.parseOperand(result&: memrefInfo) ||
3349 parser.parseAffineMapOfSSAIds(
3350 operands&: mapOperands, map&: mapAttr, attrName: AffineStoreOp::getMapAttrStrName(),
3351 attrs&: result.attributes) ||
3352 parser.parseOptionalAttrDict(result&: result.attributes) ||
3353 parser.parseColonType(result&: type) ||
3354 parser.resolveOperand(operand: storeValueInfo, type: type.getElementType(),
3355 result&: result.operands) ||
3356 parser.resolveOperand(operand: memrefInfo, type, result&: result.operands) ||
3357 parser.resolveOperands(operands&: mapOperands, type: indexTy, result&: result.operands));
3358}
3359
3360void AffineStoreOp::print(OpAsmPrinter &p) {
3361 p << " " << getValueToStore();
3362 p << ", " << getMemRef() << '[';
3363 if (AffineMapAttr mapAttr =
3364 (*this)->getAttrOfType<AffineMapAttr>(name: getMapAttrStrName()))
3365 p.printAffineMapOfSSAIds(mapAttr, operands: getMapOperands());
3366 p << ']';
3367 p.printOptionalAttrDict(attrs: (*this)->getAttrs(),
3368 /*elidedAttrs=*/{getMapAttrStrName()});
3369 p << " : " << getMemRefType();
3370}
3371
3372LogicalResult AffineStoreOp::verify() {
3373 // The value to store must have the same type as memref element type.
3374 auto memrefType = getMemRefType();
3375 if (getValueToStore().getType() != memrefType.getElementType())
3376 return emitOpError(
3377 message: "value to store must have the same type as memref element type");
3378
3379 if (failed(Result: verifyMemoryOpIndexing(
3380 op: *this, mapAttr: (*this)->getAttrOfType<AffineMapAttr>(name: getMapAttrStrName()),
3381 mapOperands: getMapOperands(), memrefType,
3382 /*numIndexOperands=*/getNumOperands() - 2)))
3383 return failure();
3384
3385 return success();
3386}
3387
3388void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3389 MLIRContext *context) {
3390 results.add<SimplifyAffineOp<AffineStoreOp>>(arg&: context);
3391}
3392
3393LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3394 SmallVectorImpl<OpFoldResult> &results) {
3395 /// store(memrefcast) -> store
3396 return memref::foldMemRefCast(op: *this, inner: getValueToStore());
3397}
3398
3399//===----------------------------------------------------------------------===//
3400// AffineMinMaxOpBase
3401//===----------------------------------------------------------------------===//
3402
3403template <typename T>
3404static LogicalResult verifyAffineMinMaxOp(T op) {
3405 // Verify that operand count matches affine map dimension and symbol count.
3406 if (op.getNumOperands() !=
3407 op.getMap().getNumDims() + op.getMap().getNumSymbols())
3408 return op.emitOpError(
3409 "operand count and affine map dimension and symbol count must match");
3410
3411 if (op.getMap().getNumResults() == 0)
3412 return op.emitOpError("affine map expect at least one result");
3413 return success();
3414}
3415
3416template <typename T>
3417static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
3418 p << ' ' << op->getAttr(T::getMapAttrStrName());
3419 auto operands = op.getOperands();
3420 unsigned numDims = op.getMap().getNumDims();
3421 p << '(' << operands.take_front(numDims) << ')';
3422
3423 if (operands.size() != numDims)
3424 p << '[' << operands.drop_front(numDims) << ']';
3425 p.printOptionalAttrDict(attrs: op->getAttrs(),
3426 /*elidedAttrs=*/{T::getMapAttrStrName()});
3427}
3428
3429template <typename T>
3430static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
3431 OperationState &result) {
3432 auto &builder = parser.getBuilder();
3433 auto indexType = builder.getIndexType();
3434 SmallVector<OpAsmParser::UnresolvedOperand, 8> dimInfos;
3435 SmallVector<OpAsmParser::UnresolvedOperand, 8> symInfos;
3436 AffineMapAttr mapAttr;
3437 return failure(
3438 parser.parseAttribute(mapAttr, T::getMapAttrStrName(),
3439 result.attributes) ||
3440 parser.parseOperandList(result&: dimInfos, delimiter: OpAsmParser::Delimiter::Paren) ||
3441 parser.parseOperandList(result&: symInfos,
3442 delimiter: OpAsmParser::Delimiter::OptionalSquare) ||
3443 parser.parseOptionalAttrDict(result&: result.attributes) ||
3444 parser.resolveOperands(operands&: dimInfos, type: indexType, result&: result.operands) ||
3445 parser.resolveOperands(operands&: symInfos, type: indexType, result&: result.operands) ||
3446 parser.addTypeToList(type: indexType, result&: result.types));
3447}
3448
3449/// Fold an affine min or max operation with the given operands. The operand
3450/// list may contain nulls, which are interpreted as the operand not being a
3451/// constant.
3452template <typename T>
3453static OpFoldResult foldMinMaxOp(T op, ArrayRef<Attribute> operands) {
3454 static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3455 "expected affine min or max op");
3456
3457 // Fold the affine map.
3458 // TODO: Fold more cases:
3459 // min(some_affine, some_affine + constant, ...), etc.
3460 SmallVector<int64_t, 2> results;
3461 auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3462
3463 if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3464 return op.getOperand(0);
3465
3466 // If some of the map results are not constant, try changing the map in-place.
3467 if (results.empty()) {
3468 // If the map is the same, report that folding did not happen.
3469 if (foldedMap == op.getMap())
3470 return {};
3471 op->setAttr("map", AffineMapAttr::get(value: foldedMap));
3472 return op.getResult();
3473 }
3474
3475 // Otherwise, completely fold the op into a constant.
3476 auto resultIt = std::is_same<T, AffineMinOp>::value
3477 ? llvm::min_element(Range&: results)
3478 : llvm::max_element(Range&: results);
3479 if (resultIt == results.end())
3480 return {};
3481 return IntegerAttr::get(IndexType::get(context: op.getContext()), *resultIt);
3482}
3483
3484/// Remove duplicated expressions in affine min/max ops.
3485template <typename T>
3486struct DeduplicateAffineMinMaxExpressions : public OpRewritePattern<T> {
3487 using OpRewritePattern<T>::OpRewritePattern;
3488
3489 LogicalResult matchAndRewrite(T affineOp,
3490 PatternRewriter &rewriter) const override {
3491 AffineMap oldMap = affineOp.getAffineMap();
3492
3493 SmallVector<AffineExpr, 4> newExprs;
3494 for (AffineExpr expr : oldMap.getResults()) {
3495 // This is a linear scan over newExprs, but it should be fine given that
3496 // we typically just have a few expressions per op.
3497 if (!llvm::is_contained(Range&: newExprs, Element: expr))
3498 newExprs.push_back(Elt: expr);
3499 }
3500
3501 if (newExprs.size() == oldMap.getNumResults())
3502 return failure();
3503
3504 auto newMap = AffineMap::get(dimCount: oldMap.getNumDims(), symbolCount: oldMap.getNumSymbols(),
3505 results: newExprs, context: rewriter.getContext());
3506 rewriter.replaceOpWithNewOp<T>(affineOp, newMap, affineOp.getMapOperands());
3507
3508 return success();
3509 }
3510};
3511
3512/// Merge an affine min/max op to its consumers if its consumer is also an
3513/// affine min/max op.
3514///
3515/// This pattern requires the producer affine min/max op is bound to a
3516/// dimension/symbol that is used as a standalone expression in the consumer
3517/// affine op's map.
3518///
3519/// For example, a pattern like the following:
3520///
3521/// %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1]
3522/// %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2]
3523///
3524/// Can be turned into:
3525///
3526/// %1 = affine.min affine_map<
3527/// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1]
3528template <typename T>
3529struct MergeAffineMinMaxOp : public OpRewritePattern<T> {
3530 using OpRewritePattern<T>::OpRewritePattern;
3531
3532 LogicalResult matchAndRewrite(T affineOp,
3533 PatternRewriter &rewriter) const override {
3534 AffineMap oldMap = affineOp.getAffineMap();
3535 ValueRange dimOperands =
3536 affineOp.getMapOperands().take_front(oldMap.getNumDims());
3537 ValueRange symOperands =
3538 affineOp.getMapOperands().take_back(oldMap.getNumSymbols());
3539
3540 auto newDimOperands = llvm::to_vector<8>(Range&: dimOperands);
3541 auto newSymOperands = llvm::to_vector<8>(Range&: symOperands);
3542 SmallVector<AffineExpr, 4> newExprs;
3543 SmallVector<T, 4> producerOps;
3544
3545 // Go over each expression to see whether it's a single dimension/symbol
3546 // with the corresponding operand which is the result of another affine
3547 // min/max op. If So it can be merged into this affine op.
3548 for (AffineExpr expr : oldMap.getResults()) {
3549 if (auto symExpr = dyn_cast<AffineSymbolExpr>(Val&: expr)) {
3550 Value symValue = symOperands[symExpr.getPosition()];
3551 if (auto producerOp = symValue.getDefiningOp<T>()) {
3552 producerOps.push_back(producerOp);
3553 continue;
3554 }
3555 } else if (auto dimExpr = dyn_cast<AffineDimExpr>(Val&: expr)) {
3556 Value dimValue = dimOperands[dimExpr.getPosition()];
3557 if (auto producerOp = dimValue.getDefiningOp<T>()) {
3558 producerOps.push_back(producerOp);
3559 continue;
3560 }
3561 }
3562 // For the above cases we will remove the expression by merging the
3563 // producer affine min/max's affine expressions. Otherwise we need to
3564 // keep the existing expression.
3565 newExprs.push_back(Elt: expr);
3566 }
3567
3568 if (producerOps.empty())
3569 return failure();
3570
3571 unsigned numUsedDims = oldMap.getNumDims();
3572 unsigned numUsedSyms = oldMap.getNumSymbols();
3573
3574 // Now go over all producer affine ops and merge their expressions.
3575 for (T producerOp : producerOps) {
3576 AffineMap producerMap = producerOp.getAffineMap();
3577 unsigned numProducerDims = producerMap.getNumDims();
3578 unsigned numProducerSyms = producerMap.getNumSymbols();
3579
3580 // Collect all dimension/symbol values.
3581 ValueRange dimValues =
3582 producerOp.getMapOperands().take_front(numProducerDims);
3583 ValueRange symValues =
3584 producerOp.getMapOperands().take_back(numProducerSyms);
3585 newDimOperands.append(in_start: dimValues.begin(), in_end: dimValues.end());
3586 newSymOperands.append(in_start: symValues.begin(), in_end: symValues.end());
3587
3588 // For expressions we need to shift to avoid overlap.
3589 for (AffineExpr expr : producerMap.getResults()) {
3590 newExprs.push_back(Elt: expr.shiftDims(numDims: numProducerDims, shift: numUsedDims)
3591 .shiftSymbols(numSymbols: numProducerSyms, shift: numUsedSyms));
3592 }
3593
3594 numUsedDims += numProducerDims;
3595 numUsedSyms += numProducerSyms;
3596 }
3597
3598 auto newMap = AffineMap::get(dimCount: numUsedDims, symbolCount: numUsedSyms, results: newExprs,
3599 context: rewriter.getContext());
3600 auto newOperands =
3601 llvm::to_vector<8>(Range: llvm::concat<Value>(Ranges&: newDimOperands, Ranges&: newSymOperands));
3602 rewriter.replaceOpWithNewOp<T>(affineOp, newMap, newOperands);
3603
3604 return success();
3605 }
3606};
3607
3608/// Canonicalize the result expression order of an affine map and return success
3609/// if the order changed.
3610///
3611/// The function flattens the map's affine expressions to coefficient arrays and
3612/// sorts them in lexicographic order. A coefficient array contains a multiplier
3613/// for every dimension/symbol and a constant term. The canonicalization fails
3614/// if a result expression is not pure or if the flattening requires local
3615/// variables that, unlike dimensions and symbols, have no global order.
3616static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map) {
3617 SmallVector<SmallVector<int64_t>> flattenedExprs;
3618 for (const AffineExpr &resultExpr : map.getResults()) {
3619 // Fail if the expression is not pure.
3620 if (!resultExpr.isPureAffine())
3621 return failure();
3622
3623 SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols());
3624 auto flattenResult = flattener.walkPostOrder(expr: resultExpr);
3625 if (failed(Result: flattenResult))
3626 return failure();
3627
3628 // Fail if the flattened expression has local variables.
3629 if (flattener.operandExprStack.back().size() !=
3630 map.getNumDims() + map.getNumSymbols() + 1)
3631 return failure();
3632
3633 flattenedExprs.emplace_back(Args: flattener.operandExprStack.back().begin(),
3634 Args: flattener.operandExprStack.back().end());
3635 }
3636
3637 // Fail if sorting is not necessary.
3638 if (llvm::is_sorted(Range&: flattenedExprs))
3639 return failure();
3640
3641 // Reorder the result expressions according to their flattened form.
3642 SmallVector<unsigned> resultPermutation =
3643 llvm::to_vector(Range: llvm::seq<unsigned>(Begin: 0, End: map.getNumResults()));
3644 llvm::sort(C&: resultPermutation, Comp: [&](unsigned lhs, unsigned rhs) {
3645 return flattenedExprs[lhs] < flattenedExprs[rhs];
3646 });
3647 SmallVector<AffineExpr> newExprs;
3648 for (unsigned idx : resultPermutation)
3649 newExprs.push_back(Elt: map.getResult(idx));
3650
3651 map = AffineMap::get(dimCount: map.getNumDims(), symbolCount: map.getNumSymbols(), results: newExprs,
3652 context: map.getContext());
3653 return success();
3654}
3655
3656/// Canonicalize the affine map result expression order of an affine min/max
3657/// operation.
3658///
3659/// The pattern calls `canonicalizeMapExprAndTermOrder` to order the result
3660/// expressions and replaces the operation if the order changed.
3661///
3662/// For example, the following operation:
3663///
3664/// %0 = affine.min affine_map<(d0, d1) -> (d0 + d1, d1 + 16, 32)> (%i0, %i1)
3665///
3666/// Turns into:
3667///
3668/// %0 = affine.min affine_map<(d0, d1) -> (32, d1 + 16, d0 + d1)> (%i0, %i1)
3669template <typename T>
3670struct CanonicalizeAffineMinMaxOpExprAndTermOrder : public OpRewritePattern<T> {
3671 using OpRewritePattern<T>::OpRewritePattern;
3672
3673 LogicalResult matchAndRewrite(T affineOp,
3674 PatternRewriter &rewriter) const override {
3675 AffineMap map = affineOp.getAffineMap();
3676 if (failed(Result: canonicalizeMapExprAndTermOrder(map)))
3677 return failure();
3678 rewriter.replaceOpWithNewOp<T>(affineOp, map, affineOp.getMapOperands());
3679 return success();
3680 }
3681};
3682
3683template <typename T>
3684struct CanonicalizeSingleResultAffineMinMaxOp : public OpRewritePattern<T> {
3685 using OpRewritePattern<T>::OpRewritePattern;
3686
3687 LogicalResult matchAndRewrite(T affineOp,
3688 PatternRewriter &rewriter) const override {
3689 if (affineOp.getMap().getNumResults() != 1)
3690 return failure();
3691 rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.getMap(),
3692 affineOp.getOperands());
3693 return success();
3694 }
3695};
3696
3697//===----------------------------------------------------------------------===//
3698// AffineMinOp
3699//===----------------------------------------------------------------------===//
3700//
3701// %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
3702//
3703
3704OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) {
3705 return foldMinMaxOp(op: *this, operands: adaptor.getOperands());
3706}
3707
3708void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3709 MLIRContext *context) {
3710 patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMinOp>,
3711 DeduplicateAffineMinMaxExpressions<AffineMinOp>,
3712 MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>,
3713 CanonicalizeAffineMinMaxOpExprAndTermOrder<AffineMinOp>>(
3714 arg&: context);
3715}
3716
3717LogicalResult AffineMinOp::verify() { return verifyAffineMinMaxOp(op: *this); }
3718
3719ParseResult AffineMinOp::parse(OpAsmParser &parser, OperationState &result) {
3720 return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3721}
3722
3723void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, op: *this); }
3724
3725//===----------------------------------------------------------------------===//
3726// AffineMaxOp
3727//===----------------------------------------------------------------------===//
3728//
3729// %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
3730//
3731
3732OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) {
3733 return foldMinMaxOp(op: *this, operands: adaptor.getOperands());
3734}
3735
3736void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3737 MLIRContext *context) {
3738 patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMaxOp>,
3739 DeduplicateAffineMinMaxExpressions<AffineMaxOp>,
3740 MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>,
3741 CanonicalizeAffineMinMaxOpExprAndTermOrder<AffineMaxOp>>(
3742 arg&: context);
3743}
3744
3745LogicalResult AffineMaxOp::verify() { return verifyAffineMinMaxOp(op: *this); }
3746
3747ParseResult AffineMaxOp::parse(OpAsmParser &parser, OperationState &result) {
3748 return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3749}
3750
3751void AffineMaxOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, op: *this); }
3752
3753//===----------------------------------------------------------------------===//
3754// AffinePrefetchOp
3755//===----------------------------------------------------------------------===//
3756
3757//
3758// affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
3759//
3760ParseResult AffinePrefetchOp::parse(OpAsmParser &parser,
3761 OperationState &result) {
3762 auto &builder = parser.getBuilder();
3763 auto indexTy = builder.getIndexType();
3764
3765 MemRefType type;
3766 OpAsmParser::UnresolvedOperand memrefInfo;
3767 IntegerAttr hintInfo;
3768 auto i32Type = parser.getBuilder().getIntegerType(width: 32);
3769 StringRef readOrWrite, cacheType;
3770
3771 AffineMapAttr mapAttr;
3772 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
3773 if (parser.parseOperand(result&: memrefInfo) ||
3774 parser.parseAffineMapOfSSAIds(operands&: mapOperands, map&: mapAttr,
3775 attrName: AffinePrefetchOp::getMapAttrStrName(),
3776 attrs&: result.attributes) ||
3777 parser.parseComma() || parser.parseKeyword(keyword: &readOrWrite) ||
3778 parser.parseComma() || parser.parseKeyword(keyword: "locality") ||
3779 parser.parseLess() ||
3780 parser.parseAttribute(result&: hintInfo, type: i32Type,
3781 attrName: AffinePrefetchOp::getLocalityHintAttrStrName(),
3782 attrs&: result.attributes) ||
3783 parser.parseGreater() || parser.parseComma() ||
3784 parser.parseKeyword(keyword: &cacheType) ||
3785 parser.parseOptionalAttrDict(result&: result.attributes) ||
3786 parser.parseColonType(result&: type) ||
3787 parser.resolveOperand(operand: memrefInfo, type, result&: result.operands) ||
3788 parser.resolveOperands(operands&: mapOperands, type: indexTy, result&: result.operands))
3789 return failure();
3790
3791 if (readOrWrite != "read" && readOrWrite != "write")
3792 return parser.emitError(loc: parser.getNameLoc(),
3793 message: "rw specifier has to be 'read' or 'write'");
3794 result.addAttribute(name: AffinePrefetchOp::getIsWriteAttrStrName(),
3795 attr: parser.getBuilder().getBoolAttr(value: readOrWrite == "write"));
3796
3797 if (cacheType != "data" && cacheType != "instr")
3798 return parser.emitError(loc: parser.getNameLoc(),
3799 message: "cache type has to be 'data' or 'instr'");
3800
3801 result.addAttribute(name: AffinePrefetchOp::getIsDataCacheAttrStrName(),
3802 attr: parser.getBuilder().getBoolAttr(value: cacheType == "data"));
3803
3804 return success();
3805}
3806
3807void AffinePrefetchOp::print(OpAsmPrinter &p) {
3808 p << " " << getMemref() << '[';
3809 AffineMapAttr mapAttr =
3810 (*this)->getAttrOfType<AffineMapAttr>(name: getMapAttrStrName());
3811 if (mapAttr)
3812 p.printAffineMapOfSSAIds(mapAttr, operands: getMapOperands());
3813 p << ']' << ", " << (getIsWrite() ? "write" : "read") << ", "
3814 << "locality<" << getLocalityHint() << ">, "
3815 << (getIsDataCache() ? "data" : "instr");
3816 p.printOptionalAttrDict(
3817 attrs: (*this)->getAttrs(),
3818 /*elidedAttrs=*/{getMapAttrStrName(), getLocalityHintAttrStrName(),
3819 getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3820 p << " : " << getMemRefType();
3821}
3822
3823LogicalResult AffinePrefetchOp::verify() {
3824 auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(name: getMapAttrStrName());
3825 if (mapAttr) {
3826 AffineMap map = mapAttr.getValue();
3827 if (map.getNumResults() != getMemRefType().getRank())
3828 return emitOpError(message: "affine.prefetch affine map num results must equal"
3829 " memref rank");
3830 if (map.getNumInputs() + 1 != getNumOperands())
3831 return emitOpError(message: "too few operands");
3832 } else {
3833 if (getNumOperands() != 1)
3834 return emitOpError(message: "too few operands");
3835 }
3836
3837 Region *scope = getAffineScope(op: *this);
3838 for (auto idx : getMapOperands()) {
3839 if (!isValidAffineIndexOperand(value: idx, region: scope))
3840 return emitOpError(
3841 message: "index must be a valid dimension or symbol identifier");
3842 }
3843 return success();
3844}
3845
3846void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
3847 MLIRContext *context) {
3848 // prefetch(memrefcast) -> prefetch
3849 results.add<SimplifyAffineOp<AffinePrefetchOp>>(arg&: context);
3850}
3851
3852LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
3853 SmallVectorImpl<OpFoldResult> &results) {
3854 /// prefetch(memrefcast) -> prefetch
3855 return memref::foldMemRefCast(op: *this);
3856}
3857
3858//===----------------------------------------------------------------------===//
3859// AffineParallelOp
3860//===----------------------------------------------------------------------===//
3861
3862void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3863 TypeRange resultTypes,
3864 ArrayRef<arith::AtomicRMWKind> reductions,
3865 ArrayRef<int64_t> ranges) {
3866 SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(val: 0));
3867 auto ubs = llvm::to_vector<4>(Range: llvm::map_range(C&: ranges, F: [&](int64_t value) {
3868 return builder.getConstantAffineMap(val: value);
3869 }));
3870 SmallVector<int64_t> steps(ranges.size(), 1);
3871 build(odsBuilder&: builder, odsState&: result, resultTypes, reductions, lbMaps: lbs, /*lbArgs=*/{}, ubMaps: ubs,
3872 /*ubArgs=*/{}, steps);
3873}
3874
3875void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3876 TypeRange resultTypes,
3877 ArrayRef<arith::AtomicRMWKind> reductions,
3878 ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
3879 ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
3880 ArrayRef<int64_t> steps) {
3881 assert(llvm::all_of(lbMaps,
3882 [lbMaps](AffineMap m) {
3883 return m.getNumDims() == lbMaps[0].getNumDims() &&
3884 m.getNumSymbols() == lbMaps[0].getNumSymbols();
3885 }) &&
3886 "expected all lower bounds maps to have the same number of dimensions "
3887 "and symbols");
3888 assert(llvm::all_of(ubMaps,
3889 [ubMaps](AffineMap m) {
3890 return m.getNumDims() == ubMaps[0].getNumDims() &&
3891 m.getNumSymbols() == ubMaps[0].getNumSymbols();
3892 }) &&
3893 "expected all upper bounds maps to have the same number of dimensions "
3894 "and symbols");
3895 assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3896 "expected lower bound maps to have as many inputs as lower bound "
3897 "operands");
3898 assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3899 "expected upper bound maps to have as many inputs as upper bound "
3900 "operands");
3901
3902 OpBuilder::InsertionGuard guard(builder);
3903 result.addTypes(newTypes&: resultTypes);
3904
3905 // Convert the reductions to integer attributes.
3906 SmallVector<Attribute, 4> reductionAttrs;
3907 for (arith::AtomicRMWKind reduction : reductions)
3908 reductionAttrs.push_back(
3909 Elt: builder.getI64IntegerAttr(value: static_cast<int64_t>(reduction)));
3910 result.addAttribute(name: getReductionsAttrStrName(),
3911 attr: builder.getArrayAttr(value: reductionAttrs));
3912
3913 // Concatenates maps defined in the same input space (same dimensions and
3914 // symbols), assumes there is at least one map.
3915 auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps,
3916 SmallVectorImpl<int32_t> &groups) {
3917 if (maps.empty())
3918 return AffineMap::get(context: builder.getContext());
3919 SmallVector<AffineExpr> exprs;
3920 groups.reserve(N: groups.size() + maps.size());
3921 exprs.reserve(N: maps.size());
3922 for (AffineMap m : maps) {
3923 llvm::append_range(C&: exprs, R: m.getResults());
3924 groups.push_back(Elt: m.getNumResults());
3925 }
3926 return AffineMap::get(dimCount: maps[0].getNumDims(), symbolCount: maps[0].getNumSymbols(), results: exprs,
3927 context: maps[0].getContext());
3928 };
3929
3930 // Set up the bounds.
3931 SmallVector<int32_t> lbGroups, ubGroups;
3932 AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3933 AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3934 result.addAttribute(name: getLowerBoundsMapAttrStrName(),
3935 attr: AffineMapAttr::get(value: lbMap));
3936 result.addAttribute(name: getLowerBoundsGroupsAttrStrName(),
3937 attr: builder.getI32TensorAttr(values: lbGroups));
3938 result.addAttribute(name: getUpperBoundsMapAttrStrName(),
3939 attr: AffineMapAttr::get(value: ubMap));
3940 result.addAttribute(name: getUpperBoundsGroupsAttrStrName(),
3941 attr: builder.getI32TensorAttr(values: ubGroups));
3942 result.addAttribute(name: getStepsAttrStrName(), attr: builder.getI64ArrayAttr(values: steps));
3943 result.addOperands(newOperands: lbArgs);
3944 result.addOperands(newOperands: ubArgs);
3945
3946 // Create a region and a block for the body.
3947 auto *bodyRegion = result.addRegion();
3948 Block *body = builder.createBlock(parent: bodyRegion);
3949
3950 // Add all the block arguments.
3951 for (unsigned i = 0, e = steps.size(); i < e; ++i)
3952 body->addArgument(type: IndexType::get(context: builder.getContext()), loc: result.location);
3953 if (resultTypes.empty())
3954 ensureTerminator(region&: *bodyRegion, builder, loc: result.location);
3955}
3956
3957SmallVector<Region *> AffineParallelOp::getLoopRegions() {
3958 return {&getRegion()};
3959}
3960
3961unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
3962
3963AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
3964 return getOperands().take_front(n: getLowerBoundsMap().getNumInputs());
3965}
3966
3967AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
3968 return getOperands().drop_front(n: getLowerBoundsMap().getNumInputs());
3969}
3970
3971AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
3972 auto values = getLowerBoundsGroups().getValues<int32_t>();
3973 unsigned start = 0;
3974 for (unsigned i = 0; i < pos; ++i)
3975 start += values[i];
3976 return getLowerBoundsMap().getSliceMap(start, length: values[pos]);
3977}
3978
3979AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
3980 auto values = getUpperBoundsGroups().getValues<int32_t>();
3981 unsigned start = 0;
3982 for (unsigned i = 0; i < pos; ++i)
3983 start += values[i];
3984 return getUpperBoundsMap().getSliceMap(start, length: values[pos]);
3985}
3986
3987AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
3988 return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
3989}
3990
3991AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
3992 return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
3993}
3994
3995std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
3996 if (hasMinMaxBounds())
3997 return std::nullopt;
3998
3999 // Try to convert all the ranges to constant expressions.
4000 SmallVector<int64_t, 8> out;
4001 AffineValueMap rangesValueMap;
4002 AffineValueMap::difference(a: getUpperBoundsValueMap(), b: getLowerBoundsValueMap(),
4003 res: &rangesValueMap);
4004 out.reserve(N: rangesValueMap.getNumResults());
4005 for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
4006 auto expr = rangesValueMap.getResult(i);
4007 auto cst = dyn_cast<AffineConstantExpr>(Val&: expr);
4008 if (!cst)
4009 return std::nullopt;
4010 out.push_back(Elt: cst.getValue());
4011 }
4012 return out;
4013}
4014
4015Block *AffineParallelOp::getBody() { return &getRegion().front(); }
4016
4017OpBuilder AffineParallelOp::getBodyBuilder() {
4018 return OpBuilder(getBody(), std::prev(x: getBody()->end()));
4019}
4020
4021void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
4022 assert(lbOperands.size() == map.getNumInputs() &&
4023 "operands to map must match number of inputs");
4024
4025 auto ubOperands = getUpperBoundsOperands();
4026
4027 SmallVector<Value, 4> newOperands(lbOperands);
4028 newOperands.append(in_start: ubOperands.begin(), in_end: ubOperands.end());
4029 (*this)->setOperands(newOperands);
4030
4031 setLowerBoundsMapAttr(AffineMapAttr::get(value: map));
4032}
4033
4034void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
4035 assert(ubOperands.size() == map.getNumInputs() &&
4036 "operands to map must match number of inputs");
4037
4038 SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
4039 newOperands.append(in_start: ubOperands.begin(), in_end: ubOperands.end());
4040 (*this)->setOperands(newOperands);
4041
4042 setUpperBoundsMapAttr(AffineMapAttr::get(value: map));
4043}
4044
4045void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
4046 setStepsAttr(getBodyBuilder().getI64ArrayAttr(values: newSteps));
4047}
4048
4049// check whether resultType match op or not in affine.parallel
4050static bool isResultTypeMatchAtomicRMWKind(Type resultType,
4051 arith::AtomicRMWKind op) {
4052 switch (op) {
4053 case arith::AtomicRMWKind::addf:
4054 return isa<FloatType>(Val: resultType);
4055 case arith::AtomicRMWKind::addi:
4056 return isa<IntegerType>(Val: resultType);
4057 case arith::AtomicRMWKind::assign:
4058 return true;
4059 case arith::AtomicRMWKind::mulf:
4060 return isa<FloatType>(Val: resultType);
4061 case arith::AtomicRMWKind::muli:
4062 return isa<IntegerType>(Val: resultType);
4063 case arith::AtomicRMWKind::maximumf:
4064 return isa<FloatType>(Val: resultType);
4065 case arith::AtomicRMWKind::minimumf:
4066 return isa<FloatType>(Val: resultType);
4067 case arith::AtomicRMWKind::maxs: {
4068 auto intType = llvm::dyn_cast<IntegerType>(Val&: resultType);
4069 return intType && intType.isSigned();
4070 }
4071 case arith::AtomicRMWKind::mins: {
4072 auto intType = llvm::dyn_cast<IntegerType>(Val&: resultType);
4073 return intType && intType.isSigned();
4074 }
4075 case arith::AtomicRMWKind::maxu: {
4076 auto intType = llvm::dyn_cast<IntegerType>(Val&: resultType);
4077 return intType && intType.isUnsigned();
4078 }
4079 case arith::AtomicRMWKind::minu: {
4080 auto intType = llvm::dyn_cast<IntegerType>(Val&: resultType);
4081 return intType && intType.isUnsigned();
4082 }
4083 case arith::AtomicRMWKind::ori:
4084 return isa<IntegerType>(Val: resultType);
4085 case arith::AtomicRMWKind::andi:
4086 return isa<IntegerType>(Val: resultType);
4087 default:
4088 return false;
4089 }
4090}
4091
4092LogicalResult AffineParallelOp::verify() {
4093 auto numDims = getNumDims();
4094 if (getLowerBoundsGroups().getNumElements() != numDims ||
4095 getUpperBoundsGroups().getNumElements() != numDims ||
4096 getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
4097 return emitOpError() << "the number of region arguments ("
4098 << getBody()->getNumArguments()
4099 << ") and the number of map groups for lower ("
4100 << getLowerBoundsGroups().getNumElements()
4101 << ") and upper bound ("
4102 << getUpperBoundsGroups().getNumElements()
4103 << "), and the number of steps (" << getSteps().size()
4104 << ") must all match";
4105 }
4106
4107 unsigned expectedNumLBResults = 0;
4108 for (APInt v : getLowerBoundsGroups()) {
4109 unsigned results = v.getZExtValue();
4110 if (results == 0)
4111 return emitOpError()
4112 << "expected lower bound map to have at least one result";
4113 expectedNumLBResults += results;
4114 }
4115 if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
4116 return emitOpError() << "expected lower bounds map to have "
4117 << expectedNumLBResults << " results";
4118 unsigned expectedNumUBResults = 0;
4119 for (APInt v : getUpperBoundsGroups()) {
4120 unsigned results = v.getZExtValue();
4121 if (results == 0)
4122 return emitOpError()
4123 << "expected upper bound map to have at least one result";
4124 expectedNumUBResults += results;
4125 }
4126 if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
4127 return emitOpError() << "expected upper bounds map to have "
4128 << expectedNumUBResults << " results";
4129
4130 if (getReductions().size() != getNumResults())
4131 return emitOpError(message: "a reduction must be specified for each output");
4132
4133 // Verify reduction ops are all valid and each result type matches reduction
4134 // ops
4135 for (auto it : llvm::enumerate(First: (getReductions()))) {
4136 Attribute attr = it.value();
4137 auto intAttr = llvm::dyn_cast<IntegerAttr>(Val&: attr);
4138 if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
4139 return emitOpError(message: "invalid reduction attribute");
4140 auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
4141 if (!isResultTypeMatchAtomicRMWKind(resultType: getResult(i: it.index()).getType(), op: kind))
4142 return emitOpError(message: "result type cannot match reduction attribute");
4143 }
4144
4145 // Verify that the bound operands are valid dimension/symbols.
4146 /// Lower bounds.
4147 if (failed(Result: verifyDimAndSymbolIdentifiers(op&: *this, operands: getLowerBoundsOperands(),
4148 numDims: getLowerBoundsMap().getNumDims())))
4149 return failure();
4150 /// Upper bounds.
4151 if (failed(Result: verifyDimAndSymbolIdentifiers(op&: *this, operands: getUpperBoundsOperands(),
4152 numDims: getUpperBoundsMap().getNumDims())))
4153 return failure();
4154 return success();
4155}
4156
4157LogicalResult AffineValueMap::canonicalize() {
4158 SmallVector<Value, 4> newOperands{operands};
4159 auto newMap = getAffineMap();
4160 composeAffineMapAndOperands(map: &newMap, operands: &newOperands);
4161 if (newMap == getAffineMap() && newOperands == operands)
4162 return failure();
4163 reset(map: newMap, operands: newOperands);
4164 return success();
4165}
4166
4167/// Canonicalize the bounds of the given loop.
4168static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
4169 AffineValueMap lb = op.getLowerBoundsValueMap();
4170 bool lbCanonicalized = succeeded(Result: lb.canonicalize());
4171
4172 AffineValueMap ub = op.getUpperBoundsValueMap();
4173 bool ubCanonicalized = succeeded(Result: ub.canonicalize());
4174
4175 // Any canonicalization change always leads to updated map(s).
4176 if (!lbCanonicalized && !ubCanonicalized)
4177 return failure();
4178
4179 if (lbCanonicalized)
4180 op.setLowerBounds(lbOperands: lb.getOperands(), map: lb.getAffineMap());
4181 if (ubCanonicalized)
4182 op.setUpperBounds(ubOperands: ub.getOperands(), map: ub.getAffineMap());
4183
4184 return success();
4185}
4186
4187LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
4188 SmallVectorImpl<OpFoldResult> &results) {
4189 return canonicalizeLoopBounds(op: *this);
4190}
4191
4192/// Prints a lower(upper) bound of an affine parallel loop with max(min)
4193/// conditions in it. `mapAttr` is a flat list of affine expressions and `group`
4194/// identifies which of the those expressions form max/min groups. `operands`
4195/// are the SSA values of dimensions and symbols and `keyword` is either "min"
4196/// or "max".
4197static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr,
4198 DenseIntElementsAttr group, ValueRange operands,
4199 StringRef keyword) {
4200 AffineMap map = mapAttr.getValue();
4201 unsigned numDims = map.getNumDims();
4202 ValueRange dimOperands = operands.take_front(n: numDims);
4203 ValueRange symOperands = operands.drop_front(n: numDims);
4204 unsigned start = 0;
4205 for (llvm::APInt groupSize : group) {
4206 if (start != 0)
4207 p << ", ";
4208
4209 unsigned size = groupSize.getZExtValue();
4210 if (size == 1) {
4211 p.printAffineExprOfSSAIds(expr: map.getResult(idx: start), dimOperands, symOperands);
4212 ++start;
4213 } else {
4214 p << keyword << '(';
4215 AffineMap submap = map.getSliceMap(start, length: size);
4216 p.printAffineMapOfSSAIds(mapAttr: AffineMapAttr::get(value: submap), operands);
4217 p << ')';
4218 start += size;
4219 }
4220 }
4221}
4222
4223void AffineParallelOp::print(OpAsmPrinter &p) {
4224 p << " (" << getBody()->getArguments() << ") = (";
4225 printMinMaxBound(p, mapAttr: getLowerBoundsMapAttr(), group: getLowerBoundsGroupsAttr(),
4226 operands: getLowerBoundsOperands(), keyword: "max");
4227 p << ") to (";
4228 printMinMaxBound(p, mapAttr: getUpperBoundsMapAttr(), group: getUpperBoundsGroupsAttr(),
4229 operands: getUpperBoundsOperands(), keyword: "min");
4230 p << ')';
4231 SmallVector<int64_t, 8> steps = getSteps();
4232 bool elideSteps = llvm::all_of(Range&: steps, P: [](int64_t step) { return step == 1; });
4233 if (!elideSteps) {
4234 p << " step (";
4235 llvm::interleaveComma(c: steps, os&: p);
4236 p << ')';
4237 }
4238 if (getNumResults()) {
4239 p << " reduce (";
4240 llvm::interleaveComma(c: getReductions(), os&: p, each_fn: [&](auto &attr) {
4241 arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4242 llvm::cast<IntegerAttr>(attr).getInt());
4243 p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";
4244 });
4245 p << ") -> (" << getResultTypes() << ")";
4246 }
4247
4248 p << ' ';
4249 p.printRegion(blocks&: getRegion(), /*printEntryBlockArgs=*/false,
4250 /*printBlockTerminators=*/getNumResults());
4251 p.printOptionalAttrDict(
4252 attrs: (*this)->getAttrs(),
4253 /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrStrName(),
4254 AffineParallelOp::getLowerBoundsMapAttrStrName(),
4255 AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4256 AffineParallelOp::getUpperBoundsMapAttrStrName(),
4257 AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4258 AffineParallelOp::getStepsAttrStrName()});
4259}
4260
4261/// Given a list of lists of parsed operands, populates `uniqueOperands` with
4262/// unique operands. Also populates `replacements with affine expressions of
4263/// `kind` that can be used to update affine maps previously accepting a
4264/// `operands` to accept `uniqueOperands` instead.
4265static ParseResult deduplicateAndResolveOperands(
4266 OpAsmParser &parser,
4267 ArrayRef<SmallVector<OpAsmParser::UnresolvedOperand>> operands,
4268 SmallVectorImpl<Value> &uniqueOperands,
4269 SmallVectorImpl<AffineExpr> &replacements, AffineExprKind kind) {
4270 assert((kind == AffineExprKind::DimId || kind == AffineExprKind::SymbolId) &&
4271 "expected operands to be dim or symbol expression");
4272
4273 Type indexType = parser.getBuilder().getIndexType();
4274 for (const auto &list : operands) {
4275 SmallVector<Value> valueOperands;
4276 if (parser.resolveOperands(operands: list, type: indexType, result&: valueOperands))
4277 return failure();
4278 for (Value operand : valueOperands) {
4279 unsigned pos = std::distance(first: uniqueOperands.begin(),
4280 last: llvm::find(Range&: uniqueOperands, Val: operand));
4281 if (pos == uniqueOperands.size())
4282 uniqueOperands.push_back(Elt: operand);
4283 replacements.push_back(
4284 Elt: kind == AffineExprKind::DimId
4285 ? getAffineDimExpr(position: pos, context: parser.getContext())
4286 : getAffineSymbolExpr(position: pos, context: parser.getContext()));
4287 }
4288 }
4289 return success();
4290}
4291
4292namespace {
4293enum class MinMaxKind { Min, Max };
4294} // namespace
4295
4296/// Parses an affine map that can contain a min/max for groups of its results,
4297/// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates
4298/// `result` attributes with the map (flat list of expressions) and the grouping
4299/// (list of integers that specify how many expressions to put into each
4300/// min/max) attributes. Deduplicates repeated operands.
4301///
4302/// parallel-bound ::= `(` parallel-group-list `)`
4303/// parallel-group-list ::= parallel-group (`,` parallel-group-list)?
4304/// parallel-group ::= simple-group | min-max-group
4305/// simple-group ::= expr-of-ssa-ids
4306/// min-max-group ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)`
4307/// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)?
4308///
4309/// Examples:
4310/// (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6))
4311/// (%0, max(%1 - 2 * %2))
4312static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser,
4313 OperationState &result,
4314 MinMaxKind kind) {
4315 // Using `const` not `constexpr` below to workaround a MSVC optimizer bug,
4316 // see: https://reviews.llvm.org/D134227#3821753
4317 const llvm::StringLiteral tmpAttrStrName = "__pseudo_bound_map";
4318
4319 StringRef mapName = kind == MinMaxKind::Min
4320 ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4321 : AffineParallelOp::getLowerBoundsMapAttrStrName();
4322 StringRef groupsName =
4323 kind == MinMaxKind::Min
4324 ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4325 : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4326
4327 if (failed(Result: parser.parseLParen()))
4328 return failure();
4329
4330 if (succeeded(Result: parser.parseOptionalRParen())) {
4331 result.addAttribute(
4332 name: mapName, attr: AffineMapAttr::get(value: parser.getBuilder().getEmptyAffineMap()));
4333 result.addAttribute(name: groupsName, attr: parser.getBuilder().getI32TensorAttr(values: {}));
4334 return success();
4335 }
4336
4337 SmallVector<AffineExpr> flatExprs;
4338 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> flatDimOperands;
4339 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> flatSymOperands;
4340 SmallVector<int32_t> numMapsPerGroup;
4341 SmallVector<OpAsmParser::UnresolvedOperand> mapOperands;
4342 auto parseOperands = [&]() {
4343 if (succeeded(Result: parser.parseOptionalKeyword(
4344 keyword: kind == MinMaxKind::Min ? "min" : "max"))) {
4345 mapOperands.clear();
4346 AffineMapAttr map;
4347 if (failed(Result: parser.parseAffineMapOfSSAIds(operands&: mapOperands, map, attrName: tmpAttrStrName,
4348 attrs&: result.attributes,
4349 delimiter: OpAsmParser::Delimiter::Paren)))
4350 return failure();
4351 result.attributes.erase(name: tmpAttrStrName);
4352 llvm::append_range(C&: flatExprs, R: map.getValue().getResults());
4353 auto operandsRef = llvm::ArrayRef(mapOperands);
4354 auto dimsRef = operandsRef.take_front(N: map.getValue().getNumDims());
4355 SmallVector<OpAsmParser::UnresolvedOperand> dims(dimsRef);
4356 auto symsRef = operandsRef.drop_front(N: map.getValue().getNumDims());
4357 SmallVector<OpAsmParser::UnresolvedOperand> syms(symsRef);
4358 flatDimOperands.append(NumInputs: map.getValue().getNumResults(), Elt: dims);
4359 flatSymOperands.append(NumInputs: map.getValue().getNumResults(), Elt: syms);
4360 numMapsPerGroup.push_back(Elt: map.getValue().getNumResults());
4361 } else {
4362 if (failed(Result: parser.parseAffineExprOfSSAIds(dimOperands&: flatDimOperands.emplace_back(),
4363 symbOperands&: flatSymOperands.emplace_back(),
4364 expr&: flatExprs.emplace_back())))
4365 return failure();
4366 numMapsPerGroup.push_back(Elt: 1);
4367 }
4368 return success();
4369 };
4370 if (parser.parseCommaSeparatedList(parseElementFn: parseOperands) || parser.parseRParen())
4371 return failure();
4372
4373 unsigned totalNumDims = 0;
4374 unsigned totalNumSyms = 0;
4375 for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4376 unsigned numDims = flatDimOperands[i].size();
4377 unsigned numSyms = flatSymOperands[i].size();
4378 flatExprs[i] = flatExprs[i]
4379 .shiftDims(numDims, shift: totalNumDims)
4380 .shiftSymbols(numSymbols: numSyms, shift: totalNumSyms);
4381 totalNumDims += numDims;
4382 totalNumSyms += numSyms;
4383 }
4384
4385 // Deduplicate map operands.
4386 SmallVector<Value> dimOperands, symOperands;
4387 SmallVector<AffineExpr> dimRplacements, symRepacements;
4388 if (deduplicateAndResolveOperands(parser, operands: flatDimOperands, uniqueOperands&: dimOperands,
4389 replacements&: dimRplacements, kind: AffineExprKind::DimId) ||
4390 deduplicateAndResolveOperands(parser, operands: flatSymOperands, uniqueOperands&: symOperands,
4391 replacements&: symRepacements, kind: AffineExprKind::SymbolId))
4392 return failure();
4393
4394 result.operands.append(in_start: dimOperands.begin(), in_end: dimOperands.end());
4395 result.operands.append(in_start: symOperands.begin(), in_end: symOperands.end());
4396
4397 Builder &builder = parser.getBuilder();
4398 auto flatMap = AffineMap::get(dimCount: totalNumDims, symbolCount: totalNumSyms, results: flatExprs,
4399 context: parser.getContext());
4400 flatMap = flatMap.replaceDimsAndSymbols(
4401 dimReplacements: dimRplacements, symReplacements: symRepacements, numResultDims: dimOperands.size(), numResultSyms: symOperands.size());
4402
4403 result.addAttribute(name: mapName, attr: AffineMapAttr::get(value: flatMap));
4404 result.addAttribute(name: groupsName, attr: builder.getI32TensorAttr(values: numMapsPerGroup));
4405 return success();
4406}
4407
4408//
4409// operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound
4410// `to` parallel-bound steps? region attr-dict?
4411// steps ::= `steps` `(` integer-literals `)`
4412//
4413ParseResult AffineParallelOp::parse(OpAsmParser &parser,
4414 OperationState &result) {
4415 auto &builder = parser.getBuilder();
4416 auto indexType = builder.getIndexType();
4417 SmallVector<OpAsmParser::Argument, 4> ivs;
4418 if (parser.parseArgumentList(result&: ivs, delimiter: OpAsmParser::Delimiter::Paren) ||
4419 parser.parseEqual() ||
4420 parseAffineMapWithMinMax(parser, result, kind: MinMaxKind::Max) ||
4421 parser.parseKeyword(keyword: "to") ||
4422 parseAffineMapWithMinMax(parser, result, kind: MinMaxKind::Min))
4423 return failure();
4424
4425 AffineMapAttr stepsMapAttr;
4426 NamedAttrList stepsAttrs;
4427 SmallVector<OpAsmParser::UnresolvedOperand, 4> stepsMapOperands;
4428 if (failed(Result: parser.parseOptionalKeyword(keyword: "step"))) {
4429 SmallVector<int64_t, 4> steps(ivs.size(), 1);
4430 result.addAttribute(name: AffineParallelOp::getStepsAttrStrName(),
4431 attr: builder.getI64ArrayAttr(values: steps));
4432 } else {
4433 if (parser.parseAffineMapOfSSAIds(operands&: stepsMapOperands, map&: stepsMapAttr,
4434 attrName: AffineParallelOp::getStepsAttrStrName(),
4435 attrs&: stepsAttrs,
4436 delimiter: OpAsmParser::Delimiter::Paren))
4437 return failure();
4438
4439 // Convert steps from an AffineMap into an I64ArrayAttr.
4440 SmallVector<int64_t, 4> steps;
4441 auto stepsMap = stepsMapAttr.getValue();
4442 for (const auto &result : stepsMap.getResults()) {
4443 auto constExpr = dyn_cast<AffineConstantExpr>(Val: result);
4444 if (!constExpr)
4445 return parser.emitError(loc: parser.getNameLoc(),
4446 message: "steps must be constant integers");
4447 steps.push_back(Elt: constExpr.getValue());
4448 }
4449 result.addAttribute(name: AffineParallelOp::getStepsAttrStrName(),
4450 attr: builder.getI64ArrayAttr(values: steps));
4451 }
4452
4453 // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
4454 // quoted strings are a member of the enum AtomicRMWKind.
4455 SmallVector<Attribute, 4> reductions;
4456 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "reduce"))) {
4457 if (parser.parseLParen())
4458 return failure();
4459 auto parseAttributes = [&]() -> ParseResult {
4460 // Parse a single quoted string via the attribute parsing, and then
4461 // verify it is a member of the enum and convert to it's integer
4462 // representation.
4463 StringAttr attrVal;
4464 NamedAttrList attrStorage;
4465 auto loc = parser.getCurrentLocation();
4466 if (parser.parseAttribute(result&: attrVal, type: builder.getNoneType(), attrName: "reduce",
4467 attrs&: attrStorage))
4468 return failure();
4469 std::optional<arith::AtomicRMWKind> reduction =
4470 arith::symbolizeAtomicRMWKind(attrVal.getValue());
4471 if (!reduction)
4472 return parser.emitError(loc, message: "invalid reduction value: ") << attrVal;
4473 reductions.push_back(
4474 Elt: builder.getI64IntegerAttr(value: static_cast<int64_t>(reduction.value())));
4475 // While we keep getting commas, keep parsing.
4476 return success();
4477 };
4478 if (parser.parseCommaSeparatedList(parseElementFn: parseAttributes) || parser.parseRParen())
4479 return failure();
4480 }
4481 result.addAttribute(name: AffineParallelOp::getReductionsAttrStrName(),
4482 attr: builder.getArrayAttr(value: reductions));
4483
4484 // Parse return types of reductions (if any)
4485 if (parser.parseOptionalArrowTypeList(result&: result.types))
4486 return failure();
4487
4488 // Now parse the body.
4489 Region *body = result.addRegion();
4490 for (auto &iv : ivs)
4491 iv.type = indexType;
4492 if (parser.parseRegion(region&: *body, arguments: ivs) ||
4493 parser.parseOptionalAttrDict(result&: result.attributes))
4494 return failure();
4495
4496 // Add a terminator if none was parsed.
4497 AffineParallelOp::ensureTerminator(region&: *body, builder, loc: result.location);
4498 return success();
4499}
4500
4501//===----------------------------------------------------------------------===//
4502// AffineYieldOp
4503//===----------------------------------------------------------------------===//
4504
4505LogicalResult AffineYieldOp::verify() {
4506 auto *parentOp = (*this)->getParentOp();
4507 auto results = parentOp->getResults();
4508 auto operands = getOperands();
4509
4510 if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(Val: parentOp))
4511 return emitOpError() << "only terminates affine.if/for/parallel regions";
4512 if (parentOp->getNumResults() != getNumOperands())
4513 return emitOpError() << "parent of yield must have same number of "
4514 "results as the yield operands";
4515 for (auto it : llvm::zip(t&: results, u&: operands)) {
4516 if (std::get<0>(t&: it).getType() != std::get<1>(t&: it).getType())
4517 return emitOpError() << "types mismatch between yield op and its parent";
4518 }
4519
4520 return success();
4521}
4522
4523//===----------------------------------------------------------------------===//
4524// AffineVectorLoadOp
4525//===----------------------------------------------------------------------===//
4526
4527void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4528 VectorType resultType, AffineMap map,
4529 ValueRange operands) {
4530 assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
4531 result.addOperands(newOperands: operands);
4532 if (map)
4533 result.addAttribute(name: getMapAttrStrName(), attr: AffineMapAttr::get(value: map));
4534 result.types.push_back(Elt: resultType);
4535}
4536
4537void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4538 VectorType resultType, Value memref,
4539 AffineMap map, ValueRange mapOperands) {
4540 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4541 result.addOperands(newOperands: memref);
4542 result.addOperands(newOperands: mapOperands);
4543 result.addAttribute(name: getMapAttrStrName(), attr: AffineMapAttr::get(value: map));
4544 result.types.push_back(Elt: resultType);
4545}
4546
4547void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4548 VectorType resultType, Value memref,
4549 ValueRange indices) {
4550 auto memrefType = llvm::cast<MemRefType>(Val: memref.getType());
4551 int64_t rank = memrefType.getRank();
4552 // Create identity map for memrefs with at least one dimension or () -> ()
4553 // for zero-dimensional memrefs.
4554 auto map =
4555 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4556 build(builder, result, resultType, memref, map, mapOperands: indices);
4557}
4558
4559void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4560 MLIRContext *context) {
4561 results.add<SimplifyAffineOp<AffineVectorLoadOp>>(arg&: context);
4562}
4563
4564ParseResult AffineVectorLoadOp::parse(OpAsmParser &parser,
4565 OperationState &result) {
4566 auto &builder = parser.getBuilder();
4567 auto indexTy = builder.getIndexType();
4568
4569 MemRefType memrefType;
4570 VectorType resultType;
4571 OpAsmParser::UnresolvedOperand memrefInfo;
4572 AffineMapAttr mapAttr;
4573 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
4574 return failure(
4575 IsFailure: parser.parseOperand(result&: memrefInfo) ||
4576 parser.parseAffineMapOfSSAIds(operands&: mapOperands, map&: mapAttr,
4577 attrName: AffineVectorLoadOp::getMapAttrStrName(),
4578 attrs&: result.attributes) ||
4579 parser.parseOptionalAttrDict(result&: result.attributes) ||
4580 parser.parseColonType(result&: memrefType) || parser.parseComma() ||
4581 parser.parseType(result&: resultType) ||
4582 parser.resolveOperand(operand: memrefInfo, type: memrefType, result&: result.operands) ||
4583 parser.resolveOperands(operands&: mapOperands, type: indexTy, result&: result.operands) ||
4584 parser.addTypeToList(type: resultType, result&: result.types));
4585}
4586
4587void AffineVectorLoadOp::print(OpAsmPrinter &p) {
4588 p << " " << getMemRef() << '[';
4589 if (AffineMapAttr mapAttr =
4590 (*this)->getAttrOfType<AffineMapAttr>(name: getMapAttrStrName()))
4591 p.printAffineMapOfSSAIds(mapAttr, operands: getMapOperands());
4592 p << ']';
4593 p.printOptionalAttrDict(attrs: (*this)->getAttrs(),
4594 /*elidedAttrs=*/{getMapAttrStrName()});
4595 p << " : " << getMemRefType() << ", " << getType();
4596}
4597
4598/// Verify common invariants of affine.vector_load and affine.vector_store.
4599static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
4600 VectorType vectorType) {
4601 // Check that memref and vector element types match.
4602 if (memrefType.getElementType() != vectorType.getElementType())
4603 return op->emitOpError(
4604 message: "requires memref and vector types of the same elemental type");
4605 return success();
4606}
4607
4608LogicalResult AffineVectorLoadOp::verify() {
4609 MemRefType memrefType = getMemRefType();
4610 if (failed(Result: verifyMemoryOpIndexing(
4611 op: *this, mapAttr: (*this)->getAttrOfType<AffineMapAttr>(name: getMapAttrStrName()),
4612 mapOperands: getMapOperands(), memrefType,
4613 /*numIndexOperands=*/getNumOperands() - 1)))
4614 return failure();
4615
4616 if (failed(Result: verifyVectorMemoryOp(op: getOperation(), memrefType, vectorType: getVectorType())))
4617 return failure();
4618
4619 return success();
4620}
4621
4622//===----------------------------------------------------------------------===//
4623// AffineVectorStoreOp
4624//===----------------------------------------------------------------------===//
4625
4626void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4627 Value valueToStore, Value memref, AffineMap map,
4628 ValueRange mapOperands) {
4629 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4630 result.addOperands(newOperands: valueToStore);
4631 result.addOperands(newOperands: memref);
4632 result.addOperands(newOperands: mapOperands);
4633 result.addAttribute(name: getMapAttrStrName(), attr: AffineMapAttr::get(value: map));
4634}
4635
4636// Use identity map.
4637void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4638 Value valueToStore, Value memref,
4639 ValueRange indices) {
4640 auto memrefType = llvm::cast<MemRefType>(Val: memref.getType());
4641 int64_t rank = memrefType.getRank();
4642 // Create identity map for memrefs with at least one dimension or () -> ()
4643 // for zero-dimensional memrefs.
4644 auto map =
4645 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4646 build(builder, result, valueToStore, memref, map, mapOperands: indices);
4647}
4648void AffineVectorStoreOp::getCanonicalizationPatterns(
4649 RewritePatternSet &results, MLIRContext *context) {
4650 results.add<SimplifyAffineOp<AffineVectorStoreOp>>(arg&: context);
4651}
4652
4653ParseResult AffineVectorStoreOp::parse(OpAsmParser &parser,
4654 OperationState &result) {
4655 auto indexTy = parser.getBuilder().getIndexType();
4656
4657 MemRefType memrefType;
4658 VectorType resultType;
4659 OpAsmParser::UnresolvedOperand storeValueInfo;
4660 OpAsmParser::UnresolvedOperand memrefInfo;
4661 AffineMapAttr mapAttr;
4662 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
4663 return failure(
4664 IsFailure: parser.parseOperand(result&: storeValueInfo) || parser.parseComma() ||
4665 parser.parseOperand(result&: memrefInfo) ||
4666 parser.parseAffineMapOfSSAIds(operands&: mapOperands, map&: mapAttr,
4667 attrName: AffineVectorStoreOp::getMapAttrStrName(),
4668 attrs&: result.attributes) ||
4669 parser.parseOptionalAttrDict(result&: result.attributes) ||
4670 parser.parseColonType(result&: memrefType) || parser.parseComma() ||
4671 parser.parseType(result&: resultType) ||
4672 parser.resolveOperand(operand: storeValueInfo, type: resultType, result&: result.operands) ||
4673 parser.resolveOperand(operand: memrefInfo, type: memrefType, result&: result.operands) ||
4674 parser.resolveOperands(operands&: mapOperands, type: indexTy, result&: result.operands));
4675}
4676
4677void AffineVectorStoreOp::print(OpAsmPrinter &p) {
4678 p << " " << getValueToStore();
4679 p << ", " << getMemRef() << '[';
4680 if (AffineMapAttr mapAttr =
4681 (*this)->getAttrOfType<AffineMapAttr>(name: getMapAttrStrName()))
4682 p.printAffineMapOfSSAIds(mapAttr, operands: getMapOperands());
4683 p << ']';
4684 p.printOptionalAttrDict(attrs: (*this)->getAttrs(),
4685 /*elidedAttrs=*/{getMapAttrStrName()});
4686 p << " : " << getMemRefType() << ", " << getValueToStore().getType();
4687}
4688
4689LogicalResult AffineVectorStoreOp::verify() {
4690 MemRefType memrefType = getMemRefType();
4691 if (failed(Result: verifyMemoryOpIndexing(
4692 op: *this, mapAttr: (*this)->getAttrOfType<AffineMapAttr>(name: getMapAttrStrName()),
4693 mapOperands: getMapOperands(), memrefType,
4694 /*numIndexOperands=*/getNumOperands() - 2)))
4695 return failure();
4696
4697 if (failed(Result: verifyVectorMemoryOp(op: *this, memrefType, vectorType: getVectorType())))
4698 return failure();
4699
4700 return success();
4701}
4702
4703//===----------------------------------------------------------------------===//
4704// DelinearizeIndexOp
4705//===----------------------------------------------------------------------===//
4706
4707void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4708 OperationState &odsState,
4709 Value linearIndex, ValueRange dynamicBasis,
4710 ArrayRef<int64_t> staticBasis,
4711 bool hasOuterBound) {
4712 SmallVector<Type> returnTypes(hasOuterBound ? staticBasis.size()
4713 : staticBasis.size() + 1,
4714 linearIndex.getType());
4715 build(odsBuilder, odsState, multi_index: returnTypes, linear_index: linearIndex, dynamic_basis: dynamicBasis,
4716 static_basis: staticBasis);
4717}
4718
4719void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4720 OperationState &odsState,
4721 Value linearIndex, ValueRange basis,
4722 bool hasOuterBound) {
4723 if (hasOuterBound && !basis.empty() && basis.front() == nullptr) {
4724 hasOuterBound = false;
4725 basis = basis.drop_front();
4726 }
4727 SmallVector<Value> dynamicBasis;
4728 SmallVector<int64_t> staticBasis;
4729 dispatchIndexOpFoldResults(ofrs: getAsOpFoldResult(values: basis), dynamicVec&: dynamicBasis,
4730 staticVec&: staticBasis);
4731 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4732 hasOuterBound);
4733}
4734
4735void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4736 OperationState &odsState,
4737 Value linearIndex,
4738 ArrayRef<OpFoldResult> basis,
4739 bool hasOuterBound) {
4740 if (hasOuterBound && !basis.empty() && basis.front() == OpFoldResult()) {
4741 hasOuterBound = false;
4742 basis = basis.drop_front();
4743 }
4744 SmallVector<Value> dynamicBasis;
4745 SmallVector<int64_t> staticBasis;
4746 dispatchIndexOpFoldResults(ofrs: basis, dynamicVec&: dynamicBasis, staticVec&: staticBasis);
4747 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4748 hasOuterBound);
4749}
4750
4751void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4752 OperationState &odsState,
4753 Value linearIndex, ArrayRef<int64_t> basis,
4754 bool hasOuterBound) {
4755 build(odsBuilder, odsState, linearIndex, dynamicBasis: ValueRange{}, staticBasis: basis, hasOuterBound);
4756}
4757
4758LogicalResult AffineDelinearizeIndexOp::verify() {
4759 ArrayRef<int64_t> staticBasis = getStaticBasis();
4760 if (getNumResults() != staticBasis.size() &&
4761 getNumResults() != staticBasis.size() + 1)
4762 return emitOpError(message: "should return an index for each basis element and up "
4763 "to one extra index");
4764
4765 auto dynamicMarkersCount = llvm::count_if(Range&: staticBasis, P: ShapedType::isDynamic);
4766 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4767 return emitOpError(
4768 message: "mismatch between dynamic and static basis (kDynamic marker but no "
4769 "corresponding dynamic basis entry) -- this can only happen due to an "
4770 "incorrect fold/rewrite");
4771
4772 if (!llvm::all_of(Range&: staticBasis, P: [](int64_t v) {
4773 return v > 0 || ShapedType::isDynamic(dValue: v);
4774 }))
4775 return emitOpError(message: "no basis element may be statically non-positive");
4776
4777 return success();
4778}
4779
4780/// Given mixed basis of affine.delinearize_index/linearize_index replace
4781/// constant SSA values with the constant integer value and return the new
4782/// static basis. In case no such candidate for replacement exists, this utility
4783/// returns std::nullopt.
4784static std::optional<SmallVector<int64_t>>
4785foldCstValueToCstAttrBasis(ArrayRef<OpFoldResult> mixedBasis,
4786 MutableOperandRange mutableDynamicBasis,
4787 ArrayRef<Attribute> dynamicBasis) {
4788 uint64_t dynamicBasisIndex = 0;
4789 for (OpFoldResult basis : dynamicBasis) {
4790 if (basis) {
4791 mutableDynamicBasis.erase(subStart: dynamicBasisIndex);
4792 } else {
4793 ++dynamicBasisIndex;
4794 }
4795 }
4796
4797 // No constant SSA value exists.
4798 if (dynamicBasisIndex == dynamicBasis.size())
4799 return std::nullopt;
4800
4801 SmallVector<int64_t> staticBasis;
4802 for (OpFoldResult basis : mixedBasis) {
4803 std::optional<int64_t> basisVal = getConstantIntValue(ofr: basis);
4804 if (!basisVal)
4805 staticBasis.push_back(Elt: ShapedType::kDynamic);
4806 else
4807 staticBasis.push_back(Elt: *basisVal);
4808 }
4809
4810 return staticBasis;
4811}
4812
4813LogicalResult
4814AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
4815 SmallVectorImpl<OpFoldResult> &result) {
4816 std::optional<SmallVector<int64_t>> maybeStaticBasis =
4817 foldCstValueToCstAttrBasis(mixedBasis: getMixedBasis(), mutableDynamicBasis: getDynamicBasisMutable(),
4818 dynamicBasis: adaptor.getDynamicBasis());
4819 if (maybeStaticBasis) {
4820 setStaticBasis(*maybeStaticBasis);
4821 return success();
4822 }
4823 // If we won't be doing any division or modulo (no basis or the one basis
4824 // element is purely advisory), simply return the input value.
4825 if (getNumResults() == 1) {
4826 result.push_back(Elt: getLinearIndex());
4827 return success();
4828 }
4829
4830 if (adaptor.getLinearIndex() == nullptr)
4831 return failure();
4832
4833 if (!adaptor.getDynamicBasis().empty())
4834 return failure();
4835
4836 int64_t highPart = cast<IntegerAttr>(Val: adaptor.getLinearIndex()).getInt();
4837 Type attrType = getLinearIndex().getType();
4838
4839 ArrayRef<int64_t> staticBasis = getStaticBasis();
4840 if (hasOuterBound())
4841 staticBasis = staticBasis.drop_front();
4842 for (int64_t modulus : llvm::reverse(C&: staticBasis)) {
4843 result.push_back(Elt: IntegerAttr::get(type: attrType, value: llvm::mod(Numerator: highPart, Denominator: modulus)));
4844 highPart = llvm::divideFloorSigned(Numerator: highPart, Denominator: modulus);
4845 }
4846 result.push_back(Elt: IntegerAttr::get(type: attrType, value: highPart));
4847 std::reverse(first: result.begin(), last: result.end());
4848 return success();
4849}
4850
4851SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getEffectiveBasis() {
4852 OpBuilder builder(getContext());
4853 if (hasOuterBound()) {
4854 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
4855 return getMixedValues(staticValues: getStaticBasis().drop_front(),
4856 dynamicValues: getDynamicBasis().drop_front(), b&: builder);
4857
4858 return getMixedValues(staticValues: getStaticBasis().drop_front(), dynamicValues: getDynamicBasis(),
4859 b&: builder);
4860 }
4861
4862 return getMixedValues(staticValues: getStaticBasis(), dynamicValues: getDynamicBasis(), b&: builder);
4863}
4864
4865SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getPaddedBasis() {
4866 SmallVector<OpFoldResult> ret = getMixedBasis();
4867 if (!hasOuterBound())
4868 ret.insert(I: ret.begin(), Elt: OpFoldResult());
4869 return ret;
4870}
4871
4872namespace {
4873
4874// Drops delinearization indices that correspond to unit-extent basis
4875struct DropUnitExtentBasis
4876 : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4877 using OpRewritePattern::OpRewritePattern;
4878
4879 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4880 PatternRewriter &rewriter) const override {
4881 SmallVector<Value> replacements(delinearizeOp->getNumResults(), nullptr);
4882 std::optional<Value> zero = std::nullopt;
4883 Location loc = delinearizeOp->getLoc();
4884 auto getZero = [&]() -> Value {
4885 if (!zero)
4886 zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
4887 return zero.value();
4888 };
4889
4890 // Replace all indices corresponding to unit-extent basis with 0.
4891 // Remaining basis can be used to get a new `affine.delinearize_index` op.
4892 SmallVector<OpFoldResult> newBasis;
4893 for (auto [index, basis] :
4894 llvm::enumerate(First: delinearizeOp.getPaddedBasis())) {
4895 std::optional<int64_t> basisVal =
4896 basis ? getConstantIntValue(ofr: basis) : std::nullopt;
4897 if (basisVal == 1)
4898 replacements[index] = getZero();
4899 else
4900 newBasis.push_back(Elt: basis);
4901 }
4902
4903 if (newBasis.size() == delinearizeOp.getNumResults())
4904 return rewriter.notifyMatchFailure(arg&: delinearizeOp,
4905 msg: "no unit basis elements");
4906
4907 if (!newBasis.empty()) {
4908 // Will drop the leading nullptr from `basis` if there was no outer bound.
4909 auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
4910 location: loc, args: delinearizeOp.getLinearIndex(), args&: newBasis);
4911 int newIndex = 0;
4912 // Map back the new delinearized indices to the values they replace.
4913 for (auto &replacement : replacements) {
4914 if (replacement)
4915 continue;
4916 replacement = newDelinearizeOp->getResult(idx: newIndex++);
4917 }
4918 }
4919
4920 rewriter.replaceOp(op: delinearizeOp, newValues: replacements);
4921 return success();
4922 }
4923};
4924
4925/// If a `affine.delinearize_index`'s input is a `affine.linearize_index
4926/// disjoint` and the two operations end with the same basis elements,
4927/// cancel those parts of the operations out because they are inverses
4928/// of each other.
4929///
4930/// If the operations have the same basis, cancel them entirely.
4931///
4932/// The `disjoint` flag is needed on the `affine.linearize_index` because
4933/// otherwise, there is no guarantee that the inputs to the linearization are
4934/// in-bounds the way the outputs of the delinearization would be.
4935struct CancelDelinearizeOfLinearizeDisjointExactTail
4936 : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4937 using OpRewritePattern::OpRewritePattern;
4938
4939 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4940 PatternRewriter &rewriter) const override {
4941 auto linearizeOp = delinearizeOp.getLinearIndex()
4942 .getDefiningOp<affine::AffineLinearizeIndexOp>();
4943 if (!linearizeOp)
4944 return rewriter.notifyMatchFailure(arg&: delinearizeOp,
4945 msg: "index doesn't come from linearize");
4946
4947 if (!linearizeOp.getDisjoint())
4948 return rewriter.notifyMatchFailure(arg&: linearizeOp, msg: "not disjoint");
4949
4950 ValueRange linearizeIns = linearizeOp.getMultiIndex();
4951 // Note: we use the full basis so we don't lose outer bounds later.
4952 SmallVector<OpFoldResult> linearizeBasis = linearizeOp.getMixedBasis();
4953 SmallVector<OpFoldResult> delinearizeBasis = delinearizeOp.getMixedBasis();
4954 size_t numMatches = 0;
4955 for (auto [linSize, delinSize] : llvm::zip(
4956 t: llvm::reverse(C&: linearizeBasis), u: llvm::reverse(C&: delinearizeBasis))) {
4957 if (linSize != delinSize)
4958 break;
4959 ++numMatches;
4960 }
4961
4962 if (numMatches == 0)
4963 return rewriter.notifyMatchFailure(
4964 arg&: delinearizeOp, msg: "final basis element doesn't match linearize");
4965
4966 // The easy case: everything lines up and the basis match sup completely.
4967 if (numMatches == linearizeBasis.size() &&
4968 numMatches == delinearizeBasis.size() &&
4969 linearizeIns.size() == delinearizeOp.getNumResults()) {
4970 rewriter.replaceOp(op: delinearizeOp, newValues: linearizeOp.getMultiIndex());
4971 return success();
4972 }
4973
4974 Value newLinearize = rewriter.create<affine::AffineLinearizeIndexOp>(
4975 location: linearizeOp.getLoc(), args: linearizeIns.drop_back(n: numMatches),
4976 args: ArrayRef<OpFoldResult>{linearizeBasis}.drop_back(N: numMatches),
4977 args: linearizeOp.getDisjoint());
4978 auto newDelinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
4979 location: delinearizeOp.getLoc(), args&: newLinearize,
4980 args: ArrayRef<OpFoldResult>{delinearizeBasis}.drop_back(N: numMatches),
4981 args: delinearizeOp.hasOuterBound());
4982 SmallVector<Value> mergedResults(newDelinearize.getResults());
4983 mergedResults.append(in_start: linearizeIns.take_back(n: numMatches).begin(),
4984 in_end: linearizeIns.take_back(n: numMatches).end());
4985 rewriter.replaceOp(op: delinearizeOp, newValues: mergedResults);
4986 return success();
4987 }
4988};
4989
4990/// If the input to a delinearization is a disjoint linearization, and the
4991/// last k > 1 components of the delinearization basis multiply to the
4992/// last component of the linearization basis, break the linearization and
4993/// delinearization into two parts, peeling off the last input to linearization.
4994///
4995/// For example:
4996/// %0 = affine.linearize_index [%z, %y, %x] by (3, 2, 32) : index
4997/// %1:4 = affine.delinearize_index %0 by (2, 3, 8, 4) : index, ...
4998/// becomes
4999/// %0 = affine.linearize_index [%z, %y] by (3, 2) : index
5000/// %1:2 = affine.delinearize_index %0 by (2, 3) : index
5001/// %2:2 = affine.delinearize_index %x by (8, 4) : index
5002/// where the original %1:4 is replaced by %1:2 ++ %2:2
5003struct SplitDelinearizeSpanningLastLinearizeArg final
5004 : OpRewritePattern<affine::AffineDelinearizeIndexOp> {
5005 using OpRewritePattern::OpRewritePattern;
5006
5007 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
5008 PatternRewriter &rewriter) const override {
5009 auto linearizeOp = delinearizeOp.getLinearIndex()
5010 .getDefiningOp<affine::AffineLinearizeIndexOp>();
5011 if (!linearizeOp)
5012 return rewriter.notifyMatchFailure(arg&: delinearizeOp,
5013 msg: "index doesn't come from linearize");
5014
5015 if (!linearizeOp.getDisjoint())
5016 return rewriter.notifyMatchFailure(arg&: linearizeOp,
5017 msg: "linearize isn't disjoint");
5018
5019 int64_t target = linearizeOp.getStaticBasis().back();
5020 if (ShapedType::isDynamic(dValue: target))
5021 return rewriter.notifyMatchFailure(
5022 arg&: linearizeOp, msg: "linearize ends with dynamic basis value");
5023
5024 int64_t sizeToSplit = 1;
5025 size_t elemsToSplit = 0;
5026 ArrayRef<int64_t> basis = delinearizeOp.getStaticBasis();
5027 for (int64_t basisElem : llvm::reverse(C&: basis)) {
5028 if (ShapedType::isDynamic(dValue: basisElem))
5029 return rewriter.notifyMatchFailure(
5030 arg&: delinearizeOp, msg: "dynamic basis element while scanning for split");
5031 sizeToSplit *= basisElem;
5032 elemsToSplit += 1;
5033
5034 if (sizeToSplit > target)
5035 return rewriter.notifyMatchFailure(arg&: delinearizeOp,
5036 msg: "overshot last argument size");
5037 if (sizeToSplit == target)
5038 break;
5039 }
5040
5041 if (sizeToSplit < target)
5042 return rewriter.notifyMatchFailure(
5043 arg&: delinearizeOp, msg: "product of known basis elements doesn't exceed last "
5044 "linearize argument");
5045
5046 if (elemsToSplit < 2)
5047 return rewriter.notifyMatchFailure(
5048 arg&: delinearizeOp,
5049 msg: "need at least two elements to form the basis product");
5050
5051 Value linearizeWithoutBack =
5052 rewriter.create<affine::AffineLinearizeIndexOp>(
5053 location: linearizeOp.getLoc(), args: linearizeOp.getMultiIndex().drop_back(),
5054 args: linearizeOp.getDynamicBasis(),
5055 args: linearizeOp.getStaticBasis().drop_back(),
5056 args: linearizeOp.getDisjoint());
5057 auto delinearizeWithoutSplitPart =
5058 rewriter.create<affine::AffineDelinearizeIndexOp>(
5059 location: delinearizeOp.getLoc(), args&: linearizeWithoutBack,
5060 args: delinearizeOp.getDynamicBasis(), args: basis.drop_back(N: elemsToSplit),
5061 args: delinearizeOp.hasOuterBound());
5062 auto delinearizeBack = rewriter.create<affine::AffineDelinearizeIndexOp>(
5063 location: delinearizeOp.getLoc(), args: linearizeOp.getMultiIndex().back(),
5064 args: basis.take_back(N: elemsToSplit), /*hasOuterBound=*/args: true);
5065 SmallVector<Value> results = llvm::to_vector(
5066 Range: llvm::concat<Value>(Ranges: delinearizeWithoutSplitPart.getResults(),
5067 Ranges: delinearizeBack.getResults()));
5068 rewriter.replaceOp(op: delinearizeOp, newValues: results);
5069
5070 return success();
5071 }
5072};
5073} // namespace
5074
5075void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
5076 RewritePatternSet &patterns, MLIRContext *context) {
5077 patterns
5078 .insert<CancelDelinearizeOfLinearizeDisjointExactTail,
5079 DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>(
5080 arg&: context);
5081}
5082
5083//===----------------------------------------------------------------------===//
5084// LinearizeIndexOp
5085//===----------------------------------------------------------------------===//
5086
5087void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5088 OperationState &odsState,
5089 ValueRange multiIndex, ValueRange basis,
5090 bool disjoint) {
5091 if (!basis.empty() && basis.front() == Value())
5092 basis = basis.drop_front();
5093 SmallVector<Value> dynamicBasis;
5094 SmallVector<int64_t> staticBasis;
5095 dispatchIndexOpFoldResults(ofrs: getAsOpFoldResult(values: basis), dynamicVec&: dynamicBasis,
5096 staticVec&: staticBasis);
5097 build(odsBuilder, odsState, multi_index: multiIndex, dynamic_basis: dynamicBasis, static_basis: staticBasis, disjoint);
5098}
5099
5100void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5101 OperationState &odsState,
5102 ValueRange multiIndex,
5103 ArrayRef<OpFoldResult> basis,
5104 bool disjoint) {
5105 if (!basis.empty() && basis.front() == OpFoldResult())
5106 basis = basis.drop_front();
5107 SmallVector<Value> dynamicBasis;
5108 SmallVector<int64_t> staticBasis;
5109 dispatchIndexOpFoldResults(ofrs: basis, dynamicVec&: dynamicBasis, staticVec&: staticBasis);
5110 build(odsBuilder, odsState, multi_index: multiIndex, dynamic_basis: dynamicBasis, static_basis: staticBasis, disjoint);
5111}
5112
5113void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5114 OperationState &odsState,
5115 ValueRange multiIndex,
5116 ArrayRef<int64_t> basis, bool disjoint) {
5117 build(odsBuilder, odsState, multi_index: multiIndex, dynamic_basis: ValueRange{}, static_basis: basis, disjoint);
5118}
5119
5120LogicalResult AffineLinearizeIndexOp::verify() {
5121 size_t numIndexes = getMultiIndex().size();
5122 size_t numBasisElems = getStaticBasis().size();
5123 if (numIndexes != numBasisElems && numIndexes != numBasisElems + 1)
5124 return emitOpError(message: "should be passed a basis element for each index except "
5125 "possibly the first");
5126
5127 auto dynamicMarkersCount =
5128 llvm::count_if(Range: getStaticBasis(), P: ShapedType::isDynamic);
5129 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
5130 return emitOpError(
5131 message: "mismatch between dynamic and static basis (kDynamic marker but no "
5132 "corresponding dynamic basis entry) -- this can only happen due to an "
5133 "incorrect fold/rewrite");
5134
5135 return success();
5136}
5137
5138OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
5139 std::optional<SmallVector<int64_t>> maybeStaticBasis =
5140 foldCstValueToCstAttrBasis(mixedBasis: getMixedBasis(), mutableDynamicBasis: getDynamicBasisMutable(),
5141 dynamicBasis: adaptor.getDynamicBasis());
5142 if (maybeStaticBasis) {
5143 setStaticBasis(*maybeStaticBasis);
5144 return getResult();
5145 }
5146 // No indices linearizes to zero.
5147 if (getMultiIndex().empty())
5148 return IntegerAttr::get(type: getResult().getType(), value: 0);
5149
5150 // One single index linearizes to itself.
5151 if (getMultiIndex().size() == 1)
5152 return getMultiIndex().front();
5153
5154 if (llvm::is_contained(Range: adaptor.getMultiIndex(), Element: nullptr))
5155 return nullptr;
5156
5157 if (!adaptor.getDynamicBasis().empty())
5158 return nullptr;
5159
5160 int64_t result = 0;
5161 int64_t stride = 1;
5162 for (auto [length, indexAttr] :
5163 llvm::zip_first(t: llvm::reverse(C: getStaticBasis()),
5164 u: llvm::reverse(C: adaptor.getMultiIndex()))) {
5165 result = result + cast<IntegerAttr>(Val: indexAttr).getInt() * stride;
5166 stride = stride * length;
5167 }
5168 // Handle the index element with no basis element.
5169 if (!hasOuterBound())
5170 result =
5171 result +
5172 cast<IntegerAttr>(Val: adaptor.getMultiIndex().front()).getInt() * stride;
5173
5174 return IntegerAttr::get(type: getResult().getType(), value: result);
5175}
5176
5177SmallVector<OpFoldResult> AffineLinearizeIndexOp::getEffectiveBasis() {
5178 OpBuilder builder(getContext());
5179 if (hasOuterBound()) {
5180 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
5181 return getMixedValues(staticValues: getStaticBasis().drop_front(),
5182 dynamicValues: getDynamicBasis().drop_front(), b&: builder);
5183
5184 return getMixedValues(staticValues: getStaticBasis().drop_front(), dynamicValues: getDynamicBasis(),
5185 b&: builder);
5186 }
5187
5188 return getMixedValues(staticValues: getStaticBasis(), dynamicValues: getDynamicBasis(), b&: builder);
5189}
5190
5191SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() {
5192 SmallVector<OpFoldResult> ret = getMixedBasis();
5193 if (!hasOuterBound())
5194 ret.insert(I: ret.begin(), Elt: OpFoldResult());
5195 return ret;
5196}
5197
5198namespace {
5199/// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
5200/// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
5201/// %...d)`.
5202
5203/// Note that `disjoint` is required here, because, without it, we could have
5204/// `affine.linearize_index [%...a, %c64, %...b] by (%...c, 1, %...d)`
5205/// is a valid operation where the `%c64` cannot be trivially dropped.
5206///
5207/// Alternatively, if `%x` in the above is a known constant 0, remove it even if
5208/// the operation isn't asserted to be `disjoint`.
5209struct DropLinearizeUnitComponentsIfDisjointOrZero final
5210 : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5211 using OpRewritePattern::OpRewritePattern;
5212
5213 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5214 PatternRewriter &rewriter) const override {
5215 ValueRange multiIndex = op.getMultiIndex();
5216 size_t numIndices = multiIndex.size();
5217 SmallVector<Value> newIndices;
5218 newIndices.reserve(N: numIndices);
5219 SmallVector<OpFoldResult> newBasis;
5220 newBasis.reserve(N: numIndices);
5221
5222 if (!op.hasOuterBound()) {
5223 newIndices.push_back(Elt: multiIndex.front());
5224 multiIndex = multiIndex.drop_front();
5225 }
5226
5227 SmallVector<OpFoldResult> basis = op.getMixedBasis();
5228 for (auto [index, basisElem] : llvm::zip_equal(t&: multiIndex, u&: basis)) {
5229 std::optional<int64_t> basisEntry = getConstantIntValue(ofr: basisElem);
5230 if (!basisEntry || *basisEntry != 1) {
5231 newIndices.push_back(Elt: index);
5232 newBasis.push_back(Elt: basisElem);
5233 continue;
5234 }
5235
5236 std::optional<int64_t> indexValue = getConstantIntValue(ofr: index);
5237 if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
5238 newIndices.push_back(Elt: index);
5239 newBasis.push_back(Elt: basisElem);
5240 continue;
5241 }
5242 }
5243 if (newIndices.size() == numIndices)
5244 return rewriter.notifyMatchFailure(arg&: op,
5245 msg: "no unit basis entries to replace");
5246
5247 if (newIndices.size() == 0) {
5248 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, args: 0);
5249 return success();
5250 }
5251 rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
5252 op, args&: newIndices, args&: newBasis, args: op.getDisjoint());
5253 return success();
5254 }
5255};
5256
5257OpFoldResult computeProduct(Location loc, OpBuilder &builder,
5258 ArrayRef<OpFoldResult> terms) {
5259 int64_t nDynamic = 0;
5260 SmallVector<Value> dynamicPart;
5261 AffineExpr result = builder.getAffineConstantExpr(constant: 1);
5262 for (OpFoldResult term : terms) {
5263 if (!term)
5264 return term;
5265 std::optional<int64_t> maybeConst = getConstantIntValue(ofr: term);
5266 if (maybeConst) {
5267 result = result * builder.getAffineConstantExpr(constant: *maybeConst);
5268 } else {
5269 dynamicPart.push_back(Elt: cast<Value>(Val&: term));
5270 result = result * builder.getAffineSymbolExpr(position: nDynamic++);
5271 }
5272 }
5273 if (auto constant = dyn_cast<AffineConstantExpr>(Val&: result))
5274 return getAsIndexOpFoldResult(ctx: builder.getContext(), val: constant.getValue());
5275 return builder.create<AffineApplyOp>(location: loc, args&: result, args&: dynamicPart).getResult();
5276}
5277
5278/// If conseceutive outputs of a delinearize_index are linearized with the same
5279/// bounds, canonicalize away the redundant arithmetic.
5280///
5281/// That is, if we have
5282/// ```
5283/// %s:N = affine.delinearize_index %x into (...a, B1, B2, ... BK, ...b)
5284/// %t = affine.linearize_index [...c, %s#I, %s#(I + 1), ... %s#(I+K-1), ...d]
5285/// by (...e, B1, B2, ..., BK, ...f)
5286/// ```
5287///
5288/// We can rewrite this to
5289/// ```
5290/// B = B1 * B2 ... BK
5291/// %sMerged:(N-K+1) affine.delinearize_index %x into (...a, B, ...b)
5292/// %t = affine.linearize_index [...c, %s#I, ...d] by (...e, B, ...f)
5293/// ```
5294/// where we replace all results of %s unaffected by the change with results
5295/// from %sMerged.
5296///
5297/// As a special case, if all results of the delinearize are merged in this way
5298/// we can replace those usages with %x, thus cancelling the delinearization
5299/// entirely, as in
5300/// ```
5301/// %s:3 = affine.delinearize_index %x into (2, 4, 8)
5302/// %t = affine.linearize_index [%s#0, %s#1, %s#2, %c0] by (2, 4, 8, 16)
5303/// ```
5304/// becoming `%t = affine.linearize_index [%x, %c0] by (64, 16)`
5305struct CancelLinearizeOfDelinearizePortion final
5306 : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5307 using OpRewritePattern::OpRewritePattern;
5308
5309private:
5310 // Struct representing a case where the cancellation pattern
5311 // applies. A `Match` means that `length` inputs to the linearize operation
5312 // starting at `linStart` can be cancelled with `length` outputs of
5313 // `delinearize`, starting from `delinStart`.
5314 struct Match {
5315 AffineDelinearizeIndexOp delinearize;
5316 unsigned linStart = 0;
5317 unsigned delinStart = 0;
5318 unsigned length = 0;
5319 };
5320
5321public:
5322 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
5323 PatternRewriter &rewriter) const override {
5324 SmallVector<Match> matches;
5325
5326 const SmallVector<OpFoldResult> linBasis = linearizeOp.getPaddedBasis();
5327 ArrayRef<OpFoldResult> linBasisRef = linBasis;
5328
5329 ValueRange multiIndex = linearizeOp.getMultiIndex();
5330 unsigned numLinArgs = multiIndex.size();
5331 unsigned linArgIdx = 0;
5332 // We only want to replace one run from the same delinearize op per
5333 // pattern invocation lest we run into invalidation issues.
5334 llvm::SmallPtrSet<Operation *, 2> alreadyMatchedDelinearize;
5335 while (linArgIdx < numLinArgs) {
5336 auto asResult = dyn_cast<OpResult>(Val: multiIndex[linArgIdx]);
5337 if (!asResult) {
5338 linArgIdx++;
5339 continue;
5340 }
5341
5342 auto delinearizeOp =
5343 dyn_cast<AffineDelinearizeIndexOp>(Val: asResult.getOwner());
5344 if (!delinearizeOp) {
5345 linArgIdx++;
5346 continue;
5347 }
5348
5349 /// Result 0 of the delinearize and argument 0 of the linearize can
5350 /// leave their maximum value unspecified. However, even if this happens
5351 /// we can still sometimes start the match process. Specifically, if
5352 /// - The argument we're matching is result 0 and argument 0 (so the
5353 /// bounds don't matter). For example,
5354 ///
5355 /// %0:2 = affine.delinearize_index %x into (8) : index, index
5356 /// %1 = affine.linearize_index [%s#0, %s#1, ...] (8, ...)
5357 /// allows cancellation
5358 /// - The delinearization doesn't specify a bound, but the linearization
5359 /// is `disjoint`, which asserts that the bound on the linearization is
5360 /// correct.
5361 unsigned delinArgIdx = asResult.getResultNumber();
5362 SmallVector<OpFoldResult> delinBasis = delinearizeOp.getPaddedBasis();
5363 OpFoldResult firstDelinBound = delinBasis[delinArgIdx];
5364 OpFoldResult firstLinBound = linBasis[linArgIdx];
5365 bool boundsMatch = firstDelinBound == firstLinBound;
5366 bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0;
5367 bool knownByDisjoint =
5368 linearizeOp.getDisjoint() && delinArgIdx == 0 && !firstDelinBound;
5369 if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
5370 linArgIdx++;
5371 continue;
5372 }
5373
5374 unsigned j = 1;
5375 unsigned numDelinOuts = delinearizeOp.getNumResults();
5376 for (; j + linArgIdx < numLinArgs && j + delinArgIdx < numDelinOuts;
5377 ++j) {
5378 if (multiIndex[linArgIdx + j] !=
5379 delinearizeOp.getResult(i: delinArgIdx + j))
5380 break;
5381 if (linBasis[linArgIdx + j] != delinBasis[delinArgIdx + j])
5382 break;
5383 }
5384 // If there're multiple matches against the same delinearize_index,
5385 // only rewrite the first one we find to prevent invalidations. The next
5386 // ones will be taken care of by subsequent pattern invocations.
5387 if (j <= 1 || !alreadyMatchedDelinearize.insert(Ptr: delinearizeOp).second) {
5388 linArgIdx++;
5389 continue;
5390 }
5391 matches.push_back(Elt: Match{.delinearize: delinearizeOp, .linStart: linArgIdx, .delinStart: delinArgIdx, .length: j});
5392 linArgIdx += j;
5393 }
5394
5395 if (matches.empty())
5396 return rewriter.notifyMatchFailure(
5397 arg&: linearizeOp, msg: "no run of delinearize outputs to deal with");
5398
5399 // Record all the delinearize replacements so we can do them after creating
5400 // the new linearization operation, since the new operation might use
5401 // outputs of something we're replacing.
5402 SmallVector<SmallVector<Value>> delinearizeReplacements;
5403
5404 SmallVector<Value> newIndex;
5405 newIndex.reserve(N: numLinArgs);
5406 SmallVector<OpFoldResult> newBasis;
5407 newBasis.reserve(N: numLinArgs);
5408 unsigned prevMatchEnd = 0;
5409 for (Match m : matches) {
5410 unsigned gap = m.linStart - prevMatchEnd;
5411 llvm::append_range(C&: newIndex, R: multiIndex.slice(n: prevMatchEnd, m: gap));
5412 llvm::append_range(C&: newBasis, R: linBasisRef.slice(N: prevMatchEnd, M: gap));
5413 // Update here so we don't forget this during early continues
5414 prevMatchEnd = m.linStart + m.length;
5415
5416 PatternRewriter::InsertionGuard g(rewriter);
5417 rewriter.setInsertionPoint(m.delinearize);
5418
5419 ArrayRef<OpFoldResult> basisToMerge =
5420 linBasisRef.slice(N: m.linStart, M: m.length);
5421 // We use the slice from the linearize's basis above because of the
5422 // "bounds inferred from `disjoint`" case above.
5423 OpFoldResult newSize =
5424 computeProduct(loc: linearizeOp.getLoc(), builder&: rewriter, terms: basisToMerge);
5425
5426 // Trivial case where we can just skip past the delinearize all together
5427 if (m.length == m.delinearize.getNumResults()) {
5428 newIndex.push_back(Elt: m.delinearize.getLinearIndex());
5429 newBasis.push_back(Elt: newSize);
5430 // Pad out set of replacements so we don't do anything with this one.
5431 delinearizeReplacements.push_back(Elt: SmallVector<Value>());
5432 continue;
5433 }
5434
5435 SmallVector<Value> newDelinResults;
5436 SmallVector<OpFoldResult> newDelinBasis = m.delinearize.getPaddedBasis();
5437 newDelinBasis.erase(CS: newDelinBasis.begin() + m.delinStart,
5438 CE: newDelinBasis.begin() + m.delinStart + m.length);
5439 newDelinBasis.insert(I: newDelinBasis.begin() + m.delinStart, Elt: newSize);
5440 auto newDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
5441 location: m.delinearize.getLoc(), args: m.delinearize.getLinearIndex(),
5442 args&: newDelinBasis);
5443
5444 // Since there may be other uses of the indices we just merged together,
5445 // create a residual affine.delinearize_index that delinearizes the
5446 // merged output into its component parts.
5447 Value combinedElem = newDelinearize.getResult(i: m.delinStart);
5448 auto residualDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
5449 location: m.delinearize.getLoc(), args&: combinedElem, args&: basisToMerge);
5450
5451 // Swap all the uses of the unaffected delinearize outputs to the new
5452 // delinearization so that the old code can be removed if this
5453 // linearize_index is the only user of the merged results.
5454 llvm::append_range(C&: newDelinResults,
5455 R: newDelinearize.getResults().take_front(n: m.delinStart));
5456 llvm::append_range(C&: newDelinResults, R: residualDelinearize.getResults());
5457 llvm::append_range(
5458 C&: newDelinResults,
5459 R: newDelinearize.getResults().drop_front(n: m.delinStart + 1));
5460
5461 delinearizeReplacements.push_back(Elt: newDelinResults);
5462 newIndex.push_back(Elt: combinedElem);
5463 newBasis.push_back(Elt: newSize);
5464 }
5465 llvm::append_range(C&: newIndex, R: multiIndex.drop_front(n: prevMatchEnd));
5466 llvm::append_range(C&: newBasis, R: linBasisRef.drop_front(N: prevMatchEnd));
5467 rewriter.replaceOpWithNewOp<AffineLinearizeIndexOp>(
5468 op: linearizeOp, args&: newIndex, args&: newBasis, args: linearizeOp.getDisjoint());
5469
5470 for (auto [m, newResults] :
5471 llvm::zip_equal(t&: matches, u&: delinearizeReplacements)) {
5472 if (newResults.empty())
5473 continue;
5474 rewriter.replaceOp(op: m.delinearize, newValues: newResults);
5475 }
5476
5477 return success();
5478 }
5479};
5480
5481/// Strip leading zero from affine.linearize_index.
5482///
5483/// `affine.linearize_index [%c0, ...a] by (%x, ...b)` can be rewritten
5484/// to `affine.linearize_index [...a] by (...b)` in all cases.
5485struct DropLinearizeLeadingZero final
5486 : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5487 using OpRewritePattern::OpRewritePattern;
5488
5489 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5490 PatternRewriter &rewriter) const override {
5491 Value leadingIdx = op.getMultiIndex().front();
5492 if (!matchPattern(value: leadingIdx, pattern: m_Zero()))
5493 return failure();
5494
5495 if (op.getMultiIndex().size() == 1) {
5496 rewriter.replaceOp(op, newValues: leadingIdx);
5497 return success();
5498 }
5499
5500 SmallVector<OpFoldResult> mixedBasis = op.getMixedBasis();
5501 ArrayRef<OpFoldResult> newMixedBasis = mixedBasis;
5502 if (op.hasOuterBound())
5503 newMixedBasis = newMixedBasis.drop_front();
5504
5505 rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
5506 op, args: op.getMultiIndex().drop_front(), args&: newMixedBasis, args: op.getDisjoint());
5507 return success();
5508 }
5509};
5510} // namespace
5511
5512void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
5513 RewritePatternSet &patterns, MLIRContext *context) {
5514 patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
5515 DropLinearizeUnitComponentsIfDisjointOrZero>(arg&: context);
5516}
5517
5518//===----------------------------------------------------------------------===//
5519// TableGen'd op method definitions
5520//===----------------------------------------------------------------------===//
5521
5522#define GET_OP_CLASSES
5523#include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
5524

source code of mlir/lib/Dialect/Affine/IR/AffineOps.cpp