| 1 | //===- LowerVectorMask.cpp - Lower 'vector.mask' operation ----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements target-independent rewrites and utilities to lower the |
| 10 | // 'vector.mask' operation. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 15 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 16 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 17 | #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
| 18 | #include "mlir/Dialect/Vector/Transforms/Passes.h" |
| 19 | #include "mlir/IR/PatternMatch.h" |
| 20 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 21 | |
| 22 | #define DEBUG_TYPE "lower-vector-mask" |
| 23 | |
| 24 | namespace mlir { |
| 25 | namespace vector { |
| 26 | #define GEN_PASS_DEF_LOWERVECTORMASKPASS |
| 27 | #include "mlir/Dialect/Vector/Transforms/Passes.h.inc" |
| 28 | } // namespace vector |
| 29 | } // namespace mlir |
| 30 | |
| 31 | using namespace mlir; |
| 32 | using namespace mlir::vector; |
| 33 | |
| 34 | //===----------------------------------------------------------------------===// |
| 35 | // populateVectorMaskOpLoweringPatterns |
| 36 | //===----------------------------------------------------------------------===// |
| 37 | |
| 38 | namespace { |
| 39 | /// Progressive lowering of CreateMaskOp. |
| 40 | /// One: |
| 41 | /// %x = vector.create_mask %a, ... : vector<dx...> |
| 42 | /// is replaced by: |
| 43 | /// %l = vector.create_mask ... : vector<...> ; one lower rank |
| 44 | /// %0 = arith.cmpi "slt", %ci, %a | |
| 45 | /// %1 = select %0, %l, %zeroes | |
| 46 | /// %r = vector.insert %1, %pr [i] | d-times |
| 47 | /// %x = .... |
| 48 | /// until a one-dimensional vector is reached. |
| 49 | class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> { |
| 50 | public: |
| 51 | using OpRewritePattern::OpRewritePattern; |
| 52 | |
| 53 | LogicalResult matchAndRewrite(vector::CreateMaskOp op, |
| 54 | PatternRewriter &rewriter) const override { |
| 55 | auto dstType = cast<VectorType>(Val: op.getResult().getType()); |
| 56 | int64_t rank = dstType.getRank(); |
| 57 | if (rank <= 1) |
| 58 | return rewriter.notifyMatchFailure( |
| 59 | arg&: op, msg: "0-D and 1-D vectors are handled separately" ); |
| 60 | |
| 61 | if (dstType.getScalableDims().front()) |
| 62 | return rewriter.notifyMatchFailure( |
| 63 | arg&: op, msg: "Cannot unroll leading scalable dim in dstType" ); |
| 64 | |
| 65 | auto loc = op.getLoc(); |
| 66 | int64_t dim = dstType.getDimSize(idx: 0); |
| 67 | Value idx = op.getOperand(i: 0); |
| 68 | |
| 69 | VectorType lowType = VectorType::Builder(dstType).dropDim(pos: 0); |
| 70 | Value trueVal = rewriter.create<vector::CreateMaskOp>( |
| 71 | location: loc, args&: lowType, args: op.getOperands().drop_front()); |
| 72 | Value falseVal = rewriter.create<arith::ConstantOp>( |
| 73 | location: loc, args&: lowType, args: rewriter.getZeroAttr(type: lowType)); |
| 74 | Value result = rewriter.create<arith::ConstantOp>( |
| 75 | location: loc, args&: dstType, args: rewriter.getZeroAttr(type: dstType)); |
| 76 | for (int64_t d = 0; d < dim; d++) { |
| 77 | Value bnd = |
| 78 | rewriter.create<arith::ConstantOp>(location: loc, args: rewriter.getIndexAttr(value: d)); |
| 79 | Value val = rewriter.create<arith::CmpIOp>(location: loc, args: arith::CmpIPredicate::slt, |
| 80 | args&: bnd, args&: idx); |
| 81 | Value sel = rewriter.create<arith::SelectOp>(location: loc, args&: val, args&: trueVal, args&: falseVal); |
| 82 | result = rewriter.create<vector::InsertOp>(location: loc, args&: sel, args&: result, args&: d); |
| 83 | } |
| 84 | rewriter.replaceOp(op, newValues: result); |
| 85 | return success(); |
| 86 | } |
| 87 | }; |
| 88 | |
| 89 | /// Progressive lowering of ConstantMaskOp. |
| 90 | /// One: |
| 91 | /// %x = vector.constant_mask [a,b] |
| 92 | /// is replaced by: |
| 93 | /// %z = zero-result |
| 94 | /// %l = vector.constant_mask [b] |
| 95 | /// %4 = vector.insert %l, %z[0] |
| 96 | /// .. |
| 97 | /// %x = vector.insert %l, %..[a-1] |
| 98 | /// until a one-dimensional vector is reached. All these operations |
| 99 | /// will be folded at LLVM IR level. |
| 100 | class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> { |
| 101 | public: |
| 102 | using OpRewritePattern::OpRewritePattern; |
| 103 | |
| 104 | LogicalResult matchAndRewrite(vector::ConstantMaskOp op, |
| 105 | PatternRewriter &rewriter) const override { |
| 106 | auto loc = op.getLoc(); |
| 107 | auto dstType = op.getType(); |
| 108 | auto dimSizes = op.getMaskDimSizes(); |
| 109 | int64_t rank = dstType.getRank(); |
| 110 | |
| 111 | if (rank == 0) { |
| 112 | assert(dimSizes.size() == 1 && |
| 113 | "Expected exactly one dim size for a 0-D vector" ); |
| 114 | bool value = dimSizes.front() == 1; |
| 115 | rewriter.replaceOpWithNewOp<arith::ConstantOp>( |
| 116 | op, args&: dstType, |
| 117 | args: DenseIntElementsAttr::get(type: VectorType::get(shape: {}, elementType: rewriter.getI1Type()), |
| 118 | arg&: value)); |
| 119 | return success(); |
| 120 | } |
| 121 | |
| 122 | int64_t trueDimSize = dimSizes.front(); |
| 123 | |
| 124 | if (rank == 1) { |
| 125 | if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(idx: 0)) { |
| 126 | // Use constant splat for 'all set' or 'none set' dims. |
| 127 | // This produces correct code for scalable dimensions (it will lower to |
| 128 | // a constant splat). |
| 129 | rewriter.replaceOpWithNewOp<arith::ConstantOp>( |
| 130 | op, args: DenseElementsAttr::get(type: dstType, value: trueDimSize != 0)); |
| 131 | } else { |
| 132 | // Express constant 1-D case in explicit vector form: |
| 133 | // [T,..,T,F,..,F]. |
| 134 | // Note: The verifier would reject this case for scalable vectors. |
| 135 | SmallVector<bool> values(dstType.getDimSize(idx: 0), false); |
| 136 | for (int64_t d = 0; d < trueDimSize; d++) |
| 137 | values[d] = true; |
| 138 | rewriter.replaceOpWithNewOp<arith::ConstantOp>( |
| 139 | op, args&: dstType, args: rewriter.getBoolVectorAttr(values)); |
| 140 | } |
| 141 | return success(); |
| 142 | } |
| 143 | |
| 144 | if (dstType.getScalableDims().front()) |
| 145 | return rewriter.notifyMatchFailure( |
| 146 | arg&: op, msg: "Cannot unroll leading scalable dim in dstType" ); |
| 147 | |
| 148 | VectorType lowType = VectorType::Builder(dstType).dropDim(pos: 0); |
| 149 | Value trueVal = rewriter.create<vector::ConstantMaskOp>( |
| 150 | location: loc, args&: lowType, args: dimSizes.drop_front()); |
| 151 | Value result = rewriter.create<arith::ConstantOp>( |
| 152 | location: loc, args&: dstType, args: rewriter.getZeroAttr(type: dstType)); |
| 153 | for (int64_t d = 0; d < trueDimSize; d++) |
| 154 | result = rewriter.create<vector::InsertOp>(location: loc, args&: trueVal, args&: result, args&: d); |
| 155 | |
| 156 | rewriter.replaceOp(op, newValues: result); |
| 157 | return success(); |
| 158 | } |
| 159 | }; |
| 160 | } // namespace |
| 161 | |
| 162 | void mlir::vector::populateVectorMaskOpLoweringPatterns( |
| 163 | RewritePatternSet &patterns, PatternBenefit benefit) { |
| 164 | patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>( |
| 165 | arg: patterns.getContext(), args&: benefit); |
| 166 | } |
| 167 | |
| 168 | //===----------------------------------------------------------------------===// |
| 169 | // populateVectorMaskLoweringPatternsForSideEffectingOps |
| 170 | //===----------------------------------------------------------------------===// |
| 171 | |
| 172 | namespace { |
| 173 | |
| 174 | /// The `MaskOpRewritePattern` implements a pattern that follows a two-fold |
| 175 | /// matching: |
| 176 | /// 1. It matches a `vector.mask` operation. |
| 177 | /// 2. It invokes `matchAndRewriteMaskableOp` on `MaskableOpInterface` nested |
| 178 | /// in the matched `vector.mask` operation. |
| 179 | /// |
| 180 | /// It is required that the replacement op in the pattern replaces the |
| 181 | /// `vector.mask` operation and not the nested `MaskableOpInterface`. This |
| 182 | /// approach allows having patterns that "stop" at every `vector.mask` operation |
| 183 | /// and actually match the traits of its the nested `MaskableOpInterface`. |
| 184 | template <class SourceOp> |
| 185 | struct MaskOpRewritePattern : OpRewritePattern<MaskOp> { |
| 186 | using OpRewritePattern<MaskOp>::OpRewritePattern; |
| 187 | |
| 188 | private: |
| 189 | LogicalResult matchAndRewrite(MaskOp maskOp, |
| 190 | PatternRewriter &rewriter) const final { |
| 191 | auto maskableOp = cast_or_null<MaskableOpInterface>(Val: maskOp.getMaskableOp()); |
| 192 | if (!maskableOp) |
| 193 | return failure(); |
| 194 | SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation()); |
| 195 | if (!sourceOp) |
| 196 | return failure(); |
| 197 | |
| 198 | return matchAndRewriteMaskableOp(sourceOp, maskingOp: maskOp, rewriter); |
| 199 | } |
| 200 | |
| 201 | protected: |
| 202 | virtual LogicalResult |
| 203 | matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp, |
| 204 | PatternRewriter &rewriter) const = 0; |
| 205 | }; |
| 206 | |
| 207 | /// Lowers a masked `vector.transfer_read` operation. |
| 208 | struct MaskedTransferReadOpPattern |
| 209 | : public MaskOpRewritePattern<TransferReadOp> { |
| 210 | public: |
| 211 | using MaskOpRewritePattern<TransferReadOp>::MaskOpRewritePattern; |
| 212 | |
| 213 | LogicalResult |
| 214 | matchAndRewriteMaskableOp(TransferReadOp readOp, MaskingOpInterface maskingOp, |
| 215 | PatternRewriter &rewriter) const override { |
| 216 | // TODO: The 'vector.mask' passthru is a vector and 'vector.transfer_read' |
| 217 | // expects a scalar. We could only lower one to the other for cases where |
| 218 | // the passthru is a broadcast of a scalar. |
| 219 | if (maskingOp.hasPassthru()) |
| 220 | return rewriter.notifyMatchFailure( |
| 221 | arg&: maskingOp, msg: "Can't lower passthru to vector.transfer_read" ); |
| 222 | |
| 223 | // Replace the `vector.mask` operation. |
| 224 | rewriter.replaceOpWithNewOp<TransferReadOp>( |
| 225 | op: maskingOp.getOperation(), args: readOp.getVectorType(), args: readOp.getBase(), |
| 226 | args: readOp.getIndices(), args: readOp.getPermutationMap(), args: readOp.getPadding(), |
| 227 | args: maskingOp.getMask(), args: readOp.getInBounds()); |
| 228 | return success(); |
| 229 | } |
| 230 | }; |
| 231 | |
| 232 | /// Lowers a masked `vector.transfer_write` operation. |
| 233 | struct MaskedTransferWriteOpPattern |
| 234 | : public MaskOpRewritePattern<TransferWriteOp> { |
| 235 | public: |
| 236 | using MaskOpRewritePattern<TransferWriteOp>::MaskOpRewritePattern; |
| 237 | |
| 238 | LogicalResult |
| 239 | matchAndRewriteMaskableOp(TransferWriteOp writeOp, |
| 240 | MaskingOpInterface maskingOp, |
| 241 | PatternRewriter &rewriter) const override { |
| 242 | Type resultType = |
| 243 | writeOp.getResult() ? writeOp.getResult().getType() : Type(); |
| 244 | |
| 245 | // Replace the `vector.mask` operation. |
| 246 | rewriter.replaceOpWithNewOp<TransferWriteOp>( |
| 247 | op: maskingOp.getOperation(), args&: resultType, args: writeOp.getVector(), |
| 248 | args: writeOp.getBase(), args: writeOp.getIndices(), args: writeOp.getPermutationMap(), |
| 249 | args: maskingOp.getMask(), args: writeOp.getInBounds()); |
| 250 | return success(); |
| 251 | } |
| 252 | }; |
| 253 | |
| 254 | /// Lowers a masked `vector.gather` operation. |
| 255 | struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> { |
| 256 | public: |
| 257 | using MaskOpRewritePattern<GatherOp>::MaskOpRewritePattern; |
| 258 | |
| 259 | LogicalResult |
| 260 | matchAndRewriteMaskableOp(GatherOp gatherOp, MaskingOpInterface maskingOp, |
| 261 | PatternRewriter &rewriter) const override { |
| 262 | Value passthru = maskingOp.hasPassthru() |
| 263 | ? maskingOp.getPassthru() |
| 264 | : rewriter.create<arith::ConstantOp>( |
| 265 | location: gatherOp.getLoc(), |
| 266 | args: rewriter.getZeroAttr(type: gatherOp.getVectorType())); |
| 267 | |
| 268 | // Replace the `vector.mask` operation. |
| 269 | rewriter.replaceOpWithNewOp<GatherOp>( |
| 270 | op: maskingOp.getOperation(), args: gatherOp.getVectorType(), args: gatherOp.getBase(), |
| 271 | args: gatherOp.getIndices(), args: gatherOp.getIndexVec(), args: maskingOp.getMask(), |
| 272 | args&: passthru); |
| 273 | return success(); |
| 274 | } |
| 275 | }; |
| 276 | |
| 277 | struct LowerVectorMaskPass |
| 278 | : public vector::impl::LowerVectorMaskPassBase<LowerVectorMaskPass> { |
| 279 | using Base::Base; |
| 280 | |
| 281 | void runOnOperation() override { |
| 282 | Operation *op = getOperation(); |
| 283 | MLIRContext *context = op->getContext(); |
| 284 | |
| 285 | RewritePatternSet loweringPatterns(context); |
| 286 | populateVectorMaskLoweringPatternsForSideEffectingOps(patterns&: loweringPatterns); |
| 287 | MaskOp::getCanonicalizationPatterns(results&: loweringPatterns, context); |
| 288 | |
| 289 | if (failed(Result: applyPatternsGreedily(op, patterns: std::move(loweringPatterns)))) |
| 290 | signalPassFailure(); |
| 291 | } |
| 292 | |
| 293 | void getDependentDialects(DialectRegistry ®istry) const override { |
| 294 | registry.insert<vector::VectorDialect>(); |
| 295 | } |
| 296 | }; |
| 297 | |
| 298 | } // namespace |
| 299 | |
| 300 | /// Populates instances of `MaskOpRewritePattern` to lower masked operations |
| 301 | /// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and |
| 302 | /// not its nested `MaskableOpInterface`. |
| 303 | void vector::populateVectorMaskLoweringPatternsForSideEffectingOps( |
| 304 | RewritePatternSet &patterns) { |
| 305 | patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern, |
| 306 | MaskedGatherOpPattern>(arg: patterns.getContext()); |
| 307 | } |
| 308 | |
| 309 | std::unique_ptr<Pass> mlir::vector::createLowerVectorMaskPass() { |
| 310 | return std::make_unique<LowerVectorMaskPass>(); |
| 311 | } |
| 312 | |