1//===- IndexToSPIRV.cpp - Index to SPIRV dialect conversion -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
10#include "../SPIRVCommon/Pattern.h"
11#include "mlir/Dialect/Index/IR/IndexDialect.h"
12#include "mlir/Dialect/Index/IR/IndexOps.h"
13#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
14#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
15#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
16
17using namespace mlir;
18using namespace index;
19
20namespace {
21
22//===----------------------------------------------------------------------===//
23// Trivial Conversions
24//===----------------------------------------------------------------------===//
25
26using ConvertIndexAdd = spirv::ElementwiseOpPattern<AddOp, spirv::IAddOp>;
27using ConvertIndexSub = spirv::ElementwiseOpPattern<SubOp, spirv::ISubOp>;
28using ConvertIndexMul = spirv::ElementwiseOpPattern<MulOp, spirv::IMulOp>;
29using ConvertIndexDivS = spirv::ElementwiseOpPattern<DivSOp, spirv::SDivOp>;
30using ConvertIndexDivU = spirv::ElementwiseOpPattern<DivUOp, spirv::UDivOp>;
31using ConvertIndexRemS = spirv::ElementwiseOpPattern<RemSOp, spirv::SRemOp>;
32using ConvertIndexRemU = spirv::ElementwiseOpPattern<RemUOp, spirv::UModOp>;
33using ConvertIndexMaxS = spirv::ElementwiseOpPattern<MaxSOp, spirv::GLSMaxOp>;
34using ConvertIndexMaxU = spirv::ElementwiseOpPattern<MaxUOp, spirv::GLUMaxOp>;
35using ConvertIndexMinS = spirv::ElementwiseOpPattern<MinSOp, spirv::GLSMinOp>;
36using ConvertIndexMinU = spirv::ElementwiseOpPattern<MinUOp, spirv::GLUMinOp>;
37
38using ConvertIndexShl =
39 spirv::ElementwiseOpPattern<ShlOp, spirv::ShiftLeftLogicalOp>;
40using ConvertIndexShrS =
41 spirv::ElementwiseOpPattern<ShrSOp, spirv::ShiftRightArithmeticOp>;
42using ConvertIndexShrU =
43 spirv::ElementwiseOpPattern<ShrUOp, spirv::ShiftRightLogicalOp>;
44
45/// It is the case that when we convert bitwise operations to SPIR-V operations
46/// we must take into account the special pattern in SPIR-V that if the
47/// operands are boolean values, then SPIR-V uses `SPIRVLogicalOp`. Otherwise,
48/// for non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. However,
49/// index.add is never a boolean operation so we can directly convert it to the
50/// Bitwise[And|Or]Op.
51using ConvertIndexAnd = spirv::ElementwiseOpPattern<AndOp, spirv::BitwiseAndOp>;
52using ConvertIndexOr = spirv::ElementwiseOpPattern<OrOp, spirv::BitwiseOrOp>;
53using ConvertIndexXor = spirv::ElementwiseOpPattern<XOrOp, spirv::BitwiseXorOp>;
54
55//===----------------------------------------------------------------------===//
56// ConvertConstantBool
57//===----------------------------------------------------------------------===//
58
59// Converts index.bool.constant operation to spirv.Constant.
60struct ConvertIndexConstantBoolOpPattern final
61 : OpConversionPattern<BoolConstantOp> {
62 using OpConversionPattern::OpConversionPattern;
63
64 LogicalResult
65 matchAndRewrite(BoolConstantOp op, BoolConstantOpAdaptor adaptor,
66 ConversionPatternRewriter &rewriter) const override {
67 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(op, args: op.getType(),
68 args: op.getValueAttr());
69 return success();
70 }
71};
72
73//===----------------------------------------------------------------------===//
74// ConvertConstant
75//===----------------------------------------------------------------------===//
76
77// Converts index.constant op to spirv.Constant. Will truncate from i64 to i32
78// when required.
79struct ConvertIndexConstantOpPattern final : OpConversionPattern<ConstantOp> {
80 using OpConversionPattern::OpConversionPattern;
81
82 LogicalResult
83 matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor,
84 ConversionPatternRewriter &rewriter) const override {
85 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
86 Type indexType = typeConverter->getIndexType();
87
88 APInt value = op.getValue().trunc(width: typeConverter->getIndexTypeBitwidth());
89 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
90 op, args&: indexType, args: IntegerAttr::get(type: indexType, value));
91 return success();
92 }
93};
94
95//===----------------------------------------------------------------------===//
96// ConvertIndexCeilDivS
97//===----------------------------------------------------------------------===//
98
99/// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then
100/// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. Formula taken from the equivalent
101/// conversion in IndexToLLVM.
102struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> {
103 using OpConversionPattern::OpConversionPattern;
104
105 LogicalResult
106 matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor,
107 ConversionPatternRewriter &rewriter) const override {
108 Location loc = op.getLoc();
109 Value n = adaptor.getLhs();
110 Type n_type = n.getType();
111 Value m = adaptor.getRhs();
112
113 // Define the constants
114 Value zero = rewriter.create<spirv::ConstantOp>(
115 location: loc, args&: n_type, args: IntegerAttr::get(type: n_type, value: 0));
116 Value posOne = rewriter.create<spirv::ConstantOp>(
117 location: loc, args&: n_type, args: IntegerAttr::get(type: n_type, value: 1));
118 Value negOne = rewriter.create<spirv::ConstantOp>(
119 location: loc, args&: n_type, args: IntegerAttr::get(type: n_type, value: -1));
120
121 // Compute `x`.
122 Value mPos = rewriter.create<spirv::SGreaterThanOp>(location: loc, args&: m, args&: zero);
123 Value x = rewriter.create<spirv::SelectOp>(location: loc, args&: mPos, args&: negOne, args&: posOne);
124
125 // Compute the positive result.
126 Value nPlusX = rewriter.create<spirv::IAddOp>(location: loc, args&: n, args&: x);
127 Value nPlusXDivM = rewriter.create<spirv::SDivOp>(location: loc, args&: nPlusX, args&: m);
128 Value posRes = rewriter.create<spirv::IAddOp>(location: loc, args&: nPlusXDivM, args&: posOne);
129
130 // Compute the negative result.
131 Value negN = rewriter.create<spirv::ISubOp>(location: loc, args&: zero, args&: n);
132 Value negNDivM = rewriter.create<spirv::SDivOp>(location: loc, args&: negN, args&: m);
133 Value negRes = rewriter.create<spirv::ISubOp>(location: loc, args&: zero, args&: negNDivM);
134
135 // Pick the positive result if `n` and `m` have the same sign and `n` is
136 // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
137 Value nPos = rewriter.create<spirv::SGreaterThanOp>(location: loc, args&: n, args&: zero);
138 Value sameSign = rewriter.create<spirv::LogicalEqualOp>(location: loc, args&: nPos, args&: mPos);
139 Value nNonZero = rewriter.create<spirv::INotEqualOp>(location: loc, args&: n, args&: zero);
140 Value cmp = rewriter.create<spirv::LogicalAndOp>(location: loc, args&: sameSign, args&: nNonZero);
141 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, args&: cmp, args&: posRes, args&: negRes);
142 return success();
143 }
144};
145
146//===----------------------------------------------------------------------===//
147// ConvertIndexCeilDivU
148//===----------------------------------------------------------------------===//
149
150/// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`. Formula taken
151/// from the equivalent conversion in IndexToLLVM.
152struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> {
153 using OpConversionPattern::OpConversionPattern;
154
155 LogicalResult
156 matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor,
157 ConversionPatternRewriter &rewriter) const override {
158 Location loc = op.getLoc();
159 Value n = adaptor.getLhs();
160 Type n_type = n.getType();
161 Value m = adaptor.getRhs();
162
163 // Define the constants
164 Value zero = rewriter.create<spirv::ConstantOp>(
165 location: loc, args&: n_type, args: IntegerAttr::get(type: n_type, value: 0));
166 Value one = rewriter.create<spirv::ConstantOp>(location: loc, args&: n_type,
167 args: IntegerAttr::get(type: n_type, value: 1));
168
169 // Compute the non-zero result.
170 Value minusOne = rewriter.create<spirv::ISubOp>(location: loc, args&: n, args&: one);
171 Value quotient = rewriter.create<spirv::UDivOp>(location: loc, args&: minusOne, args&: m);
172 Value plusOne = rewriter.create<spirv::IAddOp>(location: loc, args&: quotient, args&: one);
173
174 // Pick the result
175 Value cmp = rewriter.create<spirv::IEqualOp>(location: loc, args&: n, args&: zero);
176 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, args&: cmp, args&: zero, args&: plusOne);
177 return success();
178 }
179};
180
181//===----------------------------------------------------------------------===//
182// ConvertIndexFloorDivS
183//===----------------------------------------------------------------------===//
184
185/// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then
186/// `n*m < 0 ? -1 - (x-n)/m : n/m`. Formula taken from the equivalent conversion
187/// in IndexToLLVM.
188struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> {
189 using OpConversionPattern::OpConversionPattern;
190
191 LogicalResult
192 matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor,
193 ConversionPatternRewriter &rewriter) const override {
194 Location loc = op.getLoc();
195 Value n = adaptor.getLhs();
196 Type n_type = n.getType();
197 Value m = adaptor.getRhs();
198
199 // Define the constants
200 Value zero = rewriter.create<spirv::ConstantOp>(
201 location: loc, args&: n_type, args: IntegerAttr::get(type: n_type, value: 0));
202 Value posOne = rewriter.create<spirv::ConstantOp>(
203 location: loc, args&: n_type, args: IntegerAttr::get(type: n_type, value: 1));
204 Value negOne = rewriter.create<spirv::ConstantOp>(
205 location: loc, args&: n_type, args: IntegerAttr::get(type: n_type, value: -1));
206
207 // Compute `x`.
208 Value mNeg = rewriter.create<spirv::SLessThanOp>(location: loc, args&: m, args&: zero);
209 Value x = rewriter.create<spirv::SelectOp>(location: loc, args&: mNeg, args&: posOne, args&: negOne);
210
211 // Compute the negative result
212 Value xMinusN = rewriter.create<spirv::ISubOp>(location: loc, args&: x, args&: n);
213 Value xMinusNDivM = rewriter.create<spirv::SDivOp>(location: loc, args&: xMinusN, args&: m);
214 Value negRes = rewriter.create<spirv::ISubOp>(location: loc, args&: negOne, args&: xMinusNDivM);
215
216 // Compute the positive result.
217 Value posRes = rewriter.create<spirv::SDivOp>(location: loc, args&: n, args&: m);
218
219 // Pick the negative result if `n` and `m` have different signs and `n` is
220 // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
221 Value nNeg = rewriter.create<spirv::SLessThanOp>(location: loc, args&: n, args&: zero);
222 Value diffSign = rewriter.create<spirv::LogicalNotEqualOp>(location: loc, args&: nNeg, args&: mNeg);
223 Value nNonZero = rewriter.create<spirv::INotEqualOp>(location: loc, args&: n, args&: zero);
224
225 Value cmp = rewriter.create<spirv::LogicalAndOp>(location: loc, args&: diffSign, args&: nNonZero);
226 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, args&: cmp, args&: posRes, args&: negRes);
227 return success();
228 }
229};
230
231//===----------------------------------------------------------------------===//
232// ConvertIndexCast
233//===----------------------------------------------------------------------===//
234
235/// Convert a cast op. If the materialized index type is the same as the other
236/// type, fold away the op. Otherwise, use the Convert SPIR-V operation.
237/// Signed casts sign extend when the result bitwidth is larger. Unsigned casts
238/// zero extend when the result bitwidth is larger.
239template <typename CastOp, typename ConvertOp>
240struct ConvertIndexCast final : OpConversionPattern<CastOp> {
241 using OpConversionPattern<CastOp>::OpConversionPattern;
242
243 LogicalResult
244 matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
245 ConversionPatternRewriter &rewriter) const override {
246 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
247 Type indexType = typeConverter->getIndexType();
248
249 Type srcType = adaptor.getInput().getType();
250 Type dstType = op.getType();
251 if (isa<IndexType>(Val: srcType)) {
252 srcType = indexType;
253 }
254 if (isa<IndexType>(Val: dstType)) {
255 dstType = indexType;
256 }
257
258 if (srcType == dstType) {
259 rewriter.replaceOp(op, adaptor.getInput());
260 } else {
261 rewriter.template replaceOpWithNewOp<ConvertOp>(op, dstType,
262 adaptor.getOperands());
263 }
264 return success();
265 }
266};
267
268using ConvertIndexCastS = ConvertIndexCast<CastSOp, spirv::SConvertOp>;
269using ConvertIndexCastU = ConvertIndexCast<CastUOp, spirv::UConvertOp>;
270
271//===----------------------------------------------------------------------===//
272// ConvertIndexCmp
273//===----------------------------------------------------------------------===//
274
275// Helper template to replace the operation
276template <typename ICmpOp>
277static LogicalResult rewriteCmpOp(CmpOp op, CmpOpAdaptor adaptor,
278 ConversionPatternRewriter &rewriter) {
279 rewriter.replaceOpWithNewOp<ICmpOp>(op, adaptor.getLhs(), adaptor.getRhs());
280 return success();
281}
282
283struct ConvertIndexCmpPattern final : OpConversionPattern<CmpOp> {
284 using OpConversionPattern::OpConversionPattern;
285
286 LogicalResult
287 matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter) const override {
289 // We must convert the predicates to the corresponding int comparions.
290 switch (op.getPred()) {
291 case IndexCmpPredicate::EQ:
292 return rewriteCmpOp<spirv::IEqualOp>(op, adaptor, rewriter);
293 case IndexCmpPredicate::NE:
294 return rewriteCmpOp<spirv::INotEqualOp>(op, adaptor, rewriter);
295 case IndexCmpPredicate::SGE:
296 return rewriteCmpOp<spirv::SGreaterThanEqualOp>(op, adaptor, rewriter);
297 case IndexCmpPredicate::SGT:
298 return rewriteCmpOp<spirv::SGreaterThanOp>(op, adaptor, rewriter);
299 case IndexCmpPredicate::SLE:
300 return rewriteCmpOp<spirv::SLessThanEqualOp>(op, adaptor, rewriter);
301 case IndexCmpPredicate::SLT:
302 return rewriteCmpOp<spirv::SLessThanOp>(op, adaptor, rewriter);
303 case IndexCmpPredicate::UGE:
304 return rewriteCmpOp<spirv::UGreaterThanEqualOp>(op, adaptor, rewriter);
305 case IndexCmpPredicate::UGT:
306 return rewriteCmpOp<spirv::UGreaterThanOp>(op, adaptor, rewriter);
307 case IndexCmpPredicate::ULE:
308 return rewriteCmpOp<spirv::ULessThanEqualOp>(op, adaptor, rewriter);
309 case IndexCmpPredicate::ULT:
310 return rewriteCmpOp<spirv::ULessThanOp>(op, adaptor, rewriter);
311 }
312 llvm_unreachable("Unknown predicate in ConvertIndexCmpPattern");
313 }
314};
315
316//===----------------------------------------------------------------------===//
317// ConvertIndexSizeOf
318//===----------------------------------------------------------------------===//
319
320/// Lower `index.sizeof` to a constant with the value of the index bitwidth.
321struct ConvertIndexSizeOf final : OpConversionPattern<SizeOfOp> {
322 using OpConversionPattern::OpConversionPattern;
323
324 LogicalResult
325 matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor,
326 ConversionPatternRewriter &rewriter) const override {
327 auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
328 Type indexType = typeConverter->getIndexType();
329 unsigned bitwidth = typeConverter->getIndexTypeBitwidth();
330 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
331 op, args&: indexType, args: IntegerAttr::get(type: indexType, value: bitwidth));
332 return success();
333 }
334};
335} // namespace
336
337//===----------------------------------------------------------------------===//
338// Pattern Population
339//===----------------------------------------------------------------------===//
340
341void index::populateIndexToSPIRVPatterns(
342 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
343 patterns.add<
344 // clang-format off
345 ConvertIndexAdd,
346 ConvertIndexSub,
347 ConvertIndexMul,
348 ConvertIndexDivS,
349 ConvertIndexDivU,
350 ConvertIndexRemS,
351 ConvertIndexRemU,
352 ConvertIndexMaxS,
353 ConvertIndexMaxU,
354 ConvertIndexMinS,
355 ConvertIndexMinU,
356 ConvertIndexShl,
357 ConvertIndexShrS,
358 ConvertIndexShrU,
359 ConvertIndexAnd,
360 ConvertIndexOr,
361 ConvertIndexXor,
362 ConvertIndexConstantBoolOpPattern,
363 ConvertIndexConstantOpPattern,
364 ConvertIndexCeilDivSPattern,
365 ConvertIndexCeilDivUPattern,
366 ConvertIndexFloorDivSPattern,
367 ConvertIndexCastS,
368 ConvertIndexCastU,
369 ConvertIndexCmpPattern,
370 ConvertIndexSizeOf
371 >(arg: typeConverter, args: patterns.getContext());
372}
373
374//===----------------------------------------------------------------------===//
375// ODS-Generated Definitions
376//===----------------------------------------------------------------------===//
377
378namespace mlir {
379#define GEN_PASS_DEF_CONVERTINDEXTOSPIRVPASS
380#include "mlir/Conversion/Passes.h.inc"
381} // namespace mlir
382
383//===----------------------------------------------------------------------===//
384// Pass Definition
385//===----------------------------------------------------------------------===//
386
387namespace {
388struct ConvertIndexToSPIRVPass
389 : public impl::ConvertIndexToSPIRVPassBase<ConvertIndexToSPIRVPass> {
390 using Base::Base;
391
392 void runOnOperation() override {
393 Operation *op = getOperation();
394 spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
395 std::unique_ptr<SPIRVConversionTarget> target =
396 SPIRVConversionTarget::get(targetAttr);
397
398 SPIRVConversionOptions options;
399 options.use64bitIndex = this->use64bitIndex;
400 SPIRVTypeConverter typeConverter(targetAttr, options);
401
402 // Use UnrealizedConversionCast as the bridge so that we don't need to pull
403 // in patterns for other dialects.
404 target->addLegalOp<UnrealizedConversionCastOp>();
405
406 // Allow the spirv operations we are converting to
407 target->addLegalDialect<spirv::SPIRVDialect>();
408 // Fail hard when there are any remaining 'index' ops.
409 target->addIllegalDialect<index::IndexDialect>();
410
411 RewritePatternSet patterns(&getContext());
412 index::populateIndexToSPIRVPatterns(typeConverter, patterns);
413
414 if (failed(Result: applyPartialConversion(op, target: *target, patterns: std::move(patterns))))
415 signalPassFailure();
416 }
417};
418} // namespace
419

source code of mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp