1 | //===- TransferReadToLoad.cpp - Lowers masked transfer read to load -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/AMDGPU/Transforms/Passes.h" |
10 | |
11 | #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" |
12 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
13 | #include "mlir/Dialect/Arith/IR/Arith.h" |
14 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
15 | #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" |
16 | #include "mlir/Dialect/SCF/IR/SCF.h" |
17 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
18 | #include "mlir/IR/BuiltinTypes.h" |
19 | #include "mlir/IR/OpDefinition.h" |
20 | #include "mlir/IR/PatternMatch.h" |
21 | #include "mlir/IR/TypeUtilities.h" |
22 | #include "mlir/Pass/Pass.h" |
23 | #include "mlir/Support/LogicalResult.h" |
24 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
25 | #include "llvm/Support/MathExtras.h" |
26 | |
27 | namespace mlir::amdgpu { |
28 | #define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS |
29 | #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc" |
30 | } // namespace mlir::amdgpu |
31 | |
32 | using namespace mlir; |
33 | using namespace mlir::amdgpu; |
34 | |
35 | /// This pattern supports lowering of: |
36 | /// `vector.transfer_read` to a combination of `vector.load`, `arith.select` and |
37 | /// `vector.broadcast` if all of the following hold: |
38 | /// - The transfer op is masked. |
39 | /// - The memref is in buffer address space. |
40 | /// - Stride of most minor memref dimension must be 1. |
41 | /// - Out-of-bounds masking is not required. |
42 | /// - If the memref's element type is a vector type then it coincides with the |
43 | /// result type. |
44 | /// - The permutation map doesn't perform permutation (broadcasting is allowed). |
45 | /// Note: those conditions mostly come from TransferReadToVectorLoadLowering |
46 | /// pass. |
47 | static LogicalResult transferPreconditions( |
48 | PatternRewriter &rewriter, VectorTransferOpInterface xferOp, |
49 | bool &requiresBroadcasting, VectorType &unbroadcastedVectorType) { |
50 | if (!xferOp.getMask()) |
51 | return rewriter.notifyMatchFailure(xferOp, "Only support masked transfer" ); |
52 | |
53 | // Permutations are handled by VectorToSCF or |
54 | // populateVectorTransferPermutationMapLoweringPatterns. |
55 | // We let the 0-d corner case pass-through as it is supported. |
56 | SmallVector<unsigned> broadcastedDims; |
57 | if (!xferOp.getPermutationMap().isMinorIdentityWithBroadcasting( |
58 | &broadcastedDims)) |
59 | return rewriter.notifyMatchFailure(xferOp, "not minor identity + bcast" ); |
60 | |
61 | auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType()); |
62 | if (!memRefType) |
63 | return rewriter.notifyMatchFailure(xferOp, "not a memref source" ); |
64 | |
65 | Attribute addrSpace = memRefType.getMemorySpace(); |
66 | if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace)) |
67 | return rewriter.notifyMatchFailure(xferOp, "no address space" ); |
68 | |
69 | if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() != |
70 | amdgpu::AddressSpace::FatRawBuffer) |
71 | return rewriter.notifyMatchFailure(xferOp, "not in buffer address space" ); |
72 | |
73 | // Non-unit strides are handled by VectorToSCF. |
74 | if (!memRefType.isLastDimUnitStride()) |
75 | return rewriter.notifyMatchFailure(xferOp, "!= 1 stride needs VectorToSCF" ); |
76 | |
77 | if (memRefType.getElementTypeBitWidth() < 8) |
78 | return rewriter.notifyMatchFailure(xferOp, "unsupported sub-byte type" ); |
79 | |
80 | // If there is broadcasting involved then we first load the unbroadcasted |
81 | // vector, and then broadcast it with `vector.broadcast`. |
82 | ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape(); |
83 | SmallVector<int64_t> unbroadcastedVectorShape(vectorShape); |
84 | for (unsigned i : broadcastedDims) |
85 | unbroadcastedVectorShape[i] = 1; |
86 | unbroadcastedVectorType = xferOp.getVectorType().cloneWith( |
87 | unbroadcastedVectorShape, xferOp.getVectorType().getElementType()); |
88 | requiresBroadcasting = !broadcastedDims.empty(); |
89 | |
90 | // `vector.load` supports vector types as memref's elements only when the |
91 | // resulting vector type is the same as the element type. |
92 | auto memrefElTy = memRefType.getElementType(); |
93 | if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType) |
94 | return rewriter.notifyMatchFailure(xferOp, "incompatible element type" ); |
95 | |
96 | // Otherwise, element types of the memref and the vector must match. |
97 | if (!isa<VectorType>(memrefElTy) && |
98 | memrefElTy != xferOp.getVectorType().getElementType()) |
99 | return rewriter.notifyMatchFailure(xferOp, "non-matching element type" ); |
100 | |
101 | // Out-of-bounds dims are handled by MaterializeTransferMask. |
102 | if (xferOp.hasOutOfBoundsDim()) |
103 | return rewriter.notifyMatchFailure(xferOp, "out-of-bounds needs mask" ); |
104 | |
105 | if (xferOp.getVectorType().getRank() != 1) |
106 | // vector.maskedload operates on 1-D vectors. |
107 | return rewriter.notifyMatchFailure( |
108 | xferOp, "vector type is not rank 1, can't create masked load, needs " |
109 | "VectorToSCF" ); |
110 | |
111 | return success(); |
112 | } |
113 | |
114 | static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc, |
115 | vector::TransferReadOp readOp, |
116 | bool requiresBroadcasting, |
117 | VectorType unbroadcastedVectorType) { |
118 | Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType, |
119 | readOp.getPadding()); |
120 | Value load = builder.create<vector::LoadOp>( |
121 | loc, unbroadcastedVectorType, readOp.getBase(), readOp.getIndices()); |
122 | Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType, |
123 | readOp.getMask(), load, fill); |
124 | // Insert a broadcasting op if required. |
125 | if (requiresBroadcasting) { |
126 | res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(), res); |
127 | } |
128 | return res; |
129 | } |
130 | |
131 | static constexpr char kTransferReadNeedsMask[] = |
132 | "amdgpu.buffer_transfer_read_needs_mask" ; |
133 | |
134 | namespace { |
135 | |
136 | struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> { |
137 | using OpRewritePattern::OpRewritePattern; |
138 | |
139 | LogicalResult matchAndRewrite(vector::TransferReadOp readOp, |
140 | PatternRewriter &rewriter) const override { |
141 | if (readOp->hasAttr(kTransferReadNeedsMask)) |
142 | return failure(); |
143 | |
144 | bool requiresBroadcasting = false; |
145 | VectorType unbroadcastedVectorType; |
146 | if (failed(transferPreconditions(rewriter, readOp, requiresBroadcasting, |
147 | unbroadcastedVectorType))) { |
148 | return failure(); |
149 | } |
150 | |
151 | Location loc = readOp.getLoc(); |
152 | Value src = readOp.getBase(); |
153 | |
154 | VectorType vectorType = readOp.getVectorType(); |
155 | int64_t vectorSize = vectorType.getNumElements(); |
156 | int64_t elementBitWidth = vectorType.getElementTypeBitWidth(); |
157 | SmallVector<OpFoldResult> indices = readOp.getIndices(); |
158 | |
159 | auto stridedMetadata = |
160 | rewriter.create<memref::ExtractStridedMetadataOp>(loc, src); |
161 | SmallVector<OpFoldResult> strides = |
162 | stridedMetadata.getConstifiedMixedStrides(); |
163 | SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes(); |
164 | OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset(); |
165 | memref::LinearizedMemRefInfo linearizedInfo; |
166 | OpFoldResult linearizedIndices; |
167 | std::tie(args&: linearizedInfo, args&: linearizedIndices) = |
168 | memref::getLinearizedMemRefOffsetAndSize(builder&: rewriter, loc, srcBits: elementBitWidth, |
169 | dstBits: elementBitWidth, offset, sizes, |
170 | strides, indices); |
171 | |
172 | // delta = bufferSize - linearizedOffset |
173 | Value vectorSizeOffset = |
174 | rewriter.create<arith::ConstantIndexOp>(location: loc, args&: vectorSize); |
175 | Value linearIndex = |
176 | getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: linearizedIndices); |
177 | Value totalSize = getValueOrCreateConstantIndexOp( |
178 | b&: rewriter, loc, ofr: linearizedInfo.linearizedSize); |
179 | Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex); |
180 | |
181 | // 1) check if delta < vectorSize |
182 | Value isOutofBounds = rewriter.create<arith::CmpIOp>( |
183 | loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset); |
184 | |
185 | // 2) check if (detla % elements_per_word != 0) |
186 | Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>( |
187 | location: loc, args: llvm::divideCeil(Numerator: 32, Denominator: elementBitWidth)); |
188 | Value isNotWordAligned = rewriter.create<arith::CmpIOp>( |
189 | loc, arith::CmpIPredicate::ne, |
190 | rewriter.create<arith::RemUIOp>(loc, delta, elementsPerWord), |
191 | rewriter.create<arith::ConstantIndexOp>(loc, 0)); |
192 | |
193 | // We take the fallback of transfer_read default lowering only it is both |
194 | // out-of-bounds and not word aligned. The fallback ensures correct results |
195 | // when loading at the boundary of the buffer since buffer load returns |
196 | // inconsistent zeros for the whole word when boundary is crossed. |
197 | Value ifCondition = |
198 | rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned); |
199 | |
200 | auto thenBuilder = [&](OpBuilder &builder, Location loc) { |
201 | Operation *read = builder.clone(*readOp.getOperation()); |
202 | read->setAttr(kTransferReadNeedsMask, builder.getUnitAttr()); |
203 | Value readResult = read->getResult(idx: 0); |
204 | builder.create<scf::YieldOp>(loc, readResult); |
205 | }; |
206 | |
207 | auto elseBuilder = [&](OpBuilder &builder, Location loc) { |
208 | Value res = createVectorLoadForMaskedLoad( |
209 | builder, loc, readOp, requiresBroadcasting, unbroadcastedVectorType); |
210 | rewriter.create<scf::YieldOp>(loc, res); |
211 | }; |
212 | |
213 | auto ifOp = |
214 | rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder); |
215 | |
216 | rewriter.replaceOp(readOp, ifOp); |
217 | |
218 | return success(); |
219 | } |
220 | }; |
221 | |
222 | } // namespace |
223 | |
224 | void mlir::amdgpu::populateAmdgpuTransferReadToLoadPatterns( |
225 | RewritePatternSet &patterns) { |
226 | patterns.add<TransferReadLowering>(arg: patterns.getContext()); |
227 | } |
228 | |
229 | struct AmdgpuTransferReadToLoadPass final |
230 | : amdgpu::impl::AmdgpuTransferReadToLoadPassBase< |
231 | AmdgpuTransferReadToLoadPass> { |
232 | void runOnOperation() override { |
233 | RewritePatternSet patterns(&getContext()); |
234 | populateAmdgpuTransferReadToLoadPatterns(patterns); |
235 | if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { |
236 | return signalPassFailure(); |
237 | } |
238 | } |
239 | }; |
240 | |