1//===--- LoopUnrolling.cpp - Unroll loops -----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// This file contains functions which are used to decide if a loop worth to be
10/// unrolled. Moreover, these functions manages the stack of loop which is
11/// tracked by the ProgramState.
12///
13//===----------------------------------------------------------------------===//
14
15#include "clang/ASTMatchers/ASTMatchers.h"
16#include "clang/ASTMatchers/ASTMatchFinder.h"
17#include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
18#include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
19#include "clang/StaticAnalyzer/Core/PathSensitive/LoopUnrolling.h"
20
21using namespace clang;
22using namespace ento;
23using namespace clang::ast_matchers;
24
25static const int MAXIMUM_STEP_UNROLLED = 128;
26
27struct LoopState {
28private:
29 enum Kind { Normal, Unrolled } K;
30 const Stmt *LoopStmt;
31 const LocationContext *LCtx;
32 unsigned maxStep;
33 LoopState(Kind InK, const Stmt *S, const LocationContext *L, unsigned N)
34 : K(InK), LoopStmt(S), LCtx(L), maxStep(N) {}
35
36public:
37 static LoopState getNormal(const Stmt *S, const LocationContext *L,
38 unsigned N) {
39 return LoopState(Normal, S, L, N);
40 }
41 static LoopState getUnrolled(const Stmt *S, const LocationContext *L,
42 unsigned N) {
43 return LoopState(Unrolled, S, L, N);
44 }
45 bool isUnrolled() const { return K == Unrolled; }
46 unsigned getMaxStep() const { return maxStep; }
47 const Stmt *getLoopStmt() const { return LoopStmt; }
48 const LocationContext *getLocationContext() const { return LCtx; }
49 bool operator==(const LoopState &X) const {
50 return K == X.K && LoopStmt == X.LoopStmt;
51 }
52 void Profile(llvm::FoldingSetNodeID &ID) const {
53 ID.AddInteger(K);
54 ID.AddPointer(LoopStmt);
55 ID.AddPointer(LCtx);
56 ID.AddInteger(maxStep);
57 }
58};
59
60// The tracked stack of loops. The stack indicates that which loops the
61// simulated element contained by. The loops are marked depending if we decided
62// to unroll them.
63// TODO: The loop stack should not need to be in the program state since it is
64// lexical in nature. Instead, the stack of loops should be tracked in the
65// LocationContext.
66REGISTER_LIST_WITH_PROGRAMSTATE(LoopStack, LoopState)
67
68namespace clang {
69namespace ento {
70
71static bool isLoopStmt(const Stmt *S) {
72 return isa_and_nonnull<ForStmt, WhileStmt, DoStmt>(S);
73}
74
75ProgramStateRef processLoopEnd(const Stmt *LoopStmt, ProgramStateRef State) {
76 auto LS = State->get<LoopStack>();
77 if (!LS.isEmpty() && LS.getHead().getLoopStmt() == LoopStmt)
78 State = State->set<LoopStack>(LS.getTail());
79 return State;
80}
81
82static internal::Matcher<Stmt> simpleCondition(StringRef BindName,
83 StringRef RefName) {
84 return binaryOperator(
85 anyOf(hasOperatorName("<"), hasOperatorName(">"),
86 hasOperatorName("<="), hasOperatorName(">="),
87 hasOperatorName("!=")),
88 hasEitherOperand(ignoringParenImpCasts(
89 declRefExpr(to(varDecl(hasType(isInteger())).bind(BindName)))
90 .bind(RefName))),
91 hasEitherOperand(
92 ignoringParenImpCasts(integerLiteral().bind("boundNum"))))
93 .bind("conditionOperator");
94}
95
96static internal::Matcher<Stmt>
97changeIntBoundNode(internal::Matcher<Decl> VarNodeMatcher) {
98 return anyOf(
99 unaryOperator(anyOf(hasOperatorName("--"), hasOperatorName("++")),
100 hasUnaryOperand(ignoringParenImpCasts(
101 declRefExpr(to(varDecl(VarNodeMatcher)))))),
102 binaryOperator(isAssignmentOperator(),
103 hasLHS(ignoringParenImpCasts(
104 declRefExpr(to(varDecl(VarNodeMatcher)))))));
105}
106
107static internal::Matcher<Stmt>
108callByRef(internal::Matcher<Decl> VarNodeMatcher) {
109 return callExpr(forEachArgumentWithParam(
110 declRefExpr(to(varDecl(VarNodeMatcher))),
111 parmVarDecl(hasType(references(qualType(unless(isConstQualified())))))));
112}
113
114static internal::Matcher<Stmt>
115assignedToRef(internal::Matcher<Decl> VarNodeMatcher) {
116 return declStmt(hasDescendant(varDecl(
117 allOf(hasType(referenceType()),
118 hasInitializer(anyOf(
119 initListExpr(has(declRefExpr(to(varDecl(VarNodeMatcher))))),
120 declRefExpr(to(varDecl(VarNodeMatcher)))))))));
121}
122
123static internal::Matcher<Stmt>
124getAddrTo(internal::Matcher<Decl> VarNodeMatcher) {
125 return unaryOperator(
126 hasOperatorName("&"),
127 hasUnaryOperand(declRefExpr(hasDeclaration(VarNodeMatcher))));
128}
129
130static internal::Matcher<Stmt> hasSuspiciousStmt(StringRef NodeName) {
131 return hasDescendant(stmt(
132 anyOf(gotoStmt(), switchStmt(), returnStmt(),
133 // Escaping and not known mutation of the loop counter is handled
134 // by exclusion of assigning and address-of operators and
135 // pass-by-ref function calls on the loop counter from the body.
136 changeIntBoundNode(equalsBoundNode(std::string(NodeName))),
137 callByRef(equalsBoundNode(std::string(NodeName))),
138 getAddrTo(equalsBoundNode(std::string(NodeName))),
139 assignedToRef(equalsBoundNode(std::string(NodeName))))));
140}
141
142static internal::Matcher<Stmt> forLoopMatcher() {
143 return forStmt(
144 hasCondition(simpleCondition("initVarName", "initVarRef")),
145 // Initialization should match the form: 'int i = 6' or 'i = 42'.
146 hasLoopInit(
147 anyOf(declStmt(hasSingleDecl(
148 varDecl(allOf(hasInitializer(ignoringParenImpCasts(
149 integerLiteral().bind("initNum"))),
150 equalsBoundNode("initVarName"))))),
151 binaryOperator(hasLHS(declRefExpr(to(varDecl(
152 equalsBoundNode("initVarName"))))),
153 hasRHS(ignoringParenImpCasts(
154 integerLiteral().bind("initNum")))))),
155 // Incrementation should be a simple increment or decrement
156 // operator call.
157 hasIncrement(unaryOperator(
158 anyOf(hasOperatorName("++"), hasOperatorName("--")),
159 hasUnaryOperand(declRefExpr(
160 to(varDecl(allOf(equalsBoundNode("initVarName"),
161 hasType(isInteger())))))))),
162 unless(hasBody(hasSuspiciousStmt("initVarName"))))
163 .bind("forLoop");
164}
165
166static bool isCapturedByReference(ExplodedNode *N, const DeclRefExpr *DR) {
167
168 // Get the lambda CXXRecordDecl
169 assert(DR->refersToEnclosingVariableOrCapture());
170 const LocationContext *LocCtxt = N->getLocationContext();
171 const Decl *D = LocCtxt->getDecl();
172 const auto *MD = cast<CXXMethodDecl>(D);
173 assert(MD && MD->getParent()->isLambda() &&
174 "Captured variable should only be seen while evaluating a lambda");
175 const CXXRecordDecl *LambdaCXXRec = MD->getParent();
176
177 // Lookup the fields of the lambda
178 llvm::DenseMap<const ValueDecl *, FieldDecl *> LambdaCaptureFields;
179 FieldDecl *LambdaThisCaptureField;
180 LambdaCXXRec->getCaptureFields(LambdaCaptureFields, LambdaThisCaptureField);
181
182 // Check if the counter is captured by reference
183 const VarDecl *VD = cast<VarDecl>(DR->getDecl()->getCanonicalDecl());
184 assert(VD);
185 const FieldDecl *FD = LambdaCaptureFields[VD];
186 assert(FD && "Captured variable without a corresponding field");
187 return FD->getType()->isReferenceType();
188}
189
190// A loop counter is considered escaped if:
191// case 1: It is a global variable.
192// case 2: It is a reference parameter or a reference capture.
193// case 3: It is assigned to a non-const reference variable or parameter.
194// case 4: Has its address taken.
195static bool isPossiblyEscaped(ExplodedNode *N, const DeclRefExpr *DR) {
196 const VarDecl *VD = cast<VarDecl>(DR->getDecl()->getCanonicalDecl());
197 assert(VD);
198 // Case 1:
199 if (VD->hasGlobalStorage())
200 return true;
201
202 const bool IsRefParamOrCapture =
203 isa<ParmVarDecl>(VD) || DR->refersToEnclosingVariableOrCapture();
204 // Case 2:
205 if ((DR->refersToEnclosingVariableOrCapture() &&
206 isCapturedByReference(N, DR)) ||
207 (IsRefParamOrCapture && VD->getType()->isReferenceType()))
208 return true;
209
210 while (!N->pred_empty()) {
211 // FIXME: getStmtForDiagnostics() does nasty things in order to provide
212 // a valid statement for body farms, do we need this behavior here?
213 const Stmt *S = N->getStmtForDiagnostics();
214 if (!S) {
215 N = N->getFirstPred();
216 continue;
217 }
218
219 if (const DeclStmt *DS = dyn_cast<DeclStmt>(S)) {
220 for (const Decl *D : DS->decls()) {
221 // Once we reach the declaration of the VD we can return.
222 if (D->getCanonicalDecl() == VD)
223 return false;
224 }
225 }
226 // Check the usage of the pass-by-ref function calls and adress-of operator
227 // on VD and reference initialized by VD.
228 ASTContext &ASTCtx =
229 N->getLocationContext()->getAnalysisDeclContext()->getASTContext();
230 // Case 3 and 4:
231 auto Match =
232 match(stmt(anyOf(callByRef(equalsNode(VD)), getAddrTo(equalsNode(VD)),
233 assignedToRef(equalsNode(VD)))),
234 *S, ASTCtx);
235 if (!Match.empty())
236 return true;
237
238 N = N->getFirstPred();
239 }
240
241 // Reference parameter and reference capture will not be found.
242 if (IsRefParamOrCapture)
243 return false;
244
245 llvm_unreachable("Reached root without finding the declaration of VD");
246}
247
248bool shouldCompletelyUnroll(const Stmt *LoopStmt, ASTContext &ASTCtx,
249 ExplodedNode *Pred, unsigned &maxStep) {
250
251 if (!isLoopStmt(LoopStmt))
252 return false;
253
254 // TODO: Match the cases where the bound is not a concrete literal but an
255 // integer with known value
256 auto Matches = match(forLoopMatcher(), *LoopStmt, ASTCtx);
257 if (Matches.empty())
258 return false;
259
260 const auto *CounterVarRef = Matches[0].getNodeAs<DeclRefExpr>("initVarRef");
261 llvm::APInt BoundNum =
262 Matches[0].getNodeAs<IntegerLiteral>("boundNum")->getValue();
263 llvm::APInt InitNum =
264 Matches[0].getNodeAs<IntegerLiteral>("initNum")->getValue();
265 auto CondOp = Matches[0].getNodeAs<BinaryOperator>("conditionOperator");
266 if (InitNum.getBitWidth() != BoundNum.getBitWidth()) {
267 InitNum = InitNum.zext(BoundNum.getBitWidth());
268 BoundNum = BoundNum.zext(InitNum.getBitWidth());
269 }
270
271 if (CondOp->getOpcode() == BO_GE || CondOp->getOpcode() == BO_LE)
272 maxStep = (BoundNum - InitNum + 1).abs().getZExtValue();
273 else
274 maxStep = (BoundNum - InitNum).abs().getZExtValue();
275
276 // Check if the counter of the loop is not escaped before.
277 return !isPossiblyEscaped(Pred, CounterVarRef);
278}
279
280bool madeNewBranch(ExplodedNode *N, const Stmt *LoopStmt) {
281 const Stmt *S = nullptr;
282 while (!N->pred_empty()) {
283 if (N->succ_size() > 1)
284 return true;
285
286 ProgramPoint P = N->getLocation();
287 if (Optional<BlockEntrance> BE = P.getAs<BlockEntrance>())
288 S = BE->getBlock()->getTerminatorStmt();
289
290 if (S == LoopStmt)
291 return false;
292
293 N = N->getFirstPred();
294 }
295
296 llvm_unreachable("Reached root without encountering the previous step");
297}
298
299// updateLoopStack is called on every basic block, therefore it needs to be fast
300ProgramStateRef updateLoopStack(const Stmt *LoopStmt, ASTContext &ASTCtx,
301 ExplodedNode *Pred, unsigned maxVisitOnPath) {
302 auto State = Pred->getState();
303 auto LCtx = Pred->getLocationContext();
304
305 if (!isLoopStmt(LoopStmt))
306 return State;
307
308 auto LS = State->get<LoopStack>();
309 if (!LS.isEmpty() && LoopStmt == LS.getHead().getLoopStmt() &&
310 LCtx == LS.getHead().getLocationContext()) {
311 if (LS.getHead().isUnrolled() && madeNewBranch(Pred, LoopStmt)) {
312 State = State->set<LoopStack>(LS.getTail());
313 State = State->add<LoopStack>(
314 LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
315 }
316 return State;
317 }
318 unsigned maxStep;
319 if (!shouldCompletelyUnroll(LoopStmt, ASTCtx, Pred, maxStep)) {
320 State = State->add<LoopStack>(
321 LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
322 return State;
323 }
324
325 unsigned outerStep = (LS.isEmpty() ? 1 : LS.getHead().getMaxStep());
326
327 unsigned innerMaxStep = maxStep * outerStep;
328 if (innerMaxStep > MAXIMUM_STEP_UNROLLED)
329 State = State->add<LoopStack>(
330 LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
331 else
332 State = State->add<LoopStack>(
333 LoopState::getUnrolled(LoopStmt, LCtx, innerMaxStep));
334 return State;
335}
336
337bool isUnrolledState(ProgramStateRef State) {
338 auto LS = State->get<LoopStack>();
339 if (LS.isEmpty() || !LS.getHead().isUnrolled())
340 return false;
341 return true;
342}
343}
344}
345

source code of clang/lib/StaticAnalyzer/Core/LoopUnrolling.cpp