| 1 | //===- IslAst.cpp - isl code generator interface --------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // The isl code generator interface takes a Scop and generates an isl_ast. This |
| 10 | // ist_ast can either be returned directly or it can be pretty printed to |
| 11 | // stdout. |
| 12 | // |
| 13 | // A typical isl_ast output looks like this: |
| 14 | // |
| 15 | // for (c2 = max(0, ceild(n + m, 2); c2 <= min(511, floord(5 * n, 3)); c2++) { |
| 16 | // bb2(c2); |
| 17 | // } |
| 18 | // |
| 19 | // An in-depth discussion of our AST generation approach can be found in: |
| 20 | // |
| 21 | // Polyhedral AST generation is more than scanning polyhedra |
| 22 | // Tobias Grosser, Sven Verdoolaege, Albert Cohen |
| 23 | // ACM Transactions on Programming Languages and Systems (TOPLAS), |
| 24 | // 37(4), July 2015 |
| 25 | // http://www.grosser.es/#pub-polyhedral-AST-generation |
| 26 | // |
| 27 | //===----------------------------------------------------------------------===// |
| 28 | |
| 29 | #include "polly/CodeGen/IslAst.h" |
| 30 | #include "polly/CodeGen/CodeGeneration.h" |
| 31 | #include "polly/DependenceInfo.h" |
| 32 | #include "polly/LinkAllPasses.h" |
| 33 | #include "polly/Options.h" |
| 34 | #include "polly/ScopDetection.h" |
| 35 | #include "polly/ScopInfo.h" |
| 36 | #include "polly/ScopPass.h" |
| 37 | #include "polly/Support/GICHelper.h" |
| 38 | #include "llvm/ADT/Statistic.h" |
| 39 | #include "llvm/IR/Function.h" |
| 40 | #include "llvm/Support/Debug.h" |
| 41 | #include "llvm/Support/raw_ostream.h" |
| 42 | #include "isl/aff.h" |
| 43 | #include "isl/ast.h" |
| 44 | #include "isl/ast_build.h" |
| 45 | #include "isl/id.h" |
| 46 | #include "isl/isl-noexceptions.h" |
| 47 | #include "isl/printer.h" |
| 48 | #include "isl/schedule.h" |
| 49 | #include "isl/set.h" |
| 50 | #include "isl/union_map.h" |
| 51 | #include "isl/val.h" |
| 52 | #include <cassert> |
| 53 | #include <cstdlib> |
| 54 | |
| 55 | #include "polly/Support/PollyDebug.h" |
| 56 | #define DEBUG_TYPE "polly-ast" |
| 57 | |
| 58 | using namespace llvm; |
| 59 | using namespace polly; |
| 60 | |
| 61 | using IslAstUserPayload = IslAstInfo::IslAstUserPayload; |
| 62 | |
| 63 | static cl::opt<bool> |
| 64 | PollyParallel("polly-parallel" , |
| 65 | cl::desc("Generate thread parallel code (isl codegen only)" ), |
| 66 | cl::cat(PollyCategory)); |
| 67 | |
| 68 | static cl::opt<bool> PrintAccesses("polly-ast-print-accesses" , |
| 69 | cl::desc("Print memory access functions" ), |
| 70 | cl::cat(PollyCategory)); |
| 71 | |
| 72 | static cl::opt<bool> PollyParallelForce( |
| 73 | "polly-parallel-force" , |
| 74 | cl::desc( |
| 75 | "Force generation of thread parallel code ignoring any cost model" ), |
| 76 | cl::cat(PollyCategory)); |
| 77 | |
| 78 | static cl::opt<bool> UseContext("polly-ast-use-context" , |
| 79 | cl::desc("Use context" ), cl::Hidden, |
| 80 | cl::init(Val: true), cl::cat(PollyCategory)); |
| 81 | |
| 82 | static cl::opt<bool> DetectParallel("polly-ast-detect-parallel" , |
| 83 | cl::desc("Detect parallelism" ), cl::Hidden, |
| 84 | cl::cat(PollyCategory)); |
| 85 | |
| 86 | STATISTIC(ScopsProcessed, "Number of SCoPs processed" ); |
| 87 | STATISTIC(ScopsBeneficial, "Number of beneficial SCoPs" ); |
| 88 | STATISTIC(BeneficialAffineLoops, "Number of beneficial affine loops" ); |
| 89 | STATISTIC(BeneficialBoxedLoops, "Number of beneficial boxed loops" ); |
| 90 | |
| 91 | STATISTIC(NumForLoops, "Number of for-loops" ); |
| 92 | STATISTIC(NumParallel, "Number of parallel for-loops" ); |
| 93 | STATISTIC(NumInnermostParallel, "Number of innermost parallel for-loops" ); |
| 94 | STATISTIC(NumOutermostParallel, "Number of outermost parallel for-loops" ); |
| 95 | STATISTIC(NumReductionParallel, "Number of reduction-parallel for-loops" ); |
| 96 | STATISTIC(NumExecutedInParallel, "Number of for-loops executed in parallel" ); |
| 97 | STATISTIC(NumIfConditions, "Number of if-conditions" ); |
| 98 | |
| 99 | namespace polly { |
| 100 | |
| 101 | /// Temporary information used when building the ast. |
| 102 | struct AstBuildUserInfo { |
| 103 | /// Construct and initialize the helper struct for AST creation. |
| 104 | AstBuildUserInfo() = default; |
| 105 | |
| 106 | /// The dependence information used for the parallelism check. |
| 107 | const Dependences *Deps = nullptr; |
| 108 | |
| 109 | /// Flag to indicate that we are inside a parallel for node. |
| 110 | bool InParallelFor = false; |
| 111 | |
| 112 | /// Flag to indicate that we are inside an SIMD node. |
| 113 | bool InSIMD = false; |
| 114 | |
| 115 | /// The last iterator id created for the current SCoP. |
| 116 | isl_id *LastForNodeId = nullptr; |
| 117 | }; |
| 118 | } // namespace polly |
| 119 | |
| 120 | /// Free an IslAstUserPayload object pointed to by @p Ptr. |
| 121 | static void freeIslAstUserPayload(void *Ptr) { |
| 122 | delete ((IslAstInfo::IslAstUserPayload *)Ptr); |
| 123 | } |
| 124 | |
| 125 | /// Print a string @p str in a single line using @p Printer. |
| 126 | static isl_printer *printLine(__isl_take isl_printer *Printer, |
| 127 | const std::string &str, |
| 128 | __isl_keep isl_pw_aff *PWA = nullptr) { |
| 129 | Printer = isl_printer_start_line(p: Printer); |
| 130 | Printer = isl_printer_print_str(p: Printer, s: str.c_str()); |
| 131 | if (PWA) |
| 132 | Printer = isl_printer_print_pw_aff(p: Printer, pwaff: PWA); |
| 133 | return isl_printer_end_line(p: Printer); |
| 134 | } |
| 135 | |
| 136 | /// Return all broken reductions as a string of clauses (OpenMP style). |
| 137 | static std::string getBrokenReductionsStr(const isl::ast_node &Node) { |
| 138 | IslAstInfo::MemoryAccessSet *BrokenReductions; |
| 139 | std::string str; |
| 140 | |
| 141 | BrokenReductions = IslAstInfo::getBrokenReductions(Node); |
| 142 | if (!BrokenReductions || BrokenReductions->empty()) |
| 143 | return "" ; |
| 144 | |
| 145 | // Map each type of reduction to a comma separated list of the base addresses. |
| 146 | std::map<MemoryAccess::ReductionType, std::string> Clauses; |
| 147 | for (MemoryAccess *MA : *BrokenReductions) |
| 148 | if (MA->isWrite()) |
| 149 | Clauses[MA->getReductionType()] += |
| 150 | ", " + MA->getScopArrayInfo()->getName(); |
| 151 | |
| 152 | // Now print the reductions sorted by type. Each type will cause a clause |
| 153 | // like: reduction (+ : sum0, sum1, sum2) |
| 154 | for (const auto &ReductionClause : Clauses) { |
| 155 | str += " reduction (" ; |
| 156 | str += MemoryAccess::getReductionOperatorStr(RT: ReductionClause.first); |
| 157 | // Remove the first two symbols (", ") to make the output look pretty. |
| 158 | str += " : " + ReductionClause.second.substr(pos: 2) + ")" ; |
| 159 | } |
| 160 | |
| 161 | return str; |
| 162 | } |
| 163 | |
| 164 | /// Callback executed for each for node in the ast in order to print it. |
| 165 | static isl_printer *cbPrintFor(__isl_take isl_printer *Printer, |
| 166 | __isl_take isl_ast_print_options *Options, |
| 167 | __isl_keep isl_ast_node *Node, void *) { |
| 168 | isl::pw_aff DD = |
| 169 | IslAstInfo::getMinimalDependenceDistance(Node: isl::manage_copy(ptr: Node)); |
| 170 | const std::string BrokenReductionsStr = |
| 171 | getBrokenReductionsStr(Node: isl::manage_copy(ptr: Node)); |
| 172 | const std::string KnownParallelStr = "#pragma known-parallel" ; |
| 173 | const std::string DepDisPragmaStr = "#pragma minimal dependence distance: " ; |
| 174 | const std::string SimdPragmaStr = "#pragma simd" ; |
| 175 | const std::string OmpPragmaStr = "#pragma omp parallel for" ; |
| 176 | |
| 177 | if (!DD.is_null()) |
| 178 | Printer = printLine(Printer, str: DepDisPragmaStr, PWA: DD.get()); |
| 179 | |
| 180 | if (IslAstInfo::isInnermostParallel(Node: isl::manage_copy(ptr: Node))) |
| 181 | Printer = printLine(Printer, str: SimdPragmaStr + BrokenReductionsStr); |
| 182 | |
| 183 | if (IslAstInfo::isExecutedInParallel(Node: isl::manage_copy(ptr: Node))) |
| 184 | Printer = printLine(Printer, str: OmpPragmaStr); |
| 185 | else if (IslAstInfo::isOutermostParallel(Node: isl::manage_copy(ptr: Node))) |
| 186 | Printer = printLine(Printer, str: KnownParallelStr + BrokenReductionsStr); |
| 187 | |
| 188 | return isl_ast_node_for_print(node: Node, p: Printer, options: Options); |
| 189 | } |
| 190 | |
| 191 | /// Check if the current scheduling dimension is parallel. |
| 192 | /// |
| 193 | /// In case the dimension is parallel we also check if any reduction |
| 194 | /// dependences is broken when we exploit this parallelism. If so, |
| 195 | /// @p IsReductionParallel will be set to true. The reduction dependences we use |
| 196 | /// to check are actually the union of the transitive closure of the initial |
| 197 | /// reduction dependences together with their reversal. Even though these |
| 198 | /// dependences connect all iterations with each other (thus they are cyclic) |
| 199 | /// we can perform the parallelism check as we are only interested in a zero |
| 200 | /// (or non-zero) dependence distance on the dimension in question. |
| 201 | static bool astScheduleDimIsParallel(const isl::ast_build &Build, |
| 202 | const Dependences *D, |
| 203 | IslAstUserPayload *NodeInfo) { |
| 204 | if (!D->hasValidDependences()) |
| 205 | return false; |
| 206 | |
| 207 | isl::union_map Schedule = Build.get_schedule(); |
| 208 | isl::union_map Dep = D->getDependences( |
| 209 | Kinds: Dependences::TYPE_RAW | Dependences::TYPE_WAW | Dependences::TYPE_WAR); |
| 210 | |
| 211 | if (!D->isParallel(Schedule: Schedule.get(), Deps: Dep.release())) { |
| 212 | isl::union_map DepsAll = |
| 213 | D->getDependences(Kinds: Dependences::TYPE_RAW | Dependences::TYPE_WAW | |
| 214 | Dependences::TYPE_WAR | Dependences::TYPE_TC_RED); |
| 215 | // TODO: We will need to change isParallel to stop the unwrapping |
| 216 | isl_pw_aff *MinimalDependenceDistanceIsl = nullptr; |
| 217 | D->isParallel(Schedule: Schedule.get(), Deps: DepsAll.release(), |
| 218 | MinDistancePtr: &MinimalDependenceDistanceIsl); |
| 219 | NodeInfo->MinimalDependenceDistance = |
| 220 | isl::manage(ptr: MinimalDependenceDistanceIsl); |
| 221 | return false; |
| 222 | } |
| 223 | |
| 224 | isl::union_map RedDeps = D->getDependences(Kinds: Dependences::TYPE_TC_RED); |
| 225 | if (!D->isParallel(Schedule: Schedule.get(), Deps: RedDeps.release())) |
| 226 | NodeInfo->IsReductionParallel = true; |
| 227 | |
| 228 | if (!NodeInfo->IsReductionParallel) |
| 229 | return true; |
| 230 | |
| 231 | for (const auto &MaRedPair : D->getReductionDependences()) { |
| 232 | if (!MaRedPair.second) |
| 233 | continue; |
| 234 | isl::union_map MaRedDeps = isl::manage_copy(ptr: MaRedPair.second); |
| 235 | if (!D->isParallel(Schedule: Schedule.get(), Deps: MaRedDeps.release())) |
| 236 | NodeInfo->BrokenReductions.insert(Ptr: MaRedPair.first); |
| 237 | } |
| 238 | return true; |
| 239 | } |
| 240 | |
| 241 | // This method is executed before the construction of a for node. It creates |
| 242 | // an isl_id that is used to annotate the subsequently generated ast for nodes. |
| 243 | // |
| 244 | // In this function we also run the following analyses: |
| 245 | // |
| 246 | // - Detection of openmp parallel loops |
| 247 | // |
| 248 | static __isl_give isl_id *astBuildBeforeFor(__isl_keep isl_ast_build *Build, |
| 249 | void *User) { |
| 250 | AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User; |
| 251 | IslAstUserPayload *Payload = new IslAstUserPayload(); |
| 252 | isl_id *Id = isl_id_alloc(ctx: isl_ast_build_get_ctx(build: Build), name: "" , user: Payload); |
| 253 | Id = isl_id_set_free_user(id: Id, free_user: freeIslAstUserPayload); |
| 254 | BuildInfo->LastForNodeId = Id; |
| 255 | |
| 256 | Payload->IsParallel = astScheduleDimIsParallel(Build: isl::manage_copy(ptr: Build), |
| 257 | D: BuildInfo->Deps, NodeInfo: Payload); |
| 258 | |
| 259 | // Test for parallelism only if we are not already inside a parallel loop |
| 260 | if (!BuildInfo->InParallelFor && !BuildInfo->InSIMD) |
| 261 | BuildInfo->InParallelFor = Payload->IsOutermostParallel = |
| 262 | Payload->IsParallel; |
| 263 | |
| 264 | return Id; |
| 265 | } |
| 266 | |
| 267 | // This method is executed after the construction of a for node. |
| 268 | // |
| 269 | // It performs the following actions: |
| 270 | // |
| 271 | // - Reset the 'InParallelFor' flag, as soon as we leave a for node, |
| 272 | // that is marked as openmp parallel. |
| 273 | // |
| 274 | static __isl_give isl_ast_node * |
| 275 | astBuildAfterFor(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Build, |
| 276 | void *User) { |
| 277 | isl_id *Id = isl_ast_node_get_annotation(node: Node); |
| 278 | assert(Id && "Post order visit assumes annotated for nodes" ); |
| 279 | IslAstUserPayload *Payload = (IslAstUserPayload *)isl_id_get_user(id: Id); |
| 280 | assert(Payload && "Post order visit assumes annotated for nodes" ); |
| 281 | |
| 282 | AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User; |
| 283 | assert(Payload->Build.is_null() && "Build environment already set" ); |
| 284 | Payload->Build = isl::manage_copy(ptr: Build); |
| 285 | Payload->IsInnermost = (Id == BuildInfo->LastForNodeId); |
| 286 | |
| 287 | Payload->IsInnermostParallel = |
| 288 | Payload->IsInnermost && (BuildInfo->InSIMD || Payload->IsParallel); |
| 289 | if (Payload->IsOutermostParallel) |
| 290 | BuildInfo->InParallelFor = false; |
| 291 | |
| 292 | isl_id_free(id: Id); |
| 293 | return Node; |
| 294 | } |
| 295 | |
| 296 | static isl_stat (__isl_keep isl_id *MarkId, |
| 297 | __isl_keep isl_ast_build *Build, |
| 298 | void *User) { |
| 299 | if (!MarkId) |
| 300 | return isl_stat_error; |
| 301 | |
| 302 | AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User; |
| 303 | if (strcmp(s1: isl_id_get_name(id: MarkId), s2: "SIMD" ) == 0) |
| 304 | BuildInfo->InSIMD = true; |
| 305 | |
| 306 | return isl_stat_ok; |
| 307 | } |
| 308 | |
| 309 | static __isl_give isl_ast_node * |
| 310 | astBuildAfterMark(__isl_take isl_ast_node *Node, |
| 311 | __isl_keep isl_ast_build *Build, void *User) { |
| 312 | assert(isl_ast_node_get_type(Node) == isl_ast_node_mark); |
| 313 | AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User; |
| 314 | auto *Id = isl_ast_node_mark_get_id(node: Node); |
| 315 | if (strcmp(s1: isl_id_get_name(id: Id), s2: "SIMD" ) == 0) |
| 316 | BuildInfo->InSIMD = false; |
| 317 | isl_id_free(id: Id); |
| 318 | return Node; |
| 319 | } |
| 320 | |
| 321 | static __isl_give isl_ast_node *AtEachDomain(__isl_take isl_ast_node *Node, |
| 322 | __isl_keep isl_ast_build *Build, |
| 323 | void *User) { |
| 324 | assert(!isl_ast_node_get_annotation(Node) && "Node already annotated" ); |
| 325 | |
| 326 | IslAstUserPayload *Payload = new IslAstUserPayload(); |
| 327 | isl_id *Id = isl_id_alloc(ctx: isl_ast_build_get_ctx(build: Build), name: "" , user: Payload); |
| 328 | Id = isl_id_set_free_user(id: Id, free_user: freeIslAstUserPayload); |
| 329 | |
| 330 | Payload->Build = isl::manage_copy(ptr: Build); |
| 331 | |
| 332 | return isl_ast_node_set_annotation(node: Node, annotation: Id); |
| 333 | } |
| 334 | |
| 335 | // Build alias check condition given a pair of minimal/maximal access. |
| 336 | static isl::ast_expr buildCondition(Scop &S, isl::ast_build Build, |
| 337 | const Scop::MinMaxAccessTy *It0, |
| 338 | const Scop::MinMaxAccessTy *It1) { |
| 339 | |
| 340 | isl::pw_multi_aff AFirst = It0->first; |
| 341 | isl::pw_multi_aff ASecond = It0->second; |
| 342 | isl::pw_multi_aff BFirst = It1->first; |
| 343 | isl::pw_multi_aff BSecond = It1->second; |
| 344 | |
| 345 | isl::id Left = AFirst.get_tuple_id(type: isl::dim::set); |
| 346 | isl::id Right = BFirst.get_tuple_id(type: isl::dim::set); |
| 347 | |
| 348 | isl::ast_expr True = |
| 349 | isl::ast_expr::from_val(v: isl::val::int_from_ui(ctx: Build.ctx(), u: 1)); |
| 350 | isl::ast_expr False = |
| 351 | isl::ast_expr::from_val(v: isl::val::int_from_ui(ctx: Build.ctx(), u: 0)); |
| 352 | |
| 353 | const ScopArrayInfo *BaseLeft = |
| 354 | ScopArrayInfo::getFromId(Id: Left)->getBasePtrOriginSAI(); |
| 355 | const ScopArrayInfo *BaseRight = |
| 356 | ScopArrayInfo::getFromId(Id: Right)->getBasePtrOriginSAI(); |
| 357 | if (BaseLeft && BaseLeft == BaseRight) |
| 358 | return True; |
| 359 | |
| 360 | isl::set Params = S.getContext(); |
| 361 | |
| 362 | isl::ast_expr NonAliasGroup, MinExpr, MaxExpr; |
| 363 | |
| 364 | // In the following, we first check if any accesses will be empty under |
| 365 | // the execution context of the scop and do not code generate them if this |
| 366 | // is the case as isl will fail to derive valid AST expressions for such |
| 367 | // accesses. |
| 368 | |
| 369 | if (!AFirst.intersect_params(set: Params).domain().is_empty() && |
| 370 | !BSecond.intersect_params(set: Params).domain().is_empty()) { |
| 371 | MinExpr = Build.access_from(pma: AFirst).address_of(); |
| 372 | MaxExpr = Build.access_from(pma: BSecond).address_of(); |
| 373 | NonAliasGroup = MaxExpr.le(expr2: MinExpr); |
| 374 | } |
| 375 | |
| 376 | if (!BFirst.intersect_params(set: Params).domain().is_empty() && |
| 377 | !ASecond.intersect_params(set: Params).domain().is_empty()) { |
| 378 | MinExpr = Build.access_from(pma: BFirst).address_of(); |
| 379 | MaxExpr = Build.access_from(pma: ASecond).address_of(); |
| 380 | |
| 381 | isl::ast_expr Result = MaxExpr.le(expr2: MinExpr); |
| 382 | if (!NonAliasGroup.is_null()) |
| 383 | NonAliasGroup = isl::manage( |
| 384 | ptr: isl_ast_expr_or(expr1: NonAliasGroup.release(), expr2: Result.release())); |
| 385 | else |
| 386 | NonAliasGroup = Result; |
| 387 | } |
| 388 | |
| 389 | if (NonAliasGroup.is_null()) |
| 390 | NonAliasGroup = True; |
| 391 | |
| 392 | return NonAliasGroup; |
| 393 | } |
| 394 | |
| 395 | isl::ast_expr IslAst::buildRunCondition(Scop &S, const isl::ast_build &Build) { |
| 396 | isl::ast_expr RunCondition; |
| 397 | |
| 398 | // The conditions that need to be checked at run-time for this scop are |
| 399 | // available as an isl_set in the runtime check context from which we can |
| 400 | // directly derive a run-time condition. |
| 401 | auto PosCond = Build.expr_from(set: S.getAssumedContext()); |
| 402 | if (S.hasTrivialInvalidContext()) { |
| 403 | RunCondition = std::move(PosCond); |
| 404 | } else { |
| 405 | auto ZeroV = isl::val::zero(ctx: Build.ctx()); |
| 406 | auto NegCond = Build.expr_from(set: S.getInvalidContext()); |
| 407 | auto NotNegCond = |
| 408 | isl::ast_expr::from_val(v: std::move(ZeroV)).eq(expr2: std::move(NegCond)); |
| 409 | RunCondition = |
| 410 | isl::manage(ptr: isl_ast_expr_and(expr1: PosCond.release(), expr2: NotNegCond.release())); |
| 411 | } |
| 412 | |
| 413 | // Create the alias checks from the minimal/maximal accesses in each alias |
| 414 | // group which consists of read only and non read only (read write) accesses. |
| 415 | // This operation is by construction quadratic in the read-write pointers and |
| 416 | // linear in the read only pointers in each alias group. |
| 417 | for (const Scop::MinMaxVectorPairTy &MinMaxAccessPair : S.getAliasGroups()) { |
| 418 | auto &MinMaxReadWrite = MinMaxAccessPair.first; |
| 419 | auto &MinMaxReadOnly = MinMaxAccessPair.second; |
| 420 | auto RWAccEnd = MinMaxReadWrite.end(); |
| 421 | |
| 422 | for (auto RWAccIt0 = MinMaxReadWrite.begin(); RWAccIt0 != RWAccEnd; |
| 423 | ++RWAccIt0) { |
| 424 | for (auto RWAccIt1 = RWAccIt0 + 1; RWAccIt1 != RWAccEnd; ++RWAccIt1) |
| 425 | RunCondition = isl::manage(ptr: isl_ast_expr_and( |
| 426 | expr1: RunCondition.release(), |
| 427 | expr2: buildCondition(S, Build, It0: RWAccIt0, It1: RWAccIt1).release())); |
| 428 | for (const Scop::MinMaxAccessTy &ROAccIt : MinMaxReadOnly) |
| 429 | RunCondition = isl::manage(ptr: isl_ast_expr_and( |
| 430 | expr1: RunCondition.release(), |
| 431 | expr2: buildCondition(S, Build, It0: RWAccIt0, It1: &ROAccIt).release())); |
| 432 | } |
| 433 | } |
| 434 | |
| 435 | return RunCondition; |
| 436 | } |
| 437 | |
| 438 | /// Simple cost analysis for a given SCoP. |
| 439 | /// |
| 440 | /// TODO: Improve this analysis and extract it to make it usable in other |
| 441 | /// places too. |
| 442 | /// In order to improve the cost model we could either keep track of |
| 443 | /// performed optimizations (e.g., tiling) or compute properties on the |
| 444 | /// original as well as optimized SCoP (e.g., #stride-one-accesses). |
| 445 | static bool benefitsFromPolly(Scop &Scop, bool PerformParallelTest) { |
| 446 | if (PollyProcessUnprofitable) |
| 447 | return true; |
| 448 | |
| 449 | // Check if nothing interesting happened. |
| 450 | if (!PerformParallelTest && !Scop.isOptimized() && |
| 451 | Scop.getAliasGroups().empty()) |
| 452 | return false; |
| 453 | |
| 454 | // The default assumption is that Polly improves the code. |
| 455 | return true; |
| 456 | } |
| 457 | |
| 458 | /// Collect statistics for the syntax tree rooted at @p Ast. |
| 459 | static void walkAstForStatistics(const isl::ast_node &Ast) { |
| 460 | assert(!Ast.is_null()); |
| 461 | isl_ast_node_foreach_descendant_top_down( |
| 462 | node: Ast.get(), |
| 463 | fn: [](__isl_keep isl_ast_node *Node, void *User) -> isl_bool { |
| 464 | switch (isl_ast_node_get_type(node: Node)) { |
| 465 | case isl_ast_node_for: |
| 466 | NumForLoops++; |
| 467 | if (IslAstInfo::isParallel(Node: isl::manage_copy(ptr: Node))) |
| 468 | NumParallel++; |
| 469 | if (IslAstInfo::isInnermostParallel(Node: isl::manage_copy(ptr: Node))) |
| 470 | NumInnermostParallel++; |
| 471 | if (IslAstInfo::isOutermostParallel(Node: isl::manage_copy(ptr: Node))) |
| 472 | NumOutermostParallel++; |
| 473 | if (IslAstInfo::isReductionParallel(Node: isl::manage_copy(ptr: Node))) |
| 474 | NumReductionParallel++; |
| 475 | if (IslAstInfo::isExecutedInParallel(Node: isl::manage_copy(ptr: Node))) |
| 476 | NumExecutedInParallel++; |
| 477 | break; |
| 478 | |
| 479 | case isl_ast_node_if: |
| 480 | NumIfConditions++; |
| 481 | break; |
| 482 | |
| 483 | default: |
| 484 | break; |
| 485 | } |
| 486 | |
| 487 | // Continue traversing subtrees. |
| 488 | return isl_bool_true; |
| 489 | }, |
| 490 | user: nullptr); |
| 491 | } |
| 492 | |
| 493 | IslAst::IslAst(Scop &Scop) : S(Scop), Ctx(Scop.getSharedIslCtx()) {} |
| 494 | |
| 495 | IslAst::IslAst(IslAst &&O) |
| 496 | : S(O.S), Ctx(O.Ctx), RunCondition(std::move(O.RunCondition)), |
| 497 | Root(std::move(O.Root)) {} |
| 498 | |
| 499 | void IslAst::init(const Dependences &D) { |
| 500 | bool PerformParallelTest = PollyParallel || DetectParallel || |
| 501 | PollyVectorizerChoice != VECTORIZER_NONE; |
| 502 | auto ScheduleTree = S.getScheduleTree(); |
| 503 | |
| 504 | // Skip AST and code generation if there was no benefit achieved. |
| 505 | if (!benefitsFromPolly(Scop&: S, PerformParallelTest)) |
| 506 | return; |
| 507 | |
| 508 | auto ScopStats = S.getStatistics(); |
| 509 | ScopsBeneficial++; |
| 510 | BeneficialAffineLoops += ScopStats.NumAffineLoops; |
| 511 | BeneficialBoxedLoops += ScopStats.NumBoxedLoops; |
| 512 | |
| 513 | auto Ctx = S.getIslCtx(); |
| 514 | isl_options_set_ast_build_atomic_upper_bound(ctx: Ctx.get(), val: true); |
| 515 | isl_options_set_ast_build_detect_min_max(ctx: Ctx.get(), val: true); |
| 516 | isl_ast_build *Build; |
| 517 | AstBuildUserInfo BuildInfo; |
| 518 | |
| 519 | if (UseContext) |
| 520 | Build = isl_ast_build_from_context(set: S.getContext().release()); |
| 521 | else |
| 522 | Build = isl_ast_build_from_context( |
| 523 | set: isl_set_universe(space: S.getParamSpace().release())); |
| 524 | |
| 525 | Build = isl_ast_build_set_at_each_domain(build: Build, fn: AtEachDomain, user: nullptr); |
| 526 | |
| 527 | if (PerformParallelTest) { |
| 528 | BuildInfo.Deps = &D; |
| 529 | BuildInfo.InParallelFor = false; |
| 530 | BuildInfo.InSIMD = false; |
| 531 | |
| 532 | Build = isl_ast_build_set_before_each_for(build: Build, fn: &astBuildBeforeFor, |
| 533 | user: &BuildInfo); |
| 534 | Build = |
| 535 | isl_ast_build_set_after_each_for(build: Build, fn: &astBuildAfterFor, user: &BuildInfo); |
| 536 | |
| 537 | Build = isl_ast_build_set_before_each_mark(build: Build, fn: &astBuildBeforeMark, |
| 538 | user: &BuildInfo); |
| 539 | |
| 540 | Build = isl_ast_build_set_after_each_mark(build: Build, fn: &astBuildAfterMark, |
| 541 | user: &BuildInfo); |
| 542 | } |
| 543 | |
| 544 | RunCondition = buildRunCondition(S, Build: isl::manage_copy(ptr: Build)); |
| 545 | |
| 546 | Root = isl::manage( |
| 547 | ptr: isl_ast_build_node_from_schedule(build: Build, schedule: S.getScheduleTree().release())); |
| 548 | walkAstForStatistics(Ast: Root); |
| 549 | |
| 550 | isl_ast_build_free(build: Build); |
| 551 | } |
| 552 | |
| 553 | IslAst IslAst::create(Scop &Scop, const Dependences &D) { |
| 554 | IslAst Ast{Scop}; |
| 555 | Ast.init(D); |
| 556 | return Ast; |
| 557 | } |
| 558 | |
| 559 | isl::ast_node IslAst::getAst() { return Root; } |
| 560 | isl::ast_expr IslAst::getRunCondition() { return RunCondition; } |
| 561 | |
| 562 | isl::ast_node IslAstInfo::getAst() { return Ast.getAst(); } |
| 563 | isl::ast_expr IslAstInfo::getRunCondition() { return Ast.getRunCondition(); } |
| 564 | |
| 565 | IslAstUserPayload *IslAstInfo::getNodePayload(const isl::ast_node &Node) { |
| 566 | isl::id Id = Node.get_annotation(); |
| 567 | if (Id.is_null()) |
| 568 | return nullptr; |
| 569 | IslAstUserPayload *Payload = (IslAstUserPayload *)Id.get_user(); |
| 570 | return Payload; |
| 571 | } |
| 572 | |
| 573 | bool IslAstInfo::isInnermost(const isl::ast_node &Node) { |
| 574 | IslAstUserPayload *Payload = getNodePayload(Node); |
| 575 | return Payload && Payload->IsInnermost; |
| 576 | } |
| 577 | |
| 578 | bool IslAstInfo::isParallel(const isl::ast_node &Node) { |
| 579 | return IslAstInfo::isInnermostParallel(Node) || |
| 580 | IslAstInfo::isOutermostParallel(Node); |
| 581 | } |
| 582 | |
| 583 | bool IslAstInfo::isInnermostParallel(const isl::ast_node &Node) { |
| 584 | IslAstUserPayload *Payload = getNodePayload(Node); |
| 585 | return Payload && Payload->IsInnermostParallel; |
| 586 | } |
| 587 | |
| 588 | bool IslAstInfo::isOutermostParallel(const isl::ast_node &Node) { |
| 589 | IslAstUserPayload *Payload = getNodePayload(Node); |
| 590 | return Payload && Payload->IsOutermostParallel; |
| 591 | } |
| 592 | |
| 593 | bool IslAstInfo::isReductionParallel(const isl::ast_node &Node) { |
| 594 | IslAstUserPayload *Payload = getNodePayload(Node); |
| 595 | return Payload && Payload->IsReductionParallel; |
| 596 | } |
| 597 | |
| 598 | bool IslAstInfo::isExecutedInParallel(const isl::ast_node &Node) { |
| 599 | if (!PollyParallel) |
| 600 | return false; |
| 601 | |
| 602 | // Do not parallelize innermost loops. |
| 603 | // |
| 604 | // Parallelizing innermost loops is often not profitable, especially if |
| 605 | // they have a low number of iterations. |
| 606 | // |
| 607 | // TODO: Decide this based on the number of loop iterations that will be |
| 608 | // executed. This can possibly require run-time checks, which again |
| 609 | // raises the question of both run-time check overhead and code size |
| 610 | // costs. |
| 611 | if (!PollyParallelForce && isInnermost(Node)) |
| 612 | return false; |
| 613 | |
| 614 | return isOutermostParallel(Node) && !isReductionParallel(Node); |
| 615 | } |
| 616 | |
| 617 | isl::union_map IslAstInfo::getSchedule(const isl::ast_node &Node) { |
| 618 | IslAstUserPayload *Payload = getNodePayload(Node); |
| 619 | return Payload ? Payload->Build.get_schedule() : isl::union_map(); |
| 620 | } |
| 621 | |
| 622 | isl::pw_aff |
| 623 | IslAstInfo::getMinimalDependenceDistance(const isl::ast_node &Node) { |
| 624 | IslAstUserPayload *Payload = getNodePayload(Node); |
| 625 | return Payload ? Payload->MinimalDependenceDistance : isl::pw_aff(); |
| 626 | } |
| 627 | |
| 628 | IslAstInfo::MemoryAccessSet * |
| 629 | IslAstInfo::getBrokenReductions(const isl::ast_node &Node) { |
| 630 | IslAstUserPayload *Payload = getNodePayload(Node); |
| 631 | return Payload ? &Payload->BrokenReductions : nullptr; |
| 632 | } |
| 633 | |
| 634 | isl::ast_build IslAstInfo::getBuild(const isl::ast_node &Node) { |
| 635 | IslAstUserPayload *Payload = getNodePayload(Node); |
| 636 | return Payload ? Payload->Build : isl::ast_build(); |
| 637 | } |
| 638 | |
| 639 | static std::unique_ptr<IslAstInfo> runIslAst( |
| 640 | Scop &Scop, |
| 641 | function_ref<const Dependences &(Dependences::AnalysisLevel)> GetDeps) { |
| 642 | ScopsProcessed++; |
| 643 | |
| 644 | const Dependences &D = GetDeps(Dependences::AL_Statement); |
| 645 | |
| 646 | if (D.getSharedIslCtx() != Scop.getSharedIslCtx()) { |
| 647 | POLLY_DEBUG( |
| 648 | dbgs() << "Got dependence analysis for different SCoP/isl_ctx\n" ); |
| 649 | return {}; |
| 650 | } |
| 651 | |
| 652 | std::unique_ptr<IslAstInfo> Ast = std::make_unique<IslAstInfo>(args&: Scop, args: D); |
| 653 | |
| 654 | POLLY_DEBUG({ |
| 655 | if (Ast) |
| 656 | Ast->print(dbgs()); |
| 657 | }); |
| 658 | |
| 659 | return Ast; |
| 660 | } |
| 661 | |
| 662 | IslAstInfo IslAstAnalysis::run(Scop &S, ScopAnalysisManager &SAM, |
| 663 | ScopStandardAnalysisResults &SAR) { |
| 664 | auto GetDeps = [&](Dependences::AnalysisLevel Lvl) -> const Dependences & { |
| 665 | return SAM.getResult<DependenceAnalysis>(IR&: S, ExtraArgs&: SAR).getDependences(Level: Lvl); |
| 666 | }; |
| 667 | |
| 668 | return std::move(*runIslAst(Scop&: S, GetDeps)); |
| 669 | } |
| 670 | |
| 671 | static __isl_give isl_printer *cbPrintUser(__isl_take isl_printer *P, |
| 672 | __isl_take isl_ast_print_options *O, |
| 673 | __isl_keep isl_ast_node *Node, |
| 674 | void *User) { |
| 675 | isl::ast_node_user AstNode = isl::manage_copy(ptr: Node).as<isl::ast_node_user>(); |
| 676 | isl::ast_expr NodeExpr = AstNode.expr(); |
| 677 | isl::ast_expr CallExpr = NodeExpr.get_op_arg(pos: 0); |
| 678 | isl::id CallExprId = CallExpr.get_id(); |
| 679 | ScopStmt *AccessStmt = (ScopStmt *)CallExprId.get_user(); |
| 680 | |
| 681 | P = isl_printer_start_line(p: P); |
| 682 | P = isl_printer_print_str(p: P, s: AccessStmt->getBaseName()); |
| 683 | P = isl_printer_print_str(p: P, s: "(" ); |
| 684 | P = isl_printer_end_line(p: P); |
| 685 | P = isl_printer_indent(p: P, indent: 2); |
| 686 | |
| 687 | for (MemoryAccess *MemAcc : *AccessStmt) { |
| 688 | P = isl_printer_start_line(p: P); |
| 689 | |
| 690 | if (MemAcc->isRead()) |
| 691 | P = isl_printer_print_str(p: P, s: "/* read */ &" ); |
| 692 | else |
| 693 | P = isl_printer_print_str(p: P, s: "/* write */ " ); |
| 694 | |
| 695 | isl::ast_build Build = IslAstInfo::getBuild(Node: isl::manage_copy(ptr: Node)); |
| 696 | if (MemAcc->isAffine()) { |
| 697 | isl_pw_multi_aff *PwmaPtr = |
| 698 | MemAcc->applyScheduleToAccessRelation(Schedule: Build.get_schedule()).release(); |
| 699 | isl::pw_multi_aff Pwma = isl::manage(ptr: PwmaPtr); |
| 700 | isl::ast_expr AccessExpr = Build.access_from(pma: Pwma); |
| 701 | P = isl_printer_print_ast_expr(p: P, expr: AccessExpr.get()); |
| 702 | } else { |
| 703 | P = isl_printer_print_str( |
| 704 | p: P, s: MemAcc->getLatestScopArrayInfo()->getName().c_str()); |
| 705 | P = isl_printer_print_str(p: P, s: "[*]" ); |
| 706 | } |
| 707 | P = isl_printer_end_line(p: P); |
| 708 | } |
| 709 | |
| 710 | P = isl_printer_indent(p: P, indent: -2); |
| 711 | P = isl_printer_start_line(p: P); |
| 712 | P = isl_printer_print_str(p: P, s: ");" ); |
| 713 | P = isl_printer_end_line(p: P); |
| 714 | |
| 715 | isl_ast_print_options_free(options: O); |
| 716 | return P; |
| 717 | } |
| 718 | |
| 719 | void IslAstInfo::print(raw_ostream &OS) { |
| 720 | isl_ast_print_options *Options; |
| 721 | isl::ast_node RootNode = Ast.getAst(); |
| 722 | Function &F = S.getFunction(); |
| 723 | |
| 724 | OS << ":: isl ast :: " << F.getName() << " :: " << S.getNameStr() << "\n" ; |
| 725 | |
| 726 | if (RootNode.is_null()) { |
| 727 | OS << ":: isl ast generation and code generation was skipped!\n\n" ; |
| 728 | OS << ":: This is either because no useful optimizations could be applied " |
| 729 | "(use -polly-process-unprofitable to enforce code generation) or " |
| 730 | "because earlier passes such as dependence analysis timed out (use " |
| 731 | "-polly-dependences-computeout=0 to set dependence analysis timeout " |
| 732 | "to infinity)\n\n" ; |
| 733 | return; |
| 734 | } |
| 735 | |
| 736 | isl::ast_expr RunCondition = Ast.getRunCondition(); |
| 737 | char *RtCStr, *AstStr; |
| 738 | |
| 739 | Options = isl_ast_print_options_alloc(ctx: S.getIslCtx().get()); |
| 740 | |
| 741 | if (PrintAccesses) |
| 742 | Options = |
| 743 | isl_ast_print_options_set_print_user(options: Options, print_user: cbPrintUser, user: nullptr); |
| 744 | Options = isl_ast_print_options_set_print_for(options: Options, print_for: cbPrintFor, user: nullptr); |
| 745 | |
| 746 | isl_printer *P = isl_printer_to_str(ctx: S.getIslCtx().get()); |
| 747 | P = isl_printer_set_output_format(p: P, ISL_FORMAT_C); |
| 748 | P = isl_printer_print_ast_expr(p: P, expr: RunCondition.get()); |
| 749 | RtCStr = isl_printer_get_str(printer: P); |
| 750 | P = isl_printer_flush(p: P); |
| 751 | P = isl_printer_indent(p: P, indent: 4); |
| 752 | P = isl_ast_node_print(node: RootNode.get(), p: P, options: Options); |
| 753 | AstStr = isl_printer_get_str(printer: P); |
| 754 | |
| 755 | POLLY_DEBUG({ |
| 756 | dbgs() << S.getContextStr() << "\n" ; |
| 757 | dbgs() << stringFromIslObj(S.getScheduleTree(), "null" ); |
| 758 | }); |
| 759 | OS << "\nif (" << RtCStr << ")\n\n" ; |
| 760 | OS << AstStr << "\n" ; |
| 761 | OS << "else\n" ; |
| 762 | OS << " { /* original code */ }\n\n" ; |
| 763 | |
| 764 | free(ptr: RtCStr); |
| 765 | free(ptr: AstStr); |
| 766 | |
| 767 | isl_printer_free(printer: P); |
| 768 | } |
| 769 | |
| 770 | AnalysisKey IslAstAnalysis::Key; |
| 771 | PreservedAnalyses IslAstPrinterPass::run(Scop &S, ScopAnalysisManager &SAM, |
| 772 | ScopStandardAnalysisResults &SAR, |
| 773 | SPMUpdater &U) { |
| 774 | auto &Ast = SAM.getResult<IslAstAnalysis>(IR&: S, ExtraArgs&: SAR); |
| 775 | Ast.print(OS); |
| 776 | return PreservedAnalyses::all(); |
| 777 | } |
| 778 | |
| 779 | void IslAstInfoWrapperPass::releaseMemory() { Ast.reset(); } |
| 780 | |
| 781 | bool IslAstInfoWrapperPass::runOnScop(Scop &Scop) { |
| 782 | auto GetDeps = [this](Dependences::AnalysisLevel Lvl) -> const Dependences & { |
| 783 | return getAnalysis<DependenceInfo>().getDependences(Level: Lvl); |
| 784 | }; |
| 785 | |
| 786 | Ast = runIslAst(Scop, GetDeps); |
| 787 | |
| 788 | return false; |
| 789 | } |
| 790 | |
| 791 | void IslAstInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { |
| 792 | // Get the Common analysis usage of ScopPasses. |
| 793 | ScopPass::getAnalysisUsage(AU); |
| 794 | AU.addRequiredTransitive<ScopInfoRegionPass>(); |
| 795 | AU.addRequired<DependenceInfo>(); |
| 796 | |
| 797 | AU.addPreserved<DependenceInfo>(); |
| 798 | } |
| 799 | |
| 800 | void IslAstInfoWrapperPass::printScop(raw_ostream &OS, Scop &S) const { |
| 801 | OS << "Printing analysis 'Polly - Generate an AST of the SCoP (isl)'" |
| 802 | << S.getName() << "' in function '" << S.getFunction().getName() << "':\n" ; |
| 803 | if (Ast) |
| 804 | Ast->print(OS); |
| 805 | } |
| 806 | |
| 807 | char IslAstInfoWrapperPass::ID = 0; |
| 808 | |
| 809 | Pass *polly::createIslAstInfoWrapperPassPass() { |
| 810 | return new IslAstInfoWrapperPass(); |
| 811 | } |
| 812 | |
| 813 | INITIALIZE_PASS_BEGIN(IslAstInfoWrapperPass, "polly-ast" , |
| 814 | "Polly - Generate an AST of the SCoP (isl)" , false, |
| 815 | false); |
| 816 | INITIALIZE_PASS_DEPENDENCY(ScopInfoRegionPass); |
| 817 | INITIALIZE_PASS_DEPENDENCY(DependenceInfo); |
| 818 | INITIALIZE_PASS_END(IslAstInfoWrapperPass, "polly-ast" , |
| 819 | "Polly - Generate an AST from the SCoP (isl)" , false, false) |
| 820 | |
| 821 | //===----------------------------------------------------------------------===// |
| 822 | |
| 823 | namespace { |
| 824 | /// Print result from IslAstInfoWrapperPass. |
| 825 | class IslAstInfoPrinterLegacyPass final : public ScopPass { |
| 826 | public: |
| 827 | static char ID; |
| 828 | |
| 829 | IslAstInfoPrinterLegacyPass() : IslAstInfoPrinterLegacyPass(outs()) {} |
| 830 | explicit IslAstInfoPrinterLegacyPass(llvm::raw_ostream &OS) |
| 831 | : ScopPass(ID), OS(OS) {} |
| 832 | |
| 833 | bool runOnScop(Scop &S) override { |
| 834 | IslAstInfoWrapperPass &P = getAnalysis<IslAstInfoWrapperPass>(); |
| 835 | |
| 836 | OS << "Printing analysis '" << P.getPassName() << "' for region: '" |
| 837 | << S.getRegion().getNameStr() << "' in function '" |
| 838 | << S.getFunction().getName() << "':\n" ; |
| 839 | P.printScop(OS, S); |
| 840 | |
| 841 | return false; |
| 842 | } |
| 843 | |
| 844 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
| 845 | ScopPass::getAnalysisUsage(AU); |
| 846 | AU.addRequired<IslAstInfoWrapperPass>(); |
| 847 | AU.setPreservesAll(); |
| 848 | } |
| 849 | |
| 850 | private: |
| 851 | llvm::raw_ostream &OS; |
| 852 | }; |
| 853 | |
| 854 | char IslAstInfoPrinterLegacyPass::ID = 0; |
| 855 | } // namespace |
| 856 | |
| 857 | Pass *polly::createIslAstInfoPrinterLegacyPass(raw_ostream &OS) { |
| 858 | return new IslAstInfoPrinterLegacyPass(OS); |
| 859 | } |
| 860 | |
| 861 | INITIALIZE_PASS_BEGIN(IslAstInfoPrinterLegacyPass, "polly-print-ast" , |
| 862 | "Polly - Print the AST from a SCoP (isl)" , false, false); |
| 863 | INITIALIZE_PASS_DEPENDENCY(IslAstInfoWrapperPass); |
| 864 | INITIALIZE_PASS_END(IslAstInfoPrinterLegacyPass, "polly-print-ast" , |
| 865 | "Polly - Print the AST from a SCoP (isl)" , false, false) |
| 866 | |