1//===- MaskedloadToLoad.cpp - Lowers maskedload to load -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
10
11#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12#include "mlir/Dialect/Affine/IR/AffineOps.h"
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/MemRef/IR/MemRef.h"
15#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
16#include "mlir/Dialect/SCF/IR/SCF.h"
17#include "mlir/Dialect/Vector/IR/VectorOps.h"
18#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
19#include "mlir/IR/BuiltinTypes.h"
20#include "mlir/IR/OpDefinition.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/IR/TypeUtilities.h"
23#include "mlir/Pass/Pass.h"
24#include "mlir/Support/LogicalResult.h"
25#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26#include "llvm/Support/MathExtras.h"
27
28namespace mlir::amdgpu {
29#define GEN_PASS_DEF_AMDGPUMASKEDLOADTOLOADPASS
30#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
31} // namespace mlir::amdgpu
32
33using namespace mlir;
34using namespace mlir::amdgpu;
35
36/// This pattern supports lowering of: `vector.maskedload` to `vector.load`
37/// and `arith.select` if the memref is in buffer address space.
38static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
39 vector::MaskedLoadOp maskedOp) {
40 auto memRefType = dyn_cast<MemRefType>(Val: maskedOp.getBase().getType());
41 if (!memRefType)
42 return rewriter.notifyMatchFailure(arg&: maskedOp, msg: "not a memref source");
43
44 Attribute addrSpace = memRefType.getMemorySpace();
45 if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(Val: addrSpace))
46 return rewriter.notifyMatchFailure(arg&: maskedOp, msg: "no address space");
47
48 if (dyn_cast<amdgpu::AddressSpaceAttr>(Val&: addrSpace).getValue() !=
49 amdgpu::AddressSpace::FatRawBuffer)
50 return rewriter.notifyMatchFailure(arg&: maskedOp, msg: "not in buffer address space");
51
52 return success();
53}
54
55static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
56 vector::MaskedLoadOp maskedOp,
57 bool passthru) {
58 VectorType vectorType = maskedOp.getVectorType();
59 Value load = builder.create<vector::LoadOp>(
60 location: loc, args&: vectorType, args: maskedOp.getBase(), args: maskedOp.getIndices());
61 if (passthru)
62 load = builder.create<arith::SelectOp>(location: loc, args&: vectorType, args: maskedOp.getMask(),
63 args&: load, args: maskedOp.getPassThru());
64 return load;
65}
66
67/// Check if the given value comes from a broadcasted i1 condition.
68static FailureOr<Value> matchFullMask(OpBuilder &b, Value val) {
69 auto broadcastOp = val.getDefiningOp<vector::BroadcastOp>();
70 if (!broadcastOp)
71 return failure();
72 if (isa<VectorType>(Val: broadcastOp.getSourceType()))
73 return failure();
74 return broadcastOp.getSource();
75}
76
77static constexpr char kMaskedloadNeedsMask[] =
78 "amdgpu.buffer_maskedload_needs_mask";
79
80namespace {
81
82struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
83 using OpRewritePattern::OpRewritePattern;
84
85 LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp,
86 PatternRewriter &rewriter) const override {
87 if (maskedOp->hasAttr(name: kMaskedloadNeedsMask))
88 return failure();
89
90 if (failed(Result: baseInBufferAddrSpace(rewriter, maskedOp))) {
91 return failure();
92 }
93
94 // Check if this is either a full inbounds load or an empty, oob load. If
95 // so, take the fast path and don't generate an if condition, because we
96 // know doing the oob load is always safe.
97 if (succeeded(Result: matchFullMask(b&: rewriter, val: maskedOp.getMask()))) {
98 Value load = createVectorLoadForMaskedLoad(builder&: rewriter, loc: maskedOp.getLoc(),
99 maskedOp, /*passthru=*/true);
100 rewriter.replaceOp(op: maskedOp, newValues: load);
101 return success();
102 }
103
104 Location loc = maskedOp.getLoc();
105 Value src = maskedOp.getBase();
106
107 VectorType vectorType = maskedOp.getVectorType();
108 int64_t vectorSize = vectorType.getNumElements();
109 int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
110 SmallVector<OpFoldResult> indices = maskedOp.getIndices();
111
112 auto stridedMetadata =
113 rewriter.create<memref::ExtractStridedMetadataOp>(location: loc, args&: src);
114 SmallVector<OpFoldResult> strides =
115 stridedMetadata.getConstifiedMixedStrides();
116 SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
117 OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
118 memref::LinearizedMemRefInfo linearizedInfo;
119 OpFoldResult linearizedIndices;
120 std::tie(args&: linearizedInfo, args&: linearizedIndices) =
121 memref::getLinearizedMemRefOffsetAndSize(builder&: rewriter, loc, srcBits: elementBitWidth,
122 dstBits: elementBitWidth, offset, sizes,
123 strides, indices);
124
125 // delta = bufferSize - linearizedOffset
126 Value vectorSizeOffset =
127 rewriter.create<arith::ConstantIndexOp>(location: loc, args&: vectorSize);
128 Value linearIndex =
129 getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: linearizedIndices);
130 Value totalSize = getValueOrCreateConstantIndexOp(
131 b&: rewriter, loc, ofr: linearizedInfo.linearizedSize);
132 Value delta = rewriter.create<arith::SubIOp>(location: loc, args&: totalSize, args&: linearIndex);
133
134 // 1) check if delta < vectorSize
135 Value isOutofBounds = rewriter.create<arith::CmpIOp>(
136 location: loc, args: arith::CmpIPredicate::ult, args&: delta, args&: vectorSizeOffset);
137
138 // 2) check if (detla % elements_per_word != 0)
139 Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
140 location: loc, args: llvm::divideCeil(Numerator: 32, Denominator: elementBitWidth));
141 Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
142 location: loc, args: arith::CmpIPredicate::ne,
143 args: rewriter.create<arith::RemUIOp>(location: loc, args&: delta, args&: elementsPerWord),
144 args: rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0));
145
146 // We take the fallback of maskedload default lowering only it is both
147 // out-of-bounds and not word aligned. The fallback ensures correct results
148 // when loading at the boundary of the buffer since buffer load returns
149 // inconsistent zeros for the whole word when boundary is crossed.
150 Value ifCondition =
151 rewriter.create<arith::AndIOp>(location: loc, args&: isOutofBounds, args&: isNotWordAligned);
152
153 auto thenBuilder = [&](OpBuilder &builder, Location loc) {
154 Operation *read = builder.clone(op&: *maskedOp.getOperation());
155 read->setAttr(name: kMaskedloadNeedsMask, value: builder.getUnitAttr());
156 Value readResult = read->getResult(idx: 0);
157 builder.create<scf::YieldOp>(location: loc, args&: readResult);
158 };
159
160 auto elseBuilder = [&](OpBuilder &builder, Location loc) {
161 Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp,
162 /*passthru=*/true);
163 rewriter.create<scf::YieldOp>(location: loc, args&: res);
164 };
165
166 auto ifOp =
167 rewriter.create<scf::IfOp>(location: loc, args&: ifCondition, args&: thenBuilder, args&: elseBuilder);
168
169 rewriter.replaceOp(op: maskedOp, newOp: ifOp);
170
171 return success();
172 }
173};
174
175struct FullMaskedLoadToConditionalLoad
176 : OpRewritePattern<vector::MaskedLoadOp> {
177 using OpRewritePattern::OpRewritePattern;
178
179 LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp,
180 PatternRewriter &rewriter) const override {
181 FailureOr<Value> maybeCond = matchFullMask(b&: rewriter, val: loadOp.getMask());
182 if (failed(Result: maybeCond)) {
183 return failure();
184 }
185
186 Value cond = maybeCond.value();
187 auto trueBuilder = [&](OpBuilder &builder, Location loc) {
188 Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp: loadOp,
189 /*passthru=*/false);
190 rewriter.create<scf::YieldOp>(location: loc, args&: res);
191 };
192 auto falseBuilder = [&](OpBuilder &builder, Location loc) {
193 rewriter.create<scf::YieldOp>(location: loc, args: loadOp.getPassThru());
194 };
195 auto ifOp = rewriter.create<scf::IfOp>(location: loadOp.getLoc(), args&: cond, args&: trueBuilder,
196 args&: falseBuilder);
197 rewriter.replaceOp(op: loadOp, newOp: ifOp);
198 return success();
199 }
200};
201
202struct FullMaskedStoreToConditionalStore
203 : OpRewritePattern<vector::MaskedStoreOp> {
204 using OpRewritePattern::OpRewritePattern;
205
206 LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp,
207 PatternRewriter &rewriter) const override {
208 FailureOr<Value> maybeCond = matchFullMask(b&: rewriter, val: storeOp.getMask());
209 if (failed(Result: maybeCond)) {
210 return failure();
211 }
212 Value cond = maybeCond.value();
213
214 auto trueBuilder = [&](OpBuilder &builder, Location loc) {
215 rewriter.create<vector::StoreOp>(location: loc, args: storeOp.getValueToStore(),
216 args: storeOp.getBase(), args: storeOp.getIndices());
217 rewriter.create<scf::YieldOp>(location: loc);
218 };
219 auto ifOp = rewriter.create<scf::IfOp>(location: storeOp.getLoc(), args&: cond, args&: trueBuilder);
220 rewriter.replaceOp(op: storeOp, newOp: ifOp);
221 return success();
222 }
223};
224
225} // namespace
226
227void mlir::amdgpu::populateAmdgpuMaskedloadToLoadPatterns(
228 RewritePatternSet &patterns, PatternBenefit benefit) {
229 patterns.add<MaskedLoadLowering, FullMaskedLoadToConditionalLoad,
230 FullMaskedStoreToConditionalStore>(arg: patterns.getContext(),
231 args&: benefit);
232}
233
234struct AmdgpuMaskedloadToLoadPass final
235 : amdgpu::impl::AmdgpuMaskedloadToLoadPassBase<AmdgpuMaskedloadToLoadPass> {
236 void runOnOperation() override {
237 RewritePatternSet patterns(&getContext());
238 populateAmdgpuMaskedloadToLoadPatterns(patterns);
239 if (failed(Result: applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns)))) {
240 return signalPassFailure();
241 }
242 }
243};
244

source code of mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp