1//===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
10#include "mlir/Conversion/LLVMCommon/Pattern.h"
11#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
12#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15#include "mlir/Dialect/Utils/IndexingUtils.h"
16#include "mlir/Dialect/Vector/IR/VectorOps.h"
17#include "mlir/IR/BuiltinOps.h"
18#include "mlir/IR/PatternMatch.h"
19
20using namespace mlir;
21using namespace mlir::arm_sve;
22
23template <typename OpTy>
24class ForwardOperands : public OpConversionPattern<OpTy> {
25 using OpConversionPattern<OpTy>::OpConversionPattern;
26
27 LogicalResult
28 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
29 ConversionPatternRewriter &rewriter) const final {
30 if (adaptor.getOperands().getTypes() == op->getOperands().getTypes())
31 return rewriter.notifyMatchFailure(op, "operand types already match");
32
33 rewriter.modifyOpInPlace(op,
34 [&]() { op->setOperands(adaptor.getOperands()); });
35 return success();
36 }
37};
38
39using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
40using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
41using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
42using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
43using ScalableMaskedAddIOpLowering =
44 OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
45 ScalableMaskedAddIIntrOp>;
46using ScalableMaskedAddFOpLowering =
47 OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
48 ScalableMaskedAddFIntrOp>;
49using ScalableMaskedSubIOpLowering =
50 OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
51 ScalableMaskedSubIIntrOp>;
52using ScalableMaskedSubFOpLowering =
53 OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
54 ScalableMaskedSubFIntrOp>;
55using ScalableMaskedMulIOpLowering =
56 OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
57 ScalableMaskedMulIIntrOp>;
58using ScalableMaskedMulFOpLowering =
59 OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
60 ScalableMaskedMulFIntrOp>;
61using ScalableMaskedSDivIOpLowering =
62 OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
63 ScalableMaskedSDivIIntrOp>;
64using ScalableMaskedUDivIOpLowering =
65 OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
66 ScalableMaskedUDivIIntrOp>;
67using ScalableMaskedDivFOpLowering =
68 OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
69 ScalableMaskedDivFIntrOp>;
70
71namespace {
72
73/// Unrolls a conversion to/from equivalent vector types, to allow using a
74/// conversion intrinsic that only supports 1-D vector types.
75///
76/// Example:
77/// ```
78/// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1>
79/// ```
80/// is rewritten into:
81/// ```
82/// %cst = arith.constant dense<false> : vector<2x[16]xi1>
83/// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1>
84/// %2 = "arm_sve.intr.convert.to.svbool"(%1)
85/// : (vector<[4]xi1>) -> vector<[16]xi1>
86/// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1>
87/// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1>
88/// %5 = "arm_sve.intr.convert.to.svbool"(%4)
89/// : (vector<[4]xi1>) -> vector<[16]xi1>
90/// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1>
91/// ```
92template <typename Op, typename IntrOp>
93struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
94 using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;
95
96 LogicalResult
97 matchAndRewrite(Op convertOp, typename Op::Adaptor,
98 ConversionPatternRewriter &rewriter) const override {
99 auto loc = convertOp.getLoc();
100
101 auto source = convertOp.getSource();
102 VectorType sourceType = source.getType();
103 VectorType resultType = convertOp.getResult().getType();
104
105 Value result = rewriter.create<arith::ConstantOp>(
106 loc, resultType, rewriter.getZeroAttr(resultType));
107
108 // We want to iterate over the input vector in steps of the trailing
109 // dimension. So this creates tile shape where all leading dimensions are 1,
110 // and the trailing dimension step is the size of the dimension.
111 SmallVector<int64_t> tileShape(sourceType.getRank(), 1);
112 tileShape.back() = sourceType.getShape().back();
113
114 // Iterate over all scalable mask/predicate slices of the source vector.
115 for (SmallVector<int64_t> index :
116 StaticTileOffsetRange(sourceType.getShape(), tileShape)) {
117 auto extractOrInsertPosition = ArrayRef(index).drop_back();
118 auto sourceVector = rewriter.create<vector::ExtractOp>(
119 loc, source, extractOrInsertPosition);
120 VectorType convertedType =
121 VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType()))
122 .setDim(0, resultType.getShape().back());
123 auto convertedVector =
124 rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector);
125 result = rewriter.create<vector::InsertOp>(loc, convertedVector, result,
126 extractOrInsertPosition);
127 }
128
129 rewriter.replaceOp(convertOp, result);
130 return success();
131 }
132};
133
134using ConvertToSvboolOpLowering =
135 SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>;
136
137using ConvertFromSvboolOpLowering =
138 SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
139
140using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>;
141using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>;
142
143} // namespace
144
145/// Populate the given list with patterns that convert from ArmSVE to LLVM.
146void mlir::populateArmSVELegalizeForLLVMExportPatterns(
147 LLVMTypeConverter &converter, RewritePatternSet &patterns) {
148 // Populate conversion patterns
149
150 // clang-format off
151 patterns.add<ForwardOperands<func::CallOp>,
152 ForwardOperands<func::CallIndirectOp>,
153 ForwardOperands<func::ReturnOp>>(converter,
154 &converter.getContext());
155 patterns.add<SdotOpLowering,
156 SmmlaOpLowering,
157 UdotOpLowering,
158 UmmlaOpLowering,
159 ScalableMaskedAddIOpLowering,
160 ScalableMaskedAddFOpLowering,
161 ScalableMaskedSubIOpLowering,
162 ScalableMaskedSubFOpLowering,
163 ScalableMaskedMulIOpLowering,
164 ScalableMaskedMulFOpLowering,
165 ScalableMaskedSDivIOpLowering,
166 ScalableMaskedUDivIOpLowering,
167 ScalableMaskedDivFOpLowering,
168 ConvertToSvboolOpLowering,
169 ConvertFromSvboolOpLowering,
170 ZipX2OpLowering,
171 ZipX4OpLowering>(converter);
172 // clang-format on
173}
174
175void mlir::configureArmSVELegalizeForExportTarget(
176 LLVMConversionTarget &target) {
177 // clang-format off
178 target.addLegalOp<SdotIntrOp,
179 SmmlaIntrOp,
180 UdotIntrOp,
181 UmmlaIntrOp,
182 ScalableMaskedAddIIntrOp,
183 ScalableMaskedAddFIntrOp,
184 ScalableMaskedSubIIntrOp,
185 ScalableMaskedSubFIntrOp,
186 ScalableMaskedMulIIntrOp,
187 ScalableMaskedMulFIntrOp,
188 ScalableMaskedSDivIIntrOp,
189 ScalableMaskedUDivIIntrOp,
190 ScalableMaskedDivFIntrOp,
191 ConvertToSvboolIntrOp,
192 ConvertFromSvboolIntrOp,
193 ZipX2IntrOp,
194 ZipX4IntrOp>();
195 target.addIllegalOp<SdotOp,
196 SmmlaOp,
197 UdotOp,
198 UmmlaOp,
199 ScalableMaskedAddIOp,
200 ScalableMaskedAddFOp,
201 ScalableMaskedSubIOp,
202 ScalableMaskedSubFOp,
203 ScalableMaskedMulIOp,
204 ScalableMaskedMulFOp,
205 ScalableMaskedSDivIOp,
206 ScalableMaskedUDivIOp,
207 ScalableMaskedDivFOp,
208 ConvertToSvboolOp,
209 ConvertFromSvboolOp,
210 ZipX2Op,
211 ZipX4Op>();
212 // clang-format on
213}
214

source code of mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp