1 | //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
10 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
11 | #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" |
12 | #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" |
13 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
14 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
15 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
16 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
17 | #include "mlir/IR/BuiltinOps.h" |
18 | #include "mlir/IR/PatternMatch.h" |
19 | |
20 | using namespace mlir; |
21 | using namespace mlir::arm_sve; |
22 | |
23 | using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>; |
24 | using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>; |
25 | using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>; |
26 | using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>; |
27 | using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>; |
28 | using DupQLaneLowering = |
29 | OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>; |
30 | using ScalableMaskedAddIOpLowering = |
31 | OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp, |
32 | ScalableMaskedAddIIntrOp>; |
33 | using ScalableMaskedAddFOpLowering = |
34 | OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp, |
35 | ScalableMaskedAddFIntrOp>; |
36 | using ScalableMaskedSubIOpLowering = |
37 | OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp, |
38 | ScalableMaskedSubIIntrOp>; |
39 | using ScalableMaskedSubFOpLowering = |
40 | OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp, |
41 | ScalableMaskedSubFIntrOp>; |
42 | using ScalableMaskedMulIOpLowering = |
43 | OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp, |
44 | ScalableMaskedMulIIntrOp>; |
45 | using ScalableMaskedMulFOpLowering = |
46 | OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp, |
47 | ScalableMaskedMulFIntrOp>; |
48 | using ScalableMaskedSDivIOpLowering = |
49 | OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp, |
50 | ScalableMaskedSDivIIntrOp>; |
51 | using ScalableMaskedUDivIOpLowering = |
52 | OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp, |
53 | ScalableMaskedUDivIIntrOp>; |
54 | using ScalableMaskedDivFOpLowering = |
55 | OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp, |
56 | ScalableMaskedDivFIntrOp>; |
57 | |
58 | namespace { |
59 | |
60 | /// Unrolls a conversion to/from equivalent vector types, to allow using a |
61 | /// conversion intrinsic that only supports 1-D vector types. |
62 | /// |
63 | /// Example: |
64 | /// ``` |
65 | /// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1> |
66 | /// ``` |
67 | /// is rewritten into: |
68 | /// ``` |
69 | /// %cst = arith.constant dense<false> : vector<2x[16]xi1> |
70 | /// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1> |
71 | /// %2 = "arm_sve.intr.convert.to.svbool"(%1) |
72 | /// : (vector<[4]xi1>) -> vector<[16]xi1> |
73 | /// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1> |
74 | /// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1> |
75 | /// %5 = "arm_sve.intr.convert.to.svbool"(%4) |
76 | /// : (vector<[4]xi1>) -> vector<[16]xi1> |
77 | /// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1> |
78 | /// ``` |
79 | template <typename Op, typename IntrOp> |
80 | struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> { |
81 | using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern; |
82 | |
83 | LogicalResult |
84 | matchAndRewrite(Op convertOp, typename Op::Adaptor, |
85 | ConversionPatternRewriter &rewriter) const override { |
86 | auto loc = convertOp.getLoc(); |
87 | |
88 | auto source = convertOp.getSource(); |
89 | VectorType sourceType = source.getType(); |
90 | VectorType resultType = convertOp.getResult().getType(); |
91 | |
92 | Value result = rewriter.create<arith::ConstantOp>( |
93 | loc, resultType, rewriter.getZeroAttr(resultType)); |
94 | |
95 | // We want to iterate over the input vector in steps of the trailing |
96 | // dimension. So this creates tile shape where all leading dimensions are 1, |
97 | // and the trailing dimension step is the size of the dimension. |
98 | SmallVector<int64_t> tileShape(sourceType.getRank(), 1); |
99 | tileShape.back() = sourceType.getShape().back(); |
100 | |
101 | // Iterate over all scalable mask/predicate slices of the source vector. |
102 | for (SmallVector<int64_t> index : |
103 | StaticTileOffsetRange(sourceType.getShape(), tileShape)) { |
104 | auto extractOrInsertPosition = ArrayRef(index).drop_back(); |
105 | auto sourceVector = rewriter.create<vector::ExtractOp>( |
106 | loc, source, extractOrInsertPosition); |
107 | VectorType convertedType = |
108 | VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType())) |
109 | .setDim(0, resultType.getShape().back()); |
110 | auto convertedVector = |
111 | rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector); |
112 | result = rewriter.create<vector::InsertOp>(loc, convertedVector, result, |
113 | extractOrInsertPosition); |
114 | } |
115 | |
116 | rewriter.replaceOp(convertOp, result); |
117 | return success(); |
118 | } |
119 | }; |
120 | |
121 | using ConvertToSvboolOpLowering = |
122 | SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>; |
123 | |
124 | using ConvertFromSvboolOpLowering = |
125 | SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>; |
126 | |
127 | using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>; |
128 | using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>; |
129 | |
130 | /// Lower `arm_sve.psel` to LLVM intrinsics. This is almost a 1-to-1 conversion |
131 | /// but first input (P1) and result predicates need conversion to/from svbool. |
132 | struct PselOpLowering : public ConvertOpToLLVMPattern<PselOp> { |
133 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
134 | |
135 | LogicalResult |
136 | matchAndRewrite(PselOp pselOp, PselOp::Adaptor adaptor, |
137 | ConversionPatternRewriter &rewriter) const override { |
138 | auto svboolType = VectorType::get(16, rewriter.getI1Type(), true); |
139 | auto loc = pselOp.getLoc(); |
140 | auto svboolP1 = rewriter.create<ConvertToSvboolIntrOp>(loc, svboolType, |
141 | adaptor.getP1()); |
142 | auto indexI32 = rewriter.create<arith::IndexCastOp>( |
143 | loc, rewriter.getI32Type(), pselOp.getIndex()); |
144 | auto pselIntr = rewriter.create<PselIntrOp>(loc, svboolType, svboolP1, |
145 | pselOp.getP2(), indexI32); |
146 | rewriter.replaceOpWithNewOp<ConvertFromSvboolIntrOp>( |
147 | pselOp, adaptor.getP1().getType(), pselIntr); |
148 | return success(); |
149 | } |
150 | }; |
151 | |
152 | /// Converts `vector.create_mask` ops that match the size of an SVE predicate |
153 | /// to the `whilelt` intrinsic. This produces more canonical codegen than the |
154 | /// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840 |
155 | /// for more details. Note that we can't use (the more general) active.lane.mask |
156 | /// as its semantics don't neatly map on to `vector.create_mask`, as it does an |
157 | /// unsigned comparison (whereas `create_mask` is signed), and is UB/posion if |
158 | /// `n` is zero (whereas `create_mask` just returns an all-false mask). |
159 | struct CreateMaskOpLowering |
160 | : public ConvertOpToLLVMPattern<vector::CreateMaskOp> { |
161 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
162 | |
163 | LogicalResult |
164 | matchAndRewrite(vector::CreateMaskOp createMaskOp, |
165 | vector::CreateMaskOp::Adaptor adaptor, |
166 | ConversionPatternRewriter &rewriter) const override { |
167 | auto maskType = createMaskOp.getVectorType(); |
168 | if (maskType.getRank() != 1 || !maskType.isScalable()) |
169 | return rewriter.notifyMatchFailure(createMaskOp, "not 1-D and scalable" ); |
170 | |
171 | // TODO: Support masks which are multiples of SVE predicates. |
172 | auto maskBaseSize = maskType.getDimSize(0); |
173 | if (maskBaseSize < 2 || maskBaseSize > 16 || |
174 | !llvm::isPowerOf2_32(Value: uint32_t(maskBaseSize))) |
175 | return rewriter.notifyMatchFailure(createMaskOp, |
176 | "not SVE predicate-sized" ); |
177 | |
178 | auto loc = createMaskOp.getLoc(); |
179 | auto zero = rewriter.create<LLVM::ZeroOp>(loc, rewriter.getI64Type()); |
180 | rewriter.replaceOpWithNewOp<WhileLTIntrOp>(createMaskOp, maskType, zero, |
181 | adaptor.getOperands()[0]); |
182 | return success(); |
183 | } |
184 | }; |
185 | |
186 | } // namespace |
187 | |
188 | /// Populate the given list with patterns that convert from ArmSVE to LLVM. |
189 | void mlir::populateArmSVELegalizeForLLVMExportPatterns( |
190 | const LLVMTypeConverter &converter, RewritePatternSet &patterns) { |
191 | // Populate conversion patterns |
192 | |
193 | // clang-format off |
194 | patterns.add<ConvertFromSvboolOpLowering, |
195 | ConvertToSvboolOpLowering, |
196 | DupQLaneLowering, |
197 | PselOpLowering, |
198 | ScalableMaskedAddFOpLowering, |
199 | ScalableMaskedAddIOpLowering, |
200 | ScalableMaskedDivFOpLowering, |
201 | ScalableMaskedMulFOpLowering, |
202 | ScalableMaskedMulIOpLowering, |
203 | ScalableMaskedSDivIOpLowering, |
204 | ScalableMaskedSubFOpLowering, |
205 | ScalableMaskedSubIOpLowering, |
206 | ScalableMaskedUDivIOpLowering, |
207 | SmmlaOpLowering, |
208 | UdotOpLowering, |
209 | UmmlaOpLowering, |
210 | UsmmlaOpLowering, |
211 | ZipX2OpLowering, |
212 | ZipX4OpLowering, |
213 | SdotOpLowering>(converter); |
214 | // Add vector.create_mask conversion with a high benefit as it produces much |
215 | // nicer code than the generic lowering. |
216 | patterns.add<CreateMaskOpLowering>(arg: converter, /*benefit=*/args: 4096); |
217 | // clang-format on |
218 | } |
219 | |
220 | void mlir::configureArmSVELegalizeForExportTarget( |
221 | LLVMConversionTarget &target) { |
222 | // clang-format off |
223 | target.addLegalOp<ConvertFromSvboolIntrOp, |
224 | ConvertToSvboolIntrOp, |
225 | DupQLaneIntrOp, |
226 | PselIntrOp, |
227 | ScalableMaskedAddFIntrOp, |
228 | ScalableMaskedAddIIntrOp, |
229 | ScalableMaskedDivFIntrOp, |
230 | ScalableMaskedMulFIntrOp, |
231 | ScalableMaskedMulIIntrOp, |
232 | ScalableMaskedSDivIIntrOp, |
233 | ScalableMaskedSubFIntrOp, |
234 | ScalableMaskedSubIIntrOp, |
235 | ScalableMaskedUDivIIntrOp, |
236 | SmmlaIntrOp, |
237 | UdotIntrOp, |
238 | UmmlaIntrOp, |
239 | UsmmlaIntrOp, |
240 | WhileLTIntrOp, |
241 | ZipX2IntrOp, |
242 | ZipX4IntrOp, |
243 | SdotIntrOp>(); |
244 | target.addIllegalOp<ConvertFromSvboolOp, |
245 | ConvertToSvboolOp, |
246 | DupQLaneOp, |
247 | PselOp, |
248 | ScalableMaskedAddFOp, |
249 | ScalableMaskedAddIOp, |
250 | ScalableMaskedDivFOp, |
251 | ScalableMaskedMulFOp, |
252 | ScalableMaskedMulIOp, |
253 | ScalableMaskedSDivIOp, |
254 | ScalableMaskedSubFOp, |
255 | ScalableMaskedSubIOp, |
256 | ScalableMaskedUDivIOp, |
257 | SmmlaOp, |
258 | UdotOp, |
259 | UmmlaOp, |
260 | UsmmlaOp, |
261 | ZipX2Op, |
262 | ZipX4Op, |
263 | SdotOp>(); |
264 | // clang-format on |
265 | } |
266 | |