1//===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert the Arith dialect to the EmitC
10// dialect.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
15
16#include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/EmitC/IR/EmitC.h"
19#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
20#include "mlir/IR/BuiltinAttributes.h"
21#include "mlir/IR/BuiltinTypes.h"
22#include "mlir/Transforms/DialectConversion.h"
23
24using namespace mlir;
25
26namespace {
27/// Implement the interface to convert Arith to EmitC.
28struct ArithToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
29 using ConvertToEmitCPatternInterface::ConvertToEmitCPatternInterface;
30
31 /// Hook for derived dialect interface to provide conversion patterns
32 /// and mark dialect legal for the conversion target.
33 void populateConvertToEmitCConversionPatterns(
34 ConversionTarget &target, TypeConverter &typeConverter,
35 RewritePatternSet &patterns) const final {
36 populateArithToEmitCPatterns(typeConverter, patterns);
37 }
38};
39} // namespace
40
41void mlir::registerConvertArithToEmitCInterface(DialectRegistry &registry) {
42 registry.addExtension(extensionFn: +[](MLIRContext *ctx, arith::ArithDialect *dialect) {
43 dialect->addInterfaces<ArithToEmitCDialectInterface>();
44 });
45}
46
47//===----------------------------------------------------------------------===//
48// Conversion Patterns
49//===----------------------------------------------------------------------===//
50
51namespace {
52class ArithConstantOpConversionPattern
53 : public OpConversionPattern<arith::ConstantOp> {
54public:
55 using OpConversionPattern::OpConversionPattern;
56
57 LogicalResult
58 matchAndRewrite(arith::ConstantOp arithConst,
59 arith::ConstantOp::Adaptor adaptor,
60 ConversionPatternRewriter &rewriter) const override {
61 Type newTy = this->getTypeConverter()->convertType(arithConst.getType());
62 if (!newTy)
63 return rewriter.notifyMatchFailure(arithConst, "type conversion failed");
64 rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, newTy,
65 adaptor.getValue());
66 return success();
67 }
68};
69
70/// Get the signed or unsigned type corresponding to \p ty.
71Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
72 if (isa<IntegerType>(Val: ty)) {
73 if (ty.isUnsignedInteger() != needsUnsigned) {
74 auto signedness = needsUnsigned
75 ? IntegerType::SignednessSemantics::Unsigned
76 : IntegerType::SignednessSemantics::Signed;
77 return IntegerType::get(ty.getContext(), ty.getIntOrFloatBitWidth(),
78 signedness);
79 }
80 } else if (emitc::isPointerWideType(type: ty)) {
81 if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
82 if (needsUnsigned)
83 return emitc::SizeTType::get(ty.getContext());
84 return emitc::PtrDiffTType::get(ty.getContext());
85 }
86 }
87 return ty;
88}
89
90/// Insert a cast operation to type \p ty if \p val does not have this type.
91Value adaptValueType(Value val, ConversionPatternRewriter &rewriter, Type ty) {
92 return rewriter.createOrFold<emitc::CastOp>(val.getLoc(), ty, val);
93}
94
95class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
96public:
97 using OpConversionPattern::OpConversionPattern;
98
99 LogicalResult
100 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
101 ConversionPatternRewriter &rewriter) const override {
102
103 if (!isa<FloatType>(adaptor.getRhs().getType())) {
104 return rewriter.notifyMatchFailure(op.getLoc(),
105 "cmpf currently only supported on "
106 "floats, not tensors/vectors thereof");
107 }
108
109 bool unordered = false;
110 emitc::CmpPredicate predicate;
111 switch (op.getPredicate()) {
112 case arith::CmpFPredicate::AlwaysFalse: {
113 auto constant = rewriter.create<emitc::ConstantOp>(
114 op.getLoc(), rewriter.getI1Type(),
115 rewriter.getBoolAttr(/*value=*/false));
116 rewriter.replaceOp(op, constant);
117 return success();
118 }
119 case arith::CmpFPredicate::OEQ:
120 unordered = false;
121 predicate = emitc::CmpPredicate::eq;
122 break;
123 case arith::CmpFPredicate::OGT:
124 unordered = false;
125 predicate = emitc::CmpPredicate::gt;
126 break;
127 case arith::CmpFPredicate::OGE:
128 unordered = false;
129 predicate = emitc::CmpPredicate::ge;
130 break;
131 case arith::CmpFPredicate::OLT:
132 unordered = false;
133 predicate = emitc::CmpPredicate::lt;
134 break;
135 case arith::CmpFPredicate::OLE:
136 unordered = false;
137 predicate = emitc::CmpPredicate::le;
138 break;
139 case arith::CmpFPredicate::ONE:
140 unordered = false;
141 predicate = emitc::CmpPredicate::ne;
142 break;
143 case arith::CmpFPredicate::ORD: {
144 // ordered, i.e. none of the operands is NaN
145 auto cmp = createCheckIsOrdered(rewriter, loc: op.getLoc(), first: adaptor.getLhs(),
146 second: adaptor.getRhs());
147 rewriter.replaceOp(op, cmp);
148 return success();
149 }
150 case arith::CmpFPredicate::UEQ:
151 unordered = true;
152 predicate = emitc::CmpPredicate::eq;
153 break;
154 case arith::CmpFPredicate::UGT:
155 unordered = true;
156 predicate = emitc::CmpPredicate::gt;
157 break;
158 case arith::CmpFPredicate::UGE:
159 unordered = true;
160 predicate = emitc::CmpPredicate::ge;
161 break;
162 case arith::CmpFPredicate::ULT:
163 unordered = true;
164 predicate = emitc::CmpPredicate::lt;
165 break;
166 case arith::CmpFPredicate::ULE:
167 unordered = true;
168 predicate = emitc::CmpPredicate::le;
169 break;
170 case arith::CmpFPredicate::UNE:
171 unordered = true;
172 predicate = emitc::CmpPredicate::ne;
173 break;
174 case arith::CmpFPredicate::UNO: {
175 // unordered, i.e. either operand is nan
176 auto cmp = createCheckIsUnordered(rewriter, loc: op.getLoc(), first: adaptor.getLhs(),
177 second: adaptor.getRhs());
178 rewriter.replaceOp(op, cmp);
179 return success();
180 }
181 case arith::CmpFPredicate::AlwaysTrue: {
182 auto constant = rewriter.create<emitc::ConstantOp>(
183 op.getLoc(), rewriter.getI1Type(),
184 rewriter.getBoolAttr(/*value=*/true));
185 rewriter.replaceOp(op, constant);
186 return success();
187 }
188 }
189
190 // Compare the values naively
191 auto cmpResult =
192 rewriter.create<emitc::CmpOp>(op.getLoc(), op.getType(), predicate,
193 adaptor.getLhs(), adaptor.getRhs());
194
195 // Adjust the results for unordered/ordered semantics
196 if (unordered) {
197 auto isUnordered = createCheckIsUnordered(
198 rewriter, loc: op.getLoc(), first: adaptor.getLhs(), second: adaptor.getRhs());
199 rewriter.replaceOpWithNewOp<emitc::LogicalOrOp>(op, op.getType(),
200 isUnordered, cmpResult);
201 return success();
202 }
203
204 auto isOrdered = createCheckIsOrdered(rewriter, loc: op.getLoc(),
205 first: adaptor.getLhs(), second: adaptor.getRhs());
206 rewriter.replaceOpWithNewOp<emitc::LogicalAndOp>(op, op.getType(),
207 isOrdered, cmpResult);
208 return success();
209 }
210
211private:
212 /// Return a value that is true if \p operand is NaN.
213 Value isNaN(ConversionPatternRewriter &rewriter, Location loc,
214 Value operand) const {
215 // A value is NaN exactly when it compares unequal to itself.
216 return rewriter.create<emitc::CmpOp>(
217 loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, operand, operand);
218 }
219
220 /// Return a value that is true if \p operand is not NaN.
221 Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc,
222 Value operand) const {
223 // A value is not NaN exactly when it compares equal to itself.
224 return rewriter.create<emitc::CmpOp>(
225 loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, operand, operand);
226 }
227
228 /// Return a value that is true if the operands \p first and \p second are
229 /// unordered (i.e., at least one of them is NaN).
230 Value createCheckIsUnordered(ConversionPatternRewriter &rewriter,
231 Location loc, Value first, Value second) const {
232 auto firstIsNaN = isNaN(rewriter, loc, operand: first);
233 auto secondIsNaN = isNaN(rewriter, loc, operand: second);
234 return rewriter.create<emitc::LogicalOrOp>(loc, rewriter.getI1Type(),
235 firstIsNaN, secondIsNaN);
236 }
237
238 /// Return a value that is true if the operands \p first and \p second are
239 /// both ordered (i.e., none one of them is NaN).
240 Value createCheckIsOrdered(ConversionPatternRewriter &rewriter, Location loc,
241 Value first, Value second) const {
242 auto firstIsNotNaN = isNotNaN(rewriter, loc, operand: first);
243 auto secondIsNotNaN = isNotNaN(rewriter, loc, operand: second);
244 return rewriter.create<emitc::LogicalAndOp>(loc, rewriter.getI1Type(),
245 firstIsNotNaN, secondIsNotNaN);
246 }
247};
248
249class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
250public:
251 using OpConversionPattern::OpConversionPattern;
252
253 bool needsUnsignedCmp(arith::CmpIPredicate pred) const {
254 switch (pred) {
255 case arith::CmpIPredicate::eq:
256 case arith::CmpIPredicate::ne:
257 case arith::CmpIPredicate::slt:
258 case arith::CmpIPredicate::sle:
259 case arith::CmpIPredicate::sgt:
260 case arith::CmpIPredicate::sge:
261 return false;
262 case arith::CmpIPredicate::ult:
263 case arith::CmpIPredicate::ule:
264 case arith::CmpIPredicate::ugt:
265 case arith::CmpIPredicate::uge:
266 return true;
267 }
268 llvm_unreachable("unknown cmpi predicate kind");
269 }
270
271 emitc::CmpPredicate toEmitCPred(arith::CmpIPredicate pred) const {
272 switch (pred) {
273 case arith::CmpIPredicate::eq:
274 return emitc::CmpPredicate::eq;
275 case arith::CmpIPredicate::ne:
276 return emitc::CmpPredicate::ne;
277 case arith::CmpIPredicate::slt:
278 case arith::CmpIPredicate::ult:
279 return emitc::CmpPredicate::lt;
280 case arith::CmpIPredicate::sle:
281 case arith::CmpIPredicate::ule:
282 return emitc::CmpPredicate::le;
283 case arith::CmpIPredicate::sgt:
284 case arith::CmpIPredicate::ugt:
285 return emitc::CmpPredicate::gt;
286 case arith::CmpIPredicate::sge:
287 case arith::CmpIPredicate::uge:
288 return emitc::CmpPredicate::ge;
289 }
290 llvm_unreachable("unknown cmpi predicate kind");
291 }
292
293 LogicalResult
294 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
295 ConversionPatternRewriter &rewriter) const override {
296
297 Type type = adaptor.getLhs().getType();
298 if (!type || !(isa<IntegerType>(Val: type) || emitc::isPointerWideType(type))) {
299 return rewriter.notifyMatchFailure(
300 op, "expected integer or size_t/ssize_t/ptrdiff_t type");
301 }
302
303 bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
304 emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
305
306 Type arithmeticType = adaptIntegralTypeSignedness(ty: type, needsUnsigned);
307 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
308 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
309
310 rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
311 return success();
312 }
313};
314
315class NegFOpConversion : public OpConversionPattern<arith::NegFOp> {
316public:
317 using OpConversionPattern::OpConversionPattern;
318
319 LogicalResult
320 matchAndRewrite(arith::NegFOp op, OpAdaptor adaptor,
321 ConversionPatternRewriter &rewriter) const override {
322
323 auto adaptedOp = adaptor.getOperand();
324 auto adaptedOpType = adaptedOp.getType();
325
326 if (isa<TensorType>(adaptedOpType) || isa<VectorType>(adaptedOpType)) {
327 return rewriter.notifyMatchFailure(
328 op.getLoc(),
329 "negf currently only supports scalar types, not vectors or tensors");
330 }
331
332 if (!emitc::isSupportedFloatType(type: adaptedOpType)) {
333 return rewriter.notifyMatchFailure(
334 op.getLoc(), "floating-point type is not supported by EmitC");
335 }
336
337 rewriter.replaceOpWithNewOp<emitc::UnaryMinusOp>(op, adaptedOpType,
338 adaptedOp);
339 return success();
340 }
341};
342
343template <typename ArithOp, bool castToUnsigned>
344class CastConversion : public OpConversionPattern<ArithOp> {
345public:
346 using OpConversionPattern<ArithOp>::OpConversionPattern;
347
348 LogicalResult
349 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
350 ConversionPatternRewriter &rewriter) const override {
351
352 Type opReturnType = this->getTypeConverter()->convertType(op.getType());
353 if (!opReturnType || !(isa<IntegerType>(Val: opReturnType) ||
354 emitc::isPointerWideType(type: opReturnType)))
355 return rewriter.notifyMatchFailure(
356 op, "expected integer or size_t/ssize_t/ptrdiff_t result type");
357
358 if (adaptor.getOperands().size() != 1) {
359 return rewriter.notifyMatchFailure(
360 op, "CastConversion only supports unary ops");
361 }
362
363 Type operandType = adaptor.getIn().getType();
364 if (!operandType || !(isa<IntegerType>(Val: operandType) ||
365 emitc::isPointerWideType(type: operandType)))
366 return rewriter.notifyMatchFailure(
367 op, "expected integer or size_t/ssize_t/ptrdiff_t operand type");
368
369 // Signed (sign-extending) casts from i1 are not supported.
370 if (operandType.isInteger(width: 1) && !castToUnsigned)
371 return rewriter.notifyMatchFailure(op,
372 "operation not supported on i1 type");
373
374 // to-i1 conversions: arith semantics want truncation, whereas (bool)(v) is
375 // equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
376 // truncation.
377 if (opReturnType.isInteger(width: 1)) {
378 Type attrType = (emitc::isPointerWideType(type: operandType))
379 ? rewriter.getIndexType()
380 : operandType;
381 auto constOne = rewriter.create<emitc::ConstantOp>(
382 op.getLoc(), operandType, rewriter.getOneAttr(attrType));
383 auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>(
384 op.getLoc(), operandType, adaptor.getIn(), constOne);
385 rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
386 oneAndOperand);
387 return success();
388 }
389
390 bool isTruncation =
391 (isa<IntegerType>(Val: operandType) && isa<IntegerType>(Val: opReturnType) &&
392 operandType.getIntOrFloatBitWidth() >
393 opReturnType.getIntOrFloatBitWidth());
394 bool doUnsigned = castToUnsigned || isTruncation;
395
396 // Adapt the signedness of the result (bitwidth-preserving cast)
397 // This is needed e.g., if the return type is signless.
398 Type castDestType = adaptIntegralTypeSignedness(ty: opReturnType, needsUnsigned: doUnsigned);
399
400 // Adapt the signedness of the operand (bitwidth-preserving cast)
401 Type castSrcType = adaptIntegralTypeSignedness(ty: operandType, needsUnsigned: doUnsigned);
402 Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
403
404 // Actual cast (may change bitwidth)
405 auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(),
406 castDestType, actualOp);
407
408 // Cast to the expected output type
409 auto result = adaptValueType(cast, rewriter, opReturnType);
410
411 rewriter.replaceOp(op, result);
412 return success();
413 }
414};
415
416template <typename ArithOp>
417class UnsignedCastConversion : public CastConversion<ArithOp, true> {
418 using CastConversion<ArithOp, true>::CastConversion;
419};
420
421template <typename ArithOp>
422class SignedCastConversion : public CastConversion<ArithOp, false> {
423 using CastConversion<ArithOp, false>::CastConversion;
424};
425
426template <typename ArithOp, typename EmitCOp>
427class ArithOpConversion final : public OpConversionPattern<ArithOp> {
428public:
429 using OpConversionPattern<ArithOp>::OpConversionPattern;
430
431 LogicalResult
432 matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
433 ConversionPatternRewriter &rewriter) const override {
434
435 Type newTy = this->getTypeConverter()->convertType(arithOp.getType());
436 if (!newTy)
437 return rewriter.notifyMatchFailure(arithOp,
438 "converting result type failed");
439 rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, newTy,
440 adaptor.getOperands());
441
442 return success();
443 }
444};
445
446template <class ArithOp, class EmitCOp>
447class BinaryUIOpConversion final : public OpConversionPattern<ArithOp> {
448public:
449 using OpConversionPattern<ArithOp>::OpConversionPattern;
450
451 LogicalResult
452 matchAndRewrite(ArithOp uiBinOp, typename ArithOp::Adaptor adaptor,
453 ConversionPatternRewriter &rewriter) const override {
454 Type newRetTy = this->getTypeConverter()->convertType(uiBinOp.getType());
455 if (!newRetTy)
456 return rewriter.notifyMatchFailure(uiBinOp,
457 "converting result type failed");
458 if (!isa<IntegerType>(Val: newRetTy)) {
459 return rewriter.notifyMatchFailure(uiBinOp, "expected integer type");
460 }
461 Type unsignedType =
462 adaptIntegralTypeSignedness(ty: newRetTy, /*needsUnsigned=*/true);
463 if (!unsignedType)
464 return rewriter.notifyMatchFailure(uiBinOp,
465 "converting result type failed");
466 Value lhsAdapted = adaptValueType(uiBinOp.getLhs(), rewriter, unsignedType);
467 Value rhsAdapted = adaptValueType(uiBinOp.getRhs(), rewriter, unsignedType);
468
469 auto newDivOp =
470 rewriter.create<EmitCOp>(uiBinOp.getLoc(), unsignedType,
471 ArrayRef<Value>{lhsAdapted, rhsAdapted});
472 Value resultAdapted = adaptValueType(newDivOp, rewriter, newRetTy);
473 rewriter.replaceOp(uiBinOp, resultAdapted);
474 return success();
475 }
476};
477
478template <typename ArithOp, typename EmitCOp>
479class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
480public:
481 using OpConversionPattern<ArithOp>::OpConversionPattern;
482
483 LogicalResult
484 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
485 ConversionPatternRewriter &rewriter) const override {
486
487 Type type = this->getTypeConverter()->convertType(op.getType());
488 if (!type || !(isa<IntegerType>(Val: type) || emitc::isPointerWideType(type))) {
489 return rewriter.notifyMatchFailure(
490 op, "expected integer or size_t/ssize_t/ptrdiff_t type");
491 }
492
493 if (type.isInteger(width: 1)) {
494 // arith expects wrap-around arithmethic, which doesn't happen on `bool`.
495 return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
496 }
497
498 Type arithmeticType = type;
499 if ((type.isSignlessInteger() || type.isSignedInteger()) &&
500 !bitEnumContainsAll(op.getOverflowFlags(),
501 arith::IntegerOverflowFlags::nsw)) {
502 // If the C type is signed and the op doesn't guarantee "No Signed Wrap",
503 // we compute in unsigned integers to avoid UB.
504 arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
505 /*isSigned=*/false);
506 }
507
508 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
509 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
510
511 Value arithmeticResult = rewriter.template create<EmitCOp>(
512 op.getLoc(), arithmeticType, lhs, rhs);
513
514 Value result = adaptValueType(val: arithmeticResult, rewriter, ty: type);
515
516 rewriter.replaceOp(op, result);
517 return success();
518 }
519};
520
521template <typename ArithOp, typename EmitCOp>
522class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
523public:
524 using OpConversionPattern<ArithOp>::OpConversionPattern;
525
526 LogicalResult
527 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
528 ConversionPatternRewriter &rewriter) const override {
529
530 Type type = this->getTypeConverter()->convertType(op.getType());
531 if (!isa_and_nonnull<IntegerType>(Val: type)) {
532 return rewriter.notifyMatchFailure(
533 op,
534 "expected integer type, vector/tensor support not yet implemented");
535 }
536
537 // Bitwise ops can be performed directly on booleans
538 if (type.isInteger(width: 1)) {
539 rewriter.replaceOpWithNewOp<EmitCOp>(op, type, adaptor.getLhs(),
540 adaptor.getRhs());
541 return success();
542 }
543
544 // Bitwise ops are defined by the C standard on unsigned operands.
545 Type arithmeticType =
546 adaptIntegralTypeSignedness(ty: type, /*needsUnsigned=*/true);
547
548 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
549 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
550
551 Value arithmeticResult = rewriter.template create<EmitCOp>(
552 op.getLoc(), arithmeticType, lhs, rhs);
553
554 Value result = adaptValueType(val: arithmeticResult, rewriter, ty: type);
555
556 rewriter.replaceOp(op, result);
557 return success();
558 }
559};
560
561template <typename ArithOp, typename EmitCOp, bool isUnsignedOp>
562class ShiftOpConversion : public OpConversionPattern<ArithOp> {
563public:
564 using OpConversionPattern<ArithOp>::OpConversionPattern;
565
566 LogicalResult
567 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
568 ConversionPatternRewriter &rewriter) const override {
569
570 Type type = this->getTypeConverter()->convertType(op.getType());
571 if (!type || !(isa<IntegerType>(Val: type) || emitc::isPointerWideType(type))) {
572 return rewriter.notifyMatchFailure(
573 op, "expected integer or size_t/ssize_t/ptrdiff_t type");
574 }
575
576 if (type.isInteger(width: 1)) {
577 return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
578 }
579
580 Type arithmeticType = adaptIntegralTypeSignedness(ty: type, needsUnsigned: isUnsignedOp);
581
582 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
583 // Shift amount interpreted as unsigned per Arith dialect spec.
584 Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(),
585 /*needsUnsigned=*/true);
586 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType);
587
588 // Add a runtime check for overflow
589 Value width;
590 if (emitc::isPointerWideType(type)) {
591 Value eight = rewriter.create<emitc::ConstantOp>(
592 op.getLoc(), rhsType, rewriter.getIndexAttr(8));
593 emitc::CallOpaqueOp sizeOfCall = rewriter.create<emitc::CallOpaqueOp>(
594 op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight});
595 width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight,
596 sizeOfCall.getResult(0));
597 } else {
598 width = rewriter.create<emitc::ConstantOp>(
599 op.getLoc(), rhsType,
600 rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
601 }
602
603 Value excessCheck = rewriter.create<emitc::CmpOp>(
604 op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width);
605
606 // Any concrete value is a valid refinement of poison.
607 Value poison = rewriter.create<emitc::ConstantOp>(
608 op.getLoc(), arithmeticType,
609 (isa<IntegerType>(arithmeticType)
610 ? rewriter.getIntegerAttr(arithmeticType, 0)
611 : rewriter.getIndexAttr(0)));
612
613 emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>(
614 op.getLoc(), arithmeticType, /*do_not_inline=*/false);
615 Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
616 auto currentPoint = rewriter.getInsertionPoint();
617 rewriter.setInsertionPointToStart(&bodyBlock);
618 Value arithmeticResult =
619 rewriter.create<EmitCOp>(op.getLoc(), arithmeticType, lhs, rhs);
620 Value resultOrPoison = rewriter.create<emitc::ConditionalOp>(
621 op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison);
622 rewriter.create<emitc::YieldOp>(op.getLoc(), resultOrPoison);
623 rewriter.setInsertionPoint(op->getBlock(), currentPoint);
624
625 Value result = adaptValueType(ternary, rewriter, type);
626
627 rewriter.replaceOp(op, result);
628 return success();
629 }
630};
631
632template <typename ArithOp, typename EmitCOp>
633class SignedShiftOpConversion final
634 : public ShiftOpConversion<ArithOp, EmitCOp, false> {
635 using ShiftOpConversion<ArithOp, EmitCOp, false>::ShiftOpConversion;
636};
637
638template <typename ArithOp, typename EmitCOp>
639class UnsignedShiftOpConversion final
640 : public ShiftOpConversion<ArithOp, EmitCOp, true> {
641 using ShiftOpConversion<ArithOp, EmitCOp, true>::ShiftOpConversion;
642};
643
644class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
645public:
646 using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
647
648 LogicalResult
649 matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
650 ConversionPatternRewriter &rewriter) const override {
651
652 Type dstType = getTypeConverter()->convertType(selectOp.getType());
653 if (!dstType)
654 return rewriter.notifyMatchFailure(selectOp, "type conversion failed");
655
656 if (!adaptor.getCondition().getType().isInteger(1))
657 return rewriter.notifyMatchFailure(
658 selectOp,
659 "can only be converted if condition is a scalar of type i1");
660
661 rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(selectOp, dstType,
662 adaptor.getOperands());
663
664 return success();
665 }
666};
667
668// Floating-point to integer conversions.
669template <typename CastOp>
670class FtoICastOpConversion : public OpConversionPattern<CastOp> {
671public:
672 FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
673 : OpConversionPattern<CastOp>(typeConverter, context) {}
674
675 LogicalResult
676 matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
677 ConversionPatternRewriter &rewriter) const override {
678
679 Type operandType = adaptor.getIn().getType();
680 if (!emitc::isSupportedFloatType(type: operandType))
681 return rewriter.notifyMatchFailure(castOp,
682 "unsupported cast source type");
683
684 Type dstType = this->getTypeConverter()->convertType(castOp.getType());
685 if (!dstType)
686 return rewriter.notifyMatchFailure(castOp, "type conversion failed");
687
688 // Float-to-i1 casts are not supported: any value with 0 < value < 1 must be
689 // truncated to 0, whereas a boolean conversion would return true.
690 if (!emitc::isSupportedIntegerType(type: dstType) || dstType.isInteger(width: 1))
691 return rewriter.notifyMatchFailure(castOp,
692 "unsupported cast destination type");
693
694 // Convert to unsigned if it's the "ui" variant
695 // Signless is interpreted as signed, so no need to cast for "si"
696 Type actualResultType = dstType;
697 if (isa<arith::FPToUIOp>(castOp)) {
698 actualResultType =
699 rewriter.getIntegerType(dstType.getIntOrFloatBitWidth(),
700 /*isSigned=*/false);
701 }
702
703 Value result = rewriter.create<emitc::CastOp>(
704 castOp.getLoc(), actualResultType, adaptor.getOperands());
705
706 if (isa<arith::FPToUIOp>(castOp)) {
707 result = rewriter.create<emitc::CastOp>(castOp.getLoc(), dstType, result);
708 }
709 rewriter.replaceOp(castOp, result);
710
711 return success();
712 }
713};
714
715// Integer to floating-point conversions.
716template <typename CastOp>
717class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
718public:
719 ItoFCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
720 : OpConversionPattern<CastOp>(typeConverter, context) {}
721
722 LogicalResult
723 matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
724 ConversionPatternRewriter &rewriter) const override {
725 // Vectors in particular are not supported
726 Type operandType = adaptor.getIn().getType();
727 if (!emitc::isSupportedIntegerType(type: operandType))
728 return rewriter.notifyMatchFailure(castOp,
729 "unsupported cast source type");
730
731 Type dstType = this->getTypeConverter()->convertType(castOp.getType());
732 if (!dstType)
733 return rewriter.notifyMatchFailure(castOp, "type conversion failed");
734
735 if (!emitc::isSupportedFloatType(type: dstType))
736 return rewriter.notifyMatchFailure(castOp,
737 "unsupported cast destination type");
738
739 // Convert to unsigned if it's the "ui" variant
740 // Signless is interpreted as signed, so no need to cast for "si"
741 Type actualOperandType = operandType;
742 if (isa<arith::UIToFPOp>(castOp)) {
743 actualOperandType =
744 rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
745 /*isSigned=*/false);
746 }
747 Value fpCastOperand = adaptor.getIn();
748 if (actualOperandType != operandType) {
749 fpCastOperand = rewriter.template create<emitc::CastOp>(
750 castOp.getLoc(), actualOperandType, fpCastOperand);
751 }
752 rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
753
754 return success();
755 }
756};
757
758// Floating-point to floating-point conversions.
759template <typename CastOp>
760class FpCastOpConversion : public OpConversionPattern<CastOp> {
761public:
762 FpCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
763 : OpConversionPattern<CastOp>(typeConverter, context) {}
764
765 LogicalResult
766 matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
767 ConversionPatternRewriter &rewriter) const override {
768 // Vectors in particular are not supported.
769 Type operandType = adaptor.getIn().getType();
770 if (!emitc::isSupportedFloatType(type: operandType))
771 return rewriter.notifyMatchFailure(castOp,
772 "unsupported cast source type");
773 if (auto roundingModeOp =
774 dyn_cast<arith::ArithRoundingModeInterface>(*castOp)) {
775 // Only supporting default rounding mode as of now.
776 if (roundingModeOp.getRoundingModeAttr())
777 return rewriter.notifyMatchFailure(castOp, "unsupported rounding mode");
778 }
779
780 Type dstType = this->getTypeConverter()->convertType(castOp.getType());
781 if (!dstType)
782 return rewriter.notifyMatchFailure(castOp, "type conversion failed");
783
784 if (!emitc::isSupportedFloatType(type: dstType))
785 return rewriter.notifyMatchFailure(castOp,
786 "unsupported cast destination type");
787
788 Value fpCastOperand = adaptor.getIn();
789 rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
790
791 return success();
792 }
793};
794
795} // namespace
796
797//===----------------------------------------------------------------------===//
798// Pattern population
799//===----------------------------------------------------------------------===//
800
801void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
802 RewritePatternSet &patterns) {
803 MLIRContext *ctx = patterns.getContext();
804
805 mlir::populateEmitCSizeTTypeConversions(converter&: typeConverter);
806
807 // clang-format off
808 patterns.add<
809 ArithConstantOpConversionPattern,
810 ArithOpConversion<arith::AddFOp, emitc::AddOp>,
811 ArithOpConversion<arith::DivFOp, emitc::DivOp>,
812 ArithOpConversion<arith::DivSIOp, emitc::DivOp>,
813 ArithOpConversion<arith::MulFOp, emitc::MulOp>,
814 ArithOpConversion<arith::RemSIOp, emitc::RemOp>,
815 ArithOpConversion<arith::SubFOp, emitc::SubOp>,
816 BinaryUIOpConversion<arith::DivUIOp, emitc::DivOp>,
817 BinaryUIOpConversion<arith::RemUIOp, emitc::RemOp>,
818 IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
819 IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
820 IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
821 BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
822 BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
823 BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
824 UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>,
825 SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>,
826 UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>,
827 CmpFOpConversion,
828 CmpIOpConversion,
829 NegFOpConversion,
830 SelectOpConversion,
831 // Truncation is guaranteed for unsigned types.
832 UnsignedCastConversion<arith::TruncIOp>,
833 SignedCastConversion<arith::ExtSIOp>,
834 UnsignedCastConversion<arith::ExtUIOp>,
835 SignedCastConversion<arith::IndexCastOp>,
836 UnsignedCastConversion<arith::IndexCastUIOp>,
837 ItoFCastOpConversion<arith::SIToFPOp>,
838 ItoFCastOpConversion<arith::UIToFPOp>,
839 FtoICastOpConversion<arith::FPToSIOp>,
840 FtoICastOpConversion<arith::FPToUIOp>,
841 FpCastOpConversion<arith::ExtFOp>,
842 FpCastOpConversion<arith::TruncFOp>
843 >(typeConverter, ctx);
844 // clang-format on
845}
846

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp