1//===- IndexToLLVM.cpp - Index to LLVM dialect conversion -------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
10
11#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
12#include "mlir/Conversion/LLVMCommon/Pattern.h"
13#include "mlir/Dialect/Index/IR/IndexAttrs.h"
14#include "mlir/Dialect/Index/IR/IndexDialect.h"
15#include "mlir/Dialect/Index/IR/IndexOps.h"
16#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17#include "mlir/Pass/Pass.h"
18
19using namespace mlir;
20using namespace index;
21
22namespace {
23
24//===----------------------------------------------------------------------===//
25// ConvertIndexCeilDivS
26//===----------------------------------------------------------------------===//
27
28/// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then
29/// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`.
30struct ConvertIndexCeilDivS : mlir::ConvertOpToLLVMPattern<CeilDivSOp> {
31 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
32
33 LogicalResult
34 matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor,
35 ConversionPatternRewriter &rewriter) const override {
36 Location loc = op.getLoc();
37 Value n = adaptor.getLhs();
38 Value m = adaptor.getRhs();
39 Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0);
40 Value posOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1);
41 Value negOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), -1);
42
43 // Compute `x`.
44 Value mPos =
45 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, m, zero);
46 Value x = rewriter.create<LLVM::SelectOp>(loc, mPos, negOne, posOne);
47
48 // Compute the positive result.
49 Value nPlusX = rewriter.create<LLVM::AddOp>(loc, n, x);
50 Value nPlusXDivM = rewriter.create<LLVM::SDivOp>(loc, nPlusX, m);
51 Value posRes = rewriter.create<LLVM::AddOp>(loc, nPlusXDivM, posOne);
52
53 // Compute the negative result.
54 Value negN = rewriter.create<LLVM::SubOp>(loc, zero, n);
55 Value negNDivM = rewriter.create<LLVM::SDivOp>(loc, negN, m);
56 Value negRes = rewriter.create<LLVM::SubOp>(loc, zero, negNDivM);
57
58 // Pick the positive result if `n` and `m` have the same sign and `n` is
59 // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
60 Value nPos =
61 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, n, zero);
62 Value sameSign =
63 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, nPos, mPos);
64 Value nNonZero =
65 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, n, zero);
66 Value cmp = rewriter.create<LLVM::AndOp>(loc, sameSign, nNonZero);
67 rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, posRes, negRes);
68 return success();
69 }
70};
71
72//===----------------------------------------------------------------------===//
73// ConvertIndexCeilDivU
74//===----------------------------------------------------------------------===//
75
76/// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`.
77struct ConvertIndexCeilDivU : mlir::ConvertOpToLLVMPattern<CeilDivUOp> {
78 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
79
80 LogicalResult
81 matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor,
82 ConversionPatternRewriter &rewriter) const override {
83 Location loc = op.getLoc();
84 Value n = adaptor.getLhs();
85 Value m = adaptor.getRhs();
86 Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0);
87 Value one = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1);
88
89 // Compute the non-zero result.
90 Value minusOne = rewriter.create<LLVM::SubOp>(loc, n, one);
91 Value quotient = rewriter.create<LLVM::UDivOp>(loc, minusOne, m);
92 Value plusOne = rewriter.create<LLVM::AddOp>(loc, quotient, one);
93
94 // Pick the result.
95 Value cmp =
96 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, n, zero);
97 rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, zero, plusOne);
98 return success();
99 }
100};
101
102//===----------------------------------------------------------------------===//
103// ConvertIndexFloorDivS
104//===----------------------------------------------------------------------===//
105
106/// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then
107/// `n*m < 0 ? -1 - (x-n)/m : n/m`.
108struct ConvertIndexFloorDivS : mlir::ConvertOpToLLVMPattern<FloorDivSOp> {
109 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
110
111 LogicalResult
112 matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor,
113 ConversionPatternRewriter &rewriter) const override {
114 Location loc = op.getLoc();
115 Value n = adaptor.getLhs();
116 Value m = adaptor.getRhs();
117 Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0);
118 Value posOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1);
119 Value negOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), -1);
120
121 // Compute `x`.
122 Value mNeg =
123 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, m, zero);
124 Value x = rewriter.create<LLVM::SelectOp>(loc, mNeg, posOne, negOne);
125
126 // Compute the negative result.
127 Value xMinusN = rewriter.create<LLVM::SubOp>(loc, x, n);
128 Value xMinusNDivM = rewriter.create<LLVM::SDivOp>(loc, xMinusN, m);
129 Value negRes = rewriter.create<LLVM::SubOp>(loc, negOne, xMinusNDivM);
130
131 // Compute the positive result.
132 Value posRes = rewriter.create<LLVM::SDivOp>(loc, n, m);
133
134 // Pick the negative result if `n` and `m` have different signs and `n` is
135 // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
136 Value nNeg =
137 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, n, zero);
138 Value diffSign =
139 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, nNeg, mNeg);
140 Value nNonZero =
141 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, n, zero);
142 Value cmp = rewriter.create<LLVM::AndOp>(loc, diffSign, nNonZero);
143 rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, negRes, posRes);
144 return success();
145 }
146};
147
148//===----------------------------------------------------------------------===//
149// CovnertIndexCast
150//===----------------------------------------------------------------------===//
151
152/// Convert a cast op. If the materialized index type is the same as the other
153/// type, fold away the op. Otherwise, truncate or extend the op as appropriate.
154/// Signed casts sign extend when the result bitwidth is larger. Unsigned casts
155/// zero extend when the result bitwidth is larger.
156template <typename CastOp, typename ExtOp>
157struct ConvertIndexCast : public mlir::ConvertOpToLLVMPattern<CastOp> {
158 using mlir::ConvertOpToLLVMPattern<CastOp>::ConvertOpToLLVMPattern;
159
160 LogicalResult
161 matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
162 ConversionPatternRewriter &rewriter) const override {
163 Type in = adaptor.getInput().getType();
164 Type out = this->getTypeConverter()->convertType(op.getType());
165 if (in == out)
166 rewriter.replaceOp(op, adaptor.getInput());
167 else if (in.getIntOrFloatBitWidth() > out.getIntOrFloatBitWidth())
168 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, out, adaptor.getInput());
169 else
170 rewriter.replaceOpWithNewOp<ExtOp>(op, out, adaptor.getInput());
171 return success();
172 }
173};
174
175using ConvertIndexCastS = ConvertIndexCast<CastSOp, LLVM::SExtOp>;
176using ConvertIndexCastU = ConvertIndexCast<CastUOp, LLVM::ZExtOp>;
177
178//===----------------------------------------------------------------------===//
179// ConvertIndexCmp
180//===----------------------------------------------------------------------===//
181
182/// Assert that the LLVM comparison enum lines up with index's enum.
183static constexpr bool checkPredicates(LLVM::ICmpPredicate lhs,
184 IndexCmpPredicate rhs) {
185 return static_cast<int>(lhs) == static_cast<int>(rhs);
186}
187
188static_assert(
189 LLVM::getMaxEnumValForICmpPredicate() ==
190 getMaxEnumValForIndexCmpPredicate() &&
191 checkPredicates(LLVM::ICmpPredicate::eq, IndexCmpPredicate::EQ) &&
192 checkPredicates(LLVM::ICmpPredicate::ne, IndexCmpPredicate::NE) &&
193 checkPredicates(LLVM::ICmpPredicate::sge, IndexCmpPredicate::SGE) &&
194 checkPredicates(LLVM::ICmpPredicate::sgt, IndexCmpPredicate::SGT) &&
195 checkPredicates(LLVM::ICmpPredicate::sle, IndexCmpPredicate::SLE) &&
196 checkPredicates(LLVM::ICmpPredicate::slt, IndexCmpPredicate::SLT) &&
197 checkPredicates(LLVM::ICmpPredicate::uge, IndexCmpPredicate::UGE) &&
198 checkPredicates(LLVM::ICmpPredicate::ugt, IndexCmpPredicate::UGT) &&
199 checkPredicates(LLVM::ICmpPredicate::ule, IndexCmpPredicate::ULE) &&
200 checkPredicates(LLVM::ICmpPredicate::ult, IndexCmpPredicate::ULT),
201 "LLVM ICmpPredicate mismatches IndexCmpPredicate");
202
203struct ConvertIndexCmp : public mlir::ConvertOpToLLVMPattern<CmpOp> {
204 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
205
206 LogicalResult
207 matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor,
208 ConversionPatternRewriter &rewriter) const override {
209 // The LLVM enum has the same values as the index predicate enums.
210 rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
211 op, *LLVM::symbolizeICmpPredicate(static_cast<uint32_t>(op.getPred())),
212 adaptor.getLhs(), adaptor.getRhs());
213 return success();
214 }
215};
216
217//===----------------------------------------------------------------------===//
218// ConvertIndexSizeOf
219//===----------------------------------------------------------------------===//
220
221/// Lower `index.sizeof` to a constant with the value of the index bitwidth.
222struct ConvertIndexSizeOf : public mlir::ConvertOpToLLVMPattern<SizeOfOp> {
223 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
224
225 LogicalResult
226 matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor,
227 ConversionPatternRewriter &rewriter) const override {
228 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
229 op, getTypeConverter()->getIndexType(),
230 getTypeConverter()->getIndexTypeBitwidth());
231 return success();
232 }
233};
234
235//===----------------------------------------------------------------------===//
236// ConvertIndexConstant
237//===----------------------------------------------------------------------===//
238
239/// Convert an index constant. Truncate the value as appropriate.
240struct ConvertIndexConstant : public mlir::ConvertOpToLLVMPattern<ConstantOp> {
241 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
242
243 LogicalResult
244 matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor,
245 ConversionPatternRewriter &rewriter) const override {
246 Type type = getTypeConverter()->getIndexType();
247 APInt value = op.getValue().trunc(type.getIntOrFloatBitWidth());
248 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
249 op, type, IntegerAttr::get(type, value));
250 return success();
251 }
252};
253
254//===----------------------------------------------------------------------===//
255// Trivial Conversions
256//===----------------------------------------------------------------------===//
257
258using ConvertIndexAdd = mlir::OneToOneConvertToLLVMPattern<AddOp, LLVM::AddOp>;
259using ConvertIndexSub = mlir::OneToOneConvertToLLVMPattern<SubOp, LLVM::SubOp>;
260using ConvertIndexMul = mlir::OneToOneConvertToLLVMPattern<MulOp, LLVM::MulOp>;
261using ConvertIndexDivS =
262 mlir::OneToOneConvertToLLVMPattern<DivSOp, LLVM::SDivOp>;
263using ConvertIndexDivU =
264 mlir::OneToOneConvertToLLVMPattern<DivUOp, LLVM::UDivOp>;
265using ConvertIndexRemS =
266 mlir::OneToOneConvertToLLVMPattern<RemSOp, LLVM::SRemOp>;
267using ConvertIndexRemU =
268 mlir::OneToOneConvertToLLVMPattern<RemUOp, LLVM::URemOp>;
269using ConvertIndexMaxS =
270 mlir::OneToOneConvertToLLVMPattern<MaxSOp, LLVM::SMaxOp>;
271using ConvertIndexMaxU =
272 mlir::OneToOneConvertToLLVMPattern<MaxUOp, LLVM::UMaxOp>;
273using ConvertIndexMinS =
274 mlir::OneToOneConvertToLLVMPattern<MinSOp, LLVM::SMinOp>;
275using ConvertIndexMinU =
276 mlir::OneToOneConvertToLLVMPattern<MinUOp, LLVM::UMinOp>;
277using ConvertIndexShl = mlir::OneToOneConvertToLLVMPattern<ShlOp, LLVM::ShlOp>;
278using ConvertIndexShrS =
279 mlir::OneToOneConvertToLLVMPattern<ShrSOp, LLVM::AShrOp>;
280using ConvertIndexShrU =
281 mlir::OneToOneConvertToLLVMPattern<ShrUOp, LLVM::LShrOp>;
282using ConvertIndexAnd = mlir::OneToOneConvertToLLVMPattern<AndOp, LLVM::AndOp>;
283using ConvertIndexOr = mlir::OneToOneConvertToLLVMPattern<OrOp, LLVM::OrOp>;
284using ConvertIndexXor = mlir::OneToOneConvertToLLVMPattern<XOrOp, LLVM::XOrOp>;
285using ConvertIndexBoolConstant =
286 mlir::OneToOneConvertToLLVMPattern<BoolConstantOp, LLVM::ConstantOp>;
287
288} // namespace
289
290//===----------------------------------------------------------------------===//
291// Pattern Population
292//===----------------------------------------------------------------------===//
293
294void index::populateIndexToLLVMConversionPatterns(
295 LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
296 patterns.insert<
297 // clang-format off
298 ConvertIndexAdd,
299 ConvertIndexSub,
300 ConvertIndexMul,
301 ConvertIndexDivS,
302 ConvertIndexDivU,
303 ConvertIndexRemS,
304 ConvertIndexRemU,
305 ConvertIndexMaxS,
306 ConvertIndexMaxU,
307 ConvertIndexMinS,
308 ConvertIndexMinU,
309 ConvertIndexShl,
310 ConvertIndexShrS,
311 ConvertIndexShrU,
312 ConvertIndexAnd,
313 ConvertIndexOr,
314 ConvertIndexXor,
315 ConvertIndexCeilDivS,
316 ConvertIndexCeilDivU,
317 ConvertIndexFloorDivS,
318 ConvertIndexCastS,
319 ConvertIndexCastU,
320 ConvertIndexCmp,
321 ConvertIndexSizeOf,
322 ConvertIndexConstant,
323 ConvertIndexBoolConstant
324 // clang-format on
325 >(typeConverter);
326}
327
328//===----------------------------------------------------------------------===//
329// ODS-Generated Definitions
330//===----------------------------------------------------------------------===//
331
332namespace mlir {
333#define GEN_PASS_DEF_CONVERTINDEXTOLLVMPASS
334#include "mlir/Conversion/Passes.h.inc"
335} // namespace mlir
336
337//===----------------------------------------------------------------------===//
338// Pass Definition
339//===----------------------------------------------------------------------===//
340
341namespace {
342struct ConvertIndexToLLVMPass
343 : public impl::ConvertIndexToLLVMPassBase<ConvertIndexToLLVMPass> {
344 using Base::Base;
345
346 void runOnOperation() override;
347};
348} // namespace
349
350void ConvertIndexToLLVMPass::runOnOperation() {
351 // Configure dialect conversion.
352 ConversionTarget target(getContext());
353 target.addIllegalDialect<IndexDialect>();
354 target.addLegalDialect<LLVM::LLVMDialect>();
355
356 // Set LLVM lowering options.
357 LowerToLLVMOptions options(&getContext());
358 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
359 options.overrideIndexBitwidth(indexBitwidth);
360 LLVMTypeConverter typeConverter(&getContext(), options);
361
362 // Populate patterns and run the conversion.
363 RewritePatternSet patterns(&getContext());
364 populateIndexToLLVMConversionPatterns(typeConverter, patterns);
365
366 if (failed(
367 applyPartialConversion(getOperation(), target, std::move(patterns))))
368 return signalPassFailure();
369}
370
371//===----------------------------------------------------------------------===//
372// ConvertToLLVMPatternInterface implementation
373//===----------------------------------------------------------------------===//
374
375namespace {
376/// Implement the interface to convert Index to LLVM.
377struct IndexToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
378 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
379 void loadDependentDialects(MLIRContext *context) const final {
380 context->loadDialect<LLVM::LLVMDialect>();
381 }
382
383 /// Hook for derived dialect interface to provide conversion patterns
384 /// and mark dialect legal for the conversion target.
385 void populateConvertToLLVMConversionPatterns(
386 ConversionTarget &target, LLVMTypeConverter &typeConverter,
387 RewritePatternSet &patterns) const final {
388 populateIndexToLLVMConversionPatterns(typeConverter, patterns);
389 }
390};
391} // namespace
392
393void mlir::index::registerConvertIndexToLLVMInterface(
394 DialectRegistry &registry) {
395 registry.addExtension(extensionFn: +[](MLIRContext *ctx, index::IndexDialect *dialect) {
396 dialect->addInterfaces<IndexToLLVMDialectInterface>();
397 });
398}
399

source code of mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp