1//===-- lib/Evaluate/complex.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "flang/Evaluate/complex.h"
10#include "llvm/Support/raw_ostream.h"
11
12namespace Fortran::evaluate::value {
13
14template <typename R>
15ValueWithRealFlags<Complex<R>> Complex<R>::Add(
16 const Complex &that, Rounding rounding) const {
17 RealFlags flags;
18 Part reSum{re_.Add(that.re_, rounding).AccumulateFlags(flags)};
19 Part imSum{im_.Add(that.im_, rounding).AccumulateFlags(flags)};
20 return {Complex{reSum, imSum}, flags};
21}
22
23template <typename R>
24ValueWithRealFlags<Complex<R>> Complex<R>::Subtract(
25 const Complex &that, Rounding rounding) const {
26 RealFlags flags;
27 Part reDiff{re_.Subtract(that.re_, rounding).AccumulateFlags(flags)};
28 Part imDiff{im_.Subtract(that.im_, rounding).AccumulateFlags(flags)};
29 return {Complex{reDiff, imDiff}, flags};
30}
31
32template <typename R>
33ValueWithRealFlags<Complex<R>> Complex<R>::Multiply(
34 const Complex &that, Rounding rounding) const {
35 // (a + ib)*(c + id) -> ac - bd + i(ad + bc)
36 RealFlags flags;
37 Part ac{re_.Multiply(that.re_, rounding).AccumulateFlags(flags)};
38 Part bd{im_.Multiply(that.im_, rounding).AccumulateFlags(flags)};
39 Part ad{re_.Multiply(that.im_, rounding).AccumulateFlags(flags)};
40 Part bc{im_.Multiply(that.re_, rounding).AccumulateFlags(flags)};
41 Part acbd{ac.Subtract(bd, rounding).AccumulateFlags(flags)};
42 Part adbc{ad.Add(bc, rounding).AccumulateFlags(flags)};
43 return {Complex{acbd, adbc}, flags};
44}
45
46template <typename R>
47ValueWithRealFlags<Complex<R>> Complex<R>::Divide(
48 const Complex &that, Rounding rounding) const {
49 // (a + ib)/(c + id) -> [(a+ib)*(c-id)] / [(c+id)*(c-id)]
50 // -> [ac+bd+i(bc-ad)] / (cc+dd) -- note (cc+dd) is real
51 // -> ((ac+bd)/(cc+dd)) + i((bc-ad)/(cc+dd))
52 RealFlags flags;
53 Part cc{that.re_.Multiply(that.re_, rounding).AccumulateFlags(flags)};
54 Part dd{that.im_.Multiply(that.im_, rounding).AccumulateFlags(flags)};
55 Part ccPdd{cc.Add(dd, rounding).AccumulateFlags(flags)};
56 if (!flags.test(RealFlag::Overflow) && !flags.test(RealFlag::Underflow)) {
57 // den = (cc+dd) did not overflow or underflow; try the naive
58 // sequence without scaling to avoid extra roundings.
59 Part ac{re_.Multiply(that.re_, rounding).AccumulateFlags(flags)};
60 Part ad{re_.Multiply(that.im_, rounding).AccumulateFlags(flags)};
61 Part bc{im_.Multiply(that.re_, rounding).AccumulateFlags(flags)};
62 Part bd{im_.Multiply(that.im_, rounding).AccumulateFlags(flags)};
63 Part acPbd{ac.Add(bd, rounding).AccumulateFlags(flags)};
64 Part bcSad{bc.Subtract(ad, rounding).AccumulateFlags(flags)};
65 Part re{acPbd.Divide(ccPdd, rounding).AccumulateFlags(flags)};
66 Part im{bcSad.Divide(ccPdd, rounding).AccumulateFlags(flags)};
67 if (!flags.test(RealFlag::Overflow) && !flags.test(RealFlag::Underflow)) {
68 return {Complex{re, im}, flags};
69 }
70 }
71 // Scale numerator and denominator by d/c (if c>=d) or c/d (if c<d)
72 flags.clear();
73 Part scale; // will be <= 1.0 in magnitude
74 bool cGEd{that.re_.ABS().Compare(that.im_.ABS()) != Relation::Less};
75 if (cGEd) {
76 scale = that.im_.Divide(that.re_, rounding).AccumulateFlags(flags);
77 } else {
78 scale = that.re_.Divide(that.im_, rounding).AccumulateFlags(flags);
79 }
80 Part den;
81 if (cGEd) {
82 Part dS{scale.Multiply(that.im_, rounding).AccumulateFlags(flags)};
83 den = dS.Add(that.re_, rounding).AccumulateFlags(flags);
84 } else {
85 Part cS{scale.Multiply(that.re_, rounding).AccumulateFlags(flags)};
86 den = cS.Add(that.im_, rounding).AccumulateFlags(flags);
87 }
88 Part aS{scale.Multiply(re_, rounding).AccumulateFlags(flags)};
89 Part bS{scale.Multiply(im_, rounding).AccumulateFlags(flags)};
90 Part re1, im1;
91 if (cGEd) {
92 re1 = re_.Add(bS, rounding).AccumulateFlags(flags);
93 im1 = im_.Subtract(aS, rounding).AccumulateFlags(flags);
94 } else {
95 re1 = aS.Add(im_, rounding).AccumulateFlags(flags);
96 im1 = bS.Subtract(re_, rounding).AccumulateFlags(flags);
97 }
98 Part re{re1.Divide(den, rounding).AccumulateFlags(flags)};
99 Part im{im1.Divide(den, rounding).AccumulateFlags(flags)};
100 return {Complex{re, im}, flags};
101}
102
103template <typename R> std::string Complex<R>::DumpHexadecimal() const {
104 std::string result{'('};
105 result += re_.DumpHexadecimal();
106 result += ',';
107 result += im_.DumpHexadecimal();
108 result += ')';
109 return result;
110}
111
112template <typename R>
113llvm::raw_ostream &Complex<R>::AsFortran(llvm::raw_ostream &o, int kind) const {
114 re_.AsFortran(o << '(', kind);
115 im_.AsFortran(o << ',', kind);
116 return o << ')';
117}
118
119template class Complex<Real<Integer<16>, 11>>;
120template class Complex<Real<Integer<16>, 8>>;
121template class Complex<Real<Integer<32>, 24>>;
122template class Complex<Real<Integer<64>, 53>>;
123template class Complex<Real<X87IntegerContainer, 64>>;
124template class Complex<Real<Integer<128>, 113>>;
125} // namespace Fortran::evaluate::value
126

source code of flang/lib/Evaluate/complex.cpp