1 | //===- IslAst.cpp - isl code generator interface --------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // The isl code generator interface takes a Scop and generates an isl_ast. This |
10 | // ist_ast can either be returned directly or it can be pretty printed to |
11 | // stdout. |
12 | // |
13 | // A typical isl_ast output looks like this: |
14 | // |
15 | // for (c2 = max(0, ceild(n + m, 2); c2 <= min(511, floord(5 * n, 3)); c2++) { |
16 | // bb2(c2); |
17 | // } |
18 | // |
19 | // An in-depth discussion of our AST generation approach can be found in: |
20 | // |
21 | // Polyhedral AST generation is more than scanning polyhedra |
22 | // Tobias Grosser, Sven Verdoolaege, Albert Cohen |
23 | // ACM Transactions on Programming Languages and Systems (TOPLAS), |
24 | // 37(4), July 2015 |
25 | // http://www.grosser.es/#pub-polyhedral-AST-generation |
26 | // |
27 | //===----------------------------------------------------------------------===// |
28 | |
29 | #include "polly/CodeGen/IslAst.h" |
30 | #include "polly/CodeGen/CodeGeneration.h" |
31 | #include "polly/DependenceInfo.h" |
32 | #include "polly/LinkAllPasses.h" |
33 | #include "polly/Options.h" |
34 | #include "polly/ScopDetection.h" |
35 | #include "polly/ScopInfo.h" |
36 | #include "polly/ScopPass.h" |
37 | #include "polly/Support/GICHelper.h" |
38 | #include "llvm/ADT/Statistic.h" |
39 | #include "llvm/IR/Function.h" |
40 | #include "llvm/Support/Debug.h" |
41 | #include "llvm/Support/raw_ostream.h" |
42 | #include "isl/aff.h" |
43 | #include "isl/ast.h" |
44 | #include "isl/ast_build.h" |
45 | #include "isl/id.h" |
46 | #include "isl/isl-noexceptions.h" |
47 | #include "isl/printer.h" |
48 | #include "isl/schedule.h" |
49 | #include "isl/set.h" |
50 | #include "isl/union_map.h" |
51 | #include "isl/val.h" |
52 | #include <cassert> |
53 | #include <cstdlib> |
54 | |
55 | #include "polly/Support/PollyDebug.h" |
56 | #define DEBUG_TYPE "polly-ast" |
57 | |
58 | using namespace llvm; |
59 | using namespace polly; |
60 | |
61 | using IslAstUserPayload = IslAstInfo::IslAstUserPayload; |
62 | |
63 | static cl::opt<bool> |
64 | PollyParallel("polly-parallel" , |
65 | cl::desc("Generate thread parallel code (isl codegen only)" ), |
66 | cl::cat(PollyCategory)); |
67 | |
68 | static cl::opt<bool> PrintAccesses("polly-ast-print-accesses" , |
69 | cl::desc("Print memory access functions" ), |
70 | cl::cat(PollyCategory)); |
71 | |
72 | static cl::opt<bool> PollyParallelForce( |
73 | "polly-parallel-force" , |
74 | cl::desc( |
75 | "Force generation of thread parallel code ignoring any cost model" ), |
76 | cl::cat(PollyCategory)); |
77 | |
78 | static cl::opt<bool> UseContext("polly-ast-use-context" , |
79 | cl::desc("Use context" ), cl::Hidden, |
80 | cl::init(Val: true), cl::cat(PollyCategory)); |
81 | |
82 | static cl::opt<bool> DetectParallel("polly-ast-detect-parallel" , |
83 | cl::desc("Detect parallelism" ), cl::Hidden, |
84 | cl::cat(PollyCategory)); |
85 | |
86 | STATISTIC(ScopsProcessed, "Number of SCoPs processed" ); |
87 | STATISTIC(ScopsBeneficial, "Number of beneficial SCoPs" ); |
88 | STATISTIC(BeneficialAffineLoops, "Number of beneficial affine loops" ); |
89 | STATISTIC(BeneficialBoxedLoops, "Number of beneficial boxed loops" ); |
90 | |
91 | STATISTIC(NumForLoops, "Number of for-loops" ); |
92 | STATISTIC(NumParallel, "Number of parallel for-loops" ); |
93 | STATISTIC(NumInnermostParallel, "Number of innermost parallel for-loops" ); |
94 | STATISTIC(NumOutermostParallel, "Number of outermost parallel for-loops" ); |
95 | STATISTIC(NumReductionParallel, "Number of reduction-parallel for-loops" ); |
96 | STATISTIC(NumExecutedInParallel, "Number of for-loops executed in parallel" ); |
97 | STATISTIC(NumIfConditions, "Number of if-conditions" ); |
98 | |
99 | namespace polly { |
100 | |
101 | /// Temporary information used when building the ast. |
102 | struct AstBuildUserInfo { |
103 | /// Construct and initialize the helper struct for AST creation. |
104 | AstBuildUserInfo() = default; |
105 | |
106 | /// The dependence information used for the parallelism check. |
107 | const Dependences *Deps = nullptr; |
108 | |
109 | /// Flag to indicate that we are inside a parallel for node. |
110 | bool InParallelFor = false; |
111 | |
112 | /// Flag to indicate that we are inside an SIMD node. |
113 | bool InSIMD = false; |
114 | |
115 | /// The last iterator id created for the current SCoP. |
116 | isl_id *LastForNodeId = nullptr; |
117 | }; |
118 | } // namespace polly |
119 | |
120 | /// Free an IslAstUserPayload object pointed to by @p Ptr. |
121 | static void freeIslAstUserPayload(void *Ptr) { |
122 | delete ((IslAstInfo::IslAstUserPayload *)Ptr); |
123 | } |
124 | |
125 | /// Print a string @p str in a single line using @p Printer. |
126 | static isl_printer *printLine(__isl_take isl_printer *Printer, |
127 | const std::string &str, |
128 | __isl_keep isl_pw_aff *PWA = nullptr) { |
129 | Printer = isl_printer_start_line(p: Printer); |
130 | Printer = isl_printer_print_str(p: Printer, s: str.c_str()); |
131 | if (PWA) |
132 | Printer = isl_printer_print_pw_aff(p: Printer, pwaff: PWA); |
133 | return isl_printer_end_line(p: Printer); |
134 | } |
135 | |
136 | /// Return all broken reductions as a string of clauses (OpenMP style). |
137 | static const std::string getBrokenReductionsStr(const isl::ast_node &Node) { |
138 | IslAstInfo::MemoryAccessSet *BrokenReductions; |
139 | std::string str; |
140 | |
141 | BrokenReductions = IslAstInfo::getBrokenReductions(Node); |
142 | if (!BrokenReductions || BrokenReductions->empty()) |
143 | return "" ; |
144 | |
145 | // Map each type of reduction to a comma separated list of the base addresses. |
146 | std::map<MemoryAccess::ReductionType, std::string> Clauses; |
147 | for (MemoryAccess *MA : *BrokenReductions) |
148 | if (MA->isWrite()) |
149 | Clauses[MA->getReductionType()] += |
150 | ", " + MA->getScopArrayInfo()->getName(); |
151 | |
152 | // Now print the reductions sorted by type. Each type will cause a clause |
153 | // like: reduction (+ : sum0, sum1, sum2) |
154 | for (const auto &ReductionClause : Clauses) { |
155 | str += " reduction (" ; |
156 | str += MemoryAccess::getReductionOperatorStr(RT: ReductionClause.first); |
157 | // Remove the first two symbols (", ") to make the output look pretty. |
158 | str += " : " + ReductionClause.second.substr(pos: 2) + ")" ; |
159 | } |
160 | |
161 | return str; |
162 | } |
163 | |
164 | /// Callback executed for each for node in the ast in order to print it. |
165 | static isl_printer *cbPrintFor(__isl_take isl_printer *Printer, |
166 | __isl_take isl_ast_print_options *Options, |
167 | __isl_keep isl_ast_node *Node, void *) { |
168 | isl::pw_aff DD = |
169 | IslAstInfo::getMinimalDependenceDistance(Node: isl::manage_copy(ptr: Node)); |
170 | const std::string BrokenReductionsStr = |
171 | getBrokenReductionsStr(Node: isl::manage_copy(ptr: Node)); |
172 | const std::string KnownParallelStr = "#pragma known-parallel" ; |
173 | const std::string DepDisPragmaStr = "#pragma minimal dependence distance: " ; |
174 | const std::string SimdPragmaStr = "#pragma simd" ; |
175 | const std::string OmpPragmaStr = "#pragma omp parallel for" ; |
176 | |
177 | if (!DD.is_null()) |
178 | Printer = printLine(Printer, str: DepDisPragmaStr, PWA: DD.get()); |
179 | |
180 | if (IslAstInfo::isInnermostParallel(Node: isl::manage_copy(ptr: Node))) |
181 | Printer = printLine(Printer, str: SimdPragmaStr + BrokenReductionsStr); |
182 | |
183 | if (IslAstInfo::isExecutedInParallel(Node: isl::manage_copy(ptr: Node))) |
184 | Printer = printLine(Printer, str: OmpPragmaStr); |
185 | else if (IslAstInfo::isOutermostParallel(Node: isl::manage_copy(ptr: Node))) |
186 | Printer = printLine(Printer, str: KnownParallelStr + BrokenReductionsStr); |
187 | |
188 | return isl_ast_node_for_print(node: Node, p: Printer, options: Options); |
189 | } |
190 | |
191 | /// Check if the current scheduling dimension is parallel. |
192 | /// |
193 | /// In case the dimension is parallel we also check if any reduction |
194 | /// dependences is broken when we exploit this parallelism. If so, |
195 | /// @p IsReductionParallel will be set to true. The reduction dependences we use |
196 | /// to check are actually the union of the transitive closure of the initial |
197 | /// reduction dependences together with their reversal. Even though these |
198 | /// dependences connect all iterations with each other (thus they are cyclic) |
199 | /// we can perform the parallelism check as we are only interested in a zero |
200 | /// (or non-zero) dependence distance on the dimension in question. |
201 | static bool astScheduleDimIsParallel(const isl::ast_build &Build, |
202 | const Dependences *D, |
203 | IslAstUserPayload *NodeInfo) { |
204 | if (!D->hasValidDependences()) |
205 | return false; |
206 | |
207 | isl::union_map Schedule = Build.get_schedule(); |
208 | isl::union_map Dep = D->getDependences( |
209 | Kinds: Dependences::TYPE_RAW | Dependences::TYPE_WAW | Dependences::TYPE_WAR); |
210 | |
211 | if (!D->isParallel(Schedule: Schedule.get(), Deps: Dep.release())) { |
212 | isl::union_map DepsAll = |
213 | D->getDependences(Kinds: Dependences::TYPE_RAW | Dependences::TYPE_WAW | |
214 | Dependences::TYPE_WAR | Dependences::TYPE_TC_RED); |
215 | // TODO: We will need to change isParallel to stop the unwrapping |
216 | isl_pw_aff *MinimalDependenceDistanceIsl = nullptr; |
217 | D->isParallel(Schedule: Schedule.get(), Deps: DepsAll.release(), |
218 | MinDistancePtr: &MinimalDependenceDistanceIsl); |
219 | NodeInfo->MinimalDependenceDistance = |
220 | isl::manage(ptr: MinimalDependenceDistanceIsl); |
221 | return false; |
222 | } |
223 | |
224 | isl::union_map RedDeps = D->getDependences(Kinds: Dependences::TYPE_TC_RED); |
225 | if (!D->isParallel(Schedule: Schedule.get(), Deps: RedDeps.release())) |
226 | NodeInfo->IsReductionParallel = true; |
227 | |
228 | if (!NodeInfo->IsReductionParallel) |
229 | return true; |
230 | |
231 | for (const auto &MaRedPair : D->getReductionDependences()) { |
232 | if (!MaRedPair.second) |
233 | continue; |
234 | isl::union_map MaRedDeps = isl::manage_copy(ptr: MaRedPair.second); |
235 | if (!D->isParallel(Schedule: Schedule.get(), Deps: MaRedDeps.release())) |
236 | NodeInfo->BrokenReductions.insert(Ptr: MaRedPair.first); |
237 | } |
238 | return true; |
239 | } |
240 | |
241 | // This method is executed before the construction of a for node. It creates |
242 | // an isl_id that is used to annotate the subsequently generated ast for nodes. |
243 | // |
244 | // In this function we also run the following analyses: |
245 | // |
246 | // - Detection of openmp parallel loops |
247 | // |
248 | static __isl_give isl_id *astBuildBeforeFor(__isl_keep isl_ast_build *Build, |
249 | void *User) { |
250 | AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User; |
251 | IslAstUserPayload *Payload = new IslAstUserPayload(); |
252 | isl_id *Id = isl_id_alloc(ctx: isl_ast_build_get_ctx(build: Build), name: "" , user: Payload); |
253 | Id = isl_id_set_free_user(id: Id, free_user: freeIslAstUserPayload); |
254 | BuildInfo->LastForNodeId = Id; |
255 | |
256 | Payload->IsParallel = astScheduleDimIsParallel(Build: isl::manage_copy(ptr: Build), |
257 | D: BuildInfo->Deps, NodeInfo: Payload); |
258 | |
259 | // Test for parallelism only if we are not already inside a parallel loop |
260 | if (!BuildInfo->InParallelFor && !BuildInfo->InSIMD) |
261 | BuildInfo->InParallelFor = Payload->IsOutermostParallel = |
262 | Payload->IsParallel; |
263 | |
264 | return Id; |
265 | } |
266 | |
267 | // This method is executed after the construction of a for node. |
268 | // |
269 | // It performs the following actions: |
270 | // |
271 | // - Reset the 'InParallelFor' flag, as soon as we leave a for node, |
272 | // that is marked as openmp parallel. |
273 | // |
274 | static __isl_give isl_ast_node * |
275 | astBuildAfterFor(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Build, |
276 | void *User) { |
277 | isl_id *Id = isl_ast_node_get_annotation(node: Node); |
278 | assert(Id && "Post order visit assumes annotated for nodes" ); |
279 | IslAstUserPayload *Payload = (IslAstUserPayload *)isl_id_get_user(id: Id); |
280 | assert(Payload && "Post order visit assumes annotated for nodes" ); |
281 | |
282 | AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User; |
283 | assert(Payload->Build.is_null() && "Build environment already set" ); |
284 | Payload->Build = isl::manage_copy(ptr: Build); |
285 | Payload->IsInnermost = (Id == BuildInfo->LastForNodeId); |
286 | |
287 | Payload->IsInnermostParallel = |
288 | Payload->IsInnermost && (BuildInfo->InSIMD || Payload->IsParallel); |
289 | if (Payload->IsOutermostParallel) |
290 | BuildInfo->InParallelFor = false; |
291 | |
292 | isl_id_free(id: Id); |
293 | return Node; |
294 | } |
295 | |
296 | static isl_stat (__isl_keep isl_id *MarkId, |
297 | __isl_keep isl_ast_build *Build, |
298 | void *User) { |
299 | if (!MarkId) |
300 | return isl_stat_error; |
301 | |
302 | AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User; |
303 | if (strcmp(s1: isl_id_get_name(id: MarkId), s2: "SIMD" ) == 0) |
304 | BuildInfo->InSIMD = true; |
305 | |
306 | return isl_stat_ok; |
307 | } |
308 | |
309 | static __isl_give isl_ast_node * |
310 | astBuildAfterMark(__isl_take isl_ast_node *Node, |
311 | __isl_keep isl_ast_build *Build, void *User) { |
312 | assert(isl_ast_node_get_type(Node) == isl_ast_node_mark); |
313 | AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User; |
314 | auto *Id = isl_ast_node_mark_get_id(node: Node); |
315 | if (strcmp(s1: isl_id_get_name(id: Id), s2: "SIMD" ) == 0) |
316 | BuildInfo->InSIMD = false; |
317 | isl_id_free(id: Id); |
318 | return Node; |
319 | } |
320 | |
321 | static __isl_give isl_ast_node *AtEachDomain(__isl_take isl_ast_node *Node, |
322 | __isl_keep isl_ast_build *Build, |
323 | void *User) { |
324 | assert(!isl_ast_node_get_annotation(Node) && "Node already annotated" ); |
325 | |
326 | IslAstUserPayload *Payload = new IslAstUserPayload(); |
327 | isl_id *Id = isl_id_alloc(ctx: isl_ast_build_get_ctx(build: Build), name: "" , user: Payload); |
328 | Id = isl_id_set_free_user(id: Id, free_user: freeIslAstUserPayload); |
329 | |
330 | Payload->Build = isl::manage_copy(ptr: Build); |
331 | |
332 | return isl_ast_node_set_annotation(node: Node, annotation: Id); |
333 | } |
334 | |
335 | // Build alias check condition given a pair of minimal/maximal access. |
336 | static isl::ast_expr buildCondition(Scop &S, isl::ast_build Build, |
337 | const Scop::MinMaxAccessTy *It0, |
338 | const Scop::MinMaxAccessTy *It1) { |
339 | |
340 | isl::pw_multi_aff AFirst = It0->first; |
341 | isl::pw_multi_aff ASecond = It0->second; |
342 | isl::pw_multi_aff BFirst = It1->first; |
343 | isl::pw_multi_aff BSecond = It1->second; |
344 | |
345 | isl::id Left = AFirst.get_tuple_id(type: isl::dim::set); |
346 | isl::id Right = BFirst.get_tuple_id(type: isl::dim::set); |
347 | |
348 | isl::ast_expr True = |
349 | isl::ast_expr::from_val(v: isl::val::int_from_ui(ctx: Build.ctx(), u: 1)); |
350 | isl::ast_expr False = |
351 | isl::ast_expr::from_val(v: isl::val::int_from_ui(ctx: Build.ctx(), u: 0)); |
352 | |
353 | const ScopArrayInfo *BaseLeft = |
354 | ScopArrayInfo::getFromId(Id: Left)->getBasePtrOriginSAI(); |
355 | const ScopArrayInfo *BaseRight = |
356 | ScopArrayInfo::getFromId(Id: Right)->getBasePtrOriginSAI(); |
357 | if (BaseLeft && BaseLeft == BaseRight) |
358 | return True; |
359 | |
360 | isl::set Params = S.getContext(); |
361 | |
362 | isl::ast_expr NonAliasGroup, MinExpr, MaxExpr; |
363 | |
364 | // In the following, we first check if any accesses will be empty under |
365 | // the execution context of the scop and do not code generate them if this |
366 | // is the case as isl will fail to derive valid AST expressions for such |
367 | // accesses. |
368 | |
369 | if (!AFirst.intersect_params(set: Params).domain().is_empty() && |
370 | !BSecond.intersect_params(set: Params).domain().is_empty()) { |
371 | MinExpr = Build.access_from(pma: AFirst).address_of(); |
372 | MaxExpr = Build.access_from(pma: BSecond).address_of(); |
373 | NonAliasGroup = MaxExpr.le(expr2: MinExpr); |
374 | } |
375 | |
376 | if (!BFirst.intersect_params(set: Params).domain().is_empty() && |
377 | !ASecond.intersect_params(set: Params).domain().is_empty()) { |
378 | MinExpr = Build.access_from(pma: BFirst).address_of(); |
379 | MaxExpr = Build.access_from(pma: ASecond).address_of(); |
380 | |
381 | isl::ast_expr Result = MaxExpr.le(expr2: MinExpr); |
382 | if (!NonAliasGroup.is_null()) |
383 | NonAliasGroup = isl::manage( |
384 | ptr: isl_ast_expr_or(expr1: NonAliasGroup.release(), expr2: Result.release())); |
385 | else |
386 | NonAliasGroup = Result; |
387 | } |
388 | |
389 | if (NonAliasGroup.is_null()) |
390 | NonAliasGroup = True; |
391 | |
392 | return NonAliasGroup; |
393 | } |
394 | |
395 | isl::ast_expr IslAst::buildRunCondition(Scop &S, const isl::ast_build &Build) { |
396 | isl::ast_expr RunCondition; |
397 | |
398 | // The conditions that need to be checked at run-time for this scop are |
399 | // available as an isl_set in the runtime check context from which we can |
400 | // directly derive a run-time condition. |
401 | auto PosCond = Build.expr_from(set: S.getAssumedContext()); |
402 | if (S.hasTrivialInvalidContext()) { |
403 | RunCondition = std::move(PosCond); |
404 | } else { |
405 | auto ZeroV = isl::val::zero(ctx: Build.ctx()); |
406 | auto NegCond = Build.expr_from(set: S.getInvalidContext()); |
407 | auto NotNegCond = |
408 | isl::ast_expr::from_val(v: std::move(ZeroV)).eq(expr2: std::move(NegCond)); |
409 | RunCondition = |
410 | isl::manage(ptr: isl_ast_expr_and(expr1: PosCond.release(), expr2: NotNegCond.release())); |
411 | } |
412 | |
413 | // Create the alias checks from the minimal/maximal accesses in each alias |
414 | // group which consists of read only and non read only (read write) accesses. |
415 | // This operation is by construction quadratic in the read-write pointers and |
416 | // linear in the read only pointers in each alias group. |
417 | for (const Scop::MinMaxVectorPairTy &MinMaxAccessPair : S.getAliasGroups()) { |
418 | auto &MinMaxReadWrite = MinMaxAccessPair.first; |
419 | auto &MinMaxReadOnly = MinMaxAccessPair.second; |
420 | auto RWAccEnd = MinMaxReadWrite.end(); |
421 | |
422 | for (auto RWAccIt0 = MinMaxReadWrite.begin(); RWAccIt0 != RWAccEnd; |
423 | ++RWAccIt0) { |
424 | for (auto RWAccIt1 = RWAccIt0 + 1; RWAccIt1 != RWAccEnd; ++RWAccIt1) |
425 | RunCondition = isl::manage(ptr: isl_ast_expr_and( |
426 | expr1: RunCondition.release(), |
427 | expr2: buildCondition(S, Build, It0: RWAccIt0, It1: RWAccIt1).release())); |
428 | for (const Scop::MinMaxAccessTy &ROAccIt : MinMaxReadOnly) |
429 | RunCondition = isl::manage(ptr: isl_ast_expr_and( |
430 | expr1: RunCondition.release(), |
431 | expr2: buildCondition(S, Build, It0: RWAccIt0, It1: &ROAccIt).release())); |
432 | } |
433 | } |
434 | |
435 | return RunCondition; |
436 | } |
437 | |
438 | /// Simple cost analysis for a given SCoP. |
439 | /// |
440 | /// TODO: Improve this analysis and extract it to make it usable in other |
441 | /// places too. |
442 | /// In order to improve the cost model we could either keep track of |
443 | /// performed optimizations (e.g., tiling) or compute properties on the |
444 | /// original as well as optimized SCoP (e.g., #stride-one-accesses). |
445 | static bool benefitsFromPolly(Scop &Scop, bool PerformParallelTest) { |
446 | if (PollyProcessUnprofitable) |
447 | return true; |
448 | |
449 | // Check if nothing interesting happened. |
450 | if (!PerformParallelTest && !Scop.isOptimized() && |
451 | Scop.getAliasGroups().empty()) |
452 | return false; |
453 | |
454 | // The default assumption is that Polly improves the code. |
455 | return true; |
456 | } |
457 | |
458 | /// Collect statistics for the syntax tree rooted at @p Ast. |
459 | static void walkAstForStatistics(const isl::ast_node &Ast) { |
460 | assert(!Ast.is_null()); |
461 | isl_ast_node_foreach_descendant_top_down( |
462 | node: Ast.get(), |
463 | fn: [](__isl_keep isl_ast_node *Node, void *User) -> isl_bool { |
464 | switch (isl_ast_node_get_type(node: Node)) { |
465 | case isl_ast_node_for: |
466 | NumForLoops++; |
467 | if (IslAstInfo::isParallel(Node: isl::manage_copy(ptr: Node))) |
468 | NumParallel++; |
469 | if (IslAstInfo::isInnermostParallel(Node: isl::manage_copy(ptr: Node))) |
470 | NumInnermostParallel++; |
471 | if (IslAstInfo::isOutermostParallel(Node: isl::manage_copy(ptr: Node))) |
472 | NumOutermostParallel++; |
473 | if (IslAstInfo::isReductionParallel(Node: isl::manage_copy(ptr: Node))) |
474 | NumReductionParallel++; |
475 | if (IslAstInfo::isExecutedInParallel(Node: isl::manage_copy(ptr: Node))) |
476 | NumExecutedInParallel++; |
477 | break; |
478 | |
479 | case isl_ast_node_if: |
480 | NumIfConditions++; |
481 | break; |
482 | |
483 | default: |
484 | break; |
485 | } |
486 | |
487 | // Continue traversing subtrees. |
488 | return isl_bool_true; |
489 | }, |
490 | user: nullptr); |
491 | } |
492 | |
493 | IslAst::IslAst(Scop &Scop) : S(Scop), Ctx(Scop.getSharedIslCtx()) {} |
494 | |
495 | IslAst::IslAst(IslAst &&O) |
496 | : S(O.S), Ctx(O.Ctx), RunCondition(std::move(O.RunCondition)), |
497 | Root(std::move(O.Root)) {} |
498 | |
499 | void IslAst::init(const Dependences &D) { |
500 | bool PerformParallelTest = PollyParallel || DetectParallel || |
501 | PollyVectorizerChoice != VECTORIZER_NONE; |
502 | auto ScheduleTree = S.getScheduleTree(); |
503 | |
504 | // Skip AST and code generation if there was no benefit achieved. |
505 | if (!benefitsFromPolly(Scop&: S, PerformParallelTest)) |
506 | return; |
507 | |
508 | auto ScopStats = S.getStatistics(); |
509 | ScopsBeneficial++; |
510 | BeneficialAffineLoops += ScopStats.NumAffineLoops; |
511 | BeneficialBoxedLoops += ScopStats.NumBoxedLoops; |
512 | |
513 | auto Ctx = S.getIslCtx(); |
514 | isl_options_set_ast_build_atomic_upper_bound(ctx: Ctx.get(), val: true); |
515 | isl_options_set_ast_build_detect_min_max(ctx: Ctx.get(), val: true); |
516 | isl_ast_build *Build; |
517 | AstBuildUserInfo BuildInfo; |
518 | |
519 | if (UseContext) |
520 | Build = isl_ast_build_from_context(set: S.getContext().release()); |
521 | else |
522 | Build = isl_ast_build_from_context( |
523 | set: isl_set_universe(space: S.getParamSpace().release())); |
524 | |
525 | Build = isl_ast_build_set_at_each_domain(build: Build, fn: AtEachDomain, user: nullptr); |
526 | |
527 | if (PerformParallelTest) { |
528 | BuildInfo.Deps = &D; |
529 | BuildInfo.InParallelFor = false; |
530 | BuildInfo.InSIMD = false; |
531 | |
532 | Build = isl_ast_build_set_before_each_for(build: Build, fn: &astBuildBeforeFor, |
533 | user: &BuildInfo); |
534 | Build = |
535 | isl_ast_build_set_after_each_for(build: Build, fn: &astBuildAfterFor, user: &BuildInfo); |
536 | |
537 | Build = isl_ast_build_set_before_each_mark(build: Build, fn: &astBuildBeforeMark, |
538 | user: &BuildInfo); |
539 | |
540 | Build = isl_ast_build_set_after_each_mark(build: Build, fn: &astBuildAfterMark, |
541 | user: &BuildInfo); |
542 | } |
543 | |
544 | RunCondition = buildRunCondition(S, Build: isl::manage_copy(ptr: Build)); |
545 | |
546 | Root = isl::manage( |
547 | ptr: isl_ast_build_node_from_schedule(build: Build, schedule: S.getScheduleTree().release())); |
548 | walkAstForStatistics(Ast: Root); |
549 | |
550 | isl_ast_build_free(build: Build); |
551 | } |
552 | |
553 | IslAst IslAst::create(Scop &Scop, const Dependences &D) { |
554 | IslAst Ast{Scop}; |
555 | Ast.init(D); |
556 | return Ast; |
557 | } |
558 | |
559 | isl::ast_node IslAst::getAst() { return Root; } |
560 | isl::ast_expr IslAst::getRunCondition() { return RunCondition; } |
561 | |
562 | isl::ast_node IslAstInfo::getAst() { return Ast.getAst(); } |
563 | isl::ast_expr IslAstInfo::getRunCondition() { return Ast.getRunCondition(); } |
564 | |
565 | IslAstUserPayload *IslAstInfo::getNodePayload(const isl::ast_node &Node) { |
566 | isl::id Id = Node.get_annotation(); |
567 | if (Id.is_null()) |
568 | return nullptr; |
569 | IslAstUserPayload *Payload = (IslAstUserPayload *)Id.get_user(); |
570 | return Payload; |
571 | } |
572 | |
573 | bool IslAstInfo::isInnermost(const isl::ast_node &Node) { |
574 | IslAstUserPayload *Payload = getNodePayload(Node); |
575 | return Payload && Payload->IsInnermost; |
576 | } |
577 | |
578 | bool IslAstInfo::isParallel(const isl::ast_node &Node) { |
579 | return IslAstInfo::isInnermostParallel(Node) || |
580 | IslAstInfo::isOutermostParallel(Node); |
581 | } |
582 | |
583 | bool IslAstInfo::isInnermostParallel(const isl::ast_node &Node) { |
584 | IslAstUserPayload *Payload = getNodePayload(Node); |
585 | return Payload && Payload->IsInnermostParallel; |
586 | } |
587 | |
588 | bool IslAstInfo::isOutermostParallel(const isl::ast_node &Node) { |
589 | IslAstUserPayload *Payload = getNodePayload(Node); |
590 | return Payload && Payload->IsOutermostParallel; |
591 | } |
592 | |
593 | bool IslAstInfo::isReductionParallel(const isl::ast_node &Node) { |
594 | IslAstUserPayload *Payload = getNodePayload(Node); |
595 | return Payload && Payload->IsReductionParallel; |
596 | } |
597 | |
598 | bool IslAstInfo::isExecutedInParallel(const isl::ast_node &Node) { |
599 | if (!PollyParallel) |
600 | return false; |
601 | |
602 | // Do not parallelize innermost loops. |
603 | // |
604 | // Parallelizing innermost loops is often not profitable, especially if |
605 | // they have a low number of iterations. |
606 | // |
607 | // TODO: Decide this based on the number of loop iterations that will be |
608 | // executed. This can possibly require run-time checks, which again |
609 | // raises the question of both run-time check overhead and code size |
610 | // costs. |
611 | if (!PollyParallelForce && isInnermost(Node)) |
612 | return false; |
613 | |
614 | return isOutermostParallel(Node) && !isReductionParallel(Node); |
615 | } |
616 | |
617 | isl::union_map IslAstInfo::getSchedule(const isl::ast_node &Node) { |
618 | IslAstUserPayload *Payload = getNodePayload(Node); |
619 | return Payload ? Payload->Build.get_schedule() : isl::union_map(); |
620 | } |
621 | |
622 | isl::pw_aff |
623 | IslAstInfo::getMinimalDependenceDistance(const isl::ast_node &Node) { |
624 | IslAstUserPayload *Payload = getNodePayload(Node); |
625 | return Payload ? Payload->MinimalDependenceDistance : isl::pw_aff(); |
626 | } |
627 | |
628 | IslAstInfo::MemoryAccessSet * |
629 | IslAstInfo::getBrokenReductions(const isl::ast_node &Node) { |
630 | IslAstUserPayload *Payload = getNodePayload(Node); |
631 | return Payload ? &Payload->BrokenReductions : nullptr; |
632 | } |
633 | |
634 | isl::ast_build IslAstInfo::getBuild(const isl::ast_node &Node) { |
635 | IslAstUserPayload *Payload = getNodePayload(Node); |
636 | return Payload ? Payload->Build : isl::ast_build(); |
637 | } |
638 | |
639 | static std::unique_ptr<IslAstInfo> runIslAst( |
640 | Scop &Scop, |
641 | function_ref<const Dependences &(Dependences::AnalysisLevel)> GetDeps) { |
642 | ScopsProcessed++; |
643 | |
644 | const Dependences &D = GetDeps(Dependences::AL_Statement); |
645 | |
646 | if (D.getSharedIslCtx() != Scop.getSharedIslCtx()) { |
647 | POLLY_DEBUG( |
648 | dbgs() << "Got dependence analysis for different SCoP/isl_ctx\n" ); |
649 | return {}; |
650 | } |
651 | |
652 | std::unique_ptr<IslAstInfo> Ast = std::make_unique<IslAstInfo>(args&: Scop, args: D); |
653 | |
654 | POLLY_DEBUG({ |
655 | if (Ast) |
656 | Ast->print(dbgs()); |
657 | }); |
658 | |
659 | return Ast; |
660 | } |
661 | |
662 | IslAstInfo IslAstAnalysis::run(Scop &S, ScopAnalysisManager &SAM, |
663 | ScopStandardAnalysisResults &SAR) { |
664 | auto GetDeps = [&](Dependences::AnalysisLevel Lvl) -> const Dependences & { |
665 | return SAM.getResult<DependenceAnalysis>(IR&: S, ExtraArgs&: SAR).getDependences(Level: Lvl); |
666 | }; |
667 | |
668 | return std::move(*runIslAst(Scop&: S, GetDeps)); |
669 | } |
670 | |
671 | static __isl_give isl_printer *cbPrintUser(__isl_take isl_printer *P, |
672 | __isl_take isl_ast_print_options *O, |
673 | __isl_keep isl_ast_node *Node, |
674 | void *User) { |
675 | isl::ast_node_user AstNode = isl::manage_copy(ptr: Node).as<isl::ast_node_user>(); |
676 | isl::ast_expr NodeExpr = AstNode.expr(); |
677 | isl::ast_expr CallExpr = NodeExpr.get_op_arg(pos: 0); |
678 | isl::id CallExprId = CallExpr.get_id(); |
679 | ScopStmt *AccessStmt = (ScopStmt *)CallExprId.get_user(); |
680 | |
681 | P = isl_printer_start_line(p: P); |
682 | P = isl_printer_print_str(p: P, s: AccessStmt->getBaseName()); |
683 | P = isl_printer_print_str(p: P, s: "(" ); |
684 | P = isl_printer_end_line(p: P); |
685 | P = isl_printer_indent(p: P, indent: 2); |
686 | |
687 | for (MemoryAccess *MemAcc : *AccessStmt) { |
688 | P = isl_printer_start_line(p: P); |
689 | |
690 | if (MemAcc->isRead()) |
691 | P = isl_printer_print_str(p: P, s: "/* read */ &" ); |
692 | else |
693 | P = isl_printer_print_str(p: P, s: "/* write */ " ); |
694 | |
695 | isl::ast_build Build = IslAstInfo::getBuild(Node: isl::manage_copy(ptr: Node)); |
696 | if (MemAcc->isAffine()) { |
697 | isl_pw_multi_aff *PwmaPtr = |
698 | MemAcc->applyScheduleToAccessRelation(Schedule: Build.get_schedule()).release(); |
699 | isl::pw_multi_aff Pwma = isl::manage(ptr: PwmaPtr); |
700 | isl::ast_expr AccessExpr = Build.access_from(pma: Pwma); |
701 | P = isl_printer_print_ast_expr(p: P, expr: AccessExpr.get()); |
702 | } else { |
703 | P = isl_printer_print_str( |
704 | p: P, s: MemAcc->getLatestScopArrayInfo()->getName().c_str()); |
705 | P = isl_printer_print_str(p: P, s: "[*]" ); |
706 | } |
707 | P = isl_printer_end_line(p: P); |
708 | } |
709 | |
710 | P = isl_printer_indent(p: P, indent: -2); |
711 | P = isl_printer_start_line(p: P); |
712 | P = isl_printer_print_str(p: P, s: ");" ); |
713 | P = isl_printer_end_line(p: P); |
714 | |
715 | isl_ast_print_options_free(options: O); |
716 | return P; |
717 | } |
718 | |
719 | void IslAstInfo::print(raw_ostream &OS) { |
720 | isl_ast_print_options *Options; |
721 | isl::ast_node RootNode = Ast.getAst(); |
722 | Function &F = S.getFunction(); |
723 | |
724 | OS << ":: isl ast :: " << F.getName() << " :: " << S.getNameStr() << "\n" ; |
725 | |
726 | if (RootNode.is_null()) { |
727 | OS << ":: isl ast generation and code generation was skipped!\n\n" ; |
728 | OS << ":: This is either because no useful optimizations could be applied " |
729 | "(use -polly-process-unprofitable to enforce code generation) or " |
730 | "because earlier passes such as dependence analysis timed out (use " |
731 | "-polly-dependences-computeout=0 to set dependence analysis timeout " |
732 | "to infinity)\n\n" ; |
733 | return; |
734 | } |
735 | |
736 | isl::ast_expr RunCondition = Ast.getRunCondition(); |
737 | char *RtCStr, *AstStr; |
738 | |
739 | Options = isl_ast_print_options_alloc(ctx: S.getIslCtx().get()); |
740 | |
741 | if (PrintAccesses) |
742 | Options = |
743 | isl_ast_print_options_set_print_user(options: Options, print_user: cbPrintUser, user: nullptr); |
744 | Options = isl_ast_print_options_set_print_for(options: Options, print_for: cbPrintFor, user: nullptr); |
745 | |
746 | isl_printer *P = isl_printer_to_str(ctx: S.getIslCtx().get()); |
747 | P = isl_printer_set_output_format(p: P, ISL_FORMAT_C); |
748 | P = isl_printer_print_ast_expr(p: P, expr: RunCondition.get()); |
749 | RtCStr = isl_printer_get_str(printer: P); |
750 | P = isl_printer_flush(p: P); |
751 | P = isl_printer_indent(p: P, indent: 4); |
752 | P = isl_ast_node_print(node: RootNode.get(), p: P, options: Options); |
753 | AstStr = isl_printer_get_str(printer: P); |
754 | |
755 | POLLY_DEBUG({ |
756 | dbgs() << S.getContextStr() << "\n" ; |
757 | dbgs() << stringFromIslObj(S.getScheduleTree(), "null" ); |
758 | }); |
759 | OS << "\nif (" << RtCStr << ")\n\n" ; |
760 | OS << AstStr << "\n" ; |
761 | OS << "else\n" ; |
762 | OS << " { /* original code */ }\n\n" ; |
763 | |
764 | free(ptr: RtCStr); |
765 | free(ptr: AstStr); |
766 | |
767 | isl_printer_free(printer: P); |
768 | } |
769 | |
770 | AnalysisKey IslAstAnalysis::Key; |
771 | PreservedAnalyses IslAstPrinterPass::run(Scop &S, ScopAnalysisManager &SAM, |
772 | ScopStandardAnalysisResults &SAR, |
773 | SPMUpdater &U) { |
774 | auto &Ast = SAM.getResult<IslAstAnalysis>(IR&: S, ExtraArgs&: SAR); |
775 | Ast.print(OS); |
776 | return PreservedAnalyses::all(); |
777 | } |
778 | |
779 | void IslAstInfoWrapperPass::releaseMemory() { Ast.reset(); } |
780 | |
781 | bool IslAstInfoWrapperPass::runOnScop(Scop &Scop) { |
782 | auto GetDeps = [this](Dependences::AnalysisLevel Lvl) -> const Dependences & { |
783 | return getAnalysis<DependenceInfo>().getDependences(Level: Lvl); |
784 | }; |
785 | |
786 | Ast = runIslAst(Scop, GetDeps); |
787 | |
788 | return false; |
789 | } |
790 | |
791 | void IslAstInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { |
792 | // Get the Common analysis usage of ScopPasses. |
793 | ScopPass::getAnalysisUsage(AU); |
794 | AU.addRequiredTransitive<ScopInfoRegionPass>(); |
795 | AU.addRequired<DependenceInfo>(); |
796 | |
797 | AU.addPreserved<DependenceInfo>(); |
798 | } |
799 | |
800 | void IslAstInfoWrapperPass::printScop(raw_ostream &OS, Scop &S) const { |
801 | OS << "Printing analysis 'Polly - Generate an AST of the SCoP (isl)'" |
802 | << S.getName() << "' in function '" << S.getFunction().getName() << "':\n" ; |
803 | if (Ast) |
804 | Ast->print(OS); |
805 | } |
806 | |
807 | char IslAstInfoWrapperPass::ID = 0; |
808 | |
809 | Pass *polly::createIslAstInfoWrapperPassPass() { |
810 | return new IslAstInfoWrapperPass(); |
811 | } |
812 | |
813 | INITIALIZE_PASS_BEGIN(IslAstInfoWrapperPass, "polly-ast" , |
814 | "Polly - Generate an AST of the SCoP (isl)" , false, |
815 | false); |
816 | INITIALIZE_PASS_DEPENDENCY(ScopInfoRegionPass); |
817 | INITIALIZE_PASS_DEPENDENCY(DependenceInfo); |
818 | INITIALIZE_PASS_END(IslAstInfoWrapperPass, "polly-ast" , |
819 | "Polly - Generate an AST from the SCoP (isl)" , false, false) |
820 | |
821 | //===----------------------------------------------------------------------===// |
822 | |
823 | namespace { |
824 | /// Print result from IslAstInfoWrapperPass. |
825 | class IslAstInfoPrinterLegacyPass final : public ScopPass { |
826 | public: |
827 | static char ID; |
828 | |
829 | IslAstInfoPrinterLegacyPass() : IslAstInfoPrinterLegacyPass(outs()) {} |
830 | explicit IslAstInfoPrinterLegacyPass(llvm::raw_ostream &OS) |
831 | : ScopPass(ID), OS(OS) {} |
832 | |
833 | bool runOnScop(Scop &S) override { |
834 | IslAstInfoWrapperPass &P = getAnalysis<IslAstInfoWrapperPass>(); |
835 | |
836 | OS << "Printing analysis '" << P.getPassName() << "' for region: '" |
837 | << S.getRegion().getNameStr() << "' in function '" |
838 | << S.getFunction().getName() << "':\n" ; |
839 | P.printScop(OS, S); |
840 | |
841 | return false; |
842 | } |
843 | |
844 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
845 | ScopPass::getAnalysisUsage(AU); |
846 | AU.addRequired<IslAstInfoWrapperPass>(); |
847 | AU.setPreservesAll(); |
848 | } |
849 | |
850 | private: |
851 | llvm::raw_ostream &OS; |
852 | }; |
853 | |
854 | char IslAstInfoPrinterLegacyPass::ID = 0; |
855 | } // namespace |
856 | |
857 | Pass *polly::createIslAstInfoPrinterLegacyPass(raw_ostream &OS) { |
858 | return new IslAstInfoPrinterLegacyPass(OS); |
859 | } |
860 | |
861 | INITIALIZE_PASS_BEGIN(IslAstInfoPrinterLegacyPass, "polly-print-ast" , |
862 | "Polly - Print the AST from a SCoP (isl)" , false, false); |
863 | INITIALIZE_PASS_DEPENDENCY(IslAstInfoWrapperPass); |
864 | INITIALIZE_PASS_END(IslAstInfoPrinterLegacyPass, "polly-print-ast" , |
865 | "Polly - Print the AST from a SCoP (isl)" , false, false) |
866 | |