1 | //===----- FlattenMemRefs.cpp - MemRef ops flattener pass ----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains patterns for flattening an multi-rank memref-related |
10 | // ops into 1-d memref ops. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
17 | #include "mlir/Dialect/MemRef/Transforms/Passes.h" |
18 | #include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
19 | #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" |
20 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
21 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
22 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
23 | #include "mlir/IR/AffineExpr.h" |
24 | #include "mlir/IR/Attributes.h" |
25 | #include "mlir/IR/Builders.h" |
26 | #include "mlir/IR/BuiltinTypes.h" |
27 | #include "mlir/IR/OpDefinition.h" |
28 | #include "mlir/IR/PatternMatch.h" |
29 | #include "mlir/Pass/Pass.h" |
30 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
31 | #include "llvm/ADT/SmallVector.h" |
32 | #include "llvm/ADT/TypeSwitch.h" |
33 | |
34 | #include <numeric> |
35 | |
36 | namespace mlir { |
37 | namespace memref { |
38 | #define GEN_PASS_DEF_FLATTENMEMREFSPASS |
39 | #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" |
40 | } // namespace memref |
41 | } // namespace mlir |
42 | |
43 | using namespace mlir; |
44 | |
45 | static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc, |
46 | OpFoldResult in) { |
47 | if (Attribute offsetAttr = dyn_cast<Attribute>(in)) { |
48 | return rewriter.create<arith::ConstantIndexOp>( |
49 | loc, cast<IntegerAttr>(offsetAttr).getInt()); |
50 | } |
51 | return cast<Value>(in); |
52 | } |
53 | |
54 | /// Returns a collapsed memref and the linearized index to access the element |
55 | /// at the specified indices. |
56 | static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter, |
57 | Location loc, |
58 | Value source, |
59 | ValueRange indices) { |
60 | int64_t sourceOffset; |
61 | SmallVector<int64_t, 4> sourceStrides; |
62 | auto sourceType = cast<MemRefType>(source.getType()); |
63 | if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset))) { |
64 | assert(false); |
65 | } |
66 | |
67 | memref::ExtractStridedMetadataOp stridedMetadata = |
68 | rewriter.create<memref::ExtractStridedMetadataOp>(loc, source); |
69 | |
70 | auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth(); |
71 | OpFoldResult linearizedIndices; |
72 | memref::LinearizedMemRefInfo linearizedInfo; |
73 | std::tie(args&: linearizedInfo, args&: linearizedIndices) = |
74 | memref::getLinearizedMemRefOffsetAndSize( |
75 | rewriter, loc, typeBit, typeBit, |
76 | stridedMetadata.getConstifiedMixedOffset(), |
77 | stridedMetadata.getConstifiedMixedSizes(), |
78 | stridedMetadata.getConstifiedMixedStrides(), |
79 | getAsOpFoldResult(values: indices)); |
80 | |
81 | return std::make_pair( |
82 | rewriter.create<memref::ReinterpretCastOp>( |
83 | loc, source, |
84 | /* offset = */ linearizedInfo.linearizedOffset, |
85 | /* shapes = */ |
86 | ArrayRef<OpFoldResult>{linearizedInfo.linearizedSize}, |
87 | /* strides = */ |
88 | ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)}), |
89 | getValueFromOpFoldResult(rewriter, loc, linearizedIndices)); |
90 | } |
91 | |
92 | static bool needFlattening(Value val) { |
93 | auto type = cast<MemRefType>(val.getType()); |
94 | return type.getRank() > 1; |
95 | } |
96 | |
97 | static bool checkLayout(Value val) { |
98 | auto type = cast<MemRefType>(val.getType()); |
99 | return type.getLayout().isIdentity() || |
100 | isa<StridedLayoutAttr>(type.getLayout()); |
101 | } |
102 | |
103 | namespace { |
104 | static Value getTargetMemref(Operation *op) { |
105 | return llvm::TypeSwitch<Operation *, Value>(op) |
106 | .template Case<memref::LoadOp, memref::StoreOp, memref::AllocaOp, |
107 | memref::AllocOp>([](auto op) { return op.getMemref(); }) |
108 | .template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp, |
109 | vector::MaskedStoreOp, vector::TransferReadOp, |
110 | vector::TransferWriteOp>( |
111 | [](auto op) { return op.getBase(); }) |
112 | .Default([](auto) { return Value{}; }); |
113 | } |
114 | |
115 | template <typename T> |
116 | static void castAllocResult(T oper, T newOper, Location loc, |
117 | PatternRewriter &rewriter) { |
118 | memref::ExtractStridedMetadataOp stridedMetadata = |
119 | rewriter.create<memref::ExtractStridedMetadataOp>(loc, oper); |
120 | rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>( |
121 | oper, cast<MemRefType>(oper.getType()), newOper, |
122 | /*offset=*/rewriter.getIndexAttr(0), |
123 | stridedMetadata.getConstifiedMixedSizes(), |
124 | stridedMetadata.getConstifiedMixedStrides()); |
125 | } |
126 | |
127 | template <typename T> |
128 | static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref, |
129 | Value offset) { |
130 | Location loc = op->getLoc(); |
131 | llvm::TypeSwitch<Operation *>(op.getOperation()) |
132 | .template Case<memref::AllocOp>([&](auto oper) { |
133 | auto newAlloc = rewriter.create<memref::AllocOp>( |
134 | loc, cast<MemRefType>(flatMemref.getType()), |
135 | oper.getAlignmentAttr()); |
136 | castAllocResult(oper, newAlloc, loc, rewriter); |
137 | }) |
138 | .template Case<memref::AllocaOp>([&](auto oper) { |
139 | auto newAlloca = rewriter.create<memref::AllocaOp>( |
140 | loc, cast<MemRefType>(flatMemref.getType()), |
141 | oper.getAlignmentAttr()); |
142 | castAllocResult(oper, newAlloca, loc, rewriter); |
143 | }) |
144 | .template Case<memref::LoadOp>([&](auto op) { |
145 | auto newLoad = rewriter.create<memref::LoadOp>( |
146 | loc, op->getResultTypes(), flatMemref, ValueRange{offset}); |
147 | newLoad->setAttrs(op->getAttrs()); |
148 | rewriter.replaceOp(op, newLoad.getResult()); |
149 | }) |
150 | .template Case<memref::StoreOp>([&](auto op) { |
151 | auto newStore = rewriter.create<memref::StoreOp>( |
152 | loc, op->getOperands().front(), flatMemref, ValueRange{offset}); |
153 | newStore->setAttrs(op->getAttrs()); |
154 | rewriter.replaceOp(op, newStore); |
155 | }) |
156 | .template Case<vector::LoadOp>([&](auto op) { |
157 | auto newLoad = rewriter.create<vector::LoadOp>( |
158 | loc, op->getResultTypes(), flatMemref, ValueRange{offset}); |
159 | newLoad->setAttrs(op->getAttrs()); |
160 | rewriter.replaceOp(op, newLoad.getResult()); |
161 | }) |
162 | .template Case<vector::StoreOp>([&](auto op) { |
163 | auto newStore = rewriter.create<vector::StoreOp>( |
164 | loc, op->getOperands().front(), flatMemref, ValueRange{offset}); |
165 | newStore->setAttrs(op->getAttrs()); |
166 | rewriter.replaceOp(op, newStore); |
167 | }) |
168 | .template Case<vector::MaskedLoadOp>([&](auto op) { |
169 | auto newMaskedLoad = rewriter.create<vector::MaskedLoadOp>( |
170 | loc, op.getType(), flatMemref, ValueRange{offset}, op.getMask(), |
171 | op.getPassThru()); |
172 | newMaskedLoad->setAttrs(op->getAttrs()); |
173 | rewriter.replaceOp(op, newMaskedLoad.getResult()); |
174 | }) |
175 | .template Case<vector::MaskedStoreOp>([&](auto op) { |
176 | auto newMaskedStore = rewriter.create<vector::MaskedStoreOp>( |
177 | loc, flatMemref, ValueRange{offset}, op.getMask(), |
178 | op.getValueToStore()); |
179 | newMaskedStore->setAttrs(op->getAttrs()); |
180 | rewriter.replaceOp(op, newMaskedStore); |
181 | }) |
182 | .template Case<vector::TransferReadOp>([&](auto op) { |
183 | auto newTransferRead = rewriter.create<vector::TransferReadOp>( |
184 | loc, op.getType(), flatMemref, ValueRange{offset}, op.getPadding()); |
185 | rewriter.replaceOp(op, newTransferRead.getResult()); |
186 | }) |
187 | .template Case<vector::TransferWriteOp>([&](auto op) { |
188 | auto newTransferWrite = rewriter.create<vector::TransferWriteOp>( |
189 | loc, op.getVector(), flatMemref, ValueRange{offset}); |
190 | rewriter.replaceOp(op, newTransferWrite); |
191 | }) |
192 | .Default([&](auto op) { |
193 | op->emitOpError("unimplemented: do not know how to replace op." ); |
194 | }); |
195 | } |
196 | |
197 | template <typename T> |
198 | static ValueRange getIndices(T op) { |
199 | if constexpr (std::is_same_v<T, memref::AllocaOp> || |
200 | std::is_same_v<T, memref::AllocOp>) { |
201 | return ValueRange{}; |
202 | } else { |
203 | return op.getIndices(); |
204 | } |
205 | } |
206 | |
207 | template <typename T> |
208 | static LogicalResult canBeFlattened(T op, PatternRewriter &rewriter) { |
209 | return llvm::TypeSwitch<Operation *, LogicalResult>(op.getOperation()) |
210 | .template Case<vector::TransferReadOp, vector::TransferWriteOp>( |
211 | [&](auto oper) { |
212 | // For vector.transfer_read/write, must make sure: |
213 | // 1. all accesses are inbound, and |
214 | // 2. has an identity or minor identity permutation map. |
215 | auto permutationMap = oper.getPermutationMap(); |
216 | if (!permutationMap.isIdentity() && |
217 | !permutationMap.isMinorIdentity()) { |
218 | return rewriter.notifyMatchFailure( |
219 | oper, "only identity permutation map is supported" ); |
220 | } |
221 | mlir::ArrayAttr inbounds = oper.getInBounds(); |
222 | if (llvm::any_of(inbounds, [](Attribute attr) { |
223 | return !cast<BoolAttr>(Val&: attr).getValue(); |
224 | })) { |
225 | return rewriter.notifyMatchFailure(oper, |
226 | "only inbounds are supported" ); |
227 | } |
228 | return success(); |
229 | }) |
230 | .Default([&](auto op) { return success(); }); |
231 | } |
232 | |
233 | template <typename T> |
234 | struct MemRefRewritePattern : public OpRewritePattern<T> { |
235 | using OpRewritePattern<T>::OpRewritePattern; |
236 | LogicalResult matchAndRewrite(T op, |
237 | PatternRewriter &rewriter) const override { |
238 | LogicalResult canFlatten = canBeFlattened(op, rewriter); |
239 | if (failed(Result: canFlatten)) { |
240 | return canFlatten; |
241 | } |
242 | |
243 | Value memref = getTargetMemref(op); |
244 | if (!needFlattening(val: memref) || !checkLayout(val: memref)) |
245 | return failure(); |
246 | auto &&[flatMemref, offset] = getFlattenMemrefAndOffset( |
247 | rewriter, op->getLoc(), memref, getIndices<T>(op)); |
248 | replaceOp<T>(op, rewriter, flatMemref, offset); |
249 | return success(); |
250 | } |
251 | }; |
252 | |
253 | struct FlattenMemrefsPass |
254 | : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> { |
255 | using Base::Base; |
256 | |
257 | void getDependentDialects(DialectRegistry ®istry) const override { |
258 | registry.insert<affine::AffineDialect, arith::ArithDialect, |
259 | memref::MemRefDialect, vector::VectorDialect>(); |
260 | } |
261 | |
262 | void runOnOperation() override { |
263 | RewritePatternSet patterns(&getContext()); |
264 | |
265 | memref::populateFlattenMemrefsPatterns(patterns); |
266 | |
267 | if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) |
268 | return signalPassFailure(); |
269 | } |
270 | }; |
271 | |
272 | } // namespace |
273 | |
274 | void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) { |
275 | patterns.insert<MemRefRewritePattern<memref::LoadOp>, |
276 | MemRefRewritePattern<memref::StoreOp>, |
277 | MemRefRewritePattern<memref::AllocOp>, |
278 | MemRefRewritePattern<memref::AllocaOp>, |
279 | MemRefRewritePattern<vector::LoadOp>, |
280 | MemRefRewritePattern<vector::StoreOp>, |
281 | MemRefRewritePattern<vector::TransferReadOp>, |
282 | MemRefRewritePattern<vector::TransferWriteOp>, |
283 | MemRefRewritePattern<vector::MaskedLoadOp>, |
284 | MemRefRewritePattern<vector::MaskedStoreOp>>( |
285 | patterns.getContext()); |
286 | } |
287 | |