1 | //===- ComplexOps.cpp - MLIR Complex Operations ---------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Arith/IR/Arith.h" |
10 | #include "mlir/Dialect/Complex/IR/Complex.h" |
11 | #include "mlir/IR/Builders.h" |
12 | #include "mlir/IR/BuiltinTypes.h" |
13 | #include "mlir/IR/Matchers.h" |
14 | #include "mlir/IR/PatternMatch.h" |
15 | |
16 | using namespace mlir; |
17 | using namespace mlir::complex; |
18 | |
19 | //===----------------------------------------------------------------------===// |
20 | // ConstantOp |
21 | //===----------------------------------------------------------------------===// |
22 | |
23 | OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { |
24 | return getValue(); |
25 | } |
26 | |
27 | void ConstantOp::getAsmResultNames( |
28 | function_ref<void(Value, StringRef)> setNameFn) { |
29 | setNameFn(getResult(), "cst" ); |
30 | } |
31 | |
32 | bool ConstantOp::isBuildableWith(Attribute value, Type type) { |
33 | if (auto arrAttr = llvm::dyn_cast<ArrayAttr>(value)) { |
34 | auto complexTy = llvm::dyn_cast<ComplexType>(type); |
35 | if (!complexTy || arrAttr.size() != 2) |
36 | return false; |
37 | auto complexEltTy = complexTy.getElementType(); |
38 | if (auto fre = llvm::dyn_cast<FloatAttr>(arrAttr[0])) { |
39 | auto im = llvm::dyn_cast<FloatAttr>(arrAttr[1]); |
40 | return im && fre.getType() == complexEltTy && |
41 | im.getType() == complexEltTy; |
42 | } |
43 | if (auto ire = llvm::dyn_cast<IntegerAttr>(arrAttr[0])) { |
44 | auto im = llvm::dyn_cast<IntegerAttr>(arrAttr[1]); |
45 | return im && ire.getType() == complexEltTy && |
46 | im.getType() == complexEltTy; |
47 | } |
48 | } |
49 | return false; |
50 | } |
51 | |
52 | LogicalResult ConstantOp::verify() { |
53 | ArrayAttr arrayAttr = getValue(); |
54 | if (arrayAttr.size() != 2) { |
55 | return emitOpError( |
56 | "requires 'value' to be a complex constant, represented as array of " |
57 | "two values" ); |
58 | } |
59 | |
60 | auto complexEltTy = getType().getElementType(); |
61 | if (!isa<FloatAttr, IntegerAttr>(arrayAttr[0]) || |
62 | !isa<FloatAttr, IntegerAttr>(arrayAttr[1])) |
63 | return emitOpError( |
64 | "requires attribute's elements to be float or integer attributes" ); |
65 | auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]); |
66 | auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]); |
67 | if (complexEltTy != re.getType() || complexEltTy != im.getType()) { |
68 | return emitOpError() |
69 | << "requires attribute's element types (" << re.getType() << ", " |
70 | << im.getType() |
71 | << ") to match the element type of the op's return type (" |
72 | << complexEltTy << ")" ; |
73 | } |
74 | return success(); |
75 | } |
76 | |
77 | //===----------------------------------------------------------------------===// |
78 | // BitcastOp |
79 | //===----------------------------------------------------------------------===// |
80 | |
81 | OpFoldResult BitcastOp::fold(FoldAdaptor bitcast) { |
82 | if (getOperand().getType() == getType()) |
83 | return getOperand(); |
84 | |
85 | return {}; |
86 | } |
87 | |
88 | LogicalResult BitcastOp::verify() { |
89 | auto operandType = getOperand().getType(); |
90 | auto resultType = getType(); |
91 | |
92 | // We allow this to be legal as it can be folded away. |
93 | if (operandType == resultType) |
94 | return success(); |
95 | |
96 | if (!operandType.isIntOrFloat() && !isa<ComplexType>(operandType)) { |
97 | return emitOpError("operand must be int/float/complex" ); |
98 | } |
99 | |
100 | if (!resultType.isIntOrFloat() && !isa<ComplexType>(resultType)) { |
101 | return emitOpError("result must be int/float/complex" ); |
102 | } |
103 | |
104 | if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) { |
105 | return emitOpError( |
106 | "requires that either input or output has a complex type" ); |
107 | } |
108 | |
109 | if (isa<ComplexType>(resultType)) |
110 | std::swap(operandType, resultType); |
111 | |
112 | int32_t operandBitwidth = dyn_cast<ComplexType>(operandType) |
113 | .getElementType() |
114 | .getIntOrFloatBitWidth() * |
115 | 2; |
116 | int32_t resultBitwidth = resultType.getIntOrFloatBitWidth(); |
117 | |
118 | if (operandBitwidth != resultBitwidth) { |
119 | return emitOpError("casting bitwidths do not match" ); |
120 | } |
121 | |
122 | return success(); |
123 | } |
124 | |
125 | struct MergeComplexBitcast final : OpRewritePattern<BitcastOp> { |
126 | using OpRewritePattern<BitcastOp>::OpRewritePattern; |
127 | |
128 | LogicalResult matchAndRewrite(BitcastOp op, |
129 | PatternRewriter &rewriter) const override { |
130 | if (auto defining = op.getOperand().getDefiningOp<BitcastOp>()) { |
131 | if (isa<ComplexType>(op.getType()) || |
132 | isa<ComplexType>(defining.getOperand().getType())) { |
133 | // complex.bitcast requires that input or output is complex. |
134 | rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(), |
135 | defining.getOperand()); |
136 | } else { |
137 | rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(), |
138 | defining.getOperand()); |
139 | } |
140 | return success(); |
141 | } |
142 | |
143 | if (auto defining = op.getOperand().getDefiningOp<arith::BitcastOp>()) { |
144 | rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(), |
145 | defining.getOperand()); |
146 | return success(); |
147 | } |
148 | |
149 | return failure(); |
150 | } |
151 | }; |
152 | |
153 | struct MergeArithBitcast final : OpRewritePattern<arith::BitcastOp> { |
154 | using OpRewritePattern<arith::BitcastOp>::OpRewritePattern; |
155 | |
156 | LogicalResult matchAndRewrite(arith::BitcastOp op, |
157 | PatternRewriter &rewriter) const override { |
158 | if (auto defining = op.getOperand().getDefiningOp<complex::BitcastOp>()) { |
159 | rewriter.replaceOpWithNewOp<complex::BitcastOp>(op, op.getType(), |
160 | defining.getOperand()); |
161 | return success(); |
162 | } |
163 | |
164 | return failure(); |
165 | } |
166 | }; |
167 | |
168 | void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results, |
169 | MLIRContext *context) { |
170 | results.add<MergeComplexBitcast, MergeArithBitcast>(context); |
171 | } |
172 | |
173 | //===----------------------------------------------------------------------===// |
174 | // CreateOp |
175 | //===----------------------------------------------------------------------===// |
176 | |
177 | OpFoldResult CreateOp::fold(FoldAdaptor adaptor) { |
178 | // Fold complex.create(complex.re(op), complex.im(op)). |
179 | if (auto reOp = getOperand(0).getDefiningOp<ReOp>()) { |
180 | if (auto imOp = getOperand(1).getDefiningOp<ImOp>()) { |
181 | if (reOp.getOperand() == imOp.getOperand()) { |
182 | return reOp.getOperand(); |
183 | } |
184 | } |
185 | } |
186 | return {}; |
187 | } |
188 | |
189 | //===----------------------------------------------------------------------===// |
190 | // ImOp |
191 | //===----------------------------------------------------------------------===// |
192 | |
193 | OpFoldResult ImOp::fold(FoldAdaptor adaptor) { |
194 | ArrayAttr arrayAttr = |
195 | llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex()); |
196 | if (arrayAttr && arrayAttr.size() == 2) |
197 | return arrayAttr[1]; |
198 | if (auto createOp = getOperand().getDefiningOp<CreateOp>()) |
199 | return createOp.getOperand(1); |
200 | return {}; |
201 | } |
202 | |
203 | namespace { |
204 | template <typename OpKind, int ComponentIndex> |
205 | struct FoldComponentNeg final : OpRewritePattern<OpKind> { |
206 | using OpRewritePattern<OpKind>::OpRewritePattern; |
207 | |
208 | LogicalResult matchAndRewrite(OpKind op, |
209 | PatternRewriter &rewriter) const override { |
210 | auto negOp = op.getOperand().template getDefiningOp<NegOp>(); |
211 | if (!negOp) |
212 | return failure(); |
213 | |
214 | auto createOp = negOp.getComplex().template getDefiningOp<CreateOp>(); |
215 | if (!createOp) |
216 | return failure(); |
217 | |
218 | Type elementType = createOp.getType().getElementType(); |
219 | assert(isa<FloatType>(elementType)); |
220 | |
221 | rewriter.replaceOpWithNewOp<arith::NegFOp>( |
222 | op, elementType, createOp.getOperand(ComponentIndex)); |
223 | return success(); |
224 | } |
225 | }; |
226 | } // namespace |
227 | |
228 | void ImOp::getCanonicalizationPatterns(RewritePatternSet &results, |
229 | MLIRContext *context) { |
230 | results.add<FoldComponentNeg<ImOp, 1>>(context); |
231 | } |
232 | |
233 | //===----------------------------------------------------------------------===// |
234 | // ReOp |
235 | //===----------------------------------------------------------------------===// |
236 | |
237 | OpFoldResult ReOp::fold(FoldAdaptor adaptor) { |
238 | ArrayAttr arrayAttr = |
239 | llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex()); |
240 | if (arrayAttr && arrayAttr.size() == 2) |
241 | return arrayAttr[0]; |
242 | if (auto createOp = getOperand().getDefiningOp<CreateOp>()) |
243 | return createOp.getOperand(0); |
244 | return {}; |
245 | } |
246 | |
247 | void ReOp::getCanonicalizationPatterns(RewritePatternSet &results, |
248 | MLIRContext *context) { |
249 | results.add<FoldComponentNeg<ReOp, 0>>(context); |
250 | } |
251 | |
252 | //===----------------------------------------------------------------------===// |
253 | // AddOp |
254 | //===----------------------------------------------------------------------===// |
255 | |
256 | OpFoldResult AddOp::fold(FoldAdaptor adaptor) { |
257 | // complex.add(complex.sub(a, b), b) -> a |
258 | if (auto sub = getLhs().getDefiningOp<SubOp>()) |
259 | if (getRhs() == sub.getRhs()) |
260 | return sub.getLhs(); |
261 | |
262 | // complex.add(b, complex.sub(a, b)) -> a |
263 | if (auto sub = getRhs().getDefiningOp<SubOp>()) |
264 | if (getLhs() == sub.getRhs()) |
265 | return sub.getLhs(); |
266 | |
267 | // complex.add(a, complex.constant<0.0, 0.0>) -> a |
268 | if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) { |
269 | auto arrayAttr = constantOp.getValue(); |
270 | if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() && |
271 | llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) { |
272 | return getLhs(); |
273 | } |
274 | } |
275 | |
276 | return {}; |
277 | } |
278 | |
279 | //===----------------------------------------------------------------------===// |
280 | // SubOp |
281 | //===----------------------------------------------------------------------===// |
282 | |
283 | OpFoldResult SubOp::fold(FoldAdaptor adaptor) { |
284 | // complex.sub(complex.add(a, b), b) -> a |
285 | if (auto add = getLhs().getDefiningOp<AddOp>()) |
286 | if (getRhs() == add.getRhs()) |
287 | return add.getLhs(); |
288 | |
289 | // complex.sub(a, complex.constant<0.0, 0.0>) -> a |
290 | if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) { |
291 | auto arrayAttr = constantOp.getValue(); |
292 | if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() && |
293 | llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) { |
294 | return getLhs(); |
295 | } |
296 | } |
297 | |
298 | return {}; |
299 | } |
300 | |
301 | //===----------------------------------------------------------------------===// |
302 | // NegOp |
303 | //===----------------------------------------------------------------------===// |
304 | |
305 | OpFoldResult NegOp::fold(FoldAdaptor adaptor) { |
306 | // complex.neg(complex.neg(a)) -> a |
307 | if (auto negOp = getOperand().getDefiningOp<NegOp>()) |
308 | return negOp.getOperand(); |
309 | |
310 | return {}; |
311 | } |
312 | |
313 | //===----------------------------------------------------------------------===// |
314 | // LogOp |
315 | //===----------------------------------------------------------------------===// |
316 | |
317 | OpFoldResult LogOp::fold(FoldAdaptor adaptor) { |
318 | // complex.log(complex.exp(a)) -> a |
319 | if (auto expOp = getOperand().getDefiningOp<ExpOp>()) |
320 | return expOp.getOperand(); |
321 | |
322 | return {}; |
323 | } |
324 | |
325 | //===----------------------------------------------------------------------===// |
326 | // ExpOp |
327 | //===----------------------------------------------------------------------===// |
328 | |
329 | OpFoldResult ExpOp::fold(FoldAdaptor adaptor) { |
330 | // complex.exp(complex.log(a)) -> a |
331 | if (auto logOp = getOperand().getDefiningOp<LogOp>()) |
332 | return logOp.getOperand(); |
333 | |
334 | return {}; |
335 | } |
336 | |
337 | //===----------------------------------------------------------------------===// |
338 | // ConjOp |
339 | //===----------------------------------------------------------------------===// |
340 | |
341 | OpFoldResult ConjOp::fold(FoldAdaptor adaptor) { |
342 | // complex.conj(complex.conj(a)) -> a |
343 | if (auto conjOp = getOperand().getDefiningOp<ConjOp>()) |
344 | return conjOp.getOperand(); |
345 | |
346 | return {}; |
347 | } |
348 | |
349 | //===----------------------------------------------------------------------===// |
350 | // MulOp |
351 | //===----------------------------------------------------------------------===// |
352 | |
353 | OpFoldResult MulOp::fold(FoldAdaptor adaptor) { |
354 | auto constant = getRhs().getDefiningOp<ConstantOp>(); |
355 | if (!constant) |
356 | return {}; |
357 | |
358 | ArrayAttr arrayAttr = constant.getValue(); |
359 | APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue(); |
360 | APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue(); |
361 | |
362 | if (!imag.isZero()) |
363 | return {}; |
364 | |
365 | // complex.mul(a, complex.constant<1.0, 0.0>) -> a |
366 | if (real == APFloat(real.getSemantics(), 1)) |
367 | return getLhs(); |
368 | |
369 | return {}; |
370 | } |
371 | |
372 | //===----------------------------------------------------------------------===// |
373 | // DivOp |
374 | //===----------------------------------------------------------------------===// |
375 | |
376 | OpFoldResult DivOp::fold(FoldAdaptor adaptor) { |
377 | auto rhs = adaptor.getRhs(); |
378 | if (!rhs) |
379 | return {}; |
380 | |
381 | ArrayAttr arrayAttr = dyn_cast<ArrayAttr>(rhs); |
382 | if (!arrayAttr || arrayAttr.size() != 2) |
383 | return {}; |
384 | |
385 | APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue(); |
386 | APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue(); |
387 | |
388 | if (!imag.isZero()) |
389 | return {}; |
390 | |
391 | // complex.div(a, complex.constant<1.0, 0.0>) -> a |
392 | if (real == APFloat(real.getSemantics(), 1)) |
393 | return getLhs(); |
394 | |
395 | return {}; |
396 | } |
397 | |
398 | //===----------------------------------------------------------------------===// |
399 | // TableGen'd op method definitions |
400 | //===----------------------------------------------------------------------===// |
401 | |
402 | #define GET_OP_CLASSES |
403 | #include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc" |
404 | |