1//===- DecomposeMemRefs.cpp - Decompose memrefs pass implementation -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements decompose memrefs pass.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/GPU/IR/GPUDialect.h"
16#include "mlir/Dialect/GPU/Transforms/Passes.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Dialect/Utils/IndexingUtils.h"
19#include "mlir/IR/AffineExpr.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Pass/Pass.h"
23#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24
25namespace mlir {
26#define GEN_PASS_DEF_GPUDECOMPOSEMEMREFSPASS
27#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
28} // namespace mlir
29
30using namespace mlir;
31
32static MemRefType inferCastResultType(Value source, OpFoldResult offset) {
33 auto sourceType = cast<BaseMemRefType>(source.getType());
34 SmallVector<int64_t> staticOffsets;
35 SmallVector<Value> dynamicOffsets;
36 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
37 auto stridedLayout =
38 StridedLayoutAttr::get(source.getContext(), staticOffsets.front(), {});
39 return MemRefType::get({}, sourceType.getElementType(), stridedLayout,
40 sourceType.getMemorySpace());
41}
42
43static void setInsertionPointToStart(OpBuilder &builder, Value val) {
44 if (auto *parentOp = val.getDefiningOp()) {
45 builder.setInsertionPointAfter(parentOp);
46 } else {
47 builder.setInsertionPointToStart(val.getParentBlock());
48 }
49}
50
51static bool isInsideLaunch(Operation *op) {
52 return op->getParentOfType<gpu::LaunchOp>();
53}
54
55static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>>
56getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
57 ArrayRef<OpFoldResult> subOffsets,
58 ArrayRef<OpFoldResult> subStrides = std::nullopt) {
59 auto sourceType = cast<MemRefType>(source.getType());
60 auto sourceRank = static_cast<unsigned>(sourceType.getRank());
61
62 memref::ExtractStridedMetadataOp newExtractStridedMetadata;
63 {
64 OpBuilder::InsertionGuard g(rewriter);
65 setInsertionPointToStart(builder&: rewriter, val: source);
66 newExtractStridedMetadata =
67 rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
68 }
69
70 auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
71
72 auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
73 return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal)
74 : rewriter.getIndexAttr(dim);
75 };
76
77 OpFoldResult origOffset =
78 getDim(sourceOffset, newExtractStridedMetadata.getOffset());
79 ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides();
80
81 SmallVector<OpFoldResult> origStrides;
82 origStrides.reserve(N: sourceRank);
83
84 SmallVector<OpFoldResult> strides;
85 strides.reserve(N: sourceRank);
86
87 AffineExpr s0 = rewriter.getAffineSymbolExpr(position: 0);
88 AffineExpr s1 = rewriter.getAffineSymbolExpr(position: 1);
89 for (auto i : llvm::seq(0u, sourceRank)) {
90 OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]);
91
92 if (!subStrides.empty()) {
93 strides.push_back(affine::makeComposedFoldedAffineApply(
94 rewriter, loc, s0 * s1, {subStrides[i], origStride}));
95 }
96
97 origStrides.emplace_back(origStride);
98 }
99
100 auto &&[expr, values] =
101 computeLinearIndex(sourceOffset: origOffset, strides: origStrides, indices: subOffsets);
102 OpFoldResult finalOffset =
103 affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values);
104 return {newExtractStridedMetadata.getBaseBuffer(), finalOffset, strides};
105}
106
107static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source,
108 ValueRange offsets) {
109 SmallVector<OpFoldResult> offsetsTemp = getAsOpFoldResult(values: offsets);
110 auto &&[base, offset, ignore] =
111 getFlatOffsetAndStrides(rewriter, loc, source, subOffsets: offsetsTemp);
112 MemRefType retType = inferCastResultType(base, offset);
113 return rewriter.create<memref::ReinterpretCastOp>(loc, retType, base, offset,
114 std::nullopt, std::nullopt);
115}
116
117static bool needFlatten(Value val) {
118 auto type = cast<MemRefType>(val.getType());
119 return type.getRank() != 0;
120}
121
122static bool checkLayout(Value val) {
123 auto type = cast<MemRefType>(val.getType());
124 return type.getLayout().isIdentity() ||
125 isa<StridedLayoutAttr>(type.getLayout());
126}
127
128namespace {
129struct FlattenLoad : public OpRewritePattern<memref::LoadOp> {
130 using OpRewritePattern::OpRewritePattern;
131
132 LogicalResult matchAndRewrite(memref::LoadOp op,
133 PatternRewriter &rewriter) const override {
134 if (!isInsideLaunch(op))
135 return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
136
137 Value memref = op.getMemref();
138 if (!needFlatten(val: memref))
139 return rewriter.notifyMatchFailure(op, "nothing to do");
140
141 if (!checkLayout(val: memref))
142 return rewriter.notifyMatchFailure(op, "unsupported layout");
143
144 Location loc = op.getLoc();
145 Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices());
146 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, flatMemref);
147 return success();
148 }
149};
150
151struct FlattenStore : public OpRewritePattern<memref::StoreOp> {
152 using OpRewritePattern::OpRewritePattern;
153
154 LogicalResult matchAndRewrite(memref::StoreOp op,
155 PatternRewriter &rewriter) const override {
156 if (!isInsideLaunch(op))
157 return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
158
159 Value memref = op.getMemref();
160 if (!needFlatten(val: memref))
161 return rewriter.notifyMatchFailure(op, "nothing to do");
162
163 if (!checkLayout(val: memref))
164 return rewriter.notifyMatchFailure(op, "unsupported layout");
165
166 Location loc = op.getLoc();
167 Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices());
168 Value value = op.getValue();
169 rewriter.replaceOpWithNewOp<memref::StoreOp>(op, value, flatMemref);
170 return success();
171 }
172};
173
174struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
175 using OpRewritePattern::OpRewritePattern;
176
177 LogicalResult matchAndRewrite(memref::SubViewOp op,
178 PatternRewriter &rewriter) const override {
179 if (!isInsideLaunch(op))
180 return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
181
182 Value memref = op.getSource();
183 if (!needFlatten(val: memref))
184 return rewriter.notifyMatchFailure(op, "nothing to do");
185
186 if (!checkLayout(val: memref))
187 return rewriter.notifyMatchFailure(op, "unsupported layout");
188
189 Location loc = op.getLoc();
190 SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets();
191 SmallVector<OpFoldResult> subSizes = op.getMixedSizes();
192 SmallVector<OpFoldResult> subStrides = op.getMixedStrides();
193 auto &&[base, finalOffset, strides] =
194 getFlatOffsetAndStrides(rewriter, loc, source: memref, subOffsets, subStrides);
195
196 auto srcType = cast<MemRefType>(memref.getType());
197 auto resultType = cast<MemRefType>(op.getType());
198 unsigned subRank = static_cast<unsigned>(resultType.getRank());
199
200 llvm::SmallBitVector droppedDims = op.getDroppedDims();
201
202 SmallVector<OpFoldResult> finalSizes;
203 finalSizes.reserve(N: subRank);
204
205 SmallVector<OpFoldResult> finalStrides;
206 finalStrides.reserve(N: subRank);
207
208 for (auto i : llvm::seq(0u, static_cast<unsigned>(srcType.getRank()))) {
209 if (droppedDims.test(i))
210 continue;
211
212 finalSizes.push_back(subSizes[i]);
213 finalStrides.push_back(strides[i]);
214 }
215
216 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
217 op, resultType, base, finalOffset, finalSizes, finalStrides);
218 return success();
219 }
220};
221
222struct GpuDecomposeMemrefsPass
223 : public impl::GpuDecomposeMemrefsPassBase<GpuDecomposeMemrefsPass> {
224
225 void runOnOperation() override {
226 RewritePatternSet patterns(&getContext());
227
228 populateGpuDecomposeMemrefsPatterns(patterns);
229
230 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
231 return signalPassFailure();
232 }
233};
234
235} // namespace
236
237void mlir::populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns) {
238 patterns.insert<FlattenLoad, FlattenStore, FlattenSubview>(
239 arg: patterns.getContext());
240}
241

source code of mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp