1//===- DecomposeMemRefs.cpp - Decompose memrefs pass implementation -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements decompose memrefs pass.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/GPU/IR/GPUDialect.h"
15#include "mlir/Dialect/GPU/Transforms/Passes.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/Utils/IndexingUtils.h"
18#include "mlir/IR/AffineExpr.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/PatternMatch.h"
21#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22
23namespace mlir {
24#define GEN_PASS_DEF_GPUDECOMPOSEMEMREFSPASS
25#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
26} // namespace mlir
27
28using namespace mlir;
29
30static MemRefType inferCastResultType(Value source, OpFoldResult offset) {
31 auto sourceType = cast<BaseMemRefType>(Val: source.getType());
32 SmallVector<int64_t> staticOffsets;
33 SmallVector<Value> dynamicOffsets;
34 dispatchIndexOpFoldResults(ofrs: offset, dynamicVec&: dynamicOffsets, staticVec&: staticOffsets);
35 auto stridedLayout =
36 StridedLayoutAttr::get(context: source.getContext(), offset: staticOffsets.front(), strides: {});
37 return MemRefType::get(shape: {}, elementType: sourceType.getElementType(), layout: stridedLayout,
38 memorySpace: sourceType.getMemorySpace());
39}
40
41static void setInsertionPointToStart(OpBuilder &builder, Value val) {
42 if (auto *parentOp = val.getDefiningOp()) {
43 builder.setInsertionPointAfter(parentOp);
44 } else {
45 builder.setInsertionPointToStart(val.getParentBlock());
46 }
47}
48
49static bool isInsideLaunch(Operation *op) {
50 return op->getParentOfType<gpu::LaunchOp>();
51}
52
53static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>>
54getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
55 ArrayRef<OpFoldResult> subOffsets,
56 ArrayRef<OpFoldResult> subStrides = {}) {
57 auto sourceType = cast<MemRefType>(Val: source.getType());
58 auto sourceRank = static_cast<unsigned>(sourceType.getRank());
59
60 memref::ExtractStridedMetadataOp newExtractStridedMetadata;
61 {
62 OpBuilder::InsertionGuard g(rewriter);
63 setInsertionPointToStart(builder&: rewriter, val: source);
64 newExtractStridedMetadata =
65 rewriter.create<memref::ExtractStridedMetadataOp>(location: loc, args&: source);
66 }
67
68 auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
69
70 auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
71 return ShapedType::isDynamic(dValue: dim) ? getAsOpFoldResult(val: dimVal)
72 : rewriter.getIndexAttr(value: dim);
73 };
74
75 OpFoldResult origOffset =
76 getDim(sourceOffset, newExtractStridedMetadata.getOffset());
77 ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides();
78
79 SmallVector<OpFoldResult> origStrides;
80 origStrides.reserve(N: sourceRank);
81
82 SmallVector<OpFoldResult> strides;
83 strides.reserve(N: sourceRank);
84
85 AffineExpr s0 = rewriter.getAffineSymbolExpr(position: 0);
86 AffineExpr s1 = rewriter.getAffineSymbolExpr(position: 1);
87 for (auto i : llvm::seq(Begin: 0u, End: sourceRank)) {
88 OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]);
89
90 if (!subStrides.empty()) {
91 strides.push_back(Elt: affine::makeComposedFoldedAffineApply(
92 b&: rewriter, loc, expr: s0 * s1, operands: {subStrides[i], origStride}));
93 }
94
95 origStrides.emplace_back(Args&: origStride);
96 }
97
98 auto &&[expr, values] =
99 computeLinearIndex(sourceOffset: origOffset, strides: origStrides, indices: subOffsets);
100 OpFoldResult finalOffset =
101 affine::makeComposedFoldedAffineApply(b&: rewriter, loc, expr, operands: values);
102 return {newExtractStridedMetadata.getBaseBuffer(), finalOffset, strides};
103}
104
105static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source,
106 ValueRange offsets) {
107 SmallVector<OpFoldResult> offsetsTemp = getAsOpFoldResult(values: offsets);
108 auto &&[base, offset, ignore] =
109 getFlatOffsetAndStrides(rewriter, loc, source, subOffsets: offsetsTemp);
110 MemRefType retType = inferCastResultType(source: base, offset);
111 return rewriter.create<memref::ReinterpretCastOp>(location: loc, args&: retType, args&: base, args&: offset,
112 args: ArrayRef<OpFoldResult>(),
113 args: ArrayRef<OpFoldResult>());
114}
115
116static bool needFlatten(Value val) {
117 auto type = cast<MemRefType>(Val: val.getType());
118 return type.getRank() != 0;
119}
120
121static bool checkLayout(Value val) {
122 auto type = cast<MemRefType>(Val: val.getType());
123 return type.getLayout().isIdentity() ||
124 isa<StridedLayoutAttr>(Val: type.getLayout());
125}
126
127namespace {
128struct FlattenLoad : public OpRewritePattern<memref::LoadOp> {
129 using OpRewritePattern::OpRewritePattern;
130
131 LogicalResult matchAndRewrite(memref::LoadOp op,
132 PatternRewriter &rewriter) const override {
133 if (!isInsideLaunch(op))
134 return rewriter.notifyMatchFailure(arg&: op, msg: "not inside gpu.launch");
135
136 Value memref = op.getMemref();
137 if (!needFlatten(val: memref))
138 return rewriter.notifyMatchFailure(arg&: op, msg: "nothing to do");
139
140 if (!checkLayout(val: memref))
141 return rewriter.notifyMatchFailure(arg&: op, msg: "unsupported layout");
142
143 Location loc = op.getLoc();
144 Value flatMemref = getFlatMemref(rewriter, loc, source: memref, offsets: op.getIndices());
145 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, args&: flatMemref);
146 return success();
147 }
148};
149
150struct FlattenStore : public OpRewritePattern<memref::StoreOp> {
151 using OpRewritePattern::OpRewritePattern;
152
153 LogicalResult matchAndRewrite(memref::StoreOp op,
154 PatternRewriter &rewriter) const override {
155 if (!isInsideLaunch(op))
156 return rewriter.notifyMatchFailure(arg&: op, msg: "not inside gpu.launch");
157
158 Value memref = op.getMemref();
159 if (!needFlatten(val: memref))
160 return rewriter.notifyMatchFailure(arg&: op, msg: "nothing to do");
161
162 if (!checkLayout(val: memref))
163 return rewriter.notifyMatchFailure(arg&: op, msg: "unsupported layout");
164
165 Location loc = op.getLoc();
166 Value flatMemref = getFlatMemref(rewriter, loc, source: memref, offsets: op.getIndices());
167 Value value = op.getValue();
168 rewriter.replaceOpWithNewOp<memref::StoreOp>(op, args&: value, args&: flatMemref);
169 return success();
170 }
171};
172
173struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
174 using OpRewritePattern::OpRewritePattern;
175
176 LogicalResult matchAndRewrite(memref::SubViewOp op,
177 PatternRewriter &rewriter) const override {
178 if (!isInsideLaunch(op))
179 return rewriter.notifyMatchFailure(arg&: op, msg: "not inside gpu.launch");
180
181 Value memref = op.getSource();
182 if (!needFlatten(val: memref))
183 return rewriter.notifyMatchFailure(arg&: op, msg: "nothing to do");
184
185 if (!checkLayout(val: memref))
186 return rewriter.notifyMatchFailure(arg&: op, msg: "unsupported layout");
187
188 Location loc = op.getLoc();
189 SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets();
190 SmallVector<OpFoldResult> subSizes = op.getMixedSizes();
191 SmallVector<OpFoldResult> subStrides = op.getMixedStrides();
192 auto &&[base, finalOffset, strides] =
193 getFlatOffsetAndStrides(rewriter, loc, source: memref, subOffsets, subStrides);
194
195 auto srcType = cast<MemRefType>(Val: memref.getType());
196 auto resultType = cast<MemRefType>(Val: op.getType());
197 unsigned subRank = static_cast<unsigned>(resultType.getRank());
198
199 llvm::SmallBitVector droppedDims = op.getDroppedDims();
200
201 SmallVector<OpFoldResult> finalSizes;
202 finalSizes.reserve(N: subRank);
203
204 SmallVector<OpFoldResult> finalStrides;
205 finalStrides.reserve(N: subRank);
206
207 for (auto i : llvm::seq(Begin: 0u, End: static_cast<unsigned>(srcType.getRank()))) {
208 if (droppedDims.test(Idx: i))
209 continue;
210
211 finalSizes.push_back(Elt: subSizes[i]);
212 finalStrides.push_back(Elt: strides[i]);
213 }
214
215 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
216 op, args&: resultType, args&: base, args&: finalOffset, args&: finalSizes, args&: finalStrides);
217 return success();
218 }
219};
220
221struct GpuDecomposeMemrefsPass
222 : public impl::GpuDecomposeMemrefsPassBase<GpuDecomposeMemrefsPass> {
223
224 void runOnOperation() override {
225 RewritePatternSet patterns(&getContext());
226
227 populateGpuDecomposeMemrefsPatterns(patterns);
228
229 if (failed(Result: applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns))))
230 return signalPassFailure();
231 }
232};
233
234} // namespace
235
236void mlir::populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns) {
237 patterns.insert<FlattenLoad, FlattenStore, FlattenSubview>(
238 arg: patterns.getContext());
239}
240

source code of mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp