1//===----- FlattenMemRefs.cpp - MemRef ops flattener pass ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains patterns for flattening an multi-rank memref-related
10// ops into 1-d memref ops.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/MemRef/Transforms/Passes.h"
18#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
19#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
20#include "mlir/Dialect/Utils/IndexingUtils.h"
21#include "mlir/Dialect/Utils/StaticValueUtils.h"
22#include "mlir/Dialect/Vector/IR/VectorOps.h"
23#include "mlir/IR/AffineExpr.h"
24#include "mlir/IR/Attributes.h"
25#include "mlir/IR/Builders.h"
26#include "mlir/IR/BuiltinTypes.h"
27#include "mlir/IR/OpDefinition.h"
28#include "mlir/IR/PatternMatch.h"
29#include "mlir/Pass/Pass.h"
30#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
33
34#include <numeric>
35
36namespace mlir {
37namespace memref {
38#define GEN_PASS_DEF_FLATTENMEMREFSPASS
39#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
40} // namespace memref
41} // namespace mlir
42
43using namespace mlir;
44
45static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
46 OpFoldResult in) {
47 if (Attribute offsetAttr = dyn_cast<Attribute>(in)) {
48 return rewriter.create<arith::ConstantIndexOp>(
49 loc, cast<IntegerAttr>(offsetAttr).getInt());
50 }
51 return cast<Value>(in);
52}
53
54/// Returns a collapsed memref and the linearized index to access the element
55/// at the specified indices.
56static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
57 Location loc,
58 Value source,
59 ValueRange indices) {
60 int64_t sourceOffset;
61 SmallVector<int64_t, 4> sourceStrides;
62 auto sourceType = cast<MemRefType>(source.getType());
63 if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset))) {
64 assert(false);
65 }
66
67 memref::ExtractStridedMetadataOp stridedMetadata =
68 rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
69
70 auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth();
71 OpFoldResult linearizedIndices;
72 memref::LinearizedMemRefInfo linearizedInfo;
73 std::tie(args&: linearizedInfo, args&: linearizedIndices) =
74 memref::getLinearizedMemRefOffsetAndSize(
75 rewriter, loc, typeBit, typeBit,
76 stridedMetadata.getConstifiedMixedOffset(),
77 stridedMetadata.getConstifiedMixedSizes(),
78 stridedMetadata.getConstifiedMixedStrides(),
79 getAsOpFoldResult(values: indices));
80
81 return std::make_pair(
82 rewriter.create<memref::ReinterpretCastOp>(
83 loc, source,
84 /* offset = */ linearizedInfo.linearizedOffset,
85 /* shapes = */
86 ArrayRef<OpFoldResult>{linearizedInfo.linearizedSize},
87 /* strides = */
88 ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)}),
89 getValueFromOpFoldResult(rewriter, loc, linearizedIndices));
90}
91
92static bool needFlattening(Value val) {
93 auto type = cast<MemRefType>(val.getType());
94 return type.getRank() > 1;
95}
96
97static bool checkLayout(Value val) {
98 auto type = cast<MemRefType>(val.getType());
99 return type.getLayout().isIdentity() ||
100 isa<StridedLayoutAttr>(type.getLayout());
101}
102
103namespace {
104static Value getTargetMemref(Operation *op) {
105 return llvm::TypeSwitch<Operation *, Value>(op)
106 .template Case<memref::LoadOp, memref::StoreOp, memref::AllocaOp,
107 memref::AllocOp>([](auto op) { return op.getMemref(); })
108 .template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
109 vector::MaskedStoreOp, vector::TransferReadOp,
110 vector::TransferWriteOp>(
111 [](auto op) { return op.getBase(); })
112 .Default([](auto) { return Value{}; });
113}
114
115template <typename T>
116static void castAllocResult(T oper, T newOper, Location loc,
117 PatternRewriter &rewriter) {
118 memref::ExtractStridedMetadataOp stridedMetadata =
119 rewriter.create<memref::ExtractStridedMetadataOp>(loc, oper);
120 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
121 oper, cast<MemRefType>(oper.getType()), newOper,
122 /*offset=*/rewriter.getIndexAttr(0),
123 stridedMetadata.getConstifiedMixedSizes(),
124 stridedMetadata.getConstifiedMixedStrides());
125}
126
127template <typename T>
128static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
129 Value offset) {
130 Location loc = op->getLoc();
131 llvm::TypeSwitch<Operation *>(op.getOperation())
132 .template Case<memref::AllocOp>([&](auto oper) {
133 auto newAlloc = rewriter.create<memref::AllocOp>(
134 loc, cast<MemRefType>(flatMemref.getType()),
135 oper.getAlignmentAttr());
136 castAllocResult(oper, newAlloc, loc, rewriter);
137 })
138 .template Case<memref::AllocaOp>([&](auto oper) {
139 auto newAlloca = rewriter.create<memref::AllocaOp>(
140 loc, cast<MemRefType>(flatMemref.getType()),
141 oper.getAlignmentAttr());
142 castAllocResult(oper, newAlloca, loc, rewriter);
143 })
144 .template Case<memref::LoadOp>([&](auto op) {
145 auto newLoad = rewriter.create<memref::LoadOp>(
146 loc, op->getResultTypes(), flatMemref, ValueRange{offset});
147 newLoad->setAttrs(op->getAttrs());
148 rewriter.replaceOp(op, newLoad.getResult());
149 })
150 .template Case<memref::StoreOp>([&](auto op) {
151 auto newStore = rewriter.create<memref::StoreOp>(
152 loc, op->getOperands().front(), flatMemref, ValueRange{offset});
153 newStore->setAttrs(op->getAttrs());
154 rewriter.replaceOp(op, newStore);
155 })
156 .template Case<vector::LoadOp>([&](auto op) {
157 auto newLoad = rewriter.create<vector::LoadOp>(
158 loc, op->getResultTypes(), flatMemref, ValueRange{offset});
159 newLoad->setAttrs(op->getAttrs());
160 rewriter.replaceOp(op, newLoad.getResult());
161 })
162 .template Case<vector::StoreOp>([&](auto op) {
163 auto newStore = rewriter.create<vector::StoreOp>(
164 loc, op->getOperands().front(), flatMemref, ValueRange{offset});
165 newStore->setAttrs(op->getAttrs());
166 rewriter.replaceOp(op, newStore);
167 })
168 .template Case<vector::MaskedLoadOp>([&](auto op) {
169 auto newMaskedLoad = rewriter.create<vector::MaskedLoadOp>(
170 loc, op.getType(), flatMemref, ValueRange{offset}, op.getMask(),
171 op.getPassThru());
172 newMaskedLoad->setAttrs(op->getAttrs());
173 rewriter.replaceOp(op, newMaskedLoad.getResult());
174 })
175 .template Case<vector::MaskedStoreOp>([&](auto op) {
176 auto newMaskedStore = rewriter.create<vector::MaskedStoreOp>(
177 loc, flatMemref, ValueRange{offset}, op.getMask(),
178 op.getValueToStore());
179 newMaskedStore->setAttrs(op->getAttrs());
180 rewriter.replaceOp(op, newMaskedStore);
181 })
182 .template Case<vector::TransferReadOp>([&](auto op) {
183 auto newTransferRead = rewriter.create<vector::TransferReadOp>(
184 loc, op.getType(), flatMemref, ValueRange{offset}, op.getPadding());
185 rewriter.replaceOp(op, newTransferRead.getResult());
186 })
187 .template Case<vector::TransferWriteOp>([&](auto op) {
188 auto newTransferWrite = rewriter.create<vector::TransferWriteOp>(
189 loc, op.getVector(), flatMemref, ValueRange{offset});
190 rewriter.replaceOp(op, newTransferWrite);
191 })
192 .Default([&](auto op) {
193 op->emitOpError("unimplemented: do not know how to replace op.");
194 });
195}
196
197template <typename T>
198static ValueRange getIndices(T op) {
199 if constexpr (std::is_same_v<T, memref::AllocaOp> ||
200 std::is_same_v<T, memref::AllocOp>) {
201 return ValueRange{};
202 } else {
203 return op.getIndices();
204 }
205}
206
207template <typename T>
208static LogicalResult canBeFlattened(T op, PatternRewriter &rewriter) {
209 return llvm::TypeSwitch<Operation *, LogicalResult>(op.getOperation())
210 .template Case<vector::TransferReadOp, vector::TransferWriteOp>(
211 [&](auto oper) {
212 // For vector.transfer_read/write, must make sure:
213 // 1. all accesses are inbound, and
214 // 2. has an identity or minor identity permutation map.
215 auto permutationMap = oper.getPermutationMap();
216 if (!permutationMap.isIdentity() &&
217 !permutationMap.isMinorIdentity()) {
218 return rewriter.notifyMatchFailure(
219 oper, "only identity permutation map is supported");
220 }
221 mlir::ArrayAttr inbounds = oper.getInBounds();
222 if (llvm::any_of(inbounds, [](Attribute attr) {
223 return !cast<BoolAttr>(Val&: attr).getValue();
224 })) {
225 return rewriter.notifyMatchFailure(oper,
226 "only inbounds are supported");
227 }
228 return success();
229 })
230 .Default([&](auto op) { return success(); });
231}
232
233template <typename T>
234struct MemRefRewritePattern : public OpRewritePattern<T> {
235 using OpRewritePattern<T>::OpRewritePattern;
236 LogicalResult matchAndRewrite(T op,
237 PatternRewriter &rewriter) const override {
238 LogicalResult canFlatten = canBeFlattened(op, rewriter);
239 if (failed(Result: canFlatten)) {
240 return canFlatten;
241 }
242
243 Value memref = getTargetMemref(op);
244 if (!needFlattening(val: memref) || !checkLayout(val: memref))
245 return failure();
246 auto &&[flatMemref, offset] = getFlattenMemrefAndOffset(
247 rewriter, op->getLoc(), memref, getIndices<T>(op));
248 replaceOp<T>(op, rewriter, flatMemref, offset);
249 return success();
250 }
251};
252
253struct FlattenMemrefsPass
254 : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
255 using Base::Base;
256
257 void getDependentDialects(DialectRegistry &registry) const override {
258 registry.insert<affine::AffineDialect, arith::ArithDialect,
259 memref::MemRefDialect, vector::VectorDialect>();
260 }
261
262 void runOnOperation() override {
263 RewritePatternSet patterns(&getContext());
264
265 memref::populateFlattenMemrefsPatterns(patterns);
266
267 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
268 return signalPassFailure();
269 }
270};
271
272} // namespace
273
274void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
275 patterns.insert<MemRefRewritePattern<memref::LoadOp>,
276 MemRefRewritePattern<memref::StoreOp>,
277 MemRefRewritePattern<memref::AllocOp>,
278 MemRefRewritePattern<memref::AllocaOp>,
279 MemRefRewritePattern<vector::LoadOp>,
280 MemRefRewritePattern<vector::StoreOp>,
281 MemRefRewritePattern<vector::TransferReadOp>,
282 MemRefRewritePattern<vector::TransferWriteOp>,
283 MemRefRewritePattern<vector::MaskedLoadOp>,
284 MemRefRewritePattern<vector::MaskedStoreOp>>(
285 patterns.getContext());
286}
287

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp