1//===-- lib/Semantics/check-case.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "check-case.h"
10#include "flang/Common/idioms.h"
11#include "flang/Common/reference.h"
12#include "flang/Common/template.h"
13#include "flang/Evaluate/fold.h"
14#include "flang/Evaluate/type.h"
15#include "flang/Parser/parse-tree.h"
16#include "flang/Semantics/semantics.h"
17#include "flang/Semantics/tools.h"
18#include <tuple>
19
20namespace Fortran::semantics {
21
22template <typename T> class CaseValues {
23public:
24 CaseValues(SemanticsContext &c, const evaluate::DynamicType &t)
25 : context_{c}, caseExprType_{t} {}
26
27 void Check(const std::list<parser::CaseConstruct::Case> &cases) {
28 for (const parser::CaseConstruct::Case &c : cases) {
29 AddCase(c);
30 }
31 if (!hasErrors_) {
32 cases_.sort(Comparator{});
33 if (!AreCasesDisjoint()) { // C1149
34 ReportConflictingCases();
35 }
36 }
37 }
38
39private:
40 using Value = evaluate::Scalar<T>;
41
42 void AddCase(const parser::CaseConstruct::Case &c) {
43 const auto &stmt{std::get<parser::Statement<parser::CaseStmt>>(c.t)};
44 const parser::CaseStmt &caseStmt{stmt.statement};
45 const auto &selector{std::get<parser::CaseSelector>(caseStmt.t)};
46 common::visit(
47 common::visitors{
48 [&](const std::list<parser::CaseValueRange> &ranges) {
49 for (const auto &range : ranges) {
50 auto pair{ComputeBounds(range)};
51 if (pair.first && pair.second && *pair.first > *pair.second) {
52 context_.Say(stmt.source,
53 "CASE has lower bound greater than upper bound"_warn_en_US);
54 } else {
55 if constexpr (T::category == TypeCategory::Logical) { // C1148
56 if ((pair.first || pair.second) &&
57 (!pair.first || !pair.second ||
58 *pair.first != *pair.second)) {
59 context_.Say(stmt.source,
60 "CASE range is not allowed for LOGICAL"_err_en_US);
61 }
62 }
63 cases_.emplace_back(stmt);
64 cases_.back().lower = std::move(pair.first);
65 cases_.back().upper = std::move(pair.second);
66 }
67 }
68 },
69 [&](const parser::Default &) { cases_.emplace_front(stmt); },
70 },
71 selector.u);
72 }
73
74 std::optional<Value> GetValue(const parser::CaseValue &caseValue) {
75 const parser::Expr &expr{caseValue.thing.thing.value()};
76 auto *x{expr.typedExpr.get()};
77 if (x && x->v) { // C1147
78 auto type{x->v->GetType()};
79 if (type && type->category() == caseExprType_.category() &&
80 (type->category() != TypeCategory::Character ||
81 type->kind() == caseExprType_.kind())) {
82 parser::Messages buffer; // discarded folding messages
83 parser::ContextualMessages foldingMessages{expr.source, &buffer};
84 evaluate::FoldingContext foldingContext{
85 context_.foldingContext(), foldingMessages};
86 auto folded{evaluate::Fold(foldingContext, SomeExpr{*x->v})};
87 if (auto converted{evaluate::Fold(foldingContext,
88 evaluate::ConvertToType(T::GetType(), SomeExpr{folded}))}) {
89 if (auto value{evaluate::GetScalarConstantValue<T>(*converted)}) {
90 auto back{evaluate::Fold(foldingContext,
91 evaluate::ConvertToType(*type, SomeExpr{*converted}))};
92 if (back == folded) {
93 x->v = converted;
94 return value;
95 } else {
96 context_.Say(expr.source,
97 "CASE value (%s) overflows type (%s) of SELECT CASE expression"_warn_en_US,
98 folded.AsFortran(), caseExprType_.AsFortran());
99 hasErrors_ = true;
100 return std::nullopt;
101 }
102 }
103 }
104 context_.Say(expr.source,
105 "CASE value (%s) must be a constant scalar"_err_en_US,
106 x->v->AsFortran());
107 } else {
108 std::string typeStr{type ? type->AsFortran() : "typeless"s};
109 context_.Say(expr.source,
110 "CASE value has type '%s' which is not compatible with the SELECT CASE expression's type '%s'"_err_en_US,
111 typeStr, caseExprType_.AsFortran());
112 }
113 hasErrors_ = true;
114 }
115 return std::nullopt;
116 }
117
118 using PairOfValues = std::pair<std::optional<Value>, std::optional<Value>>;
119 PairOfValues ComputeBounds(const parser::CaseValueRange &range) {
120 return common::visit(
121 common::visitors{
122 [&](const parser::CaseValue &x) {
123 auto value{GetValue(x)};
124 return PairOfValues{value, value};
125 },
126 [&](const parser::CaseValueRange::Range &x) {
127 std::optional<Value> lo, hi;
128 if (x.lower) {
129 lo = GetValue(*x.lower);
130 }
131 if (x.upper) {
132 hi = GetValue(*x.upper);
133 }
134 if ((x.lower && !lo) || (x.upper && !hi)) {
135 return PairOfValues{}; // error case
136 }
137 return PairOfValues{std::move(lo), std::move(hi)};
138 },
139 },
140 range.u);
141 }
142
143 struct Case {
144 explicit Case(const parser::Statement<parser::CaseStmt> &s) : stmt{s} {}
145 bool IsDefault() const { return !lower && !upper; }
146 std::string AsFortran() const {
147 std::string result;
148 {
149 llvm::raw_string_ostream bs{result};
150 if (lower) {
151 evaluate::Constant<T>{*lower}.AsFortran(bs << '(');
152 if (!upper) {
153 bs << ':';
154 } else if (*lower != *upper) {
155 evaluate::Constant<T>{*upper}.AsFortran(bs << ':');
156 }
157 bs << ')';
158 } else if (upper) {
159 evaluate::Constant<T>{*upper}.AsFortran(bs << "(:") << ')';
160 } else {
161 bs << "DEFAULT";
162 }
163 }
164 return result;
165 }
166
167 const parser::Statement<parser::CaseStmt> &stmt;
168 std::optional<Value> lower, upper;
169 };
170
171 // Defines a comparator for use with std::list<>::sort().
172 // Returns true if and only if the highest value in range x is less
173 // than the least value in range y. The DEFAULT case is arbitrarily
174 // defined to be less than all others. When two ranges overlap,
175 // neither is less than the other.
176 struct Comparator {
177 bool operator()(const Case &x, const Case &y) const {
178 if (x.IsDefault()) {
179 return !y.IsDefault();
180 } else {
181 return x.upper && y.lower && *x.upper < *y.lower;
182 }
183 }
184 };
185
186 bool AreCasesDisjoint() const {
187 auto endIter{cases_.end()};
188 for (auto iter{cases_.begin()}; iter != endIter; ++iter) {
189 auto next{iter};
190 if (++next != endIter && !Comparator{}(*iter, *next)) {
191 return false;
192 }
193 }
194 return true;
195 }
196
197 // This has quadratic time, but only runs in error cases
198 void ReportConflictingCases() {
199 for (auto iter{cases_.begin()}; iter != cases_.end(); ++iter) {
200 parser::Message *msg{nullptr};
201 for (auto p{cases_.begin()}; p != cases_.end(); ++p) {
202 if (p->stmt.source.begin() < iter->stmt.source.begin() &&
203 !Comparator{}(*p, *iter) && !Comparator{}(*iter, *p)) {
204 if (!msg) {
205 msg = &context_.Say(iter->stmt.source,
206 "CASE %s conflicts with previous cases"_err_en_US,
207 iter->AsFortran());
208 }
209 msg->Attach(
210 p->stmt.source, "Conflicting CASE %s"_en_US, p->AsFortran());
211 }
212 }
213 }
214 }
215
216 SemanticsContext &context_;
217 const evaluate::DynamicType &caseExprType_;
218 std::list<Case> cases_;
219 bool hasErrors_{false};
220};
221
222template <TypeCategory CAT> struct TypeVisitor {
223 using Result = bool;
224 using Types = evaluate::CategoryTypes<CAT>;
225 template <typename T> Result Test() {
226 if (T::kind == exprType.kind()) {
227 CaseValues<T>(context, exprType).Check(caseList);
228 return true;
229 } else {
230 return false;
231 }
232 }
233 SemanticsContext &context;
234 const evaluate::DynamicType &exprType;
235 const std::list<parser::CaseConstruct::Case> &caseList;
236};
237
238void CaseChecker::Enter(const parser::CaseConstruct &construct) {
239 const auto &selectCaseStmt{
240 std::get<parser::Statement<parser::SelectCaseStmt>>(construct.t)};
241 const auto &selectCase{selectCaseStmt.statement};
242 const auto &selectExpr{
243 std::get<parser::Scalar<parser::Expr>>(selectCase.t).thing};
244 const auto *x{GetExpr(context_, selectExpr)};
245 if (!x) {
246 return; // expression semantics failed
247 }
248 if (auto exprType{x->GetType()}) {
249 const auto &caseList{
250 std::get<std::list<parser::CaseConstruct::Case>>(construct.t)};
251 switch (exprType->category()) {
252 case TypeCategory::Integer:
253 common::SearchTypes(
254 TypeVisitor<TypeCategory::Integer>{context_, *exprType, caseList});
255 return;
256 case TypeCategory::Logical:
257 CaseValues<evaluate::Type<TypeCategory::Logical, 1>>{context_, *exprType}
258 .Check(caseList);
259 return;
260 case TypeCategory::Character:
261 common::SearchTypes(
262 TypeVisitor<TypeCategory::Character>{context_, *exprType, caseList});
263 return;
264 default:
265 break;
266 }
267 }
268 context_.Say(selectExpr.source,
269 "SELECT CASE expression must be integer, logical, or character"_err_en_US);
270}
271} // namespace Fortran::semantics
272

source code of flang/lib/Semantics/check-case.cpp