1 | //===- IndexToLLVM.cpp - Index to LLVM dialect conversion -------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" |
10 | |
11 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
12 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
13 | #include "mlir/Dialect/Index/IR/IndexAttrs.h" |
14 | #include "mlir/Dialect/Index/IR/IndexDialect.h" |
15 | #include "mlir/Dialect/Index/IR/IndexOps.h" |
16 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
17 | #include "mlir/Pass/Pass.h" |
18 | |
19 | using namespace mlir; |
20 | using namespace index; |
21 | |
22 | namespace { |
23 | |
24 | //===----------------------------------------------------------------------===// |
25 | // ConvertIndexCeilDivS |
26 | //===----------------------------------------------------------------------===// |
27 | |
28 | /// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then |
29 | /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. |
30 | struct ConvertIndexCeilDivS : mlir::ConvertOpToLLVMPattern<CeilDivSOp> { |
31 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
32 | |
33 | LogicalResult |
34 | matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor, |
35 | ConversionPatternRewriter &rewriter) const override { |
36 | Location loc = op.getLoc(); |
37 | Value n = adaptor.getLhs(); |
38 | Value m = adaptor.getRhs(); |
39 | Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0); |
40 | Value posOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1); |
41 | Value negOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), -1); |
42 | |
43 | // Compute `x`. |
44 | Value mPos = |
45 | rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, m, zero); |
46 | Value x = rewriter.create<LLVM::SelectOp>(loc, mPos, negOne, posOne); |
47 | |
48 | // Compute the positive result. |
49 | Value nPlusX = rewriter.create<LLVM::AddOp>(loc, n, x); |
50 | Value nPlusXDivM = rewriter.create<LLVM::SDivOp>(loc, nPlusX, m); |
51 | Value posRes = rewriter.create<LLVM::AddOp>(loc, nPlusXDivM, posOne); |
52 | |
53 | // Compute the negative result. |
54 | Value negN = rewriter.create<LLVM::SubOp>(loc, zero, n); |
55 | Value negNDivM = rewriter.create<LLVM::SDivOp>(loc, negN, m); |
56 | Value negRes = rewriter.create<LLVM::SubOp>(loc, zero, negNDivM); |
57 | |
58 | // Pick the positive result if `n` and `m` have the same sign and `n` is |
59 | // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`. |
60 | Value nPos = |
61 | rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, n, zero); |
62 | Value sameSign = |
63 | rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, nPos, mPos); |
64 | Value nNonZero = |
65 | rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, n, zero); |
66 | Value cmp = rewriter.create<LLVM::AndOp>(loc, sameSign, nNonZero); |
67 | rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, posRes, negRes); |
68 | return success(); |
69 | } |
70 | }; |
71 | |
72 | //===----------------------------------------------------------------------===// |
73 | // ConvertIndexCeilDivU |
74 | //===----------------------------------------------------------------------===// |
75 | |
76 | /// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`. |
77 | struct ConvertIndexCeilDivU : mlir::ConvertOpToLLVMPattern<CeilDivUOp> { |
78 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
79 | |
80 | LogicalResult |
81 | matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor, |
82 | ConversionPatternRewriter &rewriter) const override { |
83 | Location loc = op.getLoc(); |
84 | Value n = adaptor.getLhs(); |
85 | Value m = adaptor.getRhs(); |
86 | Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0); |
87 | Value one = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1); |
88 | |
89 | // Compute the non-zero result. |
90 | Value minusOne = rewriter.create<LLVM::SubOp>(loc, n, one); |
91 | Value quotient = rewriter.create<LLVM::UDivOp>(loc, minusOne, m); |
92 | Value plusOne = rewriter.create<LLVM::AddOp>(loc, quotient, one); |
93 | |
94 | // Pick the result. |
95 | Value cmp = |
96 | rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, n, zero); |
97 | rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, zero, plusOne); |
98 | return success(); |
99 | } |
100 | }; |
101 | |
102 | //===----------------------------------------------------------------------===// |
103 | // ConvertIndexFloorDivS |
104 | //===----------------------------------------------------------------------===// |
105 | |
106 | /// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then |
107 | /// `n*m < 0 ? -1 - (x-n)/m : n/m`. |
108 | struct ConvertIndexFloorDivS : mlir::ConvertOpToLLVMPattern<FloorDivSOp> { |
109 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
110 | |
111 | LogicalResult |
112 | matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor, |
113 | ConversionPatternRewriter &rewriter) const override { |
114 | Location loc = op.getLoc(); |
115 | Value n = adaptor.getLhs(); |
116 | Value m = adaptor.getRhs(); |
117 | Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0); |
118 | Value posOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1); |
119 | Value negOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), -1); |
120 | |
121 | // Compute `x`. |
122 | Value mNeg = |
123 | rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, m, zero); |
124 | Value x = rewriter.create<LLVM::SelectOp>(loc, mNeg, posOne, negOne); |
125 | |
126 | // Compute the negative result. |
127 | Value xMinusN = rewriter.create<LLVM::SubOp>(loc, x, n); |
128 | Value xMinusNDivM = rewriter.create<LLVM::SDivOp>(loc, xMinusN, m); |
129 | Value negRes = rewriter.create<LLVM::SubOp>(loc, negOne, xMinusNDivM); |
130 | |
131 | // Compute the positive result. |
132 | Value posRes = rewriter.create<LLVM::SDivOp>(loc, n, m); |
133 | |
134 | // Pick the negative result if `n` and `m` have different signs and `n` is |
135 | // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`. |
136 | Value nNeg = |
137 | rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, n, zero); |
138 | Value diffSign = |
139 | rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, nNeg, mNeg); |
140 | Value nNonZero = |
141 | rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, n, zero); |
142 | Value cmp = rewriter.create<LLVM::AndOp>(loc, diffSign, nNonZero); |
143 | rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, negRes, posRes); |
144 | return success(); |
145 | } |
146 | }; |
147 | |
148 | //===----------------------------------------------------------------------===// |
149 | // CovnertIndexCast |
150 | //===----------------------------------------------------------------------===// |
151 | |
152 | /// Convert a cast op. If the materialized index type is the same as the other |
153 | /// type, fold away the op. Otherwise, truncate or extend the op as appropriate. |
154 | /// Signed casts sign extend when the result bitwidth is larger. Unsigned casts |
155 | /// zero extend when the result bitwidth is larger. |
156 | template <typename CastOp, typename ExtOp> |
157 | struct ConvertIndexCast : public mlir::ConvertOpToLLVMPattern<CastOp> { |
158 | using mlir::ConvertOpToLLVMPattern<CastOp>::ConvertOpToLLVMPattern; |
159 | |
160 | LogicalResult |
161 | matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor, |
162 | ConversionPatternRewriter &rewriter) const override { |
163 | Type in = adaptor.getInput().getType(); |
164 | Type out = this->getTypeConverter()->convertType(op.getType()); |
165 | if (in == out) |
166 | rewriter.replaceOp(op, adaptor.getInput()); |
167 | else if (in.getIntOrFloatBitWidth() > out.getIntOrFloatBitWidth()) |
168 | rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, out, adaptor.getInput()); |
169 | else |
170 | rewriter.replaceOpWithNewOp<ExtOp>(op, out, adaptor.getInput()); |
171 | return success(); |
172 | } |
173 | }; |
174 | |
175 | using ConvertIndexCastS = ConvertIndexCast<CastSOp, LLVM::SExtOp>; |
176 | using ConvertIndexCastU = ConvertIndexCast<CastUOp, LLVM::ZExtOp>; |
177 | |
178 | //===----------------------------------------------------------------------===// |
179 | // ConvertIndexCmp |
180 | //===----------------------------------------------------------------------===// |
181 | |
182 | /// Assert that the LLVM comparison enum lines up with index's enum. |
183 | static constexpr bool checkPredicates(LLVM::ICmpPredicate lhs, |
184 | IndexCmpPredicate rhs) { |
185 | return static_cast<int>(lhs) == static_cast<int>(rhs); |
186 | } |
187 | |
188 | static_assert( |
189 | LLVM::getMaxEnumValForICmpPredicate() == |
190 | getMaxEnumValForIndexCmpPredicate() && |
191 | checkPredicates(LLVM::ICmpPredicate::eq, IndexCmpPredicate::EQ) && |
192 | checkPredicates(LLVM::ICmpPredicate::ne, IndexCmpPredicate::NE) && |
193 | checkPredicates(LLVM::ICmpPredicate::sge, IndexCmpPredicate::SGE) && |
194 | checkPredicates(LLVM::ICmpPredicate::sgt, IndexCmpPredicate::SGT) && |
195 | checkPredicates(LLVM::ICmpPredicate::sle, IndexCmpPredicate::SLE) && |
196 | checkPredicates(LLVM::ICmpPredicate::slt, IndexCmpPredicate::SLT) && |
197 | checkPredicates(LLVM::ICmpPredicate::uge, IndexCmpPredicate::UGE) && |
198 | checkPredicates(LLVM::ICmpPredicate::ugt, IndexCmpPredicate::UGT) && |
199 | checkPredicates(LLVM::ICmpPredicate::ule, IndexCmpPredicate::ULE) && |
200 | checkPredicates(LLVM::ICmpPredicate::ult, IndexCmpPredicate::ULT), |
201 | "LLVM ICmpPredicate mismatches IndexCmpPredicate" ); |
202 | |
203 | struct ConvertIndexCmp : public mlir::ConvertOpToLLVMPattern<CmpOp> { |
204 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
205 | |
206 | LogicalResult |
207 | matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor, |
208 | ConversionPatternRewriter &rewriter) const override { |
209 | // The LLVM enum has the same values as the index predicate enums. |
210 | rewriter.replaceOpWithNewOp<LLVM::ICmpOp>( |
211 | op, *LLVM::symbolizeICmpPredicate(static_cast<uint32_t>(op.getPred())), |
212 | adaptor.getLhs(), adaptor.getRhs()); |
213 | return success(); |
214 | } |
215 | }; |
216 | |
217 | //===----------------------------------------------------------------------===// |
218 | // ConvertIndexSizeOf |
219 | //===----------------------------------------------------------------------===// |
220 | |
221 | /// Lower `index.sizeof` to a constant with the value of the index bitwidth. |
222 | struct ConvertIndexSizeOf : public mlir::ConvertOpToLLVMPattern<SizeOfOp> { |
223 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
224 | |
225 | LogicalResult |
226 | matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor, |
227 | ConversionPatternRewriter &rewriter) const override { |
228 | rewriter.replaceOpWithNewOp<LLVM::ConstantOp>( |
229 | op, getTypeConverter()->getIndexType(), |
230 | getTypeConverter()->getIndexTypeBitwidth()); |
231 | return success(); |
232 | } |
233 | }; |
234 | |
235 | //===----------------------------------------------------------------------===// |
236 | // ConvertIndexConstant |
237 | //===----------------------------------------------------------------------===// |
238 | |
239 | /// Convert an index constant. Truncate the value as appropriate. |
240 | struct ConvertIndexConstant : public mlir::ConvertOpToLLVMPattern<ConstantOp> { |
241 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
242 | |
243 | LogicalResult |
244 | matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor, |
245 | ConversionPatternRewriter &rewriter) const override { |
246 | Type type = getTypeConverter()->getIndexType(); |
247 | APInt value = op.getValue().trunc(type.getIntOrFloatBitWidth()); |
248 | rewriter.replaceOpWithNewOp<LLVM::ConstantOp>( |
249 | op, type, IntegerAttr::get(type, value)); |
250 | return success(); |
251 | } |
252 | }; |
253 | |
254 | //===----------------------------------------------------------------------===// |
255 | // Trivial Conversions |
256 | //===----------------------------------------------------------------------===// |
257 | |
258 | using ConvertIndexAdd = mlir::OneToOneConvertToLLVMPattern<AddOp, LLVM::AddOp>; |
259 | using ConvertIndexSub = mlir::OneToOneConvertToLLVMPattern<SubOp, LLVM::SubOp>; |
260 | using ConvertIndexMul = mlir::OneToOneConvertToLLVMPattern<MulOp, LLVM::MulOp>; |
261 | using ConvertIndexDivS = |
262 | mlir::OneToOneConvertToLLVMPattern<DivSOp, LLVM::SDivOp>; |
263 | using ConvertIndexDivU = |
264 | mlir::OneToOneConvertToLLVMPattern<DivUOp, LLVM::UDivOp>; |
265 | using ConvertIndexRemS = |
266 | mlir::OneToOneConvertToLLVMPattern<RemSOp, LLVM::SRemOp>; |
267 | using ConvertIndexRemU = |
268 | mlir::OneToOneConvertToLLVMPattern<RemUOp, LLVM::URemOp>; |
269 | using ConvertIndexMaxS = |
270 | mlir::OneToOneConvertToLLVMPattern<MaxSOp, LLVM::SMaxOp>; |
271 | using ConvertIndexMaxU = |
272 | mlir::OneToOneConvertToLLVMPattern<MaxUOp, LLVM::UMaxOp>; |
273 | using ConvertIndexMinS = |
274 | mlir::OneToOneConvertToLLVMPattern<MinSOp, LLVM::SMinOp>; |
275 | using ConvertIndexMinU = |
276 | mlir::OneToOneConvertToLLVMPattern<MinUOp, LLVM::UMinOp>; |
277 | using ConvertIndexShl = mlir::OneToOneConvertToLLVMPattern<ShlOp, LLVM::ShlOp>; |
278 | using ConvertIndexShrS = |
279 | mlir::OneToOneConvertToLLVMPattern<ShrSOp, LLVM::AShrOp>; |
280 | using ConvertIndexShrU = |
281 | mlir::OneToOneConvertToLLVMPattern<ShrUOp, LLVM::LShrOp>; |
282 | using ConvertIndexAnd = mlir::OneToOneConvertToLLVMPattern<AndOp, LLVM::AndOp>; |
283 | using ConvertIndexOr = mlir::OneToOneConvertToLLVMPattern<OrOp, LLVM::OrOp>; |
284 | using ConvertIndexXor = mlir::OneToOneConvertToLLVMPattern<XOrOp, LLVM::XOrOp>; |
285 | using ConvertIndexBoolConstant = |
286 | mlir::OneToOneConvertToLLVMPattern<BoolConstantOp, LLVM::ConstantOp>; |
287 | |
288 | } // namespace |
289 | |
290 | //===----------------------------------------------------------------------===// |
291 | // Pattern Population |
292 | //===----------------------------------------------------------------------===// |
293 | |
294 | void index::populateIndexToLLVMConversionPatterns( |
295 | LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { |
296 | patterns.insert< |
297 | // clang-format off |
298 | ConvertIndexAdd, |
299 | ConvertIndexSub, |
300 | ConvertIndexMul, |
301 | ConvertIndexDivS, |
302 | ConvertIndexDivU, |
303 | ConvertIndexRemS, |
304 | ConvertIndexRemU, |
305 | ConvertIndexMaxS, |
306 | ConvertIndexMaxU, |
307 | ConvertIndexMinS, |
308 | ConvertIndexMinU, |
309 | ConvertIndexShl, |
310 | ConvertIndexShrS, |
311 | ConvertIndexShrU, |
312 | ConvertIndexAnd, |
313 | ConvertIndexOr, |
314 | ConvertIndexXor, |
315 | ConvertIndexCeilDivS, |
316 | ConvertIndexCeilDivU, |
317 | ConvertIndexFloorDivS, |
318 | ConvertIndexCastS, |
319 | ConvertIndexCastU, |
320 | ConvertIndexCmp, |
321 | ConvertIndexSizeOf, |
322 | ConvertIndexConstant, |
323 | ConvertIndexBoolConstant |
324 | // clang-format on |
325 | >(typeConverter); |
326 | } |
327 | |
328 | //===----------------------------------------------------------------------===// |
329 | // ODS-Generated Definitions |
330 | //===----------------------------------------------------------------------===// |
331 | |
332 | namespace mlir { |
333 | #define GEN_PASS_DEF_CONVERTINDEXTOLLVMPASS |
334 | #include "mlir/Conversion/Passes.h.inc" |
335 | } // namespace mlir |
336 | |
337 | //===----------------------------------------------------------------------===// |
338 | // Pass Definition |
339 | //===----------------------------------------------------------------------===// |
340 | |
341 | namespace { |
342 | struct ConvertIndexToLLVMPass |
343 | : public impl::ConvertIndexToLLVMPassBase<ConvertIndexToLLVMPass> { |
344 | using Base::Base; |
345 | |
346 | void runOnOperation() override; |
347 | }; |
348 | } // namespace |
349 | |
350 | void ConvertIndexToLLVMPass::runOnOperation() { |
351 | // Configure dialect conversion. |
352 | ConversionTarget target(getContext()); |
353 | target.addIllegalDialect<IndexDialect>(); |
354 | target.addLegalDialect<LLVM::LLVMDialect>(); |
355 | |
356 | // Set LLVM lowering options. |
357 | LowerToLLVMOptions options(&getContext()); |
358 | if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) |
359 | options.overrideIndexBitwidth(indexBitwidth); |
360 | LLVMTypeConverter typeConverter(&getContext(), options); |
361 | |
362 | // Populate patterns and run the conversion. |
363 | RewritePatternSet patterns(&getContext()); |
364 | populateIndexToLLVMConversionPatterns(typeConverter, patterns); |
365 | |
366 | if (failed( |
367 | applyPartialConversion(getOperation(), target, std::move(patterns)))) |
368 | return signalPassFailure(); |
369 | } |
370 | |
371 | //===----------------------------------------------------------------------===// |
372 | // ConvertToLLVMPatternInterface implementation |
373 | //===----------------------------------------------------------------------===// |
374 | |
375 | namespace { |
376 | /// Implement the interface to convert Index to LLVM. |
377 | struct IndexToLLVMDialectInterface : public ConvertToLLVMPatternInterface { |
378 | using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; |
379 | void loadDependentDialects(MLIRContext *context) const final { |
380 | context->loadDialect<LLVM::LLVMDialect>(); |
381 | } |
382 | |
383 | /// Hook for derived dialect interface to provide conversion patterns |
384 | /// and mark dialect legal for the conversion target. |
385 | void populateConvertToLLVMConversionPatterns( |
386 | ConversionTarget &target, LLVMTypeConverter &typeConverter, |
387 | RewritePatternSet &patterns) const final { |
388 | populateIndexToLLVMConversionPatterns(typeConverter, patterns); |
389 | } |
390 | }; |
391 | } // namespace |
392 | |
393 | void mlir::index::registerConvertIndexToLLVMInterface( |
394 | DialectRegistry ®istry) { |
395 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, index::IndexDialect *dialect) { |
396 | dialect->addInterfaces<IndexToLLVMDialectInterface>(); |
397 | }); |
398 | } |
399 | |