1 | //===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements patterns to convert the Arith dialect to the EmitC |
10 | // dialect. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h" |
15 | |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/EmitC/IR/EmitC.h" |
18 | #include "mlir/Transforms/DialectConversion.h" |
19 | |
20 | using namespace mlir; |
21 | |
22 | //===----------------------------------------------------------------------===// |
23 | // Conversion Patterns |
24 | //===----------------------------------------------------------------------===// |
25 | |
26 | namespace { |
27 | class ArithConstantOpConversionPattern |
28 | : public OpConversionPattern<arith::ConstantOp> { |
29 | public: |
30 | using OpConversionPattern::OpConversionPattern; |
31 | |
32 | LogicalResult |
33 | matchAndRewrite(arith::ConstantOp arithConst, |
34 | arith::ConstantOp::Adaptor adaptor, |
35 | ConversionPatternRewriter &rewriter) const override { |
36 | rewriter.replaceOpWithNewOp<emitc::ConstantOp>( |
37 | arithConst, arithConst.getType(), adaptor.getValue()); |
38 | return success(); |
39 | } |
40 | }; |
41 | |
42 | class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> { |
43 | public: |
44 | using OpConversionPattern::OpConversionPattern; |
45 | |
46 | bool needsUnsignedCmp(arith::CmpIPredicate pred) const { |
47 | switch (pred) { |
48 | case arith::CmpIPredicate::eq: |
49 | case arith::CmpIPredicate::ne: |
50 | case arith::CmpIPredicate::slt: |
51 | case arith::CmpIPredicate::sle: |
52 | case arith::CmpIPredicate::sgt: |
53 | case arith::CmpIPredicate::sge: |
54 | return false; |
55 | case arith::CmpIPredicate::ult: |
56 | case arith::CmpIPredicate::ule: |
57 | case arith::CmpIPredicate::ugt: |
58 | case arith::CmpIPredicate::uge: |
59 | return true; |
60 | } |
61 | llvm_unreachable("unknown cmpi predicate kind" ); |
62 | } |
63 | |
64 | emitc::CmpPredicate toEmitCPred(arith::CmpIPredicate pred) const { |
65 | switch (pred) { |
66 | case arith::CmpIPredicate::eq: |
67 | return emitc::CmpPredicate::eq; |
68 | case arith::CmpIPredicate::ne: |
69 | return emitc::CmpPredicate::ne; |
70 | case arith::CmpIPredicate::slt: |
71 | case arith::CmpIPredicate::ult: |
72 | return emitc::CmpPredicate::lt; |
73 | case arith::CmpIPredicate::sle: |
74 | case arith::CmpIPredicate::ule: |
75 | return emitc::CmpPredicate::le; |
76 | case arith::CmpIPredicate::sgt: |
77 | case arith::CmpIPredicate::ugt: |
78 | return emitc::CmpPredicate::gt; |
79 | case arith::CmpIPredicate::sge: |
80 | case arith::CmpIPredicate::uge: |
81 | return emitc::CmpPredicate::ge; |
82 | } |
83 | llvm_unreachable("unknown cmpi predicate kind" ); |
84 | } |
85 | |
86 | LogicalResult |
87 | matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, |
88 | ConversionPatternRewriter &rewriter) const override { |
89 | |
90 | Type type = adaptor.getLhs().getType(); |
91 | if (!isa_and_nonnull<IntegerType, IndexType>(Val: type)) { |
92 | return rewriter.notifyMatchFailure(op, "expected integer or index type" ); |
93 | } |
94 | |
95 | bool needsUnsigned = needsUnsignedCmp(op.getPredicate()); |
96 | emitc::CmpPredicate pred = toEmitCPred(op.getPredicate()); |
97 | Type arithmeticType = type; |
98 | if (type.isUnsignedInteger() != needsUnsigned) { |
99 | arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(), |
100 | /*isSigned=*/!needsUnsigned); |
101 | } |
102 | Value lhs = adaptor.getLhs(); |
103 | Value rhs = adaptor.getRhs(); |
104 | if (arithmeticType != type) { |
105 | lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType, |
106 | lhs); |
107 | rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType, |
108 | rhs); |
109 | } |
110 | rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs); |
111 | return success(); |
112 | } |
113 | }; |
114 | |
115 | template <typename ArithOp, typename EmitCOp> |
116 | class ArithOpConversion final : public OpConversionPattern<ArithOp> { |
117 | public: |
118 | using OpConversionPattern<ArithOp>::OpConversionPattern; |
119 | |
120 | LogicalResult |
121 | matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor, |
122 | ConversionPatternRewriter &rewriter) const override { |
123 | |
124 | rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, arithOp.getType(), |
125 | adaptor.getOperands()); |
126 | |
127 | return success(); |
128 | } |
129 | }; |
130 | |
131 | template <typename ArithOp, typename EmitCOp> |
132 | class IntegerOpConversion final : public OpConversionPattern<ArithOp> { |
133 | public: |
134 | using OpConversionPattern<ArithOp>::OpConversionPattern; |
135 | |
136 | LogicalResult |
137 | matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, |
138 | ConversionPatternRewriter &rewriter) const override { |
139 | |
140 | Type type = this->getTypeConverter()->convertType(op.getType()); |
141 | if (!isa_and_nonnull<IntegerType, IndexType>(Val: type)) { |
142 | return rewriter.notifyMatchFailure(op, "expected integer type" ); |
143 | } |
144 | |
145 | if (type.isInteger(width: 1)) { |
146 | // arith expects wrap-around arithmethic, which doesn't happen on `bool`. |
147 | return rewriter.notifyMatchFailure(op, "i1 type is not implemented" ); |
148 | } |
149 | |
150 | Value lhs = adaptor.getLhs(); |
151 | Value rhs = adaptor.getRhs(); |
152 | Type arithmeticType = type; |
153 | if ((type.isSignlessInteger() || type.isSignedInteger()) && |
154 | !bitEnumContainsAll(op.getOverflowFlags(), |
155 | arith::IntegerOverflowFlags::nsw)) { |
156 | // If the C type is signed and the op doesn't guarantee "No Signed Wrap", |
157 | // we compute in unsigned integers to avoid UB. |
158 | arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(), |
159 | /*isSigned=*/false); |
160 | } |
161 | if (arithmeticType != type) { |
162 | lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType, |
163 | lhs); |
164 | rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType, |
165 | rhs); |
166 | } |
167 | |
168 | Value result = rewriter.template create<EmitCOp>(op.getLoc(), |
169 | arithmeticType, lhs, rhs); |
170 | |
171 | if (arithmeticType != type) { |
172 | result = |
173 | rewriter.template create<emitc::CastOp>(op.getLoc(), type, result); |
174 | } |
175 | rewriter.replaceOp(op, result); |
176 | return success(); |
177 | } |
178 | }; |
179 | |
180 | class SelectOpConversion : public OpConversionPattern<arith::SelectOp> { |
181 | public: |
182 | using OpConversionPattern<arith::SelectOp>::OpConversionPattern; |
183 | |
184 | LogicalResult |
185 | matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor, |
186 | ConversionPatternRewriter &rewriter) const override { |
187 | |
188 | Type dstType = getTypeConverter()->convertType(selectOp.getType()); |
189 | if (!dstType) |
190 | return rewriter.notifyMatchFailure(selectOp, "type conversion failed" ); |
191 | |
192 | if (!adaptor.getCondition().getType().isInteger(1)) |
193 | return rewriter.notifyMatchFailure( |
194 | selectOp, |
195 | "can only be converted if condition is a scalar of type i1" ); |
196 | |
197 | rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(selectOp, dstType, |
198 | adaptor.getOperands()); |
199 | |
200 | return success(); |
201 | } |
202 | }; |
203 | |
204 | } // namespace |
205 | |
206 | //===----------------------------------------------------------------------===// |
207 | // Pattern population |
208 | //===----------------------------------------------------------------------===// |
209 | |
210 | void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, |
211 | RewritePatternSet &patterns) { |
212 | MLIRContext *ctx = patterns.getContext(); |
213 | |
214 | // clang-format off |
215 | patterns.add< |
216 | ArithConstantOpConversionPattern, |
217 | ArithOpConversion<arith::AddFOp, emitc::AddOp>, |
218 | ArithOpConversion<arith::DivFOp, emitc::DivOp>, |
219 | ArithOpConversion<arith::MulFOp, emitc::MulOp>, |
220 | ArithOpConversion<arith::SubFOp, emitc::SubOp>, |
221 | IntegerOpConversion<arith::AddIOp, emitc::AddOp>, |
222 | IntegerOpConversion<arith::MulIOp, emitc::MulOp>, |
223 | IntegerOpConversion<arith::SubIOp, emitc::SubOp>, |
224 | CmpIOpConversion, |
225 | SelectOpConversion |
226 | >(typeConverter, ctx); |
227 | // clang-format on |
228 | } |
229 | |