| 1 | //===-- lib/Evaluate/fold-real.cpp ----------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "fold-implementation.h" |
| 10 | #include "fold-matmul.h" |
| 11 | #include "fold-reduction.h" |
| 12 | |
| 13 | namespace Fortran::evaluate { |
| 14 | |
| 15 | template <typename T> |
| 16 | static Expr<T> FoldTransformationalBessel( |
| 17 | FunctionRef<T> &&funcRef, FoldingContext &context) { |
| 18 | CHECK(funcRef.arguments().size() == 3); |
| 19 | /// Bessel runtime functions use `int` integer arguments. Convert integer |
| 20 | /// arguments to Int4, any overflow error will be reported during the |
| 21 | /// conversion folding. |
| 22 | using Int4 = Type<TypeCategory::Integer, 4>; |
| 23 | if (auto args{GetConstantArguments<Int4, Int4, T>( |
| 24 | context, funcRef.arguments(), /*hasOptionalArgument=*/false)}) { |
| 25 | const std::string &name{std::get<SpecificIntrinsic>(funcRef.proc().u).name}; |
| 26 | if (auto elementalBessel{GetHostRuntimeWrapper<T, Int4, T>(name)}) { |
| 27 | std::vector<Scalar<T>> results; |
| 28 | int n1{static_cast<int>( |
| 29 | std::get<0>(*args)->GetScalarValue().value().ToInt64())}; |
| 30 | int n2{static_cast<int>( |
| 31 | std::get<1>(*args)->GetScalarValue().value().ToInt64())}; |
| 32 | Scalar<T> x{std::get<2>(*args)->GetScalarValue().value()}; |
| 33 | for (int i{n1}; i <= n2; ++i) { |
| 34 | results.emplace_back((*elementalBessel)(context, Scalar<Int4>{i}, x)); |
| 35 | } |
| 36 | return Expr<T>{Constant<T>{ |
| 37 | std::move(results), ConstantSubscripts{std::max(n2 - n1 + 1, 0)}}}; |
| 38 | } else if (context.languageFeatures().ShouldWarn( |
| 39 | common::UsageWarning::FoldingFailure)) { |
| 40 | context.messages().Say(common::UsageWarning::FoldingFailure, |
| 41 | "%s(integer(kind=4), real(kind=%d)) cannot be folded on host"_warn_en_US , |
| 42 | name, T::kind); |
| 43 | } |
| 44 | } |
| 45 | return Expr<T>{std::move(funcRef)}; |
| 46 | } |
| 47 | |
| 48 | // NORM2 |
| 49 | template <int KIND> class Norm2Accumulator { |
| 50 | using T = Type<TypeCategory::Real, KIND>; |
| 51 | |
| 52 | public: |
| 53 | Norm2Accumulator( |
| 54 | const Constant<T> &array, const Constant<T> &maxAbs, Rounding rounding) |
| 55 | : array_{array}, maxAbs_{maxAbs}, rounding_{rounding} {}; |
| 56 | void operator()( |
| 57 | Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) { |
| 58 | // Summation of scaled elements: |
| 59 | // Naively, |
| 60 | // NORM2(A(:)) = SQRT(SUM(A(:)**2)) |
| 61 | // For any T > 0, we have mathematically |
| 62 | // SQRT(SUM(A(:)**2)) |
| 63 | // = SQRT(T**2 * (SUM(A(:)**2) / T**2)) |
| 64 | // = SQRT(T**2 * SUM(A(:)**2 / T**2)) |
| 65 | // = SQRT(T**2 * SUM((A(:)/T)**2)) |
| 66 | // = SQRT(T**2) * SQRT(SUM((A(:)/T)**2)) |
| 67 | // = T * SQRT(SUM((A(:)/T)**2)) |
| 68 | // By letting T = MAXVAL(ABS(A)), we ensure that |
| 69 | // ALL(ABS(A(:)/T) <= 1), so ALL((A(:)/T)**2 <= 1), and the SUM will |
| 70 | // not overflow unless absolutely necessary. |
| 71 | auto scale{maxAbs_.At(maxAbsAt_)}; |
| 72 | if (scale.IsZero()) { |
| 73 | // Maximum value is zero, and so will the result be. |
| 74 | // Avoid division by zero below. |
| 75 | element = scale; |
| 76 | } else { |
| 77 | auto item{array_.At(at)}; |
| 78 | auto scaled{item.Divide(scale).value}; |
| 79 | auto square{scaled.Multiply(scaled).value}; |
| 80 | if constexpr (useKahanSummation) { |
| 81 | auto next{square.Subtract(correction_, rounding_)}; |
| 82 | overflow_ |= next.flags.test(RealFlag::Overflow); |
| 83 | auto sum{element.Add(next.value, rounding_)}; |
| 84 | overflow_ |= sum.flags.test(RealFlag::Overflow); |
| 85 | correction_ = sum.value.Subtract(element, rounding_) |
| 86 | .value.Subtract(next.value, rounding_) |
| 87 | .value; |
| 88 | element = sum.value; |
| 89 | } else { |
| 90 | auto sum{element.Add(square, rounding_)}; |
| 91 | overflow_ |= sum.flags.test(RealFlag::Overflow); |
| 92 | element = sum.value; |
| 93 | } |
| 94 | } |
| 95 | } |
| 96 | bool overflow() const { return overflow_; } |
| 97 | void Done(Scalar<T> &result) { |
| 98 | // incoming result = SUM((data(:)/maxAbs)**2) |
| 99 | // outgoing result = maxAbs * SQRT(result) |
| 100 | auto root{result.SQRT().value}; |
| 101 | auto product{root.Multiply(maxAbs_.At(maxAbsAt_))}; |
| 102 | maxAbs_.IncrementSubscripts(maxAbsAt_); |
| 103 | overflow_ |= product.flags.test(RealFlag::Overflow); |
| 104 | result = product.value; |
| 105 | } |
| 106 | |
| 107 | private: |
| 108 | const Constant<T> &array_; |
| 109 | const Constant<T> &maxAbs_; |
| 110 | const Rounding rounding_; |
| 111 | bool overflow_{false}; |
| 112 | Scalar<T> correction_{}; |
| 113 | ConstantSubscripts maxAbsAt_{maxAbs_.lbounds()}; |
| 114 | }; |
| 115 | |
| 116 | template <int KIND> |
| 117 | static Expr<Type<TypeCategory::Real, KIND>> FoldNorm2(FoldingContext &context, |
| 118 | FunctionRef<Type<TypeCategory::Real, KIND>> &&funcRef) { |
| 119 | using T = Type<TypeCategory::Real, KIND>; |
| 120 | using Element = typename Constant<T>::Element; |
| 121 | std::optional<int> dim; |
| 122 | if (std::optional<ArrayAndMask<T>> arrayAndMask{ |
| 123 | ProcessReductionArgs<T>(context, funcRef.arguments(), dim, |
| 124 | /*X=*/0, /*DIM=*/1)}) { |
| 125 | MaxvalMinvalAccumulator<T, /*ABS=*/true> maxAbsAccumulator{ |
| 126 | RelationalOperator::GT, context, arrayAndMask->array}; |
| 127 | const Element identity{}; |
| 128 | Constant<T> maxAbs{DoReduction<T>(arrayAndMask->array, arrayAndMask->mask, |
| 129 | dim, identity, maxAbsAccumulator)}; |
| 130 | Norm2Accumulator norm2Accumulator{arrayAndMask->array, maxAbs, |
| 131 | context.targetCharacteristics().roundingMode()}; |
| 132 | Constant<T> result{DoReduction<T>(arrayAndMask->array, arrayAndMask->mask, |
| 133 | dim, identity, norm2Accumulator)}; |
| 134 | if (norm2Accumulator.overflow() && |
| 135 | context.languageFeatures().ShouldWarn( |
| 136 | common::UsageWarning::FoldingException)) { |
| 137 | context.messages().Say(common::UsageWarning::FoldingException, |
| 138 | "NORM2() of REAL(%d) data overflowed"_warn_en_US , KIND); |
| 139 | } |
| 140 | return Expr<T>{std::move(result)}; |
| 141 | } |
| 142 | return Expr<T>{std::move(funcRef)}; |
| 143 | } |
| 144 | |
| 145 | template <int KIND> |
| 146 | Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction( |
| 147 | FoldingContext &context, |
| 148 | FunctionRef<Type<TypeCategory::Real, KIND>> &&funcRef) { |
| 149 | using T = Type<TypeCategory::Real, KIND>; |
| 150 | using ComplexT = Type<TypeCategory::Complex, KIND>; |
| 151 | using Int4 = Type<TypeCategory::Integer, 4>; |
| 152 | ActualArguments &args{funcRef.arguments()}; |
| 153 | auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)}; |
| 154 | CHECK(intrinsic); |
| 155 | std::string name{intrinsic->name}; |
| 156 | if (name == "acos" || name == "acosh" || name == "asin" || name == "asinh" || |
| 157 | (name == "atan" && args.size() == 1) || name == "atanh" || |
| 158 | name == "bessel_j0" || name == "bessel_j1" || name == "bessel_y0" || |
| 159 | name == "bessel_y1" || name == "cos" || name == "cosh" || name == "erf" || |
| 160 | name == "erfc" || name == "erfc_scaled" || name == "exp" || |
| 161 | name == "gamma" || name == "log" || name == "log10" || |
| 162 | name == "log_gamma" || name == "sin" || name == "sinh" || name == "tan" || |
| 163 | name == "tanh" ) { |
| 164 | CHECK(args.size() == 1); |
| 165 | if (auto callable{GetHostRuntimeWrapper<T, T>(name)}) { |
| 166 | return FoldElementalIntrinsic<T, T>( |
| 167 | context, std::move(funcRef), *callable); |
| 168 | } else if (context.languageFeatures().ShouldWarn( |
| 169 | common::UsageWarning::FoldingFailure)) { |
| 170 | context.messages().Say(common::UsageWarning::FoldingFailure, |
| 171 | "%s(real(kind=%d)) cannot be folded on host"_warn_en_US , name, KIND); |
| 172 | } |
| 173 | } else if (name == "amax0" || name == "amin0" || name == "amin1" || |
| 174 | name == "amax1" || name == "dmin1" || name == "dmax1" ) { |
| 175 | return RewriteSpecificMINorMAX(context, std::move(funcRef)); |
| 176 | } else if (name == "atan" || name == "atan2" ) { |
| 177 | std::string localName{name == "atan" ? "atan2" : name}; |
| 178 | CHECK(args.size() == 2); |
| 179 | if (auto callable{GetHostRuntimeWrapper<T, T, T>(localName)}) { |
| 180 | return FoldElementalIntrinsic<T, T, T>( |
| 181 | context, std::move(funcRef), *callable); |
| 182 | } else if (context.languageFeatures().ShouldWarn( |
| 183 | common::UsageWarning::FoldingFailure)) { |
| 184 | context.messages().Say(common::UsageWarning::FoldingFailure, |
| 185 | "%s(real(kind=%d), real(kind%d)) cannot be folded on host"_warn_en_US , |
| 186 | name, KIND, KIND); |
| 187 | } |
| 188 | } else if (name == "bessel_jn" || name == "bessel_yn" ) { |
| 189 | if (args.size() == 2) { // elemental |
| 190 | // runtime functions use int arg |
| 191 | if (auto callable{GetHostRuntimeWrapper<T, Int4, T>(name)}) { |
| 192 | return FoldElementalIntrinsic<T, Int4, T>( |
| 193 | context, std::move(funcRef), *callable); |
| 194 | } else if (context.languageFeatures().ShouldWarn( |
| 195 | common::UsageWarning::FoldingFailure)) { |
| 196 | context.messages().Say(common::UsageWarning::FoldingFailure, |
| 197 | "%s(integer(kind=4), real(kind=%d)) cannot be folded on host"_warn_en_US , |
| 198 | name, KIND); |
| 199 | } |
| 200 | } else { |
| 201 | return FoldTransformationalBessel<T>(std::move(funcRef), context); |
| 202 | } |
| 203 | } else if (name == "abs" ) { // incl. zabs & cdabs |
| 204 | // Argument can be complex or real |
| 205 | if (UnwrapExpr<Expr<SomeReal>>(args[0])) { |
| 206 | return FoldElementalIntrinsic<T, T>( |
| 207 | context, std::move(funcRef), &Scalar<T>::ABS); |
| 208 | } else if (UnwrapExpr<Expr<SomeComplex>>(args[0])) { |
| 209 | return FoldElementalIntrinsic<T, ComplexT>(context, std::move(funcRef), |
| 210 | ScalarFunc<T, ComplexT>([&name, &context]( |
| 211 | const Scalar<ComplexT> &z) -> Scalar<T> { |
| 212 | ValueWithRealFlags<Scalar<T>> y{z.ABS()}; |
| 213 | if (y.flags.test(RealFlag::Overflow) && |
| 214 | context.languageFeatures().ShouldWarn( |
| 215 | common::UsageWarning::FoldingException)) { |
| 216 | context.messages().Say(common::UsageWarning::FoldingException, |
| 217 | "complex ABS intrinsic folding overflow"_warn_en_US , name); |
| 218 | } |
| 219 | return y.value; |
| 220 | })); |
| 221 | } else { |
| 222 | common::die(" unexpected argument type inside abs" ); |
| 223 | } |
| 224 | } else if (name == "aimag" ) { |
| 225 | if (auto *zExpr{UnwrapExpr<Expr<ComplexT>>(args[0])}) { |
| 226 | return Fold(context, Expr<T>{ComplexComponent{true, std::move(*zExpr)}}); |
| 227 | } |
| 228 | } else if (name == "aint" || name == "anint" ) { |
| 229 | // ANINT rounds ties away from zero, not to even |
| 230 | common::RoundingMode mode{name == "aint" |
| 231 | ? common::RoundingMode::ToZero |
| 232 | : common::RoundingMode::TiesAwayFromZero}; |
| 233 | return FoldElementalIntrinsic<T, T>(context, std::move(funcRef), |
| 234 | ScalarFunc<T, T>( |
| 235 | [&name, &context, mode](const Scalar<T> &x) -> Scalar<T> { |
| 236 | ValueWithRealFlags<Scalar<T>> y{x.ToWholeNumber(mode)}; |
| 237 | if (y.flags.test(RealFlag::Overflow) && |
| 238 | context.languageFeatures().ShouldWarn( |
| 239 | common::UsageWarning::FoldingException)) { |
| 240 | context.messages().Say(common::UsageWarning::FoldingException, |
| 241 | "%s intrinsic folding overflow"_warn_en_US , name); |
| 242 | } |
| 243 | return y.value; |
| 244 | })); |
| 245 | } else if (name == "dim" ) { |
| 246 | return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef), |
| 247 | ScalarFunc<T, T, T>([&context](const Scalar<T> &x, |
| 248 | const Scalar<T> &y) -> Scalar<T> { |
| 249 | ValueWithRealFlags<Scalar<T>> result{x.DIM(y)}; |
| 250 | if (result.flags.test(RealFlag::Overflow) && |
| 251 | context.languageFeatures().ShouldWarn( |
| 252 | common::UsageWarning::FoldingException)) { |
| 253 | context.messages().Say(common::UsageWarning::FoldingException, |
| 254 | "DIM intrinsic folding overflow"_warn_en_US ); |
| 255 | } |
| 256 | return result.value; |
| 257 | })); |
| 258 | } else if (name == "dot_product" ) { |
| 259 | return FoldDotProduct<T>(context, std::move(funcRef)); |
| 260 | } else if (name == "dprod" ) { |
| 261 | // Rewrite DPROD(x,y) -> DBLE(x)*DBLE(y) |
| 262 | if (args.at(0) && args.at(1)) { |
| 263 | const auto *xExpr{args[0]->UnwrapExpr()}; |
| 264 | const auto *yExpr{args[1]->UnwrapExpr()}; |
| 265 | if (xExpr && yExpr) { |
| 266 | return Fold(context, |
| 267 | ToReal<T::kind>(context, common::Clone(*xExpr)) * |
| 268 | ToReal<T::kind>(context, common::Clone(*yExpr))); |
| 269 | } |
| 270 | } |
| 271 | } else if (name == "epsilon" ) { |
| 272 | return Expr<T>{Scalar<T>::EPSILON()}; |
| 273 | } else if (name == "fraction" ) { |
| 274 | return FoldElementalIntrinsic<T, T>(context, std::move(funcRef), |
| 275 | ScalarFunc<T, T>( |
| 276 | [](const Scalar<T> &x) -> Scalar<T> { return x.FRACTION(); })); |
| 277 | } else if (name == "huge" ) { |
| 278 | return Expr<T>{Scalar<T>::HUGE()}; |
| 279 | } else if (name == "hypot" ) { |
| 280 | CHECK(args.size() == 2); |
| 281 | return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef), |
| 282 | ScalarFunc<T, T, T>( |
| 283 | [&](const Scalar<T> &x, const Scalar<T> &y) -> Scalar<T> { |
| 284 | ValueWithRealFlags<Scalar<T>> result{x.HYPOT(y)}; |
| 285 | if (result.flags.test(RealFlag::Overflow) && |
| 286 | context.languageFeatures().ShouldWarn( |
| 287 | common::UsageWarning::FoldingException)) { |
| 288 | context.messages().Say(common::UsageWarning::FoldingException, |
| 289 | "HYPOT intrinsic folding overflow"_warn_en_US ); |
| 290 | } |
| 291 | return result.value; |
| 292 | })); |
| 293 | } else if (name == "matmul" ) { |
| 294 | return FoldMatmul(context, std::move(funcRef)); |
| 295 | } else if (name == "max" ) { |
| 296 | return FoldMINorMAX(context, std::move(funcRef), Ordering::Greater); |
| 297 | } else if (name == "maxval" ) { |
| 298 | return FoldMaxvalMinval<T>(context, std::move(funcRef), |
| 299 | RelationalOperator::GT, T::Scalar::HUGE().Negate()); |
| 300 | } else if (name == "min" ) { |
| 301 | return FoldMINorMAX(context, std::move(funcRef), Ordering::Less); |
| 302 | } else if (name == "minval" ) { |
| 303 | return FoldMaxvalMinval<T>( |
| 304 | context, std::move(funcRef), RelationalOperator::LT, T::Scalar::HUGE()); |
| 305 | } else if (name == "mod" ) { |
| 306 | CHECK(args.size() == 2); |
| 307 | bool badPConst{false}; |
| 308 | if (auto *pExpr{UnwrapExpr<Expr<T>>(args[1])}) { |
| 309 | *pExpr = Fold(context, std::move(*pExpr)); |
| 310 | if (auto pConst{GetScalarConstantValue<T>(*pExpr)}; pConst && |
| 311 | pConst->IsZero() && |
| 312 | context.languageFeatures().ShouldWarn( |
| 313 | common::UsageWarning::FoldingAvoidsRuntimeCrash)) { |
| 314 | context.messages().Say(common::UsageWarning::FoldingAvoidsRuntimeCrash, |
| 315 | "MOD: P argument is zero"_warn_en_US ); |
| 316 | badPConst = true; |
| 317 | } |
| 318 | } |
| 319 | return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef), |
| 320 | ScalarFunc<T, T, T>([&context, badPConst](const Scalar<T> &x, |
| 321 | const Scalar<T> &y) -> Scalar<T> { |
| 322 | auto result{x.MOD(y)}; |
| 323 | if (!badPConst && result.flags.test(RealFlag::DivideByZero) && |
| 324 | context.languageFeatures().ShouldWarn( |
| 325 | common::UsageWarning::FoldingAvoidsRuntimeCrash)) { |
| 326 | context.messages().Say( |
| 327 | common::UsageWarning::FoldingAvoidsRuntimeCrash, |
| 328 | "second argument to MOD must not be zero"_warn_en_US ); |
| 329 | } |
| 330 | return result.value; |
| 331 | })); |
| 332 | } else if (name == "modulo" ) { |
| 333 | CHECK(args.size() == 2); |
| 334 | bool badPConst{false}; |
| 335 | if (auto *pExpr{UnwrapExpr<Expr<T>>(args[1])}) { |
| 336 | *pExpr = Fold(context, std::move(*pExpr)); |
| 337 | if (auto pConst{GetScalarConstantValue<T>(*pExpr)}; pConst && |
| 338 | pConst->IsZero() && |
| 339 | context.languageFeatures().ShouldWarn( |
| 340 | common::UsageWarning::FoldingAvoidsRuntimeCrash)) { |
| 341 | context.messages().Say(common::UsageWarning::FoldingAvoidsRuntimeCrash, |
| 342 | "MODULO: P argument is zero"_warn_en_US ); |
| 343 | badPConst = true; |
| 344 | } |
| 345 | } |
| 346 | return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef), |
| 347 | ScalarFunc<T, T, T>([&context, badPConst](const Scalar<T> &x, |
| 348 | const Scalar<T> &y) -> Scalar<T> { |
| 349 | auto result{x.MODULO(y)}; |
| 350 | if (!badPConst && result.flags.test(RealFlag::DivideByZero) && |
| 351 | context.languageFeatures().ShouldWarn( |
| 352 | common::UsageWarning::FoldingAvoidsRuntimeCrash)) { |
| 353 | context.messages().Say( |
| 354 | common::UsageWarning::FoldingAvoidsRuntimeCrash, |
| 355 | "second argument to MODULO must not be zero"_warn_en_US ); |
| 356 | } |
| 357 | return result.value; |
| 358 | })); |
| 359 | } else if (name == "nearest" ) { |
| 360 | if (auto *sExpr{UnwrapExpr<Expr<SomeReal>>(args[1])}) { |
| 361 | *sExpr = Fold(context, std::move(*sExpr)); |
| 362 | return common::visit( |
| 363 | [&](const auto &sVal) { |
| 364 | using TS = ResultType<decltype(sVal)>; |
| 365 | bool badSConst{false}; |
| 366 | if (auto sConst{GetScalarConstantValue<TS>(sVal)}; sConst && |
| 367 | (sConst->IsZero() || sConst->IsNotANumber()) && |
| 368 | context.languageFeatures().ShouldWarn( |
| 369 | common::UsageWarning::FoldingValueChecks)) { |
| 370 | context.messages().Say(common::UsageWarning::FoldingValueChecks, |
| 371 | "NEAREST: S argument is %s"_warn_en_US , |
| 372 | sConst->IsZero() ? "zero" : "NaN" ); |
| 373 | badSConst = true; |
| 374 | } |
| 375 | return FoldElementalIntrinsic<T, T, TS>(context, std::move(funcRef), |
| 376 | ScalarFunc<T, T, TS>([&](const Scalar<T> &x, |
| 377 | const Scalar<TS> &s) -> Scalar<T> { |
| 378 | if (!badSConst && (s.IsZero() || s.IsNotANumber()) && |
| 379 | context.languageFeatures().ShouldWarn( |
| 380 | common::UsageWarning::FoldingValueChecks)) { |
| 381 | context.messages().Say( |
| 382 | common::UsageWarning::FoldingValueChecks, |
| 383 | "NEAREST: S argument is %s"_warn_en_US , |
| 384 | s.IsZero() ? "zero" : "NaN" ); |
| 385 | } |
| 386 | auto result{x.NEAREST(!s.IsNegative())}; |
| 387 | if (context.languageFeatures().ShouldWarn( |
| 388 | common::UsageWarning::FoldingException)) { |
| 389 | if (result.flags.test(RealFlag::InvalidArgument)) { |
| 390 | context.messages().Say( |
| 391 | common::UsageWarning::FoldingException, |
| 392 | "NEAREST intrinsic folding: bad argument"_warn_en_US ); |
| 393 | } |
| 394 | } |
| 395 | return result.value; |
| 396 | })); |
| 397 | }, |
| 398 | sExpr->u); |
| 399 | } |
| 400 | } else if (name == "norm2" ) { |
| 401 | return FoldNorm2<T::kind>(context, std::move(funcRef)); |
| 402 | } else if (name == "product" ) { |
| 403 | auto one{Scalar<T>::FromInteger(value::Integer<8>{1}).value}; |
| 404 | return FoldProduct<T>(context, std::move(funcRef), one); |
| 405 | } else if (name == "real" || name == "dble" ) { |
| 406 | if (auto *expr{args[0].value().UnwrapExpr()}) { |
| 407 | return ToReal<KIND>(context, std::move(*expr)); |
| 408 | } |
| 409 | } else if (name == "rrspacing" ) { |
| 410 | return FoldElementalIntrinsic<T, T>(context, std::move(funcRef), |
| 411 | ScalarFunc<T, T>( |
| 412 | [](const Scalar<T> &x) -> Scalar<T> { return x.RRSPACING(); })); |
| 413 | } else if (name == "scale" ) { |
| 414 | if (const auto *byExpr{UnwrapExpr<Expr<SomeInteger>>(args[1])}) { |
| 415 | return common::visit( |
| 416 | [&](const auto &byVal) { |
| 417 | using TBY = ResultType<decltype(byVal)>; |
| 418 | return FoldElementalIntrinsic<T, T, TBY>(context, |
| 419 | std::move(funcRef), |
| 420 | ScalarFunc<T, T, TBY>( |
| 421 | [&](const Scalar<T> &x, const Scalar<TBY> &y) -> Scalar<T> { |
| 422 | ValueWithRealFlags<Scalar<T>> result{ |
| 423 | x. |
| 424 | // MSVC chokes on the keyword "template" here in a call to a |
| 425 | // member function template. |
| 426 | #ifndef _MSC_VER |
| 427 | template |
| 428 | #endif |
| 429 | SCALE<Scalar<TBY>>(y)}; |
| 430 | if (result.flags.test(RealFlag::Overflow) && |
| 431 | context.languageFeatures().ShouldWarn( |
| 432 | common::UsageWarning::FoldingException)) { |
| 433 | context.messages().Say( |
| 434 | common::UsageWarning::FoldingException, |
| 435 | "SCALE/IEEE_SCALB intrinsic folding overflow"_warn_en_US ); |
| 436 | } |
| 437 | return result.value; |
| 438 | })); |
| 439 | }, |
| 440 | byExpr->u); |
| 441 | } |
| 442 | } else if (name == "set_exponent" ) { |
| 443 | if (const auto *iExpr{UnwrapExpr<Expr<SomeInteger>>(args[1])}) { |
| 444 | return common::visit( |
| 445 | [&](const auto &iVal) { |
| 446 | using TY = ResultType<decltype(iVal)>; |
| 447 | return FoldElementalIntrinsic<T, T, TY>(context, std::move(funcRef), |
| 448 | ScalarFunc<T, T, TY>( |
| 449 | [&](const Scalar<T> &x, const Scalar<TY> &i) -> Scalar<T> { |
| 450 | return x.SET_EXPONENT(i.ToInt64()); |
| 451 | })); |
| 452 | }, |
| 453 | iExpr->u); |
| 454 | } |
| 455 | } else if (name == "sign" ) { |
| 456 | return FoldElementalIntrinsic<T, T, T>( |
| 457 | context, std::move(funcRef), &Scalar<T>::SIGN); |
| 458 | } else if (name == "spacing" ) { |
| 459 | return FoldElementalIntrinsic<T, T>(context, std::move(funcRef), |
| 460 | ScalarFunc<T, T>( |
| 461 | [](const Scalar<T> &x) -> Scalar<T> { return x.SPACING(); })); |
| 462 | } else if (name == "sqrt" ) { |
| 463 | return FoldElementalIntrinsic<T, T>(context, std::move(funcRef), |
| 464 | ScalarFunc<T, T>( |
| 465 | [](const Scalar<T> &x) -> Scalar<T> { return x.SQRT().value; })); |
| 466 | } else if (name == "sum" ) { |
| 467 | return FoldSum<T>(context, std::move(funcRef)); |
| 468 | } else if (name == "tiny" ) { |
| 469 | return Expr<T>{Scalar<T>::TINY()}; |
| 470 | } else if (name == "__builtin_fma" ) { |
| 471 | CHECK(args.size() == 3); |
| 472 | } else if (name == "__builtin_ieee_next_after" ) { |
| 473 | if (const auto *yExpr{UnwrapExpr<Expr<SomeReal>>(args[1])}) { |
| 474 | return common::visit( |
| 475 | [&](const auto &yVal) { |
| 476 | using TY = ResultType<decltype(yVal)>; |
| 477 | return FoldElementalIntrinsic<T, T, TY>(context, std::move(funcRef), |
| 478 | ScalarFunc<T, T, TY>([&](const Scalar<T> &x, |
| 479 | const Scalar<TY> &y) -> Scalar<T> { |
| 480 | auto xBig{Scalar<LargestReal>::Convert(x).value}; |
| 481 | auto yBig{Scalar<LargestReal>::Convert(y).value}; |
| 482 | switch (xBig.Compare(yBig)) { |
| 483 | case Relation::Unordered: |
| 484 | if (context.languageFeatures().ShouldWarn( |
| 485 | common::UsageWarning::FoldingValueChecks)) { |
| 486 | context.messages().Say( |
| 487 | common::UsageWarning::FoldingValueChecks, |
| 488 | "IEEE_NEXT_AFTER intrinsic folding: arguments are unordered"_warn_en_US ); |
| 489 | } |
| 490 | return x.NotANumber(); |
| 491 | case Relation::Equal: |
| 492 | break; |
| 493 | case Relation::Less: |
| 494 | return x.NEAREST(true).value; |
| 495 | case Relation::Greater: |
| 496 | return x.NEAREST(false).value; |
| 497 | } |
| 498 | return x; // dodge bogus "missing return" GCC warning |
| 499 | })); |
| 500 | }, |
| 501 | yExpr->u); |
| 502 | } |
| 503 | } else if (name == "__builtin_ieee_next_up" || |
| 504 | name == "__builtin_ieee_next_down" ) { |
| 505 | bool upward{name == "__builtin_ieee_next_up" }; |
| 506 | const char *iName{upward ? "IEEE_NEXT_UP" : "IEEE_NEXT_DOWN" }; |
| 507 | return FoldElementalIntrinsic<T, T>(context, std::move(funcRef), |
| 508 | ScalarFunc<T, T>([&](const Scalar<T> &x) -> Scalar<T> { |
| 509 | auto result{x.NEAREST(upward)}; |
| 510 | if (context.languageFeatures().ShouldWarn( |
| 511 | common::UsageWarning::FoldingException)) { |
| 512 | if (result.flags.test(RealFlag::InvalidArgument)) { |
| 513 | context.messages().Say(common::UsageWarning::FoldingException, |
| 514 | "%s intrinsic folding: argument is NaN"_warn_en_US , iName); |
| 515 | } |
| 516 | } |
| 517 | return result.value; |
| 518 | })); |
| 519 | } |
| 520 | return Expr<T>{std::move(funcRef)}; |
| 521 | } |
| 522 | |
| 523 | #ifdef _MSC_VER // disable bogus warning about missing definitions |
| 524 | #pragma warning(disable : 4661) |
| 525 | #endif |
| 526 | FOR_EACH_REAL_KIND(template class ExpressionBase, ) |
| 527 | template class ExpressionBase<SomeReal>; |
| 528 | } // namespace Fortran::evaluate |
| 529 | |