1//===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert the Arith dialect to the EmitC
10// dialect.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
15
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/EmitC/IR/EmitC.h"
18#include "mlir/Transforms/DialectConversion.h"
19
20using namespace mlir;
21
22//===----------------------------------------------------------------------===//
23// Conversion Patterns
24//===----------------------------------------------------------------------===//
25
26namespace {
27class ArithConstantOpConversionPattern
28 : public OpConversionPattern<arith::ConstantOp> {
29public:
30 using OpConversionPattern::OpConversionPattern;
31
32 LogicalResult
33 matchAndRewrite(arith::ConstantOp arithConst,
34 arith::ConstantOp::Adaptor adaptor,
35 ConversionPatternRewriter &rewriter) const override {
36 rewriter.replaceOpWithNewOp<emitc::ConstantOp>(
37 arithConst, arithConst.getType(), adaptor.getValue());
38 return success();
39 }
40};
41
42class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
43public:
44 using OpConversionPattern::OpConversionPattern;
45
46 bool needsUnsignedCmp(arith::CmpIPredicate pred) const {
47 switch (pred) {
48 case arith::CmpIPredicate::eq:
49 case arith::CmpIPredicate::ne:
50 case arith::CmpIPredicate::slt:
51 case arith::CmpIPredicate::sle:
52 case arith::CmpIPredicate::sgt:
53 case arith::CmpIPredicate::sge:
54 return false;
55 case arith::CmpIPredicate::ult:
56 case arith::CmpIPredicate::ule:
57 case arith::CmpIPredicate::ugt:
58 case arith::CmpIPredicate::uge:
59 return true;
60 }
61 llvm_unreachable("unknown cmpi predicate kind");
62 }
63
64 emitc::CmpPredicate toEmitCPred(arith::CmpIPredicate pred) const {
65 switch (pred) {
66 case arith::CmpIPredicate::eq:
67 return emitc::CmpPredicate::eq;
68 case arith::CmpIPredicate::ne:
69 return emitc::CmpPredicate::ne;
70 case arith::CmpIPredicate::slt:
71 case arith::CmpIPredicate::ult:
72 return emitc::CmpPredicate::lt;
73 case arith::CmpIPredicate::sle:
74 case arith::CmpIPredicate::ule:
75 return emitc::CmpPredicate::le;
76 case arith::CmpIPredicate::sgt:
77 case arith::CmpIPredicate::ugt:
78 return emitc::CmpPredicate::gt;
79 case arith::CmpIPredicate::sge:
80 case arith::CmpIPredicate::uge:
81 return emitc::CmpPredicate::ge;
82 }
83 llvm_unreachable("unknown cmpi predicate kind");
84 }
85
86 LogicalResult
87 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
88 ConversionPatternRewriter &rewriter) const override {
89
90 Type type = adaptor.getLhs().getType();
91 if (!isa_and_nonnull<IntegerType, IndexType>(Val: type)) {
92 return rewriter.notifyMatchFailure(op, "expected integer or index type");
93 }
94
95 bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
96 emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
97 Type arithmeticType = type;
98 if (type.isUnsignedInteger() != needsUnsigned) {
99 arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
100 /*isSigned=*/!needsUnsigned);
101 }
102 Value lhs = adaptor.getLhs();
103 Value rhs = adaptor.getRhs();
104 if (arithmeticType != type) {
105 lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
106 lhs);
107 rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
108 rhs);
109 }
110 rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
111 return success();
112 }
113};
114
115template <typename ArithOp, typename EmitCOp>
116class ArithOpConversion final : public OpConversionPattern<ArithOp> {
117public:
118 using OpConversionPattern<ArithOp>::OpConversionPattern;
119
120 LogicalResult
121 matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
122 ConversionPatternRewriter &rewriter) const override {
123
124 rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, arithOp.getType(),
125 adaptor.getOperands());
126
127 return success();
128 }
129};
130
131template <typename ArithOp, typename EmitCOp>
132class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
133public:
134 using OpConversionPattern<ArithOp>::OpConversionPattern;
135
136 LogicalResult
137 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
138 ConversionPatternRewriter &rewriter) const override {
139
140 Type type = this->getTypeConverter()->convertType(op.getType());
141 if (!isa_and_nonnull<IntegerType, IndexType>(Val: type)) {
142 return rewriter.notifyMatchFailure(op, "expected integer type");
143 }
144
145 if (type.isInteger(width: 1)) {
146 // arith expects wrap-around arithmethic, which doesn't happen on `bool`.
147 return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
148 }
149
150 Value lhs = adaptor.getLhs();
151 Value rhs = adaptor.getRhs();
152 Type arithmeticType = type;
153 if ((type.isSignlessInteger() || type.isSignedInteger()) &&
154 !bitEnumContainsAll(op.getOverflowFlags(),
155 arith::IntegerOverflowFlags::nsw)) {
156 // If the C type is signed and the op doesn't guarantee "No Signed Wrap",
157 // we compute in unsigned integers to avoid UB.
158 arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
159 /*isSigned=*/false);
160 }
161 if (arithmeticType != type) {
162 lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
163 lhs);
164 rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
165 rhs);
166 }
167
168 Value result = rewriter.template create<EmitCOp>(op.getLoc(),
169 arithmeticType, lhs, rhs);
170
171 if (arithmeticType != type) {
172 result =
173 rewriter.template create<emitc::CastOp>(op.getLoc(), type, result);
174 }
175 rewriter.replaceOp(op, result);
176 return success();
177 }
178};
179
180class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
181public:
182 using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
183
184 LogicalResult
185 matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
186 ConversionPatternRewriter &rewriter) const override {
187
188 Type dstType = getTypeConverter()->convertType(selectOp.getType());
189 if (!dstType)
190 return rewriter.notifyMatchFailure(selectOp, "type conversion failed");
191
192 if (!adaptor.getCondition().getType().isInteger(1))
193 return rewriter.notifyMatchFailure(
194 selectOp,
195 "can only be converted if condition is a scalar of type i1");
196
197 rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(selectOp, dstType,
198 adaptor.getOperands());
199
200 return success();
201 }
202};
203
204} // namespace
205
206//===----------------------------------------------------------------------===//
207// Pattern population
208//===----------------------------------------------------------------------===//
209
210void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
211 RewritePatternSet &patterns) {
212 MLIRContext *ctx = patterns.getContext();
213
214 // clang-format off
215 patterns.add<
216 ArithConstantOpConversionPattern,
217 ArithOpConversion<arith::AddFOp, emitc::AddOp>,
218 ArithOpConversion<arith::DivFOp, emitc::DivOp>,
219 ArithOpConversion<arith::MulFOp, emitc::MulOp>,
220 ArithOpConversion<arith::SubFOp, emitc::SubOp>,
221 IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
222 IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
223 IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
224 CmpIOpConversion,
225 SelectOpConversion
226 >(typeConverter, ctx);
227 // clang-format on
228}
229

source code of mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp