| 1 | //===- TransferReadToLoad.cpp - Lowers masked transfer read to load -------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/AMDGPU/Transforms/Passes.h" |
| 10 | |
| 11 | #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" |
| 12 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 13 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 14 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 15 | #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" |
| 16 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 17 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 18 | #include "mlir/IR/BuiltinTypes.h" |
| 19 | #include "mlir/IR/OpDefinition.h" |
| 20 | #include "mlir/IR/PatternMatch.h" |
| 21 | #include "mlir/IR/TypeUtilities.h" |
| 22 | #include "mlir/Pass/Pass.h" |
| 23 | #include "mlir/Support/LogicalResult.h" |
| 24 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 25 | #include "llvm/Support/MathExtras.h" |
| 26 | |
| 27 | namespace mlir::amdgpu { |
| 28 | #define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS |
| 29 | #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc" |
| 30 | } // namespace mlir::amdgpu |
| 31 | |
| 32 | using namespace mlir; |
| 33 | using namespace mlir::amdgpu; |
| 34 | |
| 35 | /// This pattern supports lowering of: |
| 36 | /// `vector.transfer_read` to a combination of `vector.load`, `arith.select` and |
| 37 | /// `vector.broadcast` if all of the following hold: |
| 38 | /// - The transfer op is masked. |
| 39 | /// - The memref is in buffer address space. |
| 40 | /// - Stride of most minor memref dimension must be 1. |
| 41 | /// - Out-of-bounds masking is not required. |
| 42 | /// - If the memref's element type is a vector type then it coincides with the |
| 43 | /// result type. |
| 44 | /// - The permutation map doesn't perform permutation (broadcasting is allowed). |
| 45 | /// Note: those conditions mostly come from TransferReadToVectorLoadLowering |
| 46 | /// pass. |
| 47 | static LogicalResult transferPreconditions( |
| 48 | PatternRewriter &rewriter, VectorTransferOpInterface xferOp, |
| 49 | bool &requiresBroadcasting, VectorType &unbroadcastedVectorType) { |
| 50 | if (!xferOp.getMask()) |
| 51 | return rewriter.notifyMatchFailure(xferOp, "Only support masked transfer" ); |
| 52 | |
| 53 | // Permutations are handled by VectorToSCF or |
| 54 | // populateVectorTransferPermutationMapLoweringPatterns. |
| 55 | // We let the 0-d corner case pass-through as it is supported. |
| 56 | SmallVector<unsigned> broadcastedDims; |
| 57 | if (!xferOp.getPermutationMap().isMinorIdentityWithBroadcasting( |
| 58 | &broadcastedDims)) |
| 59 | return rewriter.notifyMatchFailure(xferOp, "not minor identity + bcast" ); |
| 60 | |
| 61 | auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType()); |
| 62 | if (!memRefType) |
| 63 | return rewriter.notifyMatchFailure(xferOp, "not a memref source" ); |
| 64 | |
| 65 | Attribute addrSpace = memRefType.getMemorySpace(); |
| 66 | if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace)) |
| 67 | return rewriter.notifyMatchFailure(xferOp, "no address space" ); |
| 68 | |
| 69 | if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() != |
| 70 | amdgpu::AddressSpace::FatRawBuffer) |
| 71 | return rewriter.notifyMatchFailure(xferOp, "not in buffer address space" ); |
| 72 | |
| 73 | // Non-unit strides are handled by VectorToSCF. |
| 74 | if (!memRefType.isLastDimUnitStride()) |
| 75 | return rewriter.notifyMatchFailure(xferOp, "!= 1 stride needs VectorToSCF" ); |
| 76 | |
| 77 | if (memRefType.getElementTypeBitWidth() < 8) |
| 78 | return rewriter.notifyMatchFailure(xferOp, "unsupported sub-byte type" ); |
| 79 | |
| 80 | // If there is broadcasting involved then we first load the unbroadcasted |
| 81 | // vector, and then broadcast it with `vector.broadcast`. |
| 82 | ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape(); |
| 83 | SmallVector<int64_t> unbroadcastedVectorShape(vectorShape); |
| 84 | for (unsigned i : broadcastedDims) |
| 85 | unbroadcastedVectorShape[i] = 1; |
| 86 | unbroadcastedVectorType = xferOp.getVectorType().cloneWith( |
| 87 | unbroadcastedVectorShape, xferOp.getVectorType().getElementType()); |
| 88 | requiresBroadcasting = !broadcastedDims.empty(); |
| 89 | |
| 90 | // `vector.load` supports vector types as memref's elements only when the |
| 91 | // resulting vector type is the same as the element type. |
| 92 | auto memrefElTy = memRefType.getElementType(); |
| 93 | if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType) |
| 94 | return rewriter.notifyMatchFailure(xferOp, "incompatible element type" ); |
| 95 | |
| 96 | // Otherwise, element types of the memref and the vector must match. |
| 97 | if (!isa<VectorType>(memrefElTy) && |
| 98 | memrefElTy != xferOp.getVectorType().getElementType()) |
| 99 | return rewriter.notifyMatchFailure(xferOp, "non-matching element type" ); |
| 100 | |
| 101 | // Out-of-bounds dims are handled by MaterializeTransferMask. |
| 102 | if (xferOp.hasOutOfBoundsDim()) |
| 103 | return rewriter.notifyMatchFailure(xferOp, "out-of-bounds needs mask" ); |
| 104 | |
| 105 | if (xferOp.getVectorType().getRank() != 1) |
| 106 | // vector.maskedload operates on 1-D vectors. |
| 107 | return rewriter.notifyMatchFailure( |
| 108 | xferOp, "vector type is not rank 1, can't create masked load, needs " |
| 109 | "VectorToSCF" ); |
| 110 | |
| 111 | return success(); |
| 112 | } |
| 113 | |
| 114 | static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc, |
| 115 | vector::TransferReadOp readOp, |
| 116 | bool requiresBroadcasting, |
| 117 | VectorType unbroadcastedVectorType) { |
| 118 | Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType, |
| 119 | readOp.getPadding()); |
| 120 | Value load = builder.create<vector::LoadOp>( |
| 121 | loc, unbroadcastedVectorType, readOp.getBase(), readOp.getIndices()); |
| 122 | Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType, |
| 123 | readOp.getMask(), load, fill); |
| 124 | // Insert a broadcasting op if required. |
| 125 | if (requiresBroadcasting) { |
| 126 | res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(), res); |
| 127 | } |
| 128 | return res; |
| 129 | } |
| 130 | |
| 131 | static constexpr char kTransferReadNeedsMask[] = |
| 132 | "amdgpu.buffer_transfer_read_needs_mask" ; |
| 133 | |
| 134 | namespace { |
| 135 | |
| 136 | struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> { |
| 137 | using OpRewritePattern::OpRewritePattern; |
| 138 | |
| 139 | LogicalResult matchAndRewrite(vector::TransferReadOp readOp, |
| 140 | PatternRewriter &rewriter) const override { |
| 141 | if (readOp->hasAttr(kTransferReadNeedsMask)) |
| 142 | return failure(); |
| 143 | |
| 144 | bool requiresBroadcasting = false; |
| 145 | VectorType unbroadcastedVectorType; |
| 146 | if (failed(transferPreconditions(rewriter, readOp, requiresBroadcasting, |
| 147 | unbroadcastedVectorType))) { |
| 148 | return failure(); |
| 149 | } |
| 150 | |
| 151 | Location loc = readOp.getLoc(); |
| 152 | Value src = readOp.getBase(); |
| 153 | |
| 154 | VectorType vectorType = readOp.getVectorType(); |
| 155 | int64_t vectorSize = vectorType.getNumElements(); |
| 156 | int64_t elementBitWidth = vectorType.getElementTypeBitWidth(); |
| 157 | SmallVector<OpFoldResult> indices = readOp.getIndices(); |
| 158 | |
| 159 | auto stridedMetadata = |
| 160 | rewriter.create<memref::ExtractStridedMetadataOp>(loc, src); |
| 161 | SmallVector<OpFoldResult> strides = |
| 162 | stridedMetadata.getConstifiedMixedStrides(); |
| 163 | SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes(); |
| 164 | OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset(); |
| 165 | memref::LinearizedMemRefInfo linearizedInfo; |
| 166 | OpFoldResult linearizedIndices; |
| 167 | std::tie(args&: linearizedInfo, args&: linearizedIndices) = |
| 168 | memref::getLinearizedMemRefOffsetAndSize(builder&: rewriter, loc, srcBits: elementBitWidth, |
| 169 | dstBits: elementBitWidth, offset, sizes, |
| 170 | strides, indices); |
| 171 | |
| 172 | // delta = bufferSize - linearizedOffset |
| 173 | Value vectorSizeOffset = |
| 174 | rewriter.create<arith::ConstantIndexOp>(location: loc, args&: vectorSize); |
| 175 | Value linearIndex = |
| 176 | getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: linearizedIndices); |
| 177 | Value totalSize = getValueOrCreateConstantIndexOp( |
| 178 | b&: rewriter, loc, ofr: linearizedInfo.linearizedSize); |
| 179 | Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex); |
| 180 | |
| 181 | // 1) check if delta < vectorSize |
| 182 | Value isOutofBounds = rewriter.create<arith::CmpIOp>( |
| 183 | loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset); |
| 184 | |
| 185 | // 2) check if (detla % elements_per_word != 0) |
| 186 | Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>( |
| 187 | location: loc, args: llvm::divideCeil(Numerator: 32, Denominator: elementBitWidth)); |
| 188 | Value isNotWordAligned = rewriter.create<arith::CmpIOp>( |
| 189 | loc, arith::CmpIPredicate::ne, |
| 190 | rewriter.create<arith::RemUIOp>(loc, delta, elementsPerWord), |
| 191 | rewriter.create<arith::ConstantIndexOp>(loc, 0)); |
| 192 | |
| 193 | // We take the fallback of transfer_read default lowering only it is both |
| 194 | // out-of-bounds and not word aligned. The fallback ensures correct results |
| 195 | // when loading at the boundary of the buffer since buffer load returns |
| 196 | // inconsistent zeros for the whole word when boundary is crossed. |
| 197 | Value ifCondition = |
| 198 | rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned); |
| 199 | |
| 200 | auto thenBuilder = [&](OpBuilder &builder, Location loc) { |
| 201 | Operation *read = builder.clone(*readOp.getOperation()); |
| 202 | read->setAttr(kTransferReadNeedsMask, builder.getUnitAttr()); |
| 203 | Value readResult = read->getResult(idx: 0); |
| 204 | builder.create<scf::YieldOp>(loc, readResult); |
| 205 | }; |
| 206 | |
| 207 | auto elseBuilder = [&](OpBuilder &builder, Location loc) { |
| 208 | Value res = createVectorLoadForMaskedLoad( |
| 209 | builder, loc, readOp, requiresBroadcasting, unbroadcastedVectorType); |
| 210 | rewriter.create<scf::YieldOp>(loc, res); |
| 211 | }; |
| 212 | |
| 213 | auto ifOp = |
| 214 | rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder); |
| 215 | |
| 216 | rewriter.replaceOp(readOp, ifOp); |
| 217 | |
| 218 | return success(); |
| 219 | } |
| 220 | }; |
| 221 | |
| 222 | } // namespace |
| 223 | |
| 224 | void mlir::amdgpu::populateAmdgpuTransferReadToLoadPatterns( |
| 225 | RewritePatternSet &patterns) { |
| 226 | patterns.add<TransferReadLowering>(arg: patterns.getContext()); |
| 227 | } |
| 228 | |
| 229 | struct AmdgpuTransferReadToLoadPass final |
| 230 | : amdgpu::impl::AmdgpuTransferReadToLoadPassBase< |
| 231 | AmdgpuTransferReadToLoadPass> { |
| 232 | void runOnOperation() override { |
| 233 | RewritePatternSet patterns(&getContext()); |
| 234 | populateAmdgpuTransferReadToLoadPatterns(patterns); |
| 235 | if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { |
| 236 | return signalPassFailure(); |
| 237 | } |
| 238 | } |
| 239 | }; |
| 240 | |