1//===- IndependenceTransforms.cpp - Make ops independent of values --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
10
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/Affine/Transforms/Transforms.h"
13#include "mlir/Dialect/MemRef/IR/MemRef.h"
14#include "mlir/Dialect/Utils/StaticValueUtils.h"
15#include "mlir/Interfaces/ValueBoundsOpInterface.h"
16
17using namespace mlir;
18using namespace mlir::memref;
19
20/// Make the given OpFoldResult independent of all independencies.
21static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
22 OpFoldResult ofr,
23 ValueRange independencies) {
24 if (ofr.is<Attribute>())
25 return ofr;
26 AffineMap boundMap;
27 ValueDimList mapOperands;
28 if (failed(result: ValueBoundsConstraintSet::computeIndependentBound(
29 resultMap&: boundMap, mapOperands, type: presburger::BoundType::UB, var: ofr, independencies,
30 /*closedUB=*/true)))
31 return failure();
32 return affine::materializeComputedBound(b, loc, boundMap, mapOperands);
33}
34
35FailureOr<Value> memref::buildIndependentOp(OpBuilder &b,
36 memref::AllocaOp allocaOp,
37 ValueRange independencies) {
38 OpBuilder::InsertionGuard g(b);
39 b.setInsertionPoint(allocaOp);
40 Location loc = allocaOp.getLoc();
41
42 SmallVector<OpFoldResult> newSizes;
43 for (OpFoldResult ofr : allocaOp.getMixedSizes()) {
44 auto ub = makeIndependent(b, loc, ofr, independencies);
45 if (failed(ub))
46 return failure();
47 newSizes.push_back(*ub);
48 }
49
50 // Return existing memref::AllocaOp if nothing has changed.
51 if (llvm::equal(allocaOp.getMixedSizes(), newSizes))
52 return allocaOp.getResult();
53
54 // Create a new memref::AllocaOp.
55 Value newAllocaOp =
56 b.create<AllocaOp>(loc, newSizes, allocaOp.getType().getElementType());
57
58 // Create a memref::SubViewOp.
59 SmallVector<OpFoldResult> offsets(newSizes.size(), b.getIndexAttr(0));
60 SmallVector<OpFoldResult> strides(newSizes.size(), b.getIndexAttr(1));
61 return b
62 .create<SubViewOp>(loc, newAllocaOp, offsets, allocaOp.getMixedSizes(),
63 strides)
64 .getResult();
65}
66
67/// Push down an UnrealizedConversionCastOp past a SubViewOp.
68static UnrealizedConversionCastOp
69propagateSubViewOp(RewriterBase &rewriter,
70 UnrealizedConversionCastOp conversionOp, SubViewOp op) {
71 OpBuilder::InsertionGuard g(rewriter);
72 rewriter.setInsertionPoint(op);
73 auto newResultType = cast<MemRefType>(SubViewOp::inferRankReducedResultType(
74 op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(),
75 op.getMixedSizes(), op.getMixedStrides()));
76 Value newSubview = rewriter.create<SubViewOp>(
77 op.getLoc(), newResultType, conversionOp.getOperand(0),
78 op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides());
79 auto newConversionOp = rewriter.create<UnrealizedConversionCastOp>(
80 op.getLoc(), op.getType(), newSubview);
81 rewriter.replaceAllUsesWith(op.getResult(), newConversionOp->getResult(0));
82 return newConversionOp;
83}
84
85/// Given an original op and a new, modified op with the same number of results,
86/// whose memref return types may differ, replace all uses of the original op
87/// with the new op and propagate the new memref types through the IR.
88///
89/// Example:
90/// %from = memref.alloca(%sz) : memref<?xf32>
91/// %to = memref.subview ... : ... to memref<?xf32, strided<[1], offset: ?>>
92/// memref.store %cst, %from[%c0] : memref<?xf32>
93///
94/// In the above example, all uses of %from are replaced with %to. This can be
95/// done directly for ops such as memref.store. For ops that have memref results
96/// (e.g., memref.subview), the result type may depend on the operand type, so
97/// we cannot just replace all uses. There is special handling for common memref
98/// ops. For all other ops, unrealized_conversion_cast is inserted.
99static void replaceAndPropagateMemRefType(RewriterBase &rewriter,
100 Operation *from, Operation *to) {
101 assert(from->getNumResults() == to->getNumResults() &&
102 "expected same number of results");
103 OpBuilder::InsertionGuard g(rewriter);
104 rewriter.setInsertionPointAfter(to);
105
106 // Wrap new results in unrealized_conversion_cast and replace all uses of the
107 // original op.
108 SmallVector<UnrealizedConversionCastOp> unrealizedConversions;
109 for (const auto &it :
110 llvm::enumerate(First: llvm::zip(t: from->getResults(), u: to->getResults()))) {
111 unrealizedConversions.push_back(rewriter.create<UnrealizedConversionCastOp>(
112 to->getLoc(), std::get<0>(it.value()).getType(),
113 std::get<1>(it.value())));
114 rewriter.replaceAllUsesWith(from->getResult(idx: it.index()),
115 unrealizedConversions.back()->getResult(0));
116 }
117
118 // Push unrealized_conversion_cast ops further down in the IR. I.e., try to
119 // wrap results instead of operands in a cast.
120 for (int i = 0; i < static_cast<int>(unrealizedConversions.size()); ++i) {
121 UnrealizedConversionCastOp conversion = unrealizedConversions[i];
122 assert(conversion->getNumOperands() == 1 &&
123 conversion->getNumResults() == 1 &&
124 "expected single operand and single result");
125 SmallVector<Operation *> users = llvm::to_vector(conversion->getUsers());
126 for (Operation *user : users) {
127 // Handle common memref dialect ops that produce new memrefs and must
128 // be recreated with the new result type.
129 if (auto subviewOp = dyn_cast<SubViewOp>(user)) {
130 unrealizedConversions.push_back(
131 propagateSubViewOp(rewriter, conversion, subviewOp));
132 continue;
133 }
134
135 // TODO: Other memref ops such as memref.collapse_shape/expand_shape
136 // should also be handled here.
137
138 // Skip any ops that produce MemRef result or have MemRef region block
139 // arguments. These may need special handling (e.g., scf.for).
140 if (llvm::any_of(user->getResultTypes(),
141 [](Type t) { return isa<MemRefType>(t); }))
142 continue;
143 if (llvm::any_of(user->getRegions(), [](Region &r) {
144 return llvm::any_of(r.getArguments(), [](BlockArgument bbArg) {
145 return isa<MemRefType>(bbArg.getType());
146 });
147 }))
148 continue;
149
150 // For all other ops, we assume that we can directly replace the operand.
151 // This may have to be revised in the future; e.g., there may be ops that
152 // do not support non-identity layout maps.
153 for (OpOperand &operand : user->getOpOperands()) {
154 if ([[maybe_unused]] auto castOp =
155 operand.get().getDefiningOp<UnrealizedConversionCastOp>()) {
156 rewriter.modifyOpInPlace(
157 user, [&]() { operand.set(conversion->getOperand(0)); });
158 }
159 }
160 }
161 }
162
163 // Erase all unrealized_conversion_cast ops without uses.
164 for (auto op : unrealizedConversions)
165 if (op->getUses().empty())
166 rewriter.eraseOp(op);
167}
168
169FailureOr<Value> memref::replaceWithIndependentOp(RewriterBase &rewriter,
170 memref::AllocaOp allocaOp,
171 ValueRange independencies) {
172 auto replacement =
173 memref::buildIndependentOp(b&: rewriter, allocaOp: allocaOp, independencies);
174 if (failed(replacement))
175 return failure();
176 replaceAndPropagateMemRefType(rewriter, allocaOp,
177 replacement->getDefiningOp());
178 return replacement;
179}
180
181memref::AllocaOp memref::allocToAlloca(
182 RewriterBase &rewriter, memref::AllocOp alloc,
183 function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter) {
184 memref::DeallocOp dealloc = nullptr;
185 for (Operation &candidate :
186 llvm::make_range(alloc->getIterator(), alloc->getBlock()->end())) {
187 dealloc = dyn_cast<memref::DeallocOp>(candidate);
188 if (dealloc && dealloc.getMemref() == alloc.getMemref() &&
189 (!filter || filter(alloc, dealloc))) {
190 break;
191 }
192 }
193
194 if (!dealloc)
195 return nullptr;
196
197 OpBuilder::InsertionGuard guard(rewriter);
198 rewriter.setInsertionPoint(alloc);
199 auto alloca = rewriter.replaceOpWithNewOp<memref::AllocaOp>(
200 alloc, alloc.getMemref().getType(), alloc.getOperands());
201 rewriter.eraseOp(op: dealloc);
202 return alloca;
203}
204

source code of mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp