1//===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Affine/IR/AffineOps.h"
10#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
11#include "mlir/Dialect/MemRef/IR/MemRef.h"
12#include "mlir/Dialect/UB/IR/UBOps.h"
13#include "mlir/Dialect/Utils/StaticValueUtils.h"
14#include "mlir/IR/AffineExprVisitor.h"
15#include "mlir/IR/IRMapping.h"
16#include "mlir/IR/IntegerSet.h"
17#include "mlir/IR/Matchers.h"
18#include "mlir/IR/OpDefinition.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/Interfaces/ShapedOpInterfaces.h"
21#include "mlir/Interfaces/ValueBoundsOpInterface.h"
22#include "mlir/Transforms/InliningUtils.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/ScopeExit.h"
25#include "llvm/ADT/SmallBitVector.h"
26#include "llvm/ADT/SmallVectorExtras.h"
27#include "llvm/ADT/TypeSwitch.h"
28#include "llvm/Support/Debug.h"
29#include "llvm/Support/MathExtras.h"
30#include <numeric>
31#include <optional>
32
33using namespace mlir;
34using namespace mlir::affine;
35
36using llvm::divideCeilSigned;
37using llvm::divideFloorSigned;
38using llvm::mod;
39
40#define DEBUG_TYPE "affine-ops"
41
42#include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
43
44/// A utility function to check if a value is defined at the top level of
45/// `region` or is an argument of `region`. A value of index type defined at the
46/// top level of a `AffineScope` region is always a valid symbol for all
47/// uses in that region.
48bool mlir::affine::isTopLevelValue(Value value, Region *region) {
49 if (auto arg = llvm::dyn_cast<BlockArgument>(value))
50 return arg.getParentRegion() == region;
51 return value.getDefiningOp()->getParentRegion() == region;
52}
53
54/// Checks if `value` known to be a legal affine dimension or symbol in `src`
55/// region remains legal if the operation that uses it is inlined into `dest`
56/// with the given value mapping. `legalityCheck` is either `isValidDim` or
57/// `isValidSymbol`, depending on the value being required to remain a valid
58/// dimension or symbol.
59static bool
60remainsLegalAfterInline(Value value, Region *src, Region *dest,
61 const IRMapping &mapping,
62 function_ref<bool(Value, Region *)> legalityCheck) {
63 // If the value is a valid dimension for any other reason than being
64 // a top-level value, it will remain valid: constants get inlined
65 // with the function, transitive affine applies also get inlined and
66 // will be checked themselves, etc.
67 if (!isTopLevelValue(value, region: src))
68 return true;
69
70 // If it's a top-level value because it's a block operand, i.e. a
71 // function argument, check whether the value replacing it after
72 // inlining is a valid dimension in the new region.
73 if (llvm::isa<BlockArgument>(Val: value))
74 return legalityCheck(mapping.lookup(from: value), dest);
75
76 // If it's a top-level value because it's defined in the region,
77 // it can only be inlined if the defining op is a constant or a
78 // `dim`, which can appear anywhere and be valid, since the defining
79 // op won't be top-level anymore after inlining.
80 Attribute operandCst;
81 bool isDimLikeOp = isa<ShapedDimOpInterface>(value.getDefiningOp());
82 return matchPattern(op: value.getDefiningOp(), pattern: m_Constant(bind_value: &operandCst)) ||
83 isDimLikeOp;
84}
85
86/// Checks if all values known to be legal affine dimensions or symbols in `src`
87/// remain so if their respective users are inlined into `dest`.
88static bool
89remainsLegalAfterInline(ValueRange values, Region *src, Region *dest,
90 const IRMapping &mapping,
91 function_ref<bool(Value, Region *)> legalityCheck) {
92 return llvm::all_of(Range&: values, P: [&](Value v) {
93 return remainsLegalAfterInline(value: v, src, dest, mapping, legalityCheck);
94 });
95}
96
97/// Checks if an affine read or write operation remains legal after inlining
98/// from `src` to `dest`.
99template <typename OpTy>
100static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
101 const IRMapping &mapping) {
102 static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
103 AffineWriteOpInterface>::value,
104 "only ops with affine read/write interface are supported");
105
106 AffineMap map = op.getAffineMap();
107 ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
108 ValueRange symbolOperands =
109 op.getMapOperands().take_back(map.getNumSymbols());
110 if (!remainsLegalAfterInline(
111 values: dimOperands, src, dest, mapping,
112 legalityCheck: static_cast<bool (*)(Value, Region *)>(isValidDim)))
113 return false;
114 if (!remainsLegalAfterInline(
115 values: symbolOperands, src, dest, mapping,
116 legalityCheck: static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
117 return false;
118 return true;
119}
120
121/// Checks if an affine apply operation remains legal after inlining from `src`
122/// to `dest`.
123// Use "unused attribute" marker to silence clang-tidy warning stemming from
124// the inability to see through "llvm::TypeSwitch".
125template <>
126bool LLVM_ATTRIBUTE_UNUSED remainsLegalAfterInline(AffineApplyOp op,
127 Region *src, Region *dest,
128 const IRMapping &mapping) {
129 // If it's a valid dimension, we need to check that it remains so.
130 if (isValidDim(op.getResult(), src))
131 return remainsLegalAfterInline(
132 op.getMapOperands(), src, dest, mapping,
133 static_cast<bool (*)(Value, Region *)>(isValidDim));
134
135 // Otherwise it must be a valid symbol, check that it remains so.
136 return remainsLegalAfterInline(
137 op.getMapOperands(), src, dest, mapping,
138 static_cast<bool (*)(Value, Region *)>(isValidSymbol));
139}
140
141//===----------------------------------------------------------------------===//
142// AffineDialect Interfaces
143//===----------------------------------------------------------------------===//
144
145namespace {
146/// This class defines the interface for handling inlining with affine
147/// operations.
148struct AffineInlinerInterface : public DialectInlinerInterface {
149 using DialectInlinerInterface::DialectInlinerInterface;
150
151 //===--------------------------------------------------------------------===//
152 // Analysis Hooks
153 //===--------------------------------------------------------------------===//
154
155 /// Returns true if the given region 'src' can be inlined into the region
156 /// 'dest' that is attached to an operation registered to the current dialect.
157 /// 'wouldBeCloned' is set if the region is cloned into its new location
158 /// rather than moved, indicating there may be other users.
159 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
160 IRMapping &valueMapping) const final {
161 // We can inline into affine loops and conditionals if this doesn't break
162 // affine value categorization rules.
163 Operation *destOp = dest->getParentOp();
164 if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
165 return false;
166
167 // Multi-block regions cannot be inlined into affine constructs, all of
168 // which require single-block regions.
169 if (!llvm::hasSingleElement(C&: *src))
170 return false;
171
172 // Side-effecting operations that the affine dialect cannot understand
173 // should not be inlined.
174 Block &srcBlock = src->front();
175 for (Operation &op : srcBlock) {
176 // Ops with no side effects are fine,
177 if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
178 if (iface.hasNoEffect())
179 continue;
180 }
181
182 // Assuming the inlined region is valid, we only need to check if the
183 // inlining would change it.
184 bool remainsValid =
185 llvm::TypeSwitch<Operation *, bool>(&op)
186 .Case<AffineApplyOp, AffineReadOpInterface,
187 AffineWriteOpInterface>([&](auto op) {
188 return remainsLegalAfterInline(op, src, dest, valueMapping);
189 })
190 .Default([](Operation *) {
191 // Conservatively disallow inlining ops we cannot reason about.
192 return false;
193 });
194
195 if (!remainsValid)
196 return false;
197 }
198
199 return true;
200 }
201
202 /// Returns true if the given operation 'op', that is registered to this
203 /// dialect, can be inlined into the given region, false otherwise.
204 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
205 IRMapping &valueMapping) const final {
206 // Always allow inlining affine operations into a region that is marked as
207 // affine scope, or into affine loops and conditionals. There are some edge
208 // cases when inlining *into* affine structures, but that is handled in the
209 // other 'isLegalToInline' hook above.
210 Operation *parentOp = region->getParentOp();
211 return parentOp->hasTrait<OpTrait::AffineScope>() ||
212 isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
213 }
214
215 /// Affine regions should be analyzed recursively.
216 bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
217};
218} // namespace
219
220//===----------------------------------------------------------------------===//
221// AffineDialect
222//===----------------------------------------------------------------------===//
223
224void AffineDialect::initialize() {
225 addOperations<AffineDmaStartOp, AffineDmaWaitOp,
226#define GET_OP_LIST
227#include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
228 >();
229 addInterfaces<AffineInlinerInterface>();
230 declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
231 AffineMinOp>();
232}
233
234/// Materialize a single constant operation from a given attribute value with
235/// the desired resultant type.
236Operation *AffineDialect::materializeConstant(OpBuilder &builder,
237 Attribute value, Type type,
238 Location loc) {
239 if (auto poison = dyn_cast<ub::PoisonAttr>(value))
240 return builder.create<ub::PoisonOp>(loc, type, poison);
241 return arith::ConstantOp::materialize(builder, value, type, loc);
242}
243
244/// A utility function to check if a value is defined at the top level of an
245/// op with trait `AffineScope`. If the value is defined in an unlinked region,
246/// conservatively assume it is not top-level. A value of index type defined at
247/// the top level is always a valid symbol.
248bool mlir::affine::isTopLevelValue(Value value) {
249 if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
250 // The block owning the argument may be unlinked, e.g. when the surrounding
251 // region has not yet been attached to an Op, at which point the parent Op
252 // is null.
253 Operation *parentOp = arg.getOwner()->getParentOp();
254 return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
255 }
256 // The defining Op may live in an unlinked block so its parent Op may be null.
257 Operation *parentOp = value.getDefiningOp()->getParentOp();
258 return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
259}
260
261/// Returns the closest region enclosing `op` that is held by an operation with
262/// trait `AffineScope`; `nullptr` if there is no such region.
263Region *mlir::affine::getAffineScope(Operation *op) {
264 auto *curOp = op;
265 while (auto *parentOp = curOp->getParentOp()) {
266 if (parentOp->hasTrait<OpTrait::AffineScope>())
267 return curOp->getParentRegion();
268 curOp = parentOp;
269 }
270 return nullptr;
271}
272
273Region *mlir::affine::getAffineAnalysisScope(Operation *op) {
274 Operation *curOp = op;
275 while (auto *parentOp = curOp->getParentOp()) {
276 if (!isa<AffineForOp, AffineIfOp, AffineParallelOp>(parentOp))
277 return curOp->getParentRegion();
278 curOp = parentOp;
279 }
280 return nullptr;
281}
282
283// A Value can be used as a dimension id iff it meets one of the following
284// conditions:
285// *) It is valid as a symbol.
286// *) It is an induction variable.
287// *) It is the result of affine apply operation with dimension id arguments.
288bool mlir::affine::isValidDim(Value value) {
289 // The value must be an index type.
290 if (!value.getType().isIndex())
291 return false;
292
293 if (auto *defOp = value.getDefiningOp())
294 return isValidDim(value, region: getAffineScope(op: defOp));
295
296 // This value has to be a block argument for an op that has the
297 // `AffineScope` trait or an induction var of an affine.for or
298 // affine.parallel.
299 if (isAffineInductionVar(val: value))
300 return true;
301 auto *parentOp = llvm::cast<BlockArgument>(Val&: value).getOwner()->getParentOp();
302 return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
303}
304
305// Value can be used as a dimension id iff it meets one of the following
306// conditions:
307// *) It is valid as a symbol.
308// *) It is an induction variable.
309// *) It is the result of an affine apply operation with dimension id operands.
310// *) It is the result of a more specialized index transformation (ex.
311// delinearize_index or linearize_index) with dimension id operands.
312bool mlir::affine::isValidDim(Value value, Region *region) {
313 // The value must be an index type.
314 if (!value.getType().isIndex())
315 return false;
316
317 // All valid symbols are okay.
318 if (isValidSymbol(value, region))
319 return true;
320
321 auto *op = value.getDefiningOp();
322 if (!op) {
323 // This value has to be an induction var for an affine.for or an
324 // affine.parallel.
325 return isAffineInductionVar(val: value);
326 }
327
328 // Affine apply operation is ok if all of its operands are ok.
329 if (auto applyOp = dyn_cast<AffineApplyOp>(op))
330 return applyOp.isValidDim(region);
331 // delinearize_index and linearize_index are special forms of apply
332 // and so are valid dimensions if all their arguments are valid dimensions.
333 if (isa<AffineDelinearizeIndexOp, AffineLinearizeIndexOp>(op))
334 return llvm::all_of(Range: op->getOperands(),
335 P: [&](Value arg) { return ::isValidDim(value: arg, region); });
336 // The dim op is okay if its operand memref/tensor is defined at the top
337 // level.
338 if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
339 return isTopLevelValue(dimOp.getShapedValue());
340 return false;
341}
342
343/// Returns true if the 'index' dimension of the `memref` defined by
344/// `memrefDefOp` is a statically shaped one or defined using a valid symbol
345/// for `region`.
346template <typename AnyMemRefDefOp>
347static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
348 Region *region) {
349 MemRefType memRefType = memrefDefOp.getType();
350
351 // Dimension index is out of bounds.
352 if (index >= memRefType.getRank()) {
353 return false;
354 }
355
356 // Statically shaped.
357 if (!memRefType.isDynamicDim(index))
358 return true;
359 // Get the position of the dimension among dynamic dimensions;
360 unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
361 return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
362 region);
363}
364
365/// Returns true if the result of the dim op is a valid symbol for `region`.
366static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
367 // The dim op is okay if its source is defined at the top level.
368 if (isTopLevelValue(dimOp.getShapedValue()))
369 return true;
370
371 // Conservatively handle remaining BlockArguments as non-valid symbols.
372 // E.g. scf.for iterArgs.
373 if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
374 return false;
375
376 // The dim op is also okay if its operand memref is a view/subview whose
377 // corresponding size is a valid symbol.
378 std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension());
379
380 // Be conservative if we can't understand the dimension.
381 if (!index.has_value())
382 return false;
383
384 // Skip over all memref.cast ops (if any).
385 Operation *op = dimOp.getShapedValue().getDefiningOp();
386 while (auto castOp = dyn_cast<memref::CastOp>(op)) {
387 // Bail on unranked memrefs.
388 if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
389 return false;
390 op = castOp.getSource().getDefiningOp();
391 if (!op)
392 return false;
393 }
394
395 int64_t i = index.value();
396 return TypeSwitch<Operation *, bool>(op)
397 .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
398 [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
399 .Default([](Operation *) { return false; });
400}
401
402// A value can be used as a symbol (at all its use sites) iff it meets one of
403// the following conditions:
404// *) It is a constant.
405// *) Its defining op or block arg appearance is immediately enclosed by an op
406// with `AffineScope` trait.
407// *) It is the result of an affine.apply operation with symbol operands.
408// *) It is a result of the dim op on a memref whose corresponding size is a
409// valid symbol.
410bool mlir::affine::isValidSymbol(Value value) {
411 if (!value)
412 return false;
413
414 // The value must be an index type.
415 if (!value.getType().isIndex())
416 return false;
417
418 // Check that the value is a top level value.
419 if (isTopLevelValue(value))
420 return true;
421
422 if (auto *defOp = value.getDefiningOp())
423 return isValidSymbol(value, region: getAffineScope(op: defOp));
424
425 return false;
426}
427
428/// A value can be used as a symbol for `region` iff it meets one of the
429/// following conditions:
430/// *) It is a constant.
431/// *) It is a result of a `Pure` operation whose operands are valid symbolic
432/// *) identifiers.
433/// *) It is a result of the dim op on a memref whose corresponding size is
434/// a valid symbol.
435/// *) It is defined at the top level of 'region' or is its argument.
436/// *) It dominates `region`'s parent op.
437/// If `region` is null, conservatively assume the symbol definition scope does
438/// not exist and only accept the values that would be symbols regardless of
439/// the surrounding region structure, i.e. the first three cases above.
440bool mlir::affine::isValidSymbol(Value value, Region *region) {
441 // The value must be an index type.
442 if (!value.getType().isIndex())
443 return false;
444
445 // A top-level value is a valid symbol.
446 if (region && ::isTopLevelValue(value, region))
447 return true;
448
449 auto *defOp = value.getDefiningOp();
450 if (!defOp) {
451 // A block argument that is not a top-level value is a valid symbol if it
452 // dominates region's parent op.
453 Operation *regionOp = region ? region->getParentOp() : nullptr;
454 if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
455 if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
456 return isValidSymbol(value, region: parentOpRegion);
457 return false;
458 }
459
460 // Constant operation is ok.
461 Attribute operandCst;
462 if (matchPattern(op: defOp, pattern: m_Constant(bind_value: &operandCst)))
463 return true;
464
465 // `Pure` operation that whose operands are valid symbolic identifiers.
466 if (isPure(op: defOp) && llvm::all_of(Range: defOp->getOperands(), P: [&](Value operand) {
467 return affine::isValidSymbol(value: operand, region);
468 })) {
469 return true;
470 }
471
472 // Dim op results could be valid symbols at any level.
473 if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
474 return isDimOpValidSymbol(dimOp, region);
475
476 // Check for values dominating `region`'s parent op.
477 Operation *regionOp = region ? region->getParentOp() : nullptr;
478 if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
479 if (auto *parentRegion = region->getParentOp()->getParentRegion())
480 return isValidSymbol(value, region: parentRegion);
481
482 return false;
483}
484
485// Returns true if 'value' is a valid index to an affine operation (e.g.
486// affine.load, affine.store, affine.dma_start, affine.dma_wait) where
487// `region` provides the polyhedral symbol scope. Returns false otherwise.
488static bool isValidAffineIndexOperand(Value value, Region *region) {
489 return isValidDim(value, region) || isValidSymbol(value, region);
490}
491
492/// Prints dimension and symbol list.
493static void printDimAndSymbolList(Operation::operand_iterator begin,
494 Operation::operand_iterator end,
495 unsigned numDims, OpAsmPrinter &printer) {
496 OperandRange operands(begin, end);
497 printer << '(' << operands.take_front(n: numDims) << ')';
498 if (operands.size() > numDims)
499 printer << '[' << operands.drop_front(n: numDims) << ']';
500}
501
502/// Parses dimension and symbol list and returns true if parsing failed.
503ParseResult mlir::affine::parseDimAndSymbolList(
504 OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
505 SmallVector<OpAsmParser::UnresolvedOperand, 8> opInfos;
506 if (parser.parseOperandList(result&: opInfos, delimiter: OpAsmParser::Delimiter::Paren))
507 return failure();
508 // Store number of dimensions for validation by caller.
509 numDims = opInfos.size();
510
511 // Parse the optional symbol operands.
512 auto indexTy = parser.getBuilder().getIndexType();
513 return failure(parser.parseOperandList(
514 result&: opInfos, delimiter: OpAsmParser::Delimiter::OptionalSquare) ||
515 parser.resolveOperands(opInfos, indexTy, operands));
516}
517
518/// Utility function to verify that a set of operands are valid dimension and
519/// symbol identifiers. The operands should be laid out such that the dimension
520/// operands are before the symbol operands. This function returns failure if
521/// there was an invalid operand. An operation is provided to emit any necessary
522/// errors.
523template <typename OpTy>
524static LogicalResult
525verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands,
526 unsigned numDims) {
527 unsigned opIt = 0;
528 for (auto operand : operands) {
529 if (opIt++ < numDims) {
530 if (!isValidDim(operand, getAffineScope(op)))
531 return op.emitOpError("operand cannot be used as a dimension id");
532 } else if (!isValidSymbol(operand, getAffineScope(op))) {
533 return op.emitOpError("operand cannot be used as a symbol");
534 }
535 }
536 return success();
537}
538
539//===----------------------------------------------------------------------===//
540// AffineApplyOp
541//===----------------------------------------------------------------------===//
542
543AffineValueMap AffineApplyOp::getAffineValueMap() {
544 return AffineValueMap(getAffineMap(), getOperands(), getResult());
545}
546
547ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
548 auto &builder = parser.getBuilder();
549 auto indexTy = builder.getIndexType();
550
551 AffineMapAttr mapAttr;
552 unsigned numDims;
553 if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
554 parseDimAndSymbolList(parser, result.operands, numDims) ||
555 parser.parseOptionalAttrDict(result.attributes))
556 return failure();
557 auto map = mapAttr.getValue();
558
559 if (map.getNumDims() != numDims ||
560 numDims + map.getNumSymbols() != result.operands.size()) {
561 return parser.emitError(parser.getNameLoc(),
562 "dimension or symbol index mismatch");
563 }
564
565 result.types.append(map.getNumResults(), indexTy);
566 return success();
567}
568
569void AffineApplyOp::print(OpAsmPrinter &p) {
570 p << " " << getMapAttr();
571 printDimAndSymbolList(operand_begin(), operand_end(),
572 getAffineMap().getNumDims(), p);
573 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map"});
574}
575
576LogicalResult AffineApplyOp::verify() {
577 // Check input and output dimensions match.
578 AffineMap affineMap = getMap();
579
580 // Verify that operand count matches affine map dimension and symbol count.
581 if (getNumOperands() != affineMap.getNumDims() + affineMap.getNumSymbols())
582 return emitOpError(
583 "operand count and affine map dimension and symbol count must match");
584
585 // Verify that the map only produces one result.
586 if (affineMap.getNumResults() != 1)
587 return emitOpError("mapping must produce one value");
588
589 // Do not allow valid dims to be used in symbol positions. We do allow
590 // affine.apply to use operands for values that may neither qualify as affine
591 // dims or affine symbols due to usage outside of affine ops, analyses, etc.
592 Region *region = getAffineScope(*this);
593 for (Value operand : getMapOperands().drop_front(affineMap.getNumDims())) {
594 if (::isValidDim(operand, region) && !::isValidSymbol(operand, region))
595 return emitError("dimensional operand cannot be used as a symbol");
596 }
597
598 return success();
599}
600
601// The result of the affine apply operation can be used as a dimension id if all
602// its operands are valid dimension ids.
603bool AffineApplyOp::isValidDim() {
604 return llvm::all_of(getOperands(),
605 [](Value op) { return affine::isValidDim(op); });
606}
607
608// The result of the affine apply operation can be used as a dimension id if all
609// its operands are valid dimension ids with the parent operation of `region`
610// defining the polyhedral scope for symbols.
611bool AffineApplyOp::isValidDim(Region *region) {
612 return llvm::all_of(getOperands(),
613 [&](Value op) { return ::isValidDim(op, region); });
614}
615
616// The result of the affine apply operation can be used as a symbol if all its
617// operands are symbols.
618bool AffineApplyOp::isValidSymbol() {
619 return llvm::all_of(getOperands(),
620 [](Value op) { return affine::isValidSymbol(op); });
621}
622
623// The result of the affine apply operation can be used as a symbol in `region`
624// if all its operands are symbols in `region`.
625bool AffineApplyOp::isValidSymbol(Region *region) {
626 return llvm::all_of(getOperands(), [&](Value operand) {
627 return affine::isValidSymbol(operand, region);
628 });
629}
630
631OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
632 auto map = getAffineMap();
633
634 // Fold dims and symbols to existing values.
635 auto expr = map.getResult(0);
636 if (auto dim = dyn_cast<AffineDimExpr>(expr))
637 return getOperand(dim.getPosition());
638 if (auto sym = dyn_cast<AffineSymbolExpr>(expr))
639 return getOperand(map.getNumDims() + sym.getPosition());
640
641 // Otherwise, default to folding the map.
642 SmallVector<Attribute, 1> result;
643 bool hasPoison = false;
644 auto foldResult =
645 map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
646 if (hasPoison)
647 return ub::PoisonAttr::get(getContext());
648 if (failed(foldResult))
649 return {};
650 return result[0];
651}
652
653/// Returns the largest known divisor of `e`. Exploits information from the
654/// values in `operands`.
655static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef<Value> operands) {
656 // This method isn't aware of `operands`.
657 int64_t div = e.getLargestKnownDivisor();
658
659 // We now make use of operands for the case `e` is a dim expression.
660 // TODO: More powerful simplification would have to modify
661 // getLargestKnownDivisor to take `operands` and exploit that information as
662 // well for dim/sym expressions, but in that case, getLargestKnownDivisor
663 // can't be part of the IR library but of the `Analysis` library. The IR
664 // library can only really depend on simple O(1) checks.
665 auto dimExpr = dyn_cast<AffineDimExpr>(Val&: e);
666 // If it's not a dim expr, `div` is the best we have.
667 if (!dimExpr)
668 return div;
669
670 // We simply exploit information from loop IVs.
671 // We don't need to use mlir::getLargestKnownDivisorOfValue since the other
672 // desired simplifications are expected to be part of other
673 // canonicalizations. Also, mlir::getLargestKnownDivisorOfValue is part of the
674 // LoopAnalysis library.
675 Value operand = operands[dimExpr.getPosition()];
676 int64_t operandDivisor = 1;
677 // TODO: With the right accessors, this can be extended to
678 // LoopLikeOpInterface.
679 if (AffineForOp forOp = getForInductionVarOwner(operand)) {
680 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
681 operandDivisor = forOp.getStepAsInt();
682 } else {
683 uint64_t lbLargestKnownDivisor =
684 forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
685 operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
686 }
687 }
688 return operandDivisor;
689}
690
691/// Check if `e` is known to be: 0 <= `e` < `k`. Handles the simple cases of `e`
692/// being an affine dim expression or a constant.
693static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef<Value> operands,
694 int64_t k) {
695 if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: e)) {
696 int64_t constVal = constExpr.getValue();
697 return constVal >= 0 && constVal < k;
698 }
699 auto dimExpr = dyn_cast<AffineDimExpr>(Val&: e);
700 if (!dimExpr)
701 return false;
702 Value operand = operands[dimExpr.getPosition()];
703 // TODO: With the right accessors, this can be extended to
704 // LoopLikeOpInterface.
705 if (AffineForOp forOp = getForInductionVarOwner(operand)) {
706 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
707 forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
708 return true;
709 }
710 }
711
712 // We don't consider other cases like `operand` being defined by a constant or
713 // an affine.apply op since such cases will already be handled by other
714 // patterns and propagation of loop IVs or constant would happen.
715 return false;
716}
717
718/// Check if expression `e` is of the form d*e_1 + e_2 where 0 <= e_2 < d.
719/// Set `div` to `d`, `quotientTimesDiv` to e_1 and `rem` to e_2 if the
720/// expression is in that form.
721static bool isQTimesDPlusR(AffineExpr e, ArrayRef<Value> operands, int64_t &div,
722 AffineExpr &quotientTimesDiv, AffineExpr &rem) {
723 auto bin = dyn_cast<AffineBinaryOpExpr>(Val&: e);
724 if (!bin || bin.getKind() != AffineExprKind::Add)
725 return false;
726
727 AffineExpr llhs = bin.getLHS();
728 AffineExpr rlhs = bin.getRHS();
729 div = getLargestKnownDivisor(e: llhs, operands);
730 if (isNonNegativeBoundedBy(e: rlhs, operands, k: div)) {
731 quotientTimesDiv = llhs;
732 rem = rlhs;
733 return true;
734 }
735 div = getLargestKnownDivisor(e: rlhs, operands);
736 if (isNonNegativeBoundedBy(e: llhs, operands, k: div)) {
737 quotientTimesDiv = rlhs;
738 rem = llhs;
739 return true;
740 }
741 return false;
742}
743
744/// Gets the constant lower bound on an `iv`.
745static std::optional<int64_t> getLowerBound(Value iv) {
746 AffineForOp forOp = getForInductionVarOwner(iv);
747 if (forOp && forOp.hasConstantLowerBound())
748 return forOp.getConstantLowerBound();
749 return std::nullopt;
750}
751
752/// Gets the constant upper bound on an affine.for `iv`.
753static std::optional<int64_t> getUpperBound(Value iv) {
754 AffineForOp forOp = getForInductionVarOwner(iv);
755 if (!forOp || !forOp.hasConstantUpperBound())
756 return std::nullopt;
757
758 // If its lower bound is also known, we can get a more precise bound
759 // whenever the step is not one.
760 if (forOp.hasConstantLowerBound()) {
761 return forOp.getConstantUpperBound() - 1 -
762 (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
763 forOp.getStepAsInt();
764 }
765 return forOp.getConstantUpperBound() - 1;
766}
767
768/// Determine a constant upper bound for `expr` if one exists while exploiting
769/// values in `operands`. Note that the upper bound is an inclusive one. `expr`
770/// is guaranteed to be less than or equal to it.
771static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
772 unsigned numSymbols,
773 ArrayRef<Value> operands) {
774 // Get the constant lower or upper bounds on the operands.
775 SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
776 constLowerBounds.reserve(N: operands.size());
777 constUpperBounds.reserve(N: operands.size());
778 for (Value operand : operands) {
779 constLowerBounds.push_back(Elt: getLowerBound(iv: operand));
780 constUpperBounds.push_back(Elt: getUpperBound(iv: operand));
781 }
782
783 if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: expr))
784 return constExpr.getValue();
785
786 return getBoundForAffineExpr(expr, numDims, numSymbols, constLowerBounds,
787 constUpperBounds,
788 /*isUpper=*/true);
789}
790
791/// Determine a constant lower bound for `expr` if one exists while exploiting
792/// values in `operands`. Note that the upper bound is an inclusive one. `expr`
793/// is guaranteed to be less than or equal to it.
794static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
795 unsigned numSymbols,
796 ArrayRef<Value> operands) {
797 // Get the constant lower or upper bounds on the operands.
798 SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
799 constLowerBounds.reserve(N: operands.size());
800 constUpperBounds.reserve(N: operands.size());
801 for (Value operand : operands) {
802 constLowerBounds.push_back(Elt: getLowerBound(iv: operand));
803 constUpperBounds.push_back(Elt: getUpperBound(iv: operand));
804 }
805
806 std::optional<int64_t> lowerBound;
807 if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: expr)) {
808 lowerBound = constExpr.getValue();
809 } else {
810 lowerBound = getBoundForAffineExpr(expr, numDims, numSymbols,
811 constLowerBounds, constUpperBounds,
812 /*isUpper=*/false);
813 }
814 return lowerBound;
815}
816
817/// Simplify `expr` while exploiting information from the values in `operands`.
818static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
819 unsigned numSymbols,
820 ArrayRef<Value> operands) {
821 // We do this only for certain floordiv/mod expressions.
822 auto binExpr = dyn_cast<AffineBinaryOpExpr>(Val&: expr);
823 if (!binExpr)
824 return;
825
826 // Simplify the child expressions first.
827 AffineExpr lhs = binExpr.getLHS();
828 AffineExpr rhs = binExpr.getRHS();
829 simplifyExprAndOperands(expr&: lhs, numDims, numSymbols, operands);
830 simplifyExprAndOperands(expr&: rhs, numDims, numSymbols, operands);
831 expr = getAffineBinaryOpExpr(kind: binExpr.getKind(), lhs, rhs);
832
833 binExpr = dyn_cast<AffineBinaryOpExpr>(Val&: expr);
834 if (!binExpr || (expr.getKind() != AffineExprKind::FloorDiv &&
835 expr.getKind() != AffineExprKind::CeilDiv &&
836 expr.getKind() != AffineExprKind::Mod)) {
837 return;
838 }
839
840 // The `lhs` and `rhs` may be different post construction of simplified expr.
841 lhs = binExpr.getLHS();
842 rhs = binExpr.getRHS();
843 auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhs);
844 if (!rhsConst)
845 return;
846
847 int64_t rhsConstVal = rhsConst.getValue();
848 // Undefined exprsessions aren't touched; IR can still be valid with them.
849 if (rhsConstVal <= 0)
850 return;
851
852 // Exploit constant lower/upper bounds to simplify a floordiv or mod.
853 MLIRContext *context = expr.getContext();
854 std::optional<int64_t> lhsLbConst =
855 getLowerBound(expr: lhs, numDims, numSymbols, operands);
856 std::optional<int64_t> lhsUbConst =
857 getUpperBound(expr: lhs, numDims, numSymbols, operands);
858 if (lhsLbConst && lhsUbConst) {
859 int64_t lhsLbConstVal = *lhsLbConst;
860 int64_t lhsUbConstVal = *lhsUbConst;
861 // lhs floordiv c is a single value lhs is bounded in a range `c` that has
862 // the same quotient.
863 if (binExpr.getKind() == AffineExprKind::FloorDiv &&
864 divideFloorSigned(Numerator: lhsLbConstVal, Denominator: rhsConstVal) ==
865 divideFloorSigned(Numerator: lhsUbConstVal, Denominator: rhsConstVal)) {
866 expr = getAffineConstantExpr(
867 constant: divideFloorSigned(Numerator: lhsLbConstVal, Denominator: rhsConstVal), context);
868 return;
869 }
870 // lhs ceildiv c is a single value if the entire range has the same ceil
871 // quotient.
872 if (binExpr.getKind() == AffineExprKind::CeilDiv &&
873 divideCeilSigned(Numerator: lhsLbConstVal, Denominator: rhsConstVal) ==
874 divideCeilSigned(Numerator: lhsUbConstVal, Denominator: rhsConstVal)) {
875 expr = getAffineConstantExpr(constant: divideCeilSigned(Numerator: lhsLbConstVal, Denominator: rhsConstVal),
876 context);
877 return;
878 }
879 // lhs mod c is lhs if the entire range has quotient 0 w.r.t the rhs.
880 if (binExpr.getKind() == AffineExprKind::Mod && lhsLbConstVal >= 0 &&
881 lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
882 expr = lhs;
883 return;
884 }
885 }
886
887 // Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2)
888 // mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if
889 // `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c.
890 // And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c.
891 AffineExpr quotientTimesDiv, rem;
892 int64_t divisor;
893 if (isQTimesDPlusR(e: lhs, operands, div&: divisor, quotientTimesDiv, rem)) {
894 if (rhsConstVal % divisor == 0 &&
895 binExpr.getKind() == AffineExprKind::FloorDiv) {
896 expr = quotientTimesDiv.floorDiv(other: rhsConst);
897 } else if (divisor % rhsConstVal == 0 &&
898 binExpr.getKind() == AffineExprKind::Mod) {
899 expr = rem % rhsConst;
900 }
901 return;
902 }
903
904 // Handle the simple case when the LHS expression can be either upper
905 // bounded or is a known multiple of RHS constant.
906 // lhs floordiv c -> 0 if 0 <= lhs < c,
907 // lhs mod c -> 0 if lhs % c = 0.
908 if ((isNonNegativeBoundedBy(e: lhs, operands, k: rhsConstVal) &&
909 binExpr.getKind() == AffineExprKind::FloorDiv) ||
910 (getLargestKnownDivisor(e: lhs, operands) % rhsConstVal == 0 &&
911 binExpr.getKind() == AffineExprKind::Mod)) {
912 expr = getAffineConstantExpr(constant: 0, context: expr.getContext());
913 }
914}
915
916/// Simplify the expressions in `map` while making use of lower or upper bounds
917/// of its operands. If `isMax` is true, the map is to be treated as a max of
918/// its result expressions, and min otherwise. Eg: min (d0, d1) -> (8, 4 * d0 +
919/// d1) can be simplified to (8) if the operands are respectively lower bounded
920/// by 2 and 0 (the second expression can't be lower than 8).
921static void simplifyMinOrMaxExprWithOperands(AffineMap &map,
922 ArrayRef<Value> operands,
923 bool isMax) {
924 // Can't simplify.
925 if (operands.empty())
926 return;
927
928 // Get the upper or lower bound on an affine.for op IV using its range.
929 // Get the constant lower or upper bounds on the operands.
930 SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
931 constLowerBounds.reserve(N: operands.size());
932 constUpperBounds.reserve(N: operands.size());
933 for (Value operand : operands) {
934 constLowerBounds.push_back(Elt: getLowerBound(iv: operand));
935 constUpperBounds.push_back(Elt: getUpperBound(iv: operand));
936 }
937
938 // We will compute the lower and upper bounds on each of the expressions
939 // Then, we will check (depending on max or min) as to whether a specific
940 // bound is redundant by checking if its highest (in case of max) and its
941 // lowest (in the case of min) value is already lower than (or higher than)
942 // the lower bound (or upper bound in the case of min) of another bound.
943 SmallVector<std::optional<int64_t>, 4> lowerBounds, upperBounds;
944 lowerBounds.reserve(N: map.getNumResults());
945 upperBounds.reserve(N: map.getNumResults());
946 for (AffineExpr e : map.getResults()) {
947 if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: e)) {
948 lowerBounds.push_back(Elt: constExpr.getValue());
949 upperBounds.push_back(Elt: constExpr.getValue());
950 } else {
951 lowerBounds.push_back(
952 Elt: getBoundForAffineExpr(expr: e, numDims: map.getNumDims(), numSymbols: map.getNumSymbols(),
953 constLowerBounds, constUpperBounds,
954 /*isUpper=*/false));
955 upperBounds.push_back(
956 Elt: getBoundForAffineExpr(expr: e, numDims: map.getNumDims(), numSymbols: map.getNumSymbols(),
957 constLowerBounds, constUpperBounds,
958 /*isUpper=*/true));
959 }
960 }
961
962 // Collect expressions that are not redundant.
963 SmallVector<AffineExpr, 4> irredundantExprs;
964 for (auto exprEn : llvm::enumerate(First: map.getResults())) {
965 AffineExpr e = exprEn.value();
966 unsigned i = exprEn.index();
967 // Some expressions can be turned into constants.
968 if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
969 e = getAffineConstantExpr(constant: *lowerBounds[i], context: e.getContext());
970
971 // Check if the expression is redundant.
972 if (isMax) {
973 if (!upperBounds[i]) {
974 irredundantExprs.push_back(Elt: e);
975 continue;
976 }
977 // If there exists another expression such that its lower bound is greater
978 // than this expression's upper bound, it's redundant.
979 if (!llvm::any_of(Range: llvm::enumerate(First&: lowerBounds), P: [&](const auto &en) {
980 auto otherLowerBound = en.value();
981 unsigned pos = en.index();
982 if (pos == i || !otherLowerBound)
983 return false;
984 if (*otherLowerBound > *upperBounds[i])
985 return true;
986 if (*otherLowerBound < *upperBounds[i])
987 return false;
988 // Equality case. When both expressions are considered redundant, we
989 // don't want to get both of them. We keep the one that appears
990 // first.
991 if (upperBounds[pos] && lowerBounds[i] &&
992 lowerBounds[i] == upperBounds[i] &&
993 otherLowerBound == *upperBounds[pos] && i < pos)
994 return false;
995 return true;
996 }))
997 irredundantExprs.push_back(Elt: e);
998 } else {
999 if (!lowerBounds[i]) {
1000 irredundantExprs.push_back(Elt: e);
1001 continue;
1002 }
1003 // Likewise for the `min` case. Use the complement of the condition above.
1004 if (!llvm::any_of(Range: llvm::enumerate(First&: upperBounds), P: [&](const auto &en) {
1005 auto otherUpperBound = en.value();
1006 unsigned pos = en.index();
1007 if (pos == i || !otherUpperBound)
1008 return false;
1009 if (*otherUpperBound < *lowerBounds[i])
1010 return true;
1011 if (*otherUpperBound > *lowerBounds[i])
1012 return false;
1013 if (lowerBounds[pos] && upperBounds[i] &&
1014 lowerBounds[i] == upperBounds[i] &&
1015 otherUpperBound == lowerBounds[pos] && i < pos)
1016 return false;
1017 return true;
1018 }))
1019 irredundantExprs.push_back(Elt: e);
1020 }
1021 }
1022
1023 // Create the map without the redundant expressions.
1024 map = AffineMap::get(dimCount: map.getNumDims(), symbolCount: map.getNumSymbols(), results: irredundantExprs,
1025 context: map.getContext());
1026}
1027
1028/// Simplify the map while exploiting information on the values in `operands`.
1029// Use "unused attribute" marker to silence warning stemming from the inability
1030// to see through the template expansion.
1031static void LLVM_ATTRIBUTE_UNUSED
1032simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) {
1033 assert(map.getNumInputs() == operands.size() && "invalid operands for map");
1034 SmallVector<AffineExpr> newResults;
1035 newResults.reserve(N: map.getNumResults());
1036 for (AffineExpr expr : map.getResults()) {
1037 simplifyExprAndOperands(expr, numDims: map.getNumDims(), numSymbols: map.getNumSymbols(),
1038 operands);
1039 newResults.push_back(Elt: expr);
1040 }
1041 map = AffineMap::get(dimCount: map.getNumDims(), symbolCount: map.getNumSymbols(), results: newResults,
1042 context: map.getContext());
1043}
1044
1045/// Replace all occurrences of AffineExpr at position `pos` in `map` by the
1046/// defining AffineApplyOp expression and operands.
1047/// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
1048/// When `dimOrSymbolPosition >= dims.size()`,
1049/// AffineSymbolExpr@[pos - dims.size()] is replaced.
1050/// Mutate `map`,`dims` and `syms` in place as follows:
1051/// 1. `dims` and `syms` are only appended to.
1052/// 2. `map` dim and symbols are gradually shifted to higher positions.
1053/// 3. Old `dim` and `sym` entries are replaced by nullptr
1054/// This avoids the need for any bookkeeping.
1055static LogicalResult replaceDimOrSym(AffineMap *map,
1056 unsigned dimOrSymbolPosition,
1057 SmallVectorImpl<Value> &dims,
1058 SmallVectorImpl<Value> &syms) {
1059 MLIRContext *ctx = map->getContext();
1060 bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1061 unsigned pos = isDimReplacement ? dimOrSymbolPosition
1062 : dimOrSymbolPosition - dims.size();
1063 Value &v = isDimReplacement ? dims[pos] : syms[pos];
1064 if (!v)
1065 return failure();
1066
1067 auto affineApply = v.getDefiningOp<AffineApplyOp>();
1068 if (!affineApply)
1069 return failure();
1070
1071 // At this point we will perform a replacement of `v`, set the entry in `dim`
1072 // or `sym` to nullptr immediately.
1073 v = nullptr;
1074
1075 // Compute the map, dims and symbols coming from the AffineApplyOp.
1076 AffineMap composeMap = affineApply.getAffineMap();
1077 assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
1078 SmallVector<Value> composeOperands(affineApply.getMapOperands().begin(),
1079 affineApply.getMapOperands().end());
1080 // Canonicalize the map to promote dims to symbols when possible. This is to
1081 // avoid generating invalid maps.
1082 canonicalizeMapAndOperands(map: &composeMap, operands: &composeOperands);
1083 AffineExpr replacementExpr =
1084 composeMap.shiftDims(shift: dims.size()).shiftSymbols(shift: syms.size()).getResult(idx: 0);
1085 ValueRange composeDims =
1086 ArrayRef<Value>(composeOperands).take_front(N: composeMap.getNumDims());
1087 ValueRange composeSyms =
1088 ArrayRef<Value>(composeOperands).take_back(N: composeMap.getNumSymbols());
1089 AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(position: pos, context: ctx)
1090 : getAffineSymbolExpr(position: pos, context: ctx);
1091
1092 // Append the dims and symbols where relevant and perform the replacement.
1093 dims.append(in_start: composeDims.begin(), in_end: composeDims.end());
1094 syms.append(in_start: composeSyms.begin(), in_end: composeSyms.end());
1095 *map = map->replace(expr: toReplace, replacement: replacementExpr, numResultDims: dims.size(), numResultSyms: syms.size());
1096
1097 return success();
1098}
1099
1100/// Iterate over `operands` and fold away all those produced by an AffineApplyOp
1101/// iteratively. Perform canonicalization of map and operands as well as
1102/// AffineMap simplification. `map` and `operands` are mutated in place.
1103static void composeAffineMapAndOperands(AffineMap *map,
1104 SmallVectorImpl<Value> *operands) {
1105 if (map->getNumResults() == 0) {
1106 canonicalizeMapAndOperands(map, operands);
1107 *map = simplifyAffineMap(map: *map);
1108 return;
1109 }
1110
1111 MLIRContext *ctx = map->getContext();
1112 SmallVector<Value, 4> dims(operands->begin(),
1113 operands->begin() + map->getNumDims());
1114 SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(),
1115 operands->end());
1116
1117 // Iterate over dims and symbols coming from AffineApplyOp and replace until
1118 // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims`
1119 // and `syms` can only increase by construction.
1120 // The implementation uses a `while` loop to support the case of symbols
1121 // that may be constructed from dims ;this may be overkill.
1122 while (true) {
1123 bool changed = false;
1124 for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1125 if ((changed |= succeeded(Result: replaceDimOrSym(map, dimOrSymbolPosition: pos, dims, syms))))
1126 break;
1127 if (!changed)
1128 break;
1129 }
1130
1131 // Clear operands so we can fill them anew.
1132 operands->clear();
1133
1134 // At this point we may have introduced null operands, prune them out before
1135 // canonicalizing map and operands.
1136 unsigned nDims = 0, nSyms = 0;
1137 SmallVector<AffineExpr, 4> dimReplacements, symReplacements;
1138 dimReplacements.reserve(N: dims.size());
1139 symReplacements.reserve(N: syms.size());
1140 for (auto *container : {&dims, &syms}) {
1141 bool isDim = (container == &dims);
1142 auto &repls = isDim ? dimReplacements : symReplacements;
1143 for (const auto &en : llvm::enumerate(First&: *container)) {
1144 Value v = en.value();
1145 if (!v) {
1146 assert(isDim ? !map->isFunctionOfDim(en.index())
1147 : !map->isFunctionOfSymbol(en.index()) &&
1148 "map is function of unexpected expr@pos");
1149 repls.push_back(Elt: getAffineConstantExpr(constant: 0, context: ctx));
1150 continue;
1151 }
1152 repls.push_back(Elt: isDim ? getAffineDimExpr(position: nDims++, context: ctx)
1153 : getAffineSymbolExpr(position: nSyms++, context: ctx));
1154 operands->push_back(Elt: v);
1155 }
1156 }
1157 *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, numResultDims: nDims,
1158 numResultSyms: nSyms);
1159
1160 // Canonicalize and simplify before returning.
1161 canonicalizeMapAndOperands(map, operands);
1162 *map = simplifyAffineMap(map: *map);
1163}
1164
1165void mlir::affine::fullyComposeAffineMapAndOperands(
1166 AffineMap *map, SmallVectorImpl<Value> *operands) {
1167 while (llvm::any_of(Range&: *operands, P: [](Value v) {
1168 return isa_and_nonnull<AffineApplyOp>(Val: v.getDefiningOp());
1169 })) {
1170 composeAffineMapAndOperands(map, operands);
1171 }
1172}
1173
1174AffineApplyOp
1175mlir::affine::makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
1176 ArrayRef<OpFoldResult> operands) {
1177 SmallVector<Value> valueOperands;
1178 map = foldAttributesIntoMap(b, map, operands, remainingValues&: valueOperands);
1179 composeAffineMapAndOperands(map: &map, operands: &valueOperands);
1180 assert(map);
1181 return b.create<AffineApplyOp>(loc, map, valueOperands);
1182}
1183
1184AffineApplyOp
1185mlir::affine::makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e,
1186 ArrayRef<OpFoldResult> operands) {
1187 return makeComposedAffineApply(
1188 b, loc,
1189 AffineMap::inferFromExprList(exprsList: ArrayRef<AffineExpr>{e}, context: b.getContext())
1190 .front(),
1191 operands);
1192}
1193
1194/// Composes the given affine map with the given list of operands, pulling in
1195/// the maps from any affine.apply operations that supply the operands.
1196static void composeMultiResultAffineMap(AffineMap &map,
1197 SmallVectorImpl<Value> &operands) {
1198 // Compose and canonicalize each expression in the map individually because
1199 // composition only applies to single-result maps, collecting potentially
1200 // duplicate operands in a single list with shifted dimensions and symbols.
1201 SmallVector<Value> dims, symbols;
1202 SmallVector<AffineExpr> exprs;
1203 for (unsigned i : llvm::seq<unsigned>(Begin: 0, End: map.getNumResults())) {
1204 SmallVector<Value> submapOperands(operands.begin(), operands.end());
1205 AffineMap submap = map.getSubMap(resultPos: {i});
1206 fullyComposeAffineMapAndOperands(map: &submap, operands: &submapOperands);
1207 canonicalizeMapAndOperands(map: &submap, operands: &submapOperands);
1208 unsigned numNewDims = submap.getNumDims();
1209 submap = submap.shiftDims(shift: dims.size()).shiftSymbols(shift: symbols.size());
1210 llvm::append_range(C&: dims,
1211 R: ArrayRef<Value>(submapOperands).take_front(N: numNewDims));
1212 llvm::append_range(C&: symbols,
1213 R: ArrayRef<Value>(submapOperands).drop_front(N: numNewDims));
1214 exprs.push_back(Elt: submap.getResult(idx: 0));
1215 }
1216
1217 // Canonicalize the map created from composed expressions to deduplicate the
1218 // dimension and symbol operands.
1219 operands = llvm::to_vector(Range: llvm::concat<Value>(Ranges&: dims, Ranges&: symbols));
1220 map = AffineMap::get(dimCount: dims.size(), symbolCount: symbols.size(), results: exprs, context: map.getContext());
1221 canonicalizeMapAndOperands(map: &map, operands: &operands);
1222}
1223
1224OpFoldResult
1225mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
1226 AffineMap map,
1227 ArrayRef<OpFoldResult> operands) {
1228 assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
1229
1230 // Create new builder without a listener, so that no notification is
1231 // triggered if the op is folded.
1232 // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1233 // workaround is no longer needed.
1234 OpBuilder newBuilder(b.getContext());
1235 newBuilder.setInsertionPoint(block: b.getInsertionBlock(), insertPoint: b.getInsertionPoint());
1236
1237 // Create op.
1238 AffineApplyOp applyOp =
1239 makeComposedAffineApply(newBuilder, loc, map, operands);
1240
1241 // Get constant operands.
1242 SmallVector<Attribute> constOperands(applyOp->getNumOperands());
1243 for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1244 matchPattern(applyOp->getOperand(i), m_Constant(bind_value: &constOperands[i]));
1245
1246 // Try to fold the operation.
1247 SmallVector<OpFoldResult> foldResults;
1248 if (failed(applyOp->fold(constOperands, foldResults)) ||
1249 foldResults.empty()) {
1250 if (OpBuilder::Listener *listener = b.getListener())
1251 listener->notifyOperationInserted(op: applyOp, /*previous=*/{});
1252 return applyOp.getResult();
1253 }
1254
1255 applyOp->erase();
1256 return llvm::getSingleElement(C&: foldResults);
1257}
1258
1259OpFoldResult
1260mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
1261 AffineExpr expr,
1262 ArrayRef<OpFoldResult> operands) {
1263 return makeComposedFoldedAffineApply(
1264 b, loc,
1265 map: AffineMap::inferFromExprList(exprsList: ArrayRef<AffineExpr>{expr}, context: b.getContext())
1266 .front(),
1267 operands);
1268}
1269
1270SmallVector<OpFoldResult>
1271mlir::affine::makeComposedFoldedMultiResultAffineApply(
1272 OpBuilder &b, Location loc, AffineMap map,
1273 ArrayRef<OpFoldResult> operands) {
1274 return llvm::map_to_vector(C: llvm::seq<unsigned>(Begin: 0, End: map.getNumResults()),
1275 F: [&](unsigned i) {
1276 return makeComposedFoldedAffineApply(
1277 b, loc, map: map.getSubMap(resultPos: {i}), operands);
1278 });
1279}
1280
1281template <typename OpTy>
1282static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map,
1283 ArrayRef<OpFoldResult> operands) {
1284 SmallVector<Value> valueOperands;
1285 map = foldAttributesIntoMap(b, map, operands, remainingValues&: valueOperands);
1286 composeMultiResultAffineMap(map, operands&: valueOperands);
1287 return b.create<OpTy>(loc, b.getIndexType(), map, valueOperands);
1288}
1289
1290AffineMinOp
1291mlir::affine::makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map,
1292 ArrayRef<OpFoldResult> operands) {
1293 return makeComposedMinMax<AffineMinOp>(b, loc, map, operands);
1294}
1295
1296template <typename OpTy>
1297static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc,
1298 AffineMap map,
1299 ArrayRef<OpFoldResult> operands) {
1300 // Create new builder without a listener, so that no notification is
1301 // triggered if the op is folded.
1302 // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1303 // workaround is no longer needed.
1304 OpBuilder newBuilder(b.getContext());
1305 newBuilder.setInsertionPoint(block: b.getInsertionBlock(), insertPoint: b.getInsertionPoint());
1306
1307 // Create op.
1308 auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1309
1310 // Get constant operands.
1311 SmallVector<Attribute> constOperands(minMaxOp->getNumOperands());
1312 for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1313 matchPattern(minMaxOp->getOperand(i), m_Constant(bind_value: &constOperands[i]));
1314
1315 // Try to fold the operation.
1316 SmallVector<OpFoldResult> foldResults;
1317 if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1318 foldResults.empty()) {
1319 if (OpBuilder::Listener *listener = b.getListener())
1320 listener->notifyOperationInserted(op: minMaxOp, /*previous=*/{});
1321 return minMaxOp.getResult();
1322 }
1323
1324 minMaxOp->erase();
1325 return llvm::getSingleElement(C&: foldResults);
1326}
1327
1328OpFoldResult
1329mlir::affine::makeComposedFoldedAffineMin(OpBuilder &b, Location loc,
1330 AffineMap map,
1331 ArrayRef<OpFoldResult> operands) {
1332 return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
1333}
1334
1335OpFoldResult
1336mlir::affine::makeComposedFoldedAffineMax(OpBuilder &b, Location loc,
1337 AffineMap map,
1338 ArrayRef<OpFoldResult> operands) {
1339 return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
1340}
1341
1342// A symbol may appear as a dim in affine.apply operations. This function
1343// canonicalizes dims that are valid symbols into actual symbols.
1344template <class MapOrSet>
1345static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
1346 SmallVectorImpl<Value> *operands) {
1347 if (!mapOrSet || operands->empty())
1348 return;
1349
1350 assert(mapOrSet->getNumInputs() == operands->size() &&
1351 "map/set inputs must match number of operands");
1352
1353 auto *context = mapOrSet->getContext();
1354 SmallVector<Value, 8> resultOperands;
1355 resultOperands.reserve(N: operands->size());
1356 SmallVector<Value, 8> remappedSymbols;
1357 remappedSymbols.reserve(N: operands->size());
1358 unsigned nextDim = 0;
1359 unsigned nextSym = 0;
1360 unsigned oldNumSyms = mapOrSet->getNumSymbols();
1361 SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1362 for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1363 if (i < mapOrSet->getNumDims()) {
1364 if (isValidSymbol(value: (*operands)[i])) {
1365 // This is a valid symbol that appears as a dim, canonicalize it.
1366 dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
1367 remappedSymbols.push_back(Elt: (*operands)[i]);
1368 } else {
1369 dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1370 resultOperands.push_back(Elt: (*operands)[i]);
1371 }
1372 } else {
1373 resultOperands.push_back(Elt: (*operands)[i]);
1374 }
1375 }
1376
1377 resultOperands.append(in_start: remappedSymbols.begin(), in_end: remappedSymbols.end());
1378 *operands = resultOperands;
1379 *mapOrSet = mapOrSet->replaceDimsAndSymbols(
1380 dimRemapping, /*symReplacements=*/{}, nextDim, oldNumSyms + nextSym);
1381
1382 assert(mapOrSet->getNumInputs() == operands->size() &&
1383 "map/set inputs must match number of operands");
1384}
1385
1386/// A valid affine dimension may appear as a symbol in affine.apply operations.
1387/// Given an application of `operands` to an affine map or integer set
1388/// `mapOrSet`, this function canonicalizes symbols of `mapOrSet` that are valid
1389/// dims, but not valid symbols into actual dims. Without such a legalization,
1390/// the affine.apply will be invalid. This method is the exact inverse of
1391/// canonicalizePromotedSymbols.
1392template <class MapOrSet>
1393static void legalizeDemotedDims(MapOrSet &mapOrSet,
1394 SmallVectorImpl<Value> &operands) {
1395 if (!mapOrSet || operands.empty())
1396 return;
1397
1398 unsigned numOperands = operands.size();
1399
1400 assert(mapOrSet.getNumInputs() == numOperands &&
1401 "map/set inputs must match number of operands");
1402
1403 auto *context = mapOrSet.getContext();
1404 SmallVector<Value, 8> resultOperands;
1405 resultOperands.reserve(N: numOperands);
1406 SmallVector<Value, 8> remappedDims;
1407 remappedDims.reserve(N: numOperands);
1408 SmallVector<Value, 8> symOperands;
1409 symOperands.reserve(N: mapOrSet.getNumSymbols());
1410 unsigned nextSym = 0;
1411 unsigned nextDim = 0;
1412 unsigned oldNumDims = mapOrSet.getNumDims();
1413 SmallVector<AffineExpr, 8> symRemapping(mapOrSet.getNumSymbols());
1414 resultOperands.assign(in_start: operands.begin(), in_end: operands.begin() + oldNumDims);
1415 for (unsigned i = oldNumDims, e = mapOrSet.getNumInputs(); i != e; ++i) {
1416 if (operands[i] && isValidDim(value: operands[i]) && !isValidSymbol(value: operands[i])) {
1417 // This is a valid dim that appears as a symbol, legalize it.
1418 symRemapping[i - oldNumDims] =
1419 getAffineDimExpr(oldNumDims + nextDim++, context);
1420 remappedDims.push_back(Elt: operands[i]);
1421 } else {
1422 symRemapping[i - oldNumDims] = getAffineSymbolExpr(nextSym++, context);
1423 symOperands.push_back(Elt: operands[i]);
1424 }
1425 }
1426
1427 append_range(C&: resultOperands, R&: remappedDims);
1428 append_range(C&: resultOperands, R&: symOperands);
1429 operands = resultOperands;
1430 mapOrSet = mapOrSet.replaceDimsAndSymbols(
1431 /*dimReplacements=*/{}, symRemapping, oldNumDims + nextDim, nextSym);
1432
1433 assert(mapOrSet.getNumInputs() == operands.size() &&
1434 "map/set inputs must match number of operands");
1435}
1436
1437// Works for either an affine map or an integer set.
1438template <class MapOrSet>
1439static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
1440 SmallVectorImpl<Value> *operands) {
1441 static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1442 "Argument must be either of AffineMap or IntegerSet type");
1443
1444 if (!mapOrSet || operands->empty())
1445 return;
1446
1447 assert(mapOrSet->getNumInputs() == operands->size() &&
1448 "map/set inputs must match number of operands");
1449
1450 canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1451 legalizeDemotedDims<MapOrSet>(*mapOrSet, *operands);
1452
1453 // Check to see what dims are used.
1454 llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1455 llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1456 mapOrSet->walkExprs([&](AffineExpr expr) {
1457 if (auto dimExpr = dyn_cast<AffineDimExpr>(Val&: expr))
1458 usedDims[dimExpr.getPosition()] = true;
1459 else if (auto symExpr = dyn_cast<AffineSymbolExpr>(Val&: expr))
1460 usedSyms[symExpr.getPosition()] = true;
1461 });
1462
1463 auto *context = mapOrSet->getContext();
1464
1465 SmallVector<Value, 8> resultOperands;
1466 resultOperands.reserve(N: operands->size());
1467
1468 llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1469 SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1470 unsigned nextDim = 0;
1471 for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1472 if (usedDims[i]) {
1473 // Remap dim positions for duplicate operands.
1474 auto it = seenDims.find(Val: (*operands)[i]);
1475 if (it == seenDims.end()) {
1476 dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1477 resultOperands.push_back(Elt: (*operands)[i]);
1478 seenDims.insert(KV: std::make_pair(x&: (*operands)[i], y&: dimRemapping[i]));
1479 } else {
1480 dimRemapping[i] = it->second;
1481 }
1482 }
1483 }
1484 llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1485 SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
1486 unsigned nextSym = 0;
1487 for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1488 if (!usedSyms[i])
1489 continue;
1490 // Handle constant operands (only needed for symbolic operands since
1491 // constant operands in dimensional positions would have already been
1492 // promoted to symbolic positions above).
1493 IntegerAttr operandCst;
1494 if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
1495 m_Constant(&operandCst))) {
1496 symRemapping[i] =
1497 getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
1498 continue;
1499 }
1500 // Remap symbol positions for duplicate operands.
1501 auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1502 if (it == seenSymbols.end()) {
1503 symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
1504 resultOperands.push_back(Elt: (*operands)[i + mapOrSet->getNumDims()]);
1505 seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1506 symRemapping[i]));
1507 } else {
1508 symRemapping[i] = it->second;
1509 }
1510 }
1511 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1512 nextDim, nextSym);
1513 *operands = resultOperands;
1514}
1515
1516void mlir::affine::canonicalizeMapAndOperands(
1517 AffineMap *map, SmallVectorImpl<Value> *operands) {
1518 canonicalizeMapOrSetAndOperands<AffineMap>(mapOrSet: map, operands);
1519}
1520
1521void mlir::affine::canonicalizeSetAndOperands(
1522 IntegerSet *set, SmallVectorImpl<Value> *operands) {
1523 canonicalizeMapOrSetAndOperands<IntegerSet>(mapOrSet: set, operands);
1524}
1525
1526namespace {
1527/// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
1528/// maps that supply results into them.
1529///
1530template <typename AffineOpTy>
1531struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
1532 using OpRewritePattern<AffineOpTy>::OpRewritePattern;
1533
1534 /// Replace the affine op with another instance of it with the supplied
1535 /// map and mapOperands.
1536 void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
1537 AffineMap map, ArrayRef<Value> mapOperands) const;
1538
1539 LogicalResult matchAndRewrite(AffineOpTy affineOp,
1540 PatternRewriter &rewriter) const override {
1541 static_assert(
1542 llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1543 AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1544 AffineVectorStoreOp, AffineVectorLoadOp>::value,
1545 "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1546 "expected");
1547 auto map = affineOp.getAffineMap();
1548 AffineMap oldMap = map;
1549 auto oldOperands = affineOp.getMapOperands();
1550 SmallVector<Value, 8> resultOperands(oldOperands);
1551 composeAffineMapAndOperands(&map, &resultOperands);
1552 canonicalizeMapAndOperands(&map, &resultOperands);
1553 simplifyMapWithOperands(map, resultOperands);
1554 if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1555 resultOperands.begin()))
1556 return failure();
1557
1558 replaceAffineOp(rewriter, affineOp, map, mapOperands: resultOperands);
1559 return success();
1560 }
1561};
1562
1563// Specialize the template to account for the different build signatures for
1564// affine load, store, and apply ops.
1565template <>
1566void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1567 PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
1568 ArrayRef<Value> mapOperands) const {
1569 rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
1570 mapOperands);
1571}
1572template <>
1573void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1574 PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
1575 ArrayRef<Value> mapOperands) const {
1576 rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
1577 prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1578 prefetch.getLocalityHint(), prefetch.getIsDataCache());
1579}
1580template <>
1581void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1582 PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
1583 ArrayRef<Value> mapOperands) const {
1584 rewriter.replaceOpWithNewOp<AffineStoreOp>(
1585 store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1586}
1587template <>
1588void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1589 PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map,
1590 ArrayRef<Value> mapOperands) const {
1591 rewriter.replaceOpWithNewOp<AffineVectorLoadOp>(
1592 vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1593 mapOperands);
1594}
1595template <>
1596void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1597 PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map,
1598 ArrayRef<Value> mapOperands) const {
1599 rewriter.replaceOpWithNewOp<AffineVectorStoreOp>(
1600 vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1601 mapOperands);
1602}
1603
1604// Generic version for ops that don't have extra operands.
1605template <typename AffineOpTy>
1606void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1607 PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
1608 ArrayRef<Value> mapOperands) const {
1609 rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
1610}
1611} // namespace
1612
1613void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1614 MLIRContext *context) {
1615 results.add<SimplifyAffineOp<AffineApplyOp>>(context);
1616}
1617
1618//===----------------------------------------------------------------------===//
1619// AffineDmaStartOp
1620//===----------------------------------------------------------------------===//
1621
1622// TODO: Check that map operands are loop IVs or symbols.
1623void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
1624 Value srcMemRef, AffineMap srcMap,
1625 ValueRange srcIndices, Value destMemRef,
1626 AffineMap dstMap, ValueRange destIndices,
1627 Value tagMemRef, AffineMap tagMap,
1628 ValueRange tagIndices, Value numElements,
1629 Value stride, Value elementsPerStride) {
1630 result.addOperands(newOperands: srcMemRef);
1631 result.addAttribute(getSrcMapAttrStrName(), AffineMapAttr::get(srcMap));
1632 result.addOperands(newOperands: srcIndices);
1633 result.addOperands(newOperands: destMemRef);
1634 result.addAttribute(getDstMapAttrStrName(), AffineMapAttr::get(dstMap));
1635 result.addOperands(newOperands: destIndices);
1636 result.addOperands(newOperands: tagMemRef);
1637 result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1638 result.addOperands(newOperands: tagIndices);
1639 result.addOperands(newOperands: numElements);
1640 if (stride) {
1641 result.addOperands(newOperands: {stride, elementsPerStride});
1642 }
1643}
1644
1645void AffineDmaStartOp::print(OpAsmPrinter &p) {
1646 p << " " << getSrcMemRef() << '[';
1647 p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
1648 p << "], " << getDstMemRef() << '[';
1649 p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
1650 p << "], " << getTagMemRef() << '[';
1651 p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
1652 p << "], " << getNumElements();
1653 if (isStrided()) {
1654 p << ", " << getStride();
1655 p << ", " << getNumElementsPerStride();
1656 }
1657 p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
1658 << getTagMemRefType();
1659}
1660
1661// Parse AffineDmaStartOp.
1662// Ex:
1663// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
1664// %stride, %num_elt_per_stride
1665// : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
1666//
1667ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
1668 OperationState &result) {
1669 OpAsmParser::UnresolvedOperand srcMemRefInfo;
1670 AffineMapAttr srcMapAttr;
1671 SmallVector<OpAsmParser::UnresolvedOperand, 4> srcMapOperands;
1672 OpAsmParser::UnresolvedOperand dstMemRefInfo;
1673 AffineMapAttr dstMapAttr;
1674 SmallVector<OpAsmParser::UnresolvedOperand, 4> dstMapOperands;
1675 OpAsmParser::UnresolvedOperand tagMemRefInfo;
1676 AffineMapAttr tagMapAttr;
1677 SmallVector<OpAsmParser::UnresolvedOperand, 4> tagMapOperands;
1678 OpAsmParser::UnresolvedOperand numElementsInfo;
1679 SmallVector<OpAsmParser::UnresolvedOperand, 2> strideInfo;
1680
1681 SmallVector<Type, 3> types;
1682 auto indexType = parser.getBuilder().getIndexType();
1683
1684 // Parse and resolve the following list of operands:
1685 // *) dst memref followed by its affine maps operands (in square brackets).
1686 // *) src memref followed by its affine map operands (in square brackets).
1687 // *) tag memref followed by its affine map operands (in square brackets).
1688 // *) number of elements transferred by DMA operation.
1689 if (parser.parseOperand(result&: srcMemRefInfo) ||
1690 parser.parseAffineMapOfSSAIds(operands&: srcMapOperands, map&: srcMapAttr,
1691 attrName: getSrcMapAttrStrName(),
1692 attrs&: result.attributes) ||
1693 parser.parseComma() || parser.parseOperand(result&: dstMemRefInfo) ||
1694 parser.parseAffineMapOfSSAIds(operands&: dstMapOperands, map&: dstMapAttr,
1695 attrName: getDstMapAttrStrName(),
1696 attrs&: result.attributes) ||
1697 parser.parseComma() || parser.parseOperand(result&: tagMemRefInfo) ||
1698 parser.parseAffineMapOfSSAIds(operands&: tagMapOperands, map&: tagMapAttr,
1699 attrName: getTagMapAttrStrName(),
1700 attrs&: result.attributes) ||
1701 parser.parseComma() || parser.parseOperand(result&: numElementsInfo))
1702 return failure();
1703
1704 // Parse optional stride and elements per stride.
1705 if (parser.parseTrailingOperandList(result&: strideInfo))
1706 return failure();
1707
1708 if (!strideInfo.empty() && strideInfo.size() != 2) {
1709 return parser.emitError(loc: parser.getNameLoc(),
1710 message: "expected two stride related operands");
1711 }
1712 bool isStrided = strideInfo.size() == 2;
1713
1714 if (parser.parseColonTypeList(result&: types))
1715 return failure();
1716
1717 if (types.size() != 3)
1718 return parser.emitError(loc: parser.getNameLoc(), message: "expected three types");
1719
1720 if (parser.resolveOperand(operand: srcMemRefInfo, type: types[0], result&: result.operands) ||
1721 parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1722 parser.resolveOperand(operand: dstMemRefInfo, type: types[1], result&: result.operands) ||
1723 parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1724 parser.resolveOperand(operand: tagMemRefInfo, type: types[2], result&: result.operands) ||
1725 parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1726 parser.resolveOperand(operand: numElementsInfo, type: indexType, result&: result.operands))
1727 return failure();
1728
1729 if (isStrided) {
1730 if (parser.resolveOperands(strideInfo, indexType, result.operands))
1731 return failure();
1732 }
1733
1734 // Check that src/dst/tag operand counts match their map.numInputs.
1735 if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1736 dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1737 tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1738 return parser.emitError(loc: parser.getNameLoc(),
1739 message: "memref operand count not equal to map.numInputs");
1740 return success();
1741}
1742
1743LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
1744 if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).getType()))
1745 return emitOpError("expected DMA source to be of memref type");
1746 if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).getType()))
1747 return emitOpError("expected DMA destination to be of memref type");
1748 if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).getType()))
1749 return emitOpError("expected DMA tag to be of memref type");
1750
1751 unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1752 getDstMap().getNumInputs() +
1753 getTagMap().getNumInputs();
1754 if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1755 getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1756 return emitOpError("incorrect number of operands");
1757 }
1758
1759 Region *scope = getAffineScope(*this);
1760 for (auto idx : getSrcIndices()) {
1761 if (!idx.getType().isIndex())
1762 return emitOpError("src index to dma_start must have 'index' type");
1763 if (!isValidAffineIndexOperand(idx, scope))
1764 return emitOpError(
1765 "src index must be a valid dimension or symbol identifier");
1766 }
1767 for (auto idx : getDstIndices()) {
1768 if (!idx.getType().isIndex())
1769 return emitOpError("dst index to dma_start must have 'index' type");
1770 if (!isValidAffineIndexOperand(idx, scope))
1771 return emitOpError(
1772 "dst index must be a valid dimension or symbol identifier");
1773 }
1774 for (auto idx : getTagIndices()) {
1775 if (!idx.getType().isIndex())
1776 return emitOpError("tag index to dma_start must have 'index' type");
1777 if (!isValidAffineIndexOperand(idx, scope))
1778 return emitOpError(
1779 "tag index must be a valid dimension or symbol identifier");
1780 }
1781 return success();
1782}
1783
1784LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1785 SmallVectorImpl<OpFoldResult> &results) {
1786 /// dma_start(memrefcast) -> dma_start
1787 return memref::foldMemRefCast(*this);
1788}
1789
1790void AffineDmaStartOp::getEffects(
1791 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1792 &effects) {
1793 effects.emplace_back(Args: MemoryEffects::Read::get(), Args: &getSrcMemRefMutable(),
1794 Args: SideEffects::DefaultResource::get());
1795 effects.emplace_back(Args: MemoryEffects::Write::get(), Args: &getDstMemRefMutable(),
1796 Args: SideEffects::DefaultResource::get());
1797 effects.emplace_back(Args: MemoryEffects::Read::get(), Args: &getTagMemRefMutable(),
1798 Args: SideEffects::DefaultResource::get());
1799}
1800
1801//===----------------------------------------------------------------------===//
1802// AffineDmaWaitOp
1803//===----------------------------------------------------------------------===//
1804
1805// TODO: Check that map operands are loop IVs or symbols.
1806void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
1807 Value tagMemRef, AffineMap tagMap,
1808 ValueRange tagIndices, Value numElements) {
1809 result.addOperands(newOperands: tagMemRef);
1810 result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1811 result.addOperands(newOperands: tagIndices);
1812 result.addOperands(newOperands: numElements);
1813}
1814
1815void AffineDmaWaitOp::print(OpAsmPrinter &p) {
1816 p << " " << getTagMemRef() << '[';
1817 SmallVector<Value, 2> operands(getTagIndices());
1818 p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
1819 p << "], ";
1820 p.printOperand(value: getNumElements());
1821 p << " : " << getTagMemRef().getType();
1822}
1823
1824// Parse AffineDmaWaitOp.
1825// Eg:
1826// affine.dma_wait %tag[%index], %num_elements
1827// : memref<1 x i32, (d0) -> (d0), 4>
1828//
1829ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
1830 OperationState &result) {
1831 OpAsmParser::UnresolvedOperand tagMemRefInfo;
1832 AffineMapAttr tagMapAttr;
1833 SmallVector<OpAsmParser::UnresolvedOperand, 2> tagMapOperands;
1834 Type type;
1835 auto indexType = parser.getBuilder().getIndexType();
1836 OpAsmParser::UnresolvedOperand numElementsInfo;
1837
1838 // Parse tag memref, its map operands, and dma size.
1839 if (parser.parseOperand(result&: tagMemRefInfo) ||
1840 parser.parseAffineMapOfSSAIds(operands&: tagMapOperands, map&: tagMapAttr,
1841 attrName: getTagMapAttrStrName(),
1842 attrs&: result.attributes) ||
1843 parser.parseComma() || parser.parseOperand(result&: numElementsInfo) ||
1844 parser.parseColonType(result&: type) ||
1845 parser.resolveOperand(operand: tagMemRefInfo, type, result&: result.operands) ||
1846 parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1847 parser.resolveOperand(operand: numElementsInfo, type: indexType, result&: result.operands))
1848 return failure();
1849
1850 if (!llvm::isa<MemRefType>(Val: type))
1851 return parser.emitError(loc: parser.getNameLoc(),
1852 message: "expected tag to be of memref type");
1853
1854 if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1855 return parser.emitError(loc: parser.getNameLoc(),
1856 message: "tag memref operand count != to map.numInputs");
1857 return success();
1858}
1859
1860LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
1861 if (!llvm::isa<MemRefType>(getOperand(0).getType()))
1862 return emitOpError("expected DMA tag to be of memref type");
1863 Region *scope = getAffineScope(*this);
1864 for (auto idx : getTagIndices()) {
1865 if (!idx.getType().isIndex())
1866 return emitOpError("index to dma_wait must have 'index' type");
1867 if (!isValidAffineIndexOperand(idx, scope))
1868 return emitOpError(
1869 "index must be a valid dimension or symbol identifier");
1870 }
1871 return success();
1872}
1873
1874LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1875 SmallVectorImpl<OpFoldResult> &results) {
1876 /// dma_wait(memrefcast) -> dma_wait
1877 return memref::foldMemRefCast(*this);
1878}
1879
1880void AffineDmaWaitOp::getEffects(
1881 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1882 &effects) {
1883 effects.emplace_back(Args: MemoryEffects::Read::get(), Args: &getTagMemRefMutable(),
1884 Args: SideEffects::DefaultResource::get());
1885}
1886
1887//===----------------------------------------------------------------------===//
1888// AffineForOp
1889//===----------------------------------------------------------------------===//
1890
1891/// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
1892/// bodyBuilder are empty/null, we include default terminator op.
1893void AffineForOp::build(OpBuilder &builder, OperationState &result,
1894 ValueRange lbOperands, AffineMap lbMap,
1895 ValueRange ubOperands, AffineMap ubMap, int64_t step,
1896 ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1897 assert(((!lbMap && lbOperands.empty()) ||
1898 lbOperands.size() == lbMap.getNumInputs()) &&
1899 "lower bound operand count does not match the affine map");
1900 assert(((!ubMap && ubOperands.empty()) ||
1901 ubOperands.size() == ubMap.getNumInputs()) &&
1902 "upper bound operand count does not match the affine map");
1903 assert(step > 0 && "step has to be a positive integer constant");
1904
1905 OpBuilder::InsertionGuard guard(builder);
1906
1907 // Set variadic segment sizes.
1908 result.addAttribute(
1909 getOperandSegmentSizeAttr(),
1910 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lbOperands.size()),
1911 static_cast<int32_t>(ubOperands.size()),
1912 static_cast<int32_t>(iterArgs.size())}));
1913
1914 for (Value val : iterArgs)
1915 result.addTypes(val.getType());
1916
1917 // Add an attribute for the step.
1918 result.addAttribute(getStepAttrName(result.name),
1919 builder.getIntegerAttr(builder.getIndexType(), step));
1920
1921 // Add the lower bound.
1922 result.addAttribute(getLowerBoundMapAttrName(result.name),
1923 AffineMapAttr::get(lbMap));
1924 result.addOperands(lbOperands);
1925
1926 // Add the upper bound.
1927 result.addAttribute(getUpperBoundMapAttrName(result.name),
1928 AffineMapAttr::get(ubMap));
1929 result.addOperands(ubOperands);
1930
1931 result.addOperands(iterArgs);
1932 // Create a region and a block for the body. The argument of the region is
1933 // the loop induction variable.
1934 Region *bodyRegion = result.addRegion();
1935 Block *bodyBlock = builder.createBlock(bodyRegion);
1936 Value inductionVar =
1937 bodyBlock->addArgument(builder.getIndexType(), result.location);
1938 for (Value val : iterArgs)
1939 bodyBlock->addArgument(val.getType(), val.getLoc());
1940
1941 // Create the default terminator if the builder is not provided and if the
1942 // iteration arguments are not provided. Otherwise, leave this to the caller
1943 // because we don't know which values to return from the loop.
1944 if (iterArgs.empty() && !bodyBuilder) {
1945 ensureTerminator(*bodyRegion, builder, result.location);
1946 } else if (bodyBuilder) {
1947 OpBuilder::InsertionGuard guard(builder);
1948 builder.setInsertionPointToStart(bodyBlock);
1949 bodyBuilder(builder, result.location, inductionVar,
1950 bodyBlock->getArguments().drop_front());
1951 }
1952}
1953
1954void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
1955 int64_t ub, int64_t step, ValueRange iterArgs,
1956 BodyBuilderFn bodyBuilder) {
1957 auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
1958 auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
1959 return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1960 bodyBuilder);
1961}
1962
1963LogicalResult AffineForOp::verifyRegions() {
1964 // Check that the body defines as single block argument for the induction
1965 // variable.
1966 auto *body = getBody();
1967 if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1968 return emitOpError("expected body to have a single index argument for the "
1969 "induction variable");
1970
1971 // Verify that the bound operands are valid dimension/symbols.
1972 /// Lower bound.
1973 if (getLowerBoundMap().getNumInputs() > 0)
1974 if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundOperands(),
1975 getLowerBoundMap().getNumDims())))
1976 return failure();
1977 /// Upper bound.
1978 if (getUpperBoundMap().getNumInputs() > 0)
1979 if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundOperands(),
1980 getUpperBoundMap().getNumDims())))
1981 return failure();
1982 if (getLowerBoundMap().getNumResults() < 1)
1983 return emitOpError("expected lower bound map to have at least one result");
1984 if (getUpperBoundMap().getNumResults() < 1)
1985 return emitOpError("expected upper bound map to have at least one result");
1986
1987 unsigned opNumResults = getNumResults();
1988 if (opNumResults == 0)
1989 return success();
1990
1991 // If ForOp defines values, check that the number and types of the defined
1992 // values match ForOp initial iter operands and backedge basic block
1993 // arguments.
1994 if (getNumIterOperands() != opNumResults)
1995 return emitOpError(
1996 "mismatch between the number of loop-carried values and results");
1997 if (getNumRegionIterArgs() != opNumResults)
1998 return emitOpError(
1999 "mismatch between the number of basic block args and results");
2000
2001 return success();
2002}
2003
2004/// Parse a for operation loop bounds.
2005static ParseResult parseBound(bool isLower, OperationState &result,
2006 OpAsmParser &p) {
2007 // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
2008 // the map has multiple results.
2009 bool failedToParsedMinMax =
2010 failed(Result: p.parseOptionalKeyword(keyword: isLower ? "max" : "min"));
2011
2012 auto &builder = p.getBuilder();
2013 auto boundAttrStrName =
2014 isLower ? AffineForOp::getLowerBoundMapAttrName(result.name)
2015 : AffineForOp::getUpperBoundMapAttrName(result.name);
2016
2017 // Parse ssa-id as identity map.
2018 SmallVector<OpAsmParser::UnresolvedOperand, 1> boundOpInfos;
2019 if (p.parseOperandList(result&: boundOpInfos))
2020 return failure();
2021
2022 if (!boundOpInfos.empty()) {
2023 // Check that only one operand was parsed.
2024 if (boundOpInfos.size() > 1)
2025 return p.emitError(loc: p.getNameLoc(),
2026 message: "expected only one loop bound operand");
2027
2028 // TODO: improve error message when SSA value is not of index type.
2029 // Currently it is 'use of value ... expects different type than prior uses'
2030 if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
2031 result.operands))
2032 return failure();
2033
2034 // Create an identity map using symbol id. This representation is optimized
2035 // for storage. Analysis passes may expand it into a multi-dimensional map
2036 // if desired.
2037 AffineMap map = builder.getSymbolIdentityMap();
2038 result.addAttribute(boundAttrStrName, AffineMapAttr::get(map));
2039 return success();
2040 }
2041
2042 // Get the attribute location.
2043 SMLoc attrLoc = p.getCurrentLocation();
2044
2045 Attribute boundAttr;
2046 if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrStrName,
2047 result.attributes))
2048 return failure();
2049
2050 // Parse full form - affine map followed by dim and symbol list.
2051 if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(boundAttr)) {
2052 unsigned currentNumOperands = result.operands.size();
2053 unsigned numDims;
2054 if (parseDimAndSymbolList(parser&: p, operands&: result.operands, numDims))
2055 return failure();
2056
2057 auto map = affineMapAttr.getValue();
2058 if (map.getNumDims() != numDims)
2059 return p.emitError(
2060 loc: p.getNameLoc(),
2061 message: "dim operand count and affine map dim count must match");
2062
2063 unsigned numDimAndSymbolOperands =
2064 result.operands.size() - currentNumOperands;
2065 if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
2066 return p.emitError(
2067 loc: p.getNameLoc(),
2068 message: "symbol operand count and affine map symbol count must match");
2069
2070 // If the map has multiple results, make sure that we parsed the min/max
2071 // prefix.
2072 if (map.getNumResults() > 1 && failedToParsedMinMax) {
2073 if (isLower) {
2074 return p.emitError(loc: attrLoc, message: "lower loop bound affine map with "
2075 "multiple results requires 'max' prefix");
2076 }
2077 return p.emitError(loc: attrLoc, message: "upper loop bound affine map with multiple "
2078 "results requires 'min' prefix");
2079 }
2080 return success();
2081 }
2082
2083 // Parse custom assembly form.
2084 if (auto integerAttr = llvm::dyn_cast<IntegerAttr>(boundAttr)) {
2085 result.attributes.pop_back();
2086 result.addAttribute(
2087 boundAttrStrName,
2088 AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
2089 return success();
2090 }
2091
2092 return p.emitError(
2093 loc: p.getNameLoc(),
2094 message: "expected valid affine map representation for loop bounds");
2095}
2096
2097ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
2098 auto &builder = parser.getBuilder();
2099 OpAsmParser::Argument inductionVariable;
2100 inductionVariable.type = builder.getIndexType();
2101 // Parse the induction variable followed by '='.
2102 if (parser.parseArgument(inductionVariable) || parser.parseEqual())
2103 return failure();
2104
2105 // Parse loop bounds.
2106 int64_t numOperands = result.operands.size();
2107 if (parseBound(/*isLower=*/true, result, parser))
2108 return failure();
2109 int64_t numLbOperands = result.operands.size() - numOperands;
2110 if (parser.parseKeyword("to", " between bounds"))
2111 return failure();
2112 numOperands = result.operands.size();
2113 if (parseBound(/*isLower=*/false, result, parser))
2114 return failure();
2115 int64_t numUbOperands = result.operands.size() - numOperands;
2116
2117 // Parse the optional loop step, we default to 1 if one is not present.
2118 if (parser.parseOptionalKeyword("step")) {
2119 result.addAttribute(
2120 getStepAttrName(result.name),
2121 builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
2122 } else {
2123 SMLoc stepLoc = parser.getCurrentLocation();
2124 IntegerAttr stepAttr;
2125 if (parser.parseAttribute(stepAttr, builder.getIndexType(),
2126 getStepAttrName(result.name).data(),
2127 result.attributes))
2128 return failure();
2129
2130 if (stepAttr.getValue().isNegative())
2131 return parser.emitError(
2132 stepLoc,
2133 "expected step to be representable as a positive signed integer");
2134 }
2135
2136 // Parse the optional initial iteration arguments.
2137 SmallVector<OpAsmParser::Argument, 4> regionArgs;
2138 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
2139
2140 // Induction variable.
2141 regionArgs.push_back(inductionVariable);
2142
2143 if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
2144 // Parse assignment list and results type list.
2145 if (parser.parseAssignmentList(regionArgs, operands) ||
2146 parser.parseArrowTypeList(result.types))
2147 return failure();
2148 // Resolve input operands.
2149 for (auto argOperandType :
2150 llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
2151 Type type = std::get<2>(argOperandType);
2152 std::get<0>(argOperandType).type = type;
2153 if (parser.resolveOperand(std::get<1>(argOperandType), type,
2154 result.operands))
2155 return failure();
2156 }
2157 }
2158
2159 result.addAttribute(
2160 getOperandSegmentSizeAttr(),
2161 builder.getDenseI32ArrayAttr({static_cast<int32_t>(numLbOperands),
2162 static_cast<int32_t>(numUbOperands),
2163 static_cast<int32_t>(operands.size())}));
2164
2165 // Parse the body region.
2166 Region *body = result.addRegion();
2167 if (regionArgs.size() != result.types.size() + 1)
2168 return parser.emitError(
2169 parser.getNameLoc(),
2170 "mismatch between the number of loop-carried values and results");
2171 if (parser.parseRegion(*body, regionArgs))
2172 return failure();
2173
2174 AffineForOp::ensureTerminator(*body, builder, result.location);
2175
2176 // Parse the optional attribute list.
2177 return parser.parseOptionalAttrDict(result.attributes);
2178}
2179
2180static void printBound(AffineMapAttr boundMap,
2181 Operation::operand_range boundOperands,
2182 const char *prefix, OpAsmPrinter &p) {
2183 AffineMap map = boundMap.getValue();
2184
2185 // Check if this bound should be printed using custom assembly form.
2186 // The decision to restrict printing custom assembly form to trivial cases
2187 // comes from the will to roundtrip MLIR binary -> text -> binary in a
2188 // lossless way.
2189 // Therefore, custom assembly form parsing and printing is only supported for
2190 // zero-operand constant maps and single symbol operand identity maps.
2191 if (map.getNumResults() == 1) {
2192 AffineExpr expr = map.getResult(idx: 0);
2193
2194 // Print constant bound.
2195 if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
2196 if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2197 p << constExpr.getValue();
2198 return;
2199 }
2200 }
2201
2202 // Print bound that consists of a single SSA symbol if the map is over a
2203 // single symbol.
2204 if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
2205 if (isa<AffineSymbolExpr>(Val: expr)) {
2206 p.printOperand(value: *boundOperands.begin());
2207 return;
2208 }
2209 }
2210 } else {
2211 // Map has multiple results. Print 'min' or 'max' prefix.
2212 p << prefix << ' ';
2213 }
2214
2215 // Print the map and its operands.
2216 p << boundMap;
2217 printDimAndSymbolList(begin: boundOperands.begin(), end: boundOperands.end(),
2218 numDims: map.getNumDims(), printer&: p);
2219}
2220
2221unsigned AffineForOp::getNumIterOperands() {
2222 AffineMap lbMap = getLowerBoundMapAttr().getValue();
2223 AffineMap ubMap = getUpperBoundMapAttr().getValue();
2224
2225 return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
2226}
2227
2228std::optional<MutableArrayRef<OpOperand>>
2229AffineForOp::getYieldedValuesMutable() {
2230 return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2231}
2232
2233void AffineForOp::print(OpAsmPrinter &p) {
2234 p << ' ';
2235 p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{},
2236 /*omitType=*/true);
2237 p << " = ";
2238 printBound(getLowerBoundMapAttr(), getLowerBoundOperands(), "max", p);
2239 p << " to ";
2240 printBound(getUpperBoundMapAttr(), getUpperBoundOperands(), "min", p);
2241
2242 if (getStepAsInt() != 1)
2243 p << " step " << getStepAsInt();
2244
2245 bool printBlockTerminators = false;
2246 if (getNumIterOperands() > 0) {
2247 p << " iter_args(";
2248 auto regionArgs = getRegionIterArgs();
2249 auto operands = getInits();
2250
2251 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
2252 p << std::get<0>(it) << " = " << std::get<1>(it);
2253 });
2254 p << ") -> (" << getResultTypes() << ")";
2255 printBlockTerminators = true;
2256 }
2257
2258 p << ' ';
2259 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2260 printBlockTerminators);
2261 p.printOptionalAttrDict(
2262 (*this)->getAttrs(),
2263 /*elidedAttrs=*/{getLowerBoundMapAttrName(getOperation()->getName()),
2264 getUpperBoundMapAttrName(getOperation()->getName()),
2265 getStepAttrName(getOperation()->getName()),
2266 getOperandSegmentSizeAttr()});
2267}
2268
2269/// Fold the constant bounds of a loop.
2270static LogicalResult foldLoopBounds(AffineForOp forOp) {
2271 auto foldLowerOrUpperBound = [&forOp](bool lower) {
2272 // Check to see if each of the operands is the result of a constant. If
2273 // so, get the value. If not, ignore it.
2274 SmallVector<Attribute, 8> operandConstants;
2275 auto boundOperands =
2276 lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2277 for (auto operand : boundOperands) {
2278 Attribute operandCst;
2279 matchPattern(operand, m_Constant(&operandCst));
2280 operandConstants.push_back(operandCst);
2281 }
2282
2283 AffineMap boundMap =
2284 lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2285 assert(boundMap.getNumResults() >= 1 &&
2286 "bound maps should have at least one result");
2287 SmallVector<Attribute, 4> foldedResults;
2288 if (failed(Result: boundMap.constantFold(operandConstants, results&: foldedResults)))
2289 return failure();
2290
2291 // Compute the max or min as applicable over the results.
2292 assert(!foldedResults.empty() && "bounds should have at least one result");
2293 auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2294 for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2295 auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2296 maxOrMin = lower ? llvm::APIntOps::smax(A: maxOrMin, B: foldedResult)
2297 : llvm::APIntOps::smin(A: maxOrMin, B: foldedResult);
2298 }
2299 lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2300 : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2301 return success();
2302 };
2303
2304 // Try to fold the lower bound.
2305 bool folded = false;
2306 if (!forOp.hasConstantLowerBound())
2307 folded |= succeeded(Result: foldLowerOrUpperBound(/*lower=*/true));
2308
2309 // Try to fold the upper bound.
2310 if (!forOp.hasConstantUpperBound())
2311 folded |= succeeded(Result: foldLowerOrUpperBound(/*lower=*/false));
2312 return success(IsSuccess: folded);
2313}
2314
2315/// Canonicalize the bounds of the given loop.
2316static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
2317 SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
2318 SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
2319
2320 auto lbMap = forOp.getLowerBoundMap();
2321 auto ubMap = forOp.getUpperBoundMap();
2322 auto prevLbMap = lbMap;
2323 auto prevUbMap = ubMap;
2324
2325 composeAffineMapAndOperands(&lbMap, &lbOperands);
2326 canonicalizeMapAndOperands(&lbMap, &lbOperands);
2327 simplifyMinOrMaxExprWithOperands(lbMap, lbOperands, /*isMax=*/true);
2328 simplifyMinOrMaxExprWithOperands(ubMap, ubOperands, /*isMax=*/false);
2329 lbMap = removeDuplicateExprs(lbMap);
2330
2331 composeAffineMapAndOperands(&ubMap, &ubOperands);
2332 canonicalizeMapAndOperands(&ubMap, &ubOperands);
2333 ubMap = removeDuplicateExprs(ubMap);
2334
2335 // Any canonicalization change always leads to updated map(s).
2336 if (lbMap == prevLbMap && ubMap == prevUbMap)
2337 return failure();
2338
2339 if (lbMap != prevLbMap)
2340 forOp.setLowerBound(lbOperands, lbMap);
2341 if (ubMap != prevUbMap)
2342 forOp.setUpperBound(ubOperands, ubMap);
2343 return success();
2344}
2345
2346namespace {
2347/// Returns constant trip count in trivial cases.
2348static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2349 int64_t step = forOp.getStepAsInt();
2350 if (!forOp.hasConstantBounds() || step <= 0)
2351 return std::nullopt;
2352 int64_t lb = forOp.getConstantLowerBound();
2353 int64_t ub = forOp.getConstantUpperBound();
2354 return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2355}
2356
2357/// This is a pattern to fold trivially empty loop bodies.
2358/// TODO: This should be moved into the folding hook.
2359struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
2360 using OpRewritePattern<AffineForOp>::OpRewritePattern;
2361
2362 LogicalResult matchAndRewrite(AffineForOp forOp,
2363 PatternRewriter &rewriter) const override {
2364 // Check that the body only contains a yield.
2365 if (!llvm::hasSingleElement(*forOp.getBody()))
2366 return failure();
2367 if (forOp.getNumResults() == 0)
2368 return success();
2369 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2370 if (tripCount && *tripCount == 0) {
2371 // The initial values of the iteration arguments would be the op's
2372 // results.
2373 rewriter.replaceOp(forOp, forOp.getInits());
2374 return success();
2375 }
2376 SmallVector<Value, 4> replacements;
2377 auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2378 auto iterArgs = forOp.getRegionIterArgs();
2379 bool hasValDefinedOutsideLoop = false;
2380 bool iterArgsNotInOrder = false;
2381 for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2382 Value val = yieldOp.getOperand(i);
2383 auto *iterArgIt = llvm::find(iterArgs, val);
2384 // TODO: It should be possible to perform a replacement by computing the
2385 // last value of the IV based on the bounds and the step.
2386 if (val == forOp.getInductionVar())
2387 return failure();
2388 if (iterArgIt == iterArgs.end()) {
2389 // `val` is defined outside of the loop.
2390 assert(forOp.isDefinedOutsideOfLoop(val) &&
2391 "must be defined outside of the loop");
2392 hasValDefinedOutsideLoop = true;
2393 replacements.push_back(Elt: val);
2394 } else {
2395 unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2396 if (pos != i)
2397 iterArgsNotInOrder = true;
2398 replacements.push_back(Elt: forOp.getInits()[pos]);
2399 }
2400 }
2401 // Bail out when the trip count is unknown and the loop returns any value
2402 // defined outside of the loop or any iterArg out of order.
2403 if (!tripCount.has_value() &&
2404 (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2405 return failure();
2406 // Bail out when the loop iterates more than once and it returns any iterArg
2407 // out of order.
2408 if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2409 return failure();
2410 rewriter.replaceOp(forOp, replacements);
2411 return success();
2412 }
2413};
2414} // namespace
2415
2416void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
2417 MLIRContext *context) {
2418 results.add<AffineForEmptyLoopFolder>(context);
2419}
2420
2421OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2422 assert((point.isParent() || point == getRegion()) && "invalid region point");
2423
2424 // The initial operands map to the loop arguments after the induction
2425 // variable or are forwarded to the results when the trip count is zero.
2426 return getInits();
2427}
2428
2429void AffineForOp::getSuccessorRegions(
2430 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2431 assert((point.isParent() || point == getRegion()) && "expected loop region");
2432 // The loop may typically branch back to its body or to the parent operation.
2433 // If the predecessor is the parent op and the trip count is known to be at
2434 // least one, branch into the body using the iterator arguments. And in cases
2435 // we know the trip count is zero, it can only branch back to its parent.
2436 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
2437 if (point.isParent() && tripCount.has_value()) {
2438 if (tripCount.value() > 0) {
2439 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2440 return;
2441 }
2442 if (tripCount.value() == 0) {
2443 regions.push_back(RegionSuccessor(getResults()));
2444 return;
2445 }
2446 }
2447
2448 // From the loop body, if the trip count is one, we can only branch back to
2449 // the parent.
2450 if (!point.isParent() && tripCount && *tripCount == 1) {
2451 regions.push_back(RegionSuccessor(getResults()));
2452 return;
2453 }
2454
2455 // In all other cases, the loop may branch back to itself or the parent
2456 // operation.
2457 regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2458 regions.push_back(RegionSuccessor(getResults()));
2459}
2460
2461/// Returns true if the affine.for has zero iterations in trivial cases.
2462static bool hasTrivialZeroTripCount(AffineForOp op) {
2463 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(op);
2464 return tripCount && *tripCount == 0;
2465}
2466
2467LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2468 SmallVectorImpl<OpFoldResult> &results) {
2469 bool folded = succeeded(foldLoopBounds(*this));
2470 folded |= succeeded(canonicalizeLoopBounds(*this));
2471 if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) {
2472 // The initial values of the loop-carried variables (iter_args) are the
2473 // results of the op. But this must be avoided for an affine.for op that
2474 // does not return any results. Since ops that do not return results cannot
2475 // be folded away, we would enter an infinite loop of folds on the same
2476 // affine.for op.
2477 results.assign(getInits().begin(), getInits().end());
2478 folded = true;
2479 }
2480 return success(folded);
2481}
2482
2483AffineBound AffineForOp::getLowerBound() {
2484 return AffineBound(*this, getLowerBoundOperands(), getLowerBoundMap());
2485}
2486
2487AffineBound AffineForOp::getUpperBound() {
2488 return AffineBound(*this, getUpperBoundOperands(), getUpperBoundMap());
2489}
2490
2491void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
2492 assert(lbOperands.size() == map.getNumInputs());
2493 assert(map.getNumResults() >= 1 && "bound map has at least one result");
2494 getLowerBoundOperandsMutable().assign(lbOperands);
2495 setLowerBoundMap(map);
2496}
2497
2498void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
2499 assert(ubOperands.size() == map.getNumInputs());
2500 assert(map.getNumResults() >= 1 && "bound map has at least one result");
2501 getUpperBoundOperandsMutable().assign(ubOperands);
2502 setUpperBoundMap(map);
2503}
2504
2505bool AffineForOp::hasConstantLowerBound() {
2506 return getLowerBoundMap().isSingleConstant();
2507}
2508
2509bool AffineForOp::hasConstantUpperBound() {
2510 return getUpperBoundMap().isSingleConstant();
2511}
2512
2513int64_t AffineForOp::getConstantLowerBound() {
2514 return getLowerBoundMap().getSingleConstantResult();
2515}
2516
2517int64_t AffineForOp::getConstantUpperBound() {
2518 return getUpperBoundMap().getSingleConstantResult();
2519}
2520
2521void AffineForOp::setConstantLowerBound(int64_t value) {
2522 setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
2523}
2524
2525void AffineForOp::setConstantUpperBound(int64_t value) {
2526 setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
2527}
2528
2529AffineForOp::operand_range AffineForOp::getControlOperands() {
2530 return {operand_begin(), operand_begin() + getLowerBoundOperands().size() +
2531 getUpperBoundOperands().size()};
2532}
2533
2534bool AffineForOp::matchingBoundOperandList() {
2535 auto lbMap = getLowerBoundMap();
2536 auto ubMap = getUpperBoundMap();
2537 if (lbMap.getNumDims() != ubMap.getNumDims() ||
2538 lbMap.getNumSymbols() != ubMap.getNumSymbols())
2539 return false;
2540
2541 unsigned numOperands = lbMap.getNumInputs();
2542 for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
2543 // Compare Value 's.
2544 if (getOperand(i) != getOperand(numOperands + i))
2545 return false;
2546 }
2547 return true;
2548}
2549
2550SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
2551
2552std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
2553 return SmallVector<Value>{getInductionVar()};
2554}
2555
2556std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
2557 if (!hasConstantLowerBound())
2558 return std::nullopt;
2559 OpBuilder b(getContext());
2560 return SmallVector<OpFoldResult>{
2561 OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
2562}
2563
2564std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
2565 OpBuilder b(getContext());
2566 return SmallVector<OpFoldResult>{
2567 OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))};
2568}
2569
2570std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
2571 if (!hasConstantUpperBound())
2572 return {};
2573 OpBuilder b(getContext());
2574 return SmallVector<OpFoldResult>{
2575 OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
2576}
2577
2578FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2579 RewriterBase &rewriter, ValueRange newInitOperands,
2580 bool replaceInitOperandUsesInLoop,
2581 const NewYieldValuesFn &newYieldValuesFn) {
2582 // Create a new loop before the existing one, with the extra operands.
2583 OpBuilder::InsertionGuard g(rewriter);
2584 rewriter.setInsertionPoint(getOperation());
2585 auto inits = llvm::to_vector(getInits());
2586 inits.append(newInitOperands.begin(), newInitOperands.end());
2587 AffineForOp newLoop = rewriter.create<AffineForOp>(
2588 getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
2589 getUpperBoundOperands(), getUpperBoundMap(), getStepAsInt(), inits);
2590
2591 // Generate the new yield values and append them to the scf.yield operation.
2592 auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2593 ArrayRef<BlockArgument> newIterArgs =
2594 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2595 {
2596 OpBuilder::InsertionGuard g(rewriter);
2597 rewriter.setInsertionPoint(yieldOp);
2598 SmallVector<Value> newYieldedValues =
2599 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2600 assert(newInitOperands.size() == newYieldedValues.size() &&
2601 "expected as many new yield values as new iter operands");
2602 rewriter.modifyOpInPlace(yieldOp, [&]() {
2603 yieldOp.getOperandsMutable().append(newYieldedValues);
2604 });
2605 }
2606
2607 // Move the loop body to the new op.
2608 rewriter.mergeBlocks(getBody(), newLoop.getBody(),
2609 newLoop.getBody()->getArguments().take_front(
2610 getBody()->getNumArguments()));
2611
2612 if (replaceInitOperandUsesInLoop) {
2613 // Replace all uses of `newInitOperands` with the corresponding basic block
2614 // arguments.
2615 for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
2616 rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
2617 [&](OpOperand &use) {
2618 Operation *user = use.getOwner();
2619 return newLoop->isProperAncestor(user);
2620 });
2621 }
2622 }
2623
2624 // Replace the old loop.
2625 rewriter.replaceOp(getOperation(),
2626 newLoop->getResults().take_front(getNumResults()));
2627 return cast<LoopLikeOpInterface>(newLoop.getOperation());
2628}
2629
2630Speculation::Speculatability AffineForOp::getSpeculatability() {
2631 // `affine.for (I = Start; I < End; I += 1)` terminates for all values of
2632 // Start and End.
2633 //
2634 // For Step != 1, the loop may not terminate. We can add more smarts here if
2635 // needed.
2636 return getStepAsInt() == 1 ? Speculation::RecursivelySpeculatable
2637 : Speculation::NotSpeculatable;
2638}
2639
2640/// Returns true if the provided value is the induction variable of a
2641/// AffineForOp.
2642bool mlir::affine::isAffineForInductionVar(Value val) {
2643 return getForInductionVarOwner(val) != AffineForOp();
2644}
2645
2646bool mlir::affine::isAffineParallelInductionVar(Value val) {
2647 return getAffineParallelInductionVarOwner(val) != nullptr;
2648}
2649
2650bool mlir::affine::isAffineInductionVar(Value val) {
2651 return isAffineForInductionVar(val) || isAffineParallelInductionVar(val);
2652}
2653
2654AffineForOp mlir::affine::getForInductionVarOwner(Value val) {
2655 auto ivArg = llvm::dyn_cast<BlockArgument>(Val&: val);
2656 if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())
2657 return AffineForOp();
2658 if (auto forOp =
2659 ivArg.getOwner()->getParent()->getParentOfType<AffineForOp>())
2660 // Check to make sure `val` is the induction variable, not an iter_arg.
2661 return forOp.getInductionVar() == val ? forOp : AffineForOp();
2662 return AffineForOp();
2663}
2664
2665AffineParallelOp mlir::affine::getAffineParallelInductionVarOwner(Value val) {
2666 auto ivArg = llvm::dyn_cast<BlockArgument>(Val&: val);
2667 if (!ivArg || !ivArg.getOwner())
2668 return nullptr;
2669 Operation *containingOp = ivArg.getOwner()->getParentOp();
2670 auto parallelOp = dyn_cast_if_present<AffineParallelOp>(containingOp);
2671 if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2672 return parallelOp;
2673 return nullptr;
2674}
2675
2676/// Extracts the induction variables from a list of AffineForOps and returns
2677/// them.
2678void mlir::affine::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
2679 SmallVectorImpl<Value> *ivs) {
2680 ivs->reserve(N: forInsts.size());
2681 for (auto forInst : forInsts)
2682 ivs->push_back(forInst.getInductionVar());
2683}
2684
2685void mlir::affine::extractInductionVars(ArrayRef<mlir::Operation *> affineOps,
2686 SmallVectorImpl<mlir::Value> &ivs) {
2687 ivs.reserve(N: affineOps.size());
2688 for (Operation *op : affineOps) {
2689 // Add constraints from forOp's bounds.
2690 if (auto forOp = dyn_cast<AffineForOp>(op))
2691 ivs.push_back(Elt: forOp.getInductionVar());
2692 else if (auto parallelOp = dyn_cast<AffineParallelOp>(op))
2693 for (size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2694 ivs.push_back(Elt: parallelOp.getBody()->getArgument(i));
2695 }
2696}
2697
2698/// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
2699/// operations.
2700template <typename BoundListTy, typename LoopCreatorTy>
2701static void buildAffineLoopNestImpl(
2702 OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
2703 ArrayRef<int64_t> steps,
2704 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
2705 LoopCreatorTy &&loopCreatorFn) {
2706 assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
2707 assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
2708
2709 // If there are no loops to be constructed, construct the body anyway.
2710 OpBuilder::InsertionGuard guard(builder);
2711 if (lbs.empty()) {
2712 if (bodyBuilderFn)
2713 bodyBuilderFn(builder, loc, ValueRange());
2714 return;
2715 }
2716
2717 // Create the loops iteratively and store the induction variables.
2718 SmallVector<Value, 4> ivs;
2719 ivs.reserve(N: lbs.size());
2720 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
2721 // Callback for creating the loop body, always creates the terminator.
2722 auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
2723 ValueRange iterArgs) {
2724 ivs.push_back(Elt: iv);
2725 // In the innermost loop, call the body builder.
2726 if (i == e - 1 && bodyBuilderFn) {
2727 OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
2728 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2729 }
2730 nestedBuilder.create<AffineYieldOp>(nestedLoc);
2731 };
2732
2733 // Delegate actual loop creation to the callback in order to dispatch
2734 // between constant- and variable-bound loops.
2735 auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2736 builder.setInsertionPointToStart(loop.getBody());
2737 }
2738}
2739
2740/// Creates an affine loop from the bounds known to be constants.
2741static AffineForOp
2742buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb,
2743 int64_t ub, int64_t step,
2744 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2745 return builder.create<AffineForOp>(loc, lb, ub, step,
2746 /*iterArgs=*/std::nullopt, bodyBuilderFn);
2747}
2748
2749/// Creates an affine loop from the bounds that may or may not be constants.
2750static AffineForOp
2751buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub,
2752 int64_t step,
2753 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2754 std::optional<int64_t> lbConst = getConstantIntValue(ofr: lb);
2755 std::optional<int64_t> ubConst = getConstantIntValue(ofr: ub);
2756 if (lbConst && ubConst)
2757 return buildAffineLoopFromConstants(builder, loc, lbConst.value(),
2758 ubConst.value(), step, bodyBuilderFn);
2759 return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
2760 builder.getDimIdentityMap(), step,
2761 /*iterArgs=*/std::nullopt, bodyBuilderFn);
2762}
2763
2764void mlir::affine::buildAffineLoopNest(
2765 OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
2766 ArrayRef<int64_t> ubs, ArrayRef<int64_t> steps,
2767 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2768 buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2769 buildAffineLoopFromConstants);
2770}
2771
2772void mlir::affine::buildAffineLoopNest(
2773 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
2774 ArrayRef<int64_t> steps,
2775 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2776 buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2777 buildAffineLoopFromValues);
2778}
2779
2780//===----------------------------------------------------------------------===//
2781// AffineIfOp
2782//===----------------------------------------------------------------------===//
2783
2784namespace {
2785/// Remove else blocks that have nothing other than a zero value yield.
2786struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
2787 using OpRewritePattern<AffineIfOp>::OpRewritePattern;
2788
2789 LogicalResult matchAndRewrite(AffineIfOp ifOp,
2790 PatternRewriter &rewriter) const override {
2791 if (ifOp.getElseRegion().empty() ||
2792 !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
2793 return failure();
2794
2795 rewriter.startOpModification(op: ifOp);
2796 rewriter.eraseBlock(block: ifOp.getElseBlock());
2797 rewriter.finalizeOpModification(op: ifOp);
2798 return success();
2799 }
2800};
2801
2802/// Removes affine.if cond if the condition is always true or false in certain
2803/// trivial cases. Promotes the then/else block in the parent operation block.
2804struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
2805 using OpRewritePattern<AffineIfOp>::OpRewritePattern;
2806
2807 LogicalResult matchAndRewrite(AffineIfOp op,
2808 PatternRewriter &rewriter) const override {
2809
2810 auto isTriviallyFalse = [](IntegerSet iSet) {
2811 return iSet.isEmptyIntegerSet();
2812 };
2813
2814 auto isTriviallyTrue = [](IntegerSet iSet) {
2815 return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2816 iSet.getConstraint(idx: 0) == 0);
2817 };
2818
2819 IntegerSet affineIfConditions = op.getIntegerSet();
2820 Block *blockToMove;
2821 if (isTriviallyFalse(affineIfConditions)) {
2822 // The absence, or equivalently, the emptiness of the else region need not
2823 // be checked when affine.if is returning results because if an affine.if
2824 // operation is returning results, it always has a non-empty else region.
2825 if (op.getNumResults() == 0 && !op.hasElse()) {
2826 // If the else region is absent, or equivalently, empty, remove the
2827 // affine.if operation (which is not returning any results).
2828 rewriter.eraseOp(op: op);
2829 return success();
2830 }
2831 blockToMove = op.getElseBlock();
2832 } else if (isTriviallyTrue(affineIfConditions)) {
2833 blockToMove = op.getThenBlock();
2834 } else {
2835 return failure();
2836 }
2837 Operation *blockToMoveTerminator = blockToMove->getTerminator();
2838 // Promote the "blockToMove" block to the parent operation block between the
2839 // prologue and epilogue of "op".
2840 rewriter.inlineBlockBefore(blockToMove, op);
2841 // Replace the "op" operation with the operands of the
2842 // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is
2843 // the affine.yield operation present in the "blockToMove" block. It has no
2844 // operands when affine.if is not returning results and therefore, in that
2845 // case, replaceOp just erases "op". When affine.if is not returning
2846 // results, the affine.yield operation can be omitted. It gets inserted
2847 // implicitly.
2848 rewriter.replaceOp(op, blockToMoveTerminator->getOperands());
2849 // Erase the "blockToMoveTerminator" operation since it is now in the parent
2850 // operation block, which already has its own terminator.
2851 rewriter.eraseOp(op: blockToMoveTerminator);
2852 return success();
2853 }
2854};
2855} // namespace
2856
2857/// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
2858/// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
2859void AffineIfOp::getSuccessorRegions(
2860 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2861 // If the predecessor is an AffineIfOp, then branching into both `then` and
2862 // `else` region is valid.
2863 if (point.isParent()) {
2864 regions.reserve(2);
2865 regions.push_back(
2866 RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));
2867 // If the "else" region is empty, branch bach into parent.
2868 if (getElseRegion().empty()) {
2869 regions.push_back(getResults());
2870 } else {
2871 regions.push_back(
2872 RegionSuccessor(&getElseRegion(), getElseRegion().getArguments()));
2873 }
2874 return;
2875 }
2876
2877 // If the predecessor is the `else`/`then` region, then branching into parent
2878 // op is valid.
2879 regions.push_back(RegionSuccessor(getResults()));
2880}
2881
2882LogicalResult AffineIfOp::verify() {
2883 // Verify that we have a condition attribute.
2884 // FIXME: This should be specified in the arguments list in ODS.
2885 auto conditionAttr =
2886 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2887 if (!conditionAttr)
2888 return emitOpError("requires an integer set attribute named 'condition'");
2889
2890 // Verify that there are enough operands for the condition.
2891 IntegerSet condition = conditionAttr.getValue();
2892 if (getNumOperands() != condition.getNumInputs())
2893 return emitOpError("operand count and condition integer set dimension and "
2894 "symbol count must match");
2895
2896 // Verify that the operands are valid dimension/symbols.
2897 if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(),
2898 condition.getNumDims())))
2899 return failure();
2900
2901 return success();
2902}
2903
2904ParseResult AffineIfOp::parse(OpAsmParser &parser, OperationState &result) {
2905 // Parse the condition attribute set.
2906 IntegerSetAttr conditionAttr;
2907 unsigned numDims;
2908 if (parser.parseAttribute(conditionAttr,
2909 AffineIfOp::getConditionAttrStrName(),
2910 result.attributes) ||
2911 parseDimAndSymbolList(parser, result.operands, numDims))
2912 return failure();
2913
2914 // Verify the condition operands.
2915 auto set = conditionAttr.getValue();
2916 if (set.getNumDims() != numDims)
2917 return parser.emitError(
2918 parser.getNameLoc(),
2919 "dim operand count and integer set dim count must match");
2920 if (numDims + set.getNumSymbols() != result.operands.size())
2921 return parser.emitError(
2922 parser.getNameLoc(),
2923 "symbol operand count and integer set symbol count must match");
2924
2925 if (parser.parseOptionalArrowTypeList(result.types))
2926 return failure();
2927
2928 // Create the regions for 'then' and 'else'. The latter must be created even
2929 // if it remains empty for the validity of the operation.
2930 result.regions.reserve(2);
2931 Region *thenRegion = result.addRegion();
2932 Region *elseRegion = result.addRegion();
2933
2934 // Parse the 'then' region.
2935 if (parser.parseRegion(*thenRegion, {}, {}))
2936 return failure();
2937 AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
2938 result.location);
2939
2940 // If we find an 'else' keyword then parse the 'else' region.
2941 if (!parser.parseOptionalKeyword("else")) {
2942 if (parser.parseRegion(*elseRegion, {}, {}))
2943 return failure();
2944 AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
2945 result.location);
2946 }
2947
2948 // Parse the optional attribute list.
2949 if (parser.parseOptionalAttrDict(result.attributes))
2950 return failure();
2951
2952 return success();
2953}
2954
2955void AffineIfOp::print(OpAsmPrinter &p) {
2956 auto conditionAttr =
2957 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2958 p << " " << conditionAttr;
2959 printDimAndSymbolList(operand_begin(), operand_end(),
2960 conditionAttr.getValue().getNumDims(), p);
2961 p.printOptionalArrowTypeList(getResultTypes());
2962 p << ' ';
2963 p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false,
2964 /*printBlockTerminators=*/getNumResults());
2965
2966 // Print the 'else' regions if it has any blocks.
2967 auto &elseRegion = this->getElseRegion();
2968 if (!elseRegion.empty()) {
2969 p << " else ";
2970 p.printRegion(elseRegion,
2971 /*printEntryBlockArgs=*/false,
2972 /*printBlockTerminators=*/getNumResults());
2973 }
2974
2975 // Print the attribute list.
2976 p.printOptionalAttrDict((*this)->getAttrs(),
2977 /*elidedAttrs=*/getConditionAttrStrName());
2978}
2979
2980IntegerSet AffineIfOp::getIntegerSet() {
2981 return (*this)
2982 ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
2983 .getValue();
2984}
2985
2986void AffineIfOp::setIntegerSet(IntegerSet newSet) {
2987 (*this)->setAttr(getConditionAttrStrName(), IntegerSetAttr::get(newSet));
2988}
2989
2990void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
2991 setIntegerSet(set);
2992 (*this)->setOperands(operands);
2993}
2994
2995void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2996 TypeRange resultTypes, IntegerSet set, ValueRange args,
2997 bool withElseRegion) {
2998 assert(resultTypes.empty() || withElseRegion);
2999 OpBuilder::InsertionGuard guard(builder);
3000
3001 result.addTypes(resultTypes);
3002 result.addOperands(args);
3003 result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set));
3004
3005 Region *thenRegion = result.addRegion();
3006 builder.createBlock(thenRegion);
3007 if (resultTypes.empty())
3008 AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
3009
3010 Region *elseRegion = result.addRegion();
3011 if (withElseRegion) {
3012 builder.createBlock(elseRegion);
3013 if (resultTypes.empty())
3014 AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
3015 }
3016}
3017
3018void AffineIfOp::build(OpBuilder &builder, OperationState &result,
3019 IntegerSet set, ValueRange args, bool withElseRegion) {
3020 AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
3021 withElseRegion);
3022}
3023
3024/// Compose any affine.apply ops feeding into `operands` of the integer set
3025/// `set` by composing the maps of such affine.apply ops with the integer
3026/// set constraints.
3027static void composeSetAndOperands(IntegerSet &set,
3028 SmallVectorImpl<Value> &operands) {
3029 // We will simply reuse the API of the map composition by viewing the LHSs of
3030 // the equalities and inequalities of `set` as the affine exprs of an affine
3031 // map. Convert to equivalent map, compose, and convert back to set.
3032 auto map = AffineMap::get(dimCount: set.getNumDims(), symbolCount: set.getNumSymbols(),
3033 results: set.getConstraints(), context: set.getContext());
3034 // Check if any composition is possible.
3035 if (llvm::none_of(Range&: operands,
3036 P: [](Value v) { return v.getDefiningOp<AffineApplyOp>(); }))
3037 return;
3038
3039 composeAffineMapAndOperands(map: &map, operands: &operands);
3040 set = IntegerSet::get(dimCount: map.getNumDims(), symbolCount: map.getNumSymbols(), constraints: map.getResults(),
3041 eqFlags: set.getEqFlags());
3042}
3043
3044/// Canonicalize an affine if op's conditional (integer set + operands).
3045LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3046 auto set = getIntegerSet();
3047 SmallVector<Value, 4> operands(getOperands());
3048 composeSetAndOperands(set, operands);
3049 canonicalizeSetAndOperands(&set, &operands);
3050
3051 // Check if the canonicalization or composition led to any change.
3052 if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
3053 return failure();
3054
3055 setConditional(set, operands);
3056 return success();
3057}
3058
3059void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
3060 MLIRContext *context) {
3061 results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
3062}
3063
3064//===----------------------------------------------------------------------===//
3065// AffineLoadOp
3066//===----------------------------------------------------------------------===//
3067
3068void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3069 AffineMap map, ValueRange operands) {
3070 assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
3071 result.addOperands(operands);
3072 if (map)
3073 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3074 auto memrefType = llvm::cast<MemRefType>(operands[0].getType());
3075 result.types.push_back(memrefType.getElementType());
3076}
3077
3078void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3079 Value memref, AffineMap map, ValueRange mapOperands) {
3080 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3081 result.addOperands(memref);
3082 result.addOperands(mapOperands);
3083 auto memrefType = llvm::cast<MemRefType>(memref.getType());
3084 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3085 result.types.push_back(memrefType.getElementType());
3086}
3087
3088void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3089 Value memref, ValueRange indices) {
3090 auto memrefType = llvm::cast<MemRefType>(memref.getType());
3091 int64_t rank = memrefType.getRank();
3092 // Create identity map for memrefs with at least one dimension or () -> ()
3093 // for zero-dimensional memrefs.
3094 auto map =
3095 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3096 build(builder, result, memref, map, indices);
3097}
3098
3099ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
3100 auto &builder = parser.getBuilder();
3101 auto indexTy = builder.getIndexType();
3102
3103 MemRefType type;
3104 OpAsmParser::UnresolvedOperand memrefInfo;
3105 AffineMapAttr mapAttr;
3106 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
3107 return failure(
3108 parser.parseOperand(memrefInfo) ||
3109 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3110 AffineLoadOp::getMapAttrStrName(),
3111 result.attributes) ||
3112 parser.parseOptionalAttrDict(result.attributes) ||
3113 parser.parseColonType(type) ||
3114 parser.resolveOperand(memrefInfo, type, result.operands) ||
3115 parser.resolveOperands(mapOperands, indexTy, result.operands) ||
3116 parser.addTypeToList(type.getElementType(), result.types));
3117}
3118
3119void AffineLoadOp::print(OpAsmPrinter &p) {
3120 p << " " << getMemRef() << '[';
3121 if (AffineMapAttr mapAttr =
3122 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3123 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3124 p << ']';
3125 p.printOptionalAttrDict((*this)->getAttrs(),
3126 /*elidedAttrs=*/{getMapAttrStrName()});
3127 p << " : " << getMemRefType();
3128}
3129
3130/// Verify common indexing invariants of affine.load, affine.store,
3131/// affine.vector_load and affine.vector_store.
3132template <typename AffineMemOpTy>
3133static LogicalResult
3134verifyMemoryOpIndexing(AffineMemOpTy op, AffineMapAttr mapAttr,
3135 Operation::operand_range mapOperands,
3136 MemRefType memrefType, unsigned numIndexOperands) {
3137 AffineMap map = mapAttr.getValue();
3138 if (map.getNumResults() != memrefType.getRank())
3139 return op->emitOpError("affine map num results must equal memref rank");
3140 if (map.getNumInputs() != numIndexOperands)
3141 return op->emitOpError("expects as many subscripts as affine map inputs");
3142
3143 for (auto idx : mapOperands) {
3144 if (!idx.getType().isIndex())
3145 return op->emitOpError("index to load must have 'index' type");
3146 }
3147 if (failed(verifyDimAndSymbolIdentifiers(op, mapOperands, map.getNumDims())))
3148 return failure();
3149
3150 return success();
3151}
3152
3153LogicalResult AffineLoadOp::verify() {
3154 auto memrefType = getMemRefType();
3155 if (getType() != memrefType.getElementType())
3156 return emitOpError("result type must match element type of memref");
3157
3158 if (failed(verifyMemoryOpIndexing(
3159 *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3160 getMapOperands(), memrefType,
3161 /*numIndexOperands=*/getNumOperands() - 1)))
3162 return failure();
3163
3164 return success();
3165}
3166
3167void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3168 MLIRContext *context) {
3169 results.add<SimplifyAffineOp<AffineLoadOp>>(context);
3170}
3171
3172OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) {
3173 /// load(memrefcast) -> load
3174 if (succeeded(memref::foldMemRefCast(*this)))
3175 return getResult();
3176
3177 // Fold load from a global constant memref.
3178 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3179 if (!getGlobalOp)
3180 return {};
3181 // Get to the memref.global defining the symbol.
3182 auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
3183 if (!symbolTableOp)
3184 return {};
3185 auto global = dyn_cast_or_null<memref::GlobalOp>(
3186 SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr()));
3187 if (!global)
3188 return {};
3189
3190 // Check if the global memref is a constant.
3191 auto cstAttr =
3192 llvm::dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3193 if (!cstAttr)
3194 return {};
3195 // If it's a splat constant, we can fold irrespective of indices.
3196 if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(cstAttr))
3197 return splatAttr.getSplatValue<Attribute>();
3198 // Otherwise, we can fold only if we know the indices.
3199 if (!getAffineMap().isConstant())
3200 return {};
3201 auto indices = llvm::to_vector<4>(
3202 llvm::map_range(getAffineMap().getConstantResults(),
3203 [](int64_t v) -> uint64_t { return v; }));
3204 return cstAttr.getValues<Attribute>()[indices];
3205}
3206
3207//===----------------------------------------------------------------------===//
3208// AffineStoreOp
3209//===----------------------------------------------------------------------===//
3210
3211void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3212 Value valueToStore, Value memref, AffineMap map,
3213 ValueRange mapOperands) {
3214 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3215 result.addOperands(valueToStore);
3216 result.addOperands(memref);
3217 result.addOperands(mapOperands);
3218 result.getOrAddProperties<Properties>().map = AffineMapAttr::get(map);
3219}
3220
3221// Use identity map.
3222void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3223 Value valueToStore, Value memref,
3224 ValueRange indices) {
3225 auto memrefType = llvm::cast<MemRefType>(memref.getType());
3226 int64_t rank = memrefType.getRank();
3227 // Create identity map for memrefs with at least one dimension or () -> ()
3228 // for zero-dimensional memrefs.
3229 auto map =
3230 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3231 build(builder, result, valueToStore, memref, map, indices);
3232}
3233
3234ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
3235 auto indexTy = parser.getBuilder().getIndexType();
3236
3237 MemRefType type;
3238 OpAsmParser::UnresolvedOperand storeValueInfo;
3239 OpAsmParser::UnresolvedOperand memrefInfo;
3240 AffineMapAttr mapAttr;
3241 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
3242 return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
3243 parser.parseOperand(memrefInfo) ||
3244 parser.parseAffineMapOfSSAIds(
3245 mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3246 result.attributes) ||
3247 parser.parseOptionalAttrDict(result.attributes) ||
3248 parser.parseColonType(type) ||
3249 parser.resolveOperand(storeValueInfo, type.getElementType(),
3250 result.operands) ||
3251 parser.resolveOperand(memrefInfo, type, result.operands) ||
3252 parser.resolveOperands(mapOperands, indexTy, result.operands));
3253}
3254
3255void AffineStoreOp::print(OpAsmPrinter &p) {
3256 p << " " << getValueToStore();
3257 p << ", " << getMemRef() << '[';
3258 if (AffineMapAttr mapAttr =
3259 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3260 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3261 p << ']';
3262 p.printOptionalAttrDict((*this)->getAttrs(),
3263 /*elidedAttrs=*/{getMapAttrStrName()});
3264 p << " : " << getMemRefType();
3265}
3266
3267LogicalResult AffineStoreOp::verify() {
3268 // The value to store must have the same type as memref element type.
3269 auto memrefType = getMemRefType();
3270 if (getValueToStore().getType() != memrefType.getElementType())
3271 return emitOpError(
3272 "value to store must have the same type as memref element type");
3273
3274 if (failed(verifyMemoryOpIndexing(
3275 *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3276 getMapOperands(), memrefType,
3277 /*numIndexOperands=*/getNumOperands() - 2)))
3278 return failure();
3279
3280 return success();
3281}
3282
3283void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3284 MLIRContext *context) {
3285 results.add<SimplifyAffineOp<AffineStoreOp>>(context);
3286}
3287
3288LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3289 SmallVectorImpl<OpFoldResult> &results) {
3290 /// store(memrefcast) -> store
3291 return memref::foldMemRefCast(*this, getValueToStore());
3292}
3293
3294//===----------------------------------------------------------------------===//
3295// AffineMinMaxOpBase
3296//===----------------------------------------------------------------------===//
3297
3298template <typename T>
3299static LogicalResult verifyAffineMinMaxOp(T op) {
3300 // Verify that operand count matches affine map dimension and symbol count.
3301 if (op.getNumOperands() !=
3302 op.getMap().getNumDims() + op.getMap().getNumSymbols())
3303 return op.emitOpError(
3304 "operand count and affine map dimension and symbol count must match");
3305
3306 if (op.getMap().getNumResults() == 0)
3307 return op.emitOpError("affine map expect at least one result");
3308 return success();
3309}
3310
3311template <typename T>
3312static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
3313 p << ' ' << op->getAttr(T::getMapAttrStrName());
3314 auto operands = op.getOperands();
3315 unsigned numDims = op.getMap().getNumDims();
3316 p << '(' << operands.take_front(numDims) << ')';
3317
3318 if (operands.size() != numDims)
3319 p << '[' << operands.drop_front(numDims) << ']';
3320 p.printOptionalAttrDict(attrs: op->getAttrs(),
3321 /*elidedAttrs=*/{T::getMapAttrStrName()});
3322}
3323
3324template <typename T>
3325static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
3326 OperationState &result) {
3327 auto &builder = parser.getBuilder();
3328 auto indexType = builder.getIndexType();
3329 SmallVector<OpAsmParser::UnresolvedOperand, 8> dimInfos;
3330 SmallVector<OpAsmParser::UnresolvedOperand, 8> symInfos;
3331 AffineMapAttr mapAttr;
3332 return failure(
3333 parser.parseAttribute(mapAttr, T::getMapAttrStrName(),
3334 result.attributes) ||
3335 parser.parseOperandList(result&: dimInfos, delimiter: OpAsmParser::Delimiter::Paren) ||
3336 parser.parseOperandList(result&: symInfos,
3337 delimiter: OpAsmParser::Delimiter::OptionalSquare) ||
3338 parser.parseOptionalAttrDict(result&: result.attributes) ||
3339 parser.resolveOperands(dimInfos, indexType, result.operands) ||
3340 parser.resolveOperands(symInfos, indexType, result.operands) ||
3341 parser.addTypeToList(type: indexType, result&: result.types));
3342}
3343
3344/// Fold an affine min or max operation with the given operands. The operand
3345/// list may contain nulls, which are interpreted as the operand not being a
3346/// constant.
3347template <typename T>
3348static OpFoldResult foldMinMaxOp(T op, ArrayRef<Attribute> operands) {
3349 static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3350 "expected affine min or max op");
3351
3352 // Fold the affine map.
3353 // TODO: Fold more cases:
3354 // min(some_affine, some_affine + constant, ...), etc.
3355 SmallVector<int64_t, 2> results;
3356 auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3357
3358 if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3359 return op.getOperand(0);
3360
3361 // If some of the map results are not constant, try changing the map in-place.
3362 if (results.empty()) {
3363 // If the map is the same, report that folding did not happen.
3364 if (foldedMap == op.getMap())
3365 return {};
3366 op->setAttr("map", AffineMapAttr::get(foldedMap));
3367 return op.getResult();
3368 }
3369
3370 // Otherwise, completely fold the op into a constant.
3371 auto resultIt = std::is_same<T, AffineMinOp>::value
3372 ? llvm::min_element(Range&: results)
3373 : llvm::max_element(Range&: results);
3374 if (resultIt == results.end())
3375 return {};
3376 return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
3377}
3378
3379/// Remove duplicated expressions in affine min/max ops.
3380template <typename T>
3381struct DeduplicateAffineMinMaxExpressions : public OpRewritePattern<T> {
3382 using OpRewritePattern<T>::OpRewritePattern;
3383
3384 LogicalResult matchAndRewrite(T affineOp,
3385 PatternRewriter &rewriter) const override {
3386 AffineMap oldMap = affineOp.getAffineMap();
3387
3388 SmallVector<AffineExpr, 4> newExprs;
3389 for (AffineExpr expr : oldMap.getResults()) {
3390 // This is a linear scan over newExprs, but it should be fine given that
3391 // we typically just have a few expressions per op.
3392 if (!llvm::is_contained(Range&: newExprs, Element: expr))
3393 newExprs.push_back(Elt: expr);
3394 }
3395
3396 if (newExprs.size() == oldMap.getNumResults())
3397 return failure();
3398
3399 auto newMap = AffineMap::get(dimCount: oldMap.getNumDims(), symbolCount: oldMap.getNumSymbols(),
3400 results: newExprs, context: rewriter.getContext());
3401 rewriter.replaceOpWithNewOp<T>(affineOp, newMap, affineOp.getMapOperands());
3402
3403 return success();
3404 }
3405};
3406
3407/// Merge an affine min/max op to its consumers if its consumer is also an
3408/// affine min/max op.
3409///
3410/// This pattern requires the producer affine min/max op is bound to a
3411/// dimension/symbol that is used as a standalone expression in the consumer
3412/// affine op's map.
3413///
3414/// For example, a pattern like the following:
3415///
3416/// %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1]
3417/// %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2]
3418///
3419/// Can be turned into:
3420///
3421/// %1 = affine.min affine_map<
3422/// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1]
3423template <typename T>
3424struct MergeAffineMinMaxOp : public OpRewritePattern<T> {
3425 using OpRewritePattern<T>::OpRewritePattern;
3426
3427 LogicalResult matchAndRewrite(T affineOp,
3428 PatternRewriter &rewriter) const override {
3429 AffineMap oldMap = affineOp.getAffineMap();
3430 ValueRange dimOperands =
3431 affineOp.getMapOperands().take_front(oldMap.getNumDims());
3432 ValueRange symOperands =
3433 affineOp.getMapOperands().take_back(oldMap.getNumSymbols());
3434
3435 auto newDimOperands = llvm::to_vector<8>(Range&: dimOperands);
3436 auto newSymOperands = llvm::to_vector<8>(Range&: symOperands);
3437 SmallVector<AffineExpr, 4> newExprs;
3438 SmallVector<T, 4> producerOps;
3439
3440 // Go over each expression to see whether it's a single dimension/symbol
3441 // with the corresponding operand which is the result of another affine
3442 // min/max op. If So it can be merged into this affine op.
3443 for (AffineExpr expr : oldMap.getResults()) {
3444 if (auto symExpr = dyn_cast<AffineSymbolExpr>(Val&: expr)) {
3445 Value symValue = symOperands[symExpr.getPosition()];
3446 if (auto producerOp = symValue.getDefiningOp<T>()) {
3447 producerOps.push_back(producerOp);
3448 continue;
3449 }
3450 } else if (auto dimExpr = dyn_cast<AffineDimExpr>(Val&: expr)) {
3451 Value dimValue = dimOperands[dimExpr.getPosition()];
3452 if (auto producerOp = dimValue.getDefiningOp<T>()) {
3453 producerOps.push_back(producerOp);
3454 continue;
3455 }
3456 }
3457 // For the above cases we will remove the expression by merging the
3458 // producer affine min/max's affine expressions. Otherwise we need to
3459 // keep the existing expression.
3460 newExprs.push_back(Elt: expr);
3461 }
3462
3463 if (producerOps.empty())
3464 return failure();
3465
3466 unsigned numUsedDims = oldMap.getNumDims();
3467 unsigned numUsedSyms = oldMap.getNumSymbols();
3468
3469 // Now go over all producer affine ops and merge their expressions.
3470 for (T producerOp : producerOps) {
3471 AffineMap producerMap = producerOp.getAffineMap();
3472 unsigned numProducerDims = producerMap.getNumDims();
3473 unsigned numProducerSyms = producerMap.getNumSymbols();
3474
3475 // Collect all dimension/symbol values.
3476 ValueRange dimValues =
3477 producerOp.getMapOperands().take_front(numProducerDims);
3478 ValueRange symValues =
3479 producerOp.getMapOperands().take_back(numProducerSyms);
3480 newDimOperands.append(in_start: dimValues.begin(), in_end: dimValues.end());
3481 newSymOperands.append(in_start: symValues.begin(), in_end: symValues.end());
3482
3483 // For expressions we need to shift to avoid overlap.
3484 for (AffineExpr expr : producerMap.getResults()) {
3485 newExprs.push_back(Elt: expr.shiftDims(numDims: numProducerDims, shift: numUsedDims)
3486 .shiftSymbols(numSymbols: numProducerSyms, shift: numUsedSyms));
3487 }
3488
3489 numUsedDims += numProducerDims;
3490 numUsedSyms += numProducerSyms;
3491 }
3492
3493 auto newMap = AffineMap::get(dimCount: numUsedDims, symbolCount: numUsedSyms, results: newExprs,
3494 context: rewriter.getContext());
3495 auto newOperands =
3496 llvm::to_vector<8>(Range: llvm::concat<Value>(Ranges&: newDimOperands, Ranges&: newSymOperands));
3497 rewriter.replaceOpWithNewOp<T>(affineOp, newMap, newOperands);
3498
3499 return success();
3500 }
3501};
3502
3503/// Canonicalize the result expression order of an affine map and return success
3504/// if the order changed.
3505///
3506/// The function flattens the map's affine expressions to coefficient arrays and
3507/// sorts them in lexicographic order. A coefficient array contains a multiplier
3508/// for every dimension/symbol and a constant term. The canonicalization fails
3509/// if a result expression is not pure or if the flattening requires local
3510/// variables that, unlike dimensions and symbols, have no global order.
3511static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map) {
3512 SmallVector<SmallVector<int64_t>> flattenedExprs;
3513 for (const AffineExpr &resultExpr : map.getResults()) {
3514 // Fail if the expression is not pure.
3515 if (!resultExpr.isPureAffine())
3516 return failure();
3517
3518 SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols());
3519 auto flattenResult = flattener.walkPostOrder(expr: resultExpr);
3520 if (failed(Result: flattenResult))
3521 return failure();
3522
3523 // Fail if the flattened expression has local variables.
3524 if (flattener.operandExprStack.back().size() !=
3525 map.getNumDims() + map.getNumSymbols() + 1)
3526 return failure();
3527
3528 flattenedExprs.emplace_back(Args: flattener.operandExprStack.back().begin(),
3529 Args: flattener.operandExprStack.back().end());
3530 }
3531
3532 // Fail if sorting is not necessary.
3533 if (llvm::is_sorted(Range&: flattenedExprs))
3534 return failure();
3535
3536 // Reorder the result expressions according to their flattened form.
3537 SmallVector<unsigned> resultPermutation =
3538 llvm::to_vector(Range: llvm::seq<unsigned>(Begin: 0, End: map.getNumResults()));
3539 llvm::sort(C&: resultPermutation, Comp: [&](unsigned lhs, unsigned rhs) {
3540 return flattenedExprs[lhs] < flattenedExprs[rhs];
3541 });
3542 SmallVector<AffineExpr> newExprs;
3543 for (unsigned idx : resultPermutation)
3544 newExprs.push_back(Elt: map.getResult(idx));
3545
3546 map = AffineMap::get(dimCount: map.getNumDims(), symbolCount: map.getNumSymbols(), results: newExprs,
3547 context: map.getContext());
3548 return success();
3549}
3550
3551/// Canonicalize the affine map result expression order of an affine min/max
3552/// operation.
3553///
3554/// The pattern calls `canonicalizeMapExprAndTermOrder` to order the result
3555/// expressions and replaces the operation if the order changed.
3556///
3557/// For example, the following operation:
3558///
3559/// %0 = affine.min affine_map<(d0, d1) -> (d0 + d1, d1 + 16, 32)> (%i0, %i1)
3560///
3561/// Turns into:
3562///
3563/// %0 = affine.min affine_map<(d0, d1) -> (32, d1 + 16, d0 + d1)> (%i0, %i1)
3564template <typename T>
3565struct CanonicalizeAffineMinMaxOpExprAndTermOrder : public OpRewritePattern<T> {
3566 using OpRewritePattern<T>::OpRewritePattern;
3567
3568 LogicalResult matchAndRewrite(T affineOp,
3569 PatternRewriter &rewriter) const override {
3570 AffineMap map = affineOp.getAffineMap();
3571 if (failed(Result: canonicalizeMapExprAndTermOrder(map)))
3572 return failure();
3573 rewriter.replaceOpWithNewOp<T>(affineOp, map, affineOp.getMapOperands());
3574 return success();
3575 }
3576};
3577
3578template <typename T>
3579struct CanonicalizeSingleResultAffineMinMaxOp : public OpRewritePattern<T> {
3580 using OpRewritePattern<T>::OpRewritePattern;
3581
3582 LogicalResult matchAndRewrite(T affineOp,
3583 PatternRewriter &rewriter) const override {
3584 if (affineOp.getMap().getNumResults() != 1)
3585 return failure();
3586 rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.getMap(),
3587 affineOp.getOperands());
3588 return success();
3589 }
3590};
3591
3592//===----------------------------------------------------------------------===//
3593// AffineMinOp
3594//===----------------------------------------------------------------------===//
3595//
3596// %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
3597//
3598
3599OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) {
3600 return foldMinMaxOp(*this, adaptor.getOperands());
3601}
3602
3603void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3604 MLIRContext *context) {
3605 patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMinOp>,
3606 DeduplicateAffineMinMaxExpressions<AffineMinOp>,
3607 MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>,
3608 CanonicalizeAffineMinMaxOpExprAndTermOrder<AffineMinOp>>(
3609 context);
3610}
3611
3612LogicalResult AffineMinOp::verify() { return verifyAffineMinMaxOp(*this); }
3613
3614ParseResult AffineMinOp::parse(OpAsmParser &parser, OperationState &result) {
3615 return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3616}
3617
3618void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3619
3620//===----------------------------------------------------------------------===//
3621// AffineMaxOp
3622//===----------------------------------------------------------------------===//
3623//
3624// %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
3625//
3626
3627OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) {
3628 return foldMinMaxOp(*this, adaptor.getOperands());
3629}
3630
3631void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3632 MLIRContext *context) {
3633 patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMaxOp>,
3634 DeduplicateAffineMinMaxExpressions<AffineMaxOp>,
3635 MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>,
3636 CanonicalizeAffineMinMaxOpExprAndTermOrder<AffineMaxOp>>(
3637 context);
3638}
3639
3640LogicalResult AffineMaxOp::verify() { return verifyAffineMinMaxOp(*this); }
3641
3642ParseResult AffineMaxOp::parse(OpAsmParser &parser, OperationState &result) {
3643 return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3644}
3645
3646void AffineMaxOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3647
3648//===----------------------------------------------------------------------===//
3649// AffinePrefetchOp
3650//===----------------------------------------------------------------------===//
3651
3652//
3653// affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
3654//
3655ParseResult AffinePrefetchOp::parse(OpAsmParser &parser,
3656 OperationState &result) {
3657 auto &builder = parser.getBuilder();
3658 auto indexTy = builder.getIndexType();
3659
3660 MemRefType type;
3661 OpAsmParser::UnresolvedOperand memrefInfo;
3662 IntegerAttr hintInfo;
3663 auto i32Type = parser.getBuilder().getIntegerType(32);
3664 StringRef readOrWrite, cacheType;
3665
3666 AffineMapAttr mapAttr;
3667 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
3668 if (parser.parseOperand(memrefInfo) ||
3669 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3670 AffinePrefetchOp::getMapAttrStrName(),
3671 result.attributes) ||
3672 parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
3673 parser.parseComma() || parser.parseKeyword("locality") ||
3674 parser.parseLess() ||
3675 parser.parseAttribute(hintInfo, i32Type,
3676 AffinePrefetchOp::getLocalityHintAttrStrName(),
3677 result.attributes) ||
3678 parser.parseGreater() || parser.parseComma() ||
3679 parser.parseKeyword(&cacheType) ||
3680 parser.parseOptionalAttrDict(result.attributes) ||
3681 parser.parseColonType(type) ||
3682 parser.resolveOperand(memrefInfo, type, result.operands) ||
3683 parser.resolveOperands(mapOperands, indexTy, result.operands))
3684 return failure();
3685
3686 if (readOrWrite != "read" && readOrWrite != "write")
3687 return parser.emitError(parser.getNameLoc(),
3688 "rw specifier has to be 'read' or 'write'");
3689 result.addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
3690 parser.getBuilder().getBoolAttr(readOrWrite == "write"));
3691
3692 if (cacheType != "data" && cacheType != "instr")
3693 return parser.emitError(parser.getNameLoc(),
3694 "cache type has to be 'data' or 'instr'");
3695
3696 result.addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
3697 parser.getBuilder().getBoolAttr(cacheType == "data"));
3698
3699 return success();
3700}
3701
3702void AffinePrefetchOp::print(OpAsmPrinter &p) {
3703 p << " " << getMemref() << '[';
3704 AffineMapAttr mapAttr =
3705 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3706 if (mapAttr)
3707 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3708 p << ']' << ", " << (getIsWrite() ? "write" : "read") << ", "
3709 << "locality<" << getLocalityHint() << ">, "
3710 << (getIsDataCache() ? "data" : "instr");
3711 p.printOptionalAttrDict(
3712 (*this)->getAttrs(),
3713 /*elidedAttrs=*/{getMapAttrStrName(), getLocalityHintAttrStrName(),
3714 getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3715 p << " : " << getMemRefType();
3716}
3717
3718LogicalResult AffinePrefetchOp::verify() {
3719 auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3720 if (mapAttr) {
3721 AffineMap map = mapAttr.getValue();
3722 if (map.getNumResults() != getMemRefType().getRank())
3723 return emitOpError("affine.prefetch affine map num results must equal"
3724 " memref rank");
3725 if (map.getNumInputs() + 1 != getNumOperands())
3726 return emitOpError("too few operands");
3727 } else {
3728 if (getNumOperands() != 1)
3729 return emitOpError("too few operands");
3730 }
3731
3732 Region *scope = getAffineScope(*this);
3733 for (auto idx : getMapOperands()) {
3734 if (!isValidAffineIndexOperand(idx, scope))
3735 return emitOpError(
3736 "index must be a valid dimension or symbol identifier");
3737 }
3738 return success();
3739}
3740
3741void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
3742 MLIRContext *context) {
3743 // prefetch(memrefcast) -> prefetch
3744 results.add<SimplifyAffineOp<AffinePrefetchOp>>(context);
3745}
3746
3747LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
3748 SmallVectorImpl<OpFoldResult> &results) {
3749 /// prefetch(memrefcast) -> prefetch
3750 return memref::foldMemRefCast(*this);
3751}
3752
3753//===----------------------------------------------------------------------===//
3754// AffineParallelOp
3755//===----------------------------------------------------------------------===//
3756
3757void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3758 TypeRange resultTypes,
3759 ArrayRef<arith::AtomicRMWKind> reductions,
3760 ArrayRef<int64_t> ranges) {
3761 SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
3762 auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
3763 return builder.getConstantAffineMap(value);
3764 }));
3765 SmallVector<int64_t> steps(ranges.size(), 1);
3766 build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs,
3767 /*ubArgs=*/{}, steps);
3768}
3769
3770void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3771 TypeRange resultTypes,
3772 ArrayRef<arith::AtomicRMWKind> reductions,
3773 ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
3774 ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
3775 ArrayRef<int64_t> steps) {
3776 assert(llvm::all_of(lbMaps,
3777 [lbMaps](AffineMap m) {
3778 return m.getNumDims() == lbMaps[0].getNumDims() &&
3779 m.getNumSymbols() == lbMaps[0].getNumSymbols();
3780 }) &&
3781 "expected all lower bounds maps to have the same number of dimensions "
3782 "and symbols");
3783 assert(llvm::all_of(ubMaps,
3784 [ubMaps](AffineMap m) {
3785 return m.getNumDims() == ubMaps[0].getNumDims() &&
3786 m.getNumSymbols() == ubMaps[0].getNumSymbols();
3787 }) &&
3788 "expected all upper bounds maps to have the same number of dimensions "
3789 "and symbols");
3790 assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3791 "expected lower bound maps to have as many inputs as lower bound "
3792 "operands");
3793 assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3794 "expected upper bound maps to have as many inputs as upper bound "
3795 "operands");
3796
3797 OpBuilder::InsertionGuard guard(builder);
3798 result.addTypes(resultTypes);
3799
3800 // Convert the reductions to integer attributes.
3801 SmallVector<Attribute, 4> reductionAttrs;
3802 for (arith::AtomicRMWKind reduction : reductions)
3803 reductionAttrs.push_back(
3804 builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
3805 result.addAttribute(getReductionsAttrStrName(),
3806 builder.getArrayAttr(reductionAttrs));
3807
3808 // Concatenates maps defined in the same input space (same dimensions and
3809 // symbols), assumes there is at least one map.
3810 auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps,
3811 SmallVectorImpl<int32_t> &groups) {
3812 if (maps.empty())
3813 return AffineMap::get(builder.getContext());
3814 SmallVector<AffineExpr> exprs;
3815 groups.reserve(groups.size() + maps.size());
3816 exprs.reserve(maps.size());
3817 for (AffineMap m : maps) {
3818 llvm::append_range(exprs, m.getResults());
3819 groups.push_back(m.getNumResults());
3820 }
3821 return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
3822 maps[0].getContext());
3823 };
3824
3825 // Set up the bounds.
3826 SmallVector<int32_t> lbGroups, ubGroups;
3827 AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3828 AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3829 result.addAttribute(getLowerBoundsMapAttrStrName(),
3830 AffineMapAttr::get(lbMap));
3831 result.addAttribute(getLowerBoundsGroupsAttrStrName(),
3832 builder.getI32TensorAttr(lbGroups));
3833 result.addAttribute(getUpperBoundsMapAttrStrName(),
3834 AffineMapAttr::get(ubMap));
3835 result.addAttribute(getUpperBoundsGroupsAttrStrName(),
3836 builder.getI32TensorAttr(ubGroups));
3837 result.addAttribute(getStepsAttrStrName(), builder.getI64ArrayAttr(steps));
3838 result.addOperands(lbArgs);
3839 result.addOperands(ubArgs);
3840
3841 // Create a region and a block for the body.
3842 auto *bodyRegion = result.addRegion();
3843 Block *body = builder.createBlock(bodyRegion);
3844
3845 // Add all the block arguments.
3846 for (unsigned i = 0, e = steps.size(); i < e; ++i)
3847 body->addArgument(IndexType::get(builder.getContext()), result.location);
3848 if (resultTypes.empty())
3849 ensureTerminator(*bodyRegion, builder, result.location);
3850}
3851
3852SmallVector<Region *> AffineParallelOp::getLoopRegions() {
3853 return {&getRegion()};
3854}
3855
3856unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
3857
3858AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
3859 return getOperands().take_front(getLowerBoundsMap().getNumInputs());
3860}
3861
3862AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
3863 return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
3864}
3865
3866AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
3867 auto values = getLowerBoundsGroups().getValues<int32_t>();
3868 unsigned start = 0;
3869 for (unsigned i = 0; i < pos; ++i)
3870 start += values[i];
3871 return getLowerBoundsMap().getSliceMap(start, values[pos]);
3872}
3873
3874AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
3875 auto values = getUpperBoundsGroups().getValues<int32_t>();
3876 unsigned start = 0;
3877 for (unsigned i = 0; i < pos; ++i)
3878 start += values[i];
3879 return getUpperBoundsMap().getSliceMap(start, values[pos]);
3880}
3881
3882AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
3883 return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
3884}
3885
3886AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
3887 return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
3888}
3889
3890std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
3891 if (hasMinMaxBounds())
3892 return std::nullopt;
3893
3894 // Try to convert all the ranges to constant expressions.
3895 SmallVector<int64_t, 8> out;
3896 AffineValueMap rangesValueMap;
3897 AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
3898 &rangesValueMap);
3899 out.reserve(rangesValueMap.getNumResults());
3900 for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
3901 auto expr = rangesValueMap.getResult(i);
3902 auto cst = dyn_cast<AffineConstantExpr>(expr);
3903 if (!cst)
3904 return std::nullopt;
3905 out.push_back(cst.getValue());
3906 }
3907 return out;
3908}
3909
3910Block *AffineParallelOp::getBody() { return &getRegion().front(); }
3911
3912OpBuilder AffineParallelOp::getBodyBuilder() {
3913 return OpBuilder(getBody(), std::prev(getBody()->end()));
3914}
3915
3916void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
3917 assert(lbOperands.size() == map.getNumInputs() &&
3918 "operands to map must match number of inputs");
3919
3920 auto ubOperands = getUpperBoundsOperands();
3921
3922 SmallVector<Value, 4> newOperands(lbOperands);
3923 newOperands.append(ubOperands.begin(), ubOperands.end());
3924 (*this)->setOperands(newOperands);
3925
3926 setLowerBoundsMapAttr(AffineMapAttr::get(map));
3927}
3928
3929void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
3930 assert(ubOperands.size() == map.getNumInputs() &&
3931 "operands to map must match number of inputs");
3932
3933 SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
3934 newOperands.append(ubOperands.begin(), ubOperands.end());
3935 (*this)->setOperands(newOperands);
3936
3937 setUpperBoundsMapAttr(AffineMapAttr::get(map));
3938}
3939
3940void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
3941 setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
3942}
3943
3944// check whether resultType match op or not in affine.parallel
3945static bool isResultTypeMatchAtomicRMWKind(Type resultType,
3946 arith::AtomicRMWKind op) {
3947 switch (op) {
3948 case arith::AtomicRMWKind::addf:
3949 return isa<FloatType>(Val: resultType);
3950 case arith::AtomicRMWKind::addi:
3951 return isa<IntegerType>(Val: resultType);
3952 case arith::AtomicRMWKind::assign:
3953 return true;
3954 case arith::AtomicRMWKind::mulf:
3955 return isa<FloatType>(Val: resultType);
3956 case arith::AtomicRMWKind::muli:
3957 return isa<IntegerType>(Val: resultType);
3958 case arith::AtomicRMWKind::maximumf:
3959 return isa<FloatType>(Val: resultType);
3960 case arith::AtomicRMWKind::minimumf:
3961 return isa<FloatType>(Val: resultType);
3962 case arith::AtomicRMWKind::maxs: {
3963 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3964 return intType && intType.isSigned();
3965 }
3966 case arith::AtomicRMWKind::mins: {
3967 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3968 return intType && intType.isSigned();
3969 }
3970 case arith::AtomicRMWKind::maxu: {
3971 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3972 return intType && intType.isUnsigned();
3973 }
3974 case arith::AtomicRMWKind::minu: {
3975 auto intType = llvm::dyn_cast<IntegerType>(resultType);
3976 return intType && intType.isUnsigned();
3977 }
3978 case arith::AtomicRMWKind::ori:
3979 return isa<IntegerType>(Val: resultType);
3980 case arith::AtomicRMWKind::andi:
3981 return isa<IntegerType>(Val: resultType);
3982 default:
3983 return false;
3984 }
3985}
3986
3987LogicalResult AffineParallelOp::verify() {
3988 auto numDims = getNumDims();
3989 if (getLowerBoundsGroups().getNumElements() != numDims ||
3990 getUpperBoundsGroups().getNumElements() != numDims ||
3991 getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
3992 return emitOpError() << "the number of region arguments ("
3993 << getBody()->getNumArguments()
3994 << ") and the number of map groups for lower ("
3995 << getLowerBoundsGroups().getNumElements()
3996 << ") and upper bound ("
3997 << getUpperBoundsGroups().getNumElements()
3998 << "), and the number of steps (" << getSteps().size()
3999 << ") must all match";
4000 }
4001
4002 unsigned expectedNumLBResults = 0;
4003 for (APInt v : getLowerBoundsGroups()) {
4004 unsigned results = v.getZExtValue();
4005 if (results == 0)
4006 return emitOpError()
4007 << "expected lower bound map to have at least one result";
4008 expectedNumLBResults += results;
4009 }
4010 if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
4011 return emitOpError() << "expected lower bounds map to have "
4012 << expectedNumLBResults << " results";
4013 unsigned expectedNumUBResults = 0;
4014 for (APInt v : getUpperBoundsGroups()) {
4015 unsigned results = v.getZExtValue();
4016 if (results == 0)
4017 return emitOpError()
4018 << "expected upper bound map to have at least one result";
4019 expectedNumUBResults += results;
4020 }
4021 if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
4022 return emitOpError() << "expected upper bounds map to have "
4023 << expectedNumUBResults << " results";
4024
4025 if (getReductions().size() != getNumResults())
4026 return emitOpError("a reduction must be specified for each output");
4027
4028 // Verify reduction ops are all valid and each result type matches reduction
4029 // ops
4030 for (auto it : llvm::enumerate((getReductions()))) {
4031 Attribute attr = it.value();
4032 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
4033 if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
4034 return emitOpError("invalid reduction attribute");
4035 auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
4036 if (!isResultTypeMatchAtomicRMWKind(getResult(it.index()).getType(), kind))
4037 return emitOpError("result type cannot match reduction attribute");
4038 }
4039
4040 // Verify that the bound operands are valid dimension/symbols.
4041 /// Lower bounds.
4042 if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundsOperands(),
4043 getLowerBoundsMap().getNumDims())))
4044 return failure();
4045 /// Upper bounds.
4046 if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundsOperands(),
4047 getUpperBoundsMap().getNumDims())))
4048 return failure();
4049 return success();
4050}
4051
4052LogicalResult AffineValueMap::canonicalize() {
4053 SmallVector<Value, 4> newOperands{operands};
4054 auto newMap = getAffineMap();
4055 composeAffineMapAndOperands(map: &newMap, operands: &newOperands);
4056 if (newMap == getAffineMap() && newOperands == operands)
4057 return failure();
4058 reset(map: newMap, operands: newOperands);
4059 return success();
4060}
4061
4062/// Canonicalize the bounds of the given loop.
4063static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
4064 AffineValueMap lb = op.getLowerBoundsValueMap();
4065 bool lbCanonicalized = succeeded(Result: lb.canonicalize());
4066
4067 AffineValueMap ub = op.getUpperBoundsValueMap();
4068 bool ubCanonicalized = succeeded(Result: ub.canonicalize());
4069
4070 // Any canonicalization change always leads to updated map(s).
4071 if (!lbCanonicalized && !ubCanonicalized)
4072 return failure();
4073
4074 if (lbCanonicalized)
4075 op.setLowerBounds(lb.getOperands(), lb.getAffineMap());
4076 if (ubCanonicalized)
4077 op.setUpperBounds(ub.getOperands(), ub.getAffineMap());
4078
4079 return success();
4080}
4081
4082LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
4083 SmallVectorImpl<OpFoldResult> &results) {
4084 return canonicalizeLoopBounds(*this);
4085}
4086
4087/// Prints a lower(upper) bound of an affine parallel loop with max(min)
4088/// conditions in it. `mapAttr` is a flat list of affine expressions and `group`
4089/// identifies which of the those expressions form max/min groups. `operands`
4090/// are the SSA values of dimensions and symbols and `keyword` is either "min"
4091/// or "max".
4092static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr,
4093 DenseIntElementsAttr group, ValueRange operands,
4094 StringRef keyword) {
4095 AffineMap map = mapAttr.getValue();
4096 unsigned numDims = map.getNumDims();
4097 ValueRange dimOperands = operands.take_front(n: numDims);
4098 ValueRange symOperands = operands.drop_front(n: numDims);
4099 unsigned start = 0;
4100 for (llvm::APInt groupSize : group) {
4101 if (start != 0)
4102 p << ", ";
4103
4104 unsigned size = groupSize.getZExtValue();
4105 if (size == 1) {
4106 p.printAffineExprOfSSAIds(expr: map.getResult(idx: start), dimOperands, symOperands);
4107 ++start;
4108 } else {
4109 p << keyword << '(';
4110 AffineMap submap = map.getSliceMap(start, length: size);
4111 p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands);
4112 p << ')';
4113 start += size;
4114 }
4115 }
4116}
4117
4118void AffineParallelOp::print(OpAsmPrinter &p) {
4119 p << " (" << getBody()->getArguments() << ") = (";
4120 printMinMaxBound(p, getLowerBoundsMapAttr(), getLowerBoundsGroupsAttr(),
4121 getLowerBoundsOperands(), "max");
4122 p << ") to (";
4123 printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(),
4124 getUpperBoundsOperands(), "min");
4125 p << ')';
4126 SmallVector<int64_t, 8> steps = getSteps();
4127 bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
4128 if (!elideSteps) {
4129 p << " step (";
4130 llvm::interleaveComma(steps, p);
4131 p << ')';
4132 }
4133 if (getNumResults()) {
4134 p << " reduce (";
4135 llvm::interleaveComma(getReductions(), p, [&](auto &attr) {
4136 arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4137 llvm::cast<IntegerAttr>(attr).getInt());
4138 p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";
4139 });
4140 p << ") -> (" << getResultTypes() << ")";
4141 }
4142
4143 p << ' ';
4144 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
4145 /*printBlockTerminators=*/getNumResults());
4146 p.printOptionalAttrDict(
4147 (*this)->getAttrs(),
4148 /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrStrName(),
4149 AffineParallelOp::getLowerBoundsMapAttrStrName(),
4150 AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4151 AffineParallelOp::getUpperBoundsMapAttrStrName(),
4152 AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4153 AffineParallelOp::getStepsAttrStrName()});
4154}
4155
4156/// Given a list of lists of parsed operands, populates `uniqueOperands` with
4157/// unique operands. Also populates `replacements with affine expressions of
4158/// `kind` that can be used to update affine maps previously accepting a
4159/// `operands` to accept `uniqueOperands` instead.
4160static ParseResult deduplicateAndResolveOperands(
4161 OpAsmParser &parser,
4162 ArrayRef<SmallVector<OpAsmParser::UnresolvedOperand>> operands,
4163 SmallVectorImpl<Value> &uniqueOperands,
4164 SmallVectorImpl<AffineExpr> &replacements, AffineExprKind kind) {
4165 assert((kind == AffineExprKind::DimId || kind == AffineExprKind::SymbolId) &&
4166 "expected operands to be dim or symbol expression");
4167
4168 Type indexType = parser.getBuilder().getIndexType();
4169 for (const auto &list : operands) {
4170 SmallVector<Value> valueOperands;
4171 if (parser.resolveOperands(operands: list, type: indexType, result&: valueOperands))
4172 return failure();
4173 for (Value operand : valueOperands) {
4174 unsigned pos = std::distance(first: uniqueOperands.begin(),
4175 last: llvm::find(Range&: uniqueOperands, Val: operand));
4176 if (pos == uniqueOperands.size())
4177 uniqueOperands.push_back(Elt: operand);
4178 replacements.push_back(
4179 Elt: kind == AffineExprKind::DimId
4180 ? getAffineDimExpr(position: pos, context: parser.getContext())
4181 : getAffineSymbolExpr(position: pos, context: parser.getContext()));
4182 }
4183 }
4184 return success();
4185}
4186
4187namespace {
4188enum class MinMaxKind { Min, Max };
4189} // namespace
4190
4191/// Parses an affine map that can contain a min/max for groups of its results,
4192/// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates
4193/// `result` attributes with the map (flat list of expressions) and the grouping
4194/// (list of integers that specify how many expressions to put into each
4195/// min/max) attributes. Deduplicates repeated operands.
4196///
4197/// parallel-bound ::= `(` parallel-group-list `)`
4198/// parallel-group-list ::= parallel-group (`,` parallel-group-list)?
4199/// parallel-group ::= simple-group | min-max-group
4200/// simple-group ::= expr-of-ssa-ids
4201/// min-max-group ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)`
4202/// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)?
4203///
4204/// Examples:
4205/// (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6))
4206/// (%0, max(%1 - 2 * %2))
4207static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser,
4208 OperationState &result,
4209 MinMaxKind kind) {
4210 // Using `const` not `constexpr` below to workaround a MSVC optimizer bug,
4211 // see: https://reviews.llvm.org/D134227#3821753
4212 const llvm::StringLiteral tmpAttrStrName = "__pseudo_bound_map";
4213
4214 StringRef mapName = kind == MinMaxKind::Min
4215 ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4216 : AffineParallelOp::getLowerBoundsMapAttrStrName();
4217 StringRef groupsName =
4218 kind == MinMaxKind::Min
4219 ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4220 : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4221
4222 if (failed(Result: parser.parseLParen()))
4223 return failure();
4224
4225 if (succeeded(Result: parser.parseOptionalRParen())) {
4226 result.addAttribute(
4227 mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap()));
4228 result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr(values: {}));
4229 return success();
4230 }
4231
4232 SmallVector<AffineExpr> flatExprs;
4233 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> flatDimOperands;
4234 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> flatSymOperands;
4235 SmallVector<int32_t> numMapsPerGroup;
4236 SmallVector<OpAsmParser::UnresolvedOperand> mapOperands;
4237 auto parseOperands = [&]() {
4238 if (succeeded(Result: parser.parseOptionalKeyword(
4239 keyword: kind == MinMaxKind::Min ? "min" : "max"))) {
4240 mapOperands.clear();
4241 AffineMapAttr map;
4242 if (failed(parser.parseAffineMapOfSSAIds(operands&: mapOperands, map&: map, attrName: tmpAttrStrName,
4243 attrs&: result.attributes,
4244 delimiter: OpAsmParser::Delimiter::Paren)))
4245 return failure();
4246 result.attributes.erase(name: tmpAttrStrName);
4247 llvm::append_range(flatExprs, map.getValue().getResults());
4248 auto operandsRef = llvm::ArrayRef(mapOperands);
4249 auto dimsRef = operandsRef.take_front(N: map.getValue().getNumDims());
4250 SmallVector<OpAsmParser::UnresolvedOperand> dims(dimsRef);
4251 auto symsRef = operandsRef.drop_front(N: map.getValue().getNumDims());
4252 SmallVector<OpAsmParser::UnresolvedOperand> syms(symsRef);
4253 flatDimOperands.append(map.getValue().getNumResults(), dims);
4254 flatSymOperands.append(map.getValue().getNumResults(), syms);
4255 numMapsPerGroup.push_back(Elt: map.getValue().getNumResults());
4256 } else {
4257 if (failed(Result: parser.parseAffineExprOfSSAIds(dimOperands&: flatDimOperands.emplace_back(),
4258 symbOperands&: flatSymOperands.emplace_back(),
4259 expr&: flatExprs.emplace_back())))
4260 return failure();
4261 numMapsPerGroup.push_back(Elt: 1);
4262 }
4263 return success();
4264 };
4265 if (parser.parseCommaSeparatedList(parseElementFn: parseOperands) || parser.parseRParen())
4266 return failure();
4267
4268 unsigned totalNumDims = 0;
4269 unsigned totalNumSyms = 0;
4270 for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4271 unsigned numDims = flatDimOperands[i].size();
4272 unsigned numSyms = flatSymOperands[i].size();
4273 flatExprs[i] = flatExprs[i]
4274 .shiftDims(numDims, shift: totalNumDims)
4275 .shiftSymbols(numSymbols: numSyms, shift: totalNumSyms);
4276 totalNumDims += numDims;
4277 totalNumSyms += numSyms;
4278 }
4279
4280 // Deduplicate map operands.
4281 SmallVector<Value> dimOperands, symOperands;
4282 SmallVector<AffineExpr> dimRplacements, symRepacements;
4283 if (deduplicateAndResolveOperands(parser, operands: flatDimOperands, uniqueOperands&: dimOperands,
4284 replacements&: dimRplacements, kind: AffineExprKind::DimId) ||
4285 deduplicateAndResolveOperands(parser, operands: flatSymOperands, uniqueOperands&: symOperands,
4286 replacements&: symRepacements, kind: AffineExprKind::SymbolId))
4287 return failure();
4288
4289 result.operands.append(in_start: dimOperands.begin(), in_end: dimOperands.end());
4290 result.operands.append(in_start: symOperands.begin(), in_end: symOperands.end());
4291
4292 Builder &builder = parser.getBuilder();
4293 auto flatMap = AffineMap::get(dimCount: totalNumDims, symbolCount: totalNumSyms, results: flatExprs,
4294 context: parser.getContext());
4295 flatMap = flatMap.replaceDimsAndSymbols(
4296 dimReplacements: dimRplacements, symReplacements: symRepacements, numResultDims: dimOperands.size(), numResultSyms: symOperands.size());
4297
4298 result.addAttribute(mapName, AffineMapAttr::get(flatMap));
4299 result.addAttribute(groupsName, builder.getI32TensorAttr(values: numMapsPerGroup));
4300 return success();
4301}
4302
4303//
4304// operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound
4305// `to` parallel-bound steps? region attr-dict?
4306// steps ::= `steps` `(` integer-literals `)`
4307//
4308ParseResult AffineParallelOp::parse(OpAsmParser &parser,
4309 OperationState &result) {
4310 auto &builder = parser.getBuilder();
4311 auto indexType = builder.getIndexType();
4312 SmallVector<OpAsmParser::Argument, 4> ivs;
4313 if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
4314 parser.parseEqual() ||
4315 parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
4316 parser.parseKeyword("to") ||
4317 parseAffineMapWithMinMax(parser, result, MinMaxKind::Min))
4318 return failure();
4319
4320 AffineMapAttr stepsMapAttr;
4321 NamedAttrList stepsAttrs;
4322 SmallVector<OpAsmParser::UnresolvedOperand, 4> stepsMapOperands;
4323 if (failed(parser.parseOptionalKeyword("step"))) {
4324 SmallVector<int64_t, 4> steps(ivs.size(), 1);
4325 result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4326 builder.getI64ArrayAttr(steps));
4327 } else {
4328 if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
4329 AffineParallelOp::getStepsAttrStrName(),
4330 stepsAttrs,
4331 OpAsmParser::Delimiter::Paren))
4332 return failure();
4333
4334 // Convert steps from an AffineMap into an I64ArrayAttr.
4335 SmallVector<int64_t, 4> steps;
4336 auto stepsMap = stepsMapAttr.getValue();
4337 for (const auto &result : stepsMap.getResults()) {
4338 auto constExpr = dyn_cast<AffineConstantExpr>(result);
4339 if (!constExpr)
4340 return parser.emitError(parser.getNameLoc(),
4341 "steps must be constant integers");
4342 steps.push_back(constExpr.getValue());
4343 }
4344 result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4345 builder.getI64ArrayAttr(steps));
4346 }
4347
4348 // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
4349 // quoted strings are a member of the enum AtomicRMWKind.
4350 SmallVector<Attribute, 4> reductions;
4351 if (succeeded(parser.parseOptionalKeyword("reduce"))) {
4352 if (parser.parseLParen())
4353 return failure();
4354 auto parseAttributes = [&]() -> ParseResult {
4355 // Parse a single quoted string via the attribute parsing, and then
4356 // verify it is a member of the enum and convert to it's integer
4357 // representation.
4358 StringAttr attrVal;
4359 NamedAttrList attrStorage;
4360 auto loc = parser.getCurrentLocation();
4361 if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
4362 attrStorage))
4363 return failure();
4364 std::optional<arith::AtomicRMWKind> reduction =
4365 arith::symbolizeAtomicRMWKind(attrVal.getValue());
4366 if (!reduction)
4367 return parser.emitError(loc, "invalid reduction value: ") << attrVal;
4368 reductions.push_back(
4369 builder.getI64IntegerAttr(static_cast<int64_t>(reduction.value())));
4370 // While we keep getting commas, keep parsing.
4371 return success();
4372 };
4373 if (parser.parseCommaSeparatedList(parseAttributes) || parser.parseRParen())
4374 return failure();
4375 }
4376 result.addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4377 builder.getArrayAttr(reductions));
4378
4379 // Parse return types of reductions (if any)
4380 if (parser.parseOptionalArrowTypeList(result.types))
4381 return failure();
4382
4383 // Now parse the body.
4384 Region *body = result.addRegion();
4385 for (auto &iv : ivs)
4386 iv.type = indexType;
4387 if (parser.parseRegion(*body, ivs) ||
4388 parser.parseOptionalAttrDict(result.attributes))
4389 return failure();
4390
4391 // Add a terminator if none was parsed.
4392 AffineParallelOp::ensureTerminator(*body, builder, result.location);
4393 return success();
4394}
4395
4396//===----------------------------------------------------------------------===//
4397// AffineYieldOp
4398//===----------------------------------------------------------------------===//
4399
4400LogicalResult AffineYieldOp::verify() {
4401 auto *parentOp = (*this)->getParentOp();
4402 auto results = parentOp->getResults();
4403 auto operands = getOperands();
4404
4405 if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4406 return emitOpError() << "only terminates affine.if/for/parallel regions";
4407 if (parentOp->getNumResults() != getNumOperands())
4408 return emitOpError() << "parent of yield must have same number of "
4409 "results as the yield operands";
4410 for (auto it : llvm::zip(results, operands)) {
4411 if (std::get<0>(it).getType() != std::get<1>(it).getType())
4412 return emitOpError() << "types mismatch between yield op and its parent";
4413 }
4414
4415 return success();
4416}
4417
4418//===----------------------------------------------------------------------===//
4419// AffineVectorLoadOp
4420//===----------------------------------------------------------------------===//
4421
4422void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4423 VectorType resultType, AffineMap map,
4424 ValueRange operands) {
4425 assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
4426 result.addOperands(operands);
4427 if (map)
4428 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4429 result.types.push_back(resultType);
4430}
4431
4432void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4433 VectorType resultType, Value memref,
4434 AffineMap map, ValueRange mapOperands) {
4435 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4436 result.addOperands(memref);
4437 result.addOperands(mapOperands);
4438 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4439 result.types.push_back(resultType);
4440}
4441
4442void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4443 VectorType resultType, Value memref,
4444 ValueRange indices) {
4445 auto memrefType = llvm::cast<MemRefType>(memref.getType());
4446 int64_t rank = memrefType.getRank();
4447 // Create identity map for memrefs with at least one dimension or () -> ()
4448 // for zero-dimensional memrefs.
4449 auto map =
4450 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4451 build(builder, result, resultType, memref, map, indices);
4452}
4453
4454void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4455 MLIRContext *context) {
4456 results.add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4457}
4458
4459ParseResult AffineVectorLoadOp::parse(OpAsmParser &parser,
4460 OperationState &result) {
4461 auto &builder = parser.getBuilder();
4462 auto indexTy = builder.getIndexType();
4463
4464 MemRefType memrefType;
4465 VectorType resultType;
4466 OpAsmParser::UnresolvedOperand memrefInfo;
4467 AffineMapAttr mapAttr;
4468 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
4469 return failure(
4470 parser.parseOperand(memrefInfo) ||
4471 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4472 AffineVectorLoadOp::getMapAttrStrName(),
4473 result.attributes) ||
4474 parser.parseOptionalAttrDict(result.attributes) ||
4475 parser.parseColonType(memrefType) || parser.parseComma() ||
4476 parser.parseType(resultType) ||
4477 parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4478 parser.resolveOperands(mapOperands, indexTy, result.operands) ||
4479 parser.addTypeToList(resultType, result.types));
4480}
4481
4482void AffineVectorLoadOp::print(OpAsmPrinter &p) {
4483 p << " " << getMemRef() << '[';
4484 if (AffineMapAttr mapAttr =
4485 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4486 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4487 p << ']';
4488 p.printOptionalAttrDict((*this)->getAttrs(),
4489 /*elidedAttrs=*/{getMapAttrStrName()});
4490 p << " : " << getMemRefType() << ", " << getType();
4491}
4492
4493/// Verify common invariants of affine.vector_load and affine.vector_store.
4494static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
4495 VectorType vectorType) {
4496 // Check that memref and vector element types match.
4497 if (memrefType.getElementType() != vectorType.getElementType())
4498 return op->emitOpError(
4499 message: "requires memref and vector types of the same elemental type");
4500 return success();
4501}
4502
4503LogicalResult AffineVectorLoadOp::verify() {
4504 MemRefType memrefType = getMemRefType();
4505 if (failed(verifyMemoryOpIndexing(
4506 *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4507 getMapOperands(), memrefType,
4508 /*numIndexOperands=*/getNumOperands() - 1)))
4509 return failure();
4510
4511 if (failed(verifyVectorMemoryOp(getOperation(), memrefType, getVectorType())))
4512 return failure();
4513
4514 return success();
4515}
4516
4517//===----------------------------------------------------------------------===//
4518// AffineVectorStoreOp
4519//===----------------------------------------------------------------------===//
4520
4521void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4522 Value valueToStore, Value memref, AffineMap map,
4523 ValueRange mapOperands) {
4524 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4525 result.addOperands(valueToStore);
4526 result.addOperands(memref);
4527 result.addOperands(mapOperands);
4528 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4529}
4530
4531// Use identity map.
4532void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4533 Value valueToStore, Value memref,
4534 ValueRange indices) {
4535 auto memrefType = llvm::cast<MemRefType>(memref.getType());
4536 int64_t rank = memrefType.getRank();
4537 // Create identity map for memrefs with at least one dimension or () -> ()
4538 // for zero-dimensional memrefs.
4539 auto map =
4540 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4541 build(builder, result, valueToStore, memref, map, indices);
4542}
4543void AffineVectorStoreOp::getCanonicalizationPatterns(
4544 RewritePatternSet &results, MLIRContext *context) {
4545 results.add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4546}
4547
4548ParseResult AffineVectorStoreOp::parse(OpAsmParser &parser,
4549 OperationState &result) {
4550 auto indexTy = parser.getBuilder().getIndexType();
4551
4552 MemRefType memrefType;
4553 VectorType resultType;
4554 OpAsmParser::UnresolvedOperand storeValueInfo;
4555 OpAsmParser::UnresolvedOperand memrefInfo;
4556 AffineMapAttr mapAttr;
4557 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
4558 return failure(
4559 parser.parseOperand(storeValueInfo) || parser.parseComma() ||
4560 parser.parseOperand(memrefInfo) ||
4561 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4562 AffineVectorStoreOp::getMapAttrStrName(),
4563 result.attributes) ||
4564 parser.parseOptionalAttrDict(result.attributes) ||
4565 parser.parseColonType(memrefType) || parser.parseComma() ||
4566 parser.parseType(resultType) ||
4567 parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
4568 parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4569 parser.resolveOperands(mapOperands, indexTy, result.operands));
4570}
4571
4572void AffineVectorStoreOp::print(OpAsmPrinter &p) {
4573 p << " " << getValueToStore();
4574 p << ", " << getMemRef() << '[';
4575 if (AffineMapAttr mapAttr =
4576 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4577 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4578 p << ']';
4579 p.printOptionalAttrDict((*this)->getAttrs(),
4580 /*elidedAttrs=*/{getMapAttrStrName()});
4581 p << " : " << getMemRefType() << ", " << getValueToStore().getType();
4582}
4583
4584LogicalResult AffineVectorStoreOp::verify() {
4585 MemRefType memrefType = getMemRefType();
4586 if (failed(verifyMemoryOpIndexing(
4587 *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4588 getMapOperands(), memrefType,
4589 /*numIndexOperands=*/getNumOperands() - 2)))
4590 return failure();
4591
4592 if (failed(verifyVectorMemoryOp(*this, memrefType, getVectorType())))
4593 return failure();
4594
4595 return success();
4596}
4597
4598//===----------------------------------------------------------------------===//
4599// DelinearizeIndexOp
4600//===----------------------------------------------------------------------===//
4601
4602void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4603 OperationState &odsState,
4604 Value linearIndex, ValueRange dynamicBasis,
4605 ArrayRef<int64_t> staticBasis,
4606 bool hasOuterBound) {
4607 SmallVector<Type> returnTypes(hasOuterBound ? staticBasis.size()
4608 : staticBasis.size() + 1,
4609 linearIndex.getType());
4610 build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis,
4611 staticBasis);
4612}
4613
4614void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4615 OperationState &odsState,
4616 Value linearIndex, ValueRange basis,
4617 bool hasOuterBound) {
4618 if (hasOuterBound && !basis.empty() && basis.front() == nullptr) {
4619 hasOuterBound = false;
4620 basis = basis.drop_front();
4621 }
4622 SmallVector<Value> dynamicBasis;
4623 SmallVector<int64_t> staticBasis;
4624 dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
4625 staticBasis);
4626 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4627 hasOuterBound);
4628}
4629
4630void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4631 OperationState &odsState,
4632 Value linearIndex,
4633 ArrayRef<OpFoldResult> basis,
4634 bool hasOuterBound) {
4635 if (hasOuterBound && !basis.empty() && basis.front() == OpFoldResult()) {
4636 hasOuterBound = false;
4637 basis = basis.drop_front();
4638 }
4639 SmallVector<Value> dynamicBasis;
4640 SmallVector<int64_t> staticBasis;
4641 dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
4642 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4643 hasOuterBound);
4644}
4645
4646void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4647 OperationState &odsState,
4648 Value linearIndex, ArrayRef<int64_t> basis,
4649 bool hasOuterBound) {
4650 build(odsBuilder, odsState, linearIndex, ValueRange{}, basis, hasOuterBound);
4651}
4652
4653LogicalResult AffineDelinearizeIndexOp::verify() {
4654 ArrayRef<int64_t> staticBasis = getStaticBasis();
4655 if (getNumResults() != staticBasis.size() &&
4656 getNumResults() != staticBasis.size() + 1)
4657 return emitOpError("should return an index for each basis element and up "
4658 "to one extra index");
4659
4660 auto dynamicMarkersCount = llvm::count_if(staticBasis, ShapedType::isDynamic);
4661 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4662 return emitOpError(
4663 "mismatch between dynamic and static basis (kDynamic marker but no "
4664 "corresponding dynamic basis entry) -- this can only happen due to an "
4665 "incorrect fold/rewrite");
4666
4667 if (!llvm::all_of(staticBasis, [](int64_t v) {
4668 return v > 0 || ShapedType::isDynamic(v);
4669 }))
4670 return emitOpError("no basis element may be statically non-positive");
4671
4672 return success();
4673}
4674
4675/// Given mixed basis of affine.delinearize_index/linearize_index replace
4676/// constant SSA values with the constant integer value and return the new
4677/// static basis. In case no such candidate for replacement exists, this utility
4678/// returns std::nullopt.
4679static std::optional<SmallVector<int64_t>>
4680foldCstValueToCstAttrBasis(ArrayRef<OpFoldResult> mixedBasis,
4681 MutableOperandRange mutableDynamicBasis,
4682 ArrayRef<Attribute> dynamicBasis) {
4683 uint64_t dynamicBasisIndex = 0;
4684 for (OpFoldResult basis : dynamicBasis) {
4685 if (basis) {
4686 mutableDynamicBasis.erase(subStart: dynamicBasisIndex);
4687 } else {
4688 ++dynamicBasisIndex;
4689 }
4690 }
4691
4692 // No constant SSA value exists.
4693 if (dynamicBasisIndex == dynamicBasis.size())
4694 return std::nullopt;
4695
4696 SmallVector<int64_t> staticBasis;
4697 for (OpFoldResult basis : mixedBasis) {
4698 std::optional<int64_t> basisVal = getConstantIntValue(ofr: basis);
4699 if (!basisVal)
4700 staticBasis.push_back(ShapedType::kDynamic);
4701 else
4702 staticBasis.push_back(Elt: *basisVal);
4703 }
4704
4705 return staticBasis;
4706}
4707
4708LogicalResult
4709AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
4710 SmallVectorImpl<OpFoldResult> &result) {
4711 std::optional<SmallVector<int64_t>> maybeStaticBasis =
4712 foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
4713 adaptor.getDynamicBasis());
4714 if (maybeStaticBasis) {
4715 setStaticBasis(*maybeStaticBasis);
4716 return success();
4717 }
4718 // If we won't be doing any division or modulo (no basis or the one basis
4719 // element is purely advisory), simply return the input value.
4720 if (getNumResults() == 1) {
4721 result.push_back(getLinearIndex());
4722 return success();
4723 }
4724
4725 if (adaptor.getLinearIndex() == nullptr)
4726 return failure();
4727
4728 if (!adaptor.getDynamicBasis().empty())
4729 return failure();
4730
4731 int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
4732 Type attrType = getLinearIndex().getType();
4733
4734 ArrayRef<int64_t> staticBasis = getStaticBasis();
4735 if (hasOuterBound())
4736 staticBasis = staticBasis.drop_front();
4737 for (int64_t modulus : llvm::reverse(staticBasis)) {
4738 result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
4739 highPart = llvm::divideFloorSigned(highPart, modulus);
4740 }
4741 result.push_back(IntegerAttr::get(attrType, highPart));
4742 std::reverse(result.begin(), result.end());
4743 return success();
4744}
4745
4746SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getEffectiveBasis() {
4747 OpBuilder builder(getContext());
4748 if (hasOuterBound()) {
4749 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
4750 return getMixedValues(getStaticBasis().drop_front(),
4751 getDynamicBasis().drop_front(), builder);
4752
4753 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
4754 builder);
4755 }
4756
4757 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
4758}
4759
4760SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getPaddedBasis() {
4761 SmallVector<OpFoldResult> ret = getMixedBasis();
4762 if (!hasOuterBound())
4763 ret.insert(ret.begin(), OpFoldResult());
4764 return ret;
4765}
4766
4767namespace {
4768
4769// Drops delinearization indices that correspond to unit-extent basis
4770struct DropUnitExtentBasis
4771 : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4772 using OpRewritePattern::OpRewritePattern;
4773
4774 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4775 PatternRewriter &rewriter) const override {
4776 SmallVector<Value> replacements(delinearizeOp->getNumResults(), nullptr);
4777 std::optional<Value> zero = std::nullopt;
4778 Location loc = delinearizeOp->getLoc();
4779 auto getZero = [&]() -> Value {
4780 if (!zero)
4781 zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
4782 return zero.value();
4783 };
4784
4785 // Replace all indices corresponding to unit-extent basis with 0.
4786 // Remaining basis can be used to get a new `affine.delinearize_index` op.
4787 SmallVector<OpFoldResult> newBasis;
4788 for (auto [index, basis] :
4789 llvm::enumerate(delinearizeOp.getPaddedBasis())) {
4790 std::optional<int64_t> basisVal =
4791 basis ? getConstantIntValue(basis) : std::nullopt;
4792 if (basisVal && *basisVal == 1)
4793 replacements[index] = getZero();
4794 else
4795 newBasis.push_back(basis);
4796 }
4797
4798 if (newBasis.size() == delinearizeOp.getNumResults())
4799 return rewriter.notifyMatchFailure(delinearizeOp,
4800 "no unit basis elements");
4801
4802 if (!newBasis.empty()) {
4803 // Will drop the leading nullptr from `basis` if there was no outer bound.
4804 auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
4805 loc, delinearizeOp.getLinearIndex(), newBasis);
4806 int newIndex = 0;
4807 // Map back the new delinearized indices to the values they replace.
4808 for (auto &replacement : replacements) {
4809 if (replacement)
4810 continue;
4811 replacement = newDelinearizeOp->getResult(newIndex++);
4812 }
4813 }
4814
4815 rewriter.replaceOp(delinearizeOp, replacements);
4816 return success();
4817 }
4818};
4819
4820/// If a `affine.delinearize_index`'s input is a `affine.linearize_index
4821/// disjoint` and the two operations end with the same basis elements,
4822/// cancel those parts of the operations out because they are inverses
4823/// of each other.
4824///
4825/// If the operations have the same basis, cancel them entirely.
4826///
4827/// The `disjoint` flag is needed on the `affine.linearize_index` because
4828/// otherwise, there is no guarantee that the inputs to the linearization are
4829/// in-bounds the way the outputs of the delinearization would be.
4830struct CancelDelinearizeOfLinearizeDisjointExactTail
4831 : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4832 using OpRewritePattern::OpRewritePattern;
4833
4834 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4835 PatternRewriter &rewriter) const override {
4836 auto linearizeOp = delinearizeOp.getLinearIndex()
4837 .getDefiningOp<affine::AffineLinearizeIndexOp>();
4838 if (!linearizeOp)
4839 return rewriter.notifyMatchFailure(delinearizeOp,
4840 "index doesn't come from linearize");
4841
4842 if (!linearizeOp.getDisjoint())
4843 return rewriter.notifyMatchFailure(linearizeOp, "not disjoint");
4844
4845 ValueRange linearizeIns = linearizeOp.getMultiIndex();
4846 // Note: we use the full basis so we don't lose outer bounds later.
4847 SmallVector<OpFoldResult> linearizeBasis = linearizeOp.getMixedBasis();
4848 SmallVector<OpFoldResult> delinearizeBasis = delinearizeOp.getMixedBasis();
4849 size_t numMatches = 0;
4850 for (auto [linSize, delinSize] : llvm::zip(
4851 llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) {
4852 if (linSize != delinSize)
4853 break;
4854 ++numMatches;
4855 }
4856
4857 if (numMatches == 0)
4858 return rewriter.notifyMatchFailure(
4859 delinearizeOp, "final basis element doesn't match linearize");
4860
4861 // The easy case: everything lines up and the basis match sup completely.
4862 if (numMatches == linearizeBasis.size() &&
4863 numMatches == delinearizeBasis.size() &&
4864 linearizeIns.size() == delinearizeOp.getNumResults()) {
4865 rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
4866 return success();
4867 }
4868
4869 Value newLinearize = rewriter.create<affine::AffineLinearizeIndexOp>(
4870 linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
4871 ArrayRef<OpFoldResult>{linearizeBasis}.drop_back(numMatches),
4872 linearizeOp.getDisjoint());
4873 auto newDelinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
4874 delinearizeOp.getLoc(), newLinearize,
4875 ArrayRef<OpFoldResult>{delinearizeBasis}.drop_back(numMatches),
4876 delinearizeOp.hasOuterBound());
4877 SmallVector<Value> mergedResults(newDelinearize.getResults());
4878 mergedResults.append(in_start: linearizeIns.take_back(n: numMatches).begin(),
4879 in_end: linearizeIns.take_back(n: numMatches).end());
4880 rewriter.replaceOp(delinearizeOp, mergedResults);
4881 return success();
4882 }
4883};
4884
4885/// If the input to a delinearization is a disjoint linearization, and the
4886/// last k > 1 components of the delinearization basis multiply to the
4887/// last component of the linearization basis, break the linearization and
4888/// delinearization into two parts, peeling off the last input to linearization.
4889///
4890/// For example:
4891/// %0 = affine.linearize_index [%z, %y, %x] by (3, 2, 32) : index
4892/// %1:4 = affine.delinearize_index %0 by (2, 3, 8, 4) : index, ...
4893/// becomes
4894/// %0 = affine.linearize_index [%z, %y] by (3, 2) : index
4895/// %1:2 = affine.delinearize_index %0 by (2, 3) : index
4896/// %2:2 = affine.delinearize_index %x by (8, 4) : index
4897/// where the original %1:4 is replaced by %1:2 ++ %2:2
4898struct SplitDelinearizeSpanningLastLinearizeArg final
4899 : OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4900 using OpRewritePattern::OpRewritePattern;
4901
4902 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4903 PatternRewriter &rewriter) const override {
4904 auto linearizeOp = delinearizeOp.getLinearIndex()
4905 .getDefiningOp<affine::AffineLinearizeIndexOp>();
4906 if (!linearizeOp)
4907 return rewriter.notifyMatchFailure(delinearizeOp,
4908 "index doesn't come from linearize");
4909
4910 if (!linearizeOp.getDisjoint())
4911 return rewriter.notifyMatchFailure(linearizeOp,
4912 "linearize isn't disjoint");
4913
4914 int64_t target = linearizeOp.getStaticBasis().back();
4915 if (ShapedType::isDynamic(target))
4916 return rewriter.notifyMatchFailure(
4917 linearizeOp, "linearize ends with dynamic basis value");
4918
4919 int64_t sizeToSplit = 1;
4920 size_t elemsToSplit = 0;
4921 ArrayRef<int64_t> basis = delinearizeOp.getStaticBasis();
4922 for (int64_t basisElem : llvm::reverse(basis)) {
4923 if (ShapedType::isDynamic(basisElem))
4924 return rewriter.notifyMatchFailure(
4925 delinearizeOp, "dynamic basis element while scanning for split");
4926 sizeToSplit *= basisElem;
4927 elemsToSplit += 1;
4928
4929 if (sizeToSplit > target)
4930 return rewriter.notifyMatchFailure(delinearizeOp,
4931 "overshot last argument size");
4932 if (sizeToSplit == target)
4933 break;
4934 }
4935
4936 if (sizeToSplit < target)
4937 return rewriter.notifyMatchFailure(
4938 delinearizeOp, "product of known basis elements doesn't exceed last "
4939 "linearize argument");
4940
4941 if (elemsToSplit < 2)
4942 return rewriter.notifyMatchFailure(
4943 delinearizeOp,
4944 "need at least two elements to form the basis product");
4945
4946 Value linearizeWithoutBack =
4947 rewriter.create<affine::AffineLinearizeIndexOp>(
4948 linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
4949 linearizeOp.getDynamicBasis(),
4950 linearizeOp.getStaticBasis().drop_back(),
4951 linearizeOp.getDisjoint());
4952 auto delinearizeWithoutSplitPart =
4953 rewriter.create<affine::AffineDelinearizeIndexOp>(
4954 delinearizeOp.getLoc(), linearizeWithoutBack,
4955 delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
4956 delinearizeOp.hasOuterBound());
4957 auto delinearizeBack = rewriter.create<affine::AffineDelinearizeIndexOp>(
4958 delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
4959 basis.take_back(elemsToSplit), /*hasOuterBound=*/true);
4960 SmallVector<Value> results = llvm::to_vector(
4961 llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
4962 delinearizeBack.getResults()));
4963 rewriter.replaceOp(delinearizeOp, results);
4964
4965 return success();
4966 }
4967};
4968} // namespace
4969
4970void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
4971 RewritePatternSet &patterns, MLIRContext *context) {
4972 patterns
4973 .insert<CancelDelinearizeOfLinearizeDisjointExactTail,
4974 DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>(
4975 context);
4976}
4977
4978//===----------------------------------------------------------------------===//
4979// LinearizeIndexOp
4980//===----------------------------------------------------------------------===//
4981
4982void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
4983 OperationState &odsState,
4984 ValueRange multiIndex, ValueRange basis,
4985 bool disjoint) {
4986 if (!basis.empty() && basis.front() == Value())
4987 basis = basis.drop_front();
4988 SmallVector<Value> dynamicBasis;
4989 SmallVector<int64_t> staticBasis;
4990 dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
4991 staticBasis);
4992 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
4993}
4994
4995void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
4996 OperationState &odsState,
4997 ValueRange multiIndex,
4998 ArrayRef<OpFoldResult> basis,
4999 bool disjoint) {
5000 if (!basis.empty() && basis.front() == OpFoldResult())
5001 basis = basis.drop_front();
5002 SmallVector<Value> dynamicBasis;
5003 SmallVector<int64_t> staticBasis;
5004 dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
5005 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
5006}
5007
5008void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5009 OperationState &odsState,
5010 ValueRange multiIndex,
5011 ArrayRef<int64_t> basis, bool disjoint) {
5012 build(odsBuilder, odsState, multiIndex, ValueRange{}, basis, disjoint);
5013}
5014
5015LogicalResult AffineLinearizeIndexOp::verify() {
5016 size_t numIndexes = getMultiIndex().size();
5017 size_t numBasisElems = getStaticBasis().size();
5018 if (numIndexes != numBasisElems && numIndexes != numBasisElems + 1)
5019 return emitOpError("should be passed a basis element for each index except "
5020 "possibly the first");
5021
5022 auto dynamicMarkersCount =
5023 llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
5024 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
5025 return emitOpError(
5026 "mismatch between dynamic and static basis (kDynamic marker but no "
5027 "corresponding dynamic basis entry) -- this can only happen due to an "
5028 "incorrect fold/rewrite");
5029
5030 return success();
5031}
5032
5033OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
5034 std::optional<SmallVector<int64_t>> maybeStaticBasis =
5035 foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
5036 adaptor.getDynamicBasis());
5037 if (maybeStaticBasis) {
5038 setStaticBasis(*maybeStaticBasis);
5039 return getResult();
5040 }
5041 // No indices linearizes to zero.
5042 if (getMultiIndex().empty())
5043 return IntegerAttr::get(getResult().getType(), 0);
5044
5045 // One single index linearizes to itself.
5046 if (getMultiIndex().size() == 1)
5047 return getMultiIndex().front();
5048
5049 if (llvm::is_contained(adaptor.getMultiIndex(), nullptr))
5050 return nullptr;
5051
5052 if (!adaptor.getDynamicBasis().empty())
5053 return nullptr;
5054
5055 int64_t result = 0;
5056 int64_t stride = 1;
5057 for (auto [length, indexAttr] :
5058 llvm::zip_first(llvm::reverse(getStaticBasis()),
5059 llvm::reverse(adaptor.getMultiIndex()))) {
5060 result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
5061 stride = stride * length;
5062 }
5063 // Handle the index element with no basis element.
5064 if (!hasOuterBound())
5065 result =
5066 result +
5067 cast<IntegerAttr>(adaptor.getMultiIndex().front()).getInt() * stride;
5068
5069 return IntegerAttr::get(getResult().getType(), result);
5070}
5071
5072SmallVector<OpFoldResult> AffineLinearizeIndexOp::getEffectiveBasis() {
5073 OpBuilder builder(getContext());
5074 if (hasOuterBound()) {
5075 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
5076 return getMixedValues(getStaticBasis().drop_front(),
5077 getDynamicBasis().drop_front(), builder);
5078
5079 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
5080 builder);
5081 }
5082
5083 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
5084}
5085
5086SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() {
5087 SmallVector<OpFoldResult> ret = getMixedBasis();
5088 if (!hasOuterBound())
5089 ret.insert(ret.begin(), OpFoldResult());
5090 return ret;
5091}
5092
5093namespace {
5094/// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
5095/// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
5096/// %...d)`.
5097
5098/// Note that `disjoint` is required here, because, without it, we could have
5099/// `affine.linearize_index [%...a, %c64, %...b] by (%...c, 1, %...d)`
5100/// is a valid operation where the `%c64` cannot be trivially dropped.
5101///
5102/// Alternatively, if `%x` in the above is a known constant 0, remove it even if
5103/// the operation isn't asserted to be `disjoint`.
5104struct DropLinearizeUnitComponentsIfDisjointOrZero final
5105 : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5106 using OpRewritePattern::OpRewritePattern;
5107
5108 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5109 PatternRewriter &rewriter) const override {
5110 ValueRange multiIndex = op.getMultiIndex();
5111 size_t numIndices = multiIndex.size();
5112 SmallVector<Value> newIndices;
5113 newIndices.reserve(N: numIndices);
5114 SmallVector<OpFoldResult> newBasis;
5115 newBasis.reserve(N: numIndices);
5116
5117 if (!op.hasOuterBound()) {
5118 newIndices.push_back(Elt: multiIndex.front());
5119 multiIndex = multiIndex.drop_front();
5120 }
5121
5122 SmallVector<OpFoldResult> basis = op.getMixedBasis();
5123 for (auto [index, basisElem] : llvm::zip_equal(multiIndex, basis)) {
5124 std::optional<int64_t> basisEntry = getConstantIntValue(basisElem);
5125 if (!basisEntry || *basisEntry != 1) {
5126 newIndices.push_back(index);
5127 newBasis.push_back(basisElem);
5128 continue;
5129 }
5130
5131 std::optional<int64_t> indexValue = getConstantIntValue(index);
5132 if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
5133 newIndices.push_back(index);
5134 newBasis.push_back(basisElem);
5135 continue;
5136 }
5137 }
5138 if (newIndices.size() == numIndices)
5139 return rewriter.notifyMatchFailure(op,
5140 "no unit basis entries to replace");
5141
5142 if (newIndices.size() == 0) {
5143 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
5144 return success();
5145 }
5146 rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
5147 op, newIndices, newBasis, op.getDisjoint());
5148 return success();
5149 }
5150};
5151
5152OpFoldResult computeProduct(Location loc, OpBuilder &builder,
5153 ArrayRef<OpFoldResult> terms) {
5154 int64_t nDynamic = 0;
5155 SmallVector<Value> dynamicPart;
5156 AffineExpr result = builder.getAffineConstantExpr(constant: 1);
5157 for (OpFoldResult term : terms) {
5158 if (!term)
5159 return term;
5160 std::optional<int64_t> maybeConst = getConstantIntValue(ofr: term);
5161 if (maybeConst) {
5162 result = result * builder.getAffineConstantExpr(constant: *maybeConst);
5163 } else {
5164 dynamicPart.push_back(Elt: cast<Value>(Val&: term));
5165 result = result * builder.getAffineSymbolExpr(position: nDynamic++);
5166 }
5167 }
5168 if (auto constant = dyn_cast<AffineConstantExpr>(Val&: result))
5169 return getAsIndexOpFoldResult(ctx: builder.getContext(), val: constant.getValue());
5170 return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
5171}
5172
5173/// If conseceutive outputs of a delinearize_index are linearized with the same
5174/// bounds, canonicalize away the redundant arithmetic.
5175///
5176/// That is, if we have
5177/// ```
5178/// %s:N = affine.delinearize_index %x into (...a, B1, B2, ... BK, ...b)
5179/// %t = affine.linearize_index [...c, %s#I, %s#(I + 1), ... %s#(I+K-1), ...d]
5180/// by (...e, B1, B2, ..., BK, ...f)
5181/// ```
5182///
5183/// We can rewrite this to
5184/// ```
5185/// B = B1 * B2 ... BK
5186/// %sMerged:(N-K+1) affine.delinearize_index %x into (...a, B, ...b)
5187/// %t = affine.linearize_index [...c, %s#I, ...d] by (...e, B, ...f)
5188/// ```
5189/// where we replace all results of %s unaffected by the change with results
5190/// from %sMerged.
5191///
5192/// As a special case, if all results of the delinearize are merged in this way
5193/// we can replace those usages with %x, thus cancelling the delinearization
5194/// entirely, as in
5195/// ```
5196/// %s:3 = affine.delinearize_index %x into (2, 4, 8)
5197/// %t = affine.linearize_index [%s#0, %s#1, %s#2, %c0] by (2, 4, 8, 16)
5198/// ```
5199/// becoming `%t = affine.linearize_index [%x, %c0] by (64, 16)`
5200struct CancelLinearizeOfDelinearizePortion final
5201 : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5202 using OpRewritePattern::OpRewritePattern;
5203
5204private:
5205 // Struct representing a case where the cancellation pattern
5206 // applies. A `Match` means that `length` inputs to the linearize operation
5207 // starting at `linStart` can be cancelled with `length` outputs of
5208 // `delinearize`, starting from `delinStart`.
5209 struct Match {
5210 AffineDelinearizeIndexOp delinearize;
5211 unsigned linStart = 0;
5212 unsigned delinStart = 0;
5213 unsigned length = 0;
5214 };
5215
5216public:
5217 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
5218 PatternRewriter &rewriter) const override {
5219 SmallVector<Match> matches;
5220
5221 const SmallVector<OpFoldResult> linBasis = linearizeOp.getPaddedBasis();
5222 ArrayRef<OpFoldResult> linBasisRef = linBasis;
5223
5224 ValueRange multiIndex = linearizeOp.getMultiIndex();
5225 unsigned numLinArgs = multiIndex.size();
5226 unsigned linArgIdx = 0;
5227 // We only want to replace one run from the same delinearize op per
5228 // pattern invocation lest we run into invalidation issues.
5229 llvm::SmallPtrSet<Operation *, 2> alreadyMatchedDelinearize;
5230 while (linArgIdx < numLinArgs) {
5231 auto asResult = dyn_cast<OpResult>(Val: multiIndex[linArgIdx]);
5232 if (!asResult) {
5233 linArgIdx++;
5234 continue;
5235 }
5236
5237 auto delinearizeOp =
5238 dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
5239 if (!delinearizeOp) {
5240 linArgIdx++;
5241 continue;
5242 }
5243
5244 /// Result 0 of the delinearize and argument 0 of the linearize can
5245 /// leave their maximum value unspecified. However, even if this happens
5246 /// we can still sometimes start the match process. Specifically, if
5247 /// - The argument we're matching is result 0 and argument 0 (so the
5248 /// bounds don't matter). For example,
5249 ///
5250 /// %0:2 = affine.delinearize_index %x into (8) : index, index
5251 /// %1 = affine.linearize_index [%s#0, %s#1, ...] (8, ...)
5252 /// allows cancellation
5253 /// - The delinearization doesn't specify a bound, but the linearization
5254 /// is `disjoint`, which asserts that the bound on the linearization is
5255 /// correct.
5256 unsigned delinArgIdx = asResult.getResultNumber();
5257 SmallVector<OpFoldResult> delinBasis = delinearizeOp.getPaddedBasis();
5258 OpFoldResult firstDelinBound = delinBasis[delinArgIdx];
5259 OpFoldResult firstLinBound = linBasis[linArgIdx];
5260 bool boundsMatch = firstDelinBound == firstLinBound;
5261 bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0;
5262 bool knownByDisjoint =
5263 linearizeOp.getDisjoint() && delinArgIdx == 0 && !firstDelinBound;
5264 if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
5265 linArgIdx++;
5266 continue;
5267 }
5268
5269 unsigned j = 1;
5270 unsigned numDelinOuts = delinearizeOp.getNumResults();
5271 for (; j + linArgIdx < numLinArgs && j + delinArgIdx < numDelinOuts;
5272 ++j) {
5273 if (multiIndex[linArgIdx + j] !=
5274 delinearizeOp.getResult(delinArgIdx + j))
5275 break;
5276 if (linBasis[linArgIdx + j] != delinBasis[delinArgIdx + j])
5277 break;
5278 }
5279 // If there're multiple matches against the same delinearize_index,
5280 // only rewrite the first one we find to prevent invalidations. The next
5281 // ones will be taken care of by subsequent pattern invocations.
5282 if (j <= 1 || !alreadyMatchedDelinearize.insert(delinearizeOp).second) {
5283 linArgIdx++;
5284 continue;
5285 }
5286 matches.push_back(Elt: Match{delinearizeOp, linArgIdx, delinArgIdx, j});
5287 linArgIdx += j;
5288 }
5289
5290 if (matches.empty())
5291 return rewriter.notifyMatchFailure(
5292 linearizeOp, "no run of delinearize outputs to deal with");
5293
5294 // Record all the delinearize replacements so we can do them after creating
5295 // the new linearization operation, since the new operation might use
5296 // outputs of something we're replacing.
5297 SmallVector<SmallVector<Value>> delinearizeReplacements;
5298
5299 SmallVector<Value> newIndex;
5300 newIndex.reserve(N: numLinArgs);
5301 SmallVector<OpFoldResult> newBasis;
5302 newBasis.reserve(N: numLinArgs);
5303 unsigned prevMatchEnd = 0;
5304 for (Match m : matches) {
5305 unsigned gap = m.linStart - prevMatchEnd;
5306 llvm::append_range(C&: newIndex, R: multiIndex.slice(n: prevMatchEnd, m: gap));
5307 llvm::append_range(C&: newBasis, R: linBasisRef.slice(N: prevMatchEnd, M: gap));
5308 // Update here so we don't forget this during early continues
5309 prevMatchEnd = m.linStart + m.length;
5310
5311 PatternRewriter::InsertionGuard g(rewriter);
5312 rewriter.setInsertionPoint(m.delinearize);
5313
5314 ArrayRef<OpFoldResult> basisToMerge =
5315 linBasisRef.slice(N: m.linStart, M: m.length);
5316 // We use the slice from the linearize's basis above because of the
5317 // "bounds inferred from `disjoint`" case above.
5318 OpFoldResult newSize =
5319 computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge);
5320
5321 // Trivial case where we can just skip past the delinearize all together
5322 if (m.length == m.delinearize.getNumResults()) {
5323 newIndex.push_back(Elt: m.delinearize.getLinearIndex());
5324 newBasis.push_back(Elt: newSize);
5325 // Pad out set of replacements so we don't do anything with this one.
5326 delinearizeReplacements.push_back(Elt: SmallVector<Value>());
5327 continue;
5328 }
5329
5330 SmallVector<Value> newDelinResults;
5331 SmallVector<OpFoldResult> newDelinBasis = m.delinearize.getPaddedBasis();
5332 newDelinBasis.erase(CS: newDelinBasis.begin() + m.delinStart,
5333 CE: newDelinBasis.begin() + m.delinStart + m.length);
5334 newDelinBasis.insert(I: newDelinBasis.begin() + m.delinStart, Elt: newSize);
5335 auto newDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
5336 m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
5337 newDelinBasis);
5338
5339 // Since there may be other uses of the indices we just merged together,
5340 // create a residual affine.delinearize_index that delinearizes the
5341 // merged output into its component parts.
5342 Value combinedElem = newDelinearize.getResult(m.delinStart);
5343 auto residualDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
5344 m.delinearize.getLoc(), combinedElem, basisToMerge);
5345
5346 // Swap all the uses of the unaffected delinearize outputs to the new
5347 // delinearization so that the old code can be removed if this
5348 // linearize_index is the only user of the merged results.
5349 llvm::append_range(newDelinResults,
5350 newDelinearize.getResults().take_front(m.delinStart));
5351 llvm::append_range(newDelinResults, residualDelinearize.getResults());
5352 llvm::append_range(
5353 newDelinResults,
5354 newDelinearize.getResults().drop_front(m.delinStart + 1));
5355
5356 delinearizeReplacements.push_back(Elt: newDelinResults);
5357 newIndex.push_back(Elt: combinedElem);
5358 newBasis.push_back(Elt: newSize);
5359 }
5360 llvm::append_range(C&: newIndex, R: multiIndex.drop_front(n: prevMatchEnd));
5361 llvm::append_range(C&: newBasis, R: linBasisRef.drop_front(N: prevMatchEnd));
5362 rewriter.replaceOpWithNewOp<AffineLinearizeIndexOp>(
5363 linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
5364
5365 for (auto [m, newResults] :
5366 llvm::zip_equal(t&: matches, u&: delinearizeReplacements)) {
5367 if (newResults.empty())
5368 continue;
5369 rewriter.replaceOp(m.delinearize, newResults);
5370 }
5371
5372 return success();
5373 }
5374};
5375
5376/// Strip leading zero from affine.linearize_index.
5377///
5378/// `affine.linearize_index [%c0, ...a] by (%x, ...b)` can be rewritten
5379/// to `affine.linearize_index [...a] by (...b)` in all cases.
5380struct DropLinearizeLeadingZero final
5381 : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5382 using OpRewritePattern::OpRewritePattern;
5383
5384 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5385 PatternRewriter &rewriter) const override {
5386 Value leadingIdx = op.getMultiIndex().front();
5387 if (!matchPattern(value: leadingIdx, pattern: m_Zero()))
5388 return failure();
5389
5390 if (op.getMultiIndex().size() == 1) {
5391 rewriter.replaceOp(op, leadingIdx);
5392 return success();
5393 }
5394
5395 SmallVector<OpFoldResult> mixedBasis = op.getMixedBasis();
5396 ArrayRef<OpFoldResult> newMixedBasis = mixedBasis;
5397 if (op.hasOuterBound())
5398 newMixedBasis = newMixedBasis.drop_front();
5399
5400 rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
5401 op, op.getMultiIndex().drop_front(), newMixedBasis, op.getDisjoint());
5402 return success();
5403 }
5404};
5405} // namespace
5406
5407void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
5408 RewritePatternSet &patterns, MLIRContext *context) {
5409 patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
5410 DropLinearizeUnitComponentsIfDisjointOrZero>(context);
5411}
5412
5413//===----------------------------------------------------------------------===//
5414// TableGen'd op method definitions
5415//===----------------------------------------------------------------------===//
5416
5417#define GET_OP_CLASSES
5418#include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
5419

source code of mlir/lib/Dialect/Affine/IR/AffineOps.cpp