1//===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
10#include "mlir/Conversion/LLVMCommon/Pattern.h"
11#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
12#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
13#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14#include "mlir/Dialect/Utils/IndexingUtils.h"
15#include "mlir/Dialect/Vector/IR/VectorOps.h"
16#include "mlir/IR/PatternMatch.h"
17
18using namespace mlir;
19using namespace mlir::arm_sve;
20
21using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
22using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
23using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
24using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
25using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
26using DupQLaneLowering =
27 OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
28using ScalableMaskedAddIOpLowering =
29 OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
30 ScalableMaskedAddIIntrOp>;
31using ScalableMaskedAddFOpLowering =
32 OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
33 ScalableMaskedAddFIntrOp>;
34using ScalableMaskedSubIOpLowering =
35 OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
36 ScalableMaskedSubIIntrOp>;
37using ScalableMaskedSubFOpLowering =
38 OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
39 ScalableMaskedSubFIntrOp>;
40using ScalableMaskedMulIOpLowering =
41 OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
42 ScalableMaskedMulIIntrOp>;
43using ScalableMaskedMulFOpLowering =
44 OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
45 ScalableMaskedMulFIntrOp>;
46using ScalableMaskedSDivIOpLowering =
47 OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
48 ScalableMaskedSDivIIntrOp>;
49using ScalableMaskedUDivIOpLowering =
50 OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
51 ScalableMaskedUDivIIntrOp>;
52using ScalableMaskedDivFOpLowering =
53 OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
54 ScalableMaskedDivFIntrOp>;
55
56namespace {
57
58/// Unrolls a conversion to/from equivalent vector types, to allow using a
59/// conversion intrinsic that only supports 1-D vector types.
60///
61/// Example:
62/// ```
63/// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1>
64/// ```
65/// is rewritten into:
66/// ```
67/// %cst = arith.constant dense<false> : vector<2x[16]xi1>
68/// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1>
69/// %2 = "arm_sve.intr.convert.to.svbool"(%1)
70/// : (vector<[4]xi1>) -> vector<[16]xi1>
71/// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1>
72/// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1>
73/// %5 = "arm_sve.intr.convert.to.svbool"(%4)
74/// : (vector<[4]xi1>) -> vector<[16]xi1>
75/// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1>
76/// ```
77template <typename Op, typename IntrOp>
78struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
79 using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;
80
81 LogicalResult
82 matchAndRewrite(Op convertOp, typename Op::Adaptor,
83 ConversionPatternRewriter &rewriter) const override {
84 auto loc = convertOp.getLoc();
85
86 auto source = convertOp.getSource();
87 VectorType sourceType = source.getType();
88 VectorType resultType = convertOp.getResult().getType();
89
90 Value result = rewriter.create<arith::ConstantOp>(
91 loc, resultType, rewriter.getZeroAttr(type: resultType));
92
93 // We want to iterate over the input vector in steps of the trailing
94 // dimension. So this creates tile shape where all leading dimensions are 1,
95 // and the trailing dimension step is the size of the dimension.
96 SmallVector<int64_t> tileShape(sourceType.getRank(), 1);
97 tileShape.back() = sourceType.getShape().back();
98
99 // Iterate over all scalable mask/predicate slices of the source vector.
100 for (SmallVector<int64_t> index :
101 StaticTileOffsetRange(sourceType.getShape(), tileShape)) {
102 auto extractOrInsertPosition = ArrayRef(index).drop_back();
103 auto sourceVector = rewriter.create<vector::ExtractOp>(
104 loc, source, extractOrInsertPosition);
105 VectorType convertedType =
106 VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType()))
107 .setDim(pos: 0, val: resultType.getShape().back());
108 auto convertedVector =
109 rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector);
110 result = rewriter.create<vector::InsertOp>(loc, convertedVector, result,
111 extractOrInsertPosition);
112 }
113
114 rewriter.replaceOp(convertOp, result);
115 return success();
116 }
117};
118
119using ConvertToSvboolOpLowering =
120 SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>;
121
122using ConvertFromSvboolOpLowering =
123 SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
124
125using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>;
126using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>;
127
128/// Lower `arm_sve.psel` to LLVM intrinsics. This is almost a 1-to-1 conversion
129/// but first input (P1) and result predicates need conversion to/from svbool.
130struct PselOpLowering : public ConvertOpToLLVMPattern<PselOp> {
131 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
132
133 LogicalResult
134 matchAndRewrite(PselOp pselOp, PselOp::Adaptor adaptor,
135 ConversionPatternRewriter &rewriter) const override {
136 auto svboolType = VectorType::get(shape: 16, elementType: rewriter.getI1Type(), scalableDims: true);
137 auto loc = pselOp.getLoc();
138 auto svboolP1 = rewriter.create<ConvertToSvboolIntrOp>(location: loc, args&: svboolType,
139 args: adaptor.getP1());
140 auto indexI32 = rewriter.create<arith::IndexCastOp>(
141 location: loc, args: rewriter.getI32Type(), args: pselOp.getIndex());
142 auto pselIntr = rewriter.create<PselIntrOp>(location: loc, args&: svboolType, args&: svboolP1,
143 args: pselOp.getP2(), args&: indexI32);
144 rewriter.replaceOpWithNewOp<ConvertFromSvboolIntrOp>(
145 op: pselOp, args: adaptor.getP1().getType(), args&: pselIntr);
146 return success();
147 }
148};
149
150/// Converts `vector.create_mask` ops that match the size of an SVE predicate
151/// to the `whilelt` intrinsic. This produces more canonical codegen than the
152/// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840
153/// for more details. Note that we can't use (the more general) active.lane.mask
154/// as its semantics don't neatly map on to `vector.create_mask`, as it does an
155/// unsigned comparison (whereas `create_mask` is signed), and is UB/posion if
156/// `n` is zero (whereas `create_mask` just returns an all-false mask).
157struct CreateMaskOpLowering
158 : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
159 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
160
161 LogicalResult
162 matchAndRewrite(vector::CreateMaskOp createMaskOp,
163 vector::CreateMaskOp::Adaptor adaptor,
164 ConversionPatternRewriter &rewriter) const override {
165 auto maskType = createMaskOp.getVectorType();
166 if (maskType.getRank() != 1 || !maskType.isScalable())
167 return rewriter.notifyMatchFailure(arg&: createMaskOp, msg: "not 1-D and scalable");
168
169 // TODO: Support masks which are multiples of SVE predicates.
170 auto maskBaseSize = maskType.getDimSize(idx: 0);
171 if (maskBaseSize < 2 || maskBaseSize > 16 ||
172 !llvm::isPowerOf2_32(Value: uint32_t(maskBaseSize)))
173 return rewriter.notifyMatchFailure(arg&: createMaskOp,
174 msg: "not SVE predicate-sized");
175
176 auto loc = createMaskOp.getLoc();
177 auto zero = rewriter.create<LLVM::ZeroOp>(location: loc, args: rewriter.getI64Type());
178 rewriter.replaceOpWithNewOp<WhileLTIntrOp>(op: createMaskOp, args&: maskType, args&: zero,
179 args: adaptor.getOperands()[0]);
180 return success();
181 }
182};
183
184} // namespace
185
186/// Populate the given list with patterns that convert from ArmSVE to LLVM.
187void mlir::populateArmSVELegalizeForLLVMExportPatterns(
188 const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
189 // Populate conversion patterns
190
191 // clang-format off
192 patterns.add<ConvertFromSvboolOpLowering,
193 ConvertToSvboolOpLowering,
194 DupQLaneLowering,
195 PselOpLowering,
196 ScalableMaskedAddFOpLowering,
197 ScalableMaskedAddIOpLowering,
198 ScalableMaskedDivFOpLowering,
199 ScalableMaskedMulFOpLowering,
200 ScalableMaskedMulIOpLowering,
201 ScalableMaskedSDivIOpLowering,
202 ScalableMaskedSubFOpLowering,
203 ScalableMaskedSubIOpLowering,
204 ScalableMaskedUDivIOpLowering,
205 SmmlaOpLowering,
206 UdotOpLowering,
207 UmmlaOpLowering,
208 UsmmlaOpLowering,
209 ZipX2OpLowering,
210 ZipX4OpLowering,
211 SdotOpLowering>(arg: converter);
212 // Add vector.create_mask conversion with a high benefit as it produces much
213 // nicer code than the generic lowering.
214 patterns.add<CreateMaskOpLowering>(arg: converter, /*benefit=*/args: 4096);
215 // clang-format on
216}
217
218void mlir::configureArmSVELegalizeForExportTarget(
219 LLVMConversionTarget &target) {
220 // clang-format off
221 target.addLegalOp<BfmmlaOp,
222 ConvertFromSvboolIntrOp,
223 ConvertToSvboolIntrOp,
224 DupQLaneIntrOp,
225 PselIntrOp,
226 ScalableMaskedAddFIntrOp,
227 ScalableMaskedAddIIntrOp,
228 ScalableMaskedDivFIntrOp,
229 ScalableMaskedMulFIntrOp,
230 ScalableMaskedMulIIntrOp,
231 ScalableMaskedSDivIIntrOp,
232 ScalableMaskedSubFIntrOp,
233 ScalableMaskedSubIIntrOp,
234 ScalableMaskedUDivIIntrOp,
235 SmmlaIntrOp,
236 UdotIntrOp,
237 UmmlaIntrOp,
238 UsmmlaIntrOp,
239 WhileLTIntrOp,
240 ZipX2IntrOp,
241 ZipX4IntrOp,
242 SdotIntrOp>();
243 target.addIllegalOp<ConvertFromSvboolOp,
244 ConvertToSvboolOp,
245 DupQLaneOp,
246 PselOp,
247 ScalableMaskedAddFOp,
248 ScalableMaskedAddIOp,
249 ScalableMaskedDivFOp,
250 ScalableMaskedMulFOp,
251 ScalableMaskedMulIOp,
252 ScalableMaskedSDivIOp,
253 ScalableMaskedSubFOp,
254 ScalableMaskedSubIOp,
255 ScalableMaskedUDivIOp,
256 SmmlaOp,
257 UdotOp,
258 UmmlaOp,
259 UsmmlaOp,
260 ZipX2Op,
261 ZipX4Op,
262 SdotOp>();
263 // clang-format on
264}
265

source code of mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp