1//===----- FlattenMemRefs.cpp - MemRef ops flattener pass ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains patterns for flattening an multi-rank memref-related
10// ops into 1-d memref ops.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/MemRef/Transforms/Passes.h"
18#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
19#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
20#include "mlir/Dialect/Utils/IndexingUtils.h"
21#include "mlir/Dialect/Utils/StaticValueUtils.h"
22#include "mlir/Dialect/Vector/IR/VectorOps.h"
23#include "mlir/IR/Attributes.h"
24#include "mlir/IR/Builders.h"
25#include "mlir/IR/BuiltinTypes.h"
26#include "mlir/IR/OpDefinition.h"
27#include "mlir/IR/PatternMatch.h"
28#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29#include "llvm/ADT/TypeSwitch.h"
30
31namespace mlir {
32namespace memref {
33#define GEN_PASS_DEF_FLATTENMEMREFSPASS
34#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
35} // namespace memref
36} // namespace mlir
37
38using namespace mlir;
39
40static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
41 OpFoldResult in) {
42 if (Attribute offsetAttr = dyn_cast<Attribute>(Val&: in)) {
43 return rewriter.create<arith::ConstantIndexOp>(
44 location: loc, args: cast<IntegerAttr>(Val&: offsetAttr).getInt());
45 }
46 return cast<Value>(Val&: in);
47}
48
49/// Returns a collapsed memref and the linearized index to access the element
50/// at the specified indices.
51static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
52 Location loc,
53 Value source,
54 ValueRange indices) {
55 int64_t sourceOffset;
56 SmallVector<int64_t, 4> sourceStrides;
57 auto sourceType = cast<MemRefType>(Val: source.getType());
58 if (failed(Result: sourceType.getStridesAndOffset(strides&: sourceStrides, offset&: sourceOffset))) {
59 assert(false);
60 }
61
62 memref::ExtractStridedMetadataOp stridedMetadata =
63 rewriter.create<memref::ExtractStridedMetadataOp>(location: loc, args&: source);
64
65 auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth();
66 OpFoldResult linearizedIndices;
67 memref::LinearizedMemRefInfo linearizedInfo;
68 std::tie(args&: linearizedInfo, args&: linearizedIndices) =
69 memref::getLinearizedMemRefOffsetAndSize(
70 builder&: rewriter, loc, srcBits: typeBit, dstBits: typeBit,
71 offset: stridedMetadata.getConstifiedMixedOffset(),
72 sizes: stridedMetadata.getConstifiedMixedSizes(),
73 strides: stridedMetadata.getConstifiedMixedStrides(),
74 indices: getAsOpFoldResult(values: indices));
75
76 return std::make_pair(
77 x: rewriter.create<memref::ReinterpretCastOp>(
78 location: loc, args&: source,
79 /* offset = */ args&: linearizedInfo.linearizedOffset,
80 /* shapes = */
81 args: ArrayRef<OpFoldResult>{linearizedInfo.linearizedSize},
82 /* strides = */
83 args: ArrayRef<OpFoldResult>{rewriter.getIndexAttr(value: 1)}),
84 y: getValueFromOpFoldResult(rewriter, loc, in: linearizedIndices));
85}
86
87static bool needFlattening(Value val) {
88 auto type = cast<MemRefType>(Val: val.getType());
89 return type.getRank() > 1;
90}
91
92static bool checkLayout(Value val) {
93 auto type = cast<MemRefType>(Val: val.getType());
94 return type.getLayout().isIdentity() ||
95 isa<StridedLayoutAttr>(Val: type.getLayout());
96}
97
98namespace {
99static Value getTargetMemref(Operation *op) {
100 return llvm::TypeSwitch<Operation *, Value>(op)
101 .template Case<memref::LoadOp, memref::StoreOp, memref::AllocaOp,
102 memref::AllocOp>(caseFn: [](auto op) { return op.getMemref(); })
103 .template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
104 vector::MaskedStoreOp, vector::TransferReadOp,
105 vector::TransferWriteOp>(
106 caseFn: [](auto op) { return op.getBase(); })
107 .Default(defaultFn: [](auto) { return Value{}; });
108}
109
110template <typename T>
111static void castAllocResult(T oper, T newOper, Location loc,
112 PatternRewriter &rewriter) {
113 memref::ExtractStridedMetadataOp stridedMetadata =
114 rewriter.create<memref::ExtractStridedMetadataOp>(loc, oper);
115 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
116 oper, cast<MemRefType>(oper.getType()), newOper,
117 /*offset=*/rewriter.getIndexAttr(value: 0),
118 stridedMetadata.getConstifiedMixedSizes(),
119 stridedMetadata.getConstifiedMixedStrides());
120}
121
122template <typename T>
123static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
124 Value offset) {
125 Location loc = op->getLoc();
126 llvm::TypeSwitch<Operation *>(op.getOperation())
127 .template Case<memref::AllocOp>([&](auto oper) {
128 auto newAlloc = rewriter.create<memref::AllocOp>(
129 loc, cast<MemRefType>(Val: flatMemref.getType()),
130 oper.getAlignmentAttr());
131 castAllocResult(oper, newAlloc, loc, rewriter);
132 })
133 .template Case<memref::AllocaOp>([&](auto oper) {
134 auto newAlloca = rewriter.create<memref::AllocaOp>(
135 loc, cast<MemRefType>(Val: flatMemref.getType()),
136 oper.getAlignmentAttr());
137 castAllocResult(oper, newAlloca, loc, rewriter);
138 })
139 .template Case<memref::LoadOp>([&](auto op) {
140 auto newLoad = rewriter.create<memref::LoadOp>(
141 loc, op->getResultTypes(), flatMemref, ValueRange{offset});
142 newLoad->setAttrs(op->getAttrs());
143 rewriter.replaceOp(op, newLoad.getResult());
144 })
145 .template Case<memref::StoreOp>([&](auto op) {
146 auto newStore = rewriter.create<memref::StoreOp>(
147 loc, op->getOperands().front(), flatMemref, ValueRange{offset});
148 newStore->setAttrs(op->getAttrs());
149 rewriter.replaceOp(op, newStore);
150 })
151 .template Case<vector::LoadOp>([&](auto op) {
152 auto newLoad = rewriter.create<vector::LoadOp>(
153 loc, op->getResultTypes(), flatMemref, ValueRange{offset});
154 newLoad->setAttrs(op->getAttrs());
155 rewriter.replaceOp(op, newLoad.getResult());
156 })
157 .template Case<vector::StoreOp>([&](auto op) {
158 auto newStore = rewriter.create<vector::StoreOp>(
159 loc, op->getOperands().front(), flatMemref, ValueRange{offset});
160 newStore->setAttrs(op->getAttrs());
161 rewriter.replaceOp(op, newStore);
162 })
163 .template Case<vector::MaskedLoadOp>([&](auto op) {
164 auto newMaskedLoad = rewriter.create<vector::MaskedLoadOp>(
165 loc, op.getType(), flatMemref, ValueRange{offset}, op.getMask(),
166 op.getPassThru());
167 newMaskedLoad->setAttrs(op->getAttrs());
168 rewriter.replaceOp(op, newMaskedLoad.getResult());
169 })
170 .template Case<vector::MaskedStoreOp>([&](auto op) {
171 auto newMaskedStore = rewriter.create<vector::MaskedStoreOp>(
172 loc, flatMemref, ValueRange{offset}, op.getMask(),
173 op.getValueToStore());
174 newMaskedStore->setAttrs(op->getAttrs());
175 rewriter.replaceOp(op, newMaskedStore);
176 })
177 .template Case<vector::TransferReadOp>([&](auto op) {
178 auto newTransferRead = rewriter.create<vector::TransferReadOp>(
179 loc, op.getType(), flatMemref, ValueRange{offset}, op.getPadding());
180 rewriter.replaceOp(op, newTransferRead.getResult());
181 })
182 .template Case<vector::TransferWriteOp>([&](auto op) {
183 auto newTransferWrite = rewriter.create<vector::TransferWriteOp>(
184 loc, op.getVector(), flatMemref, ValueRange{offset});
185 rewriter.replaceOp(op, newTransferWrite);
186 })
187 .Default([&](auto op) {
188 op->emitOpError("unimplemented: do not know how to replace op.");
189 });
190}
191
192template <typename T>
193static ValueRange getIndices(T op) {
194 if constexpr (std::is_same_v<T, memref::AllocaOp> ||
195 std::is_same_v<T, memref::AllocOp>) {
196 return ValueRange{};
197 } else {
198 return op.getIndices();
199 }
200}
201
202template <typename T>
203static LogicalResult canBeFlattened(T op, PatternRewriter &rewriter) {
204 return llvm::TypeSwitch<Operation *, LogicalResult>(op.getOperation())
205 .template Case<vector::TransferReadOp, vector::TransferWriteOp>(
206 [&](auto oper) {
207 // For vector.transfer_read/write, must make sure:
208 // 1. all accesses are inbound, and
209 // 2. has an identity or minor identity permutation map.
210 auto permutationMap = oper.getPermutationMap();
211 if (!permutationMap.isIdentity() &&
212 !permutationMap.isMinorIdentity()) {
213 return rewriter.notifyMatchFailure(
214 oper, "only identity permutation map is supported");
215 }
216 mlir::ArrayAttr inbounds = oper.getInBounds();
217 if (llvm::any_of(inbounds, [](Attribute attr) {
218 return !cast<BoolAttr>(Val&: attr).getValue();
219 })) {
220 return rewriter.notifyMatchFailure(oper,
221 "only inbounds are supported");
222 }
223 return success();
224 })
225 .Default([&](auto op) { return success(); });
226}
227
228template <typename T>
229struct MemRefRewritePattern : public OpRewritePattern<T> {
230 using OpRewritePattern<T>::OpRewritePattern;
231 LogicalResult matchAndRewrite(T op,
232 PatternRewriter &rewriter) const override {
233 LogicalResult canFlatten = canBeFlattened(op, rewriter);
234 if (failed(Result: canFlatten)) {
235 return canFlatten;
236 }
237
238 Value memref = getTargetMemref(op);
239 if (!needFlattening(val: memref) || !checkLayout(val: memref))
240 return failure();
241 auto &&[flatMemref, offset] = getFlattenMemrefAndOffset(
242 rewriter, op->getLoc(), memref, getIndices<T>(op));
243 replaceOp<T>(op, rewriter, flatMemref, offset);
244 return success();
245 }
246};
247
248struct FlattenMemrefsPass
249 : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
250 using Base::Base;
251
252 void getDependentDialects(DialectRegistry &registry) const override {
253 registry.insert<affine::AffineDialect, arith::ArithDialect,
254 memref::MemRefDialect, vector::VectorDialect>();
255 }
256
257 void runOnOperation() override {
258 RewritePatternSet patterns(&getContext());
259
260 memref::populateFlattenMemrefsPatterns(patterns);
261
262 if (failed(Result: applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns))))
263 return signalPassFailure();
264 }
265};
266
267} // namespace
268
269void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
270 patterns.insert<MemRefRewritePattern<memref::LoadOp>,
271 MemRefRewritePattern<memref::StoreOp>,
272 MemRefRewritePattern<memref::AllocOp>,
273 MemRefRewritePattern<memref::AllocaOp>,
274 MemRefRewritePattern<vector::LoadOp>,
275 MemRefRewritePattern<vector::StoreOp>,
276 MemRefRewritePattern<vector::TransferReadOp>,
277 MemRefRewritePattern<vector::TransferWriteOp>,
278 MemRefRewritePattern<vector::MaskedLoadOp>,
279 MemRefRewritePattern<vector::MaskedStoreOp>>(
280 arg: patterns.getContext());
281}
282

source code of mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp