1 | //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
10 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
11 | #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" |
12 | #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" |
13 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
14 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
15 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
16 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
17 | #include "mlir/IR/BuiltinOps.h" |
18 | #include "mlir/IR/PatternMatch.h" |
19 | |
20 | using namespace mlir; |
21 | using namespace mlir::arm_sve; |
22 | |
23 | template <typename OpTy> |
24 | class ForwardOperands : public OpConversionPattern<OpTy> { |
25 | using OpConversionPattern<OpTy>::OpConversionPattern; |
26 | |
27 | LogicalResult |
28 | matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, |
29 | ConversionPatternRewriter &rewriter) const final { |
30 | if (adaptor.getOperands().getTypes() == op->getOperands().getTypes()) |
31 | return rewriter.notifyMatchFailure(op, "operand types already match" ); |
32 | |
33 | rewriter.modifyOpInPlace(op, |
34 | [&]() { op->setOperands(adaptor.getOperands()); }); |
35 | return success(); |
36 | } |
37 | }; |
38 | |
39 | using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>; |
40 | using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>; |
41 | using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>; |
42 | using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>; |
43 | using ScalableMaskedAddIOpLowering = |
44 | OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp, |
45 | ScalableMaskedAddIIntrOp>; |
46 | using ScalableMaskedAddFOpLowering = |
47 | OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp, |
48 | ScalableMaskedAddFIntrOp>; |
49 | using ScalableMaskedSubIOpLowering = |
50 | OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp, |
51 | ScalableMaskedSubIIntrOp>; |
52 | using ScalableMaskedSubFOpLowering = |
53 | OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp, |
54 | ScalableMaskedSubFIntrOp>; |
55 | using ScalableMaskedMulIOpLowering = |
56 | OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp, |
57 | ScalableMaskedMulIIntrOp>; |
58 | using ScalableMaskedMulFOpLowering = |
59 | OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp, |
60 | ScalableMaskedMulFIntrOp>; |
61 | using ScalableMaskedSDivIOpLowering = |
62 | OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp, |
63 | ScalableMaskedSDivIIntrOp>; |
64 | using ScalableMaskedUDivIOpLowering = |
65 | OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp, |
66 | ScalableMaskedUDivIIntrOp>; |
67 | using ScalableMaskedDivFOpLowering = |
68 | OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp, |
69 | ScalableMaskedDivFIntrOp>; |
70 | |
71 | namespace { |
72 | |
73 | /// Unrolls a conversion to/from equivalent vector types, to allow using a |
74 | /// conversion intrinsic that only supports 1-D vector types. |
75 | /// |
76 | /// Example: |
77 | /// ``` |
78 | /// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1> |
79 | /// ``` |
80 | /// is rewritten into: |
81 | /// ``` |
82 | /// %cst = arith.constant dense<false> : vector<2x[16]xi1> |
83 | /// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1> |
84 | /// %2 = "arm_sve.intr.convert.to.svbool"(%1) |
85 | /// : (vector<[4]xi1>) -> vector<[16]xi1> |
86 | /// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1> |
87 | /// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1> |
88 | /// %5 = "arm_sve.intr.convert.to.svbool"(%4) |
89 | /// : (vector<[4]xi1>) -> vector<[16]xi1> |
90 | /// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1> |
91 | /// ``` |
92 | template <typename Op, typename IntrOp> |
93 | struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> { |
94 | using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern; |
95 | |
96 | LogicalResult |
97 | matchAndRewrite(Op convertOp, typename Op::Adaptor, |
98 | ConversionPatternRewriter &rewriter) const override { |
99 | auto loc = convertOp.getLoc(); |
100 | |
101 | auto source = convertOp.getSource(); |
102 | VectorType sourceType = source.getType(); |
103 | VectorType resultType = convertOp.getResult().getType(); |
104 | |
105 | Value result = rewriter.create<arith::ConstantOp>( |
106 | loc, resultType, rewriter.getZeroAttr(resultType)); |
107 | |
108 | // We want to iterate over the input vector in steps of the trailing |
109 | // dimension. So this creates tile shape where all leading dimensions are 1, |
110 | // and the trailing dimension step is the size of the dimension. |
111 | SmallVector<int64_t> tileShape(sourceType.getRank(), 1); |
112 | tileShape.back() = sourceType.getShape().back(); |
113 | |
114 | // Iterate over all scalable mask/predicate slices of the source vector. |
115 | for (SmallVector<int64_t> index : |
116 | StaticTileOffsetRange(sourceType.getShape(), tileShape)) { |
117 | auto extractOrInsertPosition = ArrayRef(index).drop_back(); |
118 | auto sourceVector = rewriter.create<vector::ExtractOp>( |
119 | loc, source, extractOrInsertPosition); |
120 | VectorType convertedType = |
121 | VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType())) |
122 | .setDim(0, resultType.getShape().back()); |
123 | auto convertedVector = |
124 | rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector); |
125 | result = rewriter.create<vector::InsertOp>(loc, convertedVector, result, |
126 | extractOrInsertPosition); |
127 | } |
128 | |
129 | rewriter.replaceOp(convertOp, result); |
130 | return success(); |
131 | } |
132 | }; |
133 | |
134 | using ConvertToSvboolOpLowering = |
135 | SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>; |
136 | |
137 | using ConvertFromSvboolOpLowering = |
138 | SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>; |
139 | |
140 | using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>; |
141 | using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>; |
142 | |
143 | } // namespace |
144 | |
145 | /// Populate the given list with patterns that convert from ArmSVE to LLVM. |
146 | void mlir::populateArmSVELegalizeForLLVMExportPatterns( |
147 | LLVMTypeConverter &converter, RewritePatternSet &patterns) { |
148 | // Populate conversion patterns |
149 | |
150 | // clang-format off |
151 | patterns.add<ForwardOperands<func::CallOp>, |
152 | ForwardOperands<func::CallIndirectOp>, |
153 | ForwardOperands<func::ReturnOp>>(converter, |
154 | &converter.getContext()); |
155 | patterns.add<SdotOpLowering, |
156 | SmmlaOpLowering, |
157 | UdotOpLowering, |
158 | UmmlaOpLowering, |
159 | ScalableMaskedAddIOpLowering, |
160 | ScalableMaskedAddFOpLowering, |
161 | ScalableMaskedSubIOpLowering, |
162 | ScalableMaskedSubFOpLowering, |
163 | ScalableMaskedMulIOpLowering, |
164 | ScalableMaskedMulFOpLowering, |
165 | ScalableMaskedSDivIOpLowering, |
166 | ScalableMaskedUDivIOpLowering, |
167 | ScalableMaskedDivFOpLowering, |
168 | ConvertToSvboolOpLowering, |
169 | ConvertFromSvboolOpLowering, |
170 | ZipX2OpLowering, |
171 | ZipX4OpLowering>(converter); |
172 | // clang-format on |
173 | } |
174 | |
175 | void mlir::configureArmSVELegalizeForExportTarget( |
176 | LLVMConversionTarget &target) { |
177 | // clang-format off |
178 | target.addLegalOp<SdotIntrOp, |
179 | SmmlaIntrOp, |
180 | UdotIntrOp, |
181 | UmmlaIntrOp, |
182 | ScalableMaskedAddIIntrOp, |
183 | ScalableMaskedAddFIntrOp, |
184 | ScalableMaskedSubIIntrOp, |
185 | ScalableMaskedSubFIntrOp, |
186 | ScalableMaskedMulIIntrOp, |
187 | ScalableMaskedMulFIntrOp, |
188 | ScalableMaskedSDivIIntrOp, |
189 | ScalableMaskedUDivIIntrOp, |
190 | ScalableMaskedDivFIntrOp, |
191 | ConvertToSvboolIntrOp, |
192 | ConvertFromSvboolIntrOp, |
193 | ZipX2IntrOp, |
194 | ZipX4IntrOp>(); |
195 | target.addIllegalOp<SdotOp, |
196 | SmmlaOp, |
197 | UdotOp, |
198 | UmmlaOp, |
199 | ScalableMaskedAddIOp, |
200 | ScalableMaskedAddFOp, |
201 | ScalableMaskedSubIOp, |
202 | ScalableMaskedSubFOp, |
203 | ScalableMaskedMulIOp, |
204 | ScalableMaskedMulFOp, |
205 | ScalableMaskedSDivIOp, |
206 | ScalableMaskedUDivIOp, |
207 | ScalableMaskedDivFOp, |
208 | ConvertToSvboolOp, |
209 | ConvertFromSvboolOp, |
210 | ZipX2Op, |
211 | ZipX4Op>(); |
212 | // clang-format on |
213 | } |
214 | |