1//===--- UnrollLoopsCheck.cpp - clang-tidy --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "UnrollLoopsCheck.h"
10#include "clang/AST/APValue.h"
11#include "clang/AST/ASTContext.h"
12#include "clang/AST/ASTTypeTraits.h"
13#include "clang/AST/OperationKinds.h"
14#include "clang/AST/ParentMapContext.h"
15#include "clang/ASTMatchers/ASTMatchFinder.h"
16#include <cmath>
17
18using namespace clang::ast_matchers;
19
20namespace clang::tidy::altera {
21
22UnrollLoopsCheck::UnrollLoopsCheck(StringRef Name, ClangTidyContext *Context)
23 : ClangTidyCheck(Name, Context),
24 MaxLoopIterations(Options.get(LocalName: "MaxLoopIterations", Default: 100U)) {}
25
26void UnrollLoopsCheck::registerMatchers(MatchFinder *Finder) {
27 const auto HasLoopBound = hasDescendant(
28 varDecl(matchesName(RegExp: "__end*"),
29 hasDescendant(integerLiteral().bind(ID: "cxx_loop_bound"))));
30 const auto CXXForRangeLoop =
31 cxxForRangeStmt(anyOf(HasLoopBound, unless(HasLoopBound)));
32 const auto AnyLoop = anyOf(forStmt(), whileStmt(), doStmt(), CXXForRangeLoop);
33 Finder->addMatcher(
34 NodeMatch: stmt(AnyLoop, unless(hasDescendant(stmt(AnyLoop)))).bind(ID: "loop"), Action: this);
35}
36
37void UnrollLoopsCheck::check(const MatchFinder::MatchResult &Result) {
38 const auto *Loop = Result.Nodes.getNodeAs<Stmt>(ID: "loop");
39 const auto *CXXLoopBound =
40 Result.Nodes.getNodeAs<IntegerLiteral>(ID: "cxx_loop_bound");
41 const ASTContext *Context = Result.Context;
42 switch (unrollType(Statement: Loop, Context: Result.Context)) {
43 case NotUnrolled:
44 diag(Loc: Loop->getBeginLoc(),
45 Description: "kernel performance could be improved by unrolling this loop with a "
46 "'#pragma unroll' directive");
47 break;
48 case PartiallyUnrolled:
49 // Loop already partially unrolled, do nothing.
50 break;
51 case FullyUnrolled:
52 if (hasKnownBounds(Statement: Loop, CXXLoopBound, Context)) {
53 if (hasLargeNumIterations(Statement: Loop, CXXLoopBound, Context)) {
54 diag(Loc: Loop->getBeginLoc(),
55 Description: "loop likely has a large number of iterations and thus "
56 "cannot be fully unrolled; to partially unroll this loop, use "
57 "the '#pragma unroll <num>' directive");
58 return;
59 }
60 return;
61 }
62 if (isa<WhileStmt, DoStmt>(Val: Loop)) {
63 diag(Loc: Loop->getBeginLoc(),
64 Description: "full unrolling requested, but loop bounds may not be known; to "
65 "partially unroll this loop, use the '#pragma unroll <num>' "
66 "directive",
67 Level: DiagnosticIDs::Note);
68 break;
69 }
70 diag(Loc: Loop->getBeginLoc(),
71 Description: "full unrolling requested, but loop bounds are not known; to "
72 "partially unroll this loop, use the '#pragma unroll <num>' "
73 "directive");
74 break;
75 }
76}
77
78enum UnrollLoopsCheck::UnrollType
79UnrollLoopsCheck::unrollType(const Stmt *Statement, ASTContext *Context) {
80 const DynTypedNodeList Parents = Context->getParents<Stmt>(Node: *Statement);
81 for (const DynTypedNode &Parent : Parents) {
82 const auto *ParentStmt = Parent.get<AttributedStmt>();
83 if (!ParentStmt)
84 continue;
85 for (const Attr *Attribute : ParentStmt->getAttrs()) {
86 const auto *LoopHint = dyn_cast<LoopHintAttr>(Attribute);
87 if (!LoopHint)
88 continue;
89 switch (LoopHint->getState()) {
90 case LoopHintAttr::Numeric:
91 return PartiallyUnrolled;
92 case LoopHintAttr::Disable:
93 return NotUnrolled;
94 case LoopHintAttr::Full:
95 return FullyUnrolled;
96 case LoopHintAttr::Enable:
97 return FullyUnrolled;
98 case LoopHintAttr::AssumeSafety:
99 return NotUnrolled;
100 case LoopHintAttr::FixedWidth:
101 return NotUnrolled;
102 case LoopHintAttr::ScalableWidth:
103 return NotUnrolled;
104 }
105 }
106 }
107 return NotUnrolled;
108}
109
110bool UnrollLoopsCheck::hasKnownBounds(const Stmt *Statement,
111 const IntegerLiteral *CXXLoopBound,
112 const ASTContext *Context) {
113 if (isa<CXXForRangeStmt>(Val: Statement))
114 return CXXLoopBound != nullptr;
115 // Too many possibilities in a while statement, so always recommend partial
116 // unrolling for these.
117 if (isa<WhileStmt, DoStmt>(Val: Statement))
118 return false;
119 // The last loop type is a for loop.
120 const auto *ForLoop = cast<ForStmt>(Val: Statement);
121 const Stmt *Initializer = ForLoop->getInit();
122 const Expr *Conditional = ForLoop->getCond();
123 const Expr *Increment = ForLoop->getInc();
124 if (!Initializer || !Conditional || !Increment)
125 return false;
126 // If the loop variable value isn't known, loop bounds are unknown.
127 if (const auto *InitDeclStatement = dyn_cast<DeclStmt>(Val: Initializer)) {
128 if (const auto *VariableDecl =
129 dyn_cast<VarDecl>(Val: InitDeclStatement->getSingleDecl())) {
130 APValue *Evaluation = VariableDecl->evaluateValue();
131 if (!Evaluation || !Evaluation->hasValue())
132 return false;
133 }
134 }
135 // If increment is unary and not one of ++ and --, loop bounds are unknown.
136 if (const auto *Op = dyn_cast<UnaryOperator>(Val: Increment))
137 if (!Op->isIncrementDecrementOp())
138 return false;
139
140 if (const auto *BinaryOp = dyn_cast<BinaryOperator>(Val: Conditional)) {
141 const Expr *LHS = BinaryOp->getLHS();
142 const Expr *RHS = BinaryOp->getRHS();
143 // If both sides are value dependent or constant, loop bounds are unknown.
144 return LHS->isEvaluatable(Ctx: *Context) != RHS->isEvaluatable(Ctx: *Context);
145 }
146 return false; // If it's not a binary operator, loop bounds are unknown.
147}
148
149const Expr *UnrollLoopsCheck::getCondExpr(const Stmt *Statement) {
150 if (const auto *ForLoop = dyn_cast<ForStmt>(Val: Statement))
151 return ForLoop->getCond();
152 if (const auto *WhileLoop = dyn_cast<WhileStmt>(Val: Statement))
153 return WhileLoop->getCond();
154 if (const auto *DoWhileLoop = dyn_cast<DoStmt>(Val: Statement))
155 return DoWhileLoop->getCond();
156 if (const auto *CXXRangeLoop = dyn_cast<CXXForRangeStmt>(Val: Statement))
157 return CXXRangeLoop->getCond();
158 llvm_unreachable("Unknown loop");
159}
160
161bool UnrollLoopsCheck::hasLargeNumIterations(const Stmt *Statement,
162 const IntegerLiteral *CXXLoopBound,
163 const ASTContext *Context) {
164 // Because hasKnownBounds is called before this, if this is true, then
165 // CXXLoopBound is also matched.
166 if (isa<CXXForRangeStmt>(Val: Statement)) {
167 assert(CXXLoopBound && "CXX ranged for loop has no loop bound");
168 return exprHasLargeNumIterations(CXXLoopBound, Context);
169 }
170 const auto *ForLoop = cast<ForStmt>(Val: Statement);
171 const Stmt *Initializer = ForLoop->getInit();
172 const Expr *Conditional = ForLoop->getCond();
173 const Expr *Increment = ForLoop->getInc();
174 int InitValue = 0;
175 // If the loop variable value isn't known, we can't know the loop bounds.
176 if (const auto *InitDeclStatement = dyn_cast<DeclStmt>(Val: Initializer)) {
177 if (const auto *VariableDecl =
178 dyn_cast<VarDecl>(Val: InitDeclStatement->getSingleDecl())) {
179 APValue *Evaluation = VariableDecl->evaluateValue();
180 if (!Evaluation || !Evaluation->isInt())
181 return true;
182 InitValue = Evaluation->getInt().getExtValue();
183 }
184 }
185
186 int EndValue = 0;
187 const auto *BinaryOp = cast<BinaryOperator>(Val: Conditional);
188 if (!extractValue(Value&: EndValue, Op: BinaryOp, Context))
189 return true;
190
191 double Iterations = 0.0;
192
193 // If increment is unary and not one of ++, --, we can't know the loop bounds.
194 if (const auto *Op = dyn_cast<UnaryOperator>(Val: Increment)) {
195 if (Op->isIncrementOp())
196 Iterations = EndValue - InitValue;
197 else if (Op->isDecrementOp())
198 Iterations = InitValue - EndValue;
199 else
200 llvm_unreachable("Unary operator neither increment nor decrement");
201 }
202
203 // If increment is binary and not one of +, -, *, /, we can't know the loop
204 // bounds.
205 if (const auto *Op = dyn_cast<BinaryOperator>(Val: Increment)) {
206 int ConstantValue = 0;
207 if (!extractValue(Value&: ConstantValue, Op, Context))
208 return true;
209 switch (Op->getOpcode()) {
210 case (BO_AddAssign):
211 Iterations = ceil(x: float(EndValue - InitValue) / ConstantValue);
212 break;
213 case (BO_SubAssign):
214 Iterations = ceil(x: float(InitValue - EndValue) / ConstantValue);
215 break;
216 case (BO_MulAssign):
217 Iterations = 1 + (log(x: (double)EndValue) - log(x: (double)InitValue)) /
218 log(x: (double)ConstantValue);
219 break;
220 case (BO_DivAssign):
221 Iterations = 1 + (log(x: (double)InitValue) - log(x: (double)EndValue)) /
222 log(x: (double)ConstantValue);
223 break;
224 default:
225 // All other operators are not handled; assume large bounds.
226 return true;
227 }
228 }
229 return Iterations > MaxLoopIterations;
230}
231
232bool UnrollLoopsCheck::extractValue(int &Value, const BinaryOperator *Op,
233 const ASTContext *Context) {
234 const Expr *LHS = Op->getLHS();
235 const Expr *RHS = Op->getRHS();
236 Expr::EvalResult Result;
237 if (LHS->isEvaluatable(Ctx: *Context))
238 LHS->EvaluateAsRValue(Result, Ctx: *Context);
239 else if (RHS->isEvaluatable(Ctx: *Context))
240 RHS->EvaluateAsRValue(Result, Ctx: *Context);
241 else
242 return false; // Cannot evaluate either side.
243 if (!Result.Val.isInt())
244 return false; // Cannot check number of iterations, return false to be
245 // safe.
246 Value = Result.Val.getInt().getExtValue();
247 return true;
248}
249
250bool UnrollLoopsCheck::exprHasLargeNumIterations(const Expr *Expression,
251 const ASTContext *Context) const {
252 Expr::EvalResult Result;
253 if (Expression->EvaluateAsRValue(Result, Ctx: *Context)) {
254 if (!Result.Val.isInt())
255 return false; // Cannot check number of iterations, return false to be
256 // safe.
257 // The following assumes values go from 0 to Val in increments of 1.
258 return Result.Val.getInt() > MaxLoopIterations;
259 }
260 // Cannot evaluate Expression as an r-value, so cannot check number of
261 // iterations.
262 return false;
263}
264
265void UnrollLoopsCheck::storeOptions(ClangTidyOptions::OptionMap &Opts) {
266 Options.store(Options&: Opts, LocalName: "MaxLoopIterations", Value: MaxLoopIterations);
267}
268
269} // namespace clang::tidy::altera
270

source code of clang-tools-extra/clang-tidy/altera/UnrollLoopsCheck.cpp