| 1 | //===-- lib/Semantics/check-case.cpp --------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "check-case.h" |
| 10 | #include "flang/Common/idioms.h" |
| 11 | #include "flang/Common/reference.h" |
| 12 | #include "flang/Common/template.h" |
| 13 | #include "flang/Evaluate/fold.h" |
| 14 | #include "flang/Evaluate/type.h" |
| 15 | #include "flang/Parser/parse-tree.h" |
| 16 | #include "flang/Semantics/semantics.h" |
| 17 | #include "flang/Semantics/tools.h" |
| 18 | #include <tuple> |
| 19 | |
| 20 | namespace Fortran::semantics { |
| 21 | |
| 22 | template <typename T> class CaseValues { |
| 23 | public: |
| 24 | CaseValues(SemanticsContext &c, const evaluate::DynamicType &t) |
| 25 | : context_{c}, caseExprType_{t} {} |
| 26 | |
| 27 | void Check(const std::list<parser::CaseConstruct::Case> &cases) { |
| 28 | for (const parser::CaseConstruct::Case &c : cases) { |
| 29 | AddCase(c); |
| 30 | } |
| 31 | if (!hasErrors_) { |
| 32 | cases_.sort(Comparator{}); |
| 33 | if (!AreCasesDisjoint()) { // C1149 |
| 34 | ReportConflictingCases(); |
| 35 | } |
| 36 | } |
| 37 | } |
| 38 | |
| 39 | private: |
| 40 | using Value = evaluate::Scalar<T>; |
| 41 | |
| 42 | void AddCase(const parser::CaseConstruct::Case &c) { |
| 43 | const auto &stmt{std::get<parser::Statement<parser::CaseStmt>>(c.t)}; |
| 44 | const parser::CaseStmt &caseStmt{stmt.statement}; |
| 45 | const auto &selector{std::get<parser::CaseSelector>(caseStmt.t)}; |
| 46 | common::visit( |
| 47 | common::visitors{ |
| 48 | [&](const std::list<parser::CaseValueRange> &ranges) { |
| 49 | for (const auto &range : ranges) { |
| 50 | auto pair{ComputeBounds(range)}; |
| 51 | if (pair.first && pair.second && *pair.first > *pair.second) { |
| 52 | context_.Warn(common::UsageWarning::EmptyCase, stmt.source, |
| 53 | "CASE has lower bound greater than upper bound"_warn_en_US ); |
| 54 | } else { |
| 55 | if constexpr (T::category == TypeCategory::Logical) { // C1148 |
| 56 | if ((pair.first || pair.second) && |
| 57 | (!pair.first || !pair.second || |
| 58 | *pair.first != *pair.second)) { |
| 59 | context_.Say(stmt.source, |
| 60 | "CASE range is not allowed for LOGICAL"_err_en_US ); |
| 61 | } |
| 62 | } |
| 63 | cases_.emplace_back(stmt); |
| 64 | cases_.back().lower = std::move(pair.first); |
| 65 | cases_.back().upper = std::move(pair.second); |
| 66 | } |
| 67 | } |
| 68 | }, |
| 69 | [&](const parser::Default &) { cases_.emplace_front(stmt); }, |
| 70 | }, |
| 71 | selector.u); |
| 72 | } |
| 73 | |
| 74 | std::optional<Value> GetValue(const parser::CaseValue &caseValue) { |
| 75 | const parser::Expr &expr{caseValue.thing.thing.value()}; |
| 76 | auto *x{expr.typedExpr.get()}; |
| 77 | if (x && x->v) { // C1147 |
| 78 | auto type{x->v->GetType()}; |
| 79 | if (type && type->category() == caseExprType_.category() && |
| 80 | (type->category() != TypeCategory::Character || |
| 81 | type->kind() == caseExprType_.kind())) { |
| 82 | parser::Messages buffer; // discarded folding messages |
| 83 | parser::ContextualMessages foldingMessages{expr.source, &buffer}; |
| 84 | evaluate::FoldingContext foldingContext{ |
| 85 | context_.foldingContext(), foldingMessages}; |
| 86 | auto folded{evaluate::Fold(foldingContext, SomeExpr{*x->v})}; |
| 87 | if (auto converted{evaluate::Fold(foldingContext, |
| 88 | evaluate::ConvertToType(T::GetType(), SomeExpr{folded}))}) { |
| 89 | if (auto value{evaluate::GetScalarConstantValue<T>(*converted)}) { |
| 90 | auto back{evaluate::Fold(foldingContext, |
| 91 | evaluate::ConvertToType(*type, SomeExpr{*converted}))}; |
| 92 | if (back == folded) { |
| 93 | x->v = converted; |
| 94 | return value; |
| 95 | } else { |
| 96 | context_.Warn(common::UsageWarning::CaseOverflow, expr.source, |
| 97 | "CASE value (%s) overflows type (%s) of SELECT CASE expression"_warn_en_US , |
| 98 | folded.AsFortran(), caseExprType_.AsFortran()); |
| 99 | hasErrors_ = true; |
| 100 | return std::nullopt; |
| 101 | } |
| 102 | } |
| 103 | } |
| 104 | context_.Say(expr.source, |
| 105 | "CASE value (%s) must be a constant scalar"_err_en_US , |
| 106 | x->v->AsFortran()); |
| 107 | } else { |
| 108 | std::string typeStr{type ? type->AsFortran() : "typeless"s }; |
| 109 | context_.Say(expr.source, |
| 110 | "CASE value has type '%s' which is not compatible with the SELECT CASE expression's type '%s'"_err_en_US , |
| 111 | typeStr, caseExprType_.AsFortran()); |
| 112 | } |
| 113 | hasErrors_ = true; |
| 114 | } |
| 115 | return std::nullopt; |
| 116 | } |
| 117 | |
| 118 | using PairOfValues = std::pair<std::optional<Value>, std::optional<Value>>; |
| 119 | PairOfValues ComputeBounds(const parser::CaseValueRange &range) { |
| 120 | return common::visit( |
| 121 | common::visitors{ |
| 122 | [&](const parser::CaseValue &x) { |
| 123 | auto value{GetValue(x)}; |
| 124 | return PairOfValues{value, value}; |
| 125 | }, |
| 126 | [&](const parser::CaseValueRange::Range &x) { |
| 127 | std::optional<Value> lo, hi; |
| 128 | if (x.lower) { |
| 129 | lo = GetValue(*x.lower); |
| 130 | } |
| 131 | if (x.upper) { |
| 132 | hi = GetValue(*x.upper); |
| 133 | } |
| 134 | if ((x.lower && !lo) || (x.upper && !hi)) { |
| 135 | return PairOfValues{}; // error case |
| 136 | } |
| 137 | return PairOfValues{std::move(lo), std::move(hi)}; |
| 138 | }, |
| 139 | }, |
| 140 | range.u); |
| 141 | } |
| 142 | |
| 143 | struct Case { |
| 144 | explicit Case(const parser::Statement<parser::CaseStmt> &s) : stmt{s} {} |
| 145 | bool IsDefault() const { return !lower && !upper; } |
| 146 | std::string AsFortran() const { |
| 147 | std::string result; |
| 148 | { |
| 149 | llvm::raw_string_ostream bs{result}; |
| 150 | if (lower) { |
| 151 | evaluate::Constant<T>{*lower}.AsFortran(bs << '('); |
| 152 | if (!upper) { |
| 153 | bs << ':'; |
| 154 | } else if (*lower != *upper) { |
| 155 | evaluate::Constant<T>{*upper}.AsFortran(bs << ':'); |
| 156 | } |
| 157 | bs << ')'; |
| 158 | } else if (upper) { |
| 159 | evaluate::Constant<T>{*upper}.AsFortran(bs << "(:" ) << ')'; |
| 160 | } else { |
| 161 | bs << "DEFAULT" ; |
| 162 | } |
| 163 | } |
| 164 | return result; |
| 165 | } |
| 166 | |
| 167 | const parser::Statement<parser::CaseStmt> &stmt; |
| 168 | std::optional<Value> lower, upper; |
| 169 | }; |
| 170 | |
| 171 | // Defines a comparator for use with std::list<>::sort(). |
| 172 | // Returns true if and only if the highest value in range x is less |
| 173 | // than the least value in range y. The DEFAULT case is arbitrarily |
| 174 | // defined to be less than all others. When two ranges overlap, |
| 175 | // neither is less than the other. |
| 176 | struct Comparator { |
| 177 | bool operator()(const Case &x, const Case &y) const { |
| 178 | if (x.IsDefault()) { |
| 179 | return !y.IsDefault(); |
| 180 | } else { |
| 181 | return x.upper && y.lower && *x.upper < *y.lower; |
| 182 | } |
| 183 | } |
| 184 | }; |
| 185 | |
| 186 | bool AreCasesDisjoint() const { |
| 187 | auto endIter{cases_.end()}; |
| 188 | for (auto iter{cases_.begin()}; iter != endIter; ++iter) { |
| 189 | auto next{iter}; |
| 190 | if (++next != endIter && !Comparator{}(*iter, *next)) { |
| 191 | return false; |
| 192 | } |
| 193 | } |
| 194 | return true; |
| 195 | } |
| 196 | |
| 197 | // This has quadratic time, but only runs in error cases |
| 198 | void ReportConflictingCases() { |
| 199 | for (auto iter{cases_.begin()}; iter != cases_.end(); ++iter) { |
| 200 | parser::Message *msg{nullptr}; |
| 201 | for (auto p{cases_.begin()}; p != cases_.end(); ++p) { |
| 202 | if (p->stmt.source.begin() < iter->stmt.source.begin() && |
| 203 | !Comparator{}(*p, *iter) && !Comparator{}(*iter, *p)) { |
| 204 | if (!msg) { |
| 205 | msg = &context_.Say(iter->stmt.source, |
| 206 | "CASE %s conflicts with previous cases"_err_en_US , |
| 207 | iter->AsFortran()); |
| 208 | } |
| 209 | msg->Attach( |
| 210 | p->stmt.source, "Conflicting CASE %s"_en_US , p->AsFortran()); |
| 211 | } |
| 212 | } |
| 213 | } |
| 214 | } |
| 215 | |
| 216 | SemanticsContext &context_; |
| 217 | const evaluate::DynamicType &caseExprType_; |
| 218 | std::list<Case> cases_; |
| 219 | bool hasErrors_{false}; |
| 220 | }; |
| 221 | |
| 222 | template <TypeCategory CAT> struct TypeVisitor { |
| 223 | using Result = bool; |
| 224 | using Types = evaluate::CategoryTypes<CAT>; |
| 225 | template <typename T> Result Test() { |
| 226 | if (T::kind == exprType.kind()) { |
| 227 | CaseValues<T>(context, exprType).Check(caseList); |
| 228 | return true; |
| 229 | } else { |
| 230 | return false; |
| 231 | } |
| 232 | } |
| 233 | SemanticsContext &context; |
| 234 | const evaluate::DynamicType &exprType; |
| 235 | const std::list<parser::CaseConstruct::Case> &caseList; |
| 236 | }; |
| 237 | |
| 238 | void CaseChecker::Enter(const parser::CaseConstruct &construct) { |
| 239 | const auto &selectCaseStmt{ |
| 240 | std::get<parser::Statement<parser::SelectCaseStmt>>(construct.t)}; |
| 241 | const auto &selectCase{selectCaseStmt.statement}; |
| 242 | const auto &selectExpr{ |
| 243 | std::get<parser::Scalar<parser::Expr>>(selectCase.t).thing}; |
| 244 | const auto *x{GetExpr(context_, selectExpr)}; |
| 245 | if (!x) { |
| 246 | return; // expression semantics failed |
| 247 | } |
| 248 | if (auto exprType{x->GetType()}) { |
| 249 | const auto &caseList{ |
| 250 | std::get<std::list<parser::CaseConstruct::Case>>(construct.t)}; |
| 251 | switch (exprType->category()) { |
| 252 | case TypeCategory::Integer: |
| 253 | common::SearchTypes( |
| 254 | TypeVisitor<TypeCategory::Integer>{context_, *exprType, caseList}); |
| 255 | return; |
| 256 | case TypeCategory::Unsigned: |
| 257 | common::SearchTypes( |
| 258 | TypeVisitor<TypeCategory::Unsigned>{context_, *exprType, caseList}); |
| 259 | return; |
| 260 | case TypeCategory::Logical: |
| 261 | CaseValues<evaluate::Type<TypeCategory::Logical, 1>>{context_, *exprType} |
| 262 | .Check(caseList); |
| 263 | return; |
| 264 | case TypeCategory::Character: |
| 265 | common::SearchTypes( |
| 266 | TypeVisitor<TypeCategory::Character>{context_, *exprType, caseList}); |
| 267 | return; |
| 268 | default: |
| 269 | break; |
| 270 | } |
| 271 | } |
| 272 | context_.Say(selectExpr.source, |
| 273 | context_.IsEnabled(common::LanguageFeature::Unsigned) |
| 274 | ? "SELECT CASE expression must be integer, unsigned, logical, or character"_err_en_US |
| 275 | : "SELECT CASE expression must be integer, logical, or character"_err_en_US ); |
| 276 | } |
| 277 | } // namespace Fortran::semantics |
| 278 | |