1 | //===- TestMathToVCIXConversion.cpp - Test conversion to VCIX ops ---------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Arith/IR/Arith.h" |
10 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
11 | #include "mlir/Dialect/LLVMIR/VCIXDialect.h" |
12 | #include "mlir/Dialect/Math/IR/Math.h" |
13 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
14 | #include "mlir/IR/PatternMatch.h" |
15 | #include "mlir/Pass/Pass.h" |
16 | #include "mlir/Pass/PassManager.h" |
17 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
18 | |
19 | namespace mlir { |
20 | namespace { |
21 | |
22 | /// Return number of extracts required to make input VectorType \vt legal and |
23 | /// also return thatlegal vector type. |
24 | /// For fixed vectors nothing special is needed. Scalable vectors are legalizes |
25 | /// according to LLVM's encoding: |
26 | /// https://lists.llvm.org/pipermail/llvm-dev/2020-October/145850.html |
27 | static std::pair<unsigned, VectorType> legalizeVectorType(const Type &type) { |
28 | VectorType vt = cast<VectorType>(type); |
29 | // To simplify test pass, avoid multi-dimensional vectors. |
30 | if (!vt || vt.getRank() != 1) |
31 | return {0, nullptr}; |
32 | |
33 | if (!vt.isScalable()) |
34 | return {1, vt}; |
35 | |
36 | Type eltTy = vt.getElementType(); |
37 | unsigned sew = 0; |
38 | if (eltTy.isF32()) |
39 | sew = 32; |
40 | else if (eltTy.isF64()) |
41 | sew = 64; |
42 | else if (auto intTy = dyn_cast<IntegerType>(eltTy)) |
43 | sew = intTy.getWidth(); |
44 | else |
45 | return {0, nullptr}; |
46 | |
47 | unsigned eltCount = vt.getShape()[0]; |
48 | const unsigned lmul = eltCount * sew / 64; |
49 | |
50 | unsigned n = lmul > 8 ? llvm::Log2_32(Value: lmul) - 2 : 1; |
51 | return {n, VectorType::get({eltCount >> (n - 1)}, eltTy, {true})}; |
52 | } |
53 | |
54 | /// Replace math.cos(v) operation with vcix.v.iv(v). |
55 | struct MathCosToVCIX final : OpRewritePattern<math::CosOp> { |
56 | using OpRewritePattern::OpRewritePattern; |
57 | |
58 | LogicalResult matchAndRewrite(math::CosOp op, |
59 | PatternRewriter &rewriter) const override { |
60 | const Type opType = op.getOperand().getType(); |
61 | auto [n, legalType] = legalizeVectorType(opType); |
62 | if (!legalType) |
63 | return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV" ); |
64 | Location loc = op.getLoc(); |
65 | Value vec = op.getOperand(); |
66 | Attribute immAttr = rewriter.getI32IntegerAttr(0); |
67 | Attribute opcodeAttr = rewriter.getI64IntegerAttr(0); |
68 | Value rvl = nullptr; |
69 | if (legalType.isScalable()) |
70 | // Use arbitrary runtime vector length when vector type is scalable. |
71 | // Proper conversion pass should take it from the IR. |
72 | rvl = rewriter.create<arith::ConstantOp>(loc, |
73 | rewriter.getI64IntegerAttr(9)); |
74 | Value res; |
75 | if (n == 1) { |
76 | res = rewriter.create<vcix::BinaryImmOp>(loc, legalType, opcodeAttr, vec, |
77 | immAttr, rvl); |
78 | } else { |
79 | const unsigned eltCount = legalType.getShape()[0]; |
80 | Type eltTy = legalType.getElementType(); |
81 | Value zero = rewriter.create<arith::ConstantOp>( |
82 | loc, eltTy, rewriter.getZeroAttr(eltTy)); |
83 | res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); |
84 | for (unsigned i = 0; i < n; ++i) { |
85 | Value extracted = rewriter.create<vector::ScalableExtractOp>( |
86 | loc, legalType, vec, i * eltCount); |
87 | Value v = rewriter.create<vcix::BinaryImmOp>(loc, legalType, opcodeAttr, |
88 | extracted, immAttr, rvl); |
89 | res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, |
90 | i * eltCount); |
91 | } |
92 | } |
93 | rewriter.replaceOp(op, res); |
94 | return success(); |
95 | } |
96 | }; |
97 | |
98 | // Replace math.sin(v) operation with vcix.v.sv(v, v). |
99 | struct MathSinToVCIX final : OpRewritePattern<math::SinOp> { |
100 | using OpRewritePattern::OpRewritePattern; |
101 | |
102 | LogicalResult matchAndRewrite(math::SinOp op, |
103 | PatternRewriter &rewriter) const override { |
104 | const Type opType = op.getOperand().getType(); |
105 | auto [n, legalType] = legalizeVectorType(opType); |
106 | if (!legalType) |
107 | return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV" ); |
108 | Location loc = op.getLoc(); |
109 | Value vec = op.getOperand(); |
110 | Attribute opcodeAttr = rewriter.getI64IntegerAttr(0); |
111 | Value rvl = nullptr; |
112 | if (legalType.isScalable()) |
113 | // Use arbitrary runtime vector length when vector type is scalable. |
114 | // Proper conversion pass should take it from the IR. |
115 | rvl = rewriter.create<arith::ConstantOp>(loc, |
116 | rewriter.getI64IntegerAttr(9)); |
117 | Value res; |
118 | if (n == 1) { |
119 | res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec, |
120 | vec, rvl); |
121 | } else { |
122 | const unsigned eltCount = legalType.getShape()[0]; |
123 | Type eltTy = legalType.getElementType(); |
124 | Value zero = rewriter.create<arith::ConstantOp>( |
125 | loc, eltTy, rewriter.getZeroAttr(eltTy)); |
126 | res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); |
127 | for (unsigned i = 0; i < n; ++i) { |
128 | Value extracted = rewriter.create<vector::ScalableExtractOp>( |
129 | loc, legalType, vec, i * eltCount); |
130 | Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, |
131 | extracted, extracted, rvl); |
132 | res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, |
133 | i * eltCount); |
134 | } |
135 | } |
136 | rewriter.replaceOp(op, res); |
137 | return success(); |
138 | } |
139 | }; |
140 | |
141 | // Replace math.tan(v) operation with vcix.v.sv(v, 0.0f). |
142 | struct MathTanToVCIX final : OpRewritePattern<math::TanOp> { |
143 | using OpRewritePattern::OpRewritePattern; |
144 | |
145 | LogicalResult matchAndRewrite(math::TanOp op, |
146 | PatternRewriter &rewriter) const override { |
147 | const Type opType = op.getOperand().getType(); |
148 | auto [n, legalType] = legalizeVectorType(opType); |
149 | Type eltTy = legalType.getElementType(); |
150 | if (!legalType) |
151 | return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV" ); |
152 | Location loc = op.getLoc(); |
153 | Value vec = op.getOperand(); |
154 | Attribute opcodeAttr = rewriter.getI64IntegerAttr(0); |
155 | Value zero = rewriter.create<arith::ConstantOp>( |
156 | loc, eltTy, rewriter.getZeroAttr(eltTy)); |
157 | Value rvl = nullptr; |
158 | if (legalType.isScalable()) |
159 | // Use arbitrary runtime vector length when vector type is scalable. |
160 | // Proper conversion pass should take it from the IR. |
161 | rvl = rewriter.create<arith::ConstantOp>(loc, |
162 | rewriter.getI64IntegerAttr(9)); |
163 | Value res; |
164 | if (n == 1) { |
165 | res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec, |
166 | zero, rvl); |
167 | } else { |
168 | const unsigned eltCount = legalType.getShape()[0]; |
169 | res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); |
170 | for (unsigned i = 0; i < n; ++i) { |
171 | Value extracted = rewriter.create<vector::ScalableExtractOp>( |
172 | loc, legalType, vec, i * eltCount); |
173 | Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, |
174 | extracted, zero, rvl); |
175 | res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, |
176 | i * eltCount); |
177 | } |
178 | } |
179 | rewriter.replaceOp(op, res); |
180 | return success(); |
181 | } |
182 | }; |
183 | |
184 | // Replace math.log(v) operation with vcix.v.sv(v, 0). |
185 | struct MathLogToVCIX final : OpRewritePattern<math::LogOp> { |
186 | using OpRewritePattern::OpRewritePattern; |
187 | |
188 | LogicalResult matchAndRewrite(math::LogOp op, |
189 | PatternRewriter &rewriter) const override { |
190 | const Type opType = op.getOperand().getType(); |
191 | auto [n, legalType] = legalizeVectorType(opType); |
192 | if (!legalType) |
193 | return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV" ); |
194 | Location loc = op.getLoc(); |
195 | Value vec = op.getOperand(); |
196 | Attribute opcodeAttr = rewriter.getI64IntegerAttr(0); |
197 | Value rvl = nullptr; |
198 | Value zeroInt = rewriter.create<arith::ConstantOp>( |
199 | loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); |
200 | if (legalType.isScalable()) |
201 | // Use arbitrary runtime vector length when vector type is scalable. |
202 | // Proper conversion pass should take it from the IR. |
203 | rvl = rewriter.create<arith::ConstantOp>(loc, |
204 | rewriter.getI64IntegerAttr(9)); |
205 | Value res; |
206 | if (n == 1) { |
207 | res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec, |
208 | zeroInt, rvl); |
209 | } else { |
210 | const unsigned eltCount = legalType.getShape()[0]; |
211 | Type eltTy = legalType.getElementType(); |
212 | Value zero = rewriter.create<arith::ConstantOp>( |
213 | loc, eltTy, rewriter.getZeroAttr(eltTy)); |
214 | res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); |
215 | for (unsigned i = 0; i < n; ++i) { |
216 | Value extracted = rewriter.create<vector::ScalableExtractOp>( |
217 | loc, legalType, vec, i * eltCount); |
218 | Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, |
219 | extracted, zeroInt, rvl); |
220 | res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, |
221 | i * eltCount); |
222 | } |
223 | } |
224 | rewriter.replaceOp(op, res); |
225 | return success(); |
226 | } |
227 | }; |
228 | |
229 | struct TestMathToVCIX |
230 | : PassWrapper<TestMathToVCIX, OperationPass<func::FuncOp>> { |
231 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMathToVCIX) |
232 | |
233 | StringRef getArgument() const final { return "test-math-to-vcix" ; } |
234 | |
235 | StringRef getDescription() const final { |
236 | return "Test lowering patterns that converts some vector operations to " |
237 | "VCIX. Since DLA can implement VCIX instructions in completely " |
238 | "different way, conversions of that test pass only lives here." ; |
239 | } |
240 | |
241 | void getDependentDialects(DialectRegistry ®istry) const override { |
242 | registry.insert<arith::ArithDialect, func::FuncDialect, math::MathDialect, |
243 | vcix::VCIXDialect, vector::VectorDialect>(); |
244 | } |
245 | |
246 | void runOnOperation() override { |
247 | MLIRContext *ctx = &getContext(); |
248 | RewritePatternSet patterns(ctx); |
249 | patterns.add<MathCosToVCIX, MathSinToVCIX, MathTanToVCIX, MathLogToVCIX>( |
250 | arg&: ctx); |
251 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
252 | } |
253 | }; |
254 | |
255 | } // namespace |
256 | |
257 | namespace test { |
258 | void registerTestMathToVCIXPass() { PassRegistration<TestMathToVCIX>(); } |
259 | } // namespace test |
260 | } // namespace mlir |
261 | |