| 1 | //===- MaskedloadToLoad.cpp - Lowers maskedload to load -------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/AMDGPU/Transforms/Passes.h" |
| 10 | |
| 11 | #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" |
| 12 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 13 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 14 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 15 | #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" |
| 16 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 17 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 18 | #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
| 19 | #include "mlir/IR/BuiltinTypes.h" |
| 20 | #include "mlir/IR/OpDefinition.h" |
| 21 | #include "mlir/IR/PatternMatch.h" |
| 22 | #include "mlir/IR/TypeUtilities.h" |
| 23 | #include "mlir/Pass/Pass.h" |
| 24 | #include "mlir/Support/LogicalResult.h" |
| 25 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 26 | #include "llvm/Support/MathExtras.h" |
| 27 | |
| 28 | namespace mlir::amdgpu { |
| 29 | #define GEN_PASS_DEF_AMDGPUMASKEDLOADTOLOADPASS |
| 30 | #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc" |
| 31 | } // namespace mlir::amdgpu |
| 32 | |
| 33 | using namespace mlir; |
| 34 | using namespace mlir::amdgpu; |
| 35 | |
| 36 | /// This pattern supports lowering of: `vector.maskedload` to `vector.load` |
| 37 | /// and `arith.select` if the memref is in buffer address space. |
| 38 | static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter, |
| 39 | vector::MaskedLoadOp maskedOp) { |
| 40 | auto memRefType = dyn_cast<MemRefType>(Val: maskedOp.getBase().getType()); |
| 41 | if (!memRefType) |
| 42 | return rewriter.notifyMatchFailure(arg&: maskedOp, msg: "not a memref source" ); |
| 43 | |
| 44 | Attribute addrSpace = memRefType.getMemorySpace(); |
| 45 | if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(Val: addrSpace)) |
| 46 | return rewriter.notifyMatchFailure(arg&: maskedOp, msg: "no address space" ); |
| 47 | |
| 48 | if (dyn_cast<amdgpu::AddressSpaceAttr>(Val&: addrSpace).getValue() != |
| 49 | amdgpu::AddressSpace::FatRawBuffer) |
| 50 | return rewriter.notifyMatchFailure(arg&: maskedOp, msg: "not in buffer address space" ); |
| 51 | |
| 52 | return success(); |
| 53 | } |
| 54 | |
| 55 | static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc, |
| 56 | vector::MaskedLoadOp maskedOp, |
| 57 | bool passthru) { |
| 58 | VectorType vectorType = maskedOp.getVectorType(); |
| 59 | Value load = builder.create<vector::LoadOp>( |
| 60 | location: loc, args&: vectorType, args: maskedOp.getBase(), args: maskedOp.getIndices()); |
| 61 | if (passthru) |
| 62 | load = builder.create<arith::SelectOp>(location: loc, args&: vectorType, args: maskedOp.getMask(), |
| 63 | args&: load, args: maskedOp.getPassThru()); |
| 64 | return load; |
| 65 | } |
| 66 | |
| 67 | /// Check if the given value comes from a broadcasted i1 condition. |
| 68 | static FailureOr<Value> matchFullMask(OpBuilder &b, Value val) { |
| 69 | auto broadcastOp = val.getDefiningOp<vector::BroadcastOp>(); |
| 70 | if (!broadcastOp) |
| 71 | return failure(); |
| 72 | if (isa<VectorType>(Val: broadcastOp.getSourceType())) |
| 73 | return failure(); |
| 74 | return broadcastOp.getSource(); |
| 75 | } |
| 76 | |
| 77 | static constexpr char kMaskedloadNeedsMask[] = |
| 78 | "amdgpu.buffer_maskedload_needs_mask" ; |
| 79 | |
| 80 | namespace { |
| 81 | |
| 82 | struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> { |
| 83 | using OpRewritePattern::OpRewritePattern; |
| 84 | |
| 85 | LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp, |
| 86 | PatternRewriter &rewriter) const override { |
| 87 | if (maskedOp->hasAttr(name: kMaskedloadNeedsMask)) |
| 88 | return failure(); |
| 89 | |
| 90 | if (failed(Result: baseInBufferAddrSpace(rewriter, maskedOp))) { |
| 91 | return failure(); |
| 92 | } |
| 93 | |
| 94 | // Check if this is either a full inbounds load or an empty, oob load. If |
| 95 | // so, take the fast path and don't generate an if condition, because we |
| 96 | // know doing the oob load is always safe. |
| 97 | if (succeeded(Result: matchFullMask(b&: rewriter, val: maskedOp.getMask()))) { |
| 98 | Value load = createVectorLoadForMaskedLoad(builder&: rewriter, loc: maskedOp.getLoc(), |
| 99 | maskedOp, /*passthru=*/true); |
| 100 | rewriter.replaceOp(op: maskedOp, newValues: load); |
| 101 | return success(); |
| 102 | } |
| 103 | |
| 104 | Location loc = maskedOp.getLoc(); |
| 105 | Value src = maskedOp.getBase(); |
| 106 | |
| 107 | VectorType vectorType = maskedOp.getVectorType(); |
| 108 | int64_t vectorSize = vectorType.getNumElements(); |
| 109 | int64_t elementBitWidth = vectorType.getElementTypeBitWidth(); |
| 110 | SmallVector<OpFoldResult> indices = maskedOp.getIndices(); |
| 111 | |
| 112 | auto stridedMetadata = |
| 113 | rewriter.create<memref::ExtractStridedMetadataOp>(location: loc, args&: src); |
| 114 | SmallVector<OpFoldResult> strides = |
| 115 | stridedMetadata.getConstifiedMixedStrides(); |
| 116 | SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes(); |
| 117 | OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset(); |
| 118 | memref::LinearizedMemRefInfo linearizedInfo; |
| 119 | OpFoldResult linearizedIndices; |
| 120 | std::tie(args&: linearizedInfo, args&: linearizedIndices) = |
| 121 | memref::getLinearizedMemRefOffsetAndSize(builder&: rewriter, loc, srcBits: elementBitWidth, |
| 122 | dstBits: elementBitWidth, offset, sizes, |
| 123 | strides, indices); |
| 124 | |
| 125 | // delta = bufferSize - linearizedOffset |
| 126 | Value vectorSizeOffset = |
| 127 | rewriter.create<arith::ConstantIndexOp>(location: loc, args&: vectorSize); |
| 128 | Value linearIndex = |
| 129 | getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: linearizedIndices); |
| 130 | Value totalSize = getValueOrCreateConstantIndexOp( |
| 131 | b&: rewriter, loc, ofr: linearizedInfo.linearizedSize); |
| 132 | Value delta = rewriter.create<arith::SubIOp>(location: loc, args&: totalSize, args&: linearIndex); |
| 133 | |
| 134 | // 1) check if delta < vectorSize |
| 135 | Value isOutofBounds = rewriter.create<arith::CmpIOp>( |
| 136 | location: loc, args: arith::CmpIPredicate::ult, args&: delta, args&: vectorSizeOffset); |
| 137 | |
| 138 | // 2) check if (detla % elements_per_word != 0) |
| 139 | Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>( |
| 140 | location: loc, args: llvm::divideCeil(Numerator: 32, Denominator: elementBitWidth)); |
| 141 | Value isNotWordAligned = rewriter.create<arith::CmpIOp>( |
| 142 | location: loc, args: arith::CmpIPredicate::ne, |
| 143 | args: rewriter.create<arith::RemUIOp>(location: loc, args&: delta, args&: elementsPerWord), |
| 144 | args: rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0)); |
| 145 | |
| 146 | // We take the fallback of maskedload default lowering only it is both |
| 147 | // out-of-bounds and not word aligned. The fallback ensures correct results |
| 148 | // when loading at the boundary of the buffer since buffer load returns |
| 149 | // inconsistent zeros for the whole word when boundary is crossed. |
| 150 | Value ifCondition = |
| 151 | rewriter.create<arith::AndIOp>(location: loc, args&: isOutofBounds, args&: isNotWordAligned); |
| 152 | |
| 153 | auto thenBuilder = [&](OpBuilder &builder, Location loc) { |
| 154 | Operation *read = builder.clone(op&: *maskedOp.getOperation()); |
| 155 | read->setAttr(name: kMaskedloadNeedsMask, value: builder.getUnitAttr()); |
| 156 | Value readResult = read->getResult(idx: 0); |
| 157 | builder.create<scf::YieldOp>(location: loc, args&: readResult); |
| 158 | }; |
| 159 | |
| 160 | auto elseBuilder = [&](OpBuilder &builder, Location loc) { |
| 161 | Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp, |
| 162 | /*passthru=*/true); |
| 163 | rewriter.create<scf::YieldOp>(location: loc, args&: res); |
| 164 | }; |
| 165 | |
| 166 | auto ifOp = |
| 167 | rewriter.create<scf::IfOp>(location: loc, args&: ifCondition, args&: thenBuilder, args&: elseBuilder); |
| 168 | |
| 169 | rewriter.replaceOp(op: maskedOp, newOp: ifOp); |
| 170 | |
| 171 | return success(); |
| 172 | } |
| 173 | }; |
| 174 | |
| 175 | struct FullMaskedLoadToConditionalLoad |
| 176 | : OpRewritePattern<vector::MaskedLoadOp> { |
| 177 | using OpRewritePattern::OpRewritePattern; |
| 178 | |
| 179 | LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp, |
| 180 | PatternRewriter &rewriter) const override { |
| 181 | FailureOr<Value> maybeCond = matchFullMask(b&: rewriter, val: loadOp.getMask()); |
| 182 | if (failed(Result: maybeCond)) { |
| 183 | return failure(); |
| 184 | } |
| 185 | |
| 186 | Value cond = maybeCond.value(); |
| 187 | auto trueBuilder = [&](OpBuilder &builder, Location loc) { |
| 188 | Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp: loadOp, |
| 189 | /*passthru=*/false); |
| 190 | rewriter.create<scf::YieldOp>(location: loc, args&: res); |
| 191 | }; |
| 192 | auto falseBuilder = [&](OpBuilder &builder, Location loc) { |
| 193 | rewriter.create<scf::YieldOp>(location: loc, args: loadOp.getPassThru()); |
| 194 | }; |
| 195 | auto ifOp = rewriter.create<scf::IfOp>(location: loadOp.getLoc(), args&: cond, args&: trueBuilder, |
| 196 | args&: falseBuilder); |
| 197 | rewriter.replaceOp(op: loadOp, newOp: ifOp); |
| 198 | return success(); |
| 199 | } |
| 200 | }; |
| 201 | |
| 202 | struct FullMaskedStoreToConditionalStore |
| 203 | : OpRewritePattern<vector::MaskedStoreOp> { |
| 204 | using OpRewritePattern::OpRewritePattern; |
| 205 | |
| 206 | LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp, |
| 207 | PatternRewriter &rewriter) const override { |
| 208 | FailureOr<Value> maybeCond = matchFullMask(b&: rewriter, val: storeOp.getMask()); |
| 209 | if (failed(Result: maybeCond)) { |
| 210 | return failure(); |
| 211 | } |
| 212 | Value cond = maybeCond.value(); |
| 213 | |
| 214 | auto trueBuilder = [&](OpBuilder &builder, Location loc) { |
| 215 | rewriter.create<vector::StoreOp>(location: loc, args: storeOp.getValueToStore(), |
| 216 | args: storeOp.getBase(), args: storeOp.getIndices()); |
| 217 | rewriter.create<scf::YieldOp>(location: loc); |
| 218 | }; |
| 219 | auto ifOp = rewriter.create<scf::IfOp>(location: storeOp.getLoc(), args&: cond, args&: trueBuilder); |
| 220 | rewriter.replaceOp(op: storeOp, newOp: ifOp); |
| 221 | return success(); |
| 222 | } |
| 223 | }; |
| 224 | |
| 225 | } // namespace |
| 226 | |
| 227 | void mlir::amdgpu::populateAmdgpuMaskedloadToLoadPatterns( |
| 228 | RewritePatternSet &patterns, PatternBenefit benefit) { |
| 229 | patterns.add<MaskedLoadLowering, FullMaskedLoadToConditionalLoad, |
| 230 | FullMaskedStoreToConditionalStore>(arg: patterns.getContext(), |
| 231 | args&: benefit); |
| 232 | } |
| 233 | |
| 234 | struct AmdgpuMaskedloadToLoadPass final |
| 235 | : amdgpu::impl::AmdgpuMaskedloadToLoadPassBase<AmdgpuMaskedloadToLoadPass> { |
| 236 | void runOnOperation() override { |
| 237 | RewritePatternSet patterns(&getContext()); |
| 238 | populateAmdgpuMaskedloadToLoadPatterns(patterns); |
| 239 | if (failed(Result: applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns)))) { |
| 240 | return signalPassFailure(); |
| 241 | } |
| 242 | } |
| 243 | }; |
| 244 | |