1//===- IndexToSPIRV.cpp - Index to SPIRV dialect conversion -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
10#include "../SPIRVCommon/Pattern.h"
11#include "mlir/Dialect/Index/IR/IndexDialect.h"
12#include "mlir/Dialect/Index/IR/IndexOps.h"
13#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
14#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
15#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
16#include "mlir/Pass/Pass.h"
17
18using namespace mlir;
19using namespace index;
20
21namespace {
22
23//===----------------------------------------------------------------------===//
24// Trivial Conversions
25//===----------------------------------------------------------------------===//
26
27using ConvertIndexAdd = spirv::ElementwiseOpPattern<AddOp, spirv::IAddOp>;
28using ConvertIndexSub = spirv::ElementwiseOpPattern<SubOp, spirv::ISubOp>;
29using ConvertIndexMul = spirv::ElementwiseOpPattern<MulOp, spirv::IMulOp>;
30using ConvertIndexDivS = spirv::ElementwiseOpPattern<DivSOp, spirv::SDivOp>;
31using ConvertIndexDivU = spirv::ElementwiseOpPattern<DivUOp, spirv::UDivOp>;
32using ConvertIndexRemS = spirv::ElementwiseOpPattern<RemSOp, spirv::SRemOp>;
33using ConvertIndexRemU = spirv::ElementwiseOpPattern<RemUOp, spirv::UModOp>;
34using ConvertIndexMaxS = spirv::ElementwiseOpPattern<MaxSOp, spirv::GLSMaxOp>;
35using ConvertIndexMaxU = spirv::ElementwiseOpPattern<MaxUOp, spirv::GLUMaxOp>;
36using ConvertIndexMinS = spirv::ElementwiseOpPattern<MinSOp, spirv::GLSMinOp>;
37using ConvertIndexMinU = spirv::ElementwiseOpPattern<MinUOp, spirv::GLUMinOp>;
38
39using ConvertIndexShl =
40 spirv::ElementwiseOpPattern<ShlOp, spirv::ShiftLeftLogicalOp>;
41using ConvertIndexShrS =
42 spirv::ElementwiseOpPattern<ShrSOp, spirv::ShiftRightArithmeticOp>;
43using ConvertIndexShrU =
44 spirv::ElementwiseOpPattern<ShrUOp, spirv::ShiftRightLogicalOp>;
45
46/// It is the case that when we convert bitwise operations to SPIR-V operations
47/// we must take into account the special pattern in SPIR-V that if the
48/// operands are boolean values, then SPIR-V uses `SPIRVLogicalOp`. Otherwise,
49/// for non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. However,
50/// index.add is never a boolean operation so we can directly convert it to the
51/// Bitwise[And|Or]Op.
52using ConvertIndexAnd = spirv::ElementwiseOpPattern<AndOp, spirv::BitwiseAndOp>;
53using ConvertIndexOr = spirv::ElementwiseOpPattern<OrOp, spirv::BitwiseOrOp>;
54using ConvertIndexXor = spirv::ElementwiseOpPattern<XOrOp, spirv::BitwiseXorOp>;
55
56//===----------------------------------------------------------------------===//
57// ConvertConstantBool
58//===----------------------------------------------------------------------===//
59
60// Converts index.bool.constant operation to spirv.Constant.
61struct ConvertIndexConstantBoolOpPattern final
62 : OpConversionPattern<BoolConstantOp> {
63 using OpConversionPattern::OpConversionPattern;
64
65 LogicalResult
66 matchAndRewrite(BoolConstantOp op, BoolConstantOpAdaptor adaptor,
67 ConversionPatternRewriter &rewriter) const override {
68 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(op, op.getType(),
69 op.getValueAttr());
70 return success();
71 }
72};
73
74//===----------------------------------------------------------------------===//
75// ConvertConstant
76//===----------------------------------------------------------------------===//
77
78// Converts index.constant op to spirv.Constant. Will truncate from i64 to i32
79// when required.
80struct ConvertIndexConstantOpPattern final : OpConversionPattern<ConstantOp> {
81 using OpConversionPattern::OpConversionPattern;
82
83 LogicalResult
84 matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor,
85 ConversionPatternRewriter &rewriter) const override {
86 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
87 Type indexType = typeConverter->getIndexType();
88
89 APInt value = op.getValue().trunc(typeConverter->getIndexTypeBitwidth());
90 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
91 op, indexType, IntegerAttr::get(indexType, value));
92 return success();
93 }
94};
95
96//===----------------------------------------------------------------------===//
97// ConvertIndexCeilDivS
98//===----------------------------------------------------------------------===//
99
100/// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then
101/// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. Formula taken from the equivalent
102/// conversion in IndexToLLVM.
103struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> {
104 using OpConversionPattern::OpConversionPattern;
105
106 LogicalResult
107 matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor,
108 ConversionPatternRewriter &rewriter) const override {
109 Location loc = op.getLoc();
110 Value n = adaptor.getLhs();
111 Type n_type = n.getType();
112 Value m = adaptor.getRhs();
113
114 // Define the constants
115 Value zero = rewriter.create<spirv::ConstantOp>(
116 loc, n_type, IntegerAttr::get(n_type, 0));
117 Value posOne = rewriter.create<spirv::ConstantOp>(
118 loc, n_type, IntegerAttr::get(n_type, 1));
119 Value negOne = rewriter.create<spirv::ConstantOp>(
120 loc, n_type, IntegerAttr::get(n_type, -1));
121
122 // Compute `x`.
123 Value mPos = rewriter.create<spirv::SGreaterThanOp>(loc, m, zero);
124 Value x = rewriter.create<spirv::SelectOp>(loc, mPos, negOne, posOne);
125
126 // Compute the positive result.
127 Value nPlusX = rewriter.create<spirv::IAddOp>(loc, n, x);
128 Value nPlusXDivM = rewriter.create<spirv::SDivOp>(loc, nPlusX, m);
129 Value posRes = rewriter.create<spirv::IAddOp>(loc, nPlusXDivM, posOne);
130
131 // Compute the negative result.
132 Value negN = rewriter.create<spirv::ISubOp>(loc, zero, n);
133 Value negNDivM = rewriter.create<spirv::SDivOp>(loc, negN, m);
134 Value negRes = rewriter.create<spirv::ISubOp>(loc, zero, negNDivM);
135
136 // Pick the positive result if `n` and `m` have the same sign and `n` is
137 // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
138 Value nPos = rewriter.create<spirv::SGreaterThanOp>(loc, n, zero);
139 Value sameSign = rewriter.create<spirv::LogicalEqualOp>(loc, nPos, mPos);
140 Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
141 Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, sameSign, nNonZero);
142 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
143 return success();
144 }
145};
146
147//===----------------------------------------------------------------------===//
148// ConvertIndexCeilDivU
149//===----------------------------------------------------------------------===//
150
151/// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`. Formula taken
152/// from the equivalent conversion in IndexToLLVM.
153struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> {
154 using OpConversionPattern::OpConversionPattern;
155
156 LogicalResult
157 matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor,
158 ConversionPatternRewriter &rewriter) const override {
159 Location loc = op.getLoc();
160 Value n = adaptor.getLhs();
161 Type n_type = n.getType();
162 Value m = adaptor.getRhs();
163
164 // Define the constants
165 Value zero = rewriter.create<spirv::ConstantOp>(
166 loc, n_type, IntegerAttr::get(n_type, 0));
167 Value one = rewriter.create<spirv::ConstantOp>(loc, n_type,
168 IntegerAttr::get(n_type, 1));
169
170 // Compute the non-zero result.
171 Value minusOne = rewriter.create<spirv::ISubOp>(loc, n, one);
172 Value quotient = rewriter.create<spirv::UDivOp>(loc, minusOne, m);
173 Value plusOne = rewriter.create<spirv::IAddOp>(loc, quotient, one);
174
175 // Pick the result
176 Value cmp = rewriter.create<spirv::IEqualOp>(loc, n, zero);
177 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, zero, plusOne);
178 return success();
179 }
180};
181
182//===----------------------------------------------------------------------===//
183// ConvertIndexFloorDivS
184//===----------------------------------------------------------------------===//
185
186/// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then
187/// `n*m < 0 ? -1 - (x-n)/m : n/m`. Formula taken from the equivalent conversion
188/// in IndexToLLVM.
189struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> {
190 using OpConversionPattern::OpConversionPattern;
191
192 LogicalResult
193 matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor,
194 ConversionPatternRewriter &rewriter) const override {
195 Location loc = op.getLoc();
196 Value n = adaptor.getLhs();
197 Type n_type = n.getType();
198 Value m = adaptor.getRhs();
199
200 // Define the constants
201 Value zero = rewriter.create<spirv::ConstantOp>(
202 loc, n_type, IntegerAttr::get(n_type, 0));
203 Value posOne = rewriter.create<spirv::ConstantOp>(
204 loc, n_type, IntegerAttr::get(n_type, 1));
205 Value negOne = rewriter.create<spirv::ConstantOp>(
206 loc, n_type, IntegerAttr::get(n_type, -1));
207
208 // Compute `x`.
209 Value mNeg = rewriter.create<spirv::SLessThanOp>(loc, m, zero);
210 Value x = rewriter.create<spirv::SelectOp>(loc, mNeg, posOne, negOne);
211
212 // Compute the negative result
213 Value xMinusN = rewriter.create<spirv::ISubOp>(loc, x, n);
214 Value xMinusNDivM = rewriter.create<spirv::SDivOp>(loc, xMinusN, m);
215 Value negRes = rewriter.create<spirv::ISubOp>(loc, negOne, xMinusNDivM);
216
217 // Compute the positive result.
218 Value posRes = rewriter.create<spirv::SDivOp>(loc, n, m);
219
220 // Pick the negative result if `n` and `m` have different signs and `n` is
221 // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
222 Value nNeg = rewriter.create<spirv::SLessThanOp>(loc, n, zero);
223 Value diffSign = rewriter.create<spirv::LogicalNotEqualOp>(loc, nNeg, mNeg);
224 Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
225
226 Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, diffSign, nNonZero);
227 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
228 return success();
229 }
230};
231
232//===----------------------------------------------------------------------===//
233// ConvertIndexCast
234//===----------------------------------------------------------------------===//
235
236/// Convert a cast op. If the materialized index type is the same as the other
237/// type, fold away the op. Otherwise, use the Convert SPIR-V operation.
238/// Signed casts sign extend when the result bitwidth is larger. Unsigned casts
239/// zero extend when the result bitwidth is larger.
240template <typename CastOp, typename ConvertOp>
241struct ConvertIndexCast final : OpConversionPattern<CastOp> {
242 using OpConversionPattern<CastOp>::OpConversionPattern;
243
244 LogicalResult
245 matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
246 ConversionPatternRewriter &rewriter) const override {
247 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
248 Type indexType = typeConverter->getIndexType();
249
250 Type srcType = adaptor.getInput().getType();
251 Type dstType = op.getType();
252 if (isa<IndexType>(Val: srcType)) {
253 srcType = indexType;
254 }
255 if (isa<IndexType>(Val: dstType)) {
256 dstType = indexType;
257 }
258
259 if (srcType == dstType) {
260 rewriter.replaceOp(op, adaptor.getInput());
261 } else {
262 rewriter.template replaceOpWithNewOp<ConvertOp>(op, dstType,
263 adaptor.getOperands());
264 }
265 return success();
266 }
267};
268
269using ConvertIndexCastS = ConvertIndexCast<CastSOp, spirv::SConvertOp>;
270using ConvertIndexCastU = ConvertIndexCast<CastUOp, spirv::UConvertOp>;
271
272//===----------------------------------------------------------------------===//
273// ConvertIndexCmp
274//===----------------------------------------------------------------------===//
275
276// Helper template to replace the operation
277template <typename ICmpOp>
278static LogicalResult rewriteCmpOp(CmpOp op, CmpOpAdaptor adaptor,
279 ConversionPatternRewriter &rewriter) {
280 rewriter.replaceOpWithNewOp<ICmpOp>(op, adaptor.getLhs(), adaptor.getRhs());
281 return success();
282}
283
284struct ConvertIndexCmpPattern final : OpConversionPattern<CmpOp> {
285 using OpConversionPattern::OpConversionPattern;
286
287 LogicalResult
288 matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor,
289 ConversionPatternRewriter &rewriter) const override {
290 // We must convert the predicates to the corresponding int comparions.
291 switch (op.getPred()) {
292 case IndexCmpPredicate::EQ:
293 return rewriteCmpOp<spirv::IEqualOp>(op, adaptor, rewriter);
294 case IndexCmpPredicate::NE:
295 return rewriteCmpOp<spirv::INotEqualOp>(op, adaptor, rewriter);
296 case IndexCmpPredicate::SGE:
297 return rewriteCmpOp<spirv::SGreaterThanEqualOp>(op, adaptor, rewriter);
298 case IndexCmpPredicate::SGT:
299 return rewriteCmpOp<spirv::SGreaterThanOp>(op, adaptor, rewriter);
300 case IndexCmpPredicate::SLE:
301 return rewriteCmpOp<spirv::SLessThanEqualOp>(op, adaptor, rewriter);
302 case IndexCmpPredicate::SLT:
303 return rewriteCmpOp<spirv::SLessThanOp>(op, adaptor, rewriter);
304 case IndexCmpPredicate::UGE:
305 return rewriteCmpOp<spirv::UGreaterThanEqualOp>(op, adaptor, rewriter);
306 case IndexCmpPredicate::UGT:
307 return rewriteCmpOp<spirv::UGreaterThanOp>(op, adaptor, rewriter);
308 case IndexCmpPredicate::ULE:
309 return rewriteCmpOp<spirv::ULessThanEqualOp>(op, adaptor, rewriter);
310 case IndexCmpPredicate::ULT:
311 return rewriteCmpOp<spirv::ULessThanOp>(op, adaptor, rewriter);
312 }
313 }
314};
315
316//===----------------------------------------------------------------------===//
317// ConvertIndexSizeOf
318//===----------------------------------------------------------------------===//
319
320/// Lower `index.sizeof` to a constant with the value of the index bitwidth.
321struct ConvertIndexSizeOf final : OpConversionPattern<SizeOfOp> {
322 using OpConversionPattern::OpConversionPattern;
323
324 LogicalResult
325 matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor,
326 ConversionPatternRewriter &rewriter) const override {
327 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
328 Type indexType = typeConverter->getIndexType();
329 unsigned bitwidth = typeConverter->getIndexTypeBitwidth();
330 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
331 op, indexType, IntegerAttr::get(indexType, bitwidth));
332 return success();
333 }
334};
335} // namespace
336
337//===----------------------------------------------------------------------===//
338// Pattern Population
339//===----------------------------------------------------------------------===//
340
341void index::populateIndexToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
342 RewritePatternSet &patterns) {
343 patterns.add<
344 // clang-format off
345 ConvertIndexAdd,
346 ConvertIndexSub,
347 ConvertIndexMul,
348 ConvertIndexDivS,
349 ConvertIndexDivU,
350 ConvertIndexRemS,
351 ConvertIndexRemU,
352 ConvertIndexMaxS,
353 ConvertIndexMaxU,
354 ConvertIndexMinS,
355 ConvertIndexMinU,
356 ConvertIndexShl,
357 ConvertIndexShrS,
358 ConvertIndexShrU,
359 ConvertIndexAnd,
360 ConvertIndexOr,
361 ConvertIndexXor,
362 ConvertIndexConstantBoolOpPattern,
363 ConvertIndexConstantOpPattern,
364 ConvertIndexCeilDivSPattern,
365 ConvertIndexCeilDivUPattern,
366 ConvertIndexFloorDivSPattern,
367 ConvertIndexCastS,
368 ConvertIndexCastU,
369 ConvertIndexCmpPattern,
370 ConvertIndexSizeOf
371 >(typeConverter, patterns.getContext());
372}
373
374//===----------------------------------------------------------------------===//
375// ODS-Generated Definitions
376//===----------------------------------------------------------------------===//
377
378namespace mlir {
379#define GEN_PASS_DEF_CONVERTINDEXTOSPIRVPASS
380#include "mlir/Conversion/Passes.h.inc"
381} // namespace mlir
382
383//===----------------------------------------------------------------------===//
384// Pass Definition
385//===----------------------------------------------------------------------===//
386
387namespace {
388struct ConvertIndexToSPIRVPass
389 : public impl::ConvertIndexToSPIRVPassBase<ConvertIndexToSPIRVPass> {
390 using Base::Base;
391
392 void runOnOperation() override {
393 Operation *op = getOperation();
394 spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
395 std::unique_ptr<SPIRVConversionTarget> target =
396 SPIRVConversionTarget::get(targetAttr);
397
398 SPIRVConversionOptions options;
399 options.use64bitIndex = this->use64bitIndex;
400 SPIRVTypeConverter typeConverter(targetAttr, options);
401
402 // Use UnrealizedConversionCast as the bridge so that we don't need to pull
403 // in patterns for other dialects.
404 target->addLegalOp<UnrealizedConversionCastOp>();
405
406 // Allow the spirv operations we are converting to
407 target->addLegalDialect<spirv::SPIRVDialect>();
408 // Fail hard when there are any remaining 'index' ops.
409 target->addIllegalDialect<index::IndexDialect>();
410
411 RewritePatternSet patterns(&getContext());
412 index::populateIndexToSPIRVPatterns(typeConverter, patterns);
413
414 if (failed(applyPartialConversion(op, *target, std::move(patterns))))
415 signalPassFailure();
416 }
417};
418} // namespace
419

source code of mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp