1 | //===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements patterns to convert the Arith dialect to the EmitC |
10 | // dialect. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h" |
15 | |
16 | #include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h" |
17 | #include "mlir/Dialect/Arith/IR/Arith.h" |
18 | #include "mlir/Dialect/EmitC/IR/EmitC.h" |
19 | #include "mlir/Dialect/EmitC/Transforms/TypeConversions.h" |
20 | #include "mlir/IR/BuiltinAttributes.h" |
21 | #include "mlir/IR/BuiltinTypes.h" |
22 | #include "mlir/Transforms/DialectConversion.h" |
23 | |
24 | using namespace mlir; |
25 | |
26 | namespace { |
27 | /// Implement the interface to convert Arith to EmitC. |
28 | struct ArithToEmitCDialectInterface : public ConvertToEmitCPatternInterface { |
29 | using ConvertToEmitCPatternInterface::ConvertToEmitCPatternInterface; |
30 | |
31 | /// Hook for derived dialect interface to provide conversion patterns |
32 | /// and mark dialect legal for the conversion target. |
33 | void populateConvertToEmitCConversionPatterns( |
34 | ConversionTarget &target, TypeConverter &typeConverter, |
35 | RewritePatternSet &patterns) const final { |
36 | populateArithToEmitCPatterns(typeConverter, patterns); |
37 | } |
38 | }; |
39 | } // namespace |
40 | |
41 | void mlir::registerConvertArithToEmitCInterface(DialectRegistry ®istry) { |
42 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, arith::ArithDialect *dialect) { |
43 | dialect->addInterfaces<ArithToEmitCDialectInterface>(); |
44 | }); |
45 | } |
46 | |
47 | //===----------------------------------------------------------------------===// |
48 | // Conversion Patterns |
49 | //===----------------------------------------------------------------------===// |
50 | |
51 | namespace { |
52 | class ArithConstantOpConversionPattern |
53 | : public OpConversionPattern<arith::ConstantOp> { |
54 | public: |
55 | using OpConversionPattern::OpConversionPattern; |
56 | |
57 | LogicalResult |
58 | matchAndRewrite(arith::ConstantOp arithConst, |
59 | arith::ConstantOp::Adaptor adaptor, |
60 | ConversionPatternRewriter &rewriter) const override { |
61 | Type newTy = this->getTypeConverter()->convertType(arithConst.getType()); |
62 | if (!newTy) |
63 | return rewriter.notifyMatchFailure(arithConst, "type conversion failed"); |
64 | rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, newTy, |
65 | adaptor.getValue()); |
66 | return success(); |
67 | } |
68 | }; |
69 | |
70 | /// Get the signed or unsigned type corresponding to \p ty. |
71 | Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) { |
72 | if (isa<IntegerType>(Val: ty)) { |
73 | if (ty.isUnsignedInteger() != needsUnsigned) { |
74 | auto signedness = needsUnsigned |
75 | ? IntegerType::SignednessSemantics::Unsigned |
76 | : IntegerType::SignednessSemantics::Signed; |
77 | return IntegerType::get(ty.getContext(), ty.getIntOrFloatBitWidth(), |
78 | signedness); |
79 | } |
80 | } else if (emitc::isPointerWideType(type: ty)) { |
81 | if (isa<emitc::SizeTType>(ty) != needsUnsigned) { |
82 | if (needsUnsigned) |
83 | return emitc::SizeTType::get(ty.getContext()); |
84 | return emitc::PtrDiffTType::get(ty.getContext()); |
85 | } |
86 | } |
87 | return ty; |
88 | } |
89 | |
90 | /// Insert a cast operation to type \p ty if \p val does not have this type. |
91 | Value adaptValueType(Value val, ConversionPatternRewriter &rewriter, Type ty) { |
92 | return rewriter.createOrFold<emitc::CastOp>(val.getLoc(), ty, val); |
93 | } |
94 | |
95 | class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> { |
96 | public: |
97 | using OpConversionPattern::OpConversionPattern; |
98 | |
99 | LogicalResult |
100 | matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, |
101 | ConversionPatternRewriter &rewriter) const override { |
102 | |
103 | if (!isa<FloatType>(adaptor.getRhs().getType())) { |
104 | return rewriter.notifyMatchFailure(op.getLoc(), |
105 | "cmpf currently only supported on " |
106 | "floats, not tensors/vectors thereof"); |
107 | } |
108 | |
109 | bool unordered = false; |
110 | emitc::CmpPredicate predicate; |
111 | switch (op.getPredicate()) { |
112 | case arith::CmpFPredicate::AlwaysFalse: { |
113 | auto constant = rewriter.create<emitc::ConstantOp>( |
114 | op.getLoc(), rewriter.getI1Type(), |
115 | rewriter.getBoolAttr(/*value=*/false)); |
116 | rewriter.replaceOp(op, constant); |
117 | return success(); |
118 | } |
119 | case arith::CmpFPredicate::OEQ: |
120 | unordered = false; |
121 | predicate = emitc::CmpPredicate::eq; |
122 | break; |
123 | case arith::CmpFPredicate::OGT: |
124 | unordered = false; |
125 | predicate = emitc::CmpPredicate::gt; |
126 | break; |
127 | case arith::CmpFPredicate::OGE: |
128 | unordered = false; |
129 | predicate = emitc::CmpPredicate::ge; |
130 | break; |
131 | case arith::CmpFPredicate::OLT: |
132 | unordered = false; |
133 | predicate = emitc::CmpPredicate::lt; |
134 | break; |
135 | case arith::CmpFPredicate::OLE: |
136 | unordered = false; |
137 | predicate = emitc::CmpPredicate::le; |
138 | break; |
139 | case arith::CmpFPredicate::ONE: |
140 | unordered = false; |
141 | predicate = emitc::CmpPredicate::ne; |
142 | break; |
143 | case arith::CmpFPredicate::ORD: { |
144 | // ordered, i.e. none of the operands is NaN |
145 | auto cmp = createCheckIsOrdered(rewriter, loc: op.getLoc(), first: adaptor.getLhs(), |
146 | second: adaptor.getRhs()); |
147 | rewriter.replaceOp(op, cmp); |
148 | return success(); |
149 | } |
150 | case arith::CmpFPredicate::UEQ: |
151 | unordered = true; |
152 | predicate = emitc::CmpPredicate::eq; |
153 | break; |
154 | case arith::CmpFPredicate::UGT: |
155 | unordered = true; |
156 | predicate = emitc::CmpPredicate::gt; |
157 | break; |
158 | case arith::CmpFPredicate::UGE: |
159 | unordered = true; |
160 | predicate = emitc::CmpPredicate::ge; |
161 | break; |
162 | case arith::CmpFPredicate::ULT: |
163 | unordered = true; |
164 | predicate = emitc::CmpPredicate::lt; |
165 | break; |
166 | case arith::CmpFPredicate::ULE: |
167 | unordered = true; |
168 | predicate = emitc::CmpPredicate::le; |
169 | break; |
170 | case arith::CmpFPredicate::UNE: |
171 | unordered = true; |
172 | predicate = emitc::CmpPredicate::ne; |
173 | break; |
174 | case arith::CmpFPredicate::UNO: { |
175 | // unordered, i.e. either operand is nan |
176 | auto cmp = createCheckIsUnordered(rewriter, loc: op.getLoc(), first: adaptor.getLhs(), |
177 | second: adaptor.getRhs()); |
178 | rewriter.replaceOp(op, cmp); |
179 | return success(); |
180 | } |
181 | case arith::CmpFPredicate::AlwaysTrue: { |
182 | auto constant = rewriter.create<emitc::ConstantOp>( |
183 | op.getLoc(), rewriter.getI1Type(), |
184 | rewriter.getBoolAttr(/*value=*/true)); |
185 | rewriter.replaceOp(op, constant); |
186 | return success(); |
187 | } |
188 | } |
189 | |
190 | // Compare the values naively |
191 | auto cmpResult = |
192 | rewriter.create<emitc::CmpOp>(op.getLoc(), op.getType(), predicate, |
193 | adaptor.getLhs(), adaptor.getRhs()); |
194 | |
195 | // Adjust the results for unordered/ordered semantics |
196 | if (unordered) { |
197 | auto isUnordered = createCheckIsUnordered( |
198 | rewriter, loc: op.getLoc(), first: adaptor.getLhs(), second: adaptor.getRhs()); |
199 | rewriter.replaceOpWithNewOp<emitc::LogicalOrOp>(op, op.getType(), |
200 | isUnordered, cmpResult); |
201 | return success(); |
202 | } |
203 | |
204 | auto isOrdered = createCheckIsOrdered(rewriter, loc: op.getLoc(), |
205 | first: adaptor.getLhs(), second: adaptor.getRhs()); |
206 | rewriter.replaceOpWithNewOp<emitc::LogicalAndOp>(op, op.getType(), |
207 | isOrdered, cmpResult); |
208 | return success(); |
209 | } |
210 | |
211 | private: |
212 | /// Return a value that is true if \p operand is NaN. |
213 | Value isNaN(ConversionPatternRewriter &rewriter, Location loc, |
214 | Value operand) const { |
215 | // A value is NaN exactly when it compares unequal to itself. |
216 | return rewriter.create<emitc::CmpOp>( |
217 | loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, operand, operand); |
218 | } |
219 | |
220 | /// Return a value that is true if \p operand is not NaN. |
221 | Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc, |
222 | Value operand) const { |
223 | // A value is not NaN exactly when it compares equal to itself. |
224 | return rewriter.create<emitc::CmpOp>( |
225 | loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, operand, operand); |
226 | } |
227 | |
228 | /// Return a value that is true if the operands \p first and \p second are |
229 | /// unordered (i.e., at least one of them is NaN). |
230 | Value createCheckIsUnordered(ConversionPatternRewriter &rewriter, |
231 | Location loc, Value first, Value second) const { |
232 | auto firstIsNaN = isNaN(rewriter, loc, operand: first); |
233 | auto secondIsNaN = isNaN(rewriter, loc, operand: second); |
234 | return rewriter.create<emitc::LogicalOrOp>(loc, rewriter.getI1Type(), |
235 | firstIsNaN, secondIsNaN); |
236 | } |
237 | |
238 | /// Return a value that is true if the operands \p first and \p second are |
239 | /// both ordered (i.e., none one of them is NaN). |
240 | Value createCheckIsOrdered(ConversionPatternRewriter &rewriter, Location loc, |
241 | Value first, Value second) const { |
242 | auto firstIsNotNaN = isNotNaN(rewriter, loc, operand: first); |
243 | auto secondIsNotNaN = isNotNaN(rewriter, loc, operand: second); |
244 | return rewriter.create<emitc::LogicalAndOp>(loc, rewriter.getI1Type(), |
245 | firstIsNotNaN, secondIsNotNaN); |
246 | } |
247 | }; |
248 | |
249 | class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> { |
250 | public: |
251 | using OpConversionPattern::OpConversionPattern; |
252 | |
253 | bool needsUnsignedCmp(arith::CmpIPredicate pred) const { |
254 | switch (pred) { |
255 | case arith::CmpIPredicate::eq: |
256 | case arith::CmpIPredicate::ne: |
257 | case arith::CmpIPredicate::slt: |
258 | case arith::CmpIPredicate::sle: |
259 | case arith::CmpIPredicate::sgt: |
260 | case arith::CmpIPredicate::sge: |
261 | return false; |
262 | case arith::CmpIPredicate::ult: |
263 | case arith::CmpIPredicate::ule: |
264 | case arith::CmpIPredicate::ugt: |
265 | case arith::CmpIPredicate::uge: |
266 | return true; |
267 | } |
268 | llvm_unreachable("unknown cmpi predicate kind"); |
269 | } |
270 | |
271 | emitc::CmpPredicate toEmitCPred(arith::CmpIPredicate pred) const { |
272 | switch (pred) { |
273 | case arith::CmpIPredicate::eq: |
274 | return emitc::CmpPredicate::eq; |
275 | case arith::CmpIPredicate::ne: |
276 | return emitc::CmpPredicate::ne; |
277 | case arith::CmpIPredicate::slt: |
278 | case arith::CmpIPredicate::ult: |
279 | return emitc::CmpPredicate::lt; |
280 | case arith::CmpIPredicate::sle: |
281 | case arith::CmpIPredicate::ule: |
282 | return emitc::CmpPredicate::le; |
283 | case arith::CmpIPredicate::sgt: |
284 | case arith::CmpIPredicate::ugt: |
285 | return emitc::CmpPredicate::gt; |
286 | case arith::CmpIPredicate::sge: |
287 | case arith::CmpIPredicate::uge: |
288 | return emitc::CmpPredicate::ge; |
289 | } |
290 | llvm_unreachable("unknown cmpi predicate kind"); |
291 | } |
292 | |
293 | LogicalResult |
294 | matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, |
295 | ConversionPatternRewriter &rewriter) const override { |
296 | |
297 | Type type = adaptor.getLhs().getType(); |
298 | if (!type || !(isa<IntegerType>(Val: type) || emitc::isPointerWideType(type))) { |
299 | return rewriter.notifyMatchFailure( |
300 | op, "expected integer or size_t/ssize_t/ptrdiff_t type"); |
301 | } |
302 | |
303 | bool needsUnsigned = needsUnsignedCmp(op.getPredicate()); |
304 | emitc::CmpPredicate pred = toEmitCPred(op.getPredicate()); |
305 | |
306 | Type arithmeticType = adaptIntegralTypeSignedness(ty: type, needsUnsigned); |
307 | Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); |
308 | Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType); |
309 | |
310 | rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs); |
311 | return success(); |
312 | } |
313 | }; |
314 | |
315 | class NegFOpConversion : public OpConversionPattern<arith::NegFOp> { |
316 | public: |
317 | using OpConversionPattern::OpConversionPattern; |
318 | |
319 | LogicalResult |
320 | matchAndRewrite(arith::NegFOp op, OpAdaptor adaptor, |
321 | ConversionPatternRewriter &rewriter) const override { |
322 | |
323 | auto adaptedOp = adaptor.getOperand(); |
324 | auto adaptedOpType = adaptedOp.getType(); |
325 | |
326 | if (isa<TensorType>(adaptedOpType) || isa<VectorType>(adaptedOpType)) { |
327 | return rewriter.notifyMatchFailure( |
328 | op.getLoc(), |
329 | "negf currently only supports scalar types, not vectors or tensors"); |
330 | } |
331 | |
332 | if (!emitc::isSupportedFloatType(type: adaptedOpType)) { |
333 | return rewriter.notifyMatchFailure( |
334 | op.getLoc(), "floating-point type is not supported by EmitC"); |
335 | } |
336 | |
337 | rewriter.replaceOpWithNewOp<emitc::UnaryMinusOp>(op, adaptedOpType, |
338 | adaptedOp); |
339 | return success(); |
340 | } |
341 | }; |
342 | |
343 | template <typename ArithOp, bool castToUnsigned> |
344 | class CastConversion : public OpConversionPattern<ArithOp> { |
345 | public: |
346 | using OpConversionPattern<ArithOp>::OpConversionPattern; |
347 | |
348 | LogicalResult |
349 | matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, |
350 | ConversionPatternRewriter &rewriter) const override { |
351 | |
352 | Type opReturnType = this->getTypeConverter()->convertType(op.getType()); |
353 | if (!opReturnType || !(isa<IntegerType>(Val: opReturnType) || |
354 | emitc::isPointerWideType(type: opReturnType))) |
355 | return rewriter.notifyMatchFailure( |
356 | op, "expected integer or size_t/ssize_t/ptrdiff_t result type"); |
357 | |
358 | if (adaptor.getOperands().size() != 1) { |
359 | return rewriter.notifyMatchFailure( |
360 | op, "CastConversion only supports unary ops"); |
361 | } |
362 | |
363 | Type operandType = adaptor.getIn().getType(); |
364 | if (!operandType || !(isa<IntegerType>(Val: operandType) || |
365 | emitc::isPointerWideType(type: operandType))) |
366 | return rewriter.notifyMatchFailure( |
367 | op, "expected integer or size_t/ssize_t/ptrdiff_t operand type"); |
368 | |
369 | // Signed (sign-extending) casts from i1 are not supported. |
370 | if (operandType.isInteger(width: 1) && !castToUnsigned) |
371 | return rewriter.notifyMatchFailure(op, |
372 | "operation not supported on i1 type"); |
373 | |
374 | // to-i1 conversions: arith semantics want truncation, whereas (bool)(v) is |
375 | // equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives |
376 | // truncation. |
377 | if (opReturnType.isInteger(width: 1)) { |
378 | Type attrType = (emitc::isPointerWideType(type: operandType)) |
379 | ? rewriter.getIndexType() |
380 | : operandType; |
381 | auto constOne = rewriter.create<emitc::ConstantOp>( |
382 | op.getLoc(), operandType, rewriter.getOneAttr(attrType)); |
383 | auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>( |
384 | op.getLoc(), operandType, adaptor.getIn(), constOne); |
385 | rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType, |
386 | oneAndOperand); |
387 | return success(); |
388 | } |
389 | |
390 | bool isTruncation = |
391 | (isa<IntegerType>(Val: operandType) && isa<IntegerType>(Val: opReturnType) && |
392 | operandType.getIntOrFloatBitWidth() > |
393 | opReturnType.getIntOrFloatBitWidth()); |
394 | bool doUnsigned = castToUnsigned || isTruncation; |
395 | |
396 | // Adapt the signedness of the result (bitwidth-preserving cast) |
397 | // This is needed e.g., if the return type is signless. |
398 | Type castDestType = adaptIntegralTypeSignedness(ty: opReturnType, needsUnsigned: doUnsigned); |
399 | |
400 | // Adapt the signedness of the operand (bitwidth-preserving cast) |
401 | Type castSrcType = adaptIntegralTypeSignedness(ty: operandType, needsUnsigned: doUnsigned); |
402 | Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType); |
403 | |
404 | // Actual cast (may change bitwidth) |
405 | auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(), |
406 | castDestType, actualOp); |
407 | |
408 | // Cast to the expected output type |
409 | auto result = adaptValueType(cast, rewriter, opReturnType); |
410 | |
411 | rewriter.replaceOp(op, result); |
412 | return success(); |
413 | } |
414 | }; |
415 | |
416 | template <typename ArithOp> |
417 | class UnsignedCastConversion : public CastConversion<ArithOp, true> { |
418 | using CastConversion<ArithOp, true>::CastConversion; |
419 | }; |
420 | |
421 | template <typename ArithOp> |
422 | class SignedCastConversion : public CastConversion<ArithOp, false> { |
423 | using CastConversion<ArithOp, false>::CastConversion; |
424 | }; |
425 | |
426 | template <typename ArithOp, typename EmitCOp> |
427 | class ArithOpConversion final : public OpConversionPattern<ArithOp> { |
428 | public: |
429 | using OpConversionPattern<ArithOp>::OpConversionPattern; |
430 | |
431 | LogicalResult |
432 | matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor, |
433 | ConversionPatternRewriter &rewriter) const override { |
434 | |
435 | Type newTy = this->getTypeConverter()->convertType(arithOp.getType()); |
436 | if (!newTy) |
437 | return rewriter.notifyMatchFailure(arithOp, |
438 | "converting result type failed"); |
439 | rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, newTy, |
440 | adaptor.getOperands()); |
441 | |
442 | return success(); |
443 | } |
444 | }; |
445 | |
446 | template <class ArithOp, class EmitCOp> |
447 | class BinaryUIOpConversion final : public OpConversionPattern<ArithOp> { |
448 | public: |
449 | using OpConversionPattern<ArithOp>::OpConversionPattern; |
450 | |
451 | LogicalResult |
452 | matchAndRewrite(ArithOp uiBinOp, typename ArithOp::Adaptor adaptor, |
453 | ConversionPatternRewriter &rewriter) const override { |
454 | Type newRetTy = this->getTypeConverter()->convertType(uiBinOp.getType()); |
455 | if (!newRetTy) |
456 | return rewriter.notifyMatchFailure(uiBinOp, |
457 | "converting result type failed"); |
458 | if (!isa<IntegerType>(Val: newRetTy)) { |
459 | return rewriter.notifyMatchFailure(uiBinOp, "expected integer type"); |
460 | } |
461 | Type unsignedType = |
462 | adaptIntegralTypeSignedness(ty: newRetTy, /*needsUnsigned=*/true); |
463 | if (!unsignedType) |
464 | return rewriter.notifyMatchFailure(uiBinOp, |
465 | "converting result type failed"); |
466 | Value lhsAdapted = adaptValueType(uiBinOp.getLhs(), rewriter, unsignedType); |
467 | Value rhsAdapted = adaptValueType(uiBinOp.getRhs(), rewriter, unsignedType); |
468 | |
469 | auto newDivOp = |
470 | rewriter.create<EmitCOp>(uiBinOp.getLoc(), unsignedType, |
471 | ArrayRef<Value>{lhsAdapted, rhsAdapted}); |
472 | Value resultAdapted = adaptValueType(newDivOp, rewriter, newRetTy); |
473 | rewriter.replaceOp(uiBinOp, resultAdapted); |
474 | return success(); |
475 | } |
476 | }; |
477 | |
478 | template <typename ArithOp, typename EmitCOp> |
479 | class IntegerOpConversion final : public OpConversionPattern<ArithOp> { |
480 | public: |
481 | using OpConversionPattern<ArithOp>::OpConversionPattern; |
482 | |
483 | LogicalResult |
484 | matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, |
485 | ConversionPatternRewriter &rewriter) const override { |
486 | |
487 | Type type = this->getTypeConverter()->convertType(op.getType()); |
488 | if (!type || !(isa<IntegerType>(Val: type) || emitc::isPointerWideType(type))) { |
489 | return rewriter.notifyMatchFailure( |
490 | op, "expected integer or size_t/ssize_t/ptrdiff_t type"); |
491 | } |
492 | |
493 | if (type.isInteger(width: 1)) { |
494 | // arith expects wrap-around arithmethic, which doesn't happen on `bool`. |
495 | return rewriter.notifyMatchFailure(op, "i1 type is not implemented"); |
496 | } |
497 | |
498 | Type arithmeticType = type; |
499 | if ((type.isSignlessInteger() || type.isSignedInteger()) && |
500 | !bitEnumContainsAll(op.getOverflowFlags(), |
501 | arith::IntegerOverflowFlags::nsw)) { |
502 | // If the C type is signed and the op doesn't guarantee "No Signed Wrap", |
503 | // we compute in unsigned integers to avoid UB. |
504 | arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(), |
505 | /*isSigned=*/false); |
506 | } |
507 | |
508 | Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); |
509 | Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType); |
510 | |
511 | Value arithmeticResult = rewriter.template create<EmitCOp>( |
512 | op.getLoc(), arithmeticType, lhs, rhs); |
513 | |
514 | Value result = adaptValueType(val: arithmeticResult, rewriter, ty: type); |
515 | |
516 | rewriter.replaceOp(op, result); |
517 | return success(); |
518 | } |
519 | }; |
520 | |
521 | template <typename ArithOp, typename EmitCOp> |
522 | class BitwiseOpConversion : public OpConversionPattern<ArithOp> { |
523 | public: |
524 | using OpConversionPattern<ArithOp>::OpConversionPattern; |
525 | |
526 | LogicalResult |
527 | matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, |
528 | ConversionPatternRewriter &rewriter) const override { |
529 | |
530 | Type type = this->getTypeConverter()->convertType(op.getType()); |
531 | if (!isa_and_nonnull<IntegerType>(Val: type)) { |
532 | return rewriter.notifyMatchFailure( |
533 | op, |
534 | "expected integer type, vector/tensor support not yet implemented"); |
535 | } |
536 | |
537 | // Bitwise ops can be performed directly on booleans |
538 | if (type.isInteger(width: 1)) { |
539 | rewriter.replaceOpWithNewOp<EmitCOp>(op, type, adaptor.getLhs(), |
540 | adaptor.getRhs()); |
541 | return success(); |
542 | } |
543 | |
544 | // Bitwise ops are defined by the C standard on unsigned operands. |
545 | Type arithmeticType = |
546 | adaptIntegralTypeSignedness(ty: type, /*needsUnsigned=*/true); |
547 | |
548 | Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); |
549 | Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType); |
550 | |
551 | Value arithmeticResult = rewriter.template create<EmitCOp>( |
552 | op.getLoc(), arithmeticType, lhs, rhs); |
553 | |
554 | Value result = adaptValueType(val: arithmeticResult, rewriter, ty: type); |
555 | |
556 | rewriter.replaceOp(op, result); |
557 | return success(); |
558 | } |
559 | }; |
560 | |
561 | template <typename ArithOp, typename EmitCOp, bool isUnsignedOp> |
562 | class ShiftOpConversion : public OpConversionPattern<ArithOp> { |
563 | public: |
564 | using OpConversionPattern<ArithOp>::OpConversionPattern; |
565 | |
566 | LogicalResult |
567 | matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, |
568 | ConversionPatternRewriter &rewriter) const override { |
569 | |
570 | Type type = this->getTypeConverter()->convertType(op.getType()); |
571 | if (!type || !(isa<IntegerType>(Val: type) || emitc::isPointerWideType(type))) { |
572 | return rewriter.notifyMatchFailure( |
573 | op, "expected integer or size_t/ssize_t/ptrdiff_t type"); |
574 | } |
575 | |
576 | if (type.isInteger(width: 1)) { |
577 | return rewriter.notifyMatchFailure(op, "i1 type is not implemented"); |
578 | } |
579 | |
580 | Type arithmeticType = adaptIntegralTypeSignedness(ty: type, needsUnsigned: isUnsignedOp); |
581 | |
582 | Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); |
583 | // Shift amount interpreted as unsigned per Arith dialect spec. |
584 | Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(), |
585 | /*needsUnsigned=*/true); |
586 | Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType); |
587 | |
588 | // Add a runtime check for overflow |
589 | Value width; |
590 | if (emitc::isPointerWideType(type)) { |
591 | Value eight = rewriter.create<emitc::ConstantOp>( |
592 | op.getLoc(), rhsType, rewriter.getIndexAttr(8)); |
593 | emitc::CallOpaqueOp sizeOfCall = rewriter.create<emitc::CallOpaqueOp>( |
594 | op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight}); |
595 | width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight, |
596 | sizeOfCall.getResult(0)); |
597 | } else { |
598 | width = rewriter.create<emitc::ConstantOp>( |
599 | op.getLoc(), rhsType, |
600 | rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth())); |
601 | } |
602 | |
603 | Value excessCheck = rewriter.create<emitc::CmpOp>( |
604 | op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width); |
605 | |
606 | // Any concrete value is a valid refinement of poison. |
607 | Value poison = rewriter.create<emitc::ConstantOp>( |
608 | op.getLoc(), arithmeticType, |
609 | (isa<IntegerType>(arithmeticType) |
610 | ? rewriter.getIntegerAttr(arithmeticType, 0) |
611 | : rewriter.getIndexAttr(0))); |
612 | |
613 | emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>( |
614 | op.getLoc(), arithmeticType, /*do_not_inline=*/false); |
615 | Block &bodyBlock = ternary.getBodyRegion().emplaceBlock(); |
616 | auto currentPoint = rewriter.getInsertionPoint(); |
617 | rewriter.setInsertionPointToStart(&bodyBlock); |
618 | Value arithmeticResult = |
619 | rewriter.create<EmitCOp>(op.getLoc(), arithmeticType, lhs, rhs); |
620 | Value resultOrPoison = rewriter.create<emitc::ConditionalOp>( |
621 | op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison); |
622 | rewriter.create<emitc::YieldOp>(op.getLoc(), resultOrPoison); |
623 | rewriter.setInsertionPoint(op->getBlock(), currentPoint); |
624 | |
625 | Value result = adaptValueType(ternary, rewriter, type); |
626 | |
627 | rewriter.replaceOp(op, result); |
628 | return success(); |
629 | } |
630 | }; |
631 | |
632 | template <typename ArithOp, typename EmitCOp> |
633 | class SignedShiftOpConversion final |
634 | : public ShiftOpConversion<ArithOp, EmitCOp, false> { |
635 | using ShiftOpConversion<ArithOp, EmitCOp, false>::ShiftOpConversion; |
636 | }; |
637 | |
638 | template <typename ArithOp, typename EmitCOp> |
639 | class UnsignedShiftOpConversion final |
640 | : public ShiftOpConversion<ArithOp, EmitCOp, true> { |
641 | using ShiftOpConversion<ArithOp, EmitCOp, true>::ShiftOpConversion; |
642 | }; |
643 | |
644 | class SelectOpConversion : public OpConversionPattern<arith::SelectOp> { |
645 | public: |
646 | using OpConversionPattern<arith::SelectOp>::OpConversionPattern; |
647 | |
648 | LogicalResult |
649 | matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor, |
650 | ConversionPatternRewriter &rewriter) const override { |
651 | |
652 | Type dstType = getTypeConverter()->convertType(selectOp.getType()); |
653 | if (!dstType) |
654 | return rewriter.notifyMatchFailure(selectOp, "type conversion failed"); |
655 | |
656 | if (!adaptor.getCondition().getType().isInteger(1)) |
657 | return rewriter.notifyMatchFailure( |
658 | selectOp, |
659 | "can only be converted if condition is a scalar of type i1"); |
660 | |
661 | rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(selectOp, dstType, |
662 | adaptor.getOperands()); |
663 | |
664 | return success(); |
665 | } |
666 | }; |
667 | |
668 | // Floating-point to integer conversions. |
669 | template <typename CastOp> |
670 | class FtoICastOpConversion : public OpConversionPattern<CastOp> { |
671 | public: |
672 | FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context) |
673 | : OpConversionPattern<CastOp>(typeConverter, context) {} |
674 | |
675 | LogicalResult |
676 | matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor, |
677 | ConversionPatternRewriter &rewriter) const override { |
678 | |
679 | Type operandType = adaptor.getIn().getType(); |
680 | if (!emitc::isSupportedFloatType(type: operandType)) |
681 | return rewriter.notifyMatchFailure(castOp, |
682 | "unsupported cast source type"); |
683 | |
684 | Type dstType = this->getTypeConverter()->convertType(castOp.getType()); |
685 | if (!dstType) |
686 | return rewriter.notifyMatchFailure(castOp, "type conversion failed"); |
687 | |
688 | // Float-to-i1 casts are not supported: any value with 0 < value < 1 must be |
689 | // truncated to 0, whereas a boolean conversion would return true. |
690 | if (!emitc::isSupportedIntegerType(type: dstType) || dstType.isInteger(width: 1)) |
691 | return rewriter.notifyMatchFailure(castOp, |
692 | "unsupported cast destination type"); |
693 | |
694 | // Convert to unsigned if it's the "ui" variant |
695 | // Signless is interpreted as signed, so no need to cast for "si" |
696 | Type actualResultType = dstType; |
697 | if (isa<arith::FPToUIOp>(castOp)) { |
698 | actualResultType = |
699 | rewriter.getIntegerType(dstType.getIntOrFloatBitWidth(), |
700 | /*isSigned=*/false); |
701 | } |
702 | |
703 | Value result = rewriter.create<emitc::CastOp>( |
704 | castOp.getLoc(), actualResultType, adaptor.getOperands()); |
705 | |
706 | if (isa<arith::FPToUIOp>(castOp)) { |
707 | result = rewriter.create<emitc::CastOp>(castOp.getLoc(), dstType, result); |
708 | } |
709 | rewriter.replaceOp(castOp, result); |
710 | |
711 | return success(); |
712 | } |
713 | }; |
714 | |
715 | // Integer to floating-point conversions. |
716 | template <typename CastOp> |
717 | class ItoFCastOpConversion : public OpConversionPattern<CastOp> { |
718 | public: |
719 | ItoFCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context) |
720 | : OpConversionPattern<CastOp>(typeConverter, context) {} |
721 | |
722 | LogicalResult |
723 | matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor, |
724 | ConversionPatternRewriter &rewriter) const override { |
725 | // Vectors in particular are not supported |
726 | Type operandType = adaptor.getIn().getType(); |
727 | if (!emitc::isSupportedIntegerType(type: operandType)) |
728 | return rewriter.notifyMatchFailure(castOp, |
729 | "unsupported cast source type"); |
730 | |
731 | Type dstType = this->getTypeConverter()->convertType(castOp.getType()); |
732 | if (!dstType) |
733 | return rewriter.notifyMatchFailure(castOp, "type conversion failed"); |
734 | |
735 | if (!emitc::isSupportedFloatType(type: dstType)) |
736 | return rewriter.notifyMatchFailure(castOp, |
737 | "unsupported cast destination type"); |
738 | |
739 | // Convert to unsigned if it's the "ui" variant |
740 | // Signless is interpreted as signed, so no need to cast for "si" |
741 | Type actualOperandType = operandType; |
742 | if (isa<arith::UIToFPOp>(castOp)) { |
743 | actualOperandType = |
744 | rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(), |
745 | /*isSigned=*/false); |
746 | } |
747 | Value fpCastOperand = adaptor.getIn(); |
748 | if (actualOperandType != operandType) { |
749 | fpCastOperand = rewriter.template create<emitc::CastOp>( |
750 | castOp.getLoc(), actualOperandType, fpCastOperand); |
751 | } |
752 | rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand); |
753 | |
754 | return success(); |
755 | } |
756 | }; |
757 | |
758 | // Floating-point to floating-point conversions. |
759 | template <typename CastOp> |
760 | class FpCastOpConversion : public OpConversionPattern<CastOp> { |
761 | public: |
762 | FpCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context) |
763 | : OpConversionPattern<CastOp>(typeConverter, context) {} |
764 | |
765 | LogicalResult |
766 | matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor, |
767 | ConversionPatternRewriter &rewriter) const override { |
768 | // Vectors in particular are not supported. |
769 | Type operandType = adaptor.getIn().getType(); |
770 | if (!emitc::isSupportedFloatType(type: operandType)) |
771 | return rewriter.notifyMatchFailure(castOp, |
772 | "unsupported cast source type"); |
773 | if (auto roundingModeOp = |
774 | dyn_cast<arith::ArithRoundingModeInterface>(*castOp)) { |
775 | // Only supporting default rounding mode as of now. |
776 | if (roundingModeOp.getRoundingModeAttr()) |
777 | return rewriter.notifyMatchFailure(castOp, "unsupported rounding mode"); |
778 | } |
779 | |
780 | Type dstType = this->getTypeConverter()->convertType(castOp.getType()); |
781 | if (!dstType) |
782 | return rewriter.notifyMatchFailure(castOp, "type conversion failed"); |
783 | |
784 | if (!emitc::isSupportedFloatType(type: dstType)) |
785 | return rewriter.notifyMatchFailure(castOp, |
786 | "unsupported cast destination type"); |
787 | |
788 | Value fpCastOperand = adaptor.getIn(); |
789 | rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand); |
790 | |
791 | return success(); |
792 | } |
793 | }; |
794 | |
795 | } // namespace |
796 | |
797 | //===----------------------------------------------------------------------===// |
798 | // Pattern population |
799 | //===----------------------------------------------------------------------===// |
800 | |
801 | void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, |
802 | RewritePatternSet &patterns) { |
803 | MLIRContext *ctx = patterns.getContext(); |
804 | |
805 | mlir::populateEmitCSizeTTypeConversions(converter&: typeConverter); |
806 | |
807 | // clang-format off |
808 | patterns.add< |
809 | ArithConstantOpConversionPattern, |
810 | ArithOpConversion<arith::AddFOp, emitc::AddOp>, |
811 | ArithOpConversion<arith::DivFOp, emitc::DivOp>, |
812 | ArithOpConversion<arith::DivSIOp, emitc::DivOp>, |
813 | ArithOpConversion<arith::MulFOp, emitc::MulOp>, |
814 | ArithOpConversion<arith::RemSIOp, emitc::RemOp>, |
815 | ArithOpConversion<arith::SubFOp, emitc::SubOp>, |
816 | BinaryUIOpConversion<arith::DivUIOp, emitc::DivOp>, |
817 | BinaryUIOpConversion<arith::RemUIOp, emitc::RemOp>, |
818 | IntegerOpConversion<arith::AddIOp, emitc::AddOp>, |
819 | IntegerOpConversion<arith::MulIOp, emitc::MulOp>, |
820 | IntegerOpConversion<arith::SubIOp, emitc::SubOp>, |
821 | BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>, |
822 | BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>, |
823 | BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>, |
824 | UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>, |
825 | SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>, |
826 | UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>, |
827 | CmpFOpConversion, |
828 | CmpIOpConversion, |
829 | NegFOpConversion, |
830 | SelectOpConversion, |
831 | // Truncation is guaranteed for unsigned types. |
832 | UnsignedCastConversion<arith::TruncIOp>, |
833 | SignedCastConversion<arith::ExtSIOp>, |
834 | UnsignedCastConversion<arith::ExtUIOp>, |
835 | SignedCastConversion<arith::IndexCastOp>, |
836 | UnsignedCastConversion<arith::IndexCastUIOp>, |
837 | ItoFCastOpConversion<arith::SIToFPOp>, |
838 | ItoFCastOpConversion<arith::UIToFPOp>, |
839 | FtoICastOpConversion<arith::FPToSIOp>, |
840 | FtoICastOpConversion<arith::FPToUIOp>, |
841 | FpCastOpConversion<arith::ExtFOp>, |
842 | FpCastOpConversion<arith::TruncFOp> |
843 | >(typeConverter, ctx); |
844 | // clang-format on |
845 | } |
846 |
Definitions
- ArithToEmitCDialectInterface
- populateConvertToEmitCConversionPatterns
- registerConvertArithToEmitCInterface
- ArithConstantOpConversionPattern
- matchAndRewrite
- adaptIntegralTypeSignedness
- adaptValueType
- CmpFOpConversion
- matchAndRewrite
- isNaN
- isNotNaN
- createCheckIsUnordered
- createCheckIsOrdered
- CmpIOpConversion
- needsUnsignedCmp
- toEmitCPred
- matchAndRewrite
- NegFOpConversion
- matchAndRewrite
- CastConversion
- matchAndRewrite
- UnsignedCastConversion
- SignedCastConversion
- ArithOpConversion
- matchAndRewrite
- BinaryUIOpConversion
- matchAndRewrite
- IntegerOpConversion
- matchAndRewrite
- BitwiseOpConversion
- matchAndRewrite
- ShiftOpConversion
- matchAndRewrite
- SignedShiftOpConversion
- UnsignedShiftOpConversion
- SelectOpConversion
- matchAndRewrite
- FtoICastOpConversion
- FtoICastOpConversion
- matchAndRewrite
- ItoFCastOpConversion
- ItoFCastOpConversion
- matchAndRewrite
- FpCastOpConversion
- FpCastOpConversion
- matchAndRewrite
Learn to use CMake with our Intro Training
Find out more