1//===- Utils.cpp ---- Utilities for affine dialect transformation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements miscellaneous transformation utilities for the Affine
10// dialect.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/Utils.h"
15
16#include "mlir/Dialect/Affine/Analysis/Utils.h"
17#include "mlir/Dialect/Affine/IR/AffineOps.h"
18#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
19#include "mlir/Dialect/Affine/LoopUtils.h"
20#include "mlir/Dialect/Arith/Utils/Utils.h"
21#include "mlir/Dialect/Func/IR/FuncOps.h"
22#include "mlir/Dialect/MemRef/IR/MemRef.h"
23#include "mlir/Dialect/Utils/IndexingUtils.h"
24#include "mlir/IR/AffineExprVisitor.h"
25#include "mlir/IR/Dominance.h"
26#include "mlir/IR/IRMapping.h"
27#include "mlir/IR/ImplicitLocOpBuilder.h"
28#include "mlir/IR/IntegerSet.h"
29#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30#include <optional>
31
32#define DEBUG_TYPE "affine-utils"
33
34using namespace mlir;
35using namespace affine;
36using namespace presburger;
37
38namespace {
39/// Visit affine expressions recursively and build the sequence of operations
40/// that correspond to it. Visitation functions return an Value of the
41/// expression subtree they visited or `nullptr` on error.
42class AffineApplyExpander
43 : public AffineExprVisitor<AffineApplyExpander, Value> {
44public:
45 /// This internal class expects arguments to be non-null, checks must be
46 /// performed at the call site.
47 AffineApplyExpander(OpBuilder &builder, ValueRange dimValues,
48 ValueRange symbolValues, Location loc)
49 : builder(builder), dimValues(dimValues), symbolValues(symbolValues),
50 loc(loc) {}
51
52 template <typename OpTy>
53 Value buildBinaryExpr(AffineBinaryOpExpr expr) {
54 auto lhs = visit(expr: expr.getLHS());
55 auto rhs = visit(expr: expr.getRHS());
56 if (!lhs || !rhs)
57 return nullptr;
58 auto op = builder.create<OpTy>(loc, lhs, rhs);
59 return op.getResult();
60 }
61
62 Value visitAddExpr(AffineBinaryOpExpr expr) {
63 return buildBinaryExpr<arith::AddIOp>(expr);
64 }
65
66 Value visitMulExpr(AffineBinaryOpExpr expr) {
67 return buildBinaryExpr<arith::MulIOp>(expr);
68 }
69
70 /// Euclidean modulo operation: negative RHS is not allowed.
71 /// Remainder of the euclidean integer division is always non-negative.
72 ///
73 /// Implemented as
74 ///
75 /// a mod b =
76 /// let remainder = srem a, b;
77 /// negative = a < 0 in
78 /// select negative, remainder + b, remainder.
79 Value visitModExpr(AffineBinaryOpExpr expr) {
80 if (auto rhsConst = dyn_cast<AffineConstantExpr>(Val: expr.getRHS())) {
81 if (rhsConst.getValue() <= 0) {
82 emitError(loc, message: "modulo by non-positive value is not supported");
83 return nullptr;
84 }
85 }
86
87 auto lhs = visit(expr: expr.getLHS());
88 auto rhs = visit(expr: expr.getRHS());
89 assert(lhs && rhs && "unexpected affine expr lowering failure");
90
91 Value remainder = builder.create<arith::RemSIOp>(loc, lhs, rhs);
92 Value zeroCst = builder.create<arith::ConstantIndexOp>(location: loc, args: 0);
93 Value isRemainderNegative = builder.create<arith::CmpIOp>(
94 loc, arith::CmpIPredicate::slt, remainder, zeroCst);
95 Value correctedRemainder =
96 builder.create<arith::AddIOp>(loc, remainder, rhs);
97 Value result = builder.create<arith::SelectOp>(
98 loc, isRemainderNegative, correctedRemainder, remainder);
99 return result;
100 }
101
102 /// Floor division operation (rounds towards negative infinity).
103 ///
104 /// For positive divisors, it can be implemented without branching and with a
105 /// single division operation as
106 ///
107 /// a floordiv b =
108 /// let negative = a < 0 in
109 /// let absolute = negative ? -a - 1 : a in
110 /// let quotient = absolute / b in
111 /// negative ? -quotient - 1 : quotient
112 ///
113 /// Note: this lowering does not use arith.floordivsi because the lowering of
114 /// that to arith.divsi (see populateCeilFloorDivExpandOpsPatterns) generates
115 /// not one but two arith.divsi. That could be changed to one divsi, but one
116 /// way or another, going through arith.floordivsi will result in more complex
117 /// IR because arith.floordivsi is more general than affine floordiv in that
118 /// it supports negative RHS.
119 Value visitFloorDivExpr(AffineBinaryOpExpr expr) {
120 if (auto rhsConst = dyn_cast<AffineConstantExpr>(Val: expr.getRHS())) {
121 if (rhsConst.getValue() <= 0) {
122 emitError(loc, message: "division by non-positive value is not supported");
123 return nullptr;
124 }
125 }
126 auto lhs = visit(expr: expr.getLHS());
127 auto rhs = visit(expr: expr.getRHS());
128 assert(lhs && rhs && "unexpected affine expr lowering failure");
129
130 Value zeroCst = builder.create<arith::ConstantIndexOp>(location: loc, args: 0);
131 Value noneCst = builder.create<arith::ConstantIndexOp>(location: loc, args: -1);
132 Value negative = builder.create<arith::CmpIOp>(
133 loc, arith::CmpIPredicate::slt, lhs, zeroCst);
134 Value negatedDecremented = builder.create<arith::SubIOp>(loc, noneCst, lhs);
135 Value dividend =
136 builder.create<arith::SelectOp>(loc, negative, negatedDecremented, lhs);
137 Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs);
138 Value correctedQuotient =
139 builder.create<arith::SubIOp>(loc, noneCst, quotient);
140 Value result = builder.create<arith::SelectOp>(loc, negative,
141 correctedQuotient, quotient);
142 return result;
143 }
144
145 /// Ceiling division operation (rounds towards positive infinity).
146 ///
147 /// For positive divisors, it can be implemented without branching and with a
148 /// single division operation as
149 ///
150 /// a ceildiv b =
151 /// let negative = a <= 0 in
152 /// let absolute = negative ? -a : a - 1 in
153 /// let quotient = absolute / b in
154 /// negative ? -quotient : quotient + 1
155 ///
156 /// Note: not using arith.ceildivsi for the same reason as explained in the
157 /// visitFloorDivExpr comment.
158 Value visitCeilDivExpr(AffineBinaryOpExpr expr) {
159 if (auto rhsConst = dyn_cast<AffineConstantExpr>(Val: expr.getRHS())) {
160 if (rhsConst.getValue() <= 0) {
161 emitError(loc, message: "division by non-positive value is not supported");
162 return nullptr;
163 }
164 }
165 auto lhs = visit(expr: expr.getLHS());
166 auto rhs = visit(expr: expr.getRHS());
167 assert(lhs && rhs && "unexpected affine expr lowering failure");
168
169 Value zeroCst = builder.create<arith::ConstantIndexOp>(location: loc, args: 0);
170 Value oneCst = builder.create<arith::ConstantIndexOp>(location: loc, args: 1);
171 Value nonPositive = builder.create<arith::CmpIOp>(
172 loc, arith::CmpIPredicate::sle, lhs, zeroCst);
173 Value negated = builder.create<arith::SubIOp>(loc, zeroCst, lhs);
174 Value decremented = builder.create<arith::SubIOp>(loc, lhs, oneCst);
175 Value dividend =
176 builder.create<arith::SelectOp>(loc, nonPositive, negated, decremented);
177 Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs);
178 Value negatedQuotient =
179 builder.create<arith::SubIOp>(loc, zeroCst, quotient);
180 Value incrementedQuotient =
181 builder.create<arith::AddIOp>(loc, quotient, oneCst);
182 Value result = builder.create<arith::SelectOp>(
183 loc, nonPositive, negatedQuotient, incrementedQuotient);
184 return result;
185 }
186
187 Value visitConstantExpr(AffineConstantExpr expr) {
188 auto op = builder.create<arith::ConstantIndexOp>(location: loc, args: expr.getValue());
189 return op.getResult();
190 }
191
192 Value visitDimExpr(AffineDimExpr expr) {
193 assert(expr.getPosition() < dimValues.size() &&
194 "affine dim position out of range");
195 return dimValues[expr.getPosition()];
196 }
197
198 Value visitSymbolExpr(AffineSymbolExpr expr) {
199 assert(expr.getPosition() < symbolValues.size() &&
200 "symbol dim position out of range");
201 return symbolValues[expr.getPosition()];
202 }
203
204private:
205 OpBuilder &builder;
206 ValueRange dimValues;
207 ValueRange symbolValues;
208
209 Location loc;
210};
211} // namespace
212
213/// Create a sequence of operations that implement the `expr` applied to the
214/// given dimension and symbol values.
215mlir::Value mlir::affine::expandAffineExpr(OpBuilder &builder, Location loc,
216 AffineExpr expr,
217 ValueRange dimValues,
218 ValueRange symbolValues) {
219 return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr);
220}
221
222/// Create a sequence of operations that implement the `affineMap` applied to
223/// the given `operands` (as it it were an AffineApplyOp).
224std::optional<SmallVector<Value, 8>>
225mlir::affine::expandAffineMap(OpBuilder &builder, Location loc,
226 AffineMap affineMap, ValueRange operands) {
227 auto numDims = affineMap.getNumDims();
228 auto expanded = llvm::to_vector<8>(
229 Range: llvm::map_range(C: affineMap.getResults(),
230 F: [numDims, &builder, loc, operands](AffineExpr expr) {
231 return expandAffineExpr(builder, loc, expr,
232 dimValues: operands.take_front(n: numDims),
233 symbolValues: operands.drop_front(n: numDims));
234 }));
235 if (llvm::all_of(Range&: expanded, P: [](Value v) { return v; }))
236 return expanded;
237 return std::nullopt;
238}
239
240/// Promotes the `then` or the `else` block of `ifOp` (depending on whether
241/// `elseBlock` is false or true) into `ifOp`'s containing block, and discards
242/// the rest of the op.
243static void promoteIfBlock(AffineIfOp ifOp, bool elseBlock) {
244 if (elseBlock)
245 assert(ifOp.hasElse() && "else block expected");
246
247 Block *destBlock = ifOp->getBlock();
248 Block *srcBlock = elseBlock ? ifOp.getElseBlock() : ifOp.getThenBlock();
249 destBlock->getOperations().splice(
250 where: Block::iterator(ifOp), L2&: srcBlock->getOperations(), first: srcBlock->begin(),
251 last: std::prev(x: srcBlock->end()));
252 ifOp.erase();
253}
254
255/// Returns the outermost affine.for/parallel op that the `ifOp` is invariant
256/// on. The `ifOp` could be hoisted and placed right before such an operation.
257/// This method assumes that the ifOp has been canonicalized (to be correct and
258/// effective).
259static Operation *getOutermostInvariantForOp(AffineIfOp ifOp) {
260 // Walk up the parents past all for op that this conditional is invariant on.
261 auto ifOperands = ifOp.getOperands();
262 auto *res = ifOp.getOperation();
263 while (!isa<func::FuncOp>(res->getParentOp())) {
264 auto *parentOp = res->getParentOp();
265 if (auto forOp = dyn_cast<AffineForOp>(parentOp)) {
266 if (llvm::is_contained(ifOperands, forOp.getInductionVar()))
267 break;
268 } else if (auto parallelOp = dyn_cast<AffineParallelOp>(parentOp)) {
269 for (auto iv : parallelOp.getIVs())
270 if (llvm::is_contained(ifOperands, iv))
271 break;
272 } else if (!isa<AffineIfOp>(parentOp)) {
273 // Won't walk up past anything other than affine.for/if ops.
274 break;
275 }
276 // You can always hoist up past any affine.if ops.
277 res = parentOp;
278 }
279 return res;
280}
281
282/// A helper for the mechanics of mlir::hoistAffineIfOp. Hoists `ifOp` just over
283/// `hoistOverOp`. Returns the new hoisted op if any hoisting happened,
284/// otherwise the same `ifOp`.
285static AffineIfOp hoistAffineIfOp(AffineIfOp ifOp, Operation *hoistOverOp) {
286 // No hoisting to do.
287 if (hoistOverOp == ifOp)
288 return ifOp;
289
290 // Create the hoisted 'if' first. Then, clone the op we are hoisting over for
291 // the else block. Then drop the else block of the original 'if' in the 'then'
292 // branch while promoting its then block, and analogously drop the 'then'
293 // block of the original 'if' from the 'else' branch while promoting its else
294 // block.
295 IRMapping operandMap;
296 OpBuilder b(hoistOverOp);
297 auto hoistedIfOp = b.create<AffineIfOp>(ifOp.getLoc(), ifOp.getIntegerSet(),
298 ifOp.getOperands(),
299 /*elseBlock=*/true);
300
301 // Create a clone of hoistOverOp to use for the else branch of the hoisted
302 // conditional. The else block may get optimized away if empty.
303 Operation *hoistOverOpClone = nullptr;
304 // We use this unique name to identify/find `ifOp`'s clone in the else
305 // version.
306 StringAttr idForIfOp = b.getStringAttr("__mlir_if_hoisting");
307 operandMap.clear();
308 b.setInsertionPointAfter(hoistOverOp);
309 // We'll set an attribute to identify this op in a clone of this sub-tree.
310 ifOp->setAttr(idForIfOp, b.getBoolAttr(value: true));
311 hoistOverOpClone = b.clone(op&: *hoistOverOp, mapper&: operandMap);
312
313 // Promote the 'then' block of the original affine.if in the then version.
314 promoteIfBlock(ifOp, /*elseBlock=*/false);
315
316 // Move the then version to the hoisted if op's 'then' block.
317 auto *thenBlock = hoistedIfOp.getThenBlock();
318 thenBlock->getOperations().splice(thenBlock->begin(),
319 hoistOverOp->getBlock()->getOperations(),
320 Block::iterator(hoistOverOp));
321
322 // Find the clone of the original affine.if op in the else version.
323 AffineIfOp ifCloneInElse;
324 hoistOverOpClone->walk([&](AffineIfOp ifClone) {
325 if (!ifClone->getAttr(idForIfOp))
326 return WalkResult::advance();
327 ifCloneInElse = ifClone;
328 return WalkResult::interrupt();
329 });
330 assert(ifCloneInElse && "if op clone should exist");
331 // For the else block, promote the else block of the original 'if' if it had
332 // one; otherwise, the op itself is to be erased.
333 if (!ifCloneInElse.hasElse())
334 ifCloneInElse.erase();
335 else
336 promoteIfBlock(ifCloneInElse, /*elseBlock=*/true);
337
338 // Move the else version into the else block of the hoisted if op.
339 auto *elseBlock = hoistedIfOp.getElseBlock();
340 elseBlock->getOperations().splice(
341 elseBlock->begin(), hoistOverOpClone->getBlock()->getOperations(),
342 Block::iterator(hoistOverOpClone));
343
344 return hoistedIfOp;
345}
346
347LogicalResult
348mlir::affine::affineParallelize(AffineForOp forOp,
349 ArrayRef<LoopReduction> parallelReductions,
350 AffineParallelOp *resOp) {
351 // Fail early if there are iter arguments that are not reductions.
352 unsigned numReductions = parallelReductions.size();
353 if (numReductions != forOp.getNumIterOperands())
354 return failure();
355
356 Location loc = forOp.getLoc();
357 OpBuilder outsideBuilder(forOp);
358 AffineMap lowerBoundMap = forOp.getLowerBoundMap();
359 ValueRange lowerBoundOperands = forOp.getLowerBoundOperands();
360 AffineMap upperBoundMap = forOp.getUpperBoundMap();
361 ValueRange upperBoundOperands = forOp.getUpperBoundOperands();
362
363 // Creating empty 1-D affine.parallel op.
364 auto reducedValues = llvm::to_vector<4>(Range: llvm::map_range(
365 C&: parallelReductions, F: [](const LoopReduction &red) { return red.value; }));
366 auto reductionKinds = llvm::to_vector<4>(llvm::map_range(
367 parallelReductions, [](const LoopReduction &red) { return red.kind; }));
368 AffineParallelOp newPloop = outsideBuilder.create<AffineParallelOp>(
369 loc, ValueRange(reducedValues).getTypes(), reductionKinds,
370 llvm::ArrayRef(lowerBoundMap), lowerBoundOperands,
371 llvm::ArrayRef(upperBoundMap), upperBoundOperands,
372 llvm::ArrayRef(forOp.getStepAsInt()));
373 // Steal the body of the old affine for op.
374 newPloop.getRegion().takeBody(forOp.getRegion());
375 Operation *yieldOp = &newPloop.getBody()->back();
376
377 // Handle the initial values of reductions because the parallel loop always
378 // starts from the neutral value.
379 SmallVector<Value> newResults;
380 newResults.reserve(N: numReductions);
381 for (unsigned i = 0; i < numReductions; ++i) {
382 Value init = forOp.getInits()[i];
383 // This works because we are only handling single-op reductions at the
384 // moment. A switch on reduction kind or a mechanism to collect operations
385 // participating in the reduction will be necessary for multi-op reductions.
386 Operation *reductionOp = yieldOp->getOperand(idx: i).getDefiningOp();
387 assert(reductionOp && "yielded value is expected to be produced by an op");
388 outsideBuilder.getInsertionBlock()->getOperations().splice(
389 outsideBuilder.getInsertionPoint(), newPloop.getBody()->getOperations(),
390 reductionOp);
391 reductionOp->setOperands({init, newPloop->getResult(i)});
392 forOp->getResult(i).replaceAllUsesWith(reductionOp->getResult(idx: 0));
393 }
394
395 // Update the loop terminator to yield reduced values bypassing the reduction
396 // operation itself (now moved outside of the loop) and erase the block
397 // arguments that correspond to reductions. Note that the loop always has one
398 // "main" induction variable whenc coming from a non-parallel for.
399 unsigned numIVs = 1;
400 yieldOp->setOperands(reducedValues);
401 newPloop.getBody()->eraseArguments(numIVs, numReductions);
402
403 forOp.erase();
404 if (resOp)
405 *resOp = newPloop;
406 return success();
407}
408
409// Returns success if any hoisting happened.
410LogicalResult mlir::affine::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
411 // Bail out early if the ifOp returns a result. TODO: Consider how to
412 // properly support this case.
413 if (ifOp.getNumResults() != 0)
414 return failure();
415
416 // Apply canonicalization patterns and folding - this is necessary for the
417 // hoisting check to be correct (operands should be composed), and to be more
418 // effective (no unused operands). Since the pattern rewriter's folding is
419 // entangled with application of patterns, we may fold/end up erasing the op,
420 // in which case we return with `folded` being set.
421 RewritePatternSet patterns(ifOp.getContext());
422 AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
423 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
424 GreedyRewriteConfig config;
425 config.strictMode = GreedyRewriteStrictness::ExistingOps;
426 bool erased;
427 (void)applyOpPatternsAndFold(ifOp.getOperation(), frozenPatterns, config,
428 /*changed=*/nullptr, &erased);
429 if (erased) {
430 if (folded)
431 *folded = true;
432 return failure();
433 }
434 if (folded)
435 *folded = false;
436
437 // The folding above should have ensured this, but the affine.if's
438 // canonicalization is missing composition of affine.applys into it.
439 assert(llvm::all_of(ifOp.getOperands(),
440 [](Value v) {
441 return isTopLevelValue(v) || isAffineForInductionVar(v);
442 }) &&
443 "operands not composed");
444
445 // We are going hoist as high as possible.
446 // TODO: this could be customized in the future.
447 auto *hoistOverOp = getOutermostInvariantForOp(ifOp);
448
449 AffineIfOp hoistedIfOp = ::hoistAffineIfOp(ifOp, hoistOverOp);
450 // Nothing to hoist over.
451 if (hoistedIfOp == ifOp)
452 return failure();
453
454 // Canonicalize to remove dead else blocks (happens whenever an 'if' moves up
455 // a sequence of affine.fors that are all perfectly nested).
456 (void)applyPatternsAndFoldGreedily(
457 hoistedIfOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(),
458 frozenPatterns);
459
460 return success();
461}
462
463// Return the min expr after replacing the given dim.
464AffineExpr mlir::affine::substWithMin(AffineExpr e, AffineExpr dim,
465 AffineExpr min, AffineExpr max,
466 bool positivePath) {
467 if (e == dim)
468 return positivePath ? min : max;
469 if (auto bin = dyn_cast<AffineBinaryOpExpr>(Val&: e)) {
470 AffineExpr lhs = bin.getLHS();
471 AffineExpr rhs = bin.getRHS();
472 if (bin.getKind() == mlir::AffineExprKind::Add)
473 return substWithMin(e: lhs, dim, min, max, positivePath) +
474 substWithMin(e: rhs, dim, min, max, positivePath);
475
476 auto c1 = dyn_cast<AffineConstantExpr>(Val: bin.getLHS());
477 auto c2 = dyn_cast<AffineConstantExpr>(Val: bin.getRHS());
478 if (c1 && c1.getValue() < 0)
479 return getAffineBinaryOpExpr(
480 kind: bin.getKind(), lhs: c1, rhs: substWithMin(e: rhs, dim, min, max, positivePath: !positivePath));
481 if (c2 && c2.getValue() < 0)
482 return getAffineBinaryOpExpr(
483 kind: bin.getKind(), lhs: substWithMin(e: lhs, dim, min, max, positivePath: !positivePath), rhs: c2);
484 return getAffineBinaryOpExpr(
485 kind: bin.getKind(), lhs: substWithMin(e: lhs, dim, min, max, positivePath),
486 rhs: substWithMin(e: rhs, dim, min, max, positivePath));
487 }
488 return e;
489}
490
491void mlir::affine::normalizeAffineParallel(AffineParallelOp op) {
492 // Loops with min/max in bounds are not normalized at the moment.
493 if (op.hasMinMaxBounds())
494 return;
495
496 AffineMap lbMap = op.getLowerBoundsMap();
497 SmallVector<int64_t, 8> steps = op.getSteps();
498 // No need to do any work if the parallel op is already normalized.
499 bool isAlreadyNormalized =
500 llvm::all_of(Range: llvm::zip(t&: steps, u: lbMap.getResults()), P: [](auto tuple) {
501 int64_t step = std::get<0>(tuple);
502 auto lbExpr = dyn_cast<AffineConstantExpr>(std::get<1>(tuple));
503 return lbExpr && lbExpr.getValue() == 0 && step == 1;
504 });
505 if (isAlreadyNormalized)
506 return;
507
508 AffineValueMap ranges;
509 AffineValueMap::difference(a: op.getUpperBoundsValueMap(),
510 b: op.getLowerBoundsValueMap(), res: &ranges);
511 auto builder = OpBuilder::atBlockBegin(block: op.getBody());
512 auto zeroExpr = builder.getAffineConstantExpr(0);
513 SmallVector<AffineExpr, 8> lbExprs;
514 SmallVector<AffineExpr, 8> ubExprs;
515 for (unsigned i = 0, e = steps.size(); i < e; ++i) {
516 int64_t step = steps[i];
517
518 // Adjust the lower bound to be 0.
519 lbExprs.push_back(Elt: zeroExpr);
520
521 // Adjust the upper bound expression: 'range / step'.
522 AffineExpr ubExpr = ranges.getResult(i).ceilDiv(v: step);
523 ubExprs.push_back(Elt: ubExpr);
524
525 // Adjust the corresponding IV: 'lb + i * step'.
526 BlockArgument iv = op.getBody()->getArgument(i);
527 AffineExpr lbExpr = lbMap.getResult(idx: i);
528 unsigned nDims = lbMap.getNumDims();
529 auto expr = lbExpr + builder.getAffineDimExpr(nDims) * step;
530 auto map = AffineMap::get(/*dimCount=*/nDims + 1,
531 /*symbolCount=*/lbMap.getNumSymbols(), expr);
532
533 // Use an 'affine.apply' op that will be simplified later in subsequent
534 // canonicalizations.
535 OperandRange lbOperands = op.getLowerBoundsOperands();
536 OperandRange dimOperands = lbOperands.take_front(n: nDims);
537 OperandRange symbolOperands = lbOperands.drop_front(n: nDims);
538 SmallVector<Value, 8> applyOperands{dimOperands};
539 applyOperands.push_back(Elt: iv);
540 applyOperands.append(in_start: symbolOperands.begin(), in_end: symbolOperands.end());
541 auto apply = builder.create<AffineApplyOp>(op.getLoc(), map, applyOperands);
542 iv.replaceAllUsesExcept(apply, apply);
543 }
544
545 SmallVector<int64_t, 8> newSteps(op.getNumDims(), 1);
546 op.setSteps(newSteps);
547 auto newLowerMap = AffineMap::get(
548 /*dimCount=*/0, /*symbolCount=*/0, lbExprs, op.getContext());
549 op.setLowerBounds({}, newLowerMap);
550 auto newUpperMap = AffineMap::get(ranges.getNumDims(), ranges.getNumSymbols(),
551 ubExprs, op.getContext());
552 op.setUpperBounds(ranges.getOperands(), newUpperMap);
553}
554
555LogicalResult mlir::affine::normalizeAffineFor(AffineForOp op,
556 bool promoteSingleIter) {
557 if (promoteSingleIter && succeeded(promoteIfSingleIteration(op)))
558 return success();
559
560 // Check if the forop is already normalized.
561 if (op.hasConstantLowerBound() && (op.getConstantLowerBound() == 0) &&
562 (op.getStep() == 1))
563 return success();
564
565 // Check if the lower bound has a single result only. Loops with a max lower
566 // bound can't be normalized without additional support like
567 // affine.execute_region's. If the lower bound does not have a single result
568 // then skip this op.
569 if (op.getLowerBoundMap().getNumResults() != 1)
570 return failure();
571
572 Location loc = op.getLoc();
573 OpBuilder opBuilder(op);
574 int64_t origLoopStep = op.getStepAsInt();
575
576 // Construct the new upper bound value map.
577 AffineMap oldLbMap = op.getLowerBoundMap();
578 // The upper bound can have multiple results. To use
579 // AffineValueMap::difference, we need to have the same number of results in
580 // both lower and upper bound maps. So, we just create a value map for the
581 // lower bound with the only available lower bound result repeated to pad up
582 // to the number of upper bound results.
583 SmallVector<AffineExpr> lbExprs(op.getUpperBoundMap().getNumResults(),
584 op.getLowerBoundMap().getResult(0));
585 AffineValueMap lbMap(oldLbMap, op.getLowerBoundOperands());
586 AffineMap paddedLbMap =
587 AffineMap::get(oldLbMap.getNumDims(), oldLbMap.getNumSymbols(), lbExprs,
588 op.getContext());
589 AffineValueMap paddedLbValueMap(paddedLbMap, op.getLowerBoundOperands());
590 AffineValueMap ubValueMap(op.getUpperBoundMap(), op.getUpperBoundOperands());
591 AffineValueMap newUbValueMap;
592 // Compute the `upper bound - lower bound`.
593 AffineValueMap::difference(a: ubValueMap, b: paddedLbValueMap, res: &newUbValueMap);
594 (void)newUbValueMap.canonicalize();
595
596 // Scale down the upper bound value map by the loop step.
597 unsigned numResult = newUbValueMap.getNumResults();
598 SmallVector<AffineExpr> scaleDownExprs(numResult);
599 for (unsigned i = 0; i < numResult; ++i)
600 scaleDownExprs[i] = opBuilder.getAffineDimExpr(position: i).ceilDiv(v: origLoopStep);
601 // `scaleDownMap` is (d0, d1, ..., d_n) -> (d0 / step, d1 / step, ..., d_n /
602 // step). Where `n` is the number of results in the upper bound map.
603 AffineMap scaleDownMap =
604 AffineMap::get(numResult, 0, scaleDownExprs, op.getContext());
605 AffineMap newUbMap = scaleDownMap.compose(map: newUbValueMap.getAffineMap());
606
607 // Set the newly create upper bound map and operands.
608 op.setUpperBound(newUbValueMap.getOperands(), newUbMap);
609 op.setLowerBound({}, opBuilder.getConstantAffineMap(val: 0));
610 op.setStep(1);
611
612 // Calculate the Value of new loopIV. Create affine.apply for the value of
613 // the loopIV in normalized loop.
614 opBuilder.setInsertionPointToStart(op.getBody());
615 // Construct an affine.apply op mapping the new IV to the old IV.
616 AffineMap scaleIvMap =
617 AffineMap::get(dimCount: 1, symbolCount: 0, result: -opBuilder.getAffineDimExpr(position: 0) * origLoopStep);
618 AffineValueMap scaleIvValueMap(scaleIvMap, ValueRange{op.getInductionVar()});
619 AffineValueMap newIvToOldIvMap;
620 AffineValueMap::difference(a: lbMap, b: scaleIvValueMap, res: &newIvToOldIvMap);
621 (void)newIvToOldIvMap.canonicalize();
622 auto newIV = opBuilder.create<AffineApplyOp>(
623 loc, newIvToOldIvMap.getAffineMap(), newIvToOldIvMap.getOperands());
624 op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0), newIV);
625 return success();
626}
627
628/// Returns true if the memory operation of `destAccess` depends on `srcAccess`
629/// inside of the innermost common surrounding affine loop between the two
630/// accesses.
631static bool mustReachAtInnermost(const MemRefAccess &srcAccess,
632 const MemRefAccess &destAccess) {
633 // Affine dependence analysis is possible only if both ops in the same
634 // AffineScope.
635 if (getAffineScope(op: srcAccess.opInst) != getAffineScope(op: destAccess.opInst))
636 return false;
637
638 unsigned nsLoops =
639 getNumCommonSurroundingLoops(a&: *srcAccess.opInst, b&: *destAccess.opInst);
640 DependenceResult result =
641 checkMemrefAccessDependence(srcAccess, dstAccess: destAccess, loopDepth: nsLoops + 1);
642 return hasDependence(result);
643}
644
645/// Returns true if `srcMemOp` may have an effect on `destMemOp` within the
646/// scope of the outermost `minSurroundingLoops` loops that surround them.
647/// `srcMemOp` and `destMemOp` are expected to be affine read/write ops.
648static bool mayHaveEffect(Operation *srcMemOp, Operation *destMemOp,
649 unsigned minSurroundingLoops) {
650 MemRefAccess srcAccess(srcMemOp);
651 MemRefAccess destAccess(destMemOp);
652
653 // Affine dependence analysis here is applicable only if both ops operate on
654 // the same memref and if `srcMemOp` and `destMemOp` are in the same
655 // AffineScope. Also, we can only check if our affine scope is isolated from
656 // above; otherwise, values can from outside of the affine scope that the
657 // check below cannot analyze.
658 Region *srcScope = getAffineScope(op: srcMemOp);
659 if (srcAccess.memref == destAccess.memref &&
660 srcScope == getAffineScope(op: destMemOp)) {
661 unsigned nsLoops = getNumCommonSurroundingLoops(a&: *srcMemOp, b&: *destMemOp);
662 FlatAffineValueConstraints dependenceConstraints;
663 for (unsigned d = nsLoops + 1; d > minSurroundingLoops; d--) {
664 DependenceResult result = checkMemrefAccessDependence(
665 srcAccess, dstAccess: destAccess, loopDepth: d, dependenceConstraints: &dependenceConstraints,
666 /*dependenceComponents=*/nullptr);
667 // A dependence failure or the presence of a dependence implies a
668 // side effect.
669 if (!noDependence(result))
670 return true;
671 }
672 // No side effect was seen.
673 return false;
674 }
675 // TODO: Check here if the memrefs alias: there is no side effect if
676 // `srcAccess.memref` and `destAccess.memref` don't alias.
677 return true;
678}
679
680template <typename EffectType, typename T>
681bool mlir::affine::hasNoInterveningEffect(Operation *start, T memOp) {
682 auto isLocallyAllocated = [](Value memref) {
683 auto *defOp = memref.getDefiningOp();
684 return defOp && hasSingleEffect<MemoryEffects::Allocate>(op: defOp, value: memref);
685 };
686
687 // A boolean representing whether an intervening operation could have impacted
688 // memOp.
689 bool hasSideEffect = false;
690
691 // Check whether the effect on memOp can be caused by a given operation op.
692 Value memref = memOp.getMemRef();
693 std::function<void(Operation *)> checkOperation = [&](Operation *op) {
694 // If the effect has alreay been found, early exit,
695 if (hasSideEffect)
696 return;
697
698 if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(op)) {
699 SmallVector<MemoryEffects::EffectInstance, 1> effects;
700 memEffect.getEffects(effects);
701
702 bool opMayHaveEffect = false;
703 for (auto effect : effects) {
704 // If op causes EffectType on a potentially aliasing location for
705 // memOp, mark as having the effect.
706 if (isa<EffectType>(effect.getEffect())) {
707 // TODO: This should be replaced with a check for no aliasing.
708 // Aliasing information should be passed to this method.
709 if (effect.getValue() && effect.getValue() != memref &&
710 isLocallyAllocated(memref) &&
711 isLocallyAllocated(effect.getValue()))
712 continue;
713 opMayHaveEffect = true;
714 break;
715 }
716 }
717
718 if (!opMayHaveEffect)
719 return;
720
721 // If the side effect comes from an affine read or write, try to
722 // prove the side effecting `op` cannot reach `memOp`.
723 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
724 // For ease, let's consider the case that `op` is a store and
725 // we're looking for other potential stores that overwrite memory after
726 // `start`, and before being read in `memOp`. In this case, we only
727 // need to consider other potential stores with depth >
728 // minSurroundingLoops since `start` would overwrite any store with a
729 // smaller number of surrounding loops before.
730 unsigned minSurroundingLoops =
731 getNumCommonSurroundingLoops(*start, *memOp);
732 if (mayHaveEffect(op, memOp, minSurroundingLoops))
733 hasSideEffect = true;
734 return;
735 }
736
737 // We have an op with a memory effect and we cannot prove if it
738 // intervenes.
739 hasSideEffect = true;
740 return;
741 }
742
743 if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) {
744 // Recurse into the regions for this op and check whether the internal
745 // operations may have the side effect `EffectType` on memOp.
746 for (Region &region : op->getRegions())
747 for (Block &block : region)
748 for (Operation &op : block)
749 checkOperation(&op);
750 return;
751 }
752
753 // Otherwise, conservatively assume generic operations have the effect
754 // on the operation
755 hasSideEffect = true;
756 };
757
758 // Check all paths from ancestor op `parent` to the operation `to` for the
759 // effect. It is known that `to` must be contained within `parent`.
760 auto until = [&](Operation *parent, Operation *to) {
761 // TODO check only the paths from `parent` to `to`.
762 // Currently we fallback and check the entire parent op, rather than
763 // just the paths from the parent path, stopping after reaching `to`.
764 // This is conservatively correct, but could be made more aggressive.
765 assert(parent->isAncestor(to));
766 checkOperation(parent);
767 };
768
769 // Check for all paths from operation `from` to operation `untilOp` for the
770 // given memory effect.
771 std::function<void(Operation *, Operation *)> recur =
772 [&](Operation *from, Operation *untilOp) {
773 assert(
774 from->getParentRegion()->isAncestor(untilOp->getParentRegion()) &&
775 "Checking for side effect between two operations without a common "
776 "ancestor");
777
778 // If the operations are in different regions, recursively consider all
779 // path from `from` to the parent of `to` and all paths from the parent
780 // of `to` to `to`.
781 if (from->getParentRegion() != untilOp->getParentRegion()) {
782 recur(from, untilOp->getParentOp());
783 until(untilOp->getParentOp(), untilOp);
784 return;
785 }
786
787 // Now, assuming that `from` and `to` exist in the same region, perform
788 // a CFG traversal to check all the relevant operations.
789
790 // Additional blocks to consider.
791 SmallVector<Block *, 2> todoBlocks;
792 {
793 // First consider the parent block of `from` an check all operations
794 // after `from`.
795 for (auto iter = ++from->getIterator(), end = from->getBlock()->end();
796 iter != end && &*iter != untilOp; ++iter) {
797 checkOperation(&*iter);
798 }
799
800 // If the parent of `from` doesn't contain `to`, add the successors
801 // to the list of blocks to check.
802 if (untilOp->getBlock() != from->getBlock())
803 for (Block *succ : from->getBlock()->getSuccessors())
804 todoBlocks.push_back(Elt: succ);
805 }
806
807 SmallPtrSet<Block *, 4> done;
808 // Traverse the CFG until hitting `to`.
809 while (!todoBlocks.empty()) {
810 Block *blk = todoBlocks.pop_back_val();
811 if (done.count(Ptr: blk))
812 continue;
813 done.insert(Ptr: blk);
814 for (auto &op : *blk) {
815 if (&op == untilOp)
816 break;
817 checkOperation(&op);
818 if (&op == blk->getTerminator())
819 for (Block *succ : blk->getSuccessors())
820 todoBlocks.push_back(Elt: succ);
821 }
822 }
823 };
824 recur(start, memOp);
825 return !hasSideEffect;
826}
827
828/// Attempt to eliminate loadOp by replacing it with a value stored into memory
829/// which the load is guaranteed to retrieve. This check involves three
830/// components: 1) The store and load must be on the same location 2) The store
831/// must dominate (and therefore must always occur prior to) the load 3) No
832/// other operations will overwrite the memory loaded between the given load
833/// and store. If such a value exists, the replaced `loadOp` will be added to
834/// `loadOpsToErase` and its memref will be added to `memrefsToErase`.
835static void forwardStoreToLoad(AffineReadOpInterface loadOp,
836 SmallVectorImpl<Operation *> &loadOpsToErase,
837 SmallPtrSetImpl<Value> &memrefsToErase,
838 DominanceInfo &domInfo) {
839
840 // The store op candidate for forwarding that satisfies all conditions
841 // to replace the load, if any.
842 Operation *lastWriteStoreOp = nullptr;
843
844 for (auto *user : loadOp.getMemRef().getUsers()) {
845 auto storeOp = dyn_cast<AffineWriteOpInterface>(user);
846 if (!storeOp)
847 continue;
848 MemRefAccess srcAccess(storeOp);
849 MemRefAccess destAccess(loadOp);
850
851 // 1. Check if the store and the load have mathematically equivalent
852 // affine access functions; this implies that they statically refer to the
853 // same single memref element. As an example this filters out cases like:
854 // store %A[%i0 + 1]
855 // load %A[%i0]
856 // store %A[%M]
857 // load %A[%N]
858 // Use the AffineValueMap difference based memref access equality checking.
859 if (srcAccess != destAccess)
860 continue;
861
862 // 2. The store has to dominate the load op to be candidate.
863 if (!domInfo.dominates(storeOp, loadOp))
864 continue;
865
866 // 3. The store must reach the load. Access function equivalence only
867 // guarantees this for accesses in the same block. The load could be in a
868 // nested block that is unreachable.
869 if (storeOp->getBlock() != loadOp->getBlock() &&
870 !mustReachAtInnermost(srcAccess, destAccess))
871 continue;
872
873 // 4. Ensure there is no intermediate operation which could replace the
874 // value in memory.
875 if (!affine::hasNoInterveningEffect<MemoryEffects::Write>(storeOp, loadOp))
876 continue;
877
878 // We now have a candidate for forwarding.
879 assert(lastWriteStoreOp == nullptr &&
880 "multiple simultaneous replacement stores");
881 lastWriteStoreOp = storeOp;
882 }
883
884 if (!lastWriteStoreOp)
885 return;
886
887 // Perform the actual store to load forwarding.
888 Value storeVal =
889 cast<AffineWriteOpInterface>(lastWriteStoreOp).getValueToStore();
890 // Check if 2 values have the same shape. This is needed for affine vector
891 // loads and stores.
892 if (storeVal.getType() != loadOp.getValue().getType())
893 return;
894 loadOp.getValue().replaceAllUsesWith(storeVal);
895 // Record the memref for a later sweep to optimize away.
896 memrefsToErase.insert(loadOp.getMemRef());
897 // Record this to erase later.
898 loadOpsToErase.push_back(Elt: loadOp);
899}
900
901template bool
902mlir::affine::hasNoInterveningEffect<mlir::MemoryEffects::Read,
903 affine::AffineReadOpInterface>(
904 mlir::Operation *, affine::AffineReadOpInterface);
905
906// This attempts to find stores which have no impact on the final result.
907// A writing op writeA will be eliminated if there exists an op writeB if
908// 1) writeA and writeB have mathematically equivalent affine access functions.
909// 2) writeB postdominates writeA.
910// 3) There is no potential read between writeA and writeB.
911static void findUnusedStore(AffineWriteOpInterface writeA,
912 SmallVectorImpl<Operation *> &opsToErase,
913 PostDominanceInfo &postDominanceInfo) {
914
915 for (Operation *user : writeA.getMemRef().getUsers()) {
916 // Only consider writing operations.
917 auto writeB = dyn_cast<AffineWriteOpInterface>(user);
918 if (!writeB)
919 continue;
920
921 // The operations must be distinct.
922 if (writeB == writeA)
923 continue;
924
925 // Both operations must lie in the same region.
926 if (writeB->getParentRegion() != writeA->getParentRegion())
927 continue;
928
929 // Both operations must write to the same memory.
930 MemRefAccess srcAccess(writeB);
931 MemRefAccess destAccess(writeA);
932
933 if (srcAccess != destAccess)
934 continue;
935
936 // writeB must postdominate writeA.
937 if (!postDominanceInfo.postDominates(writeB, writeA))
938 continue;
939
940 // There cannot be an operation which reads from memory between
941 // the two writes.
942 if (!affine::hasNoInterveningEffect<MemoryEffects::Read>(writeA, writeB))
943 continue;
944
945 opsToErase.push_back(writeA);
946 break;
947 }
948}
949
950// The load to load forwarding / redundant load elimination is similar to the
951// store to load forwarding.
952// loadA will be be replaced with loadB if:
953// 1) loadA and loadB have mathematically equivalent affine access functions.
954// 2) loadB dominates loadA.
955// 3) There is no write between loadA and loadB.
956static void loadCSE(AffineReadOpInterface loadA,
957 SmallVectorImpl<Operation *> &loadOpsToErase,
958 DominanceInfo &domInfo) {
959 SmallVector<AffineReadOpInterface, 4> loadCandidates;
960 for (auto *user : loadA.getMemRef().getUsers()) {
961 auto loadB = dyn_cast<AffineReadOpInterface>(user);
962 if (!loadB || loadB == loadA)
963 continue;
964
965 MemRefAccess srcAccess(loadB);
966 MemRefAccess destAccess(loadA);
967
968 // 1. The accesses should be to be to the same location.
969 if (srcAccess != destAccess) {
970 continue;
971 }
972
973 // 2. loadB should dominate loadA.
974 if (!domInfo.dominates(loadB, loadA))
975 continue;
976
977 // 3. There should not be a write between loadA and loadB.
978 if (!affine::hasNoInterveningEffect<MemoryEffects::Write>(
979 loadB.getOperation(), loadA))
980 continue;
981
982 // Check if two values have the same shape. This is needed for affine vector
983 // loads.
984 if (loadB.getValue().getType() != loadA.getValue().getType())
985 continue;
986
987 loadCandidates.push_back(loadB);
988 }
989
990 // Of the legal load candidates, use the one that dominates all others
991 // to minimize the subsequent need to loadCSE
992 Value loadB;
993 for (AffineReadOpInterface option : loadCandidates) {
994 if (llvm::all_of(loadCandidates, [&](AffineReadOpInterface depStore) {
995 return depStore == option ||
996 domInfo.dominates(option.getOperation(),
997 depStore.getOperation());
998 })) {
999 loadB = option.getValue();
1000 break;
1001 }
1002 }
1003
1004 if (loadB) {
1005 loadA.getValue().replaceAllUsesWith(loadB);
1006 // Record this to erase later.
1007 loadOpsToErase.push_back(Elt: loadA);
1008 }
1009}
1010
1011// The store to load forwarding and load CSE rely on three conditions:
1012//
1013// 1) store/load providing a replacement value and load being replaced need to
1014// have mathematically equivalent affine access functions (checked after full
1015// composition of load/store operands); this implies that they access the same
1016// single memref element for all iterations of the common surrounding loop,
1017//
1018// 2) the store/load op should dominate the load op,
1019//
1020// 3) no operation that may write to memory read by the load being replaced can
1021// occur after executing the instruction (load or store) providing the
1022// replacement value and before the load being replaced (thus potentially
1023// allowing overwriting the memory read by the load).
1024//
1025// The above conditions are simple to check, sufficient, and powerful for most
1026// cases in practice - they are sufficient, but not necessary --- since they
1027// don't reason about loops that are guaranteed to execute at least once or
1028// multiple sources to forward from.
1029//
1030// TODO: more forwarding can be done when support for
1031// loop/conditional live-out SSA values is available.
1032// TODO: do general dead store elimination for memref's. This pass
1033// currently only eliminates the stores only if no other loads/uses (other
1034// than dealloc) remain.
1035//
1036void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
1037 PostDominanceInfo &postDomInfo) {
1038 // Load op's whose results were replaced by those forwarded from stores.
1039 SmallVector<Operation *, 8> opsToErase;
1040
1041 // A list of memref's that are potentially dead / could be eliminated.
1042 SmallPtrSet<Value, 4> memrefsToErase;
1043
1044 // Walk all load's and perform store to load forwarding.
1045 f.walk([&](AffineReadOpInterface loadOp) {
1046 forwardStoreToLoad(loadOp, opsToErase, memrefsToErase, domInfo);
1047 });
1048 for (auto *op : opsToErase)
1049 op->erase();
1050 opsToErase.clear();
1051
1052 // Walk all store's and perform unused store elimination
1053 f.walk([&](AffineWriteOpInterface storeOp) {
1054 findUnusedStore(storeOp, opsToErase, postDomInfo);
1055 });
1056 for (auto *op : opsToErase)
1057 op->erase();
1058 opsToErase.clear();
1059
1060 // Check if the store fwd'ed memrefs are now left with only stores and
1061 // deallocs and can thus be completely deleted. Note: the canonicalize pass
1062 // should be able to do this as well, but we'll do it here since we collected
1063 // these anyway.
1064 for (auto memref : memrefsToErase) {
1065 // If the memref hasn't been locally alloc'ed, skip.
1066 Operation *defOp = memref.getDefiningOp();
1067 if (!defOp || !hasSingleEffect<MemoryEffects::Allocate>(op: defOp, value: memref))
1068 // TODO: if the memref was returned by a 'call' operation, we
1069 // could still erase it if the call had no side-effects.
1070 continue;
1071 if (llvm::any_of(Range: memref.getUsers(), P: [&](Operation *ownerOp) {
1072 return !isa<AffineWriteOpInterface>(ownerOp) &&
1073 !hasSingleEffect<MemoryEffects::Free>(ownerOp, memref);
1074 }))
1075 continue;
1076
1077 // Erase all stores, the dealloc, and the alloc on the memref.
1078 for (auto *user : llvm::make_early_inc_range(Range: memref.getUsers()))
1079 user->erase();
1080 defOp->erase();
1081 }
1082
1083 // To eliminate as many loads as possible, run load CSE after eliminating
1084 // stores. Otherwise, some stores are wrongly seen as having an intervening
1085 // effect.
1086 f.walk([&](AffineReadOpInterface loadOp) {
1087 loadCSE(loadOp, opsToErase, domInfo);
1088 });
1089 for (auto *op : opsToErase)
1090 op->erase();
1091}
1092
1093// Perform the replacement in `op`.
1094LogicalResult mlir::affine::replaceAllMemRefUsesWith(
1095 Value oldMemRef, Value newMemRef, Operation *op,
1096 ArrayRef<Value> extraIndices, AffineMap indexRemap,
1097 ArrayRef<Value> extraOperands, ArrayRef<Value> symbolOperands,
1098 bool allowNonDereferencingOps) {
1099 unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType()).getRank();
1100 (void)newMemRefRank; // unused in opt mode
1101 unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType()).getRank();
1102 (void)oldMemRefRank; // unused in opt mode
1103 if (indexRemap) {
1104 assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
1105 "symbolic operand count mismatch");
1106 assert(indexRemap.getNumInputs() ==
1107 extraOperands.size() + oldMemRefRank + symbolOperands.size());
1108 assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
1109 } else {
1110 assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
1111 }
1112
1113 // Assert same elemental type.
1114 assert(cast<MemRefType>(oldMemRef.getType()).getElementType() ==
1115 cast<MemRefType>(newMemRef.getType()).getElementType());
1116
1117 SmallVector<unsigned, 2> usePositions;
1118 for (const auto &opEntry : llvm::enumerate(First: op->getOperands())) {
1119 if (opEntry.value() == oldMemRef)
1120 usePositions.push_back(Elt: opEntry.index());
1121 }
1122
1123 // If memref doesn't appear, nothing to do.
1124 if (usePositions.empty())
1125 return success();
1126
1127 if (usePositions.size() > 1) {
1128 // TODO: extend it for this case when needed (rare).
1129 assert(false && "multiple dereferencing uses in a single op not supported");
1130 return failure();
1131 }
1132
1133 unsigned memRefOperandPos = usePositions.front();
1134
1135 OpBuilder builder(op);
1136 // The following checks if op is dereferencing memref and performs the access
1137 // index rewrites.
1138 auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
1139 if (!affMapAccInterface) {
1140 if (!allowNonDereferencingOps) {
1141 // Failure: memref used in a non-dereferencing context (potentially
1142 // escapes); no replacement in these cases unless allowNonDereferencingOps
1143 // is set.
1144 return failure();
1145 }
1146 op->setOperand(idx: memRefOperandPos, value: newMemRef);
1147 return success();
1148 }
1149 // Perform index rewrites for the dereferencing op and then replace the op
1150 NamedAttribute oldMapAttrPair =
1151 affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
1152 AffineMap oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue()).getValue();
1153 unsigned oldMapNumInputs = oldMap.getNumInputs();
1154 SmallVector<Value, 4> oldMapOperands(
1155 op->operand_begin() + memRefOperandPos + 1,
1156 op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
1157
1158 // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
1159 SmallVector<Value, 4> oldMemRefOperands;
1160 SmallVector<Value, 4> affineApplyOps;
1161 oldMemRefOperands.reserve(N: oldMemRefRank);
1162 if (oldMap != builder.getMultiDimIdentityMap(rank: oldMap.getNumDims())) {
1163 for (auto resultExpr : oldMap.getResults()) {
1164 auto singleResMap = AffineMap::get(oldMap.getNumDims(),
1165 oldMap.getNumSymbols(), resultExpr);
1166 auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
1167 oldMapOperands);
1168 oldMemRefOperands.push_back(afOp);
1169 affineApplyOps.push_back(afOp);
1170 }
1171 } else {
1172 oldMemRefOperands.assign(in_start: oldMapOperands.begin(), in_end: oldMapOperands.end());
1173 }
1174
1175 // Construct new indices as a remap of the old ones if a remapping has been
1176 // provided. The indices of a memref come right after it, i.e.,
1177 // at position memRefOperandPos + 1.
1178 SmallVector<Value, 4> remapOperands;
1179 remapOperands.reserve(N: extraOperands.size() + oldMemRefRank +
1180 symbolOperands.size());
1181 remapOperands.append(in_start: extraOperands.begin(), in_end: extraOperands.end());
1182 remapOperands.append(in_start: oldMemRefOperands.begin(), in_end: oldMemRefOperands.end());
1183 remapOperands.append(in_start: symbolOperands.begin(), in_end: symbolOperands.end());
1184
1185 SmallVector<Value, 4> remapOutputs;
1186 remapOutputs.reserve(N: oldMemRefRank);
1187
1188 if (indexRemap &&
1189 indexRemap != builder.getMultiDimIdentityMap(rank: indexRemap.getNumDims())) {
1190 // Remapped indices.
1191 for (auto resultExpr : indexRemap.getResults()) {
1192 auto singleResMap = AffineMap::get(
1193 dimCount: indexRemap.getNumDims(), symbolCount: indexRemap.getNumSymbols(), result: resultExpr);
1194 auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
1195 remapOperands);
1196 remapOutputs.push_back(Elt: afOp);
1197 affineApplyOps.push_back(Elt: afOp);
1198 }
1199 } else {
1200 // No remapping specified.
1201 remapOutputs.assign(in_start: remapOperands.begin(), in_end: remapOperands.end());
1202 }
1203
1204 SmallVector<Value, 4> newMapOperands;
1205 newMapOperands.reserve(N: newMemRefRank);
1206
1207 // Prepend 'extraIndices' in 'newMapOperands'.
1208 for (Value extraIndex : extraIndices) {
1209 assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
1210 "invalid memory op index");
1211 newMapOperands.push_back(Elt: extraIndex);
1212 }
1213
1214 // Append 'remapOutputs' to 'newMapOperands'.
1215 newMapOperands.append(in_start: remapOutputs.begin(), in_end: remapOutputs.end());
1216
1217 // Create new fully composed AffineMap for new op to be created.
1218 assert(newMapOperands.size() == newMemRefRank);
1219 auto newMap = builder.getMultiDimIdentityMap(rank: newMemRefRank);
1220 fullyComposeAffineMapAndOperands(&newMap, &newMapOperands);
1221 newMap = simplifyAffineMap(newMap);
1222 canonicalizeMapAndOperands(&newMap, &newMapOperands);
1223 // Remove any affine.apply's that became dead as a result of composition.
1224 for (Value value : affineApplyOps)
1225 if (value.use_empty())
1226 value.getDefiningOp()->erase();
1227
1228 OperationState state(op->getLoc(), op->getName());
1229 // Construct the new operation using this memref.
1230 state.operands.reserve(N: op->getNumOperands() + extraIndices.size());
1231 // Insert the non-memref operands.
1232 state.operands.append(in_start: op->operand_begin(),
1233 in_end: op->operand_begin() + memRefOperandPos);
1234 // Insert the new memref value.
1235 state.operands.push_back(Elt: newMemRef);
1236
1237 // Insert the new memref map operands.
1238 state.operands.append(in_start: newMapOperands.begin(), in_end: newMapOperands.end());
1239
1240 // Insert the remaining operands unmodified.
1241 state.operands.append(in_start: op->operand_begin() + memRefOperandPos + 1 +
1242 oldMapNumInputs,
1243 in_end: op->operand_end());
1244
1245 // Result types don't change. Both memref's are of the same elemental type.
1246 state.types.reserve(N: op->getNumResults());
1247 for (auto result : op->getResults())
1248 state.types.push_back(Elt: result.getType());
1249
1250 // Add attribute for 'newMap', other Attributes do not change.
1251 auto newMapAttr = AffineMapAttr::get(newMap);
1252 for (auto namedAttr : op->getAttrs()) {
1253 if (namedAttr.getName() == oldMapAttrPair.getName())
1254 state.attributes.push_back(newAttribute: {namedAttr.getName(), newMapAttr});
1255 else
1256 state.attributes.push_back(newAttribute: namedAttr);
1257 }
1258
1259 // Create the new operation.
1260 auto *repOp = builder.create(state);
1261 op->replaceAllUsesWith(values&: repOp);
1262 op->erase();
1263
1264 return success();
1265}
1266
1267LogicalResult mlir::affine::replaceAllMemRefUsesWith(
1268 Value oldMemRef, Value newMemRef, ArrayRef<Value> extraIndices,
1269 AffineMap indexRemap, ArrayRef<Value> extraOperands,
1270 ArrayRef<Value> symbolOperands, Operation *domOpFilter,
1271 Operation *postDomOpFilter, bool allowNonDereferencingOps,
1272 bool replaceInDeallocOp) {
1273 unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType()).getRank();
1274 (void)newMemRefRank; // unused in opt mode
1275 unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType()).getRank();
1276 (void)oldMemRefRank;
1277 if (indexRemap) {
1278 assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
1279 "symbol operand count mismatch");
1280 assert(indexRemap.getNumInputs() ==
1281 extraOperands.size() + oldMemRefRank + symbolOperands.size());
1282 assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
1283 } else {
1284 assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
1285 }
1286
1287 // Assert same elemental type.
1288 assert(cast<MemRefType>(oldMemRef.getType()).getElementType() ==
1289 cast<MemRefType>(newMemRef.getType()).getElementType());
1290
1291 std::unique_ptr<DominanceInfo> domInfo;
1292 std::unique_ptr<PostDominanceInfo> postDomInfo;
1293 if (domOpFilter)
1294 domInfo = std::make_unique<DominanceInfo>(
1295 args: domOpFilter->getParentOfType<func::FuncOp>());
1296
1297 if (postDomOpFilter)
1298 postDomInfo = std::make_unique<PostDominanceInfo>(
1299 args: postDomOpFilter->getParentOfType<func::FuncOp>());
1300
1301 // Walk all uses of old memref; collect ops to perform replacement. We use a
1302 // DenseSet since an operation could potentially have multiple uses of a
1303 // memref (although rare), and the replacement later is going to erase ops.
1304 DenseSet<Operation *> opsToReplace;
1305 for (auto *op : oldMemRef.getUsers()) {
1306 // Skip this use if it's not dominated by domOpFilter.
1307 if (domOpFilter && !domInfo->dominates(a: domOpFilter, b: op))
1308 continue;
1309
1310 // Skip this use if it's not post-dominated by postDomOpFilter.
1311 if (postDomOpFilter && !postDomInfo->postDominates(a: postDomOpFilter, b: op))
1312 continue;
1313
1314 // Skip dealloc's - no replacement is necessary, and a memref replacement
1315 // at other uses doesn't hurt these dealloc's.
1316 if (hasSingleEffect<MemoryEffects::Free>(op, value: oldMemRef) &&
1317 !replaceInDeallocOp)
1318 continue;
1319
1320 // Check if the memref was used in a non-dereferencing context. It is fine
1321 // for the memref to be used in a non-dereferencing way outside of the
1322 // region where this replacement is happening.
1323 if (!isa<AffineMapAccessInterface>(*op)) {
1324 if (!allowNonDereferencingOps) {
1325 LLVM_DEBUG(llvm::dbgs()
1326 << "Memref replacement failed: non-deferencing memref op: \n"
1327 << *op << '\n');
1328 return failure();
1329 }
1330 // Non-dereferencing ops with the MemRefsNormalizable trait are
1331 // supported for replacement.
1332 if (!op->hasTrait<OpTrait::MemRefsNormalizable>()) {
1333 LLVM_DEBUG(llvm::dbgs() << "Memref replacement failed: use without a "
1334 "memrefs normalizable trait: \n"
1335 << *op << '\n');
1336 return failure();
1337 }
1338 }
1339
1340 // We'll first collect and then replace --- since replacement erases the op
1341 // that has the use, and that op could be postDomFilter or domFilter itself!
1342 opsToReplace.insert(V: op);
1343 }
1344
1345 for (auto *op : opsToReplace) {
1346 if (failed(result: replaceAllMemRefUsesWith(
1347 oldMemRef, newMemRef, op, extraIndices, indexRemap, extraOperands,
1348 symbolOperands, allowNonDereferencingOps)))
1349 llvm_unreachable("memref replacement guaranteed to succeed here");
1350 }
1351
1352 return success();
1353}
1354
1355/// Given an operation, inserts one or more single result affine
1356/// apply operations, results of which are exclusively used by this operation
1357/// operation. The operands of these newly created affine apply ops are
1358/// guaranteed to be loop iterators or terminal symbols of a function.
1359///
1360/// Before
1361///
1362/// affine.for %i = 0 to #map(%N)
1363/// %idx = affine.apply (d0) -> (d0 mod 2) (%i)
1364/// "send"(%idx, %A, ...)
1365/// "compute"(%idx)
1366///
1367/// After
1368///
1369/// affine.for %i = 0 to #map(%N)
1370/// %idx = affine.apply (d0) -> (d0 mod 2) (%i)
1371/// "send"(%idx, %A, ...)
1372/// %idx_ = affine.apply (d0) -> (d0 mod 2) (%i)
1373/// "compute"(%idx_)
1374///
1375/// This allows applying different transformations on send and compute (for eg.
1376/// different shifts/delays).
1377///
1378/// Returns nullptr either if none of opInst's operands were the result of an
1379/// affine.apply and thus there was no affine computation slice to create, or if
1380/// all the affine.apply op's supplying operands to this opInst did not have any
1381/// uses besides this opInst; otherwise returns the list of affine.apply
1382/// operations created in output argument `sliceOps`.
1383void mlir::affine::createAffineComputationSlice(
1384 Operation *opInst, SmallVectorImpl<AffineApplyOp> *sliceOps) {
1385 // Collect all operands that are results of affine apply ops.
1386 SmallVector<Value, 4> subOperands;
1387 subOperands.reserve(N: opInst->getNumOperands());
1388 for (auto operand : opInst->getOperands())
1389 if (isa_and_nonnull<AffineApplyOp>(Val: operand.getDefiningOp()))
1390 subOperands.push_back(Elt: operand);
1391
1392 // Gather sequence of AffineApplyOps reachable from 'subOperands'.
1393 SmallVector<Operation *, 4> affineApplyOps;
1394 getReachableAffineApplyOps(operands: subOperands, affineApplyOps);
1395 // Skip transforming if there are no affine maps to compose.
1396 if (affineApplyOps.empty())
1397 return;
1398
1399 // Check if all uses of the affine apply op's lie only in this op op, in
1400 // which case there would be nothing to do.
1401 bool localized = true;
1402 for (auto *op : affineApplyOps) {
1403 for (auto result : op->getResults()) {
1404 for (auto *user : result.getUsers()) {
1405 if (user != opInst) {
1406 localized = false;
1407 break;
1408 }
1409 }
1410 }
1411 }
1412 if (localized)
1413 return;
1414
1415 OpBuilder builder(opInst);
1416 SmallVector<Value, 4> composedOpOperands(subOperands);
1417 auto composedMap = builder.getMultiDimIdentityMap(rank: composedOpOperands.size());
1418 fullyComposeAffineMapAndOperands(map: &composedMap, operands: &composedOpOperands);
1419
1420 // Create an affine.apply for each of the map results.
1421 sliceOps->reserve(composedMap.getNumResults());
1422 for (auto resultExpr : composedMap.getResults()) {
1423 auto singleResMap = AffineMap::get(dimCount: composedMap.getNumDims(),
1424 symbolCount: composedMap.getNumSymbols(), result: resultExpr);
1425 sliceOps->push_back(builder.create<AffineApplyOp>(
1426 opInst->getLoc(), singleResMap, composedOpOperands));
1427 }
1428
1429 // Construct the new operands that include the results from the composed
1430 // affine apply op above instead of existing ones (subOperands). So, they
1431 // differ from opInst's operands only for those operands in 'subOperands', for
1432 // which they will be replaced by the corresponding one from 'sliceOps'.
1433 SmallVector<Value, 4> newOperands(opInst->getOperands());
1434 for (Value &operand : newOperands) {
1435 // Replace the subOperands from among the new operands.
1436 unsigned j, f;
1437 for (j = 0, f = subOperands.size(); j < f; j++) {
1438 if (operand == subOperands[j])
1439 break;
1440 }
1441 if (j < subOperands.size())
1442 operand = (*sliceOps)[j];
1443 }
1444 for (unsigned idx = 0, e = newOperands.size(); idx < e; idx++)
1445 opInst->setOperand(idx, value: newOperands[idx]);
1446}
1447
1448/// Enum to set patterns of affine expr in tiled-layout map.
1449/// TileFloorDiv: <dim expr> div <tile size>
1450/// TileMod: <dim expr> mod <tile size>
1451/// TileNone: None of the above
1452/// Example:
1453/// #tiled_2d_128x256 = affine_map<(d0, d1)
1454/// -> (d0 div 128, d1 div 256, d0 mod 128, d1 mod 256)>
1455/// "d0 div 128" and "d1 div 256" ==> TileFloorDiv
1456/// "d0 mod 128" and "d1 mod 256" ==> TileMod
1457enum TileExprPattern { TileFloorDiv, TileMod, TileNone };
1458
1459/// Check if `map` is a tiled layout. In the tiled layout, specific k dimensions
1460/// being floordiv'ed by respective tile sizes appeare in a mod with the same
1461/// tile sizes, and no other expression involves those k dimensions. This
1462/// function stores a vector of tuples (`tileSizePos`) including AffineExpr for
1463/// tile size, positions of corresponding `floordiv` and `mod`. If it is not a
1464/// tiled layout, an empty vector is returned.
1465static LogicalResult getTileSizePos(
1466 AffineMap map,
1467 SmallVectorImpl<std::tuple<AffineExpr, unsigned, unsigned>> &tileSizePos) {
1468 // Create `floordivExprs` which is a vector of tuples including LHS and RHS of
1469 // `floordiv` and its position in `map` output.
1470 // Example: #tiled_2d_128x256 = affine_map<(d0, d1)
1471 // -> (d0 div 128, d1 div 256, d0 mod 128, d1 mod 256)>
1472 // In this example, `floordivExprs` includes {d0, 128, 0} and {d1, 256, 1}.
1473 SmallVector<std::tuple<AffineExpr, AffineExpr, unsigned>, 4> floordivExprs;
1474 unsigned pos = 0;
1475 for (AffineExpr expr : map.getResults()) {
1476 if (expr.getKind() == AffineExprKind::FloorDiv) {
1477 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
1478 if (isa<AffineConstantExpr>(Val: binaryExpr.getRHS()))
1479 floordivExprs.emplace_back(
1480 Args: std::make_tuple(args: binaryExpr.getLHS(), args: binaryExpr.getRHS(), args&: pos));
1481 }
1482 pos++;
1483 }
1484 // Not tiled layout if `floordivExprs` is empty.
1485 if (floordivExprs.empty()) {
1486 tileSizePos = SmallVector<std::tuple<AffineExpr, unsigned, unsigned>>{};
1487 return success();
1488 }
1489
1490 // Check if LHS of `floordiv` is used in LHS of `mod`. If not used, `map` is
1491 // not tiled layout.
1492 for (std::tuple<AffineExpr, AffineExpr, unsigned> fexpr : floordivExprs) {
1493 AffineExpr floordivExprLHS = std::get<0>(t&: fexpr);
1494 AffineExpr floordivExprRHS = std::get<1>(t&: fexpr);
1495 unsigned floordivPos = std::get<2>(t&: fexpr);
1496
1497 // Walk affinexpr of `map` output except `fexpr`, and check if LHS and RHS
1498 // of `fexpr` are used in LHS and RHS of `mod`. If LHS of `fexpr` is used
1499 // other expr, the map is not tiled layout. Example of non tiled layout:
1500 // affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 floordiv 256)>
1501 // affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 mod 128)>
1502 // affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 mod 256, d2 mod
1503 // 256)>
1504 bool found = false;
1505 pos = 0;
1506 for (AffineExpr expr : map.getResults()) {
1507 bool notTiled = false;
1508 if (pos != floordivPos) {
1509 expr.walk(callback: [&](AffineExpr e) {
1510 if (e == floordivExprLHS) {
1511 if (expr.getKind() == AffineExprKind::Mod) {
1512 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
1513 // If LHS and RHS of `mod` are the same with those of floordiv.
1514 if (floordivExprLHS == binaryExpr.getLHS() &&
1515 floordivExprRHS == binaryExpr.getRHS()) {
1516 // Save tile size (RHS of `mod`), and position of `floordiv` and
1517 // `mod` if same expr with `mod` is not found yet.
1518 if (!found) {
1519 tileSizePos.emplace_back(
1520 Args: std::make_tuple(args: binaryExpr.getRHS(), args&: floordivPos, args&: pos));
1521 found = true;
1522 } else {
1523 // Non tiled layout: Have multilpe `mod` with the same LHS.
1524 // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2
1525 // mod 256, d2 mod 256)>
1526 notTiled = true;
1527 }
1528 } else {
1529 // Non tiled layout: RHS of `mod` is different from `floordiv`.
1530 // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2
1531 // mod 128)>
1532 notTiled = true;
1533 }
1534 } else {
1535 // Non tiled layout: LHS is the same, but not `mod`.
1536 // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2
1537 // floordiv 256)>
1538 notTiled = true;
1539 }
1540 }
1541 });
1542 }
1543 if (notTiled) {
1544 tileSizePos = SmallVector<std::tuple<AffineExpr, unsigned, unsigned>>{};
1545 return success();
1546 }
1547 pos++;
1548 }
1549 }
1550 return success();
1551}
1552
1553/// Check if `dim` dimension of memrefType with `layoutMap` becomes dynamic
1554/// after normalization. Dimensions that include dynamic dimensions in the map
1555/// output will become dynamic dimensions. Return true if `dim` is dynamic
1556/// dimension.
1557///
1558/// Example:
1559/// #map0 = affine_map<(d0, d1) -> (d0, d1 floordiv 32, d1 mod 32)>
1560///
1561/// If d1 is dynamic dimension, 2nd and 3rd dimension of map output are dynamic.
1562/// memref<4x?xf32, #map0> ==> memref<4x?x?xf32>
1563static bool
1564isNormalizedMemRefDynamicDim(unsigned dim, AffineMap layoutMap,
1565 SmallVectorImpl<unsigned> &inMemrefTypeDynDims) {
1566 AffineExpr expr = layoutMap.getResults()[dim];
1567 // Check if affine expr of the dimension includes dynamic dimension of input
1568 // memrefType.
1569 MLIRContext *context = layoutMap.getContext();
1570 return expr
1571 .walk(callback: [&](AffineExpr e) {
1572 if (isa<AffineDimExpr>(Val: e) &&
1573 llvm::any_of(Range&: inMemrefTypeDynDims, P: [&](unsigned dim) {
1574 return e == getAffineDimExpr(position: dim, context);
1575 }))
1576 return WalkResult::interrupt();
1577 return WalkResult::advance();
1578 })
1579 .wasInterrupted();
1580}
1581
1582/// Create affine expr to calculate dimension size for a tiled-layout map.
1583static AffineExpr createDimSizeExprForTiledLayout(AffineExpr oldMapOutput,
1584 TileExprPattern pat) {
1585 // Create map output for the patterns.
1586 // "floordiv <tile size>" ==> "ceildiv <tile size>"
1587 // "mod <tile size>" ==> "<tile size>"
1588 AffineExpr newMapOutput;
1589 AffineBinaryOpExpr binaryExpr = nullptr;
1590 switch (pat) {
1591 case TileExprPattern::TileMod:
1592 binaryExpr = cast<AffineBinaryOpExpr>(Val&: oldMapOutput);
1593 newMapOutput = binaryExpr.getRHS();
1594 break;
1595 case TileExprPattern::TileFloorDiv:
1596 binaryExpr = cast<AffineBinaryOpExpr>(Val&: oldMapOutput);
1597 newMapOutput = getAffineBinaryOpExpr(
1598 kind: AffineExprKind::CeilDiv, lhs: binaryExpr.getLHS(), rhs: binaryExpr.getRHS());
1599 break;
1600 default:
1601 newMapOutput = oldMapOutput;
1602 }
1603 return newMapOutput;
1604}
1605
1606/// Create new maps to calculate each dimension size of `newMemRefType`, and
1607/// create `newDynamicSizes` from them by using AffineApplyOp.
1608///
1609/// Steps for normalizing dynamic memrefs for a tiled layout map
1610/// Example:
1611/// #map0 = affine_map<(d0, d1) -> (d0, d1 floordiv 32, d1 mod 32)>
1612/// %0 = dim %arg0, %c1 :memref<4x?xf32>
1613/// %1 = alloc(%0) : memref<4x?xf32, #map0>
1614///
1615/// (Before this function)
1616/// 1. Check if `map`(#map0) is a tiled layout using `getTileSizePos()`. Only
1617/// single layout map is supported.
1618///
1619/// 2. Create normalized memrefType using `isNormalizedMemRefDynamicDim()`. It
1620/// is memref<4x?x?xf32> in the above example.
1621///
1622/// (In this function)
1623/// 3. Create new maps to calculate each dimension of the normalized memrefType
1624/// using `createDimSizeExprForTiledLayout()`. In the tiled layout, the
1625/// dimension size can be calculated by replacing "floordiv <tile size>" with
1626/// "ceildiv <tile size>" and "mod <tile size>" with "<tile size>".
1627/// - New map in the above example
1628/// #map0 = affine_map<(d0, d1) -> (d0)>
1629/// #map1 = affine_map<(d0, d1) -> (d1 ceildiv 32)>
1630/// #map2 = affine_map<(d0, d1) -> (32)>
1631///
1632/// 4. Create AffineApplyOp to apply the new maps. The output of AffineApplyOp
1633/// is used in dynamicSizes of new AllocOp.
1634/// %0 = dim %arg0, %c1 : memref<4x?xf32>
1635/// %c4 = arith.constant 4 : index
1636/// %1 = affine.apply #map1(%c4, %0)
1637/// %2 = affine.apply #map2(%c4, %0)
1638static void createNewDynamicSizes(MemRefType oldMemRefType,
1639 MemRefType newMemRefType, AffineMap map,
1640 memref::AllocOp *allocOp, OpBuilder b,
1641 SmallVectorImpl<Value> &newDynamicSizes) {
1642 // Create new input for AffineApplyOp.
1643 SmallVector<Value, 4> inAffineApply;
1644 ArrayRef<int64_t> oldMemRefShape = oldMemRefType.getShape();
1645 unsigned dynIdx = 0;
1646 for (unsigned d = 0; d < oldMemRefType.getRank(); ++d) {
1647 if (oldMemRefShape[d] < 0) {
1648 // Use dynamicSizes of allocOp for dynamic dimension.
1649 inAffineApply.emplace_back(allocOp->getDynamicSizes()[dynIdx]);
1650 dynIdx++;
1651 } else {
1652 // Create ConstantOp for static dimension.
1653 auto constantAttr = b.getIntegerAttr(b.getIndexType(), oldMemRefShape[d]);
1654 inAffineApply.emplace_back(
1655 b.create<arith::ConstantOp>(allocOp->getLoc(), constantAttr));
1656 }
1657 }
1658
1659 // Create new map to calculate each dimension size of new memref for each
1660 // original map output. Only for dynamic dimesion of `newMemRefType`.
1661 unsigned newDimIdx = 0;
1662 ArrayRef<int64_t> newMemRefShape = newMemRefType.getShape();
1663 SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
1664 (void)getTileSizePos(map, tileSizePos);
1665 for (AffineExpr expr : map.getResults()) {
1666 if (newMemRefShape[newDimIdx] < 0) {
1667 // Create new maps to calculate each dimension size of new memref.
1668 enum TileExprPattern pat = TileExprPattern::TileNone;
1669 for (auto pos : tileSizePos) {
1670 if (newDimIdx == std::get<1>(t&: pos))
1671 pat = TileExprPattern::TileFloorDiv;
1672 else if (newDimIdx == std::get<2>(t&: pos))
1673 pat = TileExprPattern::TileMod;
1674 }
1675 AffineExpr newMapOutput = createDimSizeExprForTiledLayout(oldMapOutput: expr, pat);
1676 AffineMap newMap =
1677 AffineMap::get(dimCount: map.getNumInputs(), symbolCount: map.getNumSymbols(), result: newMapOutput);
1678 Value affineApp =
1679 b.create<AffineApplyOp>(allocOp->getLoc(), newMap, inAffineApply);
1680 newDynamicSizes.emplace_back(Args&: affineApp);
1681 }
1682 newDimIdx++;
1683 }
1684}
1685
1686// TODO: Currently works for static memrefs with a single layout map.
1687LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
1688 MemRefType memrefType = allocOp->getType();
1689 OpBuilder b(*allocOp);
1690
1691 // Fetch a new memref type after normalizing the old memref to have an
1692 // identity map layout.
1693 MemRefType newMemRefType = normalizeMemRefType(memrefType);
1694 if (newMemRefType == memrefType)
1695 // Either memrefType already had an identity map or the map couldn't be
1696 // transformed to an identity map.
1697 return failure();
1698
1699 Value oldMemRef = allocOp->getResult();
1700
1701 SmallVector<Value, 4> symbolOperands(allocOp->getSymbolOperands());
1702 AffineMap layoutMap = memrefType.getLayout().getAffineMap();
1703 memref::AllocOp newAlloc;
1704 // Check if `layoutMap` is a tiled layout. Only single layout map is
1705 // supported for normalizing dynamic memrefs.
1706 SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
1707 (void)getTileSizePos(map: layoutMap, tileSizePos);
1708 if (newMemRefType.getNumDynamicDims() > 0 && !tileSizePos.empty()) {
1709 MemRefType oldMemRefType = cast<MemRefType>(oldMemRef.getType());
1710 SmallVector<Value, 4> newDynamicSizes;
1711 createNewDynamicSizes(oldMemRefType, newMemRefType, layoutMap, allocOp, b,
1712 newDynamicSizes);
1713 // Add the new dynamic sizes in new AllocOp.
1714 newAlloc =
1715 b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType,
1716 newDynamicSizes, allocOp->getAlignmentAttr());
1717 } else {
1718 newAlloc = b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType,
1719 allocOp->getAlignmentAttr());
1720 }
1721 // Replace all uses of the old memref.
1722 if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
1723 /*extraIndices=*/{},
1724 /*indexRemap=*/layoutMap,
1725 /*extraOperands=*/{},
1726 /*symbolOperands=*/symbolOperands,
1727 /*domOpFilter=*/nullptr,
1728 /*postDomOpFilter=*/nullptr,
1729 /*allowNonDereferencingOps=*/true))) {
1730 // If it failed (due to escapes for example), bail out.
1731 newAlloc.erase();
1732 return failure();
1733 }
1734 // Replace any uses of the original alloc op and erase it. All remaining uses
1735 // have to be dealloc's; RAMUW above would've failed otherwise.
1736 assert(llvm::all_of(oldMemRef.getUsers(), [&](Operation *op) {
1737 return hasSingleEffect<MemoryEffects::Free>(op, oldMemRef);
1738 }));
1739 oldMemRef.replaceAllUsesWith(newValue: newAlloc);
1740 allocOp->erase();
1741 return success();
1742}
1743
1744MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
1745 unsigned rank = memrefType.getRank();
1746 if (rank == 0)
1747 return memrefType;
1748
1749 if (memrefType.getLayout().isIdentity()) {
1750 // Either no maps is associated with this memref or this memref has
1751 // a trivial (identity) map.
1752 return memrefType;
1753 }
1754 AffineMap layoutMap = memrefType.getLayout().getAffineMap();
1755 unsigned numSymbolicOperands = layoutMap.getNumSymbols();
1756
1757 // We don't do any checks for one-to-one'ness; we assume that it is
1758 // one-to-one.
1759
1760 // Normalize only static memrefs and dynamic memrefs with a tiled-layout map
1761 // for now.
1762 // TODO: Normalize the other types of dynamic memrefs.
1763 SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
1764 (void)getTileSizePos(map: layoutMap, tileSizePos);
1765 if (memrefType.getNumDynamicDims() > 0 && tileSizePos.empty())
1766 return memrefType;
1767
1768 // We have a single map that is not an identity map. Create a new memref
1769 // with the right shape and an identity layout map.
1770 ArrayRef<int64_t> shape = memrefType.getShape();
1771 // FlatAffineValueConstraint may later on use symbolicOperands.
1772 FlatAffineValueConstraints fac(rank, numSymbolicOperands);
1773 SmallVector<unsigned, 4> memrefTypeDynDims;
1774 for (unsigned d = 0; d < rank; ++d) {
1775 // Use constraint system only in static dimensions.
1776 if (shape[d] > 0) {
1777 fac.addBound(type: BoundType::LB, pos: d, value: 0);
1778 fac.addBound(type: BoundType::UB, pos: d, value: shape[d] - 1);
1779 } else {
1780 memrefTypeDynDims.emplace_back(Args&: d);
1781 }
1782 }
1783 // We compose this map with the original index (logical) space to derive
1784 // the upper bounds for the new index space.
1785 unsigned newRank = layoutMap.getNumResults();
1786 if (failed(result: fac.composeMatchingMap(other: layoutMap)))
1787 return memrefType;
1788 // TODO: Handle semi-affine maps.
1789 // Project out the old data dimensions.
1790 fac.projectOut(pos: newRank, num: fac.getNumVars() - newRank - fac.getNumLocalVars());
1791 SmallVector<int64_t, 4> newShape(newRank);
1792 MLIRContext *context = memrefType.getContext();
1793 for (unsigned d = 0; d < newRank; ++d) {
1794 // Check if this dimension is dynamic.
1795 if (isNormalizedMemRefDynamicDim(dim: d, layoutMap, inMemrefTypeDynDims&: memrefTypeDynDims)) {
1796 newShape[d] = ShapedType::kDynamic;
1797 continue;
1798 }
1799 // The lower bound for the shape is always zero.
1800 std::optional<int64_t> ubConst = fac.getConstantBound64(type: BoundType::UB, pos: d);
1801 // For a static memref and an affine map with no symbols, this is
1802 // always bounded. However, when we have symbols, we may not be able to
1803 // obtain a constant upper bound. Also, mapping to a negative space is
1804 // invalid for normalization.
1805 if (!ubConst.has_value() || *ubConst < 0) {
1806 LLVM_DEBUG(llvm::dbgs()
1807 << "can't normalize map due to unknown/invalid upper bound");
1808 return memrefType;
1809 }
1810 // If dimension of new memrefType is dynamic, the value is -1.
1811 newShape[d] = *ubConst + 1;
1812 }
1813
1814 // Create the new memref type after trivializing the old layout map.
1815 auto newMemRefType =
1816 MemRefType::Builder(memrefType)
1817 .setShape(newShape)
1818 .setLayout(AffineMapAttr::get(
1819 AffineMap::getMultiDimIdentityMap(newRank, context)));
1820 return newMemRefType;
1821}
1822
1823DivModValue mlir::affine::getDivMod(OpBuilder &b, Location loc, Value lhs,
1824 Value rhs) {
1825 DivModValue result;
1826 AffineExpr d0, d1;
1827 bindDims(ctx: b.getContext(), exprs&: d0, exprs&: d1);
1828 result.quotient =
1829 affine::makeComposedAffineApply(b, loc, d0.floorDiv(other: d1), {lhs, rhs});
1830 result.remainder =
1831 affine::makeComposedAffineApply(b, loc, d0 % d1, {lhs, rhs});
1832 return result;
1833}
1834
1835/// Create IR that computes the product of all elements in the set.
1836static FailureOr<OpFoldResult> getIndexProduct(OpBuilder &b, Location loc,
1837 ArrayRef<Value> set) {
1838 if (set.empty())
1839 return failure();
1840 OpFoldResult result = set[0];
1841 AffineExpr s0, s1;
1842 bindSymbols(ctx: b.getContext(), exprs&: s0, exprs&: s1);
1843 for (unsigned i = 1, e = set.size(); i < e; i++)
1844 result = makeComposedFoldedAffineApply(b, loc, expr: s0 * s1, operands: {result, set[i]});
1845 return result;
1846}
1847
1848FailureOr<SmallVector<Value>>
1849mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
1850 ArrayRef<Value> basis) {
1851 unsigned numDims = basis.size();
1852
1853 SmallVector<Value> divisors;
1854 for (unsigned i = 1; i < numDims; i++) {
1855 ArrayRef<Value> slice = basis.drop_front(N: i);
1856 FailureOr<OpFoldResult> prod = getIndexProduct(b, loc, set: slice);
1857 if (failed(result: prod))
1858 return failure();
1859 divisors.push_back(Elt: getValueOrCreateConstantIndexOp(b, loc, ofr: *prod));
1860 }
1861
1862 SmallVector<Value> results;
1863 results.reserve(N: divisors.size() + 1);
1864 Value residual = linearIndex;
1865 for (Value divisor : divisors) {
1866 DivModValue divMod = getDivMod(b, loc, lhs: residual, rhs: divisor);
1867 results.push_back(Elt: divMod.quotient);
1868 residual = divMod.remainder;
1869 }
1870 results.push_back(Elt: residual);
1871 return results;
1872}
1873
1874OpFoldResult mlir::affine::linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
1875 ArrayRef<OpFoldResult> basis,
1876 ImplicitLocOpBuilder &builder) {
1877 assert(multiIndex.size() == basis.size());
1878 SmallVector<AffineExpr> basisAffine;
1879 for (size_t i = 0; i < basis.size(); ++i) {
1880 basisAffine.push_back(Elt: getAffineSymbolExpr(position: i, context: builder.getContext()));
1881 }
1882
1883 SmallVector<AffineExpr> stridesAffine = computeStrides(sizes: basisAffine);
1884 SmallVector<OpFoldResult> strides;
1885 strides.reserve(N: stridesAffine.size());
1886 llvm::transform(Range&: stridesAffine, d_first: std::back_inserter(x&: strides),
1887 F: [&builder, &basis](AffineExpr strideExpr) {
1888 return affine::makeComposedFoldedAffineApply(
1889 b&: builder, loc: builder.getLoc(), expr: strideExpr, operands: basis);
1890 });
1891
1892 auto &&[linearIndexExpr, multiIndexAndStrides] = computeLinearIndex(
1893 sourceOffset: OpFoldResult(builder.getIndexAttr(0)), strides, indices: multiIndex);
1894 return affine::makeComposedFoldedAffineApply(
1895 builder, builder.getLoc(), linearIndexExpr, multiIndexAndStrides);
1896}
1897

source code of mlir/lib/Dialect/Affine/Utils/Utils.cpp