1//===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
10#include "mlir/Conversion/LLVMCommon/Pattern.h"
11#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
12#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15#include "mlir/Dialect/Utils/IndexingUtils.h"
16#include "mlir/Dialect/Vector/IR/VectorOps.h"
17#include "mlir/IR/BuiltinOps.h"
18#include "mlir/IR/PatternMatch.h"
19
20using namespace mlir;
21using namespace mlir::arm_sve;
22
23using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
24using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
25using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
26using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
27using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
28using DupQLaneLowering =
29 OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
30using ScalableMaskedAddIOpLowering =
31 OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
32 ScalableMaskedAddIIntrOp>;
33using ScalableMaskedAddFOpLowering =
34 OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
35 ScalableMaskedAddFIntrOp>;
36using ScalableMaskedSubIOpLowering =
37 OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
38 ScalableMaskedSubIIntrOp>;
39using ScalableMaskedSubFOpLowering =
40 OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
41 ScalableMaskedSubFIntrOp>;
42using ScalableMaskedMulIOpLowering =
43 OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
44 ScalableMaskedMulIIntrOp>;
45using ScalableMaskedMulFOpLowering =
46 OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
47 ScalableMaskedMulFIntrOp>;
48using ScalableMaskedSDivIOpLowering =
49 OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
50 ScalableMaskedSDivIIntrOp>;
51using ScalableMaskedUDivIOpLowering =
52 OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
53 ScalableMaskedUDivIIntrOp>;
54using ScalableMaskedDivFOpLowering =
55 OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
56 ScalableMaskedDivFIntrOp>;
57
58namespace {
59
60/// Unrolls a conversion to/from equivalent vector types, to allow using a
61/// conversion intrinsic that only supports 1-D vector types.
62///
63/// Example:
64/// ```
65/// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1>
66/// ```
67/// is rewritten into:
68/// ```
69/// %cst = arith.constant dense<false> : vector<2x[16]xi1>
70/// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1>
71/// %2 = "arm_sve.intr.convert.to.svbool"(%1)
72/// : (vector<[4]xi1>) -> vector<[16]xi1>
73/// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1>
74/// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1>
75/// %5 = "arm_sve.intr.convert.to.svbool"(%4)
76/// : (vector<[4]xi1>) -> vector<[16]xi1>
77/// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1>
78/// ```
79template <typename Op, typename IntrOp>
80struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
81 using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;
82
83 LogicalResult
84 matchAndRewrite(Op convertOp, typename Op::Adaptor,
85 ConversionPatternRewriter &rewriter) const override {
86 auto loc = convertOp.getLoc();
87
88 auto source = convertOp.getSource();
89 VectorType sourceType = source.getType();
90 VectorType resultType = convertOp.getResult().getType();
91
92 Value result = rewriter.create<arith::ConstantOp>(
93 loc, resultType, rewriter.getZeroAttr(resultType));
94
95 // We want to iterate over the input vector in steps of the trailing
96 // dimension. So this creates tile shape where all leading dimensions are 1,
97 // and the trailing dimension step is the size of the dimension.
98 SmallVector<int64_t> tileShape(sourceType.getRank(), 1);
99 tileShape.back() = sourceType.getShape().back();
100
101 // Iterate over all scalable mask/predicate slices of the source vector.
102 for (SmallVector<int64_t> index :
103 StaticTileOffsetRange(sourceType.getShape(), tileShape)) {
104 auto extractOrInsertPosition = ArrayRef(index).drop_back();
105 auto sourceVector = rewriter.create<vector::ExtractOp>(
106 loc, source, extractOrInsertPosition);
107 VectorType convertedType =
108 VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType()))
109 .setDim(0, resultType.getShape().back());
110 auto convertedVector =
111 rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector);
112 result = rewriter.create<vector::InsertOp>(loc, convertedVector, result,
113 extractOrInsertPosition);
114 }
115
116 rewriter.replaceOp(convertOp, result);
117 return success();
118 }
119};
120
121using ConvertToSvboolOpLowering =
122 SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>;
123
124using ConvertFromSvboolOpLowering =
125 SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
126
127using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>;
128using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>;
129
130/// Lower `arm_sve.psel` to LLVM intrinsics. This is almost a 1-to-1 conversion
131/// but first input (P1) and result predicates need conversion to/from svbool.
132struct PselOpLowering : public ConvertOpToLLVMPattern<PselOp> {
133 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
134
135 LogicalResult
136 matchAndRewrite(PselOp pselOp, PselOp::Adaptor adaptor,
137 ConversionPatternRewriter &rewriter) const override {
138 auto svboolType = VectorType::get(16, rewriter.getI1Type(), true);
139 auto loc = pselOp.getLoc();
140 auto svboolP1 = rewriter.create<ConvertToSvboolIntrOp>(loc, svboolType,
141 adaptor.getP1());
142 auto indexI32 = rewriter.create<arith::IndexCastOp>(
143 loc, rewriter.getI32Type(), pselOp.getIndex());
144 auto pselIntr = rewriter.create<PselIntrOp>(loc, svboolType, svboolP1,
145 pselOp.getP2(), indexI32);
146 rewriter.replaceOpWithNewOp<ConvertFromSvboolIntrOp>(
147 pselOp, adaptor.getP1().getType(), pselIntr);
148 return success();
149 }
150};
151
152/// Converts `vector.create_mask` ops that match the size of an SVE predicate
153/// to the `whilelt` intrinsic. This produces more canonical codegen than the
154/// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840
155/// for more details. Note that we can't use (the more general) active.lane.mask
156/// as its semantics don't neatly map on to `vector.create_mask`, as it does an
157/// unsigned comparison (whereas `create_mask` is signed), and is UB/posion if
158/// `n` is zero (whereas `create_mask` just returns an all-false mask).
159struct CreateMaskOpLowering
160 : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
161 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
162
163 LogicalResult
164 matchAndRewrite(vector::CreateMaskOp createMaskOp,
165 vector::CreateMaskOp::Adaptor adaptor,
166 ConversionPatternRewriter &rewriter) const override {
167 auto maskType = createMaskOp.getVectorType();
168 if (maskType.getRank() != 1 || !maskType.isScalable())
169 return rewriter.notifyMatchFailure(createMaskOp, "not 1-D and scalable");
170
171 // TODO: Support masks which are multiples of SVE predicates.
172 auto maskBaseSize = maskType.getDimSize(0);
173 if (maskBaseSize < 2 || maskBaseSize > 16 ||
174 !llvm::isPowerOf2_32(Value: uint32_t(maskBaseSize)))
175 return rewriter.notifyMatchFailure(createMaskOp,
176 "not SVE predicate-sized");
177
178 auto loc = createMaskOp.getLoc();
179 auto zero = rewriter.create<LLVM::ZeroOp>(loc, rewriter.getI64Type());
180 rewriter.replaceOpWithNewOp<WhileLTIntrOp>(createMaskOp, maskType, zero,
181 adaptor.getOperands()[0]);
182 return success();
183 }
184};
185
186} // namespace
187
188/// Populate the given list with patterns that convert from ArmSVE to LLVM.
189void mlir::populateArmSVELegalizeForLLVMExportPatterns(
190 const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
191 // Populate conversion patterns
192
193 // clang-format off
194 patterns.add<ConvertFromSvboolOpLowering,
195 ConvertToSvboolOpLowering,
196 DupQLaneLowering,
197 PselOpLowering,
198 ScalableMaskedAddFOpLowering,
199 ScalableMaskedAddIOpLowering,
200 ScalableMaskedDivFOpLowering,
201 ScalableMaskedMulFOpLowering,
202 ScalableMaskedMulIOpLowering,
203 ScalableMaskedSDivIOpLowering,
204 ScalableMaskedSubFOpLowering,
205 ScalableMaskedSubIOpLowering,
206 ScalableMaskedUDivIOpLowering,
207 SmmlaOpLowering,
208 UdotOpLowering,
209 UmmlaOpLowering,
210 UsmmlaOpLowering,
211 ZipX2OpLowering,
212 ZipX4OpLowering,
213 SdotOpLowering>(converter);
214 // Add vector.create_mask conversion with a high benefit as it produces much
215 // nicer code than the generic lowering.
216 patterns.add<CreateMaskOpLowering>(arg: converter, /*benefit=*/args: 4096);
217 // clang-format on
218}
219
220void mlir::configureArmSVELegalizeForExportTarget(
221 LLVMConversionTarget &target) {
222 // clang-format off
223 target.addLegalOp<ConvertFromSvboolIntrOp,
224 ConvertToSvboolIntrOp,
225 DupQLaneIntrOp,
226 PselIntrOp,
227 ScalableMaskedAddFIntrOp,
228 ScalableMaskedAddIIntrOp,
229 ScalableMaskedDivFIntrOp,
230 ScalableMaskedMulFIntrOp,
231 ScalableMaskedMulIIntrOp,
232 ScalableMaskedSDivIIntrOp,
233 ScalableMaskedSubFIntrOp,
234 ScalableMaskedSubIIntrOp,
235 ScalableMaskedUDivIIntrOp,
236 SmmlaIntrOp,
237 UdotIntrOp,
238 UmmlaIntrOp,
239 UsmmlaIntrOp,
240 WhileLTIntrOp,
241 ZipX2IntrOp,
242 ZipX4IntrOp,
243 SdotIntrOp>();
244 target.addIllegalOp<ConvertFromSvboolOp,
245 ConvertToSvboolOp,
246 DupQLaneOp,
247 PselOp,
248 ScalableMaskedAddFOp,
249 ScalableMaskedAddIOp,
250 ScalableMaskedDivFOp,
251 ScalableMaskedMulFOp,
252 ScalableMaskedMulIOp,
253 ScalableMaskedSDivIOp,
254 ScalableMaskedSubFOp,
255 ScalableMaskedSubIOp,
256 ScalableMaskedUDivIOp,
257 SmmlaOp,
258 UdotOp,
259 UmmlaOp,
260 UsmmlaOp,
261 ZipX2Op,
262 ZipX4Op,
263 SdotOp>();
264 // clang-format on
265}
266

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp