1 | //===- IndexToSPIRV.cpp - Index to SPIRV dialect conversion -----*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h" |
10 | #include "../SPIRVCommon/Pattern.h" |
11 | #include "mlir/Dialect/Index/IR/IndexDialect.h" |
12 | #include "mlir/Dialect/Index/IR/IndexOps.h" |
13 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
14 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
15 | #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" |
16 | #include "mlir/Pass/Pass.h" |
17 | |
18 | using namespace mlir; |
19 | using namespace index; |
20 | |
21 | namespace { |
22 | |
23 | //===----------------------------------------------------------------------===// |
24 | // Trivial Conversions |
25 | //===----------------------------------------------------------------------===// |
26 | |
27 | using ConvertIndexAdd = spirv::ElementwiseOpPattern<AddOp, spirv::IAddOp>; |
28 | using ConvertIndexSub = spirv::ElementwiseOpPattern<SubOp, spirv::ISubOp>; |
29 | using ConvertIndexMul = spirv::ElementwiseOpPattern<MulOp, spirv::IMulOp>; |
30 | using ConvertIndexDivS = spirv::ElementwiseOpPattern<DivSOp, spirv::SDivOp>; |
31 | using ConvertIndexDivU = spirv::ElementwiseOpPattern<DivUOp, spirv::UDivOp>; |
32 | using ConvertIndexRemS = spirv::ElementwiseOpPattern<RemSOp, spirv::SRemOp>; |
33 | using ConvertIndexRemU = spirv::ElementwiseOpPattern<RemUOp, spirv::UModOp>; |
34 | using ConvertIndexMaxS = spirv::ElementwiseOpPattern<MaxSOp, spirv::GLSMaxOp>; |
35 | using ConvertIndexMaxU = spirv::ElementwiseOpPattern<MaxUOp, spirv::GLUMaxOp>; |
36 | using ConvertIndexMinS = spirv::ElementwiseOpPattern<MinSOp, spirv::GLSMinOp>; |
37 | using ConvertIndexMinU = spirv::ElementwiseOpPattern<MinUOp, spirv::GLUMinOp>; |
38 | |
39 | using ConvertIndexShl = |
40 | spirv::ElementwiseOpPattern<ShlOp, spirv::ShiftLeftLogicalOp>; |
41 | using ConvertIndexShrS = |
42 | spirv::ElementwiseOpPattern<ShrSOp, spirv::ShiftRightArithmeticOp>; |
43 | using ConvertIndexShrU = |
44 | spirv::ElementwiseOpPattern<ShrUOp, spirv::ShiftRightLogicalOp>; |
45 | |
46 | /// It is the case that when we convert bitwise operations to SPIR-V operations |
47 | /// we must take into account the special pattern in SPIR-V that if the |
48 | /// operands are boolean values, then SPIR-V uses `SPIRVLogicalOp`. Otherwise, |
49 | /// for non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. However, |
50 | /// index.add is never a boolean operation so we can directly convert it to the |
51 | /// Bitwise[And|Or]Op. |
52 | using ConvertIndexAnd = spirv::ElementwiseOpPattern<AndOp, spirv::BitwiseAndOp>; |
53 | using ConvertIndexOr = spirv::ElementwiseOpPattern<OrOp, spirv::BitwiseOrOp>; |
54 | using ConvertIndexXor = spirv::ElementwiseOpPattern<XOrOp, spirv::BitwiseXorOp>; |
55 | |
56 | //===----------------------------------------------------------------------===// |
57 | // ConvertConstantBool |
58 | //===----------------------------------------------------------------------===// |
59 | |
60 | // Converts index.bool.constant operation to spirv.Constant. |
61 | struct ConvertIndexConstantBoolOpPattern final |
62 | : OpConversionPattern<BoolConstantOp> { |
63 | using OpConversionPattern::OpConversionPattern; |
64 | |
65 | LogicalResult |
66 | matchAndRewrite(BoolConstantOp op, BoolConstantOpAdaptor adaptor, |
67 | ConversionPatternRewriter &rewriter) const override { |
68 | rewriter.replaceOpWithNewOp<spirv::ConstantOp>(op, op.getType(), |
69 | op.getValueAttr()); |
70 | return success(); |
71 | } |
72 | }; |
73 | |
74 | //===----------------------------------------------------------------------===// |
75 | // ConvertConstant |
76 | //===----------------------------------------------------------------------===// |
77 | |
78 | // Converts index.constant op to spirv.Constant. Will truncate from i64 to i32 |
79 | // when required. |
80 | struct ConvertIndexConstantOpPattern final : OpConversionPattern<ConstantOp> { |
81 | using OpConversionPattern::OpConversionPattern; |
82 | |
83 | LogicalResult |
84 | matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor, |
85 | ConversionPatternRewriter &rewriter) const override { |
86 | auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>(); |
87 | Type indexType = typeConverter->getIndexType(); |
88 | |
89 | APInt value = op.getValue().trunc(typeConverter->getIndexTypeBitwidth()); |
90 | rewriter.replaceOpWithNewOp<spirv::ConstantOp>( |
91 | op, indexType, IntegerAttr::get(indexType, value)); |
92 | return success(); |
93 | } |
94 | }; |
95 | |
96 | //===----------------------------------------------------------------------===// |
97 | // ConvertIndexCeilDivS |
98 | //===----------------------------------------------------------------------===// |
99 | |
100 | /// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then |
101 | /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. Formula taken from the equivalent |
102 | /// conversion in IndexToLLVM. |
103 | struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> { |
104 | using OpConversionPattern::OpConversionPattern; |
105 | |
106 | LogicalResult |
107 | matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor, |
108 | ConversionPatternRewriter &rewriter) const override { |
109 | Location loc = op.getLoc(); |
110 | Value n = adaptor.getLhs(); |
111 | Type n_type = n.getType(); |
112 | Value m = adaptor.getRhs(); |
113 | |
114 | // Define the constants |
115 | Value zero = rewriter.create<spirv::ConstantOp>( |
116 | loc, n_type, IntegerAttr::get(n_type, 0)); |
117 | Value posOne = rewriter.create<spirv::ConstantOp>( |
118 | loc, n_type, IntegerAttr::get(n_type, 1)); |
119 | Value negOne = rewriter.create<spirv::ConstantOp>( |
120 | loc, n_type, IntegerAttr::get(n_type, -1)); |
121 | |
122 | // Compute `x`. |
123 | Value mPos = rewriter.create<spirv::SGreaterThanOp>(loc, m, zero); |
124 | Value x = rewriter.create<spirv::SelectOp>(loc, mPos, negOne, posOne); |
125 | |
126 | // Compute the positive result. |
127 | Value nPlusX = rewriter.create<spirv::IAddOp>(loc, n, x); |
128 | Value nPlusXDivM = rewriter.create<spirv::SDivOp>(loc, nPlusX, m); |
129 | Value posRes = rewriter.create<spirv::IAddOp>(loc, nPlusXDivM, posOne); |
130 | |
131 | // Compute the negative result. |
132 | Value negN = rewriter.create<spirv::ISubOp>(loc, zero, n); |
133 | Value negNDivM = rewriter.create<spirv::SDivOp>(loc, negN, m); |
134 | Value negRes = rewriter.create<spirv::ISubOp>(loc, zero, negNDivM); |
135 | |
136 | // Pick the positive result if `n` and `m` have the same sign and `n` is |
137 | // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`. |
138 | Value nPos = rewriter.create<spirv::SGreaterThanOp>(loc, n, zero); |
139 | Value sameSign = rewriter.create<spirv::LogicalEqualOp>(loc, nPos, mPos); |
140 | Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero); |
141 | Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, sameSign, nNonZero); |
142 | rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes); |
143 | return success(); |
144 | } |
145 | }; |
146 | |
147 | //===----------------------------------------------------------------------===// |
148 | // ConvertIndexCeilDivU |
149 | //===----------------------------------------------------------------------===// |
150 | |
151 | /// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`. Formula taken |
152 | /// from the equivalent conversion in IndexToLLVM. |
153 | struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> { |
154 | using OpConversionPattern::OpConversionPattern; |
155 | |
156 | LogicalResult |
157 | matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor, |
158 | ConversionPatternRewriter &rewriter) const override { |
159 | Location loc = op.getLoc(); |
160 | Value n = adaptor.getLhs(); |
161 | Type n_type = n.getType(); |
162 | Value m = adaptor.getRhs(); |
163 | |
164 | // Define the constants |
165 | Value zero = rewriter.create<spirv::ConstantOp>( |
166 | loc, n_type, IntegerAttr::get(n_type, 0)); |
167 | Value one = rewriter.create<spirv::ConstantOp>(loc, n_type, |
168 | IntegerAttr::get(n_type, 1)); |
169 | |
170 | // Compute the non-zero result. |
171 | Value minusOne = rewriter.create<spirv::ISubOp>(loc, n, one); |
172 | Value quotient = rewriter.create<spirv::UDivOp>(loc, minusOne, m); |
173 | Value plusOne = rewriter.create<spirv::IAddOp>(loc, quotient, one); |
174 | |
175 | // Pick the result |
176 | Value cmp = rewriter.create<spirv::IEqualOp>(loc, n, zero); |
177 | rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, zero, plusOne); |
178 | return success(); |
179 | } |
180 | }; |
181 | |
182 | //===----------------------------------------------------------------------===// |
183 | // ConvertIndexFloorDivS |
184 | //===----------------------------------------------------------------------===// |
185 | |
186 | /// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then |
187 | /// `n*m < 0 ? -1 - (x-n)/m : n/m`. Formula taken from the equivalent conversion |
188 | /// in IndexToLLVM. |
189 | struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> { |
190 | using OpConversionPattern::OpConversionPattern; |
191 | |
192 | LogicalResult |
193 | matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor, |
194 | ConversionPatternRewriter &rewriter) const override { |
195 | Location loc = op.getLoc(); |
196 | Value n = adaptor.getLhs(); |
197 | Type n_type = n.getType(); |
198 | Value m = adaptor.getRhs(); |
199 | |
200 | // Define the constants |
201 | Value zero = rewriter.create<spirv::ConstantOp>( |
202 | loc, n_type, IntegerAttr::get(n_type, 0)); |
203 | Value posOne = rewriter.create<spirv::ConstantOp>( |
204 | loc, n_type, IntegerAttr::get(n_type, 1)); |
205 | Value negOne = rewriter.create<spirv::ConstantOp>( |
206 | loc, n_type, IntegerAttr::get(n_type, -1)); |
207 | |
208 | // Compute `x`. |
209 | Value mNeg = rewriter.create<spirv::SLessThanOp>(loc, m, zero); |
210 | Value x = rewriter.create<spirv::SelectOp>(loc, mNeg, posOne, negOne); |
211 | |
212 | // Compute the negative result |
213 | Value xMinusN = rewriter.create<spirv::ISubOp>(loc, x, n); |
214 | Value xMinusNDivM = rewriter.create<spirv::SDivOp>(loc, xMinusN, m); |
215 | Value negRes = rewriter.create<spirv::ISubOp>(loc, negOne, xMinusNDivM); |
216 | |
217 | // Compute the positive result. |
218 | Value posRes = rewriter.create<spirv::SDivOp>(loc, n, m); |
219 | |
220 | // Pick the negative result if `n` and `m` have different signs and `n` is |
221 | // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`. |
222 | Value nNeg = rewriter.create<spirv::SLessThanOp>(loc, n, zero); |
223 | Value diffSign = rewriter.create<spirv::LogicalNotEqualOp>(loc, nNeg, mNeg); |
224 | Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero); |
225 | |
226 | Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, diffSign, nNonZero); |
227 | rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes); |
228 | return success(); |
229 | } |
230 | }; |
231 | |
232 | //===----------------------------------------------------------------------===// |
233 | // ConvertIndexCast |
234 | //===----------------------------------------------------------------------===// |
235 | |
236 | /// Convert a cast op. If the materialized index type is the same as the other |
237 | /// type, fold away the op. Otherwise, use the Convert SPIR-V operation. |
238 | /// Signed casts sign extend when the result bitwidth is larger. Unsigned casts |
239 | /// zero extend when the result bitwidth is larger. |
240 | template <typename CastOp, typename ConvertOp> |
241 | struct ConvertIndexCast final : OpConversionPattern<CastOp> { |
242 | using OpConversionPattern<CastOp>::OpConversionPattern; |
243 | |
244 | LogicalResult |
245 | matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor, |
246 | ConversionPatternRewriter &rewriter) const override { |
247 | auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>(); |
248 | Type indexType = typeConverter->getIndexType(); |
249 | |
250 | Type srcType = adaptor.getInput().getType(); |
251 | Type dstType = op.getType(); |
252 | if (isa<IndexType>(Val: srcType)) { |
253 | srcType = indexType; |
254 | } |
255 | if (isa<IndexType>(Val: dstType)) { |
256 | dstType = indexType; |
257 | } |
258 | |
259 | if (srcType == dstType) { |
260 | rewriter.replaceOp(op, adaptor.getInput()); |
261 | } else { |
262 | rewriter.template replaceOpWithNewOp<ConvertOp>(op, dstType, |
263 | adaptor.getOperands()); |
264 | } |
265 | return success(); |
266 | } |
267 | }; |
268 | |
269 | using ConvertIndexCastS = ConvertIndexCast<CastSOp, spirv::SConvertOp>; |
270 | using ConvertIndexCastU = ConvertIndexCast<CastUOp, spirv::UConvertOp>; |
271 | |
272 | //===----------------------------------------------------------------------===// |
273 | // ConvertIndexCmp |
274 | //===----------------------------------------------------------------------===// |
275 | |
276 | // Helper template to replace the operation |
277 | template <typename ICmpOp> |
278 | static LogicalResult rewriteCmpOp(CmpOp op, CmpOpAdaptor adaptor, |
279 | ConversionPatternRewriter &rewriter) { |
280 | rewriter.replaceOpWithNewOp<ICmpOp>(op, adaptor.getLhs(), adaptor.getRhs()); |
281 | return success(); |
282 | } |
283 | |
284 | struct ConvertIndexCmpPattern final : OpConversionPattern<CmpOp> { |
285 | using OpConversionPattern::OpConversionPattern; |
286 | |
287 | LogicalResult |
288 | matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor, |
289 | ConversionPatternRewriter &rewriter) const override { |
290 | // We must convert the predicates to the corresponding int comparions. |
291 | switch (op.getPred()) { |
292 | case IndexCmpPredicate::EQ: |
293 | return rewriteCmpOp<spirv::IEqualOp>(op, adaptor, rewriter); |
294 | case IndexCmpPredicate::NE: |
295 | return rewriteCmpOp<spirv::INotEqualOp>(op, adaptor, rewriter); |
296 | case IndexCmpPredicate::SGE: |
297 | return rewriteCmpOp<spirv::SGreaterThanEqualOp>(op, adaptor, rewriter); |
298 | case IndexCmpPredicate::SGT: |
299 | return rewriteCmpOp<spirv::SGreaterThanOp>(op, adaptor, rewriter); |
300 | case IndexCmpPredicate::SLE: |
301 | return rewriteCmpOp<spirv::SLessThanEqualOp>(op, adaptor, rewriter); |
302 | case IndexCmpPredicate::SLT: |
303 | return rewriteCmpOp<spirv::SLessThanOp>(op, adaptor, rewriter); |
304 | case IndexCmpPredicate::UGE: |
305 | return rewriteCmpOp<spirv::UGreaterThanEqualOp>(op, adaptor, rewriter); |
306 | case IndexCmpPredicate::UGT: |
307 | return rewriteCmpOp<spirv::UGreaterThanOp>(op, adaptor, rewriter); |
308 | case IndexCmpPredicate::ULE: |
309 | return rewriteCmpOp<spirv::ULessThanEqualOp>(op, adaptor, rewriter); |
310 | case IndexCmpPredicate::ULT: |
311 | return rewriteCmpOp<spirv::ULessThanOp>(op, adaptor, rewriter); |
312 | } |
313 | } |
314 | }; |
315 | |
316 | //===----------------------------------------------------------------------===// |
317 | // ConvertIndexSizeOf |
318 | //===----------------------------------------------------------------------===// |
319 | |
320 | /// Lower `index.sizeof` to a constant with the value of the index bitwidth. |
321 | struct ConvertIndexSizeOf final : OpConversionPattern<SizeOfOp> { |
322 | using OpConversionPattern::OpConversionPattern; |
323 | |
324 | LogicalResult |
325 | matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor, |
326 | ConversionPatternRewriter &rewriter) const override { |
327 | auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>(); |
328 | Type indexType = typeConverter->getIndexType(); |
329 | unsigned bitwidth = typeConverter->getIndexTypeBitwidth(); |
330 | rewriter.replaceOpWithNewOp<spirv::ConstantOp>( |
331 | op, indexType, IntegerAttr::get(indexType, bitwidth)); |
332 | return success(); |
333 | } |
334 | }; |
335 | } // namespace |
336 | |
337 | //===----------------------------------------------------------------------===// |
338 | // Pattern Population |
339 | //===----------------------------------------------------------------------===// |
340 | |
341 | void index::populateIndexToSPIRVPatterns(SPIRVTypeConverter &typeConverter, |
342 | RewritePatternSet &patterns) { |
343 | patterns.add< |
344 | // clang-format off |
345 | ConvertIndexAdd, |
346 | ConvertIndexSub, |
347 | ConvertIndexMul, |
348 | ConvertIndexDivS, |
349 | ConvertIndexDivU, |
350 | ConvertIndexRemS, |
351 | ConvertIndexRemU, |
352 | ConvertIndexMaxS, |
353 | ConvertIndexMaxU, |
354 | ConvertIndexMinS, |
355 | ConvertIndexMinU, |
356 | ConvertIndexShl, |
357 | ConvertIndexShrS, |
358 | ConvertIndexShrU, |
359 | ConvertIndexAnd, |
360 | ConvertIndexOr, |
361 | ConvertIndexXor, |
362 | ConvertIndexConstantBoolOpPattern, |
363 | ConvertIndexConstantOpPattern, |
364 | ConvertIndexCeilDivSPattern, |
365 | ConvertIndexCeilDivUPattern, |
366 | ConvertIndexFloorDivSPattern, |
367 | ConvertIndexCastS, |
368 | ConvertIndexCastU, |
369 | ConvertIndexCmpPattern, |
370 | ConvertIndexSizeOf |
371 | >(typeConverter, patterns.getContext()); |
372 | } |
373 | |
374 | //===----------------------------------------------------------------------===// |
375 | // ODS-Generated Definitions |
376 | //===----------------------------------------------------------------------===// |
377 | |
378 | namespace mlir { |
379 | #define GEN_PASS_DEF_CONVERTINDEXTOSPIRVPASS |
380 | #include "mlir/Conversion/Passes.h.inc" |
381 | } // namespace mlir |
382 | |
383 | //===----------------------------------------------------------------------===// |
384 | // Pass Definition |
385 | //===----------------------------------------------------------------------===// |
386 | |
387 | namespace { |
388 | struct ConvertIndexToSPIRVPass |
389 | : public impl::ConvertIndexToSPIRVPassBase<ConvertIndexToSPIRVPass> { |
390 | using Base::Base; |
391 | |
392 | void runOnOperation() override { |
393 | Operation *op = getOperation(); |
394 | spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op); |
395 | std::unique_ptr<SPIRVConversionTarget> target = |
396 | SPIRVConversionTarget::get(targetAttr); |
397 | |
398 | SPIRVConversionOptions options; |
399 | options.use64bitIndex = this->use64bitIndex; |
400 | SPIRVTypeConverter typeConverter(targetAttr, options); |
401 | |
402 | // Use UnrealizedConversionCast as the bridge so that we don't need to pull |
403 | // in patterns for other dialects. |
404 | target->addLegalOp<UnrealizedConversionCastOp>(); |
405 | |
406 | // Allow the spirv operations we are converting to |
407 | target->addLegalDialect<spirv::SPIRVDialect>(); |
408 | // Fail hard when there are any remaining 'index' ops. |
409 | target->addIllegalDialect<index::IndexDialect>(); |
410 | |
411 | RewritePatternSet patterns(&getContext()); |
412 | index::populateIndexToSPIRVPatterns(typeConverter, patterns); |
413 | |
414 | if (failed(applyPartialConversion(op, *target, std::move(patterns)))) |
415 | signalPassFailure(); |
416 | } |
417 | }; |
418 | } // namespace |
419 | |