1 | //===-- lib/Semantics/check-case.cpp --------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "check-case.h" |
10 | #include "flang/Common/idioms.h" |
11 | #include "flang/Common/reference.h" |
12 | #include "flang/Common/template.h" |
13 | #include "flang/Evaluate/fold.h" |
14 | #include "flang/Evaluate/type.h" |
15 | #include "flang/Parser/parse-tree.h" |
16 | #include "flang/Semantics/semantics.h" |
17 | #include "flang/Semantics/tools.h" |
18 | #include <tuple> |
19 | |
20 | namespace Fortran::semantics { |
21 | |
22 | template <typename T> class CaseValues { |
23 | public: |
24 | CaseValues(SemanticsContext &c, const evaluate::DynamicType &t) |
25 | : context_{c}, caseExprType_{t} {} |
26 | |
27 | void Check(const std::list<parser::CaseConstruct::Case> &cases) { |
28 | for (const parser::CaseConstruct::Case &c : cases) { |
29 | AddCase(c); |
30 | } |
31 | if (!hasErrors_) { |
32 | cases_.sort(Comparator{}); |
33 | if (!AreCasesDisjoint()) { // C1149 |
34 | ReportConflictingCases(); |
35 | } |
36 | } |
37 | } |
38 | |
39 | private: |
40 | using Value = evaluate::Scalar<T>; |
41 | |
42 | void AddCase(const parser::CaseConstruct::Case &c) { |
43 | const auto &stmt{std::get<parser::Statement<parser::CaseStmt>>(c.t)}; |
44 | const parser::CaseStmt &caseStmt{stmt.statement}; |
45 | const auto &selector{std::get<parser::CaseSelector>(caseStmt.t)}; |
46 | common::visit( |
47 | common::visitors{ |
48 | [&](const std::list<parser::CaseValueRange> &ranges) { |
49 | for (const auto &range : ranges) { |
50 | auto pair{ComputeBounds(range)}; |
51 | if (pair.first && pair.second && *pair.first > *pair.second) { |
52 | context_.Say(stmt.source, |
53 | "CASE has lower bound greater than upper bound"_warn_en_US ); |
54 | } else { |
55 | if constexpr (T::category == TypeCategory::Logical) { // C1148 |
56 | if ((pair.first || pair.second) && |
57 | (!pair.first || !pair.second || |
58 | *pair.first != *pair.second)) { |
59 | context_.Say(stmt.source, |
60 | "CASE range is not allowed for LOGICAL"_err_en_US ); |
61 | } |
62 | } |
63 | cases_.emplace_back(stmt); |
64 | cases_.back().lower = std::move(pair.first); |
65 | cases_.back().upper = std::move(pair.second); |
66 | } |
67 | } |
68 | }, |
69 | [&](const parser::Default &) { cases_.emplace_front(stmt); }, |
70 | }, |
71 | selector.u); |
72 | } |
73 | |
74 | std::optional<Value> GetValue(const parser::CaseValue &caseValue) { |
75 | const parser::Expr &expr{caseValue.thing.thing.value()}; |
76 | auto *x{expr.typedExpr.get()}; |
77 | if (x && x->v) { // C1147 |
78 | auto type{x->v->GetType()}; |
79 | if (type && type->category() == caseExprType_.category() && |
80 | (type->category() != TypeCategory::Character || |
81 | type->kind() == caseExprType_.kind())) { |
82 | parser::Messages buffer; // discarded folding messages |
83 | parser::ContextualMessages foldingMessages{expr.source, &buffer}; |
84 | evaluate::FoldingContext foldingContext{ |
85 | context_.foldingContext(), foldingMessages}; |
86 | auto folded{evaluate::Fold(foldingContext, SomeExpr{*x->v})}; |
87 | if (auto converted{evaluate::Fold(foldingContext, |
88 | evaluate::ConvertToType(T::GetType(), SomeExpr{folded}))}) { |
89 | if (auto value{evaluate::GetScalarConstantValue<T>(*converted)}) { |
90 | auto back{evaluate::Fold(foldingContext, |
91 | evaluate::ConvertToType(*type, SomeExpr{*converted}))}; |
92 | if (back == folded) { |
93 | x->v = converted; |
94 | return value; |
95 | } else { |
96 | context_.Say(expr.source, |
97 | "CASE value (%s) overflows type (%s) of SELECT CASE expression"_warn_en_US , |
98 | folded.AsFortran(), caseExprType_.AsFortran()); |
99 | hasErrors_ = true; |
100 | return std::nullopt; |
101 | } |
102 | } |
103 | } |
104 | context_.Say(expr.source, |
105 | "CASE value (%s) must be a constant scalar"_err_en_US , |
106 | x->v->AsFortran()); |
107 | } else { |
108 | std::string typeStr{type ? type->AsFortran() : "typeless"s }; |
109 | context_.Say(expr.source, |
110 | "CASE value has type '%s' which is not compatible with the SELECT CASE expression's type '%s'"_err_en_US , |
111 | typeStr, caseExprType_.AsFortran()); |
112 | } |
113 | hasErrors_ = true; |
114 | } |
115 | return std::nullopt; |
116 | } |
117 | |
118 | using PairOfValues = std::pair<std::optional<Value>, std::optional<Value>>; |
119 | PairOfValues ComputeBounds(const parser::CaseValueRange &range) { |
120 | return common::visit( |
121 | common::visitors{ |
122 | [&](const parser::CaseValue &x) { |
123 | auto value{GetValue(x)}; |
124 | return PairOfValues{value, value}; |
125 | }, |
126 | [&](const parser::CaseValueRange::Range &x) { |
127 | std::optional<Value> lo, hi; |
128 | if (x.lower) { |
129 | lo = GetValue(*x.lower); |
130 | } |
131 | if (x.upper) { |
132 | hi = GetValue(*x.upper); |
133 | } |
134 | if ((x.lower && !lo) || (x.upper && !hi)) { |
135 | return PairOfValues{}; // error case |
136 | } |
137 | return PairOfValues{std::move(lo), std::move(hi)}; |
138 | }, |
139 | }, |
140 | range.u); |
141 | } |
142 | |
143 | struct Case { |
144 | explicit Case(const parser::Statement<parser::CaseStmt> &s) : stmt{s} {} |
145 | bool IsDefault() const { return !lower && !upper; } |
146 | std::string AsFortran() const { |
147 | std::string result; |
148 | { |
149 | llvm::raw_string_ostream bs{result}; |
150 | if (lower) { |
151 | evaluate::Constant<T>{*lower}.AsFortran(bs << '('); |
152 | if (!upper) { |
153 | bs << ':'; |
154 | } else if (*lower != *upper) { |
155 | evaluate::Constant<T>{*upper}.AsFortran(bs << ':'); |
156 | } |
157 | bs << ')'; |
158 | } else if (upper) { |
159 | evaluate::Constant<T>{*upper}.AsFortran(bs << "(:" ) << ')'; |
160 | } else { |
161 | bs << "DEFAULT" ; |
162 | } |
163 | } |
164 | return result; |
165 | } |
166 | |
167 | const parser::Statement<parser::CaseStmt> &stmt; |
168 | std::optional<Value> lower, upper; |
169 | }; |
170 | |
171 | // Defines a comparator for use with std::list<>::sort(). |
172 | // Returns true if and only if the highest value in range x is less |
173 | // than the least value in range y. The DEFAULT case is arbitrarily |
174 | // defined to be less than all others. When two ranges overlap, |
175 | // neither is less than the other. |
176 | struct Comparator { |
177 | bool operator()(const Case &x, const Case &y) const { |
178 | if (x.IsDefault()) { |
179 | return !y.IsDefault(); |
180 | } else { |
181 | return x.upper && y.lower && *x.upper < *y.lower; |
182 | } |
183 | } |
184 | }; |
185 | |
186 | bool AreCasesDisjoint() const { |
187 | auto endIter{cases_.end()}; |
188 | for (auto iter{cases_.begin()}; iter != endIter; ++iter) { |
189 | auto next{iter}; |
190 | if (++next != endIter && !Comparator{}(*iter, *next)) { |
191 | return false; |
192 | } |
193 | } |
194 | return true; |
195 | } |
196 | |
197 | // This has quadratic time, but only runs in error cases |
198 | void ReportConflictingCases() { |
199 | for (auto iter{cases_.begin()}; iter != cases_.end(); ++iter) { |
200 | parser::Message *msg{nullptr}; |
201 | for (auto p{cases_.begin()}; p != cases_.end(); ++p) { |
202 | if (p->stmt.source.begin() < iter->stmt.source.begin() && |
203 | !Comparator{}(*p, *iter) && !Comparator{}(*iter, *p)) { |
204 | if (!msg) { |
205 | msg = &context_.Say(iter->stmt.source, |
206 | "CASE %s conflicts with previous cases"_err_en_US , |
207 | iter->AsFortran()); |
208 | } |
209 | msg->Attach( |
210 | p->stmt.source, "Conflicting CASE %s"_en_US , p->AsFortran()); |
211 | } |
212 | } |
213 | } |
214 | } |
215 | |
216 | SemanticsContext &context_; |
217 | const evaluate::DynamicType &caseExprType_; |
218 | std::list<Case> cases_; |
219 | bool hasErrors_{false}; |
220 | }; |
221 | |
222 | template <TypeCategory CAT> struct TypeVisitor { |
223 | using Result = bool; |
224 | using Types = evaluate::CategoryTypes<CAT>; |
225 | template <typename T> Result Test() { |
226 | if (T::kind == exprType.kind()) { |
227 | CaseValues<T>(context, exprType).Check(caseList); |
228 | return true; |
229 | } else { |
230 | return false; |
231 | } |
232 | } |
233 | SemanticsContext &context; |
234 | const evaluate::DynamicType &exprType; |
235 | const std::list<parser::CaseConstruct::Case> &caseList; |
236 | }; |
237 | |
238 | void CaseChecker::Enter(const parser::CaseConstruct &construct) { |
239 | const auto &selectCaseStmt{ |
240 | std::get<parser::Statement<parser::SelectCaseStmt>>(construct.t)}; |
241 | const auto &selectCase{selectCaseStmt.statement}; |
242 | const auto &selectExpr{ |
243 | std::get<parser::Scalar<parser::Expr>>(selectCase.t).thing}; |
244 | const auto *x{GetExpr(context_, selectExpr)}; |
245 | if (!x) { |
246 | return; // expression semantics failed |
247 | } |
248 | if (auto exprType{x->GetType()}) { |
249 | const auto &caseList{ |
250 | std::get<std::list<parser::CaseConstruct::Case>>(construct.t)}; |
251 | switch (exprType->category()) { |
252 | case TypeCategory::Integer: |
253 | common::SearchTypes( |
254 | TypeVisitor<TypeCategory::Integer>{context_, *exprType, caseList}); |
255 | return; |
256 | case TypeCategory::Logical: |
257 | CaseValues<evaluate::Type<TypeCategory::Logical, 1>>{context_, *exprType} |
258 | .Check(caseList); |
259 | return; |
260 | case TypeCategory::Character: |
261 | common::SearchTypes( |
262 | TypeVisitor<TypeCategory::Character>{context_, *exprType, caseList}); |
263 | return; |
264 | default: |
265 | break; |
266 | } |
267 | } |
268 | context_.Say(selectExpr.source, |
269 | "SELECT CASE expression must be integer, logical, or character"_err_en_US ); |
270 | } |
271 | } // namespace Fortran::semantics |
272 | |