1 | //===- IndependenceTransforms.cpp - Make ops independent of values --------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
10 | |
11 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
12 | #include "mlir/Dialect/Affine/Transforms/Transforms.h" |
13 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
14 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
15 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
16 | |
17 | using namespace mlir; |
18 | using namespace mlir::memref; |
19 | |
20 | /// Make the given OpFoldResult independent of all independencies. |
21 | static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc, |
22 | OpFoldResult ofr, |
23 | ValueRange independencies) { |
24 | if (ofr.is<Attribute>()) |
25 | return ofr; |
26 | AffineMap boundMap; |
27 | ValueDimList mapOperands; |
28 | if (failed(result: ValueBoundsConstraintSet::computeIndependentBound( |
29 | resultMap&: boundMap, mapOperands, type: presburger::BoundType::UB, var: ofr, independencies, |
30 | /*closedUB=*/true))) |
31 | return failure(); |
32 | return affine::materializeComputedBound(b, loc, boundMap, mapOperands); |
33 | } |
34 | |
35 | FailureOr<Value> memref::buildIndependentOp(OpBuilder &b, |
36 | memref::AllocaOp allocaOp, |
37 | ValueRange independencies) { |
38 | OpBuilder::InsertionGuard g(b); |
39 | b.setInsertionPoint(allocaOp); |
40 | Location loc = allocaOp.getLoc(); |
41 | |
42 | SmallVector<OpFoldResult> newSizes; |
43 | for (OpFoldResult ofr : allocaOp.getMixedSizes()) { |
44 | auto ub = makeIndependent(b, loc, ofr, independencies); |
45 | if (failed(ub)) |
46 | return failure(); |
47 | newSizes.push_back(*ub); |
48 | } |
49 | |
50 | // Return existing memref::AllocaOp if nothing has changed. |
51 | if (llvm::equal(allocaOp.getMixedSizes(), newSizes)) |
52 | return allocaOp.getResult(); |
53 | |
54 | // Create a new memref::AllocaOp. |
55 | Value newAllocaOp = |
56 | b.create<AllocaOp>(loc, newSizes, allocaOp.getType().getElementType()); |
57 | |
58 | // Create a memref::SubViewOp. |
59 | SmallVector<OpFoldResult> offsets(newSizes.size(), b.getIndexAttr(0)); |
60 | SmallVector<OpFoldResult> strides(newSizes.size(), b.getIndexAttr(1)); |
61 | return b |
62 | .create<SubViewOp>(loc, newAllocaOp, offsets, allocaOp.getMixedSizes(), |
63 | strides) |
64 | .getResult(); |
65 | } |
66 | |
67 | /// Push down an UnrealizedConversionCastOp past a SubViewOp. |
68 | static UnrealizedConversionCastOp |
69 | propagateSubViewOp(RewriterBase &rewriter, |
70 | UnrealizedConversionCastOp conversionOp, SubViewOp op) { |
71 | OpBuilder::InsertionGuard g(rewriter); |
72 | rewriter.setInsertionPoint(op); |
73 | auto newResultType = cast<MemRefType>(SubViewOp::inferRankReducedResultType( |
74 | op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(), |
75 | op.getMixedSizes(), op.getMixedStrides())); |
76 | Value newSubview = rewriter.create<SubViewOp>( |
77 | op.getLoc(), newResultType, conversionOp.getOperand(0), |
78 | op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides()); |
79 | auto newConversionOp = rewriter.create<UnrealizedConversionCastOp>( |
80 | op.getLoc(), op.getType(), newSubview); |
81 | rewriter.replaceAllUsesWith(op.getResult(), newConversionOp->getResult(0)); |
82 | return newConversionOp; |
83 | } |
84 | |
85 | /// Given an original op and a new, modified op with the same number of results, |
86 | /// whose memref return types may differ, replace all uses of the original op |
87 | /// with the new op and propagate the new memref types through the IR. |
88 | /// |
89 | /// Example: |
90 | /// %from = memref.alloca(%sz) : memref<?xf32> |
91 | /// %to = memref.subview ... : ... to memref<?xf32, strided<[1], offset: ?>> |
92 | /// memref.store %cst, %from[%c0] : memref<?xf32> |
93 | /// |
94 | /// In the above example, all uses of %from are replaced with %to. This can be |
95 | /// done directly for ops such as memref.store. For ops that have memref results |
96 | /// (e.g., memref.subview), the result type may depend on the operand type, so |
97 | /// we cannot just replace all uses. There is special handling for common memref |
98 | /// ops. For all other ops, unrealized_conversion_cast is inserted. |
99 | static void replaceAndPropagateMemRefType(RewriterBase &rewriter, |
100 | Operation *from, Operation *to) { |
101 | assert(from->getNumResults() == to->getNumResults() && |
102 | "expected same number of results" ); |
103 | OpBuilder::InsertionGuard g(rewriter); |
104 | rewriter.setInsertionPointAfter(to); |
105 | |
106 | // Wrap new results in unrealized_conversion_cast and replace all uses of the |
107 | // original op. |
108 | SmallVector<UnrealizedConversionCastOp> unrealizedConversions; |
109 | for (const auto &it : |
110 | llvm::enumerate(First: llvm::zip(t: from->getResults(), u: to->getResults()))) { |
111 | unrealizedConversions.push_back(rewriter.create<UnrealizedConversionCastOp>( |
112 | to->getLoc(), std::get<0>(it.value()).getType(), |
113 | std::get<1>(it.value()))); |
114 | rewriter.replaceAllUsesWith(from->getResult(idx: it.index()), |
115 | unrealizedConversions.back()->getResult(0)); |
116 | } |
117 | |
118 | // Push unrealized_conversion_cast ops further down in the IR. I.e., try to |
119 | // wrap results instead of operands in a cast. |
120 | for (int i = 0; i < static_cast<int>(unrealizedConversions.size()); ++i) { |
121 | UnrealizedConversionCastOp conversion = unrealizedConversions[i]; |
122 | assert(conversion->getNumOperands() == 1 && |
123 | conversion->getNumResults() == 1 && |
124 | "expected single operand and single result" ); |
125 | SmallVector<Operation *> users = llvm::to_vector(conversion->getUsers()); |
126 | for (Operation *user : users) { |
127 | // Handle common memref dialect ops that produce new memrefs and must |
128 | // be recreated with the new result type. |
129 | if (auto subviewOp = dyn_cast<SubViewOp>(user)) { |
130 | unrealizedConversions.push_back( |
131 | propagateSubViewOp(rewriter, conversion, subviewOp)); |
132 | continue; |
133 | } |
134 | |
135 | // TODO: Other memref ops such as memref.collapse_shape/expand_shape |
136 | // should also be handled here. |
137 | |
138 | // Skip any ops that produce MemRef result or have MemRef region block |
139 | // arguments. These may need special handling (e.g., scf.for). |
140 | if (llvm::any_of(user->getResultTypes(), |
141 | [](Type t) { return isa<MemRefType>(t); })) |
142 | continue; |
143 | if (llvm::any_of(user->getRegions(), [](Region &r) { |
144 | return llvm::any_of(r.getArguments(), [](BlockArgument bbArg) { |
145 | return isa<MemRefType>(bbArg.getType()); |
146 | }); |
147 | })) |
148 | continue; |
149 | |
150 | // For all other ops, we assume that we can directly replace the operand. |
151 | // This may have to be revised in the future; e.g., there may be ops that |
152 | // do not support non-identity layout maps. |
153 | for (OpOperand &operand : user->getOpOperands()) { |
154 | if ([[maybe_unused]] auto castOp = |
155 | operand.get().getDefiningOp<UnrealizedConversionCastOp>()) { |
156 | rewriter.modifyOpInPlace( |
157 | user, [&]() { operand.set(conversion->getOperand(0)); }); |
158 | } |
159 | } |
160 | } |
161 | } |
162 | |
163 | // Erase all unrealized_conversion_cast ops without uses. |
164 | for (auto op : unrealizedConversions) |
165 | if (op->getUses().empty()) |
166 | rewriter.eraseOp(op); |
167 | } |
168 | |
169 | FailureOr<Value> memref::replaceWithIndependentOp(RewriterBase &rewriter, |
170 | memref::AllocaOp allocaOp, |
171 | ValueRange independencies) { |
172 | auto replacement = |
173 | memref::buildIndependentOp(b&: rewriter, allocaOp: allocaOp, independencies); |
174 | if (failed(replacement)) |
175 | return failure(); |
176 | replaceAndPropagateMemRefType(rewriter, allocaOp, |
177 | replacement->getDefiningOp()); |
178 | return replacement; |
179 | } |
180 | |
181 | memref::AllocaOp memref::allocToAlloca( |
182 | RewriterBase &rewriter, memref::AllocOp alloc, |
183 | function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter) { |
184 | memref::DeallocOp dealloc = nullptr; |
185 | for (Operation &candidate : |
186 | llvm::make_range(alloc->getIterator(), alloc->getBlock()->end())) { |
187 | dealloc = dyn_cast<memref::DeallocOp>(candidate); |
188 | if (dealloc && dealloc.getMemref() == alloc.getMemref() && |
189 | (!filter || filter(alloc, dealloc))) { |
190 | break; |
191 | } |
192 | } |
193 | |
194 | if (!dealloc) |
195 | return nullptr; |
196 | |
197 | OpBuilder::InsertionGuard guard(rewriter); |
198 | rewriter.setInsertionPoint(alloc); |
199 | auto alloca = rewriter.replaceOpWithNewOp<memref::AllocaOp>( |
200 | alloc, alloc.getMemref().getType(), alloc.getOperands()); |
201 | rewriter.eraseOp(op: dealloc); |
202 | return alloca; |
203 | } |
204 | |