| 1 | //===- LowerVectorMask.cpp - Lower 'vector.mask' operation ----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements target-independent rewrites and utilities to lower the |
| 10 | // 'vector.mask' operation. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 15 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 16 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 17 | #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
| 18 | #include "mlir/Dialect/Vector/Transforms/Passes.h" |
| 19 | #include "mlir/IR/PatternMatch.h" |
| 20 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 21 | |
| 22 | #define DEBUG_TYPE "lower-vector-mask" |
| 23 | |
| 24 | namespace mlir { |
| 25 | namespace vector { |
| 26 | #define GEN_PASS_DEF_LOWERVECTORMASKPASS |
| 27 | #include "mlir/Dialect/Vector/Transforms/Passes.h.inc" |
| 28 | } // namespace vector |
| 29 | } // namespace mlir |
| 30 | |
| 31 | using namespace mlir; |
| 32 | using namespace mlir::vector; |
| 33 | |
| 34 | //===----------------------------------------------------------------------===// |
| 35 | // populateVectorMaskOpLoweringPatterns |
| 36 | //===----------------------------------------------------------------------===// |
| 37 | |
| 38 | namespace { |
| 39 | /// Progressive lowering of CreateMaskOp. |
| 40 | /// One: |
| 41 | /// %x = vector.create_mask %a, ... : vector<dx...> |
| 42 | /// is replaced by: |
| 43 | /// %l = vector.create_mask ... : vector<...> ; one lower rank |
| 44 | /// %0 = arith.cmpi "slt", %ci, %a | |
| 45 | /// %1 = select %0, %l, %zeroes | |
| 46 | /// %r = vector.insert %1, %pr [i] | d-times |
| 47 | /// %x = .... |
| 48 | /// until a one-dimensional vector is reached. |
| 49 | class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> { |
| 50 | public: |
| 51 | using OpRewritePattern::OpRewritePattern; |
| 52 | |
| 53 | LogicalResult matchAndRewrite(vector::CreateMaskOp op, |
| 54 | PatternRewriter &rewriter) const override { |
| 55 | auto dstType = cast<VectorType>(op.getResult().getType()); |
| 56 | int64_t rank = dstType.getRank(); |
| 57 | if (rank <= 1) |
| 58 | return rewriter.notifyMatchFailure( |
| 59 | op, "0-D and 1-D vectors are handled separately" ); |
| 60 | |
| 61 | if (dstType.getScalableDims().front()) |
| 62 | return rewriter.notifyMatchFailure( |
| 63 | op, "Cannot unroll leading scalable dim in dstType" ); |
| 64 | |
| 65 | auto loc = op.getLoc(); |
| 66 | int64_t dim = dstType.getDimSize(0); |
| 67 | Value idx = op.getOperand(0); |
| 68 | |
| 69 | VectorType lowType = VectorType::Builder(dstType).dropDim(0); |
| 70 | Value trueVal = rewriter.create<vector::CreateMaskOp>( |
| 71 | loc, lowType, op.getOperands().drop_front()); |
| 72 | Value falseVal = rewriter.create<arith::ConstantOp>( |
| 73 | loc, lowType, rewriter.getZeroAttr(lowType)); |
| 74 | Value result = rewriter.create<arith::ConstantOp>( |
| 75 | loc, dstType, rewriter.getZeroAttr(dstType)); |
| 76 | for (int64_t d = 0; d < dim; d++) { |
| 77 | Value bnd = |
| 78 | rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d)); |
| 79 | Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, |
| 80 | bnd, idx); |
| 81 | Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal); |
| 82 | result = rewriter.create<vector::InsertOp>(loc, sel, result, d); |
| 83 | } |
| 84 | rewriter.replaceOp(op, result); |
| 85 | return success(); |
| 86 | } |
| 87 | }; |
| 88 | |
| 89 | /// Progressive lowering of ConstantMaskOp. |
| 90 | /// One: |
| 91 | /// %x = vector.constant_mask [a,b] |
| 92 | /// is replaced by: |
| 93 | /// %z = zero-result |
| 94 | /// %l = vector.constant_mask [b] |
| 95 | /// %4 = vector.insert %l, %z[0] |
| 96 | /// .. |
| 97 | /// %x = vector.insert %l, %..[a-1] |
| 98 | /// until a one-dimensional vector is reached. All these operations |
| 99 | /// will be folded at LLVM IR level. |
| 100 | class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> { |
| 101 | public: |
| 102 | using OpRewritePattern::OpRewritePattern; |
| 103 | |
| 104 | LogicalResult matchAndRewrite(vector::ConstantMaskOp op, |
| 105 | PatternRewriter &rewriter) const override { |
| 106 | auto loc = op.getLoc(); |
| 107 | auto dstType = op.getType(); |
| 108 | auto dimSizes = op.getMaskDimSizes(); |
| 109 | int64_t rank = dstType.getRank(); |
| 110 | |
| 111 | if (rank == 0) { |
| 112 | assert(dimSizes.size() == 1 && |
| 113 | "Expected exactly one dim size for a 0-D vector" ); |
| 114 | bool value = dimSizes.front() == 1; |
| 115 | rewriter.replaceOpWithNewOp<arith::ConstantOp>( |
| 116 | op, dstType, |
| 117 | DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()), |
| 118 | value)); |
| 119 | return success(); |
| 120 | } |
| 121 | |
| 122 | int64_t trueDimSize = dimSizes.front(); |
| 123 | |
| 124 | if (rank == 1) { |
| 125 | if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) { |
| 126 | // Use constant splat for 'all set' or 'none set' dims. |
| 127 | // This produces correct code for scalable dimensions (it will lower to |
| 128 | // a constant splat). |
| 129 | rewriter.replaceOpWithNewOp<arith::ConstantOp>( |
| 130 | op, DenseElementsAttr::get(dstType, trueDimSize != 0)); |
| 131 | } else { |
| 132 | // Express constant 1-D case in explicit vector form: |
| 133 | // [T,..,T,F,..,F]. |
| 134 | // Note: The verifier would reject this case for scalable vectors. |
| 135 | SmallVector<bool> values(dstType.getDimSize(0), false); |
| 136 | for (int64_t d = 0; d < trueDimSize; d++) |
| 137 | values[d] = true; |
| 138 | rewriter.replaceOpWithNewOp<arith::ConstantOp>( |
| 139 | op, dstType, rewriter.getBoolVectorAttr(values)); |
| 140 | } |
| 141 | return success(); |
| 142 | } |
| 143 | |
| 144 | if (dstType.getScalableDims().front()) |
| 145 | return rewriter.notifyMatchFailure( |
| 146 | op, "Cannot unroll leading scalable dim in dstType" ); |
| 147 | |
| 148 | VectorType lowType = VectorType::Builder(dstType).dropDim(0); |
| 149 | Value trueVal = rewriter.create<vector::ConstantMaskOp>( |
| 150 | loc, lowType, dimSizes.drop_front()); |
| 151 | Value result = rewriter.create<arith::ConstantOp>( |
| 152 | loc, dstType, rewriter.getZeroAttr(dstType)); |
| 153 | for (int64_t d = 0; d < trueDimSize; d++) |
| 154 | result = rewriter.create<vector::InsertOp>(loc, trueVal, result, d); |
| 155 | |
| 156 | rewriter.replaceOp(op, result); |
| 157 | return success(); |
| 158 | } |
| 159 | }; |
| 160 | } // namespace |
| 161 | |
| 162 | void mlir::vector::populateVectorMaskOpLoweringPatterns( |
| 163 | RewritePatternSet &patterns, PatternBenefit benefit) { |
| 164 | patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>( |
| 165 | arg: patterns.getContext(), args&: benefit); |
| 166 | } |
| 167 | |
| 168 | //===----------------------------------------------------------------------===// |
| 169 | // populateVectorMaskLoweringPatternsForSideEffectingOps |
| 170 | //===----------------------------------------------------------------------===// |
| 171 | |
| 172 | namespace { |
| 173 | |
| 174 | /// The `MaskOpRewritePattern` implements a pattern that follows a two-fold |
| 175 | /// matching: |
| 176 | /// 1. It matches a `vector.mask` operation. |
| 177 | /// 2. It invokes `matchAndRewriteMaskableOp` on `MaskableOpInterface` nested |
| 178 | /// in the matched `vector.mask` operation. |
| 179 | /// |
| 180 | /// It is required that the replacement op in the pattern replaces the |
| 181 | /// `vector.mask` operation and not the nested `MaskableOpInterface`. This |
| 182 | /// approach allows having patterns that "stop" at every `vector.mask` operation |
| 183 | /// and actually match the traits of its the nested `MaskableOpInterface`. |
| 184 | template <class SourceOp> |
| 185 | struct MaskOpRewritePattern : OpRewritePattern<MaskOp> { |
| 186 | using OpRewritePattern<MaskOp>::OpRewritePattern; |
| 187 | |
| 188 | private: |
| 189 | LogicalResult matchAndRewrite(MaskOp maskOp, |
| 190 | PatternRewriter &rewriter) const final { |
| 191 | auto maskableOp = cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp()); |
| 192 | if (!maskableOp) |
| 193 | return failure(); |
| 194 | SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation()); |
| 195 | if (!sourceOp) |
| 196 | return failure(); |
| 197 | |
| 198 | return matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter); |
| 199 | } |
| 200 | |
| 201 | protected: |
| 202 | virtual LogicalResult |
| 203 | matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp, |
| 204 | PatternRewriter &rewriter) const = 0; |
| 205 | }; |
| 206 | |
| 207 | /// Lowers a masked `vector.transfer_read` operation. |
| 208 | struct MaskedTransferReadOpPattern |
| 209 | : public MaskOpRewritePattern<TransferReadOp> { |
| 210 | public: |
| 211 | using MaskOpRewritePattern<TransferReadOp>::MaskOpRewritePattern; |
| 212 | |
| 213 | LogicalResult |
| 214 | matchAndRewriteMaskableOp(TransferReadOp readOp, MaskingOpInterface maskingOp, |
| 215 | PatternRewriter &rewriter) const override { |
| 216 | // TODO: The 'vector.mask' passthru is a vector and 'vector.transfer_read' |
| 217 | // expects a scalar. We could only lower one to the other for cases where |
| 218 | // the passthru is a broadcast of a scalar. |
| 219 | if (maskingOp.hasPassthru()) |
| 220 | return rewriter.notifyMatchFailure( |
| 221 | maskingOp, "Can't lower passthru to vector.transfer_read" ); |
| 222 | |
| 223 | // Replace the `vector.mask` operation. |
| 224 | rewriter.replaceOpWithNewOp<TransferReadOp>( |
| 225 | maskingOp.getOperation(), readOp.getVectorType(), readOp.getBase(), |
| 226 | readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(), |
| 227 | maskingOp.getMask(), readOp.getInBounds()); |
| 228 | return success(); |
| 229 | } |
| 230 | }; |
| 231 | |
| 232 | /// Lowers a masked `vector.transfer_write` operation. |
| 233 | struct MaskedTransferWriteOpPattern |
| 234 | : public MaskOpRewritePattern<TransferWriteOp> { |
| 235 | public: |
| 236 | using MaskOpRewritePattern<TransferWriteOp>::MaskOpRewritePattern; |
| 237 | |
| 238 | LogicalResult |
| 239 | matchAndRewriteMaskableOp(TransferWriteOp writeOp, |
| 240 | MaskingOpInterface maskingOp, |
| 241 | PatternRewriter &rewriter) const override { |
| 242 | Type resultType = |
| 243 | writeOp.getResult() ? writeOp.getResult().getType() : Type(); |
| 244 | |
| 245 | // Replace the `vector.mask` operation. |
| 246 | rewriter.replaceOpWithNewOp<TransferWriteOp>( |
| 247 | maskingOp.getOperation(), resultType, writeOp.getVector(), |
| 248 | writeOp.getBase(), writeOp.getIndices(), writeOp.getPermutationMap(), |
| 249 | maskingOp.getMask(), writeOp.getInBounds()); |
| 250 | return success(); |
| 251 | } |
| 252 | }; |
| 253 | |
| 254 | /// Lowers a masked `vector.gather` operation. |
| 255 | struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> { |
| 256 | public: |
| 257 | using MaskOpRewritePattern<GatherOp>::MaskOpRewritePattern; |
| 258 | |
| 259 | LogicalResult |
| 260 | matchAndRewriteMaskableOp(GatherOp gatherOp, MaskingOpInterface maskingOp, |
| 261 | PatternRewriter &rewriter) const override { |
| 262 | Value passthru = maskingOp.hasPassthru() |
| 263 | ? maskingOp.getPassthru() |
| 264 | : rewriter.create<arith::ConstantOp>( |
| 265 | gatherOp.getLoc(), |
| 266 | rewriter.getZeroAttr(gatherOp.getVectorType())); |
| 267 | |
| 268 | // Replace the `vector.mask` operation. |
| 269 | rewriter.replaceOpWithNewOp<GatherOp>( |
| 270 | maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(), |
| 271 | gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(), |
| 272 | passthru); |
| 273 | return success(); |
| 274 | } |
| 275 | }; |
| 276 | |
| 277 | struct LowerVectorMaskPass |
| 278 | : public vector::impl::LowerVectorMaskPassBase<LowerVectorMaskPass> { |
| 279 | using Base::Base; |
| 280 | |
| 281 | void runOnOperation() override { |
| 282 | Operation *op = getOperation(); |
| 283 | MLIRContext *context = op->getContext(); |
| 284 | |
| 285 | RewritePatternSet loweringPatterns(context); |
| 286 | populateVectorMaskLoweringPatternsForSideEffectingOps(patterns&: loweringPatterns); |
| 287 | MaskOp::getCanonicalizationPatterns(loweringPatterns, context); |
| 288 | |
| 289 | if (failed(applyPatternsGreedily(op, std::move(loweringPatterns)))) |
| 290 | signalPassFailure(); |
| 291 | } |
| 292 | |
| 293 | void getDependentDialects(DialectRegistry ®istry) const override { |
| 294 | registry.insert<vector::VectorDialect>(); |
| 295 | } |
| 296 | }; |
| 297 | |
| 298 | } // namespace |
| 299 | |
| 300 | /// Populates instances of `MaskOpRewritePattern` to lower masked operations |
| 301 | /// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and |
| 302 | /// not its nested `MaskableOpInterface`. |
| 303 | void vector::populateVectorMaskLoweringPatternsForSideEffectingOps( |
| 304 | RewritePatternSet &patterns) { |
| 305 | patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern, |
| 306 | MaskedGatherOpPattern>(arg: patterns.getContext()); |
| 307 | } |
| 308 | |
| 309 | std::unique_ptr<Pass> mlir::vector::createLowerVectorMaskPass() { |
| 310 | return std::make_unique<LowerVectorMaskPass>(); |
| 311 | } |
| 312 | |