| 1 | //===-- lib/Evaluate/complex.cpp ------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "flang/Evaluate/complex.h" |
| 10 | #include "llvm/Support/raw_ostream.h" |
| 11 | |
| 12 | namespace Fortran::evaluate::value { |
| 13 | |
| 14 | template <typename R> |
| 15 | ValueWithRealFlags<Complex<R>> Complex<R>::Add( |
| 16 | const Complex &that, Rounding rounding) const { |
| 17 | RealFlags flags; |
| 18 | Part reSum{re_.Add(that.re_, rounding).AccumulateFlags(flags)}; |
| 19 | Part imSum{im_.Add(that.im_, rounding).AccumulateFlags(flags)}; |
| 20 | return {Complex{reSum, imSum}, flags}; |
| 21 | } |
| 22 | |
| 23 | template <typename R> |
| 24 | ValueWithRealFlags<Complex<R>> Complex<R>::Subtract( |
| 25 | const Complex &that, Rounding rounding) const { |
| 26 | RealFlags flags; |
| 27 | Part reDiff{re_.Subtract(that.re_, rounding).AccumulateFlags(flags)}; |
| 28 | Part imDiff{im_.Subtract(that.im_, rounding).AccumulateFlags(flags)}; |
| 29 | return {Complex{reDiff, imDiff}, flags}; |
| 30 | } |
| 31 | |
| 32 | template <typename R> |
| 33 | ValueWithRealFlags<Complex<R>> Complex<R>::Multiply( |
| 34 | const Complex &that, Rounding rounding) const { |
| 35 | // (a + ib)*(c + id) -> ac - bd + i(ad + bc) |
| 36 | RealFlags flags; |
| 37 | Part ac{re_.Multiply(that.re_, rounding).AccumulateFlags(flags)}; |
| 38 | Part bd{im_.Multiply(that.im_, rounding).AccumulateFlags(flags)}; |
| 39 | Part ad{re_.Multiply(that.im_, rounding).AccumulateFlags(flags)}; |
| 40 | Part bc{im_.Multiply(that.re_, rounding).AccumulateFlags(flags)}; |
| 41 | Part acbd{ac.Subtract(bd, rounding).AccumulateFlags(flags)}; |
| 42 | Part adbc{ad.Add(bc, rounding).AccumulateFlags(flags)}; |
| 43 | return {Complex{acbd, adbc}, flags}; |
| 44 | } |
| 45 | |
| 46 | template <typename R> |
| 47 | ValueWithRealFlags<Complex<R>> Complex<R>::Divide( |
| 48 | const Complex &that, Rounding rounding) const { |
| 49 | // (a + ib)/(c + id) -> [(a+ib)*(c-id)] / [(c+id)*(c-id)] |
| 50 | // -> [ac+bd+i(bc-ad)] / (cc+dd) -- note (cc+dd) is real |
| 51 | // -> ((ac+bd)/(cc+dd)) + i((bc-ad)/(cc+dd)) |
| 52 | RealFlags flags; |
| 53 | Part cc{that.re_.Multiply(that.re_, rounding).AccumulateFlags(flags)}; |
| 54 | Part dd{that.im_.Multiply(that.im_, rounding).AccumulateFlags(flags)}; |
| 55 | Part ccPdd{cc.Add(dd, rounding).AccumulateFlags(flags)}; |
| 56 | if (!flags.test(RealFlag::Overflow) && !flags.test(RealFlag::Underflow)) { |
| 57 | // den = (cc+dd) did not overflow or underflow; try the naive |
| 58 | // sequence without scaling to avoid extra roundings. |
| 59 | Part ac{re_.Multiply(that.re_, rounding).AccumulateFlags(flags)}; |
| 60 | Part ad{re_.Multiply(that.im_, rounding).AccumulateFlags(flags)}; |
| 61 | Part bc{im_.Multiply(that.re_, rounding).AccumulateFlags(flags)}; |
| 62 | Part bd{im_.Multiply(that.im_, rounding).AccumulateFlags(flags)}; |
| 63 | Part acPbd{ac.Add(bd, rounding).AccumulateFlags(flags)}; |
| 64 | Part bcSad{bc.Subtract(ad, rounding).AccumulateFlags(flags)}; |
| 65 | Part re{acPbd.Divide(ccPdd, rounding).AccumulateFlags(flags)}; |
| 66 | Part im{bcSad.Divide(ccPdd, rounding).AccumulateFlags(flags)}; |
| 67 | if (!flags.test(RealFlag::Overflow) && !flags.test(RealFlag::Underflow)) { |
| 68 | return {Complex{re, im}, flags}; |
| 69 | } |
| 70 | } |
| 71 | // Scale numerator and denominator by d/c (if c>=d) or c/d (if c<d) |
| 72 | flags.clear(); |
| 73 | Part scale; // will be <= 1.0 in magnitude |
| 74 | bool cGEd{that.re_.ABS().Compare(that.im_.ABS()) != Relation::Less}; |
| 75 | if (cGEd) { |
| 76 | scale = that.im_.Divide(that.re_, rounding).AccumulateFlags(flags); |
| 77 | } else { |
| 78 | scale = that.re_.Divide(that.im_, rounding).AccumulateFlags(flags); |
| 79 | } |
| 80 | Part den; |
| 81 | if (cGEd) { |
| 82 | Part dS{scale.Multiply(that.im_, rounding).AccumulateFlags(flags)}; |
| 83 | den = dS.Add(that.re_, rounding).AccumulateFlags(flags); |
| 84 | } else { |
| 85 | Part cS{scale.Multiply(that.re_, rounding).AccumulateFlags(flags)}; |
| 86 | den = cS.Add(that.im_, rounding).AccumulateFlags(flags); |
| 87 | } |
| 88 | Part aS{scale.Multiply(re_, rounding).AccumulateFlags(flags)}; |
| 89 | Part bS{scale.Multiply(im_, rounding).AccumulateFlags(flags)}; |
| 90 | Part re1, im1; |
| 91 | if (cGEd) { |
| 92 | re1 = re_.Add(bS, rounding).AccumulateFlags(flags); |
| 93 | im1 = im_.Subtract(aS, rounding).AccumulateFlags(flags); |
| 94 | } else { |
| 95 | re1 = aS.Add(im_, rounding).AccumulateFlags(flags); |
| 96 | im1 = bS.Subtract(re_, rounding).AccumulateFlags(flags); |
| 97 | } |
| 98 | Part re{re1.Divide(den, rounding).AccumulateFlags(flags)}; |
| 99 | Part im{im1.Divide(den, rounding).AccumulateFlags(flags)}; |
| 100 | return {Complex{re, im}, flags}; |
| 101 | } |
| 102 | |
| 103 | template <typename R> std::string Complex<R>::DumpHexadecimal() const { |
| 104 | std::string result{'('}; |
| 105 | result += re_.DumpHexadecimal(); |
| 106 | result += ','; |
| 107 | result += im_.DumpHexadecimal(); |
| 108 | result += ')'; |
| 109 | return result; |
| 110 | } |
| 111 | |
| 112 | template <typename R> |
| 113 | llvm::raw_ostream &Complex<R>::AsFortran(llvm::raw_ostream &o, int kind) const { |
| 114 | re_.AsFortran(o << '(', kind); |
| 115 | im_.AsFortran(o << ',', kind); |
| 116 | return o << ')'; |
| 117 | } |
| 118 | |
| 119 | template class Complex<Real<Integer<16>, 11>>; |
| 120 | template class Complex<Real<Integer<16>, 8>>; |
| 121 | template class Complex<Real<Integer<32>, 24>>; |
| 122 | template class Complex<Real<Integer<64>, 53>>; |
| 123 | template class Complex<Real<X87IntegerContainer, 64>>; |
| 124 | template class Complex<Real<Integer<128>, 113>>; |
| 125 | } // namespace Fortran::evaluate::value |
| 126 | |