1//===- Utils.cpp ---- Utilities for affine dialect transformation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements miscellaneous transformation utilities for the Affine
10// dialect.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/Utils.h"
15
16#include "mlir/Dialect/Affine/Analysis/Utils.h"
17#include "mlir/Dialect/Affine/IR/AffineOps.h"
18#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
19#include "mlir/Dialect/Affine/LoopUtils.h"
20#include "mlir/Dialect/Arith/Utils/Utils.h"
21#include "mlir/Dialect/Func/IR/FuncOps.h"
22#include "mlir/Dialect/MemRef/IR/MemRef.h"
23#include "mlir/Dialect/Utils/IndexingUtils.h"
24#include "mlir/IR/AffineExprVisitor.h"
25#include "mlir/IR/Dominance.h"
26#include "mlir/IR/IRMapping.h"
27#include "mlir/IR/ImplicitLocOpBuilder.h"
28#include "mlir/IR/IntegerSet.h"
29#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30#include "llvm/Support/LogicalResult.h"
31#include <optional>
32
33#define DEBUG_TYPE "affine-utils"
34
35using namespace mlir;
36using namespace affine;
37using namespace presburger;
38
39namespace {
40/// Visit affine expressions recursively and build the sequence of operations
41/// that correspond to it. Visitation functions return an Value of the
42/// expression subtree they visited or `nullptr` on error.
43class AffineApplyExpander
44 : public AffineExprVisitor<AffineApplyExpander, Value> {
45public:
46 /// This internal class expects arguments to be non-null, checks must be
47 /// performed at the call site.
48 AffineApplyExpander(OpBuilder &builder, ValueRange dimValues,
49 ValueRange symbolValues, Location loc)
50 : builder(builder), dimValues(dimValues), symbolValues(symbolValues),
51 loc(loc) {}
52
53 template <typename OpTy>
54 Value buildBinaryExpr(AffineBinaryOpExpr expr,
55 arith::IntegerOverflowFlags overflowFlags =
56 arith::IntegerOverflowFlags::none) {
57 auto lhs = visit(expr: expr.getLHS());
58 auto rhs = visit(expr: expr.getRHS());
59 if (!lhs || !rhs)
60 return nullptr;
61 auto op = builder.create<OpTy>(loc, lhs, rhs, overflowFlags);
62 return op.getResult();
63 }
64
65 Value visitAddExpr(AffineBinaryOpExpr expr) {
66 return buildBinaryExpr<arith::AddIOp>(expr);
67 }
68
69 Value visitMulExpr(AffineBinaryOpExpr expr) {
70 return buildBinaryExpr<arith::MulIOp>(expr,
71 arith::IntegerOverflowFlags::nsw);
72 }
73
74 /// Euclidean modulo operation: negative RHS is not allowed.
75 /// Remainder of the euclidean integer division is always non-negative.
76 ///
77 /// Implemented as
78 ///
79 /// a mod b =
80 /// let remainder = srem a, b;
81 /// negative = a < 0 in
82 /// select negative, remainder + b, remainder.
83 Value visitModExpr(AffineBinaryOpExpr expr) {
84 if (auto rhsConst = dyn_cast<AffineConstantExpr>(Val: expr.getRHS())) {
85 if (rhsConst.getValue() <= 0) {
86 emitError(loc, message: "modulo by non-positive value is not supported");
87 return nullptr;
88 }
89 }
90
91 auto lhs = visit(expr: expr.getLHS());
92 auto rhs = visit(expr: expr.getRHS());
93 assert(lhs && rhs && "unexpected affine expr lowering failure");
94
95 Value remainder = builder.create<arith::RemSIOp>(loc, lhs, rhs);
96 Value zeroCst = builder.create<arith::ConstantIndexOp>(location: loc, args: 0);
97 Value isRemainderNegative = builder.create<arith::CmpIOp>(
98 loc, arith::CmpIPredicate::slt, remainder, zeroCst);
99 Value correctedRemainder =
100 builder.create<arith::AddIOp>(loc, remainder, rhs);
101 Value result = builder.create<arith::SelectOp>(
102 loc, isRemainderNegative, correctedRemainder, remainder);
103 return result;
104 }
105
106 /// Floor division operation (rounds towards negative infinity).
107 ///
108 /// For positive divisors, it can be implemented without branching and with a
109 /// single division operation as
110 ///
111 /// a floordiv b =
112 /// let negative = a < 0 in
113 /// let absolute = negative ? -a - 1 : a in
114 /// let quotient = absolute / b in
115 /// negative ? -quotient - 1 : quotient
116 ///
117 /// Note: this lowering does not use arith.floordivsi because the lowering of
118 /// that to arith.divsi (see populateCeilFloorDivExpandOpsPatterns) generates
119 /// not one but two arith.divsi. That could be changed to one divsi, but one
120 /// way or another, going through arith.floordivsi will result in more complex
121 /// IR because arith.floordivsi is more general than affine floordiv in that
122 /// it supports negative RHS.
123 Value visitFloorDivExpr(AffineBinaryOpExpr expr) {
124 if (auto rhsConst = dyn_cast<AffineConstantExpr>(Val: expr.getRHS())) {
125 if (rhsConst.getValue() <= 0) {
126 emitError(loc, message: "division by non-positive value is not supported");
127 return nullptr;
128 }
129 }
130 auto lhs = visit(expr: expr.getLHS());
131 auto rhs = visit(expr: expr.getRHS());
132 assert(lhs && rhs && "unexpected affine expr lowering failure");
133
134 Value zeroCst = builder.create<arith::ConstantIndexOp>(location: loc, args: 0);
135 Value noneCst = builder.create<arith::ConstantIndexOp>(location: loc, args: -1);
136 Value negative = builder.create<arith::CmpIOp>(
137 loc, arith::CmpIPredicate::slt, lhs, zeroCst);
138 Value negatedDecremented = builder.create<arith::SubIOp>(loc, noneCst, lhs);
139 Value dividend =
140 builder.create<arith::SelectOp>(loc, negative, negatedDecremented, lhs);
141 Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs);
142 Value correctedQuotient =
143 builder.create<arith::SubIOp>(loc, noneCst, quotient);
144 Value result = builder.create<arith::SelectOp>(loc, negative,
145 correctedQuotient, quotient);
146 return result;
147 }
148
149 /// Ceiling division operation (rounds towards positive infinity).
150 ///
151 /// For positive divisors, it can be implemented without branching and with a
152 /// single division operation as
153 ///
154 /// a ceildiv b =
155 /// let negative = a <= 0 in
156 /// let absolute = negative ? -a : a - 1 in
157 /// let quotient = absolute / b in
158 /// negative ? -quotient : quotient + 1
159 ///
160 /// Note: not using arith.ceildivsi for the same reason as explained in the
161 /// visitFloorDivExpr comment.
162 Value visitCeilDivExpr(AffineBinaryOpExpr expr) {
163 if (auto rhsConst = dyn_cast<AffineConstantExpr>(Val: expr.getRHS())) {
164 if (rhsConst.getValue() <= 0) {
165 emitError(loc, message: "division by non-positive value is not supported");
166 return nullptr;
167 }
168 }
169 auto lhs = visit(expr: expr.getLHS());
170 auto rhs = visit(expr: expr.getRHS());
171 assert(lhs && rhs && "unexpected affine expr lowering failure");
172
173 Value zeroCst = builder.create<arith::ConstantIndexOp>(location: loc, args: 0);
174 Value oneCst = builder.create<arith::ConstantIndexOp>(location: loc, args: 1);
175 Value nonPositive = builder.create<arith::CmpIOp>(
176 loc, arith::CmpIPredicate::sle, lhs, zeroCst);
177 Value negated = builder.create<arith::SubIOp>(loc, zeroCst, lhs);
178 Value decremented = builder.create<arith::SubIOp>(loc, lhs, oneCst);
179 Value dividend =
180 builder.create<arith::SelectOp>(loc, nonPositive, negated, decremented);
181 Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs);
182 Value negatedQuotient =
183 builder.create<arith::SubIOp>(loc, zeroCst, quotient);
184 Value incrementedQuotient =
185 builder.create<arith::AddIOp>(loc, quotient, oneCst);
186 Value result = builder.create<arith::SelectOp>(
187 loc, nonPositive, negatedQuotient, incrementedQuotient);
188 return result;
189 }
190
191 Value visitConstantExpr(AffineConstantExpr expr) {
192 auto op = builder.create<arith::ConstantIndexOp>(location: loc, args: expr.getValue());
193 return op.getResult();
194 }
195
196 Value visitDimExpr(AffineDimExpr expr) {
197 assert(expr.getPosition() < dimValues.size() &&
198 "affine dim position out of range");
199 return dimValues[expr.getPosition()];
200 }
201
202 Value visitSymbolExpr(AffineSymbolExpr expr) {
203 assert(expr.getPosition() < symbolValues.size() &&
204 "symbol dim position out of range");
205 return symbolValues[expr.getPosition()];
206 }
207
208private:
209 OpBuilder &builder;
210 ValueRange dimValues;
211 ValueRange symbolValues;
212
213 Location loc;
214};
215} // namespace
216
217/// Create a sequence of operations that implement the `expr` applied to the
218/// given dimension and symbol values.
219mlir::Value mlir::affine::expandAffineExpr(OpBuilder &builder, Location loc,
220 AffineExpr expr,
221 ValueRange dimValues,
222 ValueRange symbolValues) {
223 return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr);
224}
225
226/// Create a sequence of operations that implement the `affineMap` applied to
227/// the given `operands` (as it it were an AffineApplyOp).
228std::optional<SmallVector<Value, 8>>
229mlir::affine::expandAffineMap(OpBuilder &builder, Location loc,
230 AffineMap affineMap, ValueRange operands) {
231 auto numDims = affineMap.getNumDims();
232 auto expanded = llvm::to_vector<8>(
233 Range: llvm::map_range(C: affineMap.getResults(),
234 F: [numDims, &builder, loc, operands](AffineExpr expr) {
235 return expandAffineExpr(builder, loc, expr,
236 dimValues: operands.take_front(n: numDims),
237 symbolValues: operands.drop_front(n: numDims));
238 }));
239 if (llvm::all_of(Range&: expanded, P: [](Value v) { return v; }))
240 return expanded;
241 return std::nullopt;
242}
243
244/// Promotes the `then` or the `else` block of `ifOp` (depending on whether
245/// `elseBlock` is false or true) into `ifOp`'s containing block, and discards
246/// the rest of the op.
247static void promoteIfBlock(AffineIfOp ifOp, bool elseBlock) {
248 if (elseBlock)
249 assert(ifOp.hasElse() && "else block expected");
250
251 Block *destBlock = ifOp->getBlock();
252 Block *srcBlock = elseBlock ? ifOp.getElseBlock() : ifOp.getThenBlock();
253 destBlock->getOperations().splice(
254 where: Block::iterator(ifOp), L2&: srcBlock->getOperations(), first: srcBlock->begin(),
255 last: std::prev(x: srcBlock->end()));
256 ifOp.erase();
257}
258
259/// Returns the outermost affine.for/parallel op that the `ifOp` is invariant
260/// on. The `ifOp` could be hoisted and placed right before such an operation.
261/// This method assumes that the ifOp has been canonicalized (to be correct and
262/// effective).
263static Operation *getOutermostInvariantForOp(AffineIfOp ifOp) {
264 // Walk up the parents past all for op that this conditional is invariant on.
265 auto ifOperands = ifOp.getOperands();
266 Operation *res = ifOp;
267 while (!res->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
268 auto *parentOp = res->getParentOp();
269 if (auto forOp = dyn_cast<AffineForOp>(parentOp)) {
270 if (llvm::is_contained(ifOperands, forOp.getInductionVar()))
271 break;
272 } else if (auto parallelOp = dyn_cast<AffineParallelOp>(parentOp)) {
273 if (llvm::any_of(parallelOp.getIVs(), [&](Value iv) {
274 return llvm::is_contained(ifOperands, iv);
275 }))
276 break;
277 } else if (!isa<AffineIfOp>(parentOp)) {
278 // Won't walk up past anything other than affine.for/if ops.
279 break;
280 }
281 // You can always hoist up past any affine.if ops.
282 res = parentOp;
283 }
284 return res;
285}
286
287/// A helper for the mechanics of mlir::hoistAffineIfOp. Hoists `ifOp` just over
288/// `hoistOverOp`. Returns the new hoisted op if any hoisting happened,
289/// otherwise the same `ifOp`.
290static AffineIfOp hoistAffineIfOp(AffineIfOp ifOp, Operation *hoistOverOp) {
291 // No hoisting to do.
292 if (hoistOverOp == ifOp)
293 return ifOp;
294
295 // Create the hoisted 'if' first. Then, clone the op we are hoisting over for
296 // the else block. Then drop the else block of the original 'if' in the 'then'
297 // branch while promoting its then block, and analogously drop the 'then'
298 // block of the original 'if' from the 'else' branch while promoting its else
299 // block.
300 IRMapping operandMap;
301 OpBuilder b(hoistOverOp);
302 auto hoistedIfOp = b.create<AffineIfOp>(ifOp.getLoc(), ifOp.getIntegerSet(),
303 ifOp.getOperands(),
304 /*elseBlock=*/true);
305
306 // Create a clone of hoistOverOp to use for the else branch of the hoisted
307 // conditional. The else block may get optimized away if empty.
308 Operation *hoistOverOpClone = nullptr;
309 // We use this unique name to identify/find `ifOp`'s clone in the else
310 // version.
311 StringAttr idForIfOp = b.getStringAttr("__mlir_if_hoisting");
312 operandMap.clear();
313 b.setInsertionPointAfter(hoistOverOp);
314 // We'll set an attribute to identify this op in a clone of this sub-tree.
315 ifOp->setAttr(idForIfOp, b.getBoolAttr(value: true));
316 hoistOverOpClone = b.clone(op&: *hoistOverOp, mapper&: operandMap);
317
318 // Promote the 'then' block of the original affine.if in the then version.
319 promoteIfBlock(ifOp, /*elseBlock=*/false);
320
321 // Move the then version to the hoisted if op's 'then' block.
322 auto *thenBlock = hoistedIfOp.getThenBlock();
323 thenBlock->getOperations().splice(thenBlock->begin(),
324 hoistOverOp->getBlock()->getOperations(),
325 Block::iterator(hoistOverOp));
326
327 // Find the clone of the original affine.if op in the else version.
328 AffineIfOp ifCloneInElse;
329 hoistOverOpClone->walk([&](AffineIfOp ifClone) {
330 if (!ifClone->getAttr(idForIfOp))
331 return WalkResult::advance();
332 ifCloneInElse = ifClone;
333 return WalkResult::interrupt();
334 });
335 assert(ifCloneInElse && "if op clone should exist");
336 // For the else block, promote the else block of the original 'if' if it had
337 // one; otherwise, the op itself is to be erased.
338 if (!ifCloneInElse.hasElse())
339 ifCloneInElse.erase();
340 else
341 promoteIfBlock(ifCloneInElse, /*elseBlock=*/true);
342
343 // Move the else version into the else block of the hoisted if op.
344 auto *elseBlock = hoistedIfOp.getElseBlock();
345 elseBlock->getOperations().splice(
346 elseBlock->begin(), hoistOverOpClone->getBlock()->getOperations(),
347 Block::iterator(hoistOverOpClone));
348
349 return hoistedIfOp;
350}
351
352LogicalResult
353mlir::affine::affineParallelize(AffineForOp forOp,
354 ArrayRef<LoopReduction> parallelReductions,
355 AffineParallelOp *resOp) {
356 // Fail early if there are iter arguments that are not reductions.
357 unsigned numReductions = parallelReductions.size();
358 if (numReductions != forOp.getNumIterOperands())
359 return failure();
360
361 Location loc = forOp.getLoc();
362 OpBuilder outsideBuilder(forOp);
363 AffineMap lowerBoundMap = forOp.getLowerBoundMap();
364 ValueRange lowerBoundOperands = forOp.getLowerBoundOperands();
365 AffineMap upperBoundMap = forOp.getUpperBoundMap();
366 ValueRange upperBoundOperands = forOp.getUpperBoundOperands();
367
368 // Creating empty 1-D affine.parallel op.
369 auto reducedValues = llvm::to_vector<4>(Range: llvm::map_range(
370 C&: parallelReductions, F: [](const LoopReduction &red) { return red.value; }));
371 auto reductionKinds = llvm::to_vector<4>(llvm::map_range(
372 parallelReductions, [](const LoopReduction &red) { return red.kind; }));
373 AffineParallelOp newPloop = outsideBuilder.create<AffineParallelOp>(
374 loc, ValueRange(reducedValues).getTypes(), reductionKinds,
375 llvm::ArrayRef(lowerBoundMap), lowerBoundOperands,
376 llvm::ArrayRef(upperBoundMap), upperBoundOperands,
377 llvm::ArrayRef(forOp.getStepAsInt()));
378 // Steal the body of the old affine for op.
379 newPloop.getRegion().takeBody(forOp.getRegion());
380 Operation *yieldOp = &newPloop.getBody()->back();
381
382 // Handle the initial values of reductions because the parallel loop always
383 // starts from the neutral value.
384 SmallVector<Value> newResults;
385 newResults.reserve(N: numReductions);
386 for (unsigned i = 0; i < numReductions; ++i) {
387 Value init = forOp.getInits()[i];
388 // This works because we are only handling single-op reductions at the
389 // moment. A switch on reduction kind or a mechanism to collect operations
390 // participating in the reduction will be necessary for multi-op reductions.
391 Operation *reductionOp = yieldOp->getOperand(idx: i).getDefiningOp();
392 assert(reductionOp && "yielded value is expected to be produced by an op");
393 outsideBuilder.getInsertionBlock()->getOperations().splice(
394 outsideBuilder.getInsertionPoint(), newPloop.getBody()->getOperations(),
395 reductionOp);
396 reductionOp->setOperands({init, newPloop->getResult(i)});
397 forOp->getResult(i).replaceAllUsesWith(reductionOp->getResult(idx: 0));
398 }
399
400 // Update the loop terminator to yield reduced values bypassing the reduction
401 // operation itself (now moved outside of the loop) and erase the block
402 // arguments that correspond to reductions. Note that the loop always has one
403 // "main" induction variable whenc coming from a non-parallel for.
404 unsigned numIVs = 1;
405 yieldOp->setOperands(reducedValues);
406 newPloop.getBody()->eraseArguments(numIVs, numReductions);
407
408 forOp.erase();
409 if (resOp)
410 *resOp = newPloop;
411 return success();
412}
413
414// Returns success if any hoisting happened.
415LogicalResult mlir::affine::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
416 // Bail out early if the ifOp returns a result. TODO: Consider how to
417 // properly support this case.
418 if (ifOp.getNumResults() != 0)
419 return failure();
420
421 // Apply canonicalization patterns and folding - this is necessary for the
422 // hoisting check to be correct (operands should be composed), and to be more
423 // effective (no unused operands). Since the pattern rewriter's folding is
424 // entangled with application of patterns, we may fold/end up erasing the op,
425 // in which case we return with `folded` being set.
426 RewritePatternSet patterns(ifOp.getContext());
427 AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
428 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
429 bool erased;
430 (void)applyOpPatternsGreedily(
431 ifOp.getOperation(), frozenPatterns,
432 GreedyRewriteConfig().setStrictness(GreedyRewriteStrictness::ExistingOps),
433 /*changed=*/nullptr, &erased);
434 if (erased) {
435 if (folded)
436 *folded = true;
437 return failure();
438 }
439 if (folded)
440 *folded = false;
441
442 // The folding above should have ensured this.
443 assert(llvm::all_of(ifOp.getOperands(),
444 [](Value v) {
445 return isTopLevelValue(v) || isAffineInductionVar(v);
446 }) &&
447 "operands not composed");
448
449 // We are going hoist as high as possible.
450 // TODO: this could be customized in the future.
451 auto *hoistOverOp = getOutermostInvariantForOp(ifOp);
452
453 AffineIfOp hoistedIfOp = ::hoistAffineIfOp(ifOp, hoistOverOp);
454 // Nothing to hoist over.
455 if (hoistedIfOp == ifOp)
456 return failure();
457
458 // Canonicalize to remove dead else blocks (happens whenever an 'if' moves up
459 // a sequence of affine.fors that are all perfectly nested).
460 (void)applyPatternsGreedily(
461 hoistedIfOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(),
462 frozenPatterns);
463
464 return success();
465}
466
467// Return the min expr after replacing the given dim.
468AffineExpr mlir::affine::substWithMin(AffineExpr e, AffineExpr dim,
469 AffineExpr min, AffineExpr max,
470 bool positivePath) {
471 if (e == dim)
472 return positivePath ? min : max;
473 if (auto bin = dyn_cast<AffineBinaryOpExpr>(Val&: e)) {
474 AffineExpr lhs = bin.getLHS();
475 AffineExpr rhs = bin.getRHS();
476 if (bin.getKind() == mlir::AffineExprKind::Add)
477 return substWithMin(e: lhs, dim, min, max, positivePath) +
478 substWithMin(e: rhs, dim, min, max, positivePath);
479
480 auto c1 = dyn_cast<AffineConstantExpr>(Val: bin.getLHS());
481 auto c2 = dyn_cast<AffineConstantExpr>(Val: bin.getRHS());
482 if (c1 && c1.getValue() < 0)
483 return getAffineBinaryOpExpr(
484 kind: bin.getKind(), lhs: c1, rhs: substWithMin(e: rhs, dim, min, max, positivePath: !positivePath));
485 if (c2 && c2.getValue() < 0)
486 return getAffineBinaryOpExpr(
487 kind: bin.getKind(), lhs: substWithMin(e: lhs, dim, min, max, positivePath: !positivePath), rhs: c2);
488 return getAffineBinaryOpExpr(
489 kind: bin.getKind(), lhs: substWithMin(e: lhs, dim, min, max, positivePath),
490 rhs: substWithMin(e: rhs, dim, min, max, positivePath));
491 }
492 return e;
493}
494
495void mlir::affine::normalizeAffineParallel(AffineParallelOp op) {
496 // Loops with min/max in bounds are not normalized at the moment.
497 if (op.hasMinMaxBounds())
498 return;
499
500 AffineMap lbMap = op.getLowerBoundsMap();
501 SmallVector<int64_t, 8> steps = op.getSteps();
502 // No need to do any work if the parallel op is already normalized.
503 bool isAlreadyNormalized =
504 llvm::all_of(Range: llvm::zip(t&: steps, u: lbMap.getResults()), P: [](auto tuple) {
505 int64_t step = std::get<0>(tuple);
506 auto lbExpr = dyn_cast<AffineConstantExpr>(std::get<1>(tuple));
507 return lbExpr && lbExpr.getValue() == 0 && step == 1;
508 });
509 if (isAlreadyNormalized)
510 return;
511
512 AffineValueMap ranges;
513 AffineValueMap::difference(a: op.getUpperBoundsValueMap(),
514 b: op.getLowerBoundsValueMap(), res: &ranges);
515 auto builder = OpBuilder::atBlockBegin(block: op.getBody());
516 auto zeroExpr = builder.getAffineConstantExpr(0);
517 SmallVector<AffineExpr, 8> lbExprs;
518 SmallVector<AffineExpr, 8> ubExprs;
519 for (unsigned i = 0, e = steps.size(); i < e; ++i) {
520 int64_t step = steps[i];
521
522 // Adjust the lower bound to be 0.
523 lbExprs.push_back(Elt: zeroExpr);
524
525 // Adjust the upper bound expression: 'range / step'.
526 AffineExpr ubExpr = ranges.getResult(i).ceilDiv(v: step);
527 ubExprs.push_back(Elt: ubExpr);
528
529 // Adjust the corresponding IV: 'lb + i * step'.
530 BlockArgument iv = op.getBody()->getArgument(i);
531 AffineExpr lbExpr = lbMap.getResult(idx: i);
532 unsigned nDims = lbMap.getNumDims();
533 auto expr = lbExpr + builder.getAffineDimExpr(nDims) * step;
534 auto map = AffineMap::get(/*dimCount=*/nDims + 1,
535 /*symbolCount=*/lbMap.getNumSymbols(), expr);
536
537 // Use an 'affine.apply' op that will be simplified later in subsequent
538 // canonicalizations.
539 OperandRange lbOperands = op.getLowerBoundsOperands();
540 OperandRange dimOperands = lbOperands.take_front(n: nDims);
541 OperandRange symbolOperands = lbOperands.drop_front(n: nDims);
542 SmallVector<Value, 8> applyOperands{dimOperands};
543 applyOperands.push_back(Elt: iv);
544 applyOperands.append(in_start: symbolOperands.begin(), in_end: symbolOperands.end());
545 auto apply = builder.create<AffineApplyOp>(op.getLoc(), map, applyOperands);
546 iv.replaceAllUsesExcept(apply, apply);
547 }
548
549 SmallVector<int64_t, 8> newSteps(op.getNumDims(), 1);
550 op.setSteps(newSteps);
551 auto newLowerMap = AffineMap::get(
552 /*dimCount=*/0, /*symbolCount=*/0, lbExprs, op.getContext());
553 op.setLowerBounds({}, newLowerMap);
554 auto newUpperMap = AffineMap::get(ranges.getNumDims(), ranges.getNumSymbols(),
555 ubExprs, op.getContext());
556 op.setUpperBounds(ranges.getOperands(), newUpperMap);
557}
558
559LogicalResult mlir::affine::normalizeAffineFor(AffineForOp op,
560 bool promoteSingleIter) {
561 if (promoteSingleIter && succeeded(promoteIfSingleIteration(op)))
562 return success();
563
564 // Check if the forop is already normalized.
565 if (op.hasConstantLowerBound() && (op.getConstantLowerBound() == 0) &&
566 (op.getStep() == 1))
567 return success();
568
569 // Check if the lower bound has a single result only. Loops with a max lower
570 // bound can't be normalized without additional support like
571 // affine.execute_region's. If the lower bound does not have a single result
572 // then skip this op.
573 if (op.getLowerBoundMap().getNumResults() != 1)
574 return failure();
575
576 Location loc = op.getLoc();
577 OpBuilder opBuilder(op);
578 int64_t origLoopStep = op.getStepAsInt();
579
580 // Construct the new upper bound value map.
581 AffineMap oldLbMap = op.getLowerBoundMap();
582 // The upper bound can have multiple results. To use
583 // AffineValueMap::difference, we need to have the same number of results in
584 // both lower and upper bound maps. So, we just create a value map for the
585 // lower bound with the only available lower bound result repeated to pad up
586 // to the number of upper bound results.
587 SmallVector<AffineExpr> lbExprs(op.getUpperBoundMap().getNumResults(),
588 op.getLowerBoundMap().getResult(0));
589 AffineValueMap lbMap(oldLbMap, op.getLowerBoundOperands());
590 AffineMap paddedLbMap =
591 AffineMap::get(oldLbMap.getNumDims(), oldLbMap.getNumSymbols(), lbExprs,
592 op.getContext());
593 AffineValueMap paddedLbValueMap(paddedLbMap, op.getLowerBoundOperands());
594 AffineValueMap ubValueMap(op.getUpperBoundMap(), op.getUpperBoundOperands());
595 AffineValueMap newUbValueMap;
596 // Compute the `upper bound - lower bound`.
597 AffineValueMap::difference(a: ubValueMap, b: paddedLbValueMap, res: &newUbValueMap);
598 (void)newUbValueMap.canonicalize();
599
600 // Scale down the upper bound value map by the loop step.
601 unsigned numResult = newUbValueMap.getNumResults();
602 SmallVector<AffineExpr> scaleDownExprs(numResult);
603 for (unsigned i = 0; i < numResult; ++i)
604 scaleDownExprs[i] = opBuilder.getAffineDimExpr(position: i).ceilDiv(v: origLoopStep);
605 // `scaleDownMap` is (d0, d1, ..., d_n) -> (d0 / step, d1 / step, ..., d_n /
606 // step). Where `n` is the number of results in the upper bound map.
607 AffineMap scaleDownMap =
608 AffineMap::get(numResult, 0, scaleDownExprs, op.getContext());
609 AffineMap newUbMap = scaleDownMap.compose(map: newUbValueMap.getAffineMap());
610
611 // Set the newly create upper bound map and operands.
612 op.setUpperBound(newUbValueMap.getOperands(), newUbMap);
613 op.setLowerBound({}, opBuilder.getConstantAffineMap(val: 0));
614 op.setStep(1);
615
616 // Calculate the Value of new loopIV. Create affine.apply for the value of
617 // the loopIV in normalized loop.
618 opBuilder.setInsertionPointToStart(op.getBody());
619 // Construct an affine.apply op mapping the new IV to the old IV.
620 AffineMap scaleIvMap =
621 AffineMap::get(dimCount: 1, symbolCount: 0, result: -opBuilder.getAffineDimExpr(position: 0) * origLoopStep);
622 AffineValueMap scaleIvValueMap(scaleIvMap, ValueRange{op.getInductionVar()});
623 AffineValueMap newIvToOldIvMap;
624 AffineValueMap::difference(a: lbMap, b: scaleIvValueMap, res: &newIvToOldIvMap);
625 (void)newIvToOldIvMap.canonicalize();
626 auto newIV = opBuilder.create<AffineApplyOp>(
627 loc, newIvToOldIvMap.getAffineMap(), newIvToOldIvMap.getOperands());
628 op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0), newIV);
629 return success();
630}
631
632/// Returns true if the memory operation of `destAccess` depends on `srcAccess`
633/// inside of the innermost common surrounding affine loop between the two
634/// accesses.
635static bool mustReachAtInnermost(const MemRefAccess &srcAccess,
636 const MemRefAccess &destAccess) {
637 // Affine dependence analysis is possible only if both ops in the same
638 // AffineScope.
639 if (getAffineAnalysisScope(op: srcAccess.opInst) !=
640 getAffineAnalysisScope(op: destAccess.opInst))
641 return false;
642
643 unsigned nsLoops =
644 getNumCommonSurroundingLoops(a&: *srcAccess.opInst, b&: *destAccess.opInst);
645 DependenceResult result =
646 checkMemrefAccessDependence(srcAccess, dstAccess: destAccess, loopDepth: nsLoops + 1);
647 return hasDependence(result);
648}
649
650/// Returns true if `srcMemOp` may have an effect on `destMemOp` within the
651/// scope of the outermost `minSurroundingLoops` loops that surround them.
652/// `srcMemOp` and `destMemOp` are expected to be affine read/write ops.
653static bool mayHaveEffect(Operation *srcMemOp, Operation *destMemOp,
654 unsigned minSurroundingLoops) {
655 MemRefAccess srcAccess(srcMemOp);
656 MemRefAccess destAccess(destMemOp);
657
658 // Affine dependence analysis here is applicable only if both ops operate on
659 // the same memref and if `srcMemOp` and `destMemOp` are in the same
660 // AffineScope. Also, we can only check if our affine scope is isolated from
661 // above; otherwise, values can from outside of the affine scope that the
662 // check below cannot analyze.
663 Region *srcScope = getAffineAnalysisScope(op: srcMemOp);
664 if (srcAccess.memref == destAccess.memref &&
665 srcScope == getAffineAnalysisScope(op: destMemOp)) {
666 unsigned nsLoops = getNumCommonSurroundingLoops(a&: *srcMemOp, b&: *destMemOp);
667 FlatAffineValueConstraints dependenceConstraints;
668 for (unsigned d = nsLoops + 1; d > minSurroundingLoops; d--) {
669 DependenceResult result = checkMemrefAccessDependence(
670 srcAccess, dstAccess: destAccess, loopDepth: d, dependenceConstraints: &dependenceConstraints,
671 /*dependenceComponents=*/nullptr);
672 // A dependence failure or the presence of a dependence implies a
673 // side effect.
674 if (!noDependence(result))
675 return true;
676 }
677 // No side effect was seen.
678 return false;
679 }
680 // TODO: Check here if the memrefs alias: there is no side effect if
681 // `srcAccess.memref` and `destAccess.memref` don't alias.
682 return true;
683}
684
685template <typename EffectType, typename T>
686bool mlir::affine::hasNoInterveningEffect(
687 Operation *start, T memOp,
688 llvm::function_ref<bool(Value, Value)> mayAlias) {
689 // A boolean representing whether an intervening operation could have impacted
690 // memOp.
691 bool hasSideEffect = false;
692
693 // Check whether the effect on memOp can be caused by a given operation op.
694 Value memref = memOp.getMemRef();
695 std::function<void(Operation *)> checkOperation = [&](Operation *op) {
696 // If the effect has alreay been found, early exit,
697 if (hasSideEffect)
698 return;
699
700 if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(op)) {
701 SmallVector<MemoryEffects::EffectInstance, 1> effects;
702 memEffect.getEffects(effects);
703
704 bool opMayHaveEffect = false;
705 for (auto effect : effects) {
706 // If op causes EffectType on a potentially aliasing location for
707 // memOp, mark as having the effect.
708 if (isa<EffectType>(effect.getEffect())) {
709 if (effect.getValue() && effect.getValue() != memref &&
710 !mayAlias(effect.getValue(), memref))
711 continue;
712 opMayHaveEffect = true;
713 break;
714 }
715 }
716
717 if (!opMayHaveEffect)
718 return;
719
720 // If the side effect comes from an affine read or write, try to
721 // prove the side effecting `op` cannot reach `memOp`.
722 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
723 // For ease, let's consider the case that `op` is a store and
724 // we're looking for other potential stores that overwrite memory after
725 // `start`, and before being read in `memOp`. In this case, we only
726 // need to consider other potential stores with depth >
727 // minSurroundingLoops since `start` would overwrite any store with a
728 // smaller number of surrounding loops before.
729 unsigned minSurroundingLoops =
730 getNumCommonSurroundingLoops(*start, *memOp);
731 if (mayHaveEffect(op, memOp, minSurroundingLoops))
732 hasSideEffect = true;
733 return;
734 }
735
736 // We have an op with a memory effect and we cannot prove if it
737 // intervenes.
738 hasSideEffect = true;
739 return;
740 }
741
742 if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) {
743 // Recurse into the regions for this op and check whether the internal
744 // operations may have the side effect `EffectType` on memOp.
745 for (Region &region : op->getRegions())
746 for (Block &block : region)
747 for (Operation &op : block)
748 checkOperation(&op);
749 return;
750 }
751
752 // Otherwise, conservatively assume generic operations have the effect
753 // on the operation
754 hasSideEffect = true;
755 };
756
757 // Check all paths from ancestor op `parent` to the operation `to` for the
758 // effect. It is known that `to` must be contained within `parent`.
759 auto until = [&](Operation *parent, Operation *to) {
760 // TODO check only the paths from `parent` to `to`.
761 // Currently we fallback and check the entire parent op, rather than
762 // just the paths from the parent path, stopping after reaching `to`.
763 // This is conservatively correct, but could be made more aggressive.
764 assert(parent->isAncestor(to));
765 checkOperation(parent);
766 };
767
768 // Check for all paths from operation `from` to operation `untilOp` for the
769 // given memory effect.
770 std::function<void(Operation *, Operation *)> recur =
771 [&](Operation *from, Operation *untilOp) {
772 assert(
773 from->getParentRegion()->isAncestor(untilOp->getParentRegion()) &&
774 "Checking for side effect between two operations without a common "
775 "ancestor");
776
777 // If the operations are in different regions, recursively consider all
778 // path from `from` to the parent of `to` and all paths from the parent
779 // of `to` to `to`.
780 if (from->getParentRegion() != untilOp->getParentRegion()) {
781 recur(from, untilOp->getParentOp());
782 until(untilOp->getParentOp(), untilOp);
783 return;
784 }
785
786 // Now, assuming that `from` and `to` exist in the same region, perform
787 // a CFG traversal to check all the relevant operations.
788
789 // Additional blocks to consider.
790 SmallVector<Block *, 2> todoBlocks;
791 {
792 // First consider the parent block of `from` an check all operations
793 // after `from`.
794 for (auto iter = ++from->getIterator(), end = from->getBlock()->end();
795 iter != end && &*iter != untilOp; ++iter) {
796 checkOperation(&*iter);
797 }
798
799 // If the parent of `from` doesn't contain `to`, add the successors
800 // to the list of blocks to check.
801 if (untilOp->getBlock() != from->getBlock())
802 for (Block *succ : from->getBlock()->getSuccessors())
803 todoBlocks.push_back(Elt: succ);
804 }
805
806 SmallPtrSet<Block *, 4> done;
807 // Traverse the CFG until hitting `to`.
808 while (!todoBlocks.empty()) {
809 Block *blk = todoBlocks.pop_back_val();
810 if (done.count(Ptr: blk))
811 continue;
812 done.insert(Ptr: blk);
813 for (auto &op : *blk) {
814 if (&op == untilOp)
815 break;
816 checkOperation(&op);
817 if (&op == blk->getTerminator())
818 for (Block *succ : blk->getSuccessors())
819 todoBlocks.push_back(Elt: succ);
820 }
821 }
822 };
823 recur(start, memOp);
824 return !hasSideEffect;
825}
826
827/// Attempt to eliminate loadOp by replacing it with a value stored into memory
828/// which the load is guaranteed to retrieve. This check involves three
829/// components: 1) The store and load must be on the same location 2) The store
830/// must dominate (and therefore must always occur prior to) the load 3) No
831/// other operations will overwrite the memory loaded between the given load
832/// and store. If such a value exists, the replaced `loadOp` will be added to
833/// `loadOpsToErase` and its memref will be added to `memrefsToErase`.
834static void forwardStoreToLoad(
835 AffineReadOpInterface loadOp, SmallVectorImpl<Operation *> &loadOpsToErase,
836 SmallPtrSetImpl<Value> &memrefsToErase, DominanceInfo &domInfo,
837 llvm::function_ref<bool(Value, Value)> mayAlias) {
838
839 // The store op candidate for forwarding that satisfies all conditions
840 // to replace the load, if any.
841 Operation *lastWriteStoreOp = nullptr;
842
843 for (auto *user : loadOp.getMemRef().getUsers()) {
844 auto storeOp = dyn_cast<AffineWriteOpInterface>(user);
845 if (!storeOp)
846 continue;
847 MemRefAccess srcAccess(storeOp);
848 MemRefAccess destAccess(loadOp);
849
850 // 1. Check if the store and the load have mathematically equivalent
851 // affine access functions; this implies that they statically refer to the
852 // same single memref element. As an example this filters out cases like:
853 // store %A[%i0 + 1]
854 // load %A[%i0]
855 // store %A[%M]
856 // load %A[%N]
857 // Use the AffineValueMap difference based memref access equality checking.
858 if (srcAccess != destAccess)
859 continue;
860
861 // 2. The store has to dominate the load op to be candidate.
862 if (!domInfo.dominates(storeOp, loadOp))
863 continue;
864
865 // 3. The store must reach the load. Access function equivalence only
866 // guarantees this for accesses in the same block. The load could be in a
867 // nested block that is unreachable.
868 if (!mustReachAtInnermost(srcAccess, destAccess))
869 continue;
870
871 // 4. Ensure there is no intermediate operation which could replace the
872 // value in memory.
873 if (!affine::hasNoInterveningEffect<MemoryEffects::Write>(storeOp, loadOp,
874 mayAlias))
875 continue;
876
877 // We now have a candidate for forwarding.
878 assert(lastWriteStoreOp == nullptr &&
879 "multiple simultaneous replacement stores");
880 lastWriteStoreOp = storeOp;
881 }
882
883 if (!lastWriteStoreOp)
884 return;
885
886 // Perform the actual store to load forwarding.
887 Value storeVal =
888 cast<AffineWriteOpInterface>(lastWriteStoreOp).getValueToStore();
889 // Check if 2 values have the same shape. This is needed for affine vector
890 // loads and stores.
891 if (storeVal.getType() != loadOp.getValue().getType())
892 return;
893 loadOp.getValue().replaceAllUsesWith(storeVal);
894 // Record the memref for a later sweep to optimize away.
895 memrefsToErase.insert(loadOp.getMemRef());
896 // Record this to erase later.
897 loadOpsToErase.push_back(Elt: loadOp);
898}
899
900template bool
901mlir::affine::hasNoInterveningEffect<mlir::MemoryEffects::Read,
902 affine::AffineReadOpInterface>(
903 mlir::Operation *, affine::AffineReadOpInterface,
904 llvm::function_ref<bool(Value, Value)>);
905
906// This attempts to find stores which have no impact on the final result.
907// A writing op writeA will be eliminated if there exists an op writeB if
908// 1) writeA and writeB have mathematically equivalent affine access functions.
909// 2) writeB postdominates writeA.
910// 3) There is no potential read between writeA and writeB.
911static void findUnusedStore(AffineWriteOpInterface writeA,
912 SmallVectorImpl<Operation *> &opsToErase,
913 PostDominanceInfo &postDominanceInfo,
914 llvm::function_ref<bool(Value, Value)> mayAlias) {
915
916 for (Operation *user : writeA.getMemRef().getUsers()) {
917 // Only consider writing operations.
918 auto writeB = dyn_cast<AffineWriteOpInterface>(user);
919 if (!writeB)
920 continue;
921
922 // The operations must be distinct.
923 if (writeB == writeA)
924 continue;
925
926 // Both operations must lie in the same region.
927 if (writeB->getParentRegion() != writeA->getParentRegion())
928 continue;
929
930 // Both operations must write to the same memory.
931 MemRefAccess srcAccess(writeB);
932 MemRefAccess destAccess(writeA);
933
934 if (srcAccess != destAccess)
935 continue;
936
937 // writeB must postdominate writeA.
938 if (!postDominanceInfo.postDominates(writeB, writeA))
939 continue;
940
941 // There cannot be an operation which reads from memory between
942 // the two writes.
943 if (!affine::hasNoInterveningEffect<MemoryEffects::Read>(writeA, writeB,
944 mayAlias))
945 continue;
946
947 opsToErase.push_back(writeA);
948 break;
949 }
950}
951
952// The load to load forwarding / redundant load elimination is similar to the
953// store to load forwarding.
954// loadA will be be replaced with loadB if:
955// 1) loadA and loadB have mathematically equivalent affine access functions.
956// 2) loadB dominates loadA.
957// 3) There is no write between loadA and loadB.
958static void loadCSE(AffineReadOpInterface loadA,
959 SmallVectorImpl<Operation *> &loadOpsToErase,
960 DominanceInfo &domInfo,
961 llvm::function_ref<bool(Value, Value)> mayAlias) {
962 SmallVector<AffineReadOpInterface, 4> loadCandidates;
963 for (auto *user : loadA.getMemRef().getUsers()) {
964 auto loadB = dyn_cast<AffineReadOpInterface>(user);
965 if (!loadB || loadB == loadA)
966 continue;
967
968 MemRefAccess srcAccess(loadB);
969 MemRefAccess destAccess(loadA);
970
971 // 1. The accesses should be to be to the same location.
972 if (srcAccess != destAccess) {
973 continue;
974 }
975
976 // 2. loadB should dominate loadA.
977 if (!domInfo.dominates(loadB, loadA))
978 continue;
979
980 // 3. There should not be a write between loadA and loadB.
981 if (!affine::hasNoInterveningEffect<MemoryEffects::Write>(
982 loadB.getOperation(), loadA, mayAlias))
983 continue;
984
985 // Check if two values have the same shape. This is needed for affine vector
986 // loads.
987 if (loadB.getValue().getType() != loadA.getValue().getType())
988 continue;
989
990 loadCandidates.push_back(loadB);
991 }
992
993 // Of the legal load candidates, use the one that dominates all others
994 // to minimize the subsequent need to loadCSE
995 Value loadB;
996 for (AffineReadOpInterface option : loadCandidates) {
997 if (llvm::all_of(loadCandidates, [&](AffineReadOpInterface depStore) {
998 return depStore == option ||
999 domInfo.dominates(option.getOperation(),
1000 depStore.getOperation());
1001 })) {
1002 loadB = option.getValue();
1003 break;
1004 }
1005 }
1006
1007 if (loadB) {
1008 loadA.getValue().replaceAllUsesWith(loadB);
1009 // Record this to erase later.
1010 loadOpsToErase.push_back(Elt: loadA);
1011 }
1012}
1013
1014// The store to load forwarding and load CSE rely on three conditions:
1015//
1016// 1) store/load providing a replacement value and load being replaced need to
1017// have mathematically equivalent affine access functions (checked after full
1018// composition of load/store operands); this implies that they access the same
1019// single memref element for all iterations of the common surrounding loop,
1020//
1021// 2) the store/load op should dominate the load op,
1022//
1023// 3) no operation that may write to memory read by the load being replaced can
1024// occur after executing the instruction (load or store) providing the
1025// replacement value and before the load being replaced (thus potentially
1026// allowing overwriting the memory read by the load).
1027//
1028// The above conditions are simple to check, sufficient, and powerful for most
1029// cases in practice - they are sufficient, but not necessary --- since they
1030// don't reason about loops that are guaranteed to execute at least once or
1031// multiple sources to forward from.
1032//
1033// TODO: more forwarding can be done when support for
1034// loop/conditional live-out SSA values is available.
1035// TODO: do general dead store elimination for memref's. This pass
1036// currently only eliminates the stores only if no other loads/uses (other
1037// than dealloc) remain.
1038//
1039void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
1040 PostDominanceInfo &postDomInfo,
1041 AliasAnalysis &aliasAnalysis) {
1042 // Load op's whose results were replaced by those forwarded from stores.
1043 SmallVector<Operation *, 8> opsToErase;
1044
1045 // A list of memref's that are potentially dead / could be eliminated.
1046 SmallPtrSet<Value, 4> memrefsToErase;
1047
1048 auto mayAlias = [&](Value val1, Value val2) -> bool {
1049 return !aliasAnalysis.alias(lhs: val1, rhs: val2).isNo();
1050 };
1051
1052 // Walk all load's and perform store to load forwarding.
1053 f.walk([&](AffineReadOpInterface loadOp) {
1054 forwardStoreToLoad(loadOp, opsToErase, memrefsToErase, domInfo, mayAlias);
1055 });
1056 for (auto *op : opsToErase)
1057 op->erase();
1058 opsToErase.clear();
1059
1060 // Walk all store's and perform unused store elimination
1061 f.walk([&](AffineWriteOpInterface storeOp) {
1062 findUnusedStore(storeOp, opsToErase, postDomInfo, mayAlias);
1063 });
1064 for (auto *op : opsToErase)
1065 op->erase();
1066 opsToErase.clear();
1067
1068 // Check if the store fwd'ed memrefs are now left with only stores and
1069 // deallocs and can thus be completely deleted. Note: the canonicalize pass
1070 // should be able to do this as well, but we'll do it here since we collected
1071 // these anyway.
1072 for (auto memref : memrefsToErase) {
1073 // If the memref hasn't been locally alloc'ed, skip.
1074 Operation *defOp = memref.getDefiningOp();
1075 if (!defOp || !hasSingleEffect<MemoryEffects::Allocate>(op: defOp, value: memref))
1076 // TODO: if the memref was returned by a 'call' operation, we
1077 // could still erase it if the call had no side-effects.
1078 continue;
1079 if (llvm::any_of(Range: memref.getUsers(), P: [&](Operation *ownerOp) {
1080 return !isa<AffineWriteOpInterface>(ownerOp) &&
1081 !hasSingleEffect<MemoryEffects::Free>(ownerOp, memref);
1082 }))
1083 continue;
1084
1085 // Erase all stores, the dealloc, and the alloc on the memref.
1086 for (auto *user : llvm::make_early_inc_range(Range: memref.getUsers()))
1087 user->erase();
1088 defOp->erase();
1089 }
1090
1091 // To eliminate as many loads as possible, run load CSE after eliminating
1092 // stores. Otherwise, some stores are wrongly seen as having an intervening
1093 // effect.
1094 f.walk([&](AffineReadOpInterface loadOp) {
1095 loadCSE(loadOp, opsToErase, domInfo, mayAlias);
1096 });
1097 for (auto *op : opsToErase)
1098 op->erase();
1099}
1100
1101// Checks if `op` is non dereferencing.
1102// TODO: This hardcoded check will be removed once the right interface is added.
1103static bool isDereferencingOp(Operation *op) {
1104 return isa<AffineMapAccessInterface, memref::LoadOp, memref::StoreOp>(op);
1105}
1106
1107// Perform the replacement in `op`.
1108LogicalResult mlir::affine::replaceAllMemRefUsesWith(
1109 Value oldMemRef, Value newMemRef, Operation *op,
1110 ArrayRef<Value> extraIndices, AffineMap indexRemap,
1111 ArrayRef<Value> extraOperands, ArrayRef<Value> symbolOperands,
1112 bool allowNonDereferencingOps) {
1113 unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType()).getRank();
1114 (void)newMemRefRank; // unused in opt mode
1115 unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType()).getRank();
1116 (void)oldMemRefRank; // unused in opt mode
1117 if (indexRemap) {
1118 assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
1119 "symbolic operand count mismatch");
1120 assert(indexRemap.getNumInputs() ==
1121 extraOperands.size() + oldMemRefRank + symbolOperands.size());
1122 assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
1123 } else {
1124 assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
1125 }
1126
1127 // Assert same elemental type.
1128 assert(cast<MemRefType>(oldMemRef.getType()).getElementType() ==
1129 cast<MemRefType>(newMemRef.getType()).getElementType());
1130
1131 SmallVector<unsigned, 2> usePositions;
1132 for (const auto &opEntry : llvm::enumerate(First: op->getOperands())) {
1133 if (opEntry.value() == oldMemRef)
1134 usePositions.push_back(Elt: opEntry.index());
1135 }
1136
1137 // If memref doesn't appear, nothing to do.
1138 if (usePositions.empty())
1139 return success();
1140
1141 unsigned memRefOperandPos = usePositions.front();
1142
1143 OpBuilder builder(op);
1144 // The following checks if op is dereferencing memref and performs the access
1145 // index rewrites.
1146 if (!isDereferencingOp(op)) {
1147 if (!allowNonDereferencingOps) {
1148 // Failure: memref used in a non-dereferencing context (potentially
1149 // escapes); no replacement in these cases unless allowNonDereferencingOps
1150 // is set.
1151 return failure();
1152 }
1153 for (unsigned pos : usePositions)
1154 op->setOperand(idx: pos, value: newMemRef);
1155 return success();
1156 }
1157
1158 if (usePositions.size() > 1) {
1159 // TODO: extend it for this case when needed (rare).
1160 LLVM_DEBUG(llvm::dbgs()
1161 << "multiple dereferencing uses in a single op not supported");
1162 return failure();
1163 }
1164
1165 // Perform index rewrites for the dereferencing op and then replace the op.
1166 SmallVector<Value, 4> oldMapOperands;
1167 AffineMap oldMap;
1168 unsigned oldMemRefNumIndices = oldMemRefRank;
1169 auto startIdx = op->operand_begin() + memRefOperandPos + 1;
1170 auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
1171 if (affMapAccInterface) {
1172 // If `op` implements AffineMapAccessInterface, we can get the indices by
1173 // quering the number of map operands from the operand list from a certain
1174 // offset (`memRefOperandPos` in this case).
1175 NamedAttribute oldMapAttrPair =
1176 affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
1177 oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue()).getValue();
1178 oldMemRefNumIndices = oldMap.getNumInputs();
1179 }
1180 oldMapOperands.assign(in_start: startIdx, in_end: startIdx + oldMemRefNumIndices);
1181
1182 // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
1183 SmallVector<Value, 4> oldMemRefOperands;
1184 SmallVector<Value, 4> affineApplyOps;
1185 oldMemRefOperands.reserve(N: oldMemRefRank);
1186 if (affMapAccInterface &&
1187 oldMap != builder.getMultiDimIdentityMap(rank: oldMap.getNumDims())) {
1188 for (auto resultExpr : oldMap.getResults()) {
1189 auto singleResMap = AffineMap::get(dimCount: oldMap.getNumDims(),
1190 symbolCount: oldMap.getNumSymbols(), result: resultExpr);
1191 auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
1192 oldMapOperands);
1193 oldMemRefOperands.push_back(Elt: afOp);
1194 affineApplyOps.push_back(Elt: afOp);
1195 }
1196 } else {
1197 oldMemRefOperands.assign(in_start: oldMapOperands.begin(), in_end: oldMapOperands.end());
1198 }
1199
1200 // Construct new indices as a remap of the old ones if a remapping has been
1201 // provided. The indices of a memref come right after it, i.e.,
1202 // at position memRefOperandPos + 1.
1203 SmallVector<Value, 4> remapOperands;
1204 remapOperands.reserve(N: extraOperands.size() + oldMemRefRank +
1205 symbolOperands.size());
1206 remapOperands.append(in_start: extraOperands.begin(), in_end: extraOperands.end());
1207 remapOperands.append(in_start: oldMemRefOperands.begin(), in_end: oldMemRefOperands.end());
1208 remapOperands.append(in_start: symbolOperands.begin(), in_end: symbolOperands.end());
1209
1210 SmallVector<Value, 4> remapOutputs;
1211 remapOutputs.reserve(N: oldMemRefRank);
1212 if (indexRemap &&
1213 indexRemap != builder.getMultiDimIdentityMap(rank: indexRemap.getNumDims())) {
1214 // Remapped indices.
1215 for (auto resultExpr : indexRemap.getResults()) {
1216 auto singleResMap = AffineMap::get(
1217 dimCount: indexRemap.getNumDims(), symbolCount: indexRemap.getNumSymbols(), result: resultExpr);
1218 auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
1219 remapOperands);
1220 remapOutputs.push_back(Elt: afOp);
1221 affineApplyOps.push_back(Elt: afOp);
1222 }
1223 } else {
1224 // No remapping specified.
1225 remapOutputs.assign(in_start: remapOperands.begin(), in_end: remapOperands.end());
1226 }
1227 SmallVector<Value, 4> newMapOperands;
1228 newMapOperands.reserve(N: newMemRefRank);
1229
1230 // Prepend 'extraIndices' in 'newMapOperands'.
1231 for (Value extraIndex : extraIndices) {
1232 assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
1233 "invalid memory op index");
1234 newMapOperands.push_back(Elt: extraIndex);
1235 }
1236
1237 // Append 'remapOutputs' to 'newMapOperands'.
1238 newMapOperands.append(in_start: remapOutputs.begin(), in_end: remapOutputs.end());
1239
1240 // Create new fully composed AffineMap for new op to be created.
1241 assert(newMapOperands.size() == newMemRefRank);
1242 auto newMap = builder.getMultiDimIdentityMap(rank: newMemRefRank);
1243 fullyComposeAffineMapAndOperands(&newMap, &newMapOperands);
1244 newMap = simplifyAffineMap(newMap);
1245 canonicalizeMapAndOperands(&newMap, &newMapOperands);
1246 // Remove any affine.apply's that became dead as a result of composition.
1247 for (Value value : affineApplyOps)
1248 if (value.use_empty())
1249 value.getDefiningOp()->erase();
1250
1251 OperationState state(op->getLoc(), op->getName());
1252 // Construct the new operation using this memref.
1253 state.operands.reserve(N: op->getNumOperands() + extraIndices.size());
1254 // Insert the non-memref operands.
1255 state.operands.append(in_start: op->operand_begin(),
1256 in_end: op->operand_begin() + memRefOperandPos);
1257 // Insert the new memref value.
1258 state.operands.push_back(Elt: newMemRef);
1259
1260 // Insert the new memref map operands.
1261 if (affMapAccInterface) {
1262 state.operands.append(in_start: newMapOperands.begin(), in_end: newMapOperands.end());
1263 } else {
1264 // In the case of dereferencing ops not implementing
1265 // AffineMapAccessInterface, we need to apply the values of `newMapOperands`
1266 // to the `newMap` to get the correct indices.
1267 for (unsigned i = 0; i < newMemRefRank; i++) {
1268 state.operands.push_back(Elt: builder.create<AffineApplyOp>(
1269 op->getLoc(),
1270 AffineMap::get(newMap.getNumDims(), newMap.getNumSymbols(),
1271 newMap.getResult(i)),
1272 newMapOperands));
1273 }
1274 }
1275
1276 // Insert the remaining operands unmodified.
1277 unsigned oldMapNumInputs = oldMapOperands.size();
1278 state.operands.append(in_start: op->operand_begin() + memRefOperandPos + 1 +
1279 oldMapNumInputs,
1280 in_end: op->operand_end());
1281 // Result types don't change. Both memref's are of the same elemental type.
1282 state.types.reserve(N: op->getNumResults());
1283 for (auto result : op->getResults())
1284 state.types.push_back(Elt: result.getType());
1285
1286 // Add attribute for 'newMap', other Attributes do not change.
1287 auto newMapAttr = AffineMapAttr::get(newMap);
1288 for (auto namedAttr : op->getAttrs()) {
1289 if (affMapAccInterface &&
1290 namedAttr.getName() ==
1291 affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef).getName())
1292 state.attributes.push_back(newAttribute: {namedAttr.getName(), newMapAttr});
1293 else
1294 state.attributes.push_back(newAttribute: namedAttr);
1295 }
1296
1297 // Create the new operation.
1298 auto *repOp = builder.create(state);
1299 op->replaceAllUsesWith(values&: repOp);
1300 op->erase();
1301
1302 return success();
1303}
1304
1305LogicalResult mlir::affine::replaceAllMemRefUsesWith(
1306 Value oldMemRef, Value newMemRef, ArrayRef<Value> extraIndices,
1307 AffineMap indexRemap, ArrayRef<Value> extraOperands,
1308 ArrayRef<Value> symbolOperands, Operation *domOpFilter,
1309 Operation *postDomOpFilter, bool allowNonDereferencingOps,
1310 bool replaceInDeallocOp) {
1311 unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType()).getRank();
1312 (void)newMemRefRank; // unused in opt mode
1313 unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType()).getRank();
1314 (void)oldMemRefRank;
1315 if (indexRemap) {
1316 assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
1317 "symbol operand count mismatch");
1318 assert(indexRemap.getNumInputs() ==
1319 extraOperands.size() + oldMemRefRank + symbolOperands.size());
1320 assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
1321 } else {
1322 assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
1323 }
1324
1325 // Assert same elemental type.
1326 assert(cast<MemRefType>(oldMemRef.getType()).getElementType() ==
1327 cast<MemRefType>(newMemRef.getType()).getElementType());
1328
1329 std::unique_ptr<DominanceInfo> domInfo;
1330 std::unique_ptr<PostDominanceInfo> postDomInfo;
1331 if (domOpFilter)
1332 domInfo = std::make_unique<DominanceInfo>(
1333 args: domOpFilter->getParentOfType<FunctionOpInterface>());
1334
1335 if (postDomOpFilter)
1336 postDomInfo = std::make_unique<PostDominanceInfo>(
1337 args: postDomOpFilter->getParentOfType<FunctionOpInterface>());
1338
1339 // Walk all uses of old memref; collect ops to perform replacement. We use a
1340 // DenseSet since an operation could potentially have multiple uses of a
1341 // memref (although rare), and the replacement later is going to erase ops.
1342 DenseSet<Operation *> opsToReplace;
1343 for (auto *op : oldMemRef.getUsers()) {
1344 // Skip this use if it's not dominated by domOpFilter.
1345 if (domOpFilter && !domInfo->dominates(a: domOpFilter, b: op))
1346 continue;
1347
1348 // Skip this use if it's not post-dominated by postDomOpFilter.
1349 if (postDomOpFilter && !postDomInfo->postDominates(a: postDomOpFilter, b: op))
1350 continue;
1351
1352 // Skip dealloc's - no replacement is necessary, and a memref replacement
1353 // at other uses doesn't hurt these dealloc's.
1354 if (hasSingleEffect<MemoryEffects::Free>(op, value: oldMemRef) &&
1355 !replaceInDeallocOp)
1356 continue;
1357
1358 // Check if the memref was used in a non-dereferencing context. It is fine
1359 // for the memref to be used in a non-dereferencing way outside of the
1360 // region where this replacement is happening.
1361 if (!isa<AffineMapAccessInterface>(*op)) {
1362 if (!allowNonDereferencingOps) {
1363 LLVM_DEBUG(llvm::dbgs()
1364 << "Memref replacement failed: non-deferencing memref op: \n"
1365 << *op << '\n');
1366 return failure();
1367 }
1368 // Non-dereferencing ops with the MemRefsNormalizable trait are
1369 // supported for replacement.
1370 if (!op->hasTrait<OpTrait::MemRefsNormalizable>()) {
1371 LLVM_DEBUG(llvm::dbgs() << "Memref replacement failed: use without a "
1372 "memrefs normalizable trait: \n"
1373 << *op << '\n');
1374 return failure();
1375 }
1376 }
1377
1378 // We'll first collect and then replace --- since replacement erases the op
1379 // that has the use, and that op could be postDomFilter or domFilter itself!
1380 opsToReplace.insert(V: op);
1381 }
1382
1383 for (auto *op : opsToReplace) {
1384 if (failed(Result: replaceAllMemRefUsesWith(
1385 oldMemRef, newMemRef, op, extraIndices, indexRemap, extraOperands,
1386 symbolOperands, allowNonDereferencingOps)))
1387 llvm_unreachable("memref replacement guaranteed to succeed here");
1388 }
1389
1390 return success();
1391}
1392
1393/// Given an operation, inserts one or more single result affine
1394/// apply operations, results of which are exclusively used by this operation
1395/// operation. The operands of these newly created affine apply ops are
1396/// guaranteed to be loop iterators or terminal symbols of a function.
1397///
1398/// Before
1399///
1400/// affine.for %i = 0 to #map(%N)
1401/// %idx = affine.apply (d0) -> (d0 mod 2) (%i)
1402/// "send"(%idx, %A, ...)
1403/// "compute"(%idx)
1404///
1405/// After
1406///
1407/// affine.for %i = 0 to #map(%N)
1408/// %idx = affine.apply (d0) -> (d0 mod 2) (%i)
1409/// "send"(%idx, %A, ...)
1410/// %idx_ = affine.apply (d0) -> (d0 mod 2) (%i)
1411/// "compute"(%idx_)
1412///
1413/// This allows applying different transformations on send and compute (for eg.
1414/// different shifts/delays).
1415///
1416/// Returns nullptr either if none of opInst's operands were the result of an
1417/// affine.apply and thus there was no affine computation slice to create, or if
1418/// all the affine.apply op's supplying operands to this opInst did not have any
1419/// uses besides this opInst; otherwise returns the list of affine.apply
1420/// operations created in output argument `sliceOps`.
1421void mlir::affine::createAffineComputationSlice(
1422 Operation *opInst, SmallVectorImpl<AffineApplyOp> *sliceOps) {
1423 // Collect all operands that are results of affine apply ops.
1424 SmallVector<Value, 4> subOperands;
1425 subOperands.reserve(N: opInst->getNumOperands());
1426 for (auto operand : opInst->getOperands())
1427 if (isa_and_nonnull<AffineApplyOp>(Val: operand.getDefiningOp()))
1428 subOperands.push_back(Elt: operand);
1429
1430 // Gather sequence of AffineApplyOps reachable from 'subOperands'.
1431 SmallVector<Operation *, 4> affineApplyOps;
1432 getReachableAffineApplyOps(operands: subOperands, affineApplyOps);
1433 // Skip transforming if there are no affine maps to compose.
1434 if (affineApplyOps.empty())
1435 return;
1436
1437 // Check if all uses of the affine apply op's lie only in this op op, in
1438 // which case there would be nothing to do.
1439 bool localized = true;
1440 for (auto *op : affineApplyOps) {
1441 for (auto result : op->getResults()) {
1442 for (auto *user : result.getUsers()) {
1443 if (user != opInst) {
1444 localized = false;
1445 break;
1446 }
1447 }
1448 }
1449 }
1450 if (localized)
1451 return;
1452
1453 OpBuilder builder(opInst);
1454 SmallVector<Value, 4> composedOpOperands(subOperands);
1455 auto composedMap = builder.getMultiDimIdentityMap(rank: composedOpOperands.size());
1456 fullyComposeAffineMapAndOperands(map: &composedMap, operands: &composedOpOperands);
1457
1458 // Create an affine.apply for each of the map results.
1459 sliceOps->reserve(composedMap.getNumResults());
1460 for (auto resultExpr : composedMap.getResults()) {
1461 auto singleResMap = AffineMap::get(dimCount: composedMap.getNumDims(),
1462 symbolCount: composedMap.getNumSymbols(), result: resultExpr);
1463 sliceOps->push_back(builder.create<AffineApplyOp>(
1464 opInst->getLoc(), singleResMap, composedOpOperands));
1465 }
1466
1467 // Construct the new operands that include the results from the composed
1468 // affine apply op above instead of existing ones (subOperands). So, they
1469 // differ from opInst's operands only for those operands in 'subOperands', for
1470 // which they will be replaced by the corresponding one from 'sliceOps'.
1471 SmallVector<Value, 4> newOperands(opInst->getOperands());
1472 for (Value &operand : newOperands) {
1473 // Replace the subOperands from among the new operands.
1474 unsigned j, f;
1475 for (j = 0, f = subOperands.size(); j < f; j++) {
1476 if (operand == subOperands[j])
1477 break;
1478 }
1479 if (j < subOperands.size())
1480 operand = (*sliceOps)[j];
1481 }
1482 for (unsigned idx = 0, e = newOperands.size(); idx < e; idx++)
1483 opInst->setOperand(idx, value: newOperands[idx]);
1484}
1485
1486/// Enum to set patterns of affine expr in tiled-layout map.
1487/// TileFloorDiv: <dim expr> div <tile size>
1488/// TileMod: <dim expr> mod <tile size>
1489/// TileNone: None of the above
1490/// Example:
1491/// #tiled_2d_128x256 = affine_map<(d0, d1)
1492/// -> (d0 div 128, d1 div 256, d0 mod 128, d1 mod 256)>
1493/// "d0 div 128" and "d1 div 256" ==> TileFloorDiv
1494/// "d0 mod 128" and "d1 mod 256" ==> TileMod
1495enum TileExprPattern { TileFloorDiv, TileMod, TileNone };
1496
1497/// Check if `map` is a tiled layout. In the tiled layout, specific k dimensions
1498/// being floordiv'ed by respective tile sizes appeare in a mod with the same
1499/// tile sizes, and no other expression involves those k dimensions. This
1500/// function stores a vector of tuples (`tileSizePos`) including AffineExpr for
1501/// tile size, positions of corresponding `floordiv` and `mod`. If it is not a
1502/// tiled layout, an empty vector is returned.
1503static LogicalResult getTileSizePos(
1504 AffineMap map,
1505 SmallVectorImpl<std::tuple<AffineExpr, unsigned, unsigned>> &tileSizePos) {
1506 // Create `floordivExprs` which is a vector of tuples including LHS and RHS of
1507 // `floordiv` and its position in `map` output.
1508 // Example: #tiled_2d_128x256 = affine_map<(d0, d1)
1509 // -> (d0 div 128, d1 div 256, d0 mod 128, d1 mod 256)>
1510 // In this example, `floordivExprs` includes {d0, 128, 0} and {d1, 256, 1}.
1511 SmallVector<std::tuple<AffineExpr, AffineExpr, unsigned>, 4> floordivExprs;
1512 unsigned pos = 0;
1513 for (AffineExpr expr : map.getResults()) {
1514 if (expr.getKind() == AffineExprKind::FloorDiv) {
1515 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
1516 if (isa<AffineConstantExpr>(Val: binaryExpr.getRHS()))
1517 floordivExprs.emplace_back(
1518 Args: std::make_tuple(args: binaryExpr.getLHS(), args: binaryExpr.getRHS(), args&: pos));
1519 }
1520 pos++;
1521 }
1522 // Not tiled layout if `floordivExprs` is empty.
1523 if (floordivExprs.empty()) {
1524 tileSizePos = SmallVector<std::tuple<AffineExpr, unsigned, unsigned>>{};
1525 return success();
1526 }
1527
1528 // Check if LHS of `floordiv` is used in LHS of `mod`. If not used, `map` is
1529 // not tiled layout.
1530 for (std::tuple<AffineExpr, AffineExpr, unsigned> fexpr : floordivExprs) {
1531 AffineExpr floordivExprLHS = std::get<0>(t&: fexpr);
1532 AffineExpr floordivExprRHS = std::get<1>(t&: fexpr);
1533 unsigned floordivPos = std::get<2>(t&: fexpr);
1534
1535 // Walk affinexpr of `map` output except `fexpr`, and check if LHS and RHS
1536 // of `fexpr` are used in LHS and RHS of `mod`. If LHS of `fexpr` is used
1537 // other expr, the map is not tiled layout. Example of non tiled layout:
1538 // affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 floordiv 256)>
1539 // affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 mod 128)>
1540 // affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 mod 256, d2 mod
1541 // 256)>
1542 bool found = false;
1543 pos = 0;
1544 for (AffineExpr expr : map.getResults()) {
1545 bool notTiled = false;
1546 if (pos != floordivPos) {
1547 expr.walk(callback: [&](AffineExpr e) {
1548 if (e == floordivExprLHS) {
1549 if (expr.getKind() == AffineExprKind::Mod) {
1550 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
1551 // If LHS and RHS of `mod` are the same with those of floordiv.
1552 if (floordivExprLHS == binaryExpr.getLHS() &&
1553 floordivExprRHS == binaryExpr.getRHS()) {
1554 // Save tile size (RHS of `mod`), and position of `floordiv` and
1555 // `mod` if same expr with `mod` is not found yet.
1556 if (!found) {
1557 tileSizePos.emplace_back(
1558 Args: std::make_tuple(args: binaryExpr.getRHS(), args&: floordivPos, args&: pos));
1559 found = true;
1560 } else {
1561 // Non tiled layout: Have multilpe `mod` with the same LHS.
1562 // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2
1563 // mod 256, d2 mod 256)>
1564 notTiled = true;
1565 }
1566 } else {
1567 // Non tiled layout: RHS of `mod` is different from `floordiv`.
1568 // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2
1569 // mod 128)>
1570 notTiled = true;
1571 }
1572 } else {
1573 // Non tiled layout: LHS is the same, but not `mod`.
1574 // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2
1575 // floordiv 256)>
1576 notTiled = true;
1577 }
1578 }
1579 });
1580 }
1581 if (notTiled) {
1582 tileSizePos = SmallVector<std::tuple<AffineExpr, unsigned, unsigned>>{};
1583 return success();
1584 }
1585 pos++;
1586 }
1587 }
1588 return success();
1589}
1590
1591/// Check if `dim` dimension of memrefType with `layoutMap` becomes dynamic
1592/// after normalization. Dimensions that include dynamic dimensions in the map
1593/// output will become dynamic dimensions. Return true if `dim` is dynamic
1594/// dimension.
1595///
1596/// Example:
1597/// #map0 = affine_map<(d0, d1) -> (d0, d1 floordiv 32, d1 mod 32)>
1598///
1599/// If d1 is dynamic dimension, 2nd and 3rd dimension of map output are dynamic.
1600/// memref<4x?xf32, #map0> ==> memref<4x?x?xf32>
1601static bool
1602isNormalizedMemRefDynamicDim(unsigned dim, AffineMap layoutMap,
1603 SmallVectorImpl<unsigned> &inMemrefTypeDynDims) {
1604 AffineExpr expr = layoutMap.getResults()[dim];
1605 // Check if affine expr of the dimension includes dynamic dimension of input
1606 // memrefType.
1607 MLIRContext *context = layoutMap.getContext();
1608 return expr
1609 .walk(callback: [&](AffineExpr e) {
1610 if (isa<AffineDimExpr>(Val: e) &&
1611 llvm::any_of(Range&: inMemrefTypeDynDims, P: [&](unsigned dim) {
1612 return e == getAffineDimExpr(position: dim, context);
1613 }))
1614 return WalkResult::interrupt();
1615 return WalkResult::advance();
1616 })
1617 .wasInterrupted();
1618}
1619
1620/// Create affine expr to calculate dimension size for a tiled-layout map.
1621static AffineExpr createDimSizeExprForTiledLayout(AffineExpr oldMapOutput,
1622 TileExprPattern pat) {
1623 // Create map output for the patterns.
1624 // "floordiv <tile size>" ==> "ceildiv <tile size>"
1625 // "mod <tile size>" ==> "<tile size>"
1626 AffineExpr newMapOutput;
1627 AffineBinaryOpExpr binaryExpr = nullptr;
1628 switch (pat) {
1629 case TileExprPattern::TileMod:
1630 binaryExpr = cast<AffineBinaryOpExpr>(Val&: oldMapOutput);
1631 newMapOutput = binaryExpr.getRHS();
1632 break;
1633 case TileExprPattern::TileFloorDiv:
1634 binaryExpr = cast<AffineBinaryOpExpr>(Val&: oldMapOutput);
1635 newMapOutput = getAffineBinaryOpExpr(
1636 kind: AffineExprKind::CeilDiv, lhs: binaryExpr.getLHS(), rhs: binaryExpr.getRHS());
1637 break;
1638 default:
1639 newMapOutput = oldMapOutput;
1640 }
1641 return newMapOutput;
1642}
1643
1644/// Create new maps to calculate each dimension size of `newMemRefType`, and
1645/// create `newDynamicSizes` from them by using AffineApplyOp.
1646///
1647/// Steps for normalizing dynamic memrefs for a tiled layout map
1648/// Example:
1649/// #map0 = affine_map<(d0, d1) -> (d0, d1 floordiv 32, d1 mod 32)>
1650/// %0 = dim %arg0, %c1 :memref<4x?xf32>
1651/// %1 = alloc(%0) : memref<4x?xf32, #map0>
1652///
1653/// (Before this function)
1654/// 1. Check if `map`(#map0) is a tiled layout using `getTileSizePos()`. Only
1655/// single layout map is supported.
1656///
1657/// 2. Create normalized memrefType using `isNormalizedMemRefDynamicDim()`. It
1658/// is memref<4x?x?xf32> in the above example.
1659///
1660/// (In this function)
1661/// 3. Create new maps to calculate each dimension of the normalized memrefType
1662/// using `createDimSizeExprForTiledLayout()`. In the tiled layout, the
1663/// dimension size can be calculated by replacing "floordiv <tile size>" with
1664/// "ceildiv <tile size>" and "mod <tile size>" with "<tile size>".
1665/// - New map in the above example
1666/// #map0 = affine_map<(d0, d1) -> (d0)>
1667/// #map1 = affine_map<(d0, d1) -> (d1 ceildiv 32)>
1668/// #map2 = affine_map<(d0, d1) -> (32)>
1669///
1670/// 4. Create AffineApplyOp to apply the new maps. The output of AffineApplyOp
1671/// is used in dynamicSizes of new AllocOp.
1672/// %0 = dim %arg0, %c1 : memref<4x?xf32>
1673/// %c4 = arith.constant 4 : index
1674/// %1 = affine.apply #map1(%c4, %0)
1675/// %2 = affine.apply #map2(%c4, %0)
1676template <typename AllocLikeOp>
1677static void createNewDynamicSizes(MemRefType oldMemRefType,
1678 MemRefType newMemRefType, AffineMap map,
1679 AllocLikeOp allocOp, OpBuilder b,
1680 SmallVectorImpl<Value> &newDynamicSizes) {
1681 // Create new input for AffineApplyOp.
1682 SmallVector<Value, 4> inAffineApply;
1683 ArrayRef<int64_t> oldMemRefShape = oldMemRefType.getShape();
1684 unsigned dynIdx = 0;
1685 for (unsigned d = 0; d < oldMemRefType.getRank(); ++d) {
1686 if (oldMemRefShape[d] < 0) {
1687 // Use dynamicSizes of allocOp for dynamic dimension.
1688 inAffineApply.emplace_back(allocOp.getDynamicSizes()[dynIdx]);
1689 dynIdx++;
1690 } else {
1691 // Create ConstantOp for static dimension.
1692 auto constantAttr = b.getIntegerAttr(b.getIndexType(), oldMemRefShape[d]);
1693 inAffineApply.emplace_back(
1694 b.create<arith::ConstantOp>(allocOp.getLoc(), constantAttr));
1695 }
1696 }
1697
1698 // Create new map to calculate each dimension size of new memref for each
1699 // original map output. Only for dynamic dimesion of `newMemRefType`.
1700 unsigned newDimIdx = 0;
1701 ArrayRef<int64_t> newMemRefShape = newMemRefType.getShape();
1702 SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
1703 (void)getTileSizePos(map, tileSizePos);
1704 for (AffineExpr expr : map.getResults()) {
1705 if (newMemRefShape[newDimIdx] < 0) {
1706 // Create new maps to calculate each dimension size of new memref.
1707 enum TileExprPattern pat = TileExprPattern::TileNone;
1708 for (auto pos : tileSizePos) {
1709 if (newDimIdx == std::get<1>(t&: pos))
1710 pat = TileExprPattern::TileFloorDiv;
1711 else if (newDimIdx == std::get<2>(t&: pos))
1712 pat = TileExprPattern::TileMod;
1713 }
1714 AffineExpr newMapOutput = createDimSizeExprForTiledLayout(oldMapOutput: expr, pat);
1715 AffineMap newMap =
1716 AffineMap::get(dimCount: map.getNumInputs(), symbolCount: map.getNumSymbols(), result: newMapOutput);
1717 Value affineApp =
1718 b.create<AffineApplyOp>(allocOp.getLoc(), newMap, inAffineApply);
1719 newDynamicSizes.emplace_back(Args&: affineApp);
1720 }
1721 newDimIdx++;
1722 }
1723}
1724
1725template <typename AllocLikeOp>
1726LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp allocOp) {
1727 MemRefType memrefType = allocOp.getType();
1728 OpBuilder b(allocOp);
1729
1730 // Fetch a new memref type after normalizing the old memref to have an
1731 // identity map layout.
1732 MemRefType newMemRefType = normalizeMemRefType(memrefType);
1733 if (newMemRefType == memrefType)
1734 // Either memrefType already had an identity map or the map couldn't be
1735 // transformed to an identity map.
1736 return failure();
1737
1738 Value oldMemRef = allocOp.getResult();
1739
1740 SmallVector<Value, 4> symbolOperands(allocOp.getSymbolOperands());
1741 AffineMap layoutMap = memrefType.getLayout().getAffineMap();
1742 AllocLikeOp newAlloc;
1743 // Check if `layoutMap` is a tiled layout. Only single layout map is
1744 // supported for normalizing dynamic memrefs.
1745 SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
1746 (void)getTileSizePos(map: layoutMap, tileSizePos);
1747 if (newMemRefType.getNumDynamicDims() > 0 && !tileSizePos.empty()) {
1748 auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
1749 SmallVector<Value, 4> newDynamicSizes;
1750 createNewDynamicSizes(oldMemRefType, newMemRefType, layoutMap, allocOp, b,
1751 newDynamicSizes);
1752 // Add the new dynamic sizes in new AllocOp.
1753 newAlloc =
1754 b.create<AllocLikeOp>(allocOp.getLoc(), newMemRefType, newDynamicSizes,
1755 allocOp.getAlignmentAttr());
1756 } else {
1757 newAlloc = b.create<AllocLikeOp>(allocOp.getLoc(), newMemRefType,
1758 allocOp.getAlignmentAttr());
1759 }
1760 // Replace all uses of the old memref.
1761 if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
1762 /*extraIndices=*/{},
1763 /*indexRemap=*/layoutMap,
1764 /*extraOperands=*/{},
1765 /*symbolOperands=*/symbolOperands,
1766 /*domOpFilter=*/nullptr,
1767 /*postDomOpFilter=*/nullptr,
1768 /*allowNonDereferencingOps=*/true))) {
1769 // If it failed (due to escapes for example), bail out.
1770 newAlloc.erase();
1771 return failure();
1772 }
1773 // Replace any uses of the original alloc op and erase it. All remaining uses
1774 // have to be dealloc's; RAMUW above would've failed otherwise.
1775 assert(llvm::all_of(oldMemRef.getUsers(), [&](Operation *op) {
1776 return hasSingleEffect<MemoryEffects::Free>(op, oldMemRef);
1777 }));
1778 oldMemRef.replaceAllUsesWith(newValue: newAlloc);
1779 allocOp.erase();
1780 return success();
1781}
1782
1783LogicalResult
1784mlir::affine::normalizeMemRef(memref::ReinterpretCastOp reinterpretCastOp) {
1785 MemRefType memrefType = reinterpretCastOp.getType();
1786 AffineMap oldLayoutMap = memrefType.getLayout().getAffineMap();
1787 Value oldMemRef = reinterpretCastOp.getResult();
1788
1789 // If `oldLayoutMap` is identity, `memrefType` is already normalized.
1790 if (oldLayoutMap.isIdentity())
1791 return success();
1792
1793 // Fetch a new memref type after normalizing the old memref to have an
1794 // identity map layout.
1795 MemRefType newMemRefType = normalizeMemRefType(memrefType);
1796 if (newMemRefType == memrefType)
1797 // `oldLayoutMap` couldn't be transformed to an identity map.
1798 return failure();
1799
1800 uint64_t newRank = newMemRefType.getRank();
1801 SmallVector<Value> mapOperands(oldLayoutMap.getNumDims() +
1802 oldLayoutMap.getNumSymbols());
1803 SmallVector<Value> oldStrides = reinterpretCastOp.getStrides();
1804 Location loc = reinterpretCastOp.getLoc();
1805 // As `newMemRefType` is normalized, it is unit strided.
1806 SmallVector<int64_t> newStaticStrides(newRank, 1);
1807 SmallVector<int64_t> newStaticOffsets(newRank, 0);
1808 ArrayRef<int64_t> oldShape = memrefType.getShape();
1809 ValueRange oldSizes = reinterpretCastOp.getSizes();
1810 unsigned idx = 0;
1811 OpBuilder b(reinterpretCastOp);
1812 // Collect the map operands which will be used to compute the new normalized
1813 // memref shape.
1814 for (unsigned i = 0, e = memrefType.getRank(); i < e; i++) {
1815 if (memrefType.isDynamicDim(i))
1816 mapOperands[i] =
1817 b.create<arith::SubIOp>(loc, oldSizes[0].getType(), oldSizes[idx++],
1818 b.create<arith::ConstantIndexOp>(loc, 1));
1819 else
1820 mapOperands[i] = b.create<arith::ConstantIndexOp>(location: loc, args: oldShape[i] - 1);
1821 }
1822 for (unsigned i = 0, e = oldStrides.size(); i < e; i++)
1823 mapOperands[memrefType.getRank() + i] = oldStrides[i];
1824 SmallVector<Value> newSizes;
1825 ArrayRef<int64_t> newShape = newMemRefType.getShape();
1826 // Compute size along all the dimensions of the new normalized memref.
1827 for (unsigned i = 0; i < newRank; i++) {
1828 if (!newMemRefType.isDynamicDim(i))
1829 continue;
1830 newSizes.push_back(b.create<AffineApplyOp>(
1831 loc,
1832 AffineMap::get(dimCount: oldLayoutMap.getNumDims(), symbolCount: oldLayoutMap.getNumSymbols(),
1833 result: oldLayoutMap.getResult(idx: i)),
1834 mapOperands));
1835 }
1836 for (unsigned i = 0, e = newSizes.size(); i < e; i++) {
1837 newSizes[i] =
1838 b.create<arith::AddIOp>(loc, newSizes[i].getType(), newSizes[i],
1839 b.create<arith::ConstantIndexOp>(loc, 1));
1840 }
1841 // Create the new reinterpret_cast op.
1842 auto newReinterpretCast = b.create<memref::ReinterpretCastOp>(
1843 loc, newMemRefType, reinterpretCastOp.getSource(),
1844 /*offsets=*/ValueRange(), newSizes,
1845 /*strides=*/ValueRange(),
1846 /*static_offsets=*/newStaticOffsets,
1847 /*static_sizes=*/newShape,
1848 /*static_strides=*/newStaticStrides);
1849
1850 // Replace all uses of the old memref.
1851 if (failed(replaceAllMemRefUsesWith(oldMemRef,
1852 /*newMemRef=*/newReinterpretCast,
1853 /*extraIndices=*/{},
1854 /*indexRemap=*/oldLayoutMap,
1855 /*extraOperands=*/{},
1856 /*symbolOperands=*/oldStrides,
1857 /*domOpFilter=*/nullptr,
1858 /*postDomOpFilter=*/nullptr,
1859 /*allowNonDereferencingOps=*/true))) {
1860 // If it failed (due to escapes for example), bail out.
1861 newReinterpretCast.erase();
1862 return failure();
1863 }
1864
1865 oldMemRef.replaceAllUsesWith(newValue: newReinterpretCast);
1866 reinterpretCastOp.erase();
1867 return success();
1868}
1869
1870template LogicalResult
1871mlir::affine::normalizeMemRef<memref::AllocaOp>(memref::AllocaOp op);
1872template LogicalResult
1873mlir::affine::normalizeMemRef<memref::AllocOp>(memref::AllocOp op);
1874
1875MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
1876 unsigned rank = memrefType.getRank();
1877 if (rank == 0)
1878 return memrefType;
1879
1880 if (memrefType.getLayout().isIdentity()) {
1881 // Either no maps is associated with this memref or this memref has
1882 // a trivial (identity) map.
1883 return memrefType;
1884 }
1885 AffineMap layoutMap = memrefType.getLayout().getAffineMap();
1886 unsigned numSymbolicOperands = layoutMap.getNumSymbols();
1887
1888 // We don't do any checks for one-to-one'ness; we assume that it is
1889 // one-to-one.
1890
1891 // Normalize only static memrefs and dynamic memrefs with a tiled-layout map
1892 // for now.
1893 // TODO: Normalize the other types of dynamic memrefs.
1894 SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
1895 (void)getTileSizePos(map: layoutMap, tileSizePos);
1896 if (memrefType.getNumDynamicDims() > 0 && tileSizePos.empty())
1897 return memrefType;
1898
1899 // We have a single map that is not an identity map. Create a new memref
1900 // with the right shape and an identity layout map.
1901 ArrayRef<int64_t> shape = memrefType.getShape();
1902 // FlatAffineValueConstraint may later on use symbolicOperands.
1903 FlatAffineValueConstraints fac(rank, numSymbolicOperands);
1904 SmallVector<unsigned, 4> memrefTypeDynDims;
1905 for (unsigned d = 0; d < rank; ++d) {
1906 // Use constraint system only in static dimensions.
1907 if (shape[d] > 0) {
1908 fac.addBound(type: BoundType::LB, pos: d, value: 0);
1909 fac.addBound(type: BoundType::UB, pos: d, value: shape[d] - 1);
1910 } else {
1911 memrefTypeDynDims.emplace_back(Args&: d);
1912 }
1913 }
1914 // We compose this map with the original index (logical) space to derive
1915 // the upper bounds for the new index space.
1916 unsigned newRank = layoutMap.getNumResults();
1917 if (failed(Result: fac.composeMatchingMap(other: layoutMap)))
1918 return memrefType;
1919 // TODO: Handle semi-affine maps.
1920 // Project out the old data dimensions.
1921 fac.projectOut(pos: newRank, num: fac.getNumVars() - newRank - fac.getNumLocalVars());
1922 SmallVector<int64_t, 4> newShape(newRank);
1923 MLIRContext *context = memrefType.getContext();
1924 for (unsigned d = 0; d < newRank; ++d) {
1925 // Check if this dimension is dynamic.
1926 if (isNormalizedMemRefDynamicDim(dim: d, layoutMap, inMemrefTypeDynDims&: memrefTypeDynDims)) {
1927 newShape[d] = ShapedType::kDynamic;
1928 continue;
1929 }
1930 // The lower bound for the shape is always zero.
1931 std::optional<int64_t> ubConst = fac.getConstantBound64(type: BoundType::UB, pos: d);
1932 // For a static memref and an affine map with no symbols, this is
1933 // always bounded. However, when we have symbols, we may not be able to
1934 // obtain a constant upper bound. Also, mapping to a negative space is
1935 // invalid for normalization.
1936 if (!ubConst.has_value() || *ubConst < 0) {
1937 LLVM_DEBUG(llvm::dbgs()
1938 << "can't normalize map due to unknown/invalid upper bound");
1939 return memrefType;
1940 }
1941 // If dimension of new memrefType is dynamic, the value is -1.
1942 newShape[d] = *ubConst + 1;
1943 }
1944
1945 // Create the new memref type after trivializing the old layout map.
1946 auto newMemRefType =
1947 MemRefType::Builder(memrefType)
1948 .setShape(newShape)
1949 .setLayout(AffineMapAttr::get(
1950 AffineMap::getMultiDimIdentityMap(newRank, context)));
1951 return newMemRefType;
1952}
1953
1954DivModValue mlir::affine::getDivMod(OpBuilder &b, Location loc, Value lhs,
1955 Value rhs) {
1956 DivModValue result;
1957 AffineExpr d0, d1;
1958 bindDims(ctx: b.getContext(), exprs&: d0, exprs&: d1);
1959 result.quotient =
1960 affine::makeComposedAffineApply(b, loc, d0.floorDiv(other: d1), {lhs, rhs});
1961 result.remainder =
1962 affine::makeComposedAffineApply(b, loc, d0 % d1, {lhs, rhs});
1963 return result;
1964}
1965
1966/// Create an affine map that computes `lhs` * `rhs`, composing in any other
1967/// affine maps.
1968static FailureOr<OpFoldResult> composedAffineMultiply(OpBuilder &b,
1969 Location loc,
1970 OpFoldResult lhs,
1971 OpFoldResult rhs) {
1972 AffineExpr s0, s1;
1973 bindSymbols(ctx: b.getContext(), exprs&: s0, exprs&: s1);
1974 return makeComposedFoldedAffineApply(b, loc, expr: s0 * s1, operands: {lhs, rhs});
1975}
1976
1977FailureOr<SmallVector<Value>>
1978mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
1979 ArrayRef<Value> basis, bool hasOuterBound) {
1980 if (hasOuterBound)
1981 basis = basis.drop_front();
1982
1983 // Note: the divisors are backwards due to the scan.
1984 SmallVector<Value> divisors;
1985 OpFoldResult basisProd = b.getIndexAttr(1);
1986 for (OpFoldResult basisElem : llvm::reverse(C&: basis)) {
1987 FailureOr<OpFoldResult> nextProd =
1988 composedAffineMultiply(b, loc, lhs: basisElem, rhs: basisProd);
1989 if (failed(Result: nextProd))
1990 return failure();
1991 basisProd = *nextProd;
1992 divisors.push_back(Elt: getValueOrCreateConstantIndexOp(b, loc, ofr: basisProd));
1993 }
1994
1995 SmallVector<Value> results;
1996 results.reserve(N: divisors.size() + 1);
1997 Value residual = linearIndex;
1998 for (Value divisor : llvm::reverse(C&: divisors)) {
1999 DivModValue divMod = getDivMod(b, loc, lhs: residual, rhs: divisor);
2000 results.push_back(Elt: divMod.quotient);
2001 residual = divMod.remainder;
2002 }
2003 results.push_back(Elt: residual);
2004 return results;
2005}
2006
2007FailureOr<SmallVector<Value>>
2008mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
2009 ArrayRef<OpFoldResult> basis,
2010 bool hasOuterBound) {
2011 if (hasOuterBound)
2012 basis = basis.drop_front();
2013
2014 // Note: the divisors are backwards due to the scan.
2015 SmallVector<Value> divisors;
2016 OpFoldResult basisProd = b.getIndexAttr(1);
2017 for (OpFoldResult basisElem : llvm::reverse(C&: basis)) {
2018 FailureOr<OpFoldResult> nextProd =
2019 composedAffineMultiply(b, loc, lhs: basisElem, rhs: basisProd);
2020 if (failed(Result: nextProd))
2021 return failure();
2022 basisProd = *nextProd;
2023 divisors.push_back(Elt: getValueOrCreateConstantIndexOp(b, loc, ofr: basisProd));
2024 }
2025
2026 SmallVector<Value> results;
2027 results.reserve(N: divisors.size() + 1);
2028 Value residual = linearIndex;
2029 for (Value divisor : llvm::reverse(C&: divisors)) {
2030 DivModValue divMod = getDivMod(b, loc, lhs: residual, rhs: divisor);
2031 results.push_back(Elt: divMod.quotient);
2032 residual = divMod.remainder;
2033 }
2034 results.push_back(Elt: residual);
2035 return results;
2036}
2037
2038OpFoldResult mlir::affine::linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
2039 ArrayRef<OpFoldResult> basis,
2040 ImplicitLocOpBuilder &builder) {
2041 return linearizeIndex(builder, loc: builder.getLoc(), multiIndex, basis);
2042}
2043
2044OpFoldResult mlir::affine::linearizeIndex(OpBuilder &builder, Location loc,
2045 ArrayRef<OpFoldResult> multiIndex,
2046 ArrayRef<OpFoldResult> basis) {
2047 assert(multiIndex.size() == basis.size() ||
2048 multiIndex.size() == basis.size() + 1);
2049 SmallVector<AffineExpr> basisAffine;
2050
2051 // Add a fake initial size in order to make the later index linearization
2052 // computations line up if an outer bound is not provided.
2053 if (multiIndex.size() == basis.size() + 1)
2054 basisAffine.push_back(Elt: getAffineConstantExpr(constant: 1, context: builder.getContext()));
2055
2056 for (size_t i = 0; i < basis.size(); ++i) {
2057 basisAffine.push_back(Elt: getAffineSymbolExpr(position: i, context: builder.getContext()));
2058 }
2059
2060 SmallVector<AffineExpr> stridesAffine = computeStrides(sizes: basisAffine);
2061 SmallVector<OpFoldResult> strides;
2062 strides.reserve(N: stridesAffine.size());
2063 llvm::transform(Range&: stridesAffine, d_first: std::back_inserter(x&: strides),
2064 F: [&builder, &basis, loc](AffineExpr strideExpr) {
2065 return affine::makeComposedFoldedAffineApply(
2066 b&: builder, loc, expr: strideExpr, operands: basis);
2067 });
2068
2069 auto &&[linearIndexExpr, multiIndexAndStrides] = computeLinearIndex(
2070 sourceOffset: OpFoldResult(builder.getIndexAttr(0)), strides, indices: multiIndex);
2071 return affine::makeComposedFoldedAffineApply(builder, loc, linearIndexExpr,
2072 multiIndexAndStrides);
2073}
2074

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Affine/Utils/Utils.cpp