1//===--- LoopUnrolling.cpp - Unroll loops -----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// This file contains functions which are used to decide if a loop worth to be
10/// unrolled. Moreover, these functions manages the stack of loop which is
11/// tracked by the ProgramState.
12///
13//===----------------------------------------------------------------------===//
14
15#include "clang/ASTMatchers/ASTMatchers.h"
16#include "clang/ASTMatchers/ASTMatchFinder.h"
17#include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
18#include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
19#include "clang/StaticAnalyzer/Core/PathSensitive/LoopUnrolling.h"
20#include <optional>
21
22using namespace clang;
23using namespace ento;
24using namespace clang::ast_matchers;
25
26static const int MAXIMUM_STEP_UNROLLED = 128;
27
28namespace {
29struct LoopState {
30private:
31 enum Kind { Normal, Unrolled } K;
32 const Stmt *LoopStmt;
33 const LocationContext *LCtx;
34 unsigned maxStep;
35 LoopState(Kind InK, const Stmt *S, const LocationContext *L, unsigned N)
36 : K(InK), LoopStmt(S), LCtx(L), maxStep(N) {}
37
38public:
39 static LoopState getNormal(const Stmt *S, const LocationContext *L,
40 unsigned N) {
41 return LoopState(Normal, S, L, N);
42 }
43 static LoopState getUnrolled(const Stmt *S, const LocationContext *L,
44 unsigned N) {
45 return LoopState(Unrolled, S, L, N);
46 }
47 bool isUnrolled() const { return K == Unrolled; }
48 unsigned getMaxStep() const { return maxStep; }
49 const Stmt *getLoopStmt() const { return LoopStmt; }
50 const LocationContext *getLocationContext() const { return LCtx; }
51 bool operator==(const LoopState &X) const {
52 return K == X.K && LoopStmt == X.LoopStmt;
53 }
54 void Profile(llvm::FoldingSetNodeID &ID) const {
55 ID.AddInteger(I: K);
56 ID.AddPointer(Ptr: LoopStmt);
57 ID.AddPointer(Ptr: LCtx);
58 ID.AddInteger(I: maxStep);
59 }
60};
61} // namespace
62
63// The tracked stack of loops. The stack indicates that which loops the
64// simulated element contained by. The loops are marked depending if we decided
65// to unroll them.
66// TODO: The loop stack should not need to be in the program state since it is
67// lexical in nature. Instead, the stack of loops should be tracked in the
68// LocationContext.
69REGISTER_LIST_WITH_PROGRAMSTATE(LoopStack, LoopState)
70
71namespace clang {
72namespace ento {
73
74static bool isLoopStmt(const Stmt *S) {
75 return isa_and_nonnull<ForStmt, WhileStmt, DoStmt>(Val: S);
76}
77
78ProgramStateRef processLoopEnd(const Stmt *LoopStmt, ProgramStateRef State) {
79 auto LS = State->get<LoopStack>();
80 if (!LS.isEmpty() && LS.getHead().getLoopStmt() == LoopStmt)
81 State = State->set<LoopStack>(LS.getTail());
82 return State;
83}
84
85static internal::Matcher<Stmt> simpleCondition(StringRef BindName,
86 StringRef RefName) {
87 return binaryOperator(
88 anyOf(hasOperatorName(Name: "<"), hasOperatorName(Name: ">"),
89 hasOperatorName(Name: "<="), hasOperatorName(Name: ">="),
90 hasOperatorName(Name: "!=")),
91 hasEitherOperand(InnerMatcher: ignoringParenImpCasts(
92 InnerMatcher: declRefExpr(to(InnerMatcher: varDecl(hasType(InnerMatcher: isInteger())).bind(ID: BindName)))
93 .bind(ID: RefName))),
94 hasEitherOperand(
95 InnerMatcher: ignoringParenImpCasts(InnerMatcher: integerLiteral().bind(ID: "boundNum"))))
96 .bind(ID: "conditionOperator");
97}
98
99static internal::Matcher<Stmt>
100changeIntBoundNode(internal::Matcher<Decl> VarNodeMatcher) {
101 return anyOf(
102 unaryOperator(anyOf(hasOperatorName(Name: "--"), hasOperatorName(Name: "++")),
103 hasUnaryOperand(InnerMatcher: ignoringParenImpCasts(
104 InnerMatcher: declRefExpr(to(InnerMatcher: varDecl(VarNodeMatcher)))))),
105 binaryOperator(isAssignmentOperator(),
106 hasLHS(InnerMatcher: ignoringParenImpCasts(
107 InnerMatcher: declRefExpr(to(InnerMatcher: varDecl(VarNodeMatcher)))))));
108}
109
110static internal::Matcher<Stmt>
111callByRef(internal::Matcher<Decl> VarNodeMatcher) {
112 return callExpr(forEachArgumentWithParam(
113 ArgMatcher: declRefExpr(to(InnerMatcher: varDecl(VarNodeMatcher))),
114 ParamMatcher: parmVarDecl(hasType(InnerMatcher: references(InnerMatcher: qualType(unless(isConstQualified())))))));
115}
116
117static internal::Matcher<Stmt>
118assignedToRef(internal::Matcher<Decl> VarNodeMatcher) {
119 return declStmt(hasDescendant(varDecl(
120 allOf(hasType(InnerMatcher: referenceType()),
121 hasInitializer(InnerMatcher: anyOf(
122 initListExpr(has(declRefExpr(to(InnerMatcher: varDecl(VarNodeMatcher))))),
123 declRefExpr(to(InnerMatcher: varDecl(VarNodeMatcher)))))))));
124}
125
126static internal::Matcher<Stmt>
127getAddrTo(internal::Matcher<Decl> VarNodeMatcher) {
128 return unaryOperator(
129 hasOperatorName(Name: "&"),
130 hasUnaryOperand(InnerMatcher: declRefExpr(hasDeclaration(InnerMatcher: VarNodeMatcher))));
131}
132
133static internal::Matcher<Stmt> hasSuspiciousStmt(StringRef NodeName) {
134 return hasDescendant(stmt(
135 anyOf(gotoStmt(), switchStmt(), returnStmt(),
136 // Escaping and not known mutation of the loop counter is handled
137 // by exclusion of assigning and address-of operators and
138 // pass-by-ref function calls on the loop counter from the body.
139 changeIntBoundNode(VarNodeMatcher: equalsBoundNode(ID: std::string(NodeName))),
140 callByRef(VarNodeMatcher: equalsBoundNode(ID: std::string(NodeName))),
141 getAddrTo(VarNodeMatcher: equalsBoundNode(ID: std::string(NodeName))),
142 assignedToRef(VarNodeMatcher: equalsBoundNode(ID: std::string(NodeName))))));
143}
144
145static internal::Matcher<Stmt> forLoopMatcher() {
146 return forStmt(
147 hasCondition(InnerMatcher: simpleCondition(BindName: "initVarName", RefName: "initVarRef")),
148 // Initialization should match the form: 'int i = 6' or 'i = 42'.
149 hasLoopInit(
150 InnerMatcher: anyOf(declStmt(hasSingleDecl(
151 InnerMatcher: varDecl(allOf(hasInitializer(InnerMatcher: ignoringParenImpCasts(
152 InnerMatcher: integerLiteral().bind(ID: "initNum"))),
153 equalsBoundNode(ID: "initVarName"))))),
154 binaryOperator(hasLHS(InnerMatcher: declRefExpr(to(InnerMatcher: varDecl(
155 equalsBoundNode(ID: "initVarName"))))),
156 hasRHS(InnerMatcher: ignoringParenImpCasts(
157 InnerMatcher: integerLiteral().bind(ID: "initNum")))))),
158 // Incrementation should be a simple increment or decrement
159 // operator call.
160 hasIncrement(InnerMatcher: unaryOperator(
161 anyOf(hasOperatorName(Name: "++"), hasOperatorName(Name: "--")),
162 hasUnaryOperand(InnerMatcher: declRefExpr(
163 to(InnerMatcher: varDecl(allOf(equalsBoundNode(ID: "initVarName"),
164 hasType(InnerMatcher: isInteger())))))))),
165 unless(hasBody(InnerMatcher: hasSuspiciousStmt(NodeName: "initVarName"))))
166 .bind(ID: "forLoop");
167}
168
169static bool isCapturedByReference(ExplodedNode *N, const DeclRefExpr *DR) {
170
171 // Get the lambda CXXRecordDecl
172 assert(DR->refersToEnclosingVariableOrCapture());
173 const LocationContext *LocCtxt = N->getLocationContext();
174 const Decl *D = LocCtxt->getDecl();
175 const auto *MD = cast<CXXMethodDecl>(Val: D);
176 assert(MD && MD->getParent()->isLambda() &&
177 "Captured variable should only be seen while evaluating a lambda");
178 const CXXRecordDecl *LambdaCXXRec = MD->getParent();
179
180 // Lookup the fields of the lambda
181 llvm::DenseMap<const ValueDecl *, FieldDecl *> LambdaCaptureFields;
182 FieldDecl *LambdaThisCaptureField;
183 LambdaCXXRec->getCaptureFields(Captures&: LambdaCaptureFields, ThisCapture&: LambdaThisCaptureField);
184
185 // Check if the counter is captured by reference
186 const VarDecl *VD = cast<VarDecl>(DR->getDecl()->getCanonicalDecl());
187 assert(VD);
188 const FieldDecl *FD = LambdaCaptureFields[VD];
189 assert(FD && "Captured variable without a corresponding field");
190 return FD->getType()->isReferenceType();
191}
192
193// A loop counter is considered escaped if:
194// case 1: It is a global variable.
195// case 2: It is a reference parameter or a reference capture.
196// case 3: It is assigned to a non-const reference variable or parameter.
197// case 4: Has its address taken.
198static bool isPossiblyEscaped(ExplodedNode *N, const DeclRefExpr *DR) {
199 const VarDecl *VD = cast<VarDecl>(DR->getDecl()->getCanonicalDecl());
200 assert(VD);
201 // Case 1:
202 if (VD->hasGlobalStorage())
203 return true;
204
205 const bool IsRefParamOrCapture =
206 isa<ParmVarDecl>(Val: VD) || DR->refersToEnclosingVariableOrCapture();
207 // Case 2:
208 if ((DR->refersToEnclosingVariableOrCapture() &&
209 isCapturedByReference(N, DR)) ||
210 (IsRefParamOrCapture && VD->getType()->isReferenceType()))
211 return true;
212
213 while (!N->pred_empty()) {
214 // FIXME: getStmtForDiagnostics() does nasty things in order to provide
215 // a valid statement for body farms, do we need this behavior here?
216 const Stmt *S = N->getStmtForDiagnostics();
217 if (!S) {
218 N = N->getFirstPred();
219 continue;
220 }
221
222 if (const DeclStmt *DS = dyn_cast<DeclStmt>(Val: S)) {
223 for (const Decl *D : DS->decls()) {
224 // Once we reach the declaration of the VD we can return.
225 if (D->getCanonicalDecl() == VD)
226 return false;
227 }
228 }
229 // Check the usage of the pass-by-ref function calls and adress-of operator
230 // on VD and reference initialized by VD.
231 ASTContext &ASTCtx =
232 N->getLocationContext()->getAnalysisDeclContext()->getASTContext();
233 // Case 3 and 4:
234 auto Match =
235 match(stmt(anyOf(callByRef(equalsNode(VD)), getAddrTo(equalsNode(VD)),
236 assignedToRef(equalsNode(VD)))),
237 *S, ASTCtx);
238 if (!Match.empty())
239 return true;
240
241 N = N->getFirstPred();
242 }
243
244 // Reference parameter and reference capture will not be found.
245 if (IsRefParamOrCapture)
246 return false;
247
248 llvm_unreachable("Reached root without finding the declaration of VD");
249}
250
251bool shouldCompletelyUnroll(const Stmt *LoopStmt, ASTContext &ASTCtx,
252 ExplodedNode *Pred, unsigned &maxStep) {
253
254 if (!isLoopStmt(S: LoopStmt))
255 return false;
256
257 // TODO: Match the cases where the bound is not a concrete literal but an
258 // integer with known value
259 auto Matches = match(Matcher: forLoopMatcher(), Node: *LoopStmt, Context&: ASTCtx);
260 if (Matches.empty())
261 return false;
262
263 const auto *CounterVarRef = Matches[0].getNodeAs<DeclRefExpr>(ID: "initVarRef");
264 llvm::APInt BoundNum =
265 Matches[0].getNodeAs<IntegerLiteral>(ID: "boundNum")->getValue();
266 llvm::APInt InitNum =
267 Matches[0].getNodeAs<IntegerLiteral>(ID: "initNum")->getValue();
268 auto CondOp = Matches[0].getNodeAs<BinaryOperator>(ID: "conditionOperator");
269 if (InitNum.getBitWidth() != BoundNum.getBitWidth()) {
270 InitNum = InitNum.zext(width: BoundNum.getBitWidth());
271 BoundNum = BoundNum.zext(width: InitNum.getBitWidth());
272 }
273
274 if (CondOp->getOpcode() == BO_GE || CondOp->getOpcode() == BO_LE)
275 maxStep = (BoundNum - InitNum + 1).abs().getZExtValue();
276 else
277 maxStep = (BoundNum - InitNum).abs().getZExtValue();
278
279 // Check if the counter of the loop is not escaped before.
280 return !isPossiblyEscaped(N: Pred, DR: CounterVarRef);
281}
282
283bool madeNewBranch(ExplodedNode *N, const Stmt *LoopStmt) {
284 const Stmt *S = nullptr;
285 while (!N->pred_empty()) {
286 if (N->succ_size() > 1)
287 return true;
288
289 ProgramPoint P = N->getLocation();
290 if (std::optional<BlockEntrance> BE = P.getAs<BlockEntrance>())
291 S = BE->getBlock()->getTerminatorStmt();
292
293 if (S == LoopStmt)
294 return false;
295
296 N = N->getFirstPred();
297 }
298
299 llvm_unreachable("Reached root without encountering the previous step");
300}
301
302// updateLoopStack is called on every basic block, therefore it needs to be fast
303ProgramStateRef updateLoopStack(const Stmt *LoopStmt, ASTContext &ASTCtx,
304 ExplodedNode *Pred, unsigned maxVisitOnPath) {
305 auto State = Pred->getState();
306 auto LCtx = Pred->getLocationContext();
307
308 if (!isLoopStmt(S: LoopStmt))
309 return State;
310
311 auto LS = State->get<LoopStack>();
312 if (!LS.isEmpty() && LoopStmt == LS.getHead().getLoopStmt() &&
313 LCtx == LS.getHead().getLocationContext()) {
314 if (LS.getHead().isUnrolled() && madeNewBranch(N: Pred, LoopStmt)) {
315 State = State->set<LoopStack>(LS.getTail());
316 State = State->add<LoopStack>(
317 K: LoopState::getNormal(S: LoopStmt, L: LCtx, N: maxVisitOnPath));
318 }
319 return State;
320 }
321 unsigned maxStep;
322 if (!shouldCompletelyUnroll(LoopStmt, ASTCtx, Pred, maxStep)) {
323 State = State->add<LoopStack>(
324 K: LoopState::getNormal(S: LoopStmt, L: LCtx, N: maxVisitOnPath));
325 return State;
326 }
327
328 unsigned outerStep = (LS.isEmpty() ? 1 : LS.getHead().getMaxStep());
329
330 unsigned innerMaxStep = maxStep * outerStep;
331 if (innerMaxStep > MAXIMUM_STEP_UNROLLED)
332 State = State->add<LoopStack>(
333 K: LoopState::getNormal(S: LoopStmt, L: LCtx, N: maxVisitOnPath));
334 else
335 State = State->add<LoopStack>(
336 K: LoopState::getUnrolled(S: LoopStmt, L: LCtx, N: innerMaxStep));
337 return State;
338}
339
340bool isUnrolledState(ProgramStateRef State) {
341 auto LS = State->get<LoopStack>();
342 if (LS.isEmpty() || !LS.getHead().isUnrolled())
343 return false;
344 return true;
345}
346}
347}
348

source code of clang/lib/StaticAnalyzer/Core/LoopUnrolling.cpp