| 1 | //===- ComplexOps.cpp - MLIR Complex Operations ---------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 10 | #include "mlir/Dialect/Complex/IR/Complex.h" |
| 11 | #include "mlir/IR/Builders.h" |
| 12 | #include "mlir/IR/BuiltinTypes.h" |
| 13 | #include "mlir/IR/Matchers.h" |
| 14 | #include "mlir/IR/PatternMatch.h" |
| 15 | |
| 16 | using namespace mlir; |
| 17 | using namespace mlir::complex; |
| 18 | |
| 19 | //===----------------------------------------------------------------------===// |
| 20 | // ConstantOp |
| 21 | //===----------------------------------------------------------------------===// |
| 22 | |
| 23 | OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { |
| 24 | return getValue(); |
| 25 | } |
| 26 | |
| 27 | void ConstantOp::getAsmResultNames( |
| 28 | function_ref<void(Value, StringRef)> setNameFn) { |
| 29 | setNameFn(getResult(), "cst" ); |
| 30 | } |
| 31 | |
| 32 | bool ConstantOp::isBuildableWith(Attribute value, Type type) { |
| 33 | if (auto arrAttr = llvm::dyn_cast<ArrayAttr>(value)) { |
| 34 | auto complexTy = llvm::dyn_cast<ComplexType>(type); |
| 35 | if (!complexTy || arrAttr.size() != 2) |
| 36 | return false; |
| 37 | auto complexEltTy = complexTy.getElementType(); |
| 38 | if (auto fre = llvm::dyn_cast<FloatAttr>(arrAttr[0])) { |
| 39 | auto im = llvm::dyn_cast<FloatAttr>(arrAttr[1]); |
| 40 | return im && fre.getType() == complexEltTy && |
| 41 | im.getType() == complexEltTy; |
| 42 | } |
| 43 | if (auto ire = llvm::dyn_cast<IntegerAttr>(arrAttr[0])) { |
| 44 | auto im = llvm::dyn_cast<IntegerAttr>(arrAttr[1]); |
| 45 | return im && ire.getType() == complexEltTy && |
| 46 | im.getType() == complexEltTy; |
| 47 | } |
| 48 | } |
| 49 | return false; |
| 50 | } |
| 51 | |
| 52 | LogicalResult ConstantOp::verify() { |
| 53 | ArrayAttr arrayAttr = getValue(); |
| 54 | if (arrayAttr.size() != 2) { |
| 55 | return emitOpError( |
| 56 | "requires 'value' to be a complex constant, represented as array of " |
| 57 | "two values" ); |
| 58 | } |
| 59 | |
| 60 | auto complexEltTy = getType().getElementType(); |
| 61 | if (!isa<FloatAttr, IntegerAttr>(arrayAttr[0]) || |
| 62 | !isa<FloatAttr, IntegerAttr>(arrayAttr[1])) |
| 63 | return emitOpError( |
| 64 | "requires attribute's elements to be float or integer attributes" ); |
| 65 | auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]); |
| 66 | auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]); |
| 67 | if (complexEltTy != re.getType() || complexEltTy != im.getType()) { |
| 68 | return emitOpError() |
| 69 | << "requires attribute's element types (" << re.getType() << ", " |
| 70 | << im.getType() |
| 71 | << ") to match the element type of the op's return type (" |
| 72 | << complexEltTy << ")" ; |
| 73 | } |
| 74 | return success(); |
| 75 | } |
| 76 | |
| 77 | //===----------------------------------------------------------------------===// |
| 78 | // BitcastOp |
| 79 | //===----------------------------------------------------------------------===// |
| 80 | |
| 81 | OpFoldResult BitcastOp::fold(FoldAdaptor bitcast) { |
| 82 | if (getOperand().getType() == getType()) |
| 83 | return getOperand(); |
| 84 | |
| 85 | return {}; |
| 86 | } |
| 87 | |
| 88 | LogicalResult BitcastOp::verify() { |
| 89 | auto operandType = getOperand().getType(); |
| 90 | auto resultType = getType(); |
| 91 | |
| 92 | // We allow this to be legal as it can be folded away. |
| 93 | if (operandType == resultType) |
| 94 | return success(); |
| 95 | |
| 96 | if (!operandType.isIntOrFloat() && !isa<ComplexType>(operandType)) { |
| 97 | return emitOpError("operand must be int/float/complex" ); |
| 98 | } |
| 99 | |
| 100 | if (!resultType.isIntOrFloat() && !isa<ComplexType>(resultType)) { |
| 101 | return emitOpError("result must be int/float/complex" ); |
| 102 | } |
| 103 | |
| 104 | if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) { |
| 105 | return emitOpError( |
| 106 | "requires that either input or output has a complex type" ); |
| 107 | } |
| 108 | |
| 109 | if (isa<ComplexType>(resultType)) |
| 110 | std::swap(operandType, resultType); |
| 111 | |
| 112 | int32_t operandBitwidth = dyn_cast<ComplexType>(operandType) |
| 113 | .getElementType() |
| 114 | .getIntOrFloatBitWidth() * |
| 115 | 2; |
| 116 | int32_t resultBitwidth = resultType.getIntOrFloatBitWidth(); |
| 117 | |
| 118 | if (operandBitwidth != resultBitwidth) { |
| 119 | return emitOpError("casting bitwidths do not match" ); |
| 120 | } |
| 121 | |
| 122 | return success(); |
| 123 | } |
| 124 | |
| 125 | struct MergeComplexBitcast final : OpRewritePattern<BitcastOp> { |
| 126 | using OpRewritePattern<BitcastOp>::OpRewritePattern; |
| 127 | |
| 128 | LogicalResult matchAndRewrite(BitcastOp op, |
| 129 | PatternRewriter &rewriter) const override { |
| 130 | if (auto defining = op.getOperand().getDefiningOp<BitcastOp>()) { |
| 131 | if (isa<ComplexType>(op.getType()) || |
| 132 | isa<ComplexType>(defining.getOperand().getType())) { |
| 133 | // complex.bitcast requires that input or output is complex. |
| 134 | rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(), |
| 135 | defining.getOperand()); |
| 136 | } else { |
| 137 | rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(), |
| 138 | defining.getOperand()); |
| 139 | } |
| 140 | return success(); |
| 141 | } |
| 142 | |
| 143 | if (auto defining = op.getOperand().getDefiningOp<arith::BitcastOp>()) { |
| 144 | rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(), |
| 145 | defining.getOperand()); |
| 146 | return success(); |
| 147 | } |
| 148 | |
| 149 | return failure(); |
| 150 | } |
| 151 | }; |
| 152 | |
| 153 | struct MergeArithBitcast final : OpRewritePattern<arith::BitcastOp> { |
| 154 | using OpRewritePattern<arith::BitcastOp>::OpRewritePattern; |
| 155 | |
| 156 | LogicalResult matchAndRewrite(arith::BitcastOp op, |
| 157 | PatternRewriter &rewriter) const override { |
| 158 | if (auto defining = op.getOperand().getDefiningOp<complex::BitcastOp>()) { |
| 159 | rewriter.replaceOpWithNewOp<complex::BitcastOp>(op, op.getType(), |
| 160 | defining.getOperand()); |
| 161 | return success(); |
| 162 | } |
| 163 | |
| 164 | return failure(); |
| 165 | } |
| 166 | }; |
| 167 | |
| 168 | void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| 169 | MLIRContext *context) { |
| 170 | results.add<MergeComplexBitcast, MergeArithBitcast>(context); |
| 171 | } |
| 172 | |
| 173 | //===----------------------------------------------------------------------===// |
| 174 | // CreateOp |
| 175 | //===----------------------------------------------------------------------===// |
| 176 | |
| 177 | OpFoldResult CreateOp::fold(FoldAdaptor adaptor) { |
| 178 | // Fold complex.create(complex.re(op), complex.im(op)). |
| 179 | if (auto reOp = getOperand(0).getDefiningOp<ReOp>()) { |
| 180 | if (auto imOp = getOperand(1).getDefiningOp<ImOp>()) { |
| 181 | if (reOp.getOperand() == imOp.getOperand()) { |
| 182 | return reOp.getOperand(); |
| 183 | } |
| 184 | } |
| 185 | } |
| 186 | return {}; |
| 187 | } |
| 188 | |
| 189 | //===----------------------------------------------------------------------===// |
| 190 | // ImOp |
| 191 | //===----------------------------------------------------------------------===// |
| 192 | |
| 193 | OpFoldResult ImOp::fold(FoldAdaptor adaptor) { |
| 194 | ArrayAttr arrayAttr = |
| 195 | llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex()); |
| 196 | if (arrayAttr && arrayAttr.size() == 2) |
| 197 | return arrayAttr[1]; |
| 198 | if (auto createOp = getOperand().getDefiningOp<CreateOp>()) |
| 199 | return createOp.getOperand(1); |
| 200 | return {}; |
| 201 | } |
| 202 | |
| 203 | namespace { |
| 204 | template <typename OpKind, int ComponentIndex> |
| 205 | struct FoldComponentNeg final : OpRewritePattern<OpKind> { |
| 206 | using OpRewritePattern<OpKind>::OpRewritePattern; |
| 207 | |
| 208 | LogicalResult matchAndRewrite(OpKind op, |
| 209 | PatternRewriter &rewriter) const override { |
| 210 | auto negOp = op.getOperand().template getDefiningOp<NegOp>(); |
| 211 | if (!negOp) |
| 212 | return failure(); |
| 213 | |
| 214 | auto createOp = negOp.getComplex().template getDefiningOp<CreateOp>(); |
| 215 | if (!createOp) |
| 216 | return failure(); |
| 217 | |
| 218 | Type elementType = createOp.getType().getElementType(); |
| 219 | assert(isa<FloatType>(elementType)); |
| 220 | |
| 221 | rewriter.replaceOpWithNewOp<arith::NegFOp>( |
| 222 | op, elementType, createOp.getOperand(ComponentIndex)); |
| 223 | return success(); |
| 224 | } |
| 225 | }; |
| 226 | } // namespace |
| 227 | |
| 228 | void ImOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| 229 | MLIRContext *context) { |
| 230 | results.add<FoldComponentNeg<ImOp, 1>>(context); |
| 231 | } |
| 232 | |
| 233 | //===----------------------------------------------------------------------===// |
| 234 | // ReOp |
| 235 | //===----------------------------------------------------------------------===// |
| 236 | |
| 237 | OpFoldResult ReOp::fold(FoldAdaptor adaptor) { |
| 238 | ArrayAttr arrayAttr = |
| 239 | llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex()); |
| 240 | if (arrayAttr && arrayAttr.size() == 2) |
| 241 | return arrayAttr[0]; |
| 242 | if (auto createOp = getOperand().getDefiningOp<CreateOp>()) |
| 243 | return createOp.getOperand(0); |
| 244 | return {}; |
| 245 | } |
| 246 | |
| 247 | void ReOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| 248 | MLIRContext *context) { |
| 249 | results.add<FoldComponentNeg<ReOp, 0>>(context); |
| 250 | } |
| 251 | |
| 252 | //===----------------------------------------------------------------------===// |
| 253 | // AddOp |
| 254 | //===----------------------------------------------------------------------===// |
| 255 | |
| 256 | OpFoldResult AddOp::fold(FoldAdaptor adaptor) { |
| 257 | // complex.add(complex.sub(a, b), b) -> a |
| 258 | if (auto sub = getLhs().getDefiningOp<SubOp>()) |
| 259 | if (getRhs() == sub.getRhs()) |
| 260 | return sub.getLhs(); |
| 261 | |
| 262 | // complex.add(b, complex.sub(a, b)) -> a |
| 263 | if (auto sub = getRhs().getDefiningOp<SubOp>()) |
| 264 | if (getLhs() == sub.getRhs()) |
| 265 | return sub.getLhs(); |
| 266 | |
| 267 | // complex.add(a, complex.constant<0.0, 0.0>) -> a |
| 268 | if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) { |
| 269 | auto arrayAttr = constantOp.getValue(); |
| 270 | if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() && |
| 271 | llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) { |
| 272 | return getLhs(); |
| 273 | } |
| 274 | } |
| 275 | |
| 276 | return {}; |
| 277 | } |
| 278 | |
| 279 | //===----------------------------------------------------------------------===// |
| 280 | // SubOp |
| 281 | //===----------------------------------------------------------------------===// |
| 282 | |
| 283 | OpFoldResult SubOp::fold(FoldAdaptor adaptor) { |
| 284 | // complex.sub(complex.add(a, b), b) -> a |
| 285 | if (auto add = getLhs().getDefiningOp<AddOp>()) |
| 286 | if (getRhs() == add.getRhs()) |
| 287 | return add.getLhs(); |
| 288 | |
| 289 | // complex.sub(a, complex.constant<0.0, 0.0>) -> a |
| 290 | if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) { |
| 291 | auto arrayAttr = constantOp.getValue(); |
| 292 | if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() && |
| 293 | llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) { |
| 294 | return getLhs(); |
| 295 | } |
| 296 | } |
| 297 | |
| 298 | return {}; |
| 299 | } |
| 300 | |
| 301 | //===----------------------------------------------------------------------===// |
| 302 | // NegOp |
| 303 | //===----------------------------------------------------------------------===// |
| 304 | |
| 305 | OpFoldResult NegOp::fold(FoldAdaptor adaptor) { |
| 306 | // complex.neg(complex.neg(a)) -> a |
| 307 | if (auto negOp = getOperand().getDefiningOp<NegOp>()) |
| 308 | return negOp.getOperand(); |
| 309 | |
| 310 | return {}; |
| 311 | } |
| 312 | |
| 313 | //===----------------------------------------------------------------------===// |
| 314 | // LogOp |
| 315 | //===----------------------------------------------------------------------===// |
| 316 | |
| 317 | OpFoldResult LogOp::fold(FoldAdaptor adaptor) { |
| 318 | // complex.log(complex.exp(a)) -> a |
| 319 | if (auto expOp = getOperand().getDefiningOp<ExpOp>()) |
| 320 | return expOp.getOperand(); |
| 321 | |
| 322 | return {}; |
| 323 | } |
| 324 | |
| 325 | //===----------------------------------------------------------------------===// |
| 326 | // ExpOp |
| 327 | //===----------------------------------------------------------------------===// |
| 328 | |
| 329 | OpFoldResult ExpOp::fold(FoldAdaptor adaptor) { |
| 330 | // complex.exp(complex.log(a)) -> a |
| 331 | if (auto logOp = getOperand().getDefiningOp<LogOp>()) |
| 332 | return logOp.getOperand(); |
| 333 | |
| 334 | return {}; |
| 335 | } |
| 336 | |
| 337 | //===----------------------------------------------------------------------===// |
| 338 | // ConjOp |
| 339 | //===----------------------------------------------------------------------===// |
| 340 | |
| 341 | OpFoldResult ConjOp::fold(FoldAdaptor adaptor) { |
| 342 | // complex.conj(complex.conj(a)) -> a |
| 343 | if (auto conjOp = getOperand().getDefiningOp<ConjOp>()) |
| 344 | return conjOp.getOperand(); |
| 345 | |
| 346 | return {}; |
| 347 | } |
| 348 | |
| 349 | //===----------------------------------------------------------------------===// |
| 350 | // MulOp |
| 351 | //===----------------------------------------------------------------------===// |
| 352 | |
| 353 | OpFoldResult MulOp::fold(FoldAdaptor adaptor) { |
| 354 | auto constant = getRhs().getDefiningOp<ConstantOp>(); |
| 355 | if (!constant) |
| 356 | return {}; |
| 357 | |
| 358 | ArrayAttr arrayAttr = constant.getValue(); |
| 359 | APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue(); |
| 360 | APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue(); |
| 361 | |
| 362 | if (!imag.isZero()) |
| 363 | return {}; |
| 364 | |
| 365 | // complex.mul(a, complex.constant<1.0, 0.0>) -> a |
| 366 | if (real == APFloat(real.getSemantics(), 1)) |
| 367 | return getLhs(); |
| 368 | |
| 369 | return {}; |
| 370 | } |
| 371 | |
| 372 | //===----------------------------------------------------------------------===// |
| 373 | // DivOp |
| 374 | //===----------------------------------------------------------------------===// |
| 375 | |
| 376 | OpFoldResult DivOp::fold(FoldAdaptor adaptor) { |
| 377 | auto rhs = adaptor.getRhs(); |
| 378 | if (!rhs) |
| 379 | return {}; |
| 380 | |
| 381 | ArrayAttr arrayAttr = dyn_cast<ArrayAttr>(rhs); |
| 382 | if (!arrayAttr || arrayAttr.size() != 2) |
| 383 | return {}; |
| 384 | |
| 385 | APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue(); |
| 386 | APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue(); |
| 387 | |
| 388 | if (!imag.isZero()) |
| 389 | return {}; |
| 390 | |
| 391 | // complex.div(a, complex.constant<1.0, 0.0>) -> a |
| 392 | if (real == APFloat(real.getSemantics(), 1)) |
| 393 | return getLhs(); |
| 394 | |
| 395 | return {}; |
| 396 | } |
| 397 | |
| 398 | //===----------------------------------------------------------------------===// |
| 399 | // TableGen'd op method definitions |
| 400 | //===----------------------------------------------------------------------===// |
| 401 | |
| 402 | #define GET_OP_CLASSES |
| 403 | #include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc" |
| 404 | |