| 1 | //===- polly/ScheduleTreeTransform.cpp --------------------------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Make changes to isl's schedule tree data structure. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "polly/ScheduleTreeTransform.h" |
| 14 | #include "polly/Support/GICHelper.h" |
| 15 | #include "polly/Support/ISLTools.h" |
| 16 | #include "polly/Support/ScopHelper.h" |
| 17 | #include "llvm/ADT/ArrayRef.h" |
| 18 | #include "llvm/ADT/Sequence.h" |
| 19 | #include "llvm/ADT/SmallVector.h" |
| 20 | #include "llvm/IR/Constants.h" |
| 21 | #include "llvm/IR/Metadata.h" |
| 22 | #include "llvm/Transforms/Utils/UnrollLoop.h" |
| 23 | |
| 24 | #include "polly/Support/PollyDebug.h" |
| 25 | #define DEBUG_TYPE "polly-opt-isl" |
| 26 | |
| 27 | using namespace polly; |
| 28 | using namespace llvm; |
| 29 | |
| 30 | namespace { |
| 31 | |
| 32 | /// Copy the band member attributes (coincidence, loop type, isolate ast loop |
| 33 | /// type) from one band to another. |
| 34 | static isl::schedule_node_band |
| 35 | applyBandMemberAttributes(isl::schedule_node_band Target, int TargetIdx, |
| 36 | const isl::schedule_node_band &Source, |
| 37 | int SourceIdx) { |
| 38 | bool Coincident = Source.member_get_coincident(pos: SourceIdx).release(); |
| 39 | Target = Target.member_set_coincident(pos: TargetIdx, coincident: Coincident); |
| 40 | |
| 41 | isl_ast_loop_type LoopType = |
| 42 | isl_schedule_node_band_member_get_ast_loop_type(node: Source.get(), pos: SourceIdx); |
| 43 | Target = isl::manage(ptr: isl_schedule_node_band_member_set_ast_loop_type( |
| 44 | node: Target.release(), pos: TargetIdx, type: LoopType)) |
| 45 | .as<isl::schedule_node_band>(); |
| 46 | |
| 47 | isl_ast_loop_type IsolateType = |
| 48 | isl_schedule_node_band_member_get_isolate_ast_loop_type(node: Source.get(), |
| 49 | pos: SourceIdx); |
| 50 | Target = isl::manage(ptr: isl_schedule_node_band_member_set_isolate_ast_loop_type( |
| 51 | node: Target.release(), pos: TargetIdx, type: IsolateType)) |
| 52 | .as<isl::schedule_node_band>(); |
| 53 | |
| 54 | return Target; |
| 55 | } |
| 56 | |
| 57 | /// Create a new band by copying members from another @p Band. @p IncludeCb |
| 58 | /// decides which band indices are copied to the result. |
| 59 | template <typename CbTy> |
| 60 | static isl::schedule rebuildBand(isl::schedule_node_band OldBand, |
| 61 | isl::schedule Body, CbTy IncludeCb) { |
| 62 | int NumBandDims = unsignedFromIslSize(Size: OldBand.n_member()); |
| 63 | |
| 64 | bool ExcludeAny = false; |
| 65 | bool IncludeAny = false; |
| 66 | for (auto OldIdx : seq<int>(Begin: 0, End: NumBandDims)) { |
| 67 | if (IncludeCb(OldIdx)) |
| 68 | IncludeAny = true; |
| 69 | else |
| 70 | ExcludeAny = true; |
| 71 | } |
| 72 | |
| 73 | // Instead of creating a zero-member band, don't create a band at all. |
| 74 | if (!IncludeAny) |
| 75 | return Body; |
| 76 | |
| 77 | isl::multi_union_pw_aff PartialSched = OldBand.get_partial_schedule(); |
| 78 | isl::multi_union_pw_aff NewPartialSched; |
| 79 | if (ExcludeAny) { |
| 80 | // Select the included partial scatter functions. |
| 81 | isl::union_pw_aff_list List = PartialSched.list(); |
| 82 | int NewIdx = 0; |
| 83 | for (auto OldIdx : seq<int>(Begin: 0, End: NumBandDims)) { |
| 84 | if (IncludeCb(OldIdx)) |
| 85 | NewIdx += 1; |
| 86 | else |
| 87 | List = List.drop(first: NewIdx, n: 1); |
| 88 | } |
| 89 | isl::space ParamSpace = PartialSched.get_space().params(); |
| 90 | isl::space NewScatterSpace = ParamSpace.add_unnamed_tuple(dim: NewIdx); |
| 91 | NewPartialSched = isl::multi_union_pw_aff(NewScatterSpace, List); |
| 92 | } else { |
| 93 | // Just reuse original scatter function of copying all of them. |
| 94 | NewPartialSched = PartialSched; |
| 95 | } |
| 96 | |
| 97 | // Create the new band node. |
| 98 | isl::schedule_node_band NewBand = |
| 99 | Body.insert_partial_schedule(partial: NewPartialSched) |
| 100 | .get_root() |
| 101 | .child(pos: 0) |
| 102 | .as<isl::schedule_node_band>(); |
| 103 | |
| 104 | // If OldBand was permutable, so is the new one, even if some dimensions are |
| 105 | // missing. |
| 106 | bool IsPermutable = OldBand.permutable().release(); |
| 107 | NewBand = NewBand.set_permutable(IsPermutable); |
| 108 | |
| 109 | // Reapply member attributes. |
| 110 | int NewIdx = 0; |
| 111 | for (auto OldIdx : seq<int>(Begin: 0, End: NumBandDims)) { |
| 112 | if (!IncludeCb(OldIdx)) |
| 113 | continue; |
| 114 | NewBand = |
| 115 | applyBandMemberAttributes(Target: std::move(NewBand), TargetIdx: NewIdx, Source: OldBand, SourceIdx: OldIdx); |
| 116 | NewIdx += 1; |
| 117 | } |
| 118 | |
| 119 | return NewBand.get_schedule(); |
| 120 | } |
| 121 | |
| 122 | /// Rewrite a schedule tree by reconstructing it bottom-up. |
| 123 | /// |
| 124 | /// By default, the original schedule tree is reconstructed. To build a |
| 125 | /// different tree, redefine visitor methods in a derived class (CRTP). |
| 126 | /// |
| 127 | /// Note that AST build options are not applied; Setting the isolate[] option |
| 128 | /// makes the schedule tree 'anchored' and cannot be modified afterwards. Hence, |
| 129 | /// AST build options must be set after the tree has been constructed. |
| 130 | template <typename Derived, typename... Args> |
| 131 | struct ScheduleTreeRewriter |
| 132 | : RecursiveScheduleTreeVisitor<Derived, isl::schedule, Args...> { |
| 133 | Derived &getDerived() { return *static_cast<Derived *>(this); } |
| 134 | const Derived &getDerived() const { |
| 135 | return *static_cast<const Derived *>(this); |
| 136 | } |
| 137 | |
| 138 | isl::schedule visitDomain(isl::schedule_node_domain Node, Args... args) { |
| 139 | // Every schedule_tree already has a domain node, no need to add one. |
| 140 | return getDerived().visit(Node.first_child(), std::forward<Args>(args)...); |
| 141 | } |
| 142 | |
| 143 | isl::schedule visitBand(isl::schedule_node_band Band, Args... args) { |
| 144 | isl::schedule NewChild = |
| 145 | getDerived().visit(Band.child(pos: 0), std::forward<Args>(args)...); |
| 146 | return rebuildBand(Band, NewChild, [](int) { return true; }); |
| 147 | } |
| 148 | |
| 149 | isl::schedule visitSequence(isl::schedule_node_sequence Sequence, |
| 150 | Args... args) { |
| 151 | int NumChildren = isl_schedule_node_n_children(node: Sequence.get()); |
| 152 | isl::schedule Result = |
| 153 | getDerived().visit(Sequence.child(pos: 0), std::forward<Args>(args)...); |
| 154 | for (int i = 1; i < NumChildren; i += 1) |
| 155 | Result = Result.sequence( |
| 156 | schedule2: getDerived().visit(Sequence.child(pos: i), std::forward<Args>(args)...)); |
| 157 | return Result; |
| 158 | } |
| 159 | |
| 160 | isl::schedule visitSet(isl::schedule_node_set Set, Args... args) { |
| 161 | int NumChildren = isl_schedule_node_n_children(node: Set.get()); |
| 162 | isl::schedule Result = |
| 163 | getDerived().visit(Set.child(pos: 0), std::forward<Args>(args)...); |
| 164 | for (int i = 1; i < NumChildren; i += 1) |
| 165 | Result = isl::manage( |
| 166 | isl_schedule_set(Result.release(), |
| 167 | getDerived() |
| 168 | .visit(Set.child(pos: i), std::forward<Args>(args)...) |
| 169 | .release())); |
| 170 | return Result; |
| 171 | } |
| 172 | |
| 173 | isl::schedule visitLeaf(isl::schedule_node_leaf Leaf, Args... args) { |
| 174 | return isl::schedule::from_domain(domain: Leaf.get_domain()); |
| 175 | } |
| 176 | |
| 177 | isl::schedule visitMark(const isl::schedule_node &Mark, Args... args) { |
| 178 | |
| 179 | isl::id TheMark = Mark.as<isl::schedule_node_mark>().get_id(); |
| 180 | isl::schedule_node NewChild = |
| 181 | getDerived() |
| 182 | .visit(Mark.first_child(), std::forward<Args>(args)...) |
| 183 | .get_root() |
| 184 | .first_child(); |
| 185 | return NewChild.insert_mark(mark: TheMark).get_schedule(); |
| 186 | } |
| 187 | |
| 188 | isl::schedule visitExtension(isl::schedule_node_extension Extension, |
| 189 | Args... args) { |
| 190 | isl::union_map TheExtension = |
| 191 | Extension.as<isl::schedule_node_extension>().get_extension(); |
| 192 | isl::schedule_node NewChild = getDerived() |
| 193 | .visit(Extension.child(pos: 0), args...) |
| 194 | .get_root() |
| 195 | .first_child(); |
| 196 | isl::schedule_node NewExtension = |
| 197 | isl::schedule_node::from_extension(extension: TheExtension); |
| 198 | return NewChild.graft_before(graft: NewExtension).get_schedule(); |
| 199 | } |
| 200 | |
| 201 | isl::schedule visitFilter(isl::schedule_node_filter Filter, Args... args) { |
| 202 | isl::union_set FilterDomain = |
| 203 | Filter.as<isl::schedule_node_filter>().get_filter(); |
| 204 | isl::schedule NewSchedule = |
| 205 | getDerived().visit(Filter.child(pos: 0), std::forward<Args>(args)...); |
| 206 | return NewSchedule.intersect_domain(domain: FilterDomain); |
| 207 | } |
| 208 | |
| 209 | isl::schedule visitNode(isl::schedule_node Node, Args... args) { |
| 210 | llvm_unreachable("Not implemented" ); |
| 211 | } |
| 212 | }; |
| 213 | |
| 214 | /// Rewrite the schedule tree without any changes. Useful to copy a subtree into |
| 215 | /// a new schedule, discarding everything but. |
| 216 | struct IdentityRewriter : ScheduleTreeRewriter<IdentityRewriter> {}; |
| 217 | |
| 218 | /// Rewrite a schedule tree to an equivalent one without extension nodes. |
| 219 | /// |
| 220 | /// Each visit method takes two additional arguments: |
| 221 | /// |
| 222 | /// * The new domain the node, which is the inherited domain plus any domains |
| 223 | /// added by extension nodes. |
| 224 | /// |
| 225 | /// * A map of extension domains of all children is returned; it is required by |
| 226 | /// band nodes to schedule the additional domains at the same position as the |
| 227 | /// extension node would. |
| 228 | /// |
| 229 | struct ExtensionNodeRewriter final |
| 230 | : ScheduleTreeRewriter<ExtensionNodeRewriter, const isl::union_set &, |
| 231 | isl::union_map &> { |
| 232 | using BaseTy = ScheduleTreeRewriter<ExtensionNodeRewriter, |
| 233 | const isl::union_set &, isl::union_map &>; |
| 234 | BaseTy &getBase() { return *this; } |
| 235 | const BaseTy &getBase() const { return *this; } |
| 236 | |
| 237 | isl::schedule visitSchedule(isl::schedule Schedule) { |
| 238 | isl::union_map Extensions; |
| 239 | isl::schedule Result = |
| 240 | visit(Node: Schedule.get_root(), args: Schedule.get_domain(), args&: Extensions); |
| 241 | assert(!Extensions.is_null() && Extensions.is_empty()); |
| 242 | return Result; |
| 243 | } |
| 244 | |
| 245 | isl::schedule visitSequence(isl::schedule_node_sequence Sequence, |
| 246 | const isl::union_set &Domain, |
| 247 | isl::union_map &Extensions) { |
| 248 | int NumChildren = isl_schedule_node_n_children(node: Sequence.get()); |
| 249 | isl::schedule NewNode = visit(Node: Sequence.first_child(), args: Domain, args&: Extensions); |
| 250 | for (int i = 1; i < NumChildren; i += 1) { |
| 251 | isl::schedule_node OldChild = Sequence.child(pos: i); |
| 252 | isl::union_map NewChildExtensions; |
| 253 | isl::schedule NewChildNode = visit(Node: OldChild, args: Domain, args&: NewChildExtensions); |
| 254 | NewNode = NewNode.sequence(schedule2: NewChildNode); |
| 255 | Extensions = Extensions.unite(umap2: NewChildExtensions); |
| 256 | } |
| 257 | return NewNode; |
| 258 | } |
| 259 | |
| 260 | isl::schedule visitSet(isl::schedule_node_set Set, |
| 261 | const isl::union_set &Domain, |
| 262 | isl::union_map &Extensions) { |
| 263 | int NumChildren = isl_schedule_node_n_children(node: Set.get()); |
| 264 | isl::schedule NewNode = visit(Node: Set.first_child(), args: Domain, args&: Extensions); |
| 265 | for (int i = 1; i < NumChildren; i += 1) { |
| 266 | isl::schedule_node OldChild = Set.child(pos: i); |
| 267 | isl::union_map NewChildExtensions; |
| 268 | isl::schedule NewChildNode = visit(Node: OldChild, args: Domain, args&: NewChildExtensions); |
| 269 | NewNode = isl::manage( |
| 270 | ptr: isl_schedule_set(schedule1: NewNode.release(), schedule2: NewChildNode.release())); |
| 271 | Extensions = Extensions.unite(umap2: NewChildExtensions); |
| 272 | } |
| 273 | return NewNode; |
| 274 | } |
| 275 | |
| 276 | isl::schedule visitLeaf(isl::schedule_node_leaf Leaf, |
| 277 | const isl::union_set &Domain, |
| 278 | isl::union_map &Extensions) { |
| 279 | Extensions = isl::union_map::empty(ctx: Leaf.ctx()); |
| 280 | return isl::schedule::from_domain(domain: Domain); |
| 281 | } |
| 282 | |
| 283 | isl::schedule visitBand(isl::schedule_node_band OldNode, |
| 284 | const isl::union_set &Domain, |
| 285 | isl::union_map &OuterExtensions) { |
| 286 | isl::schedule_node OldChild = OldNode.first_child(); |
| 287 | isl::multi_union_pw_aff PartialSched = |
| 288 | isl::manage(ptr: isl_schedule_node_band_get_partial_schedule(node: OldNode.get())); |
| 289 | |
| 290 | isl::union_map NewChildExtensions; |
| 291 | isl::schedule NewChild = visit(Node: OldChild, args: Domain, args&: NewChildExtensions); |
| 292 | |
| 293 | // Add the extensions to the partial schedule. |
| 294 | OuterExtensions = isl::union_map::empty(ctx: NewChildExtensions.ctx()); |
| 295 | isl::union_map NewPartialSchedMap = isl::union_map::from(mupa: PartialSched); |
| 296 | unsigned BandDims = isl_schedule_node_band_n_member(node: OldNode.get()); |
| 297 | for (isl::map Ext : NewChildExtensions.get_map_list()) { |
| 298 | unsigned ExtDims = unsignedFromIslSize(Size: Ext.domain_tuple_dim()); |
| 299 | assert(ExtDims >= BandDims); |
| 300 | unsigned OuterDims = ExtDims - BandDims; |
| 301 | |
| 302 | isl::map BandSched = |
| 303 | Ext.project_out(type: isl::dim::in, first: 0, n: OuterDims).reverse(); |
| 304 | NewPartialSchedMap = NewPartialSchedMap.unite(umap2: BandSched); |
| 305 | |
| 306 | // There might be more outer bands that have to schedule the extensions. |
| 307 | if (OuterDims > 0) { |
| 308 | isl::map OuterSched = |
| 309 | Ext.project_out(type: isl::dim::in, first: OuterDims, n: BandDims); |
| 310 | OuterExtensions = OuterExtensions.unite(umap2: OuterSched); |
| 311 | } |
| 312 | } |
| 313 | isl::multi_union_pw_aff NewPartialSchedAsAsMultiUnionPwAff = |
| 314 | isl::multi_union_pw_aff::from_union_map(umap: NewPartialSchedMap); |
| 315 | isl::schedule_node NewNode = |
| 316 | NewChild.insert_partial_schedule(partial: NewPartialSchedAsAsMultiUnionPwAff) |
| 317 | .get_root() |
| 318 | .child(pos: 0); |
| 319 | |
| 320 | // Reapply permutability and coincidence attributes. |
| 321 | NewNode = isl::manage(ptr: isl_schedule_node_band_set_permutable( |
| 322 | node: NewNode.release(), |
| 323 | permutable: isl_schedule_node_band_get_permutable(node: OldNode.get()))); |
| 324 | for (unsigned i = 0; i < BandDims; i += 1) |
| 325 | NewNode = applyBandMemberAttributes(Target: NewNode.as<isl::schedule_node_band>(), |
| 326 | TargetIdx: i, Source: OldNode, SourceIdx: i); |
| 327 | |
| 328 | return NewNode.get_schedule(); |
| 329 | } |
| 330 | |
| 331 | isl::schedule visitFilter(isl::schedule_node_filter Filter, |
| 332 | const isl::union_set &Domain, |
| 333 | isl::union_map &Extensions) { |
| 334 | isl::union_set FilterDomain = |
| 335 | Filter.as<isl::schedule_node_filter>().get_filter(); |
| 336 | isl::union_set NewDomain = Domain.intersect(uset2: FilterDomain); |
| 337 | |
| 338 | // A filter is added implicitly if necessary when joining schedule trees. |
| 339 | return visit(Node: Filter.first_child(), args: NewDomain, args&: Extensions); |
| 340 | } |
| 341 | |
| 342 | isl::schedule visitExtension(isl::schedule_node_extension Extension, |
| 343 | const isl::union_set &Domain, |
| 344 | isl::union_map &Extensions) { |
| 345 | isl::union_map ExtDomain = |
| 346 | Extension.as<isl::schedule_node_extension>().get_extension(); |
| 347 | isl::union_set NewDomain = Domain.unite(uset2: ExtDomain.range()); |
| 348 | isl::union_map ChildExtensions; |
| 349 | isl::schedule NewChild = |
| 350 | visit(Node: Extension.first_child(), args: NewDomain, args&: ChildExtensions); |
| 351 | Extensions = ChildExtensions.unite(umap2: ExtDomain); |
| 352 | return NewChild; |
| 353 | } |
| 354 | }; |
| 355 | |
| 356 | /// Collect all AST build options in any schedule tree band. |
| 357 | /// |
| 358 | /// ScheduleTreeRewriter cannot apply the schedule tree options. This class |
| 359 | /// collects these options to apply them later. |
| 360 | struct CollectASTBuildOptions final |
| 361 | : RecursiveScheduleTreeVisitor<CollectASTBuildOptions> { |
| 362 | using BaseTy = RecursiveScheduleTreeVisitor<CollectASTBuildOptions>; |
| 363 | BaseTy &getBase() { return *this; } |
| 364 | const BaseTy &getBase() const { return *this; } |
| 365 | |
| 366 | llvm::SmallVector<isl::union_set, 8> ASTBuildOptions; |
| 367 | |
| 368 | void visitBand(isl::schedule_node_band Band) { |
| 369 | ASTBuildOptions.push_back( |
| 370 | Elt: isl::manage(ptr: isl_schedule_node_band_get_ast_build_options(node: Band.get()))); |
| 371 | return getBase().visitBand(Band); |
| 372 | } |
| 373 | }; |
| 374 | |
| 375 | /// Apply AST build options to the bands in a schedule tree. |
| 376 | /// |
| 377 | /// This rewrites a schedule tree with the AST build options applied. We assume |
| 378 | /// that the band nodes are visited in the same order as they were when the |
| 379 | /// build options were collected, typically by CollectASTBuildOptions. |
| 380 | struct ApplyASTBuildOptions final : ScheduleNodeRewriter<ApplyASTBuildOptions> { |
| 381 | using BaseTy = ScheduleNodeRewriter<ApplyASTBuildOptions>; |
| 382 | BaseTy &getBase() { return *this; } |
| 383 | const BaseTy &getBase() const { return *this; } |
| 384 | |
| 385 | size_t Pos; |
| 386 | llvm::ArrayRef<isl::union_set> ASTBuildOptions; |
| 387 | |
| 388 | ApplyASTBuildOptions(llvm::ArrayRef<isl::union_set> ASTBuildOptions) |
| 389 | : ASTBuildOptions(ASTBuildOptions) {} |
| 390 | |
| 391 | isl::schedule visitSchedule(isl::schedule Schedule) { |
| 392 | Pos = 0; |
| 393 | isl::schedule Result = visit(Schedule).get_schedule(); |
| 394 | assert(Pos == ASTBuildOptions.size() && |
| 395 | "AST build options must match to band nodes" ); |
| 396 | return Result; |
| 397 | } |
| 398 | |
| 399 | isl::schedule_node visitBand(isl::schedule_node_band Band) { |
| 400 | isl::schedule_node_band Result = |
| 401 | Band.set_ast_build_options(ASTBuildOptions[Pos]); |
| 402 | Pos += 1; |
| 403 | return getBase().visitBand(Band: Result); |
| 404 | } |
| 405 | }; |
| 406 | |
| 407 | /// Return whether the schedule contains an extension node. |
| 408 | static bool containsExtensionNode(isl::schedule Schedule) { |
| 409 | assert(!Schedule.is_null()); |
| 410 | |
| 411 | auto Callback = [](__isl_keep isl_schedule_node *Node, |
| 412 | void *User) -> isl_bool { |
| 413 | if (isl_schedule_node_get_type(node: Node) == isl_schedule_node_extension) { |
| 414 | // Stop walking the schedule tree. |
| 415 | return isl_bool_error; |
| 416 | } |
| 417 | |
| 418 | // Continue searching the subtree. |
| 419 | return isl_bool_true; |
| 420 | }; |
| 421 | isl_stat RetVal = isl_schedule_foreach_schedule_node_top_down( |
| 422 | sched: Schedule.get(), fn: Callback, user: nullptr); |
| 423 | |
| 424 | // We assume that the traversal itself does not fail, i.e. the only reason to |
| 425 | // return isl_stat_error is that an extension node was found. |
| 426 | return RetVal == isl_stat_error; |
| 427 | } |
| 428 | |
| 429 | /// Find a named MDNode property in a LoopID. |
| 430 | static MDNode *findOptionalNodeOperand(MDNode *LoopMD, StringRef Name) { |
| 431 | return dyn_cast_or_null<MDNode>( |
| 432 | Val: findMetadataOperand(LoopMD, Name).value_or(u: nullptr)); |
| 433 | } |
| 434 | |
| 435 | /// Is this node of type mark? |
| 436 | static bool isMark(const isl::schedule_node &Node) { |
| 437 | return isl_schedule_node_get_type(node: Node.get()) == isl_schedule_node_mark; |
| 438 | } |
| 439 | |
| 440 | /// Is this node of type band? |
| 441 | static bool isBand(const isl::schedule_node &Node) { |
| 442 | return isl_schedule_node_get_type(node: Node.get()) == isl_schedule_node_band; |
| 443 | } |
| 444 | |
| 445 | #ifndef NDEBUG |
| 446 | /// Is this node a band of a single dimension (i.e. could represent a loop)? |
| 447 | static bool isBandWithSingleLoop(const isl::schedule_node &Node) { |
| 448 | return isBand(Node) && isl_schedule_node_band_n_member(node: Node.get()) == 1; |
| 449 | } |
| 450 | #endif |
| 451 | |
| 452 | static bool isLeaf(const isl::schedule_node &Node) { |
| 453 | return isl_schedule_node_get_type(node: Node.get()) == isl_schedule_node_leaf; |
| 454 | } |
| 455 | |
| 456 | /// Create an isl::id representing the output loop after a transformation. |
| 457 | static isl::id createGeneratedLoopAttr(isl::ctx Ctx, MDNode *FollowupLoopMD) { |
| 458 | // Don't need to id the followup. |
| 459 | // TODO: Append llvm.loop.disable_heustistics metadata unless overridden by |
| 460 | // user followup-MD |
| 461 | if (!FollowupLoopMD) |
| 462 | return {}; |
| 463 | |
| 464 | BandAttr *Attr = new BandAttr(); |
| 465 | Attr->Metadata = FollowupLoopMD; |
| 466 | return getIslLoopAttr(Ctx, Attr); |
| 467 | } |
| 468 | |
| 469 | /// A loop consists of a band and an optional marker that wraps it. Return the |
| 470 | /// outermost of the two. |
| 471 | |
| 472 | /// That is, either the mark or, if there is not mark, the loop itself. Can |
| 473 | /// start with either the mark or the band. |
| 474 | static isl::schedule_node moveToBandMark(isl::schedule_node BandOrMark) { |
| 475 | if (isBandMark(Node: BandOrMark)) { |
| 476 | assert(isBandWithSingleLoop(BandOrMark.child(0))); |
| 477 | return BandOrMark; |
| 478 | } |
| 479 | assert(isBandWithSingleLoop(BandOrMark)); |
| 480 | |
| 481 | isl::schedule_node Mark = BandOrMark.parent(); |
| 482 | if (isBandMark(Node: Mark)) |
| 483 | return Mark; |
| 484 | |
| 485 | // Band has no loop marker. |
| 486 | return BandOrMark; |
| 487 | } |
| 488 | |
| 489 | static isl::schedule_node removeMark(isl::schedule_node MarkOrBand, |
| 490 | BandAttr *&Attr) { |
| 491 | MarkOrBand = moveToBandMark(BandOrMark: MarkOrBand); |
| 492 | |
| 493 | isl::schedule_node Band; |
| 494 | if (isMark(Node: MarkOrBand)) { |
| 495 | Attr = getLoopAttr(Id: MarkOrBand.as<isl::schedule_node_mark>().get_id()); |
| 496 | Band = isl::manage(ptr: isl_schedule_node_delete(node: MarkOrBand.release())); |
| 497 | } else { |
| 498 | Attr = nullptr; |
| 499 | Band = MarkOrBand; |
| 500 | } |
| 501 | |
| 502 | assert(isBandWithSingleLoop(Band)); |
| 503 | return Band; |
| 504 | } |
| 505 | |
| 506 | /// Remove the mark that wraps a loop. Return the band representing the loop. |
| 507 | static isl::schedule_node removeMark(isl::schedule_node MarkOrBand) { |
| 508 | BandAttr *Attr; |
| 509 | return removeMark(MarkOrBand, Attr); |
| 510 | } |
| 511 | |
| 512 | static isl::schedule_node insertMark(isl::schedule_node Band, isl::id Mark) { |
| 513 | assert(isBand(Band)); |
| 514 | assert(moveToBandMark(Band).is_equal(Band) && |
| 515 | "Don't add a two marks for a band" ); |
| 516 | |
| 517 | return Band.insert_mark(mark: Mark).child(pos: 0); |
| 518 | } |
| 519 | |
| 520 | /// Return the (one-dimensional) set of numbers that are divisible by @p Factor |
| 521 | /// with remainder @p Offset. |
| 522 | /// |
| 523 | /// isDivisibleBySet(Ctx, 4, 0) = { [i] : floord(i,4) = 0 } |
| 524 | /// isDivisibleBySet(Ctx, 4, 1) = { [i] : floord(i,4) = 1 } |
| 525 | /// |
| 526 | static isl::basic_set isDivisibleBySet(isl::ctx &Ctx, long Factor, |
| 527 | long Offset) { |
| 528 | isl::val ValFactor{Ctx, Factor}; |
| 529 | isl::val ValOffset{Ctx, Offset}; |
| 530 | |
| 531 | isl::space Unispace{Ctx, 0, 1}; |
| 532 | isl::local_space LUnispace{Unispace}; |
| 533 | isl::aff AffFactor{LUnispace, ValFactor}; |
| 534 | isl::aff AffOffset{LUnispace, ValOffset}; |
| 535 | |
| 536 | isl::aff Id = isl::aff::var_on_domain(ls: LUnispace, type: isl::dim::out, pos: 0); |
| 537 | isl::aff DivMul = Id.mod(mod: ValFactor); |
| 538 | isl::basic_map Divisible = isl::basic_map::from_aff(aff: DivMul); |
| 539 | isl::basic_map Modulo = Divisible.fix_val(type: isl::dim::out, pos: 0, v: ValOffset); |
| 540 | return Modulo.domain(); |
| 541 | } |
| 542 | |
| 543 | /// Make the last dimension of Set to take values from 0 to VectorWidth - 1. |
| 544 | /// |
| 545 | /// @param Set A set, which should be modified. |
| 546 | /// @param VectorWidth A parameter, which determines the constraint. |
| 547 | static isl::set addExtentConstraints(isl::set Set, int VectorWidth) { |
| 548 | unsigned Dims = unsignedFromIslSize(Size: Set.tuple_dim()); |
| 549 | assert(Dims >= 1); |
| 550 | isl::space Space = Set.get_space(); |
| 551 | isl::local_space LocalSpace = isl::local_space(Space); |
| 552 | isl::constraint ExtConstr = isl::constraint::alloc_inequality(ls: LocalSpace); |
| 553 | ExtConstr = ExtConstr.set_constant_si(0); |
| 554 | ExtConstr = ExtConstr.set_coefficient_si(type: isl::dim::set, pos: Dims - 1, v: 1); |
| 555 | Set = Set.add_constraint(constraint: ExtConstr); |
| 556 | ExtConstr = isl::constraint::alloc_inequality(ls: LocalSpace); |
| 557 | ExtConstr = ExtConstr.set_constant_si(VectorWidth - 1); |
| 558 | ExtConstr = ExtConstr.set_coefficient_si(type: isl::dim::set, pos: Dims - 1, v: -1); |
| 559 | return Set.add_constraint(constraint: ExtConstr); |
| 560 | } |
| 561 | |
| 562 | /// Collapse perfectly nested bands into a single band. |
| 563 | class BandCollapseRewriter final |
| 564 | : public ScheduleTreeRewriter<BandCollapseRewriter> { |
| 565 | private: |
| 566 | using BaseTy = ScheduleTreeRewriter<BandCollapseRewriter>; |
| 567 | BaseTy &getBase() { return *this; } |
| 568 | const BaseTy &getBase() const { return *this; } |
| 569 | |
| 570 | public: |
| 571 | isl::schedule visitBand(isl::schedule_node_band RootBand) { |
| 572 | isl::schedule_node_band Band = RootBand; |
| 573 | isl::ctx Ctx = Band.ctx(); |
| 574 | |
| 575 | // Do not merge permutable band to avoid losing the permutability property. |
| 576 | // Cannot collapse even two permutable loops, they might be permutable |
| 577 | // individually, but not necassarily across. |
| 578 | if (unsignedFromIslSize(Size: Band.n_member()) > 1u && Band.permutable()) |
| 579 | return getBase().visitBand(Band); |
| 580 | |
| 581 | // Find collapsible bands. |
| 582 | SmallVector<isl::schedule_node_band> Nest; |
| 583 | int NumTotalLoops = 0; |
| 584 | isl::schedule_node Body; |
| 585 | while (true) { |
| 586 | Nest.push_back(Elt: Band); |
| 587 | NumTotalLoops += unsignedFromIslSize(Size: Band.n_member()); |
| 588 | Body = Band.first_child(); |
| 589 | if (!Body.isa<isl::schedule_node_band>()) |
| 590 | break; |
| 591 | Band = Body.as<isl::schedule_node_band>(); |
| 592 | |
| 593 | // Do not include next band if it is permutable to not lose its |
| 594 | // permutability property. |
| 595 | if (unsignedFromIslSize(Size: Band.n_member()) > 1u && Band.permutable()) |
| 596 | break; |
| 597 | } |
| 598 | |
| 599 | // Nothing to collapse, preserve permutability. |
| 600 | if (Nest.size() <= 1) |
| 601 | return getBase().visitBand(Band); |
| 602 | |
| 603 | POLLY_DEBUG({ |
| 604 | dbgs() << "Found loops to collapse between\n" ; |
| 605 | dumpIslObj(RootBand, dbgs()); |
| 606 | dbgs() << "and\n" ; |
| 607 | dumpIslObj(Body, dbgs()); |
| 608 | dbgs() << "\n" ; |
| 609 | }); |
| 610 | |
| 611 | isl::schedule NewBody = visit(Node: Body); |
| 612 | |
| 613 | // Collect partial schedules from all members. |
| 614 | isl::union_pw_aff_list PartScheds{Ctx, NumTotalLoops}; |
| 615 | for (isl::schedule_node_band Band : Nest) { |
| 616 | int NumLoops = unsignedFromIslSize(Size: Band.n_member()); |
| 617 | isl::multi_union_pw_aff BandScheds = Band.get_partial_schedule(); |
| 618 | for (auto j : seq<int>(Begin: 0, End: NumLoops)) |
| 619 | PartScheds = PartScheds.add(el: BandScheds.at(pos: j)); |
| 620 | } |
| 621 | isl::space ScatterSpace = isl::space(Ctx, 0, NumTotalLoops); |
| 622 | isl::multi_union_pw_aff PartSchedsMulti{ScatterSpace, PartScheds}; |
| 623 | |
| 624 | isl::schedule_node_band CollapsedBand = |
| 625 | NewBody.insert_partial_schedule(partial: PartSchedsMulti) |
| 626 | .get_root() |
| 627 | .first_child() |
| 628 | .as<isl::schedule_node_band>(); |
| 629 | |
| 630 | // Copy over loop attributes form original bands. |
| 631 | int LoopIdx = 0; |
| 632 | for (isl::schedule_node_band Band : Nest) { |
| 633 | int NumLoops = unsignedFromIslSize(Size: Band.n_member()); |
| 634 | for (int i : seq<int>(Begin: 0, End: NumLoops)) { |
| 635 | CollapsedBand = applyBandMemberAttributes(Target: std::move(CollapsedBand), |
| 636 | TargetIdx: LoopIdx, Source: Band, SourceIdx: i); |
| 637 | LoopIdx += 1; |
| 638 | } |
| 639 | } |
| 640 | assert(LoopIdx == NumTotalLoops && |
| 641 | "Expect the same number of loops to add up again" ); |
| 642 | |
| 643 | return CollapsedBand.get_schedule(); |
| 644 | } |
| 645 | }; |
| 646 | |
| 647 | static isl::schedule collapseBands(isl::schedule Sched) { |
| 648 | POLLY_DEBUG(dbgs() << "Collapse bands in schedule\n" ); |
| 649 | BandCollapseRewriter Rewriter; |
| 650 | return Rewriter.visit(Schedule: Sched); |
| 651 | } |
| 652 | |
| 653 | /// Collect sequentially executed bands (or anything else), even if nested in a |
| 654 | /// mark or other nodes whose child is executed just once. If we can |
| 655 | /// successfully fuse the bands, we allow them to be removed. |
| 656 | static void collectPotentiallyFusableBands( |
| 657 | isl::schedule_node Node, |
| 658 | SmallVectorImpl<std::pair<isl::schedule_node, isl::schedule_node>> |
| 659 | &ScheduleBands, |
| 660 | const isl::schedule_node &DirectChild) { |
| 661 | switch (isl_schedule_node_get_type(node: Node.get())) { |
| 662 | case isl_schedule_node_sequence: |
| 663 | case isl_schedule_node_set: |
| 664 | case isl_schedule_node_mark: |
| 665 | case isl_schedule_node_domain: |
| 666 | case isl_schedule_node_filter: |
| 667 | if (Node.has_children()) { |
| 668 | isl::schedule_node C = Node.first_child(); |
| 669 | while (true) { |
| 670 | collectPotentiallyFusableBands(Node: C, ScheduleBands, DirectChild); |
| 671 | if (!C.has_next_sibling()) |
| 672 | break; |
| 673 | C = C.next_sibling(); |
| 674 | } |
| 675 | } |
| 676 | break; |
| 677 | |
| 678 | default: |
| 679 | // Something that does not execute suquentially (e.g. a band) |
| 680 | ScheduleBands.push_back(Elt: {Node, DirectChild}); |
| 681 | break; |
| 682 | } |
| 683 | } |
| 684 | |
| 685 | /// Remove dependencies that are resolved by @p PartSched. That is, remove |
| 686 | /// everything that we already know is executed in-order. |
| 687 | static isl::union_map remainingDepsFromPartialSchedule(isl::union_map PartSched, |
| 688 | isl::union_map Deps) { |
| 689 | unsigned NumDims = getNumScatterDims(Schedule: PartSched); |
| 690 | auto ParamSpace = PartSched.get_space().params(); |
| 691 | |
| 692 | // { Scatter[] } |
| 693 | isl::space ScatterSpace = |
| 694 | ParamSpace.set_from_params().add_dims(type: isl::dim::set, n: NumDims); |
| 695 | |
| 696 | // { Scatter[] -> Domain[] } |
| 697 | isl::union_map PartSchedRev = PartSched.reverse(); |
| 698 | |
| 699 | // { Scatter[] -> Scatter[] } |
| 700 | isl::map MaybeBefore = isl::map::lex_le(set_space: ScatterSpace); |
| 701 | |
| 702 | // { Domain[] -> Domain[] } |
| 703 | isl::union_map DomMaybeBefore = |
| 704 | MaybeBefore.apply_domain(umap2: PartSchedRev).apply_range(umap2: PartSchedRev); |
| 705 | |
| 706 | // { Domain[] -> Domain[] } |
| 707 | isl::union_map ChildRemainingDeps = Deps.intersect(umap2: DomMaybeBefore); |
| 708 | |
| 709 | return ChildRemainingDeps; |
| 710 | } |
| 711 | |
| 712 | /// Remove dependencies that are resolved by executing them in the order |
| 713 | /// specified by @p Domains; |
| 714 | static isl::union_map remainigDepsFromSequence(ArrayRef<isl::union_set> Domains, |
| 715 | isl::union_map Deps) { |
| 716 | isl::ctx Ctx = Deps.ctx(); |
| 717 | isl::space ParamSpace = Deps.get_space().params(); |
| 718 | |
| 719 | // Create a partial schedule mapping to constants that reflect the execution |
| 720 | // order. |
| 721 | isl::union_map PartialSchedules = isl::union_map::empty(ctx: Ctx); |
| 722 | for (auto P : enumerate(First&: Domains)) { |
| 723 | isl::val ExecTime = isl::val(Ctx, P.index()); |
| 724 | isl::union_pw_aff DomSched{P.value(), ExecTime}; |
| 725 | PartialSchedules = PartialSchedules.unite(umap2: DomSched.as_union_map()); |
| 726 | } |
| 727 | |
| 728 | return remainingDepsFromPartialSchedule(PartSched: PartialSchedules, Deps); |
| 729 | } |
| 730 | |
| 731 | /// Determine whether the outermost loop of to bands can be fused while |
| 732 | /// respecting validity dependencies. |
| 733 | static bool canFuseOutermost(const isl::schedule_node_band &LHS, |
| 734 | const isl::schedule_node_band &RHS, |
| 735 | const isl::union_map &Deps) { |
| 736 | // { LHSDomain[] -> Scatter[] } |
| 737 | isl::union_map LHSPartSched = |
| 738 | LHS.get_partial_schedule().get_at(pos: 0).as_union_map(); |
| 739 | |
| 740 | // { Domain[] -> Scatter[] } |
| 741 | isl::union_map RHSPartSched = |
| 742 | RHS.get_partial_schedule().get_at(pos: 0).as_union_map(); |
| 743 | |
| 744 | // Dependencies that are already resolved because LHS executes before RHS, but |
| 745 | // will not be anymore after fusion. { DefDomain[] -> UseDomain[] } |
| 746 | isl::union_map OrderedBySequence = |
| 747 | Deps.intersect_domain(uset: LHSPartSched.domain()) |
| 748 | .intersect_range(uset: RHSPartSched.domain()); |
| 749 | |
| 750 | isl::space ParamSpace = OrderedBySequence.get_space().params(); |
| 751 | isl::space NewScatterSpace = ParamSpace.add_unnamed_tuple(dim: 1); |
| 752 | |
| 753 | // { Scatter[] -> Scatter[] } |
| 754 | isl::map After = isl::map::lex_gt(set_space: NewScatterSpace); |
| 755 | |
| 756 | // After fusion, instances with smaller (or equal, which means they will be |
| 757 | // executed in the same iteration, but the LHS instance is still sequenced |
| 758 | // before RHS) scatter value will still be executed before. This are the |
| 759 | // orderings where this is not necessarily the case. |
| 760 | // { LHSDomain[] -> RHSDomain[] } |
| 761 | isl::union_map MightBeAfterDoms = After.apply_domain(umap2: LHSPartSched.reverse()) |
| 762 | .apply_range(umap2: RHSPartSched.reverse()); |
| 763 | |
| 764 | // Dependencies that are not resolved by the new execution order. |
| 765 | isl::union_map WithBefore = OrderedBySequence.intersect(umap2: MightBeAfterDoms); |
| 766 | |
| 767 | return WithBefore.is_empty(); |
| 768 | } |
| 769 | |
| 770 | /// Fuse @p LHS and @p RHS if possible while preserving validity dependenvies. |
| 771 | static isl::schedule tryGreedyFuse(isl::schedule_node_band LHS, |
| 772 | isl::schedule_node_band RHS, |
| 773 | const isl::union_map &Deps) { |
| 774 | if (!canFuseOutermost(LHS, RHS, Deps)) |
| 775 | return {}; |
| 776 | |
| 777 | POLLY_DEBUG({ |
| 778 | dbgs() << "Found loops for greedy fusion:\n" ; |
| 779 | dumpIslObj(LHS, dbgs()); |
| 780 | dbgs() << "and\n" ; |
| 781 | dumpIslObj(RHS, dbgs()); |
| 782 | dbgs() << "\n" ; |
| 783 | }); |
| 784 | |
| 785 | // The partial schedule of the bands outermost loop that we need to combine |
| 786 | // for the fusion. |
| 787 | isl::union_pw_aff LHSPartOuterSched = LHS.get_partial_schedule().get_at(pos: 0); |
| 788 | isl::union_pw_aff RHSPartOuterSched = RHS.get_partial_schedule().get_at(pos: 0); |
| 789 | |
| 790 | // Isolate band bodies as roots of their own schedule trees. |
| 791 | IdentityRewriter Rewriter; |
| 792 | isl::schedule LHSBody = Rewriter.visit(Node: LHS.first_child()); |
| 793 | isl::schedule RHSBody = Rewriter.visit(Node: RHS.first_child()); |
| 794 | |
| 795 | // Reconstruct the non-outermost (not going to be fused) loops from both |
| 796 | // bands. |
| 797 | // TODO: Maybe it is possibly to transfer the 'permutability' property from |
| 798 | // LHS+RHS. At minimum we need merge multiple band members at once, otherwise |
| 799 | // permutability has no meaning. |
| 800 | isl::schedule LHSNewBody = |
| 801 | rebuildBand(OldBand: LHS, Body: LHSBody, IncludeCb: [](int i) { return i > 0; }); |
| 802 | isl::schedule RHSNewBody = |
| 803 | rebuildBand(OldBand: RHS, Body: RHSBody, IncludeCb: [](int i) { return i > 0; }); |
| 804 | |
| 805 | // The loop body of the fused loop. |
| 806 | isl::schedule NewCommonBody = LHSNewBody.sequence(schedule2: RHSNewBody); |
| 807 | |
| 808 | // Combine the partial schedules of both loops to a new one. Instances with |
| 809 | // the same scatter value are put together. |
| 810 | isl::union_map NewCommonPartialSched = |
| 811 | LHSPartOuterSched.as_union_map().unite(umap2: RHSPartOuterSched.as_union_map()); |
| 812 | isl::schedule NewCommonSchedule = NewCommonBody.insert_partial_schedule( |
| 813 | partial: NewCommonPartialSched.as_multi_union_pw_aff()); |
| 814 | |
| 815 | return NewCommonSchedule; |
| 816 | } |
| 817 | |
| 818 | static isl::schedule tryGreedyFuse(isl::schedule_node LHS, |
| 819 | isl::schedule_node RHS, |
| 820 | const isl::union_map &Deps) { |
| 821 | // TODO: Non-bands could be interpreted as a band with just as single |
| 822 | // iteration. However, this is only useful if both ends of a fused loop were |
| 823 | // originally loops themselves. |
| 824 | if (!LHS.isa<isl::schedule_node_band>()) |
| 825 | return {}; |
| 826 | if (!RHS.isa<isl::schedule_node_band>()) |
| 827 | return {}; |
| 828 | |
| 829 | return tryGreedyFuse(LHS: LHS.as<isl::schedule_node_band>(), |
| 830 | RHS: RHS.as<isl::schedule_node_band>(), Deps); |
| 831 | } |
| 832 | |
| 833 | /// Fuse all fusable loop top-down in a schedule tree. |
| 834 | /// |
| 835 | /// The isl::union_map parameters is the set of validity dependencies that have |
| 836 | /// not been resolved/carried by a parent schedule node. |
| 837 | class GreedyFusionRewriter final |
| 838 | : public ScheduleTreeRewriter<GreedyFusionRewriter, isl::union_map> { |
| 839 | private: |
| 840 | using BaseTy = ScheduleTreeRewriter<GreedyFusionRewriter, isl::union_map>; |
| 841 | BaseTy &getBase() { return *this; } |
| 842 | const BaseTy &getBase() const { return *this; } |
| 843 | |
| 844 | public: |
| 845 | /// Is set to true if anything has been fused. |
| 846 | bool AnyChange = false; |
| 847 | |
| 848 | isl::schedule visitBand(isl::schedule_node_band Band, isl::union_map Deps) { |
| 849 | // { Domain[] -> Scatter[] } |
| 850 | isl::union_map PartSched = |
| 851 | isl::union_map::from(mupa: Band.get_partial_schedule()); |
| 852 | assert(getNumScatterDims(PartSched) == |
| 853 | unsignedFromIslSize(Band.n_member())); |
| 854 | isl::space ParamSpace = PartSched.get_space().params(); |
| 855 | |
| 856 | // { Scatter[] -> Domain[] } |
| 857 | isl::union_map PartSchedRev = PartSched.reverse(); |
| 858 | |
| 859 | // Possible within the same iteration. Dependencies with smaller scatter |
| 860 | // value are carried by this loop and therefore have been resolved by the |
| 861 | // in-order execution if the loop iteration. A dependency with small scatter |
| 862 | // value would be a dependency violation that we assume did not happen. { |
| 863 | // Domain[] -> Domain[] } |
| 864 | isl::union_map Unsequenced = PartSchedRev.apply_domain(umap2: PartSchedRev); |
| 865 | |
| 866 | // Actual dependencies within the same iteration. |
| 867 | // { DefDomain[] -> UseDomain[] } |
| 868 | isl::union_map RemDeps = Deps.intersect(umap2: Unsequenced); |
| 869 | |
| 870 | return getBase().visitBand(Band, args: RemDeps); |
| 871 | } |
| 872 | |
| 873 | isl::schedule visitSequence(isl::schedule_node_sequence Sequence, |
| 874 | isl::union_map Deps) { |
| 875 | int NumChildren = isl_schedule_node_n_children(node: Sequence.get()); |
| 876 | |
| 877 | // List of fusion candidates. The first element is the fusion candidate, the |
| 878 | // second is candidate's ancestor that is the sequence's direct child. It is |
| 879 | // preferable to use the direct child if not if its non-direct children is |
| 880 | // fused to preserve its structure such as mark nodes. |
| 881 | SmallVector<std::pair<isl::schedule_node, isl::schedule_node>> Bands; |
| 882 | for (auto i : seq<int>(Begin: 0, End: NumChildren)) { |
| 883 | isl::schedule_node Child = Sequence.child(pos: i); |
| 884 | collectPotentiallyFusableBands(Node: Child, ScheduleBands&: Bands, DirectChild: Child); |
| 885 | } |
| 886 | |
| 887 | // Direct children that had at least one of its descendants fused. |
| 888 | SmallDenseSet<isl_schedule_node *, 4> ChangedDirectChildren; |
| 889 | |
| 890 | // Fuse neighboring bands until reaching the end of candidates. |
| 891 | int i = 0; |
| 892 | while (i + 1 < (int)Bands.size()) { |
| 893 | isl::schedule Fused = |
| 894 | tryGreedyFuse(LHS: Bands[i].first, RHS: Bands[i + 1].first, Deps); |
| 895 | if (Fused.is_null()) { |
| 896 | // Cannot merge this node with the next; look at next pair. |
| 897 | i += 1; |
| 898 | continue; |
| 899 | } |
| 900 | |
| 901 | // Mark the direct children as (partially) fused. |
| 902 | if (!Bands[i].second.is_null()) |
| 903 | ChangedDirectChildren.insert(V: Bands[i].second.get()); |
| 904 | if (!Bands[i + 1].second.is_null()) |
| 905 | ChangedDirectChildren.insert(V: Bands[i + 1].second.get()); |
| 906 | |
| 907 | // Collapse the neigbros to a single new candidate that could be fused |
| 908 | // with the next candidate. |
| 909 | Bands[i] = {Fused.get_root(), {}}; |
| 910 | Bands.erase(CI: Bands.begin() + i + 1); |
| 911 | |
| 912 | AnyChange = true; |
| 913 | } |
| 914 | |
| 915 | // By construction equal if done with collectPotentiallyFusableBands's |
| 916 | // output. |
| 917 | SmallVector<isl::union_set> SubDomains; |
| 918 | SubDomains.reserve(N: NumChildren); |
| 919 | for (int i = 0; i < NumChildren; i += 1) |
| 920 | SubDomains.push_back(Elt: Sequence.child(pos: i).domain()); |
| 921 | auto SubRemainingDeps = remainigDepsFromSequence(Domains: SubDomains, Deps); |
| 922 | |
| 923 | // We may iterate over direct children multiple times, be sure to add each |
| 924 | // at most once. |
| 925 | SmallDenseSet<isl_schedule_node *, 4> AlreadyAdded; |
| 926 | |
| 927 | isl::schedule Result; |
| 928 | for (auto &P : Bands) { |
| 929 | isl::schedule_node MaybeFused = P.first; |
| 930 | isl::schedule_node DirectChild = P.second; |
| 931 | |
| 932 | // If not modified, use the direct child. |
| 933 | if (!DirectChild.is_null() && |
| 934 | !ChangedDirectChildren.count(V: DirectChild.get())) { |
| 935 | if (AlreadyAdded.count(V: DirectChild.get())) |
| 936 | continue; |
| 937 | AlreadyAdded.insert(V: DirectChild.get()); |
| 938 | MaybeFused = DirectChild; |
| 939 | } else { |
| 940 | assert(AnyChange && |
| 941 | "Need changed flag for be consistent with actual change" ); |
| 942 | } |
| 943 | |
| 944 | // Top-down recursion: If the outermost loop has been fused, their nested |
| 945 | // bands might be fusable now as well. |
| 946 | isl::schedule InnerFused = visit(Node: MaybeFused, args: SubRemainingDeps); |
| 947 | |
| 948 | // Reconstruct the sequence, with some of the children fused. |
| 949 | if (Result.is_null()) |
| 950 | Result = InnerFused; |
| 951 | else |
| 952 | Result = Result.sequence(schedule2: InnerFused); |
| 953 | } |
| 954 | |
| 955 | return Result; |
| 956 | } |
| 957 | }; |
| 958 | |
| 959 | } // namespace |
| 960 | |
| 961 | bool polly::isBandMark(const isl::schedule_node &Node) { |
| 962 | return isMark(Node) && |
| 963 | isLoopAttr(Id: Node.as<isl::schedule_node_mark>().get_id()); |
| 964 | } |
| 965 | |
| 966 | BandAttr *polly::getBandAttr(isl::schedule_node MarkOrBand) { |
| 967 | MarkOrBand = moveToBandMark(BandOrMark: MarkOrBand); |
| 968 | if (!isMark(Node: MarkOrBand)) |
| 969 | return nullptr; |
| 970 | |
| 971 | return getLoopAttr(Id: MarkOrBand.as<isl::schedule_node_mark>().get_id()); |
| 972 | } |
| 973 | |
| 974 | isl::schedule polly::hoistExtensionNodes(isl::schedule Sched) { |
| 975 | // If there is no extension node in the first place, return the original |
| 976 | // schedule tree. |
| 977 | if (!containsExtensionNode(Schedule: Sched)) |
| 978 | return Sched; |
| 979 | |
| 980 | // Build options can anchor schedule nodes, such that the schedule tree cannot |
| 981 | // be modified anymore. Therefore, apply build options after the tree has been |
| 982 | // created. |
| 983 | CollectASTBuildOptions Collector; |
| 984 | Collector.visit(Schedule: Sched); |
| 985 | |
| 986 | // Rewrite the schedule tree without extension nodes. |
| 987 | ExtensionNodeRewriter Rewriter; |
| 988 | isl::schedule NewSched = Rewriter.visitSchedule(Schedule: Sched); |
| 989 | |
| 990 | // Reapply the AST build options. The rewriter must not change the iteration |
| 991 | // order of bands. Any other node type is ignored. |
| 992 | ApplyASTBuildOptions Applicator(Collector.ASTBuildOptions); |
| 993 | NewSched = Applicator.visitSchedule(Schedule: NewSched); |
| 994 | |
| 995 | return NewSched; |
| 996 | } |
| 997 | |
| 998 | isl::schedule polly::applyFullUnroll(isl::schedule_node BandToUnroll) { |
| 999 | isl::ctx Ctx = BandToUnroll.ctx(); |
| 1000 | |
| 1001 | // Remove the loop's mark, the loop will disappear anyway. |
| 1002 | BandToUnroll = removeMark(MarkOrBand: BandToUnroll); |
| 1003 | assert(isBandWithSingleLoop(BandToUnroll)); |
| 1004 | |
| 1005 | isl::multi_union_pw_aff PartialSched = isl::manage( |
| 1006 | ptr: isl_schedule_node_band_get_partial_schedule(node: BandToUnroll.get())); |
| 1007 | assert(unsignedFromIslSize(PartialSched.dim(isl::dim::out)) == 1u && |
| 1008 | "Can only unroll a single dimension" ); |
| 1009 | isl::union_pw_aff PartialSchedUAff = PartialSched.at(pos: 0); |
| 1010 | |
| 1011 | isl::union_set Domain = BandToUnroll.get_domain(); |
| 1012 | PartialSchedUAff = PartialSchedUAff.intersect_domain(uset: Domain); |
| 1013 | isl::union_map PartialSchedUMap = |
| 1014 | isl::union_map::from(upma: isl::union_pw_multi_aff(PartialSchedUAff)); |
| 1015 | |
| 1016 | // Enumerator only the scatter elements. |
| 1017 | isl::union_set ScatterList = PartialSchedUMap.range(); |
| 1018 | |
| 1019 | // Enumerate all loop iterations. |
| 1020 | // TODO: Diagnose if not enumerable or depends on a parameter. |
| 1021 | SmallVector<isl::point, 16> Elts; |
| 1022 | ScatterList.foreach_point(fn: [&Elts](isl::point P) -> isl::stat { |
| 1023 | Elts.push_back(Elt: P); |
| 1024 | return isl::stat::ok(); |
| 1025 | }); |
| 1026 | |
| 1027 | // Don't assume that foreach_point returns in execution order. |
| 1028 | llvm::sort(C&: Elts, Comp: [](isl::point P1, isl::point P2) -> bool { |
| 1029 | isl::val C1 = P1.get_coordinate_val(type: isl::dim::set, pos: 0); |
| 1030 | isl::val C2 = P2.get_coordinate_val(type: isl::dim::set, pos: 0); |
| 1031 | return C1.lt(v2: C2); |
| 1032 | }); |
| 1033 | |
| 1034 | // Convert the points to a sequence of filters. |
| 1035 | isl::union_set_list List = isl::union_set_list(Ctx, Elts.size()); |
| 1036 | for (isl::point P : Elts) { |
| 1037 | // Determine the domains that map this scatter element. |
| 1038 | isl::union_set DomainFilter = PartialSchedUMap.intersect_range(uset: P).domain(); |
| 1039 | |
| 1040 | List = List.add(el: DomainFilter); |
| 1041 | } |
| 1042 | |
| 1043 | // Replace original band with unrolled sequence. |
| 1044 | isl::schedule_node Body = |
| 1045 | isl::manage(ptr: isl_schedule_node_delete(node: BandToUnroll.release())); |
| 1046 | Body = Body.insert_sequence(filters: List); |
| 1047 | return Body.get_schedule(); |
| 1048 | } |
| 1049 | |
| 1050 | isl::schedule polly::applyPartialUnroll(isl::schedule_node BandToUnroll, |
| 1051 | int Factor) { |
| 1052 | assert(Factor > 0 && "Positive unroll factor required" ); |
| 1053 | isl::ctx Ctx = BandToUnroll.ctx(); |
| 1054 | |
| 1055 | // Remove the mark, save the attribute for later use. |
| 1056 | BandAttr *Attr; |
| 1057 | BandToUnroll = removeMark(MarkOrBand: BandToUnroll, Attr); |
| 1058 | assert(isBandWithSingleLoop(BandToUnroll)); |
| 1059 | |
| 1060 | isl::multi_union_pw_aff PartialSched = isl::manage( |
| 1061 | ptr: isl_schedule_node_band_get_partial_schedule(node: BandToUnroll.get())); |
| 1062 | |
| 1063 | // { Stmt[] -> [x] } |
| 1064 | isl::union_pw_aff PartialSchedUAff = PartialSched.at(pos: 0); |
| 1065 | |
| 1066 | // Here we assume the schedule stride is one and starts with 0, which is not |
| 1067 | // necessarily the case. |
| 1068 | isl::union_pw_aff StridedPartialSchedUAff = |
| 1069 | isl::union_pw_aff::empty(space: PartialSchedUAff.get_space()); |
| 1070 | isl::val ValFactor{Ctx, Factor}; |
| 1071 | PartialSchedUAff.foreach_pw_aff(fn: [&StridedPartialSchedUAff, |
| 1072 | &ValFactor](isl::pw_aff PwAff) -> isl::stat { |
| 1073 | isl::space Space = PwAff.get_space(); |
| 1074 | isl::set Universe = isl::set::universe(space: Space.domain()); |
| 1075 | isl::pw_aff AffFactor{Universe, ValFactor}; |
| 1076 | isl::pw_aff DivSchedAff = PwAff.div(pa2: AffFactor).floor().mul(pwaff2: AffFactor); |
| 1077 | StridedPartialSchedUAff = StridedPartialSchedUAff.union_add(upa2: DivSchedAff); |
| 1078 | return isl::stat::ok(); |
| 1079 | }); |
| 1080 | |
| 1081 | isl::union_set_list List = isl::union_set_list(Ctx, Factor); |
| 1082 | for (auto i : seq<int>(Begin: 0, End: Factor)) { |
| 1083 | // { Stmt[] -> [x] } |
| 1084 | isl::union_map UMap = |
| 1085 | isl::union_map::from(upma: isl::union_pw_multi_aff(PartialSchedUAff)); |
| 1086 | |
| 1087 | // { [x] } |
| 1088 | isl::basic_set Divisible = isDivisibleBySet(Ctx, Factor, Offset: i); |
| 1089 | |
| 1090 | // { Stmt[] } |
| 1091 | isl::union_set UnrolledDomain = UMap.intersect_range(uset: Divisible).domain(); |
| 1092 | |
| 1093 | List = List.add(el: UnrolledDomain); |
| 1094 | } |
| 1095 | |
| 1096 | isl::schedule_node Body = |
| 1097 | isl::manage(ptr: isl_schedule_node_delete(node: BandToUnroll.copy())); |
| 1098 | Body = Body.insert_sequence(filters: List); |
| 1099 | isl::schedule_node NewLoop = |
| 1100 | Body.insert_partial_schedule(schedule: StridedPartialSchedUAff); |
| 1101 | |
| 1102 | MDNode *FollowupMD = nullptr; |
| 1103 | if (Attr && Attr->Metadata) |
| 1104 | FollowupMD = |
| 1105 | findOptionalNodeOperand(LoopMD: Attr->Metadata, Name: LLVMLoopUnrollFollowupUnrolled); |
| 1106 | |
| 1107 | isl::id NewBandId = createGeneratedLoopAttr(Ctx, FollowupLoopMD: FollowupMD); |
| 1108 | if (!NewBandId.is_null()) |
| 1109 | NewLoop = insertMark(Band: NewLoop, Mark: NewBandId); |
| 1110 | |
| 1111 | return NewLoop.get_schedule(); |
| 1112 | } |
| 1113 | |
| 1114 | isl::set polly::getPartialTilePrefixes(isl::set ScheduleRange, |
| 1115 | int VectorWidth) { |
| 1116 | unsigned Dims = unsignedFromIslSize(Size: ScheduleRange.tuple_dim()); |
| 1117 | assert(Dims >= 1); |
| 1118 | isl::set LoopPrefixes = |
| 1119 | ScheduleRange.drop_constraints_involving_dims(type: isl::dim::set, first: Dims - 1, n: 1); |
| 1120 | auto ExtentPrefixes = addExtentConstraints(Set: LoopPrefixes, VectorWidth); |
| 1121 | isl::set BadPrefixes = ExtentPrefixes.subtract(set2: ScheduleRange); |
| 1122 | BadPrefixes = BadPrefixes.project_out(type: isl::dim::set, first: Dims - 1, n: 1); |
| 1123 | LoopPrefixes = LoopPrefixes.project_out(type: isl::dim::set, first: Dims - 1, n: 1); |
| 1124 | return LoopPrefixes.subtract(set2: BadPrefixes); |
| 1125 | } |
| 1126 | |
| 1127 | isl::union_set polly::getIsolateOptions(isl::set IsolateDomain, |
| 1128 | unsigned OutDimsNum) { |
| 1129 | unsigned Dims = unsignedFromIslSize(Size: IsolateDomain.tuple_dim()); |
| 1130 | assert(OutDimsNum <= Dims && |
| 1131 | "The isl::set IsolateDomain is used to describe the range of schedule " |
| 1132 | "dimensions values, which should be isolated. Consequently, the " |
| 1133 | "number of its dimensions should be greater than or equal to the " |
| 1134 | "number of the schedule dimensions." ); |
| 1135 | isl::map IsolateRelation = isl::map::from_domain(set: IsolateDomain); |
| 1136 | IsolateRelation = IsolateRelation.move_dims(dst_type: isl::dim::out, dst_pos: 0, src_type: isl::dim::in, |
| 1137 | src_pos: Dims - OutDimsNum, n: OutDimsNum); |
| 1138 | isl::set IsolateOption = IsolateRelation.wrap(); |
| 1139 | isl::id Id = isl::id::alloc(ctx: IsolateOption.ctx(), name: "isolate" , user: nullptr); |
| 1140 | IsolateOption = IsolateOption.set_tuple_id(Id); |
| 1141 | return isl::union_set(IsolateOption); |
| 1142 | } |
| 1143 | |
| 1144 | isl::union_set polly::getDimOptions(isl::ctx Ctx, const char *Option) { |
| 1145 | isl::space Space(Ctx, 0, 1); |
| 1146 | auto DimOption = isl::set::universe(space: Space); |
| 1147 | auto Id = isl::id::alloc(ctx: Ctx, name: Option, user: nullptr); |
| 1148 | DimOption = DimOption.set_tuple_id(Id); |
| 1149 | return isl::union_set(DimOption); |
| 1150 | } |
| 1151 | |
| 1152 | isl::schedule_node polly::tileNode(isl::schedule_node Node, |
| 1153 | const char *Identifier, |
| 1154 | ArrayRef<int> TileSizes, |
| 1155 | int DefaultTileSize) { |
| 1156 | auto Space = isl::manage(ptr: isl_schedule_node_band_get_space(node: Node.get())); |
| 1157 | auto Dims = Space.dim(type: isl::dim::set); |
| 1158 | auto Sizes = isl::multi_val::zero(space: Space); |
| 1159 | std::string IdentifierString(Identifier); |
| 1160 | for (unsigned i : rangeIslSize(Begin: 0, End: Dims)) { |
| 1161 | unsigned tileSize = i < TileSizes.size() ? TileSizes[i] : DefaultTileSize; |
| 1162 | Sizes = Sizes.set_val(pos: i, el: isl::val(Node.ctx(), tileSize)); |
| 1163 | } |
| 1164 | auto TileLoopMarkerStr = IdentifierString + " - Tiles" ; |
| 1165 | auto TileLoopMarker = isl::id::alloc(ctx: Node.ctx(), name: TileLoopMarkerStr, user: nullptr); |
| 1166 | Node = Node.insert_mark(mark: TileLoopMarker); |
| 1167 | Node = Node.child(pos: 0); |
| 1168 | Node = |
| 1169 | isl::manage(ptr: isl_schedule_node_band_tile(node: Node.release(), sizes: Sizes.release())); |
| 1170 | Node = Node.child(pos: 0); |
| 1171 | auto PointLoopMarkerStr = IdentifierString + " - Points" ; |
| 1172 | auto PointLoopMarker = |
| 1173 | isl::id::alloc(ctx: Node.ctx(), name: PointLoopMarkerStr, user: nullptr); |
| 1174 | Node = Node.insert_mark(mark: PointLoopMarker); |
| 1175 | return Node.child(pos: 0); |
| 1176 | } |
| 1177 | |
| 1178 | isl::schedule_node polly::applyRegisterTiling(isl::schedule_node Node, |
| 1179 | ArrayRef<int> TileSizes, |
| 1180 | int DefaultTileSize) { |
| 1181 | Node = tileNode(Node, Identifier: "Register tiling" , TileSizes, DefaultTileSize); |
| 1182 | auto Ctx = Node.ctx(); |
| 1183 | return Node.as<isl::schedule_node_band>().set_ast_build_options( |
| 1184 | isl::union_set(Ctx, "{unroll[x]}" )); |
| 1185 | } |
| 1186 | |
| 1187 | /// Find statements and sub-loops in (possibly nested) sequences. |
| 1188 | static void |
| 1189 | collectFissionableStmts(isl::schedule_node Node, |
| 1190 | SmallVectorImpl<isl::schedule_node> &ScheduleStmts) { |
| 1191 | if (isBand(Node) || isLeaf(Node)) { |
| 1192 | ScheduleStmts.push_back(Elt: Node); |
| 1193 | return; |
| 1194 | } |
| 1195 | |
| 1196 | if (Node.has_children()) { |
| 1197 | isl::schedule_node C = Node.first_child(); |
| 1198 | while (true) { |
| 1199 | collectFissionableStmts(Node: C, ScheduleStmts); |
| 1200 | if (!C.has_next_sibling()) |
| 1201 | break; |
| 1202 | C = C.next_sibling(); |
| 1203 | } |
| 1204 | } |
| 1205 | } |
| 1206 | |
| 1207 | isl::schedule polly::applyMaxFission(isl::schedule_node BandToFission) { |
| 1208 | isl::ctx Ctx = BandToFission.ctx(); |
| 1209 | BandToFission = removeMark(MarkOrBand: BandToFission); |
| 1210 | isl::schedule_node BandBody = BandToFission.child(pos: 0); |
| 1211 | |
| 1212 | SmallVector<isl::schedule_node> FissionableStmts; |
| 1213 | collectFissionableStmts(Node: BandBody, ScheduleStmts&: FissionableStmts); |
| 1214 | size_t N = FissionableStmts.size(); |
| 1215 | |
| 1216 | // Collect the domain for each of the statements that will get their own loop. |
| 1217 | isl::union_set_list DomList = isl::union_set_list(Ctx, N); |
| 1218 | for (size_t i = 0; i < N; ++i) { |
| 1219 | isl::schedule_node BodyPart = FissionableStmts[i]; |
| 1220 | DomList = DomList.add(el: BodyPart.get_domain()); |
| 1221 | } |
| 1222 | |
| 1223 | // Apply the fission by copying the entire loop, but inserting a filter for |
| 1224 | // the statement domains for each fissioned loop. |
| 1225 | isl::schedule_node Fissioned = BandToFission.insert_sequence(filters: DomList); |
| 1226 | |
| 1227 | return Fissioned.get_schedule(); |
| 1228 | } |
| 1229 | |
| 1230 | isl::schedule polly::applyGreedyFusion(isl::schedule Sched, |
| 1231 | const isl::union_map &Deps) { |
| 1232 | POLLY_DEBUG(dbgs() << "Greedy loop fusion\n" ); |
| 1233 | |
| 1234 | GreedyFusionRewriter Rewriter; |
| 1235 | isl::schedule Result = Rewriter.visit(Schedule: Sched, args: Deps); |
| 1236 | if (!Rewriter.AnyChange) { |
| 1237 | POLLY_DEBUG(dbgs() << "Found nothing to fuse\n" ); |
| 1238 | return Sched; |
| 1239 | } |
| 1240 | |
| 1241 | // GreedyFusionRewriter due to working loop-by-loop, bands with multiple loops |
| 1242 | // may have been split into multiple bands. |
| 1243 | return collapseBands(Sched: Result); |
| 1244 | } |
| 1245 | |