1//===- DecomposeMemrefs.cpp - Decompose memrefs pass implementation -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements decompose memrefs pass.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/GPU/IR/GPUDialect.h"
16#include "mlir/Dialect/GPU/Transforms/Passes.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Dialect/Utils/IndexingUtils.h"
19#include "mlir/IR/AffineExpr.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Pass/Pass.h"
23#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24
25namespace mlir {
26#define GEN_PASS_DEF_GPUDECOMPOSEMEMREFSPASS
27#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
28} // namespace mlir
29
30using namespace mlir;
31
32static void setInsertionPointToStart(OpBuilder &builder, Value val) {
33 if (auto *parentOp = val.getDefiningOp()) {
34 builder.setInsertionPointAfter(parentOp);
35 } else {
36 builder.setInsertionPointToStart(val.getParentBlock());
37 }
38}
39
40static bool isInsideLaunch(Operation *op) {
41 return op->getParentOfType<gpu::LaunchOp>();
42}
43
44static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>>
45getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
46 ArrayRef<OpFoldResult> subOffsets,
47 ArrayRef<OpFoldResult> subStrides = std::nullopt) {
48 auto sourceType = cast<MemRefType>(source.getType());
49 auto sourceRank = static_cast<unsigned>(sourceType.getRank());
50
51 memref::ExtractStridedMetadataOp newExtractStridedMetadata;
52 {
53 OpBuilder::InsertionGuard g(rewriter);
54 setInsertionPointToStart(builder&: rewriter, val: source);
55 newExtractStridedMetadata =
56 rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
57 }
58
59 auto &&[sourceStrides, sourceOffset] = getStridesAndOffset(sourceType);
60
61 auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
62 return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal)
63 : rewriter.getIndexAttr(dim);
64 };
65
66 OpFoldResult origOffset =
67 getDim(sourceOffset, newExtractStridedMetadata.getOffset());
68 ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides();
69
70 SmallVector<OpFoldResult> origStrides;
71 origStrides.reserve(N: sourceRank);
72
73 SmallVector<OpFoldResult> strides;
74 strides.reserve(N: sourceRank);
75
76 AffineExpr s0 = rewriter.getAffineSymbolExpr(position: 0);
77 AffineExpr s1 = rewriter.getAffineSymbolExpr(position: 1);
78 for (auto i : llvm::seq(0u, sourceRank)) {
79 OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]);
80
81 if (!subStrides.empty()) {
82 strides.push_back(affine::makeComposedFoldedAffineApply(
83 rewriter, loc, s0 * s1, {subStrides[i], origStride}));
84 }
85
86 origStrides.emplace_back(origStride);
87 }
88
89 auto &&[expr, values] =
90 computeLinearIndex(sourceOffset: origOffset, strides: origStrides, indices: subOffsets);
91 OpFoldResult finalOffset =
92 affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values);
93 return {newExtractStridedMetadata.getBaseBuffer(), finalOffset, strides};
94}
95
96static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source,
97 ValueRange offsets) {
98 SmallVector<OpFoldResult> offsetsTemp = getAsOpFoldResult(values: offsets);
99 auto &&[base, offset, ignore] =
100 getFlatOffsetAndStrides(rewriter, loc, source, subOffsets: offsetsTemp);
101 auto retType = cast<MemRefType>(base.getType());
102 return rewriter.create<memref::ReinterpretCastOp>(loc, retType, base, offset,
103 std::nullopt, std::nullopt);
104}
105
106static bool needFlatten(Value val) {
107 auto type = cast<MemRefType>(val.getType());
108 return type.getRank() != 0;
109}
110
111static bool checkLayout(Value val) {
112 auto type = cast<MemRefType>(val.getType());
113 return type.getLayout().isIdentity() ||
114 isa<StridedLayoutAttr>(type.getLayout());
115}
116
117namespace {
118struct FlattenLoad : public OpRewritePattern<memref::LoadOp> {
119 using OpRewritePattern::OpRewritePattern;
120
121 LogicalResult matchAndRewrite(memref::LoadOp op,
122 PatternRewriter &rewriter) const override {
123 if (!isInsideLaunch(op))
124 return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
125
126 Value memref = op.getMemref();
127 if (!needFlatten(val: memref))
128 return rewriter.notifyMatchFailure(op, "nothing to do");
129
130 if (!checkLayout(val: memref))
131 return rewriter.notifyMatchFailure(op, "unsupported layout");
132
133 Location loc = op.getLoc();
134 Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices());
135 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, flatMemref);
136 return success();
137 }
138};
139
140struct FlattenStore : public OpRewritePattern<memref::StoreOp> {
141 using OpRewritePattern::OpRewritePattern;
142
143 LogicalResult matchAndRewrite(memref::StoreOp op,
144 PatternRewriter &rewriter) const override {
145 if (!isInsideLaunch(op))
146 return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
147
148 Value memref = op.getMemref();
149 if (!needFlatten(val: memref))
150 return rewriter.notifyMatchFailure(op, "nothing to do");
151
152 if (!checkLayout(val: memref))
153 return rewriter.notifyMatchFailure(op, "unsupported layout");
154
155 Location loc = op.getLoc();
156 Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices());
157 Value value = op.getValue();
158 rewriter.replaceOpWithNewOp<memref::StoreOp>(op, value, flatMemref);
159 return success();
160 }
161};
162
163struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
164 using OpRewritePattern::OpRewritePattern;
165
166 LogicalResult matchAndRewrite(memref::SubViewOp op,
167 PatternRewriter &rewriter) const override {
168 if (!isInsideLaunch(op))
169 return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
170
171 Value memref = op.getSource();
172 if (!needFlatten(val: memref))
173 return rewriter.notifyMatchFailure(op, "nothing to do");
174
175 if (!checkLayout(val: memref))
176 return rewriter.notifyMatchFailure(op, "unsupported layout");
177
178 Location loc = op.getLoc();
179 SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets();
180 SmallVector<OpFoldResult> subSizes = op.getMixedSizes();
181 SmallVector<OpFoldResult> subStrides = op.getMixedStrides();
182 auto &&[base, finalOffset, strides] =
183 getFlatOffsetAndStrides(rewriter, loc, source: memref, subOffsets, subStrides);
184
185 auto srcType = cast<MemRefType>(memref.getType());
186 auto resultType = cast<MemRefType>(op.getType());
187 unsigned subRank = static_cast<unsigned>(resultType.getRank());
188
189 llvm::SmallBitVector droppedDims = op.getDroppedDims();
190
191 SmallVector<OpFoldResult> finalSizes;
192 finalSizes.reserve(N: subRank);
193
194 SmallVector<OpFoldResult> finalStrides;
195 finalStrides.reserve(N: subRank);
196
197 for (auto i : llvm::seq(0u, static_cast<unsigned>(srcType.getRank()))) {
198 if (droppedDims.test(i))
199 continue;
200
201 finalSizes.push_back(subSizes[i]);
202 finalStrides.push_back(strides[i]);
203 }
204
205 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
206 op, resultType, base, finalOffset, finalSizes, finalStrides);
207 return success();
208 }
209};
210
211struct GpuDecomposeMemrefsPass
212 : public impl::GpuDecomposeMemrefsPassBase<GpuDecomposeMemrefsPass> {
213
214 void runOnOperation() override {
215 RewritePatternSet patterns(&getContext());
216
217 populateGpuDecomposeMemrefsPatterns(patterns);
218
219 if (failed(
220 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
221 return signalPassFailure();
222 }
223};
224
225} // namespace
226
227void mlir::populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns) {
228 patterns.insert<FlattenLoad, FlattenStore, FlattenSubview>(
229 arg: patterns.getContext());
230}
231
232std::unique_ptr<Pass> mlir::createGpuDecomposeMemrefsPass() {
233 return std::make_unique<GpuDecomposeMemrefsPass>();
234}
235

source code of mlir/lib/Dialect/GPU/Transforms/DecomposeMemrefs.cpp