| 1 | //===- TestMathToVCIXConversion.cpp - Test conversion to VCIX ops ---------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 10 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 11 | #include "mlir/Dialect/LLVMIR/VCIXDialect.h" |
| 12 | #include "mlir/Dialect/Math/IR/Math.h" |
| 13 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 14 | #include "mlir/IR/PatternMatch.h" |
| 15 | #include "mlir/Pass/Pass.h" |
| 16 | #include "mlir/Pass/PassManager.h" |
| 17 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 18 | |
| 19 | namespace mlir { |
| 20 | namespace { |
| 21 | |
| 22 | /// Return number of extracts required to make input VectorType \vt legal and |
| 23 | /// also return thatlegal vector type. |
| 24 | /// For fixed vectors nothing special is needed. Scalable vectors are legalizes |
| 25 | /// according to LLVM's encoding: |
| 26 | /// https://lists.llvm.org/pipermail/llvm-dev/2020-October/145850.html |
| 27 | static std::pair<unsigned, VectorType> legalizeVectorType(const Type &type) { |
| 28 | VectorType vt = dyn_cast<VectorType>(type); |
| 29 | // To simplify test pass, avoid multi-dimensional vectors. |
| 30 | if (!vt || vt.getRank() != 1) |
| 31 | return {0, nullptr}; |
| 32 | |
| 33 | if (!vt.isScalable()) |
| 34 | return {1, vt}; |
| 35 | |
| 36 | Type eltTy = vt.getElementType(); |
| 37 | unsigned sew = 0; |
| 38 | if (eltTy.isF32()) |
| 39 | sew = 32; |
| 40 | else if (eltTy.isF64()) |
| 41 | sew = 64; |
| 42 | else if (auto intTy = dyn_cast<IntegerType>(eltTy)) |
| 43 | sew = intTy.getWidth(); |
| 44 | else |
| 45 | return {0, nullptr}; |
| 46 | |
| 47 | unsigned eltCount = vt.getShape()[0]; |
| 48 | const unsigned lmul = eltCount * sew / 64; |
| 49 | |
| 50 | unsigned n = lmul > 8 ? llvm::Log2_32(Value: lmul) - 2 : 1; |
| 51 | return {n, VectorType::get({eltCount >> (n - 1)}, eltTy, {true})}; |
| 52 | } |
| 53 | |
| 54 | /// Replace math.cos(v) operation with vcix.v.iv(v). |
| 55 | struct MathCosToVCIX final : OpRewritePattern<math::CosOp> { |
| 56 | using OpRewritePattern::OpRewritePattern; |
| 57 | |
| 58 | LogicalResult matchAndRewrite(math::CosOp op, |
| 59 | PatternRewriter &rewriter) const override { |
| 60 | const Type opType = op.getOperand().getType(); |
| 61 | auto [n, legalType] = legalizeVectorType(opType); |
| 62 | if (!legalType) |
| 63 | return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV" ); |
| 64 | Location loc = op.getLoc(); |
| 65 | Value vec = op.getOperand(); |
| 66 | Attribute immAttr = rewriter.getI32IntegerAttr(0); |
| 67 | Attribute opcodeAttr = rewriter.getI64IntegerAttr(0); |
| 68 | Value rvl = nullptr; |
| 69 | if (legalType.isScalable()) |
| 70 | // Use arbitrary runtime vector length when vector type is scalable. |
| 71 | // Proper conversion pass should take it from the IR. |
| 72 | rvl = rewriter.create<arith::ConstantOp>(loc, |
| 73 | rewriter.getI64IntegerAttr(9)); |
| 74 | Value res; |
| 75 | if (n == 1) { |
| 76 | res = rewriter.create<vcix::BinaryImmOp>(loc, legalType, opcodeAttr, vec, |
| 77 | immAttr, rvl); |
| 78 | } else { |
| 79 | const unsigned eltCount = legalType.getShape()[0]; |
| 80 | Type eltTy = legalType.getElementType(); |
| 81 | Value zero = rewriter.create<arith::ConstantOp>( |
| 82 | loc, eltTy, rewriter.getZeroAttr(eltTy)); |
| 83 | res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); |
| 84 | for (unsigned i = 0; i < n; ++i) { |
| 85 | Value extracted = rewriter.create<vector::ScalableExtractOp>( |
| 86 | loc, legalType, vec, i * eltCount); |
| 87 | Value v = rewriter.create<vcix::BinaryImmOp>(loc, legalType, opcodeAttr, |
| 88 | extracted, immAttr, rvl); |
| 89 | res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, |
| 90 | i * eltCount); |
| 91 | } |
| 92 | } |
| 93 | rewriter.replaceOp(op, res); |
| 94 | return success(); |
| 95 | } |
| 96 | }; |
| 97 | |
| 98 | // Replace math.sin(v) operation with vcix.v.sv(v, v). |
| 99 | struct MathSinToVCIX final : OpRewritePattern<math::SinOp> { |
| 100 | using OpRewritePattern::OpRewritePattern; |
| 101 | |
| 102 | LogicalResult matchAndRewrite(math::SinOp op, |
| 103 | PatternRewriter &rewriter) const override { |
| 104 | const Type opType = op.getOperand().getType(); |
| 105 | auto [n, legalType] = legalizeVectorType(opType); |
| 106 | if (!legalType) |
| 107 | return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV" ); |
| 108 | Location loc = op.getLoc(); |
| 109 | Value vec = op.getOperand(); |
| 110 | Attribute opcodeAttr = rewriter.getI64IntegerAttr(0); |
| 111 | Value rvl = nullptr; |
| 112 | if (legalType.isScalable()) |
| 113 | // Use arbitrary runtime vector length when vector type is scalable. |
| 114 | // Proper conversion pass should take it from the IR. |
| 115 | rvl = rewriter.create<arith::ConstantOp>(loc, |
| 116 | rewriter.getI64IntegerAttr(9)); |
| 117 | Value res; |
| 118 | if (n == 1) { |
| 119 | res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec, |
| 120 | vec, rvl); |
| 121 | } else { |
| 122 | const unsigned eltCount = legalType.getShape()[0]; |
| 123 | Type eltTy = legalType.getElementType(); |
| 124 | Value zero = rewriter.create<arith::ConstantOp>( |
| 125 | loc, eltTy, rewriter.getZeroAttr(eltTy)); |
| 126 | res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); |
| 127 | for (unsigned i = 0; i < n; ++i) { |
| 128 | Value extracted = rewriter.create<vector::ScalableExtractOp>( |
| 129 | loc, legalType, vec, i * eltCount); |
| 130 | Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, |
| 131 | extracted, extracted, rvl); |
| 132 | res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, |
| 133 | i * eltCount); |
| 134 | } |
| 135 | } |
| 136 | rewriter.replaceOp(op, res); |
| 137 | return success(); |
| 138 | } |
| 139 | }; |
| 140 | |
| 141 | // Replace math.tan(v) operation with vcix.v.sv(v, 0.0f). |
| 142 | struct MathTanToVCIX final : OpRewritePattern<math::TanOp> { |
| 143 | using OpRewritePattern::OpRewritePattern; |
| 144 | |
| 145 | LogicalResult matchAndRewrite(math::TanOp op, |
| 146 | PatternRewriter &rewriter) const override { |
| 147 | const Type opType = op.getOperand().getType(); |
| 148 | auto [n, legalType] = legalizeVectorType(opType); |
| 149 | Type eltTy = legalType.getElementType(); |
| 150 | if (!legalType) |
| 151 | return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV" ); |
| 152 | Location loc = op.getLoc(); |
| 153 | Value vec = op.getOperand(); |
| 154 | Attribute opcodeAttr = rewriter.getI64IntegerAttr(0); |
| 155 | Value zero = rewriter.create<arith::ConstantOp>( |
| 156 | loc, eltTy, rewriter.getZeroAttr(eltTy)); |
| 157 | Value rvl = nullptr; |
| 158 | if (legalType.isScalable()) |
| 159 | // Use arbitrary runtime vector length when vector type is scalable. |
| 160 | // Proper conversion pass should take it from the IR. |
| 161 | rvl = rewriter.create<arith::ConstantOp>(loc, |
| 162 | rewriter.getI64IntegerAttr(9)); |
| 163 | Value res; |
| 164 | if (n == 1) { |
| 165 | res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec, |
| 166 | zero, rvl); |
| 167 | } else { |
| 168 | const unsigned eltCount = legalType.getShape()[0]; |
| 169 | res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); |
| 170 | for (unsigned i = 0; i < n; ++i) { |
| 171 | Value extracted = rewriter.create<vector::ScalableExtractOp>( |
| 172 | loc, legalType, vec, i * eltCount); |
| 173 | Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, |
| 174 | extracted, zero, rvl); |
| 175 | res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, |
| 176 | i * eltCount); |
| 177 | } |
| 178 | } |
| 179 | rewriter.replaceOp(op, res); |
| 180 | return success(); |
| 181 | } |
| 182 | }; |
| 183 | |
| 184 | // Replace math.log(v) operation with vcix.v.sv(v, 0). |
| 185 | struct MathLogToVCIX final : OpRewritePattern<math::LogOp> { |
| 186 | using OpRewritePattern::OpRewritePattern; |
| 187 | |
| 188 | LogicalResult matchAndRewrite(math::LogOp op, |
| 189 | PatternRewriter &rewriter) const override { |
| 190 | const Type opType = op.getOperand().getType(); |
| 191 | auto [n, legalType] = legalizeVectorType(opType); |
| 192 | if (!legalType) |
| 193 | return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV" ); |
| 194 | Location loc = op.getLoc(); |
| 195 | Value vec = op.getOperand(); |
| 196 | Attribute opcodeAttr = rewriter.getI64IntegerAttr(0); |
| 197 | Value rvl = nullptr; |
| 198 | Value zeroInt = rewriter.create<arith::ConstantOp>( |
| 199 | loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); |
| 200 | if (legalType.isScalable()) |
| 201 | // Use arbitrary runtime vector length when vector type is scalable. |
| 202 | // Proper conversion pass should take it from the IR. |
| 203 | rvl = rewriter.create<arith::ConstantOp>(loc, |
| 204 | rewriter.getI64IntegerAttr(9)); |
| 205 | Value res; |
| 206 | if (n == 1) { |
| 207 | res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec, |
| 208 | zeroInt, rvl); |
| 209 | } else { |
| 210 | const unsigned eltCount = legalType.getShape()[0]; |
| 211 | Type eltTy = legalType.getElementType(); |
| 212 | Value zero = rewriter.create<arith::ConstantOp>( |
| 213 | loc, eltTy, rewriter.getZeroAttr(eltTy)); |
| 214 | res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); |
| 215 | for (unsigned i = 0; i < n; ++i) { |
| 216 | Value extracted = rewriter.create<vector::ScalableExtractOp>( |
| 217 | loc, legalType, vec, i * eltCount); |
| 218 | Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, |
| 219 | extracted, zeroInt, rvl); |
| 220 | res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, |
| 221 | i * eltCount); |
| 222 | } |
| 223 | } |
| 224 | rewriter.replaceOp(op, res); |
| 225 | return success(); |
| 226 | } |
| 227 | }; |
| 228 | |
| 229 | struct TestMathToVCIX |
| 230 | : PassWrapper<TestMathToVCIX, OperationPass<func::FuncOp>> { |
| 231 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMathToVCIX) |
| 232 | |
| 233 | StringRef getArgument() const final { return "test-math-to-vcix" ; } |
| 234 | |
| 235 | StringRef getDescription() const final { |
| 236 | return "Test lowering patterns that converts some vector operations to " |
| 237 | "VCIX. Since DLA can implement VCIX instructions in completely " |
| 238 | "different way, conversions of that test pass only lives here." ; |
| 239 | } |
| 240 | |
| 241 | void getDependentDialects(DialectRegistry ®istry) const override { |
| 242 | registry.insert<arith::ArithDialect, func::FuncDialect, math::MathDialect, |
| 243 | vcix::VCIXDialect, vector::VectorDialect>(); |
| 244 | } |
| 245 | |
| 246 | void runOnOperation() override { |
| 247 | MLIRContext *ctx = &getContext(); |
| 248 | RewritePatternSet patterns(ctx); |
| 249 | patterns.add<MathCosToVCIX, MathSinToVCIX, MathTanToVCIX, MathLogToVCIX>( |
| 250 | arg&: ctx); |
| 251 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
| 252 | } |
| 253 | }; |
| 254 | |
| 255 | } // namespace |
| 256 | |
| 257 | namespace test { |
| 258 | void registerTestMathToVCIXPass() { PassRegistration<TestMathToVCIX>(); } |
| 259 | } // namespace test |
| 260 | } // namespace mlir |
| 261 | |