1//===- TestMathToVCIXConversion.cpp - Test conversion to VCIX ops ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/IR/Arith.h"
10#include "mlir/Dialect/Func/IR/FuncOps.h"
11#include "mlir/Dialect/LLVMIR/VCIXDialect.h"
12#include "mlir/Dialect/Math/IR/Math.h"
13#include "mlir/Dialect/Vector/IR/VectorOps.h"
14#include "mlir/IR/PatternMatch.h"
15#include "mlir/Pass/Pass.h"
16#include "mlir/Pass/PassManager.h"
17#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18
19namespace mlir {
20namespace {
21
22/// Return number of extracts required to make input VectorType \vt legal and
23/// also return thatlegal vector type.
24/// For fixed vectors nothing special is needed. Scalable vectors are legalizes
25/// according to LLVM's encoding:
26/// https://lists.llvm.org/pipermail/llvm-dev/2020-October/145850.html
27static std::pair<unsigned, VectorType> legalizeVectorType(const Type &type) {
28 VectorType vt = cast<VectorType>(type);
29 // To simplify test pass, avoid multi-dimensional vectors.
30 if (!vt || vt.getRank() != 1)
31 return {0, nullptr};
32
33 if (!vt.isScalable())
34 return {1, vt};
35
36 Type eltTy = vt.getElementType();
37 unsigned sew = 0;
38 if (eltTy.isF32())
39 sew = 32;
40 else if (eltTy.isF64())
41 sew = 64;
42 else if (auto intTy = dyn_cast<IntegerType>(eltTy))
43 sew = intTy.getWidth();
44 else
45 return {0, nullptr};
46
47 unsigned eltCount = vt.getShape()[0];
48 const unsigned lmul = eltCount * sew / 64;
49
50 unsigned n = lmul > 8 ? llvm::Log2_32(Value: lmul) - 2 : 1;
51 return {n, VectorType::get({eltCount >> (n - 1)}, eltTy, {true})};
52}
53
54/// Replace math.cos(v) operation with vcix.v.iv(v).
55struct MathCosToVCIX final : OpRewritePattern<math::CosOp> {
56 using OpRewritePattern::OpRewritePattern;
57
58 LogicalResult matchAndRewrite(math::CosOp op,
59 PatternRewriter &rewriter) const override {
60 const Type opType = op.getOperand().getType();
61 auto [n, legalType] = legalizeVectorType(opType);
62 if (!legalType)
63 return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV");
64 Location loc = op.getLoc();
65 Value vec = op.getOperand();
66 Attribute immAttr = rewriter.getI32IntegerAttr(0);
67 Attribute opcodeAttr = rewriter.getI64IntegerAttr(0);
68 Value rvl = nullptr;
69 if (legalType.isScalable())
70 // Use arbitrary runtime vector length when vector type is scalable.
71 // Proper conversion pass should take it from the IR.
72 rvl = rewriter.create<arith::ConstantOp>(loc,
73 rewriter.getI64IntegerAttr(9));
74 Value res;
75 if (n == 1) {
76 res = rewriter.create<vcix::BinaryImmOp>(loc, legalType, opcodeAttr, vec,
77 immAttr, rvl);
78 } else {
79 const unsigned eltCount = legalType.getShape()[0];
80 Type eltTy = legalType.getElementType();
81 Value zero = rewriter.create<arith::ConstantOp>(
82 loc, eltTy, rewriter.getZeroAttr(eltTy));
83 res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/);
84 for (unsigned i = 0; i < n; ++i) {
85 Value extracted = rewriter.create<vector::ScalableExtractOp>(
86 loc, legalType, vec, i * eltCount);
87 Value v = rewriter.create<vcix::BinaryImmOp>(loc, legalType, opcodeAttr,
88 extracted, immAttr, rvl);
89 res = rewriter.create<vector::ScalableInsertOp>(loc, v, res,
90 i * eltCount);
91 }
92 }
93 rewriter.replaceOp(op, res);
94 return success();
95 }
96};
97
98// Replace math.sin(v) operation with vcix.v.sv(v, v).
99struct MathSinToVCIX final : OpRewritePattern<math::SinOp> {
100 using OpRewritePattern::OpRewritePattern;
101
102 LogicalResult matchAndRewrite(math::SinOp op,
103 PatternRewriter &rewriter) const override {
104 const Type opType = op.getOperand().getType();
105 auto [n, legalType] = legalizeVectorType(opType);
106 if (!legalType)
107 return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV");
108 Location loc = op.getLoc();
109 Value vec = op.getOperand();
110 Attribute opcodeAttr = rewriter.getI64IntegerAttr(0);
111 Value rvl = nullptr;
112 if (legalType.isScalable())
113 // Use arbitrary runtime vector length when vector type is scalable.
114 // Proper conversion pass should take it from the IR.
115 rvl = rewriter.create<arith::ConstantOp>(loc,
116 rewriter.getI64IntegerAttr(9));
117 Value res;
118 if (n == 1) {
119 res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec,
120 vec, rvl);
121 } else {
122 const unsigned eltCount = legalType.getShape()[0];
123 Type eltTy = legalType.getElementType();
124 Value zero = rewriter.create<arith::ConstantOp>(
125 loc, eltTy, rewriter.getZeroAttr(eltTy));
126 res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/);
127 for (unsigned i = 0; i < n; ++i) {
128 Value extracted = rewriter.create<vector::ScalableExtractOp>(
129 loc, legalType, vec, i * eltCount);
130 Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr,
131 extracted, extracted, rvl);
132 res = rewriter.create<vector::ScalableInsertOp>(loc, v, res,
133 i * eltCount);
134 }
135 }
136 rewriter.replaceOp(op, res);
137 return success();
138 }
139};
140
141// Replace math.tan(v) operation with vcix.v.sv(v, 0.0f).
142struct MathTanToVCIX final : OpRewritePattern<math::TanOp> {
143 using OpRewritePattern::OpRewritePattern;
144
145 LogicalResult matchAndRewrite(math::TanOp op,
146 PatternRewriter &rewriter) const override {
147 const Type opType = op.getOperand().getType();
148 auto [n, legalType] = legalizeVectorType(opType);
149 Type eltTy = legalType.getElementType();
150 if (!legalType)
151 return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV");
152 Location loc = op.getLoc();
153 Value vec = op.getOperand();
154 Attribute opcodeAttr = rewriter.getI64IntegerAttr(0);
155 Value zero = rewriter.create<arith::ConstantOp>(
156 loc, eltTy, rewriter.getZeroAttr(eltTy));
157 Value rvl = nullptr;
158 if (legalType.isScalable())
159 // Use arbitrary runtime vector length when vector type is scalable.
160 // Proper conversion pass should take it from the IR.
161 rvl = rewriter.create<arith::ConstantOp>(loc,
162 rewriter.getI64IntegerAttr(9));
163 Value res;
164 if (n == 1) {
165 res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec,
166 zero, rvl);
167 } else {
168 const unsigned eltCount = legalType.getShape()[0];
169 res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/);
170 for (unsigned i = 0; i < n; ++i) {
171 Value extracted = rewriter.create<vector::ScalableExtractOp>(
172 loc, legalType, vec, i * eltCount);
173 Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr,
174 extracted, zero, rvl);
175 res = rewriter.create<vector::ScalableInsertOp>(loc, v, res,
176 i * eltCount);
177 }
178 }
179 rewriter.replaceOp(op, res);
180 return success();
181 }
182};
183
184// Replace math.log(v) operation with vcix.v.sv(v, 0).
185struct MathLogToVCIX final : OpRewritePattern<math::LogOp> {
186 using OpRewritePattern::OpRewritePattern;
187
188 LogicalResult matchAndRewrite(math::LogOp op,
189 PatternRewriter &rewriter) const override {
190 const Type opType = op.getOperand().getType();
191 auto [n, legalType] = legalizeVectorType(opType);
192 if (!legalType)
193 return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV");
194 Location loc = op.getLoc();
195 Value vec = op.getOperand();
196 Attribute opcodeAttr = rewriter.getI64IntegerAttr(0);
197 Value rvl = nullptr;
198 Value zeroInt = rewriter.create<arith::ConstantOp>(
199 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
200 if (legalType.isScalable())
201 // Use arbitrary runtime vector length when vector type is scalable.
202 // Proper conversion pass should take it from the IR.
203 rvl = rewriter.create<arith::ConstantOp>(loc,
204 rewriter.getI64IntegerAttr(9));
205 Value res;
206 if (n == 1) {
207 res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec,
208 zeroInt, rvl);
209 } else {
210 const unsigned eltCount = legalType.getShape()[0];
211 Type eltTy = legalType.getElementType();
212 Value zero = rewriter.create<arith::ConstantOp>(
213 loc, eltTy, rewriter.getZeroAttr(eltTy));
214 res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/);
215 for (unsigned i = 0; i < n; ++i) {
216 Value extracted = rewriter.create<vector::ScalableExtractOp>(
217 loc, legalType, vec, i * eltCount);
218 Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr,
219 extracted, zeroInt, rvl);
220 res = rewriter.create<vector::ScalableInsertOp>(loc, v, res,
221 i * eltCount);
222 }
223 }
224 rewriter.replaceOp(op, res);
225 return success();
226 }
227};
228
229struct TestMathToVCIX
230 : PassWrapper<TestMathToVCIX, OperationPass<func::FuncOp>> {
231 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMathToVCIX)
232
233 StringRef getArgument() const final { return "test-math-to-vcix"; }
234
235 StringRef getDescription() const final {
236 return "Test lowering patterns that converts some vector operations to "
237 "VCIX. Since DLA can implement VCIX instructions in completely "
238 "different way, conversions of that test pass only lives here.";
239 }
240
241 void getDependentDialects(DialectRegistry &registry) const override {
242 registry.insert<arith::ArithDialect, func::FuncDialect, math::MathDialect,
243 vcix::VCIXDialect, vector::VectorDialect>();
244 }
245
246 void runOnOperation() override {
247 MLIRContext *ctx = &getContext();
248 RewritePatternSet patterns(ctx);
249 patterns.add<MathCosToVCIX, MathSinToVCIX, MathTanToVCIX, MathLogToVCIX>(
250 arg&: ctx);
251 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
252 }
253};
254
255} // namespace
256
257namespace test {
258void registerTestMathToVCIXPass() { PassRegistration<TestMathToVCIX>(); }
259} // namespace test
260} // namespace mlir
261

source code of mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp