1 | //===- polly/ScheduleTreeTransform.cpp --------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Make changes to isl's schedule tree data structure. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "polly/ScheduleTreeTransform.h" |
14 | #include "polly/Support/GICHelper.h" |
15 | #include "polly/Support/ISLTools.h" |
16 | #include "polly/Support/ScopHelper.h" |
17 | #include "llvm/ADT/ArrayRef.h" |
18 | #include "llvm/ADT/Sequence.h" |
19 | #include "llvm/ADT/SmallVector.h" |
20 | #include "llvm/IR/Constants.h" |
21 | #include "llvm/IR/Metadata.h" |
22 | #include "llvm/Transforms/Utils/UnrollLoop.h" |
23 | |
24 | #include "polly/Support/PollyDebug.h" |
25 | #define DEBUG_TYPE "polly-opt-isl" |
26 | |
27 | using namespace polly; |
28 | using namespace llvm; |
29 | |
30 | namespace { |
31 | |
32 | /// Copy the band member attributes (coincidence, loop type, isolate ast loop |
33 | /// type) from one band to another. |
34 | static isl::schedule_node_band |
35 | applyBandMemberAttributes(isl::schedule_node_band Target, int TargetIdx, |
36 | const isl::schedule_node_band &Source, |
37 | int SourceIdx) { |
38 | bool Coincident = Source.member_get_coincident(pos: SourceIdx).release(); |
39 | Target = Target.member_set_coincident(pos: TargetIdx, coincident: Coincident); |
40 | |
41 | isl_ast_loop_type LoopType = |
42 | isl_schedule_node_band_member_get_ast_loop_type(node: Source.get(), pos: SourceIdx); |
43 | Target = isl::manage(ptr: isl_schedule_node_band_member_set_ast_loop_type( |
44 | node: Target.release(), pos: TargetIdx, type: LoopType)) |
45 | .as<isl::schedule_node_band>(); |
46 | |
47 | isl_ast_loop_type IsolateType = |
48 | isl_schedule_node_band_member_get_isolate_ast_loop_type(node: Source.get(), |
49 | pos: SourceIdx); |
50 | Target = isl::manage(ptr: isl_schedule_node_band_member_set_isolate_ast_loop_type( |
51 | node: Target.release(), pos: TargetIdx, type: IsolateType)) |
52 | .as<isl::schedule_node_band>(); |
53 | |
54 | return Target; |
55 | } |
56 | |
57 | /// Create a new band by copying members from another @p Band. @p IncludeCb |
58 | /// decides which band indices are copied to the result. |
59 | template <typename CbTy> |
60 | static isl::schedule rebuildBand(isl::schedule_node_band OldBand, |
61 | isl::schedule Body, CbTy IncludeCb) { |
62 | int NumBandDims = unsignedFromIslSize(Size: OldBand.n_member()); |
63 | |
64 | bool ExcludeAny = false; |
65 | bool IncludeAny = false; |
66 | for (auto OldIdx : seq<int>(Begin: 0, End: NumBandDims)) { |
67 | if (IncludeCb(OldIdx)) |
68 | IncludeAny = true; |
69 | else |
70 | ExcludeAny = true; |
71 | } |
72 | |
73 | // Instead of creating a zero-member band, don't create a band at all. |
74 | if (!IncludeAny) |
75 | return Body; |
76 | |
77 | isl::multi_union_pw_aff PartialSched = OldBand.get_partial_schedule(); |
78 | isl::multi_union_pw_aff NewPartialSched; |
79 | if (ExcludeAny) { |
80 | // Select the included partial scatter functions. |
81 | isl::union_pw_aff_list List = PartialSched.list(); |
82 | int NewIdx = 0; |
83 | for (auto OldIdx : seq<int>(Begin: 0, End: NumBandDims)) { |
84 | if (IncludeCb(OldIdx)) |
85 | NewIdx += 1; |
86 | else |
87 | List = List.drop(first: NewIdx, n: 1); |
88 | } |
89 | isl::space ParamSpace = PartialSched.get_space().params(); |
90 | isl::space NewScatterSpace = ParamSpace.add_unnamed_tuple(dim: NewIdx); |
91 | NewPartialSched = isl::multi_union_pw_aff(NewScatterSpace, List); |
92 | } else { |
93 | // Just reuse original scatter function of copying all of them. |
94 | NewPartialSched = PartialSched; |
95 | } |
96 | |
97 | // Create the new band node. |
98 | isl::schedule_node_band NewBand = |
99 | Body.insert_partial_schedule(partial: NewPartialSched) |
100 | .get_root() |
101 | .child(pos: 0) |
102 | .as<isl::schedule_node_band>(); |
103 | |
104 | // If OldBand was permutable, so is the new one, even if some dimensions are |
105 | // missing. |
106 | bool IsPermutable = OldBand.permutable().release(); |
107 | NewBand = NewBand.set_permutable(IsPermutable); |
108 | |
109 | // Reapply member attributes. |
110 | int NewIdx = 0; |
111 | for (auto OldIdx : seq<int>(Begin: 0, End: NumBandDims)) { |
112 | if (!IncludeCb(OldIdx)) |
113 | continue; |
114 | NewBand = |
115 | applyBandMemberAttributes(Target: std::move(NewBand), TargetIdx: NewIdx, Source: OldBand, SourceIdx: OldIdx); |
116 | NewIdx += 1; |
117 | } |
118 | |
119 | return NewBand.get_schedule(); |
120 | } |
121 | |
122 | /// Rewrite a schedule tree by reconstructing it bottom-up. |
123 | /// |
124 | /// By default, the original schedule tree is reconstructed. To build a |
125 | /// different tree, redefine visitor methods in a derived class (CRTP). |
126 | /// |
127 | /// Note that AST build options are not applied; Setting the isolate[] option |
128 | /// makes the schedule tree 'anchored' and cannot be modified afterwards. Hence, |
129 | /// AST build options must be set after the tree has been constructed. |
130 | template <typename Derived, typename... Args> |
131 | struct ScheduleTreeRewriter |
132 | : RecursiveScheduleTreeVisitor<Derived, isl::schedule, Args...> { |
133 | Derived &getDerived() { return *static_cast<Derived *>(this); } |
134 | const Derived &getDerived() const { |
135 | return *static_cast<const Derived *>(this); |
136 | } |
137 | |
138 | isl::schedule visitDomain(isl::schedule_node_domain Node, Args... args) { |
139 | // Every schedule_tree already has a domain node, no need to add one. |
140 | return getDerived().visit(Node.first_child(), std::forward<Args>(args)...); |
141 | } |
142 | |
143 | isl::schedule visitBand(isl::schedule_node_band Band, Args... args) { |
144 | isl::schedule NewChild = |
145 | getDerived().visit(Band.child(pos: 0), std::forward<Args>(args)...); |
146 | return rebuildBand(Band, NewChild, [](int) { return true; }); |
147 | } |
148 | |
149 | isl::schedule visitSequence(isl::schedule_node_sequence Sequence, |
150 | Args... args) { |
151 | int NumChildren = isl_schedule_node_n_children(node: Sequence.get()); |
152 | isl::schedule Result = |
153 | getDerived().visit(Sequence.child(pos: 0), std::forward<Args>(args)...); |
154 | for (int i = 1; i < NumChildren; i += 1) |
155 | Result = Result.sequence( |
156 | schedule2: getDerived().visit(Sequence.child(pos: i), std::forward<Args>(args)...)); |
157 | return Result; |
158 | } |
159 | |
160 | isl::schedule visitSet(isl::schedule_node_set Set, Args... args) { |
161 | int NumChildren = isl_schedule_node_n_children(node: Set.get()); |
162 | isl::schedule Result = |
163 | getDerived().visit(Set.child(pos: 0), std::forward<Args>(args)...); |
164 | for (int i = 1; i < NumChildren; i += 1) |
165 | Result = isl::manage( |
166 | isl_schedule_set(Result.release(), |
167 | getDerived() |
168 | .visit(Set.child(pos: i), std::forward<Args>(args)...) |
169 | .release())); |
170 | return Result; |
171 | } |
172 | |
173 | isl::schedule visitLeaf(isl::schedule_node_leaf Leaf, Args... args) { |
174 | return isl::schedule::from_domain(domain: Leaf.get_domain()); |
175 | } |
176 | |
177 | isl::schedule visitMark(const isl::schedule_node &Mark, Args... args) { |
178 | |
179 | isl::id TheMark = Mark.as<isl::schedule_node_mark>().get_id(); |
180 | isl::schedule_node NewChild = |
181 | getDerived() |
182 | .visit(Mark.first_child(), std::forward<Args>(args)...) |
183 | .get_root() |
184 | .first_child(); |
185 | return NewChild.insert_mark(mark: TheMark).get_schedule(); |
186 | } |
187 | |
188 | isl::schedule visitExtension(isl::schedule_node_extension Extension, |
189 | Args... args) { |
190 | isl::union_map TheExtension = |
191 | Extension.as<isl::schedule_node_extension>().get_extension(); |
192 | isl::schedule_node NewChild = getDerived() |
193 | .visit(Extension.child(pos: 0), args...) |
194 | .get_root() |
195 | .first_child(); |
196 | isl::schedule_node NewExtension = |
197 | isl::schedule_node::from_extension(extension: TheExtension); |
198 | return NewChild.graft_before(graft: NewExtension).get_schedule(); |
199 | } |
200 | |
201 | isl::schedule visitFilter(isl::schedule_node_filter Filter, Args... args) { |
202 | isl::union_set FilterDomain = |
203 | Filter.as<isl::schedule_node_filter>().get_filter(); |
204 | isl::schedule NewSchedule = |
205 | getDerived().visit(Filter.child(pos: 0), std::forward<Args>(args)...); |
206 | return NewSchedule.intersect_domain(domain: FilterDomain); |
207 | } |
208 | |
209 | isl::schedule visitNode(isl::schedule_node Node, Args... args) { |
210 | llvm_unreachable("Not implemented" ); |
211 | } |
212 | }; |
213 | |
214 | /// Rewrite the schedule tree without any changes. Useful to copy a subtree into |
215 | /// a new schedule, discarding everything but. |
216 | struct IdentityRewriter : ScheduleTreeRewriter<IdentityRewriter> {}; |
217 | |
218 | /// Rewrite a schedule tree to an equivalent one without extension nodes. |
219 | /// |
220 | /// Each visit method takes two additional arguments: |
221 | /// |
222 | /// * The new domain the node, which is the inherited domain plus any domains |
223 | /// added by extension nodes. |
224 | /// |
225 | /// * A map of extension domains of all children is returned; it is required by |
226 | /// band nodes to schedule the additional domains at the same position as the |
227 | /// extension node would. |
228 | /// |
229 | struct ExtensionNodeRewriter final |
230 | : ScheduleTreeRewriter<ExtensionNodeRewriter, const isl::union_set &, |
231 | isl::union_map &> { |
232 | using BaseTy = ScheduleTreeRewriter<ExtensionNodeRewriter, |
233 | const isl::union_set &, isl::union_map &>; |
234 | BaseTy &getBase() { return *this; } |
235 | const BaseTy &getBase() const { return *this; } |
236 | |
237 | isl::schedule visitSchedule(isl::schedule Schedule) { |
238 | isl::union_map Extensions; |
239 | isl::schedule Result = |
240 | visit(Node: Schedule.get_root(), args: Schedule.get_domain(), args&: Extensions); |
241 | assert(!Extensions.is_null() && Extensions.is_empty()); |
242 | return Result; |
243 | } |
244 | |
245 | isl::schedule visitSequence(isl::schedule_node_sequence Sequence, |
246 | const isl::union_set &Domain, |
247 | isl::union_map &Extensions) { |
248 | int NumChildren = isl_schedule_node_n_children(node: Sequence.get()); |
249 | isl::schedule NewNode = visit(Node: Sequence.first_child(), args: Domain, args&: Extensions); |
250 | for (int i = 1; i < NumChildren; i += 1) { |
251 | isl::schedule_node OldChild = Sequence.child(pos: i); |
252 | isl::union_map NewChildExtensions; |
253 | isl::schedule NewChildNode = visit(Node: OldChild, args: Domain, args&: NewChildExtensions); |
254 | NewNode = NewNode.sequence(schedule2: NewChildNode); |
255 | Extensions = Extensions.unite(umap2: NewChildExtensions); |
256 | } |
257 | return NewNode; |
258 | } |
259 | |
260 | isl::schedule visitSet(isl::schedule_node_set Set, |
261 | const isl::union_set &Domain, |
262 | isl::union_map &Extensions) { |
263 | int NumChildren = isl_schedule_node_n_children(node: Set.get()); |
264 | isl::schedule NewNode = visit(Node: Set.first_child(), args: Domain, args&: Extensions); |
265 | for (int i = 1; i < NumChildren; i += 1) { |
266 | isl::schedule_node OldChild = Set.child(pos: i); |
267 | isl::union_map NewChildExtensions; |
268 | isl::schedule NewChildNode = visit(Node: OldChild, args: Domain, args&: NewChildExtensions); |
269 | NewNode = isl::manage( |
270 | ptr: isl_schedule_set(schedule1: NewNode.release(), schedule2: NewChildNode.release())); |
271 | Extensions = Extensions.unite(umap2: NewChildExtensions); |
272 | } |
273 | return NewNode; |
274 | } |
275 | |
276 | isl::schedule visitLeaf(isl::schedule_node_leaf Leaf, |
277 | const isl::union_set &Domain, |
278 | isl::union_map &Extensions) { |
279 | Extensions = isl::union_map::empty(ctx: Leaf.ctx()); |
280 | return isl::schedule::from_domain(domain: Domain); |
281 | } |
282 | |
283 | isl::schedule visitBand(isl::schedule_node_band OldNode, |
284 | const isl::union_set &Domain, |
285 | isl::union_map &OuterExtensions) { |
286 | isl::schedule_node OldChild = OldNode.first_child(); |
287 | isl::multi_union_pw_aff PartialSched = |
288 | isl::manage(ptr: isl_schedule_node_band_get_partial_schedule(node: OldNode.get())); |
289 | |
290 | isl::union_map NewChildExtensions; |
291 | isl::schedule NewChild = visit(Node: OldChild, args: Domain, args&: NewChildExtensions); |
292 | |
293 | // Add the extensions to the partial schedule. |
294 | OuterExtensions = isl::union_map::empty(ctx: NewChildExtensions.ctx()); |
295 | isl::union_map NewPartialSchedMap = isl::union_map::from(mupa: PartialSched); |
296 | unsigned BandDims = isl_schedule_node_band_n_member(node: OldNode.get()); |
297 | for (isl::map Ext : NewChildExtensions.get_map_list()) { |
298 | unsigned ExtDims = unsignedFromIslSize(Size: Ext.domain_tuple_dim()); |
299 | assert(ExtDims >= BandDims); |
300 | unsigned OuterDims = ExtDims - BandDims; |
301 | |
302 | isl::map BandSched = |
303 | Ext.project_out(type: isl::dim::in, first: 0, n: OuterDims).reverse(); |
304 | NewPartialSchedMap = NewPartialSchedMap.unite(umap2: BandSched); |
305 | |
306 | // There might be more outer bands that have to schedule the extensions. |
307 | if (OuterDims > 0) { |
308 | isl::map OuterSched = |
309 | Ext.project_out(type: isl::dim::in, first: OuterDims, n: BandDims); |
310 | OuterExtensions = OuterExtensions.unite(umap2: OuterSched); |
311 | } |
312 | } |
313 | isl::multi_union_pw_aff NewPartialSchedAsAsMultiUnionPwAff = |
314 | isl::multi_union_pw_aff::from_union_map(umap: NewPartialSchedMap); |
315 | isl::schedule_node NewNode = |
316 | NewChild.insert_partial_schedule(partial: NewPartialSchedAsAsMultiUnionPwAff) |
317 | .get_root() |
318 | .child(pos: 0); |
319 | |
320 | // Reapply permutability and coincidence attributes. |
321 | NewNode = isl::manage(ptr: isl_schedule_node_band_set_permutable( |
322 | node: NewNode.release(), |
323 | permutable: isl_schedule_node_band_get_permutable(node: OldNode.get()))); |
324 | for (unsigned i = 0; i < BandDims; i += 1) |
325 | NewNode = applyBandMemberAttributes(Target: NewNode.as<isl::schedule_node_band>(), |
326 | TargetIdx: i, Source: OldNode, SourceIdx: i); |
327 | |
328 | return NewNode.get_schedule(); |
329 | } |
330 | |
331 | isl::schedule visitFilter(isl::schedule_node_filter Filter, |
332 | const isl::union_set &Domain, |
333 | isl::union_map &Extensions) { |
334 | isl::union_set FilterDomain = |
335 | Filter.as<isl::schedule_node_filter>().get_filter(); |
336 | isl::union_set NewDomain = Domain.intersect(uset2: FilterDomain); |
337 | |
338 | // A filter is added implicitly if necessary when joining schedule trees. |
339 | return visit(Node: Filter.first_child(), args: NewDomain, args&: Extensions); |
340 | } |
341 | |
342 | isl::schedule visitExtension(isl::schedule_node_extension Extension, |
343 | const isl::union_set &Domain, |
344 | isl::union_map &Extensions) { |
345 | isl::union_map ExtDomain = |
346 | Extension.as<isl::schedule_node_extension>().get_extension(); |
347 | isl::union_set NewDomain = Domain.unite(uset2: ExtDomain.range()); |
348 | isl::union_map ChildExtensions; |
349 | isl::schedule NewChild = |
350 | visit(Node: Extension.first_child(), args: NewDomain, args&: ChildExtensions); |
351 | Extensions = ChildExtensions.unite(umap2: ExtDomain); |
352 | return NewChild; |
353 | } |
354 | }; |
355 | |
356 | /// Collect all AST build options in any schedule tree band. |
357 | /// |
358 | /// ScheduleTreeRewriter cannot apply the schedule tree options. This class |
359 | /// collects these options to apply them later. |
360 | struct CollectASTBuildOptions final |
361 | : RecursiveScheduleTreeVisitor<CollectASTBuildOptions> { |
362 | using BaseTy = RecursiveScheduleTreeVisitor<CollectASTBuildOptions>; |
363 | BaseTy &getBase() { return *this; } |
364 | const BaseTy &getBase() const { return *this; } |
365 | |
366 | llvm::SmallVector<isl::union_set, 8> ASTBuildOptions; |
367 | |
368 | void visitBand(isl::schedule_node_band Band) { |
369 | ASTBuildOptions.push_back( |
370 | Elt: isl::manage(ptr: isl_schedule_node_band_get_ast_build_options(node: Band.get()))); |
371 | return getBase().visitBand(Band); |
372 | } |
373 | }; |
374 | |
375 | /// Apply AST build options to the bands in a schedule tree. |
376 | /// |
377 | /// This rewrites a schedule tree with the AST build options applied. We assume |
378 | /// that the band nodes are visited in the same order as they were when the |
379 | /// build options were collected, typically by CollectASTBuildOptions. |
380 | struct ApplyASTBuildOptions final : ScheduleNodeRewriter<ApplyASTBuildOptions> { |
381 | using BaseTy = ScheduleNodeRewriter<ApplyASTBuildOptions>; |
382 | BaseTy &getBase() { return *this; } |
383 | const BaseTy &getBase() const { return *this; } |
384 | |
385 | size_t Pos; |
386 | llvm::ArrayRef<isl::union_set> ASTBuildOptions; |
387 | |
388 | ApplyASTBuildOptions(llvm::ArrayRef<isl::union_set> ASTBuildOptions) |
389 | : ASTBuildOptions(ASTBuildOptions) {} |
390 | |
391 | isl::schedule visitSchedule(isl::schedule Schedule) { |
392 | Pos = 0; |
393 | isl::schedule Result = visit(Schedule).get_schedule(); |
394 | assert(Pos == ASTBuildOptions.size() && |
395 | "AST build options must match to band nodes" ); |
396 | return Result; |
397 | } |
398 | |
399 | isl::schedule_node visitBand(isl::schedule_node_band Band) { |
400 | isl::schedule_node_band Result = |
401 | Band.set_ast_build_options(ASTBuildOptions[Pos]); |
402 | Pos += 1; |
403 | return getBase().visitBand(Band: Result); |
404 | } |
405 | }; |
406 | |
407 | /// Return whether the schedule contains an extension node. |
408 | static bool containsExtensionNode(isl::schedule Schedule) { |
409 | assert(!Schedule.is_null()); |
410 | |
411 | auto Callback = [](__isl_keep isl_schedule_node *Node, |
412 | void *User) -> isl_bool { |
413 | if (isl_schedule_node_get_type(node: Node) == isl_schedule_node_extension) { |
414 | // Stop walking the schedule tree. |
415 | return isl_bool_error; |
416 | } |
417 | |
418 | // Continue searching the subtree. |
419 | return isl_bool_true; |
420 | }; |
421 | isl_stat RetVal = isl_schedule_foreach_schedule_node_top_down( |
422 | sched: Schedule.get(), fn: Callback, user: nullptr); |
423 | |
424 | // We assume that the traversal itself does not fail, i.e. the only reason to |
425 | // return isl_stat_error is that an extension node was found. |
426 | return RetVal == isl_stat_error; |
427 | } |
428 | |
429 | /// Find a named MDNode property in a LoopID. |
430 | static MDNode *findOptionalNodeOperand(MDNode *LoopMD, StringRef Name) { |
431 | return dyn_cast_or_null<MDNode>( |
432 | Val: findMetadataOperand(LoopMD, Name).value_or(u: nullptr)); |
433 | } |
434 | |
435 | /// Is this node of type mark? |
436 | static bool isMark(const isl::schedule_node &Node) { |
437 | return isl_schedule_node_get_type(node: Node.get()) == isl_schedule_node_mark; |
438 | } |
439 | |
440 | /// Is this node of type band? |
441 | static bool isBand(const isl::schedule_node &Node) { |
442 | return isl_schedule_node_get_type(node: Node.get()) == isl_schedule_node_band; |
443 | } |
444 | |
445 | #ifndef NDEBUG |
446 | /// Is this node a band of a single dimension (i.e. could represent a loop)? |
447 | static bool isBandWithSingleLoop(const isl::schedule_node &Node) { |
448 | return isBand(Node) && isl_schedule_node_band_n_member(node: Node.get()) == 1; |
449 | } |
450 | #endif |
451 | |
452 | static bool isLeaf(const isl::schedule_node &Node) { |
453 | return isl_schedule_node_get_type(node: Node.get()) == isl_schedule_node_leaf; |
454 | } |
455 | |
456 | /// Create an isl::id representing the output loop after a transformation. |
457 | static isl::id createGeneratedLoopAttr(isl::ctx Ctx, MDNode *FollowupLoopMD) { |
458 | // Don't need to id the followup. |
459 | // TODO: Append llvm.loop.disable_heustistics metadata unless overridden by |
460 | // user followup-MD |
461 | if (!FollowupLoopMD) |
462 | return {}; |
463 | |
464 | BandAttr *Attr = new BandAttr(); |
465 | Attr->Metadata = FollowupLoopMD; |
466 | return getIslLoopAttr(Ctx, Attr); |
467 | } |
468 | |
469 | /// A loop consists of a band and an optional marker that wraps it. Return the |
470 | /// outermost of the two. |
471 | |
472 | /// That is, either the mark or, if there is not mark, the loop itself. Can |
473 | /// start with either the mark or the band. |
474 | static isl::schedule_node moveToBandMark(isl::schedule_node BandOrMark) { |
475 | if (isBandMark(Node: BandOrMark)) { |
476 | assert(isBandWithSingleLoop(BandOrMark.child(0))); |
477 | return BandOrMark; |
478 | } |
479 | assert(isBandWithSingleLoop(BandOrMark)); |
480 | |
481 | isl::schedule_node Mark = BandOrMark.parent(); |
482 | if (isBandMark(Node: Mark)) |
483 | return Mark; |
484 | |
485 | // Band has no loop marker. |
486 | return BandOrMark; |
487 | } |
488 | |
489 | static isl::schedule_node removeMark(isl::schedule_node MarkOrBand, |
490 | BandAttr *&Attr) { |
491 | MarkOrBand = moveToBandMark(BandOrMark: MarkOrBand); |
492 | |
493 | isl::schedule_node Band; |
494 | if (isMark(Node: MarkOrBand)) { |
495 | Attr = getLoopAttr(Id: MarkOrBand.as<isl::schedule_node_mark>().get_id()); |
496 | Band = isl::manage(ptr: isl_schedule_node_delete(node: MarkOrBand.release())); |
497 | } else { |
498 | Attr = nullptr; |
499 | Band = MarkOrBand; |
500 | } |
501 | |
502 | assert(isBandWithSingleLoop(Band)); |
503 | return Band; |
504 | } |
505 | |
506 | /// Remove the mark that wraps a loop. Return the band representing the loop. |
507 | static isl::schedule_node removeMark(isl::schedule_node MarkOrBand) { |
508 | BandAttr *Attr; |
509 | return removeMark(MarkOrBand, Attr); |
510 | } |
511 | |
512 | static isl::schedule_node insertMark(isl::schedule_node Band, isl::id Mark) { |
513 | assert(isBand(Band)); |
514 | assert(moveToBandMark(Band).is_equal(Band) && |
515 | "Don't add a two marks for a band" ); |
516 | |
517 | return Band.insert_mark(mark: Mark).child(pos: 0); |
518 | } |
519 | |
520 | /// Return the (one-dimensional) set of numbers that are divisible by @p Factor |
521 | /// with remainder @p Offset. |
522 | /// |
523 | /// isDivisibleBySet(Ctx, 4, 0) = { [i] : floord(i,4) = 0 } |
524 | /// isDivisibleBySet(Ctx, 4, 1) = { [i] : floord(i,4) = 1 } |
525 | /// |
526 | static isl::basic_set isDivisibleBySet(isl::ctx &Ctx, long Factor, |
527 | long Offset) { |
528 | isl::val ValFactor{Ctx, Factor}; |
529 | isl::val ValOffset{Ctx, Offset}; |
530 | |
531 | isl::space Unispace{Ctx, 0, 1}; |
532 | isl::local_space LUnispace{Unispace}; |
533 | isl::aff AffFactor{LUnispace, ValFactor}; |
534 | isl::aff AffOffset{LUnispace, ValOffset}; |
535 | |
536 | isl::aff Id = isl::aff::var_on_domain(ls: LUnispace, type: isl::dim::out, pos: 0); |
537 | isl::aff DivMul = Id.mod(mod: ValFactor); |
538 | isl::basic_map Divisible = isl::basic_map::from_aff(aff: DivMul); |
539 | isl::basic_map Modulo = Divisible.fix_val(type: isl::dim::out, pos: 0, v: ValOffset); |
540 | return Modulo.domain(); |
541 | } |
542 | |
543 | /// Make the last dimension of Set to take values from 0 to VectorWidth - 1. |
544 | /// |
545 | /// @param Set A set, which should be modified. |
546 | /// @param VectorWidth A parameter, which determines the constraint. |
547 | static isl::set addExtentConstraints(isl::set Set, int VectorWidth) { |
548 | unsigned Dims = unsignedFromIslSize(Size: Set.tuple_dim()); |
549 | assert(Dims >= 1); |
550 | isl::space Space = Set.get_space(); |
551 | isl::local_space LocalSpace = isl::local_space(Space); |
552 | isl::constraint ExtConstr = isl::constraint::alloc_inequality(ls: LocalSpace); |
553 | ExtConstr = ExtConstr.set_constant_si(0); |
554 | ExtConstr = ExtConstr.set_coefficient_si(type: isl::dim::set, pos: Dims - 1, v: 1); |
555 | Set = Set.add_constraint(constraint: ExtConstr); |
556 | ExtConstr = isl::constraint::alloc_inequality(ls: LocalSpace); |
557 | ExtConstr = ExtConstr.set_constant_si(VectorWidth - 1); |
558 | ExtConstr = ExtConstr.set_coefficient_si(type: isl::dim::set, pos: Dims - 1, v: -1); |
559 | return Set.add_constraint(constraint: ExtConstr); |
560 | } |
561 | |
562 | /// Collapse perfectly nested bands into a single band. |
563 | class BandCollapseRewriter final |
564 | : public ScheduleTreeRewriter<BandCollapseRewriter> { |
565 | private: |
566 | using BaseTy = ScheduleTreeRewriter<BandCollapseRewriter>; |
567 | BaseTy &getBase() { return *this; } |
568 | const BaseTy &getBase() const { return *this; } |
569 | |
570 | public: |
571 | isl::schedule visitBand(isl::schedule_node_band RootBand) { |
572 | isl::schedule_node_band Band = RootBand; |
573 | isl::ctx Ctx = Band.ctx(); |
574 | |
575 | // Do not merge permutable band to avoid loosing the permutability property. |
576 | // Cannot collapse even two permutable loops, they might be permutable |
577 | // individually, but not necassarily across. |
578 | if (unsignedFromIslSize(Size: Band.n_member()) > 1u && Band.permutable()) |
579 | return getBase().visitBand(Band); |
580 | |
581 | // Find collapsable bands. |
582 | SmallVector<isl::schedule_node_band> Nest; |
583 | int NumTotalLoops = 0; |
584 | isl::schedule_node Body; |
585 | while (true) { |
586 | Nest.push_back(Elt: Band); |
587 | NumTotalLoops += unsignedFromIslSize(Size: Band.n_member()); |
588 | Body = Band.first_child(); |
589 | if (!Body.isa<isl::schedule_node_band>()) |
590 | break; |
591 | Band = Body.as<isl::schedule_node_band>(); |
592 | |
593 | // Do not include next band if it is permutable to not lose its |
594 | // permutability property. |
595 | if (unsignedFromIslSize(Size: Band.n_member()) > 1u && Band.permutable()) |
596 | break; |
597 | } |
598 | |
599 | // Nothing to collapse, preserve permutability. |
600 | if (Nest.size() <= 1) |
601 | return getBase().visitBand(Band); |
602 | |
603 | POLLY_DEBUG({ |
604 | dbgs() << "Found loops to collapse between\n" ; |
605 | dumpIslObj(RootBand, dbgs()); |
606 | dbgs() << "and\n" ; |
607 | dumpIslObj(Body, dbgs()); |
608 | dbgs() << "\n" ; |
609 | }); |
610 | |
611 | isl::schedule NewBody = visit(Node: Body); |
612 | |
613 | // Collect partial schedules from all members. |
614 | isl::union_pw_aff_list PartScheds{Ctx, NumTotalLoops}; |
615 | for (isl::schedule_node_band Band : Nest) { |
616 | int NumLoops = unsignedFromIslSize(Size: Band.n_member()); |
617 | isl::multi_union_pw_aff BandScheds = Band.get_partial_schedule(); |
618 | for (auto j : seq<int>(Begin: 0, End: NumLoops)) |
619 | PartScheds = PartScheds.add(el: BandScheds.at(pos: j)); |
620 | } |
621 | isl::space ScatterSpace = isl::space(Ctx, 0, NumTotalLoops); |
622 | isl::multi_union_pw_aff PartSchedsMulti{ScatterSpace, PartScheds}; |
623 | |
624 | isl::schedule_node_band CollapsedBand = |
625 | NewBody.insert_partial_schedule(partial: PartSchedsMulti) |
626 | .get_root() |
627 | .first_child() |
628 | .as<isl::schedule_node_band>(); |
629 | |
630 | // Copy over loop attributes form original bands. |
631 | int LoopIdx = 0; |
632 | for (isl::schedule_node_band Band : Nest) { |
633 | int NumLoops = unsignedFromIslSize(Size: Band.n_member()); |
634 | for (int i : seq<int>(Begin: 0, End: NumLoops)) { |
635 | CollapsedBand = applyBandMemberAttributes(Target: std::move(CollapsedBand), |
636 | TargetIdx: LoopIdx, Source: Band, SourceIdx: i); |
637 | LoopIdx += 1; |
638 | } |
639 | } |
640 | assert(LoopIdx == NumTotalLoops && |
641 | "Expect the same number of loops to add up again" ); |
642 | |
643 | return CollapsedBand.get_schedule(); |
644 | } |
645 | }; |
646 | |
647 | static isl::schedule collapseBands(isl::schedule Sched) { |
648 | POLLY_DEBUG(dbgs() << "Collapse bands in schedule\n" ); |
649 | BandCollapseRewriter Rewriter; |
650 | return Rewriter.visit(Schedule: Sched); |
651 | } |
652 | |
653 | /// Collect sequentially executed bands (or anything else), even if nested in a |
654 | /// mark or other nodes whose child is executed just once. If we can |
655 | /// successfully fuse the bands, we allow them to be removed. |
656 | static void collectPotentiallyFusableBands( |
657 | isl::schedule_node Node, |
658 | SmallVectorImpl<std::pair<isl::schedule_node, isl::schedule_node>> |
659 | &ScheduleBands, |
660 | const isl::schedule_node &DirectChild) { |
661 | switch (isl_schedule_node_get_type(node: Node.get())) { |
662 | case isl_schedule_node_sequence: |
663 | case isl_schedule_node_set: |
664 | case isl_schedule_node_mark: |
665 | case isl_schedule_node_domain: |
666 | case isl_schedule_node_filter: |
667 | if (Node.has_children()) { |
668 | isl::schedule_node C = Node.first_child(); |
669 | while (true) { |
670 | collectPotentiallyFusableBands(Node: C, ScheduleBands, DirectChild); |
671 | if (!C.has_next_sibling()) |
672 | break; |
673 | C = C.next_sibling(); |
674 | } |
675 | } |
676 | break; |
677 | |
678 | default: |
679 | // Something that does not execute suquentially (e.g. a band) |
680 | ScheduleBands.push_back(Elt: {Node, DirectChild}); |
681 | break; |
682 | } |
683 | } |
684 | |
685 | /// Remove dependencies that are resolved by @p PartSched. That is, remove |
686 | /// everything that we already know is executed in-order. |
687 | static isl::union_map remainingDepsFromPartialSchedule(isl::union_map PartSched, |
688 | isl::union_map Deps) { |
689 | unsigned NumDims = getNumScatterDims(Schedule: PartSched); |
690 | auto ParamSpace = PartSched.get_space().params(); |
691 | |
692 | // { Scatter[] } |
693 | isl::space ScatterSpace = |
694 | ParamSpace.set_from_params().add_dims(type: isl::dim::set, n: NumDims); |
695 | |
696 | // { Scatter[] -> Domain[] } |
697 | isl::union_map PartSchedRev = PartSched.reverse(); |
698 | |
699 | // { Scatter[] -> Scatter[] } |
700 | isl::map MaybeBefore = isl::map::lex_le(set_space: ScatterSpace); |
701 | |
702 | // { Domain[] -> Domain[] } |
703 | isl::union_map DomMaybeBefore = |
704 | MaybeBefore.apply_domain(umap2: PartSchedRev).apply_range(umap2: PartSchedRev); |
705 | |
706 | // { Domain[] -> Domain[] } |
707 | isl::union_map ChildRemainingDeps = Deps.intersect(umap2: DomMaybeBefore); |
708 | |
709 | return ChildRemainingDeps; |
710 | } |
711 | |
712 | /// Remove dependencies that are resolved by executing them in the order |
713 | /// specified by @p Domains; |
714 | static isl::union_map remainigDepsFromSequence(ArrayRef<isl::union_set> Domains, |
715 | isl::union_map Deps) { |
716 | isl::ctx Ctx = Deps.ctx(); |
717 | isl::space ParamSpace = Deps.get_space().params(); |
718 | |
719 | // Create a partial schedule mapping to constants that reflect the execution |
720 | // order. |
721 | isl::union_map PartialSchedules = isl::union_map::empty(ctx: Ctx); |
722 | for (auto P : enumerate(First&: Domains)) { |
723 | isl::val ExecTime = isl::val(Ctx, P.index()); |
724 | isl::union_pw_aff DomSched{P.value(), ExecTime}; |
725 | PartialSchedules = PartialSchedules.unite(umap2: DomSched.as_union_map()); |
726 | } |
727 | |
728 | return remainingDepsFromPartialSchedule(PartSched: PartialSchedules, Deps); |
729 | } |
730 | |
731 | /// Determine whether the outermost loop of to bands can be fused while |
732 | /// respecting validity dependencies. |
733 | static bool canFuseOutermost(const isl::schedule_node_band &LHS, |
734 | const isl::schedule_node_band &RHS, |
735 | const isl::union_map &Deps) { |
736 | // { LHSDomain[] -> Scatter[] } |
737 | isl::union_map LHSPartSched = |
738 | LHS.get_partial_schedule().get_at(pos: 0).as_union_map(); |
739 | |
740 | // { Domain[] -> Scatter[] } |
741 | isl::union_map RHSPartSched = |
742 | RHS.get_partial_schedule().get_at(pos: 0).as_union_map(); |
743 | |
744 | // Dependencies that are already resolved because LHS executes before RHS, but |
745 | // will not be anymore after fusion. { DefDomain[] -> UseDomain[] } |
746 | isl::union_map OrderedBySequence = |
747 | Deps.intersect_domain(uset: LHSPartSched.domain()) |
748 | .intersect_range(uset: RHSPartSched.domain()); |
749 | |
750 | isl::space ParamSpace = OrderedBySequence.get_space().params(); |
751 | isl::space NewScatterSpace = ParamSpace.add_unnamed_tuple(dim: 1); |
752 | |
753 | // { Scatter[] -> Scatter[] } |
754 | isl::map After = isl::map::lex_gt(set_space: NewScatterSpace); |
755 | |
756 | // After fusion, instances with smaller (or equal, which means they will be |
757 | // executed in the same iteration, but the LHS instance is still sequenced |
758 | // before RHS) scatter value will still be executed before. This are the |
759 | // orderings where this is not necessarily the case. |
760 | // { LHSDomain[] -> RHSDomain[] } |
761 | isl::union_map MightBeAfterDoms = After.apply_domain(umap2: LHSPartSched.reverse()) |
762 | .apply_range(umap2: RHSPartSched.reverse()); |
763 | |
764 | // Dependencies that are not resolved by the new execution order. |
765 | isl::union_map WithBefore = OrderedBySequence.intersect(umap2: MightBeAfterDoms); |
766 | |
767 | return WithBefore.is_empty(); |
768 | } |
769 | |
770 | /// Fuse @p LHS and @p RHS if possible while preserving validity dependenvies. |
771 | static isl::schedule tryGreedyFuse(isl::schedule_node_band LHS, |
772 | isl::schedule_node_band RHS, |
773 | const isl::union_map &Deps) { |
774 | if (!canFuseOutermost(LHS, RHS, Deps)) |
775 | return {}; |
776 | |
777 | POLLY_DEBUG({ |
778 | dbgs() << "Found loops for greedy fusion:\n" ; |
779 | dumpIslObj(LHS, dbgs()); |
780 | dbgs() << "and\n" ; |
781 | dumpIslObj(RHS, dbgs()); |
782 | dbgs() << "\n" ; |
783 | }); |
784 | |
785 | // The partial schedule of the bands outermost loop that we need to combine |
786 | // for the fusion. |
787 | isl::union_pw_aff LHSPartOuterSched = LHS.get_partial_schedule().get_at(pos: 0); |
788 | isl::union_pw_aff RHSPartOuterSched = RHS.get_partial_schedule().get_at(pos: 0); |
789 | |
790 | // Isolate band bodies as roots of their own schedule trees. |
791 | IdentityRewriter Rewriter; |
792 | isl::schedule LHSBody = Rewriter.visit(Node: LHS.first_child()); |
793 | isl::schedule RHSBody = Rewriter.visit(Node: RHS.first_child()); |
794 | |
795 | // Reconstruct the non-outermost (not going to be fused) loops from both |
796 | // bands. |
797 | // TODO: Maybe it is possibly to transfer the 'permutability' property from |
798 | // LHS+RHS. At minimum we need merge multiple band members at once, otherwise |
799 | // permutability has no meaning. |
800 | isl::schedule LHSNewBody = |
801 | rebuildBand(OldBand: LHS, Body: LHSBody, IncludeCb: [](int i) { return i > 0; }); |
802 | isl::schedule RHSNewBody = |
803 | rebuildBand(OldBand: RHS, Body: RHSBody, IncludeCb: [](int i) { return i > 0; }); |
804 | |
805 | // The loop body of the fused loop. |
806 | isl::schedule NewCommonBody = LHSNewBody.sequence(schedule2: RHSNewBody); |
807 | |
808 | // Combine the partial schedules of both loops to a new one. Instances with |
809 | // the same scatter value are put together. |
810 | isl::union_map NewCommonPartialSched = |
811 | LHSPartOuterSched.as_union_map().unite(umap2: RHSPartOuterSched.as_union_map()); |
812 | isl::schedule NewCommonSchedule = NewCommonBody.insert_partial_schedule( |
813 | partial: NewCommonPartialSched.as_multi_union_pw_aff()); |
814 | |
815 | return NewCommonSchedule; |
816 | } |
817 | |
818 | static isl::schedule tryGreedyFuse(isl::schedule_node LHS, |
819 | isl::schedule_node RHS, |
820 | const isl::union_map &Deps) { |
821 | // TODO: Non-bands could be interpreted as a band with just as single |
822 | // iteration. However, this is only useful if both ends of a fused loop were |
823 | // originally loops themselves. |
824 | if (!LHS.isa<isl::schedule_node_band>()) |
825 | return {}; |
826 | if (!RHS.isa<isl::schedule_node_band>()) |
827 | return {}; |
828 | |
829 | return tryGreedyFuse(LHS: LHS.as<isl::schedule_node_band>(), |
830 | RHS: RHS.as<isl::schedule_node_band>(), Deps); |
831 | } |
832 | |
833 | /// Fuse all fusable loop top-down in a schedule tree. |
834 | /// |
835 | /// The isl::union_map parameters is the set of validity dependencies that have |
836 | /// not been resolved/carried by a parent schedule node. |
837 | class GreedyFusionRewriter final |
838 | : public ScheduleTreeRewriter<GreedyFusionRewriter, isl::union_map> { |
839 | private: |
840 | using BaseTy = ScheduleTreeRewriter<GreedyFusionRewriter, isl::union_map>; |
841 | BaseTy &getBase() { return *this; } |
842 | const BaseTy &getBase() const { return *this; } |
843 | |
844 | public: |
845 | /// Is set to true if anything has been fused. |
846 | bool AnyChange = false; |
847 | |
848 | isl::schedule visitBand(isl::schedule_node_band Band, isl::union_map Deps) { |
849 | // { Domain[] -> Scatter[] } |
850 | isl::union_map PartSched = |
851 | isl::union_map::from(mupa: Band.get_partial_schedule()); |
852 | assert(getNumScatterDims(PartSched) == |
853 | unsignedFromIslSize(Band.n_member())); |
854 | isl::space ParamSpace = PartSched.get_space().params(); |
855 | |
856 | // { Scatter[] -> Domain[] } |
857 | isl::union_map PartSchedRev = PartSched.reverse(); |
858 | |
859 | // Possible within the same iteration. Dependencies with smaller scatter |
860 | // value are carried by this loop and therefore have been resolved by the |
861 | // in-order execution if the loop iteration. A dependency with small scatter |
862 | // value would be a dependency violation that we assume did not happen. { |
863 | // Domain[] -> Domain[] } |
864 | isl::union_map Unsequenced = PartSchedRev.apply_domain(umap2: PartSchedRev); |
865 | |
866 | // Actual dependencies within the same iteration. |
867 | // { DefDomain[] -> UseDomain[] } |
868 | isl::union_map RemDeps = Deps.intersect(umap2: Unsequenced); |
869 | |
870 | return getBase().visitBand(Band, args: RemDeps); |
871 | } |
872 | |
873 | isl::schedule visitSequence(isl::schedule_node_sequence Sequence, |
874 | isl::union_map Deps) { |
875 | int NumChildren = isl_schedule_node_n_children(node: Sequence.get()); |
876 | |
877 | // List of fusion candidates. The first element is the fusion candidate, the |
878 | // second is candidate's ancestor that is the sequence's direct child. It is |
879 | // preferable to use the direct child if not if its non-direct children is |
880 | // fused to preserve its structure such as mark nodes. |
881 | SmallVector<std::pair<isl::schedule_node, isl::schedule_node>> Bands; |
882 | for (auto i : seq<int>(Begin: 0, End: NumChildren)) { |
883 | isl::schedule_node Child = Sequence.child(pos: i); |
884 | collectPotentiallyFusableBands(Node: Child, ScheduleBands&: Bands, DirectChild: Child); |
885 | } |
886 | |
887 | // Direct children that had at least one of its decendants fused. |
888 | SmallDenseSet<isl_schedule_node *, 4> ChangedDirectChildren; |
889 | |
890 | // Fuse neigboring bands until reaching the end of candidates. |
891 | int i = 0; |
892 | while (i + 1 < (int)Bands.size()) { |
893 | isl::schedule Fused = |
894 | tryGreedyFuse(LHS: Bands[i].first, RHS: Bands[i + 1].first, Deps); |
895 | if (Fused.is_null()) { |
896 | // Cannot merge this node with the next; look at next pair. |
897 | i += 1; |
898 | continue; |
899 | } |
900 | |
901 | // Mark the direct children as (partially) fused. |
902 | if (!Bands[i].second.is_null()) |
903 | ChangedDirectChildren.insert(V: Bands[i].second.get()); |
904 | if (!Bands[i + 1].second.is_null()) |
905 | ChangedDirectChildren.insert(V: Bands[i + 1].second.get()); |
906 | |
907 | // Collapse the neigbros to a single new candidate that could be fused |
908 | // with the next candidate. |
909 | Bands[i] = {Fused.get_root(), {}}; |
910 | Bands.erase(CI: Bands.begin() + i + 1); |
911 | |
912 | AnyChange = true; |
913 | } |
914 | |
915 | // By construction equal if done with collectPotentiallyFusableBands's |
916 | // output. |
917 | SmallVector<isl::union_set> SubDomains; |
918 | SubDomains.reserve(N: NumChildren); |
919 | for (int i = 0; i < NumChildren; i += 1) |
920 | SubDomains.push_back(Elt: Sequence.child(pos: i).domain()); |
921 | auto SubRemainingDeps = remainigDepsFromSequence(Domains: SubDomains, Deps); |
922 | |
923 | // We may iterate over direct children multiple times, be sure to add each |
924 | // at most once. |
925 | SmallDenseSet<isl_schedule_node *, 4> AlreadyAdded; |
926 | |
927 | isl::schedule Result; |
928 | for (auto &P : Bands) { |
929 | isl::schedule_node MaybeFused = P.first; |
930 | isl::schedule_node DirectChild = P.second; |
931 | |
932 | // If not modified, use the direct child. |
933 | if (!DirectChild.is_null() && |
934 | !ChangedDirectChildren.count(V: DirectChild.get())) { |
935 | if (AlreadyAdded.count(V: DirectChild.get())) |
936 | continue; |
937 | AlreadyAdded.insert(V: DirectChild.get()); |
938 | MaybeFused = DirectChild; |
939 | } else { |
940 | assert(AnyChange && |
941 | "Need changed flag for be consistent with actual change" ); |
942 | } |
943 | |
944 | // Top-down recursion: If the outermost loop has been fused, their nested |
945 | // bands might be fusable now as well. |
946 | isl::schedule InnerFused = visit(Node: MaybeFused, args: SubRemainingDeps); |
947 | |
948 | // Reconstruct the sequence, with some of the children fused. |
949 | if (Result.is_null()) |
950 | Result = InnerFused; |
951 | else |
952 | Result = Result.sequence(schedule2: InnerFused); |
953 | } |
954 | |
955 | return Result; |
956 | } |
957 | }; |
958 | |
959 | } // namespace |
960 | |
961 | bool polly::isBandMark(const isl::schedule_node &Node) { |
962 | return isMark(Node) && |
963 | isLoopAttr(Id: Node.as<isl::schedule_node_mark>().get_id()); |
964 | } |
965 | |
966 | BandAttr *polly::getBandAttr(isl::schedule_node MarkOrBand) { |
967 | MarkOrBand = moveToBandMark(BandOrMark: MarkOrBand); |
968 | if (!isMark(Node: MarkOrBand)) |
969 | return nullptr; |
970 | |
971 | return getLoopAttr(Id: MarkOrBand.as<isl::schedule_node_mark>().get_id()); |
972 | } |
973 | |
974 | isl::schedule polly::hoistExtensionNodes(isl::schedule Sched) { |
975 | // If there is no extension node in the first place, return the original |
976 | // schedule tree. |
977 | if (!containsExtensionNode(Schedule: Sched)) |
978 | return Sched; |
979 | |
980 | // Build options can anchor schedule nodes, such that the schedule tree cannot |
981 | // be modified anymore. Therefore, apply build options after the tree has been |
982 | // created. |
983 | CollectASTBuildOptions Collector; |
984 | Collector.visit(Schedule: Sched); |
985 | |
986 | // Rewrite the schedule tree without extension nodes. |
987 | ExtensionNodeRewriter Rewriter; |
988 | isl::schedule NewSched = Rewriter.visitSchedule(Schedule: Sched); |
989 | |
990 | // Reapply the AST build options. The rewriter must not change the iteration |
991 | // order of bands. Any other node type is ignored. |
992 | ApplyASTBuildOptions Applicator(Collector.ASTBuildOptions); |
993 | NewSched = Applicator.visitSchedule(Schedule: NewSched); |
994 | |
995 | return NewSched; |
996 | } |
997 | |
998 | isl::schedule polly::applyFullUnroll(isl::schedule_node BandToUnroll) { |
999 | isl::ctx Ctx = BandToUnroll.ctx(); |
1000 | |
1001 | // Remove the loop's mark, the loop will disappear anyway. |
1002 | BandToUnroll = removeMark(MarkOrBand: BandToUnroll); |
1003 | assert(isBandWithSingleLoop(BandToUnroll)); |
1004 | |
1005 | isl::multi_union_pw_aff PartialSched = isl::manage( |
1006 | ptr: isl_schedule_node_band_get_partial_schedule(node: BandToUnroll.get())); |
1007 | assert(unsignedFromIslSize(PartialSched.dim(isl::dim::out)) == 1u && |
1008 | "Can only unroll a single dimension" ); |
1009 | isl::union_pw_aff PartialSchedUAff = PartialSched.at(pos: 0); |
1010 | |
1011 | isl::union_set Domain = BandToUnroll.get_domain(); |
1012 | PartialSchedUAff = PartialSchedUAff.intersect_domain(uset: Domain); |
1013 | isl::union_map PartialSchedUMap = |
1014 | isl::union_map::from(upma: isl::union_pw_multi_aff(PartialSchedUAff)); |
1015 | |
1016 | // Enumerator only the scatter elements. |
1017 | isl::union_set ScatterList = PartialSchedUMap.range(); |
1018 | |
1019 | // Enumerate all loop iterations. |
1020 | // TODO: Diagnose if not enumerable or depends on a parameter. |
1021 | SmallVector<isl::point, 16> Elts; |
1022 | ScatterList.foreach_point(fn: [&Elts](isl::point P) -> isl::stat { |
1023 | Elts.push_back(Elt: P); |
1024 | return isl::stat::ok(); |
1025 | }); |
1026 | |
1027 | // Don't assume that foreach_point returns in execution order. |
1028 | llvm::sort(C&: Elts, Comp: [](isl::point P1, isl::point P2) -> bool { |
1029 | isl::val C1 = P1.get_coordinate_val(type: isl::dim::set, pos: 0); |
1030 | isl::val C2 = P2.get_coordinate_val(type: isl::dim::set, pos: 0); |
1031 | return C1.lt(v2: C2); |
1032 | }); |
1033 | |
1034 | // Convert the points to a sequence of filters. |
1035 | isl::union_set_list List = isl::union_set_list(Ctx, Elts.size()); |
1036 | for (isl::point P : Elts) { |
1037 | // Determine the domains that map this scatter element. |
1038 | isl::union_set DomainFilter = PartialSchedUMap.intersect_range(uset: P).domain(); |
1039 | |
1040 | List = List.add(el: DomainFilter); |
1041 | } |
1042 | |
1043 | // Replace original band with unrolled sequence. |
1044 | isl::schedule_node Body = |
1045 | isl::manage(ptr: isl_schedule_node_delete(node: BandToUnroll.release())); |
1046 | Body = Body.insert_sequence(filters: List); |
1047 | return Body.get_schedule(); |
1048 | } |
1049 | |
1050 | isl::schedule polly::applyPartialUnroll(isl::schedule_node BandToUnroll, |
1051 | int Factor) { |
1052 | assert(Factor > 0 && "Positive unroll factor required" ); |
1053 | isl::ctx Ctx = BandToUnroll.ctx(); |
1054 | |
1055 | // Remove the mark, save the attribute for later use. |
1056 | BandAttr *Attr; |
1057 | BandToUnroll = removeMark(MarkOrBand: BandToUnroll, Attr); |
1058 | assert(isBandWithSingleLoop(BandToUnroll)); |
1059 | |
1060 | isl::multi_union_pw_aff PartialSched = isl::manage( |
1061 | ptr: isl_schedule_node_band_get_partial_schedule(node: BandToUnroll.get())); |
1062 | |
1063 | // { Stmt[] -> [x] } |
1064 | isl::union_pw_aff PartialSchedUAff = PartialSched.at(pos: 0); |
1065 | |
1066 | // Here we assume the schedule stride is one and starts with 0, which is not |
1067 | // necessarily the case. |
1068 | isl::union_pw_aff StridedPartialSchedUAff = |
1069 | isl::union_pw_aff::empty(space: PartialSchedUAff.get_space()); |
1070 | isl::val ValFactor{Ctx, Factor}; |
1071 | PartialSchedUAff.foreach_pw_aff(fn: [&StridedPartialSchedUAff, |
1072 | &ValFactor](isl::pw_aff PwAff) -> isl::stat { |
1073 | isl::space Space = PwAff.get_space(); |
1074 | isl::set Universe = isl::set::universe(space: Space.domain()); |
1075 | isl::pw_aff AffFactor{Universe, ValFactor}; |
1076 | isl::pw_aff DivSchedAff = PwAff.div(pa2: AffFactor).floor().mul(pwaff2: AffFactor); |
1077 | StridedPartialSchedUAff = StridedPartialSchedUAff.union_add(upa2: DivSchedAff); |
1078 | return isl::stat::ok(); |
1079 | }); |
1080 | |
1081 | isl::union_set_list List = isl::union_set_list(Ctx, Factor); |
1082 | for (auto i : seq<int>(Begin: 0, End: Factor)) { |
1083 | // { Stmt[] -> [x] } |
1084 | isl::union_map UMap = |
1085 | isl::union_map::from(upma: isl::union_pw_multi_aff(PartialSchedUAff)); |
1086 | |
1087 | // { [x] } |
1088 | isl::basic_set Divisible = isDivisibleBySet(Ctx, Factor, Offset: i); |
1089 | |
1090 | // { Stmt[] } |
1091 | isl::union_set UnrolledDomain = UMap.intersect_range(uset: Divisible).domain(); |
1092 | |
1093 | List = List.add(el: UnrolledDomain); |
1094 | } |
1095 | |
1096 | isl::schedule_node Body = |
1097 | isl::manage(ptr: isl_schedule_node_delete(node: BandToUnroll.copy())); |
1098 | Body = Body.insert_sequence(filters: List); |
1099 | isl::schedule_node NewLoop = |
1100 | Body.insert_partial_schedule(schedule: StridedPartialSchedUAff); |
1101 | |
1102 | MDNode *FollowupMD = nullptr; |
1103 | if (Attr && Attr->Metadata) |
1104 | FollowupMD = |
1105 | findOptionalNodeOperand(LoopMD: Attr->Metadata, Name: LLVMLoopUnrollFollowupUnrolled); |
1106 | |
1107 | isl::id NewBandId = createGeneratedLoopAttr(Ctx, FollowupLoopMD: FollowupMD); |
1108 | if (!NewBandId.is_null()) |
1109 | NewLoop = insertMark(Band: NewLoop, Mark: NewBandId); |
1110 | |
1111 | return NewLoop.get_schedule(); |
1112 | } |
1113 | |
1114 | isl::set polly::getPartialTilePrefixes(isl::set ScheduleRange, |
1115 | int VectorWidth) { |
1116 | unsigned Dims = unsignedFromIslSize(Size: ScheduleRange.tuple_dim()); |
1117 | assert(Dims >= 1); |
1118 | isl::set LoopPrefixes = |
1119 | ScheduleRange.drop_constraints_involving_dims(type: isl::dim::set, first: Dims - 1, n: 1); |
1120 | auto ExtentPrefixes = addExtentConstraints(Set: LoopPrefixes, VectorWidth); |
1121 | isl::set BadPrefixes = ExtentPrefixes.subtract(set2: ScheduleRange); |
1122 | BadPrefixes = BadPrefixes.project_out(type: isl::dim::set, first: Dims - 1, n: 1); |
1123 | LoopPrefixes = LoopPrefixes.project_out(type: isl::dim::set, first: Dims - 1, n: 1); |
1124 | return LoopPrefixes.subtract(set2: BadPrefixes); |
1125 | } |
1126 | |
1127 | isl::union_set polly::getIsolateOptions(isl::set IsolateDomain, |
1128 | unsigned OutDimsNum) { |
1129 | unsigned Dims = unsignedFromIslSize(Size: IsolateDomain.tuple_dim()); |
1130 | assert(OutDimsNum <= Dims && |
1131 | "The isl::set IsolateDomain is used to describe the range of schedule " |
1132 | "dimensions values, which should be isolated. Consequently, the " |
1133 | "number of its dimensions should be greater than or equal to the " |
1134 | "number of the schedule dimensions." ); |
1135 | isl::map IsolateRelation = isl::map::from_domain(set: IsolateDomain); |
1136 | IsolateRelation = IsolateRelation.move_dims(dst_type: isl::dim::out, dst_pos: 0, src_type: isl::dim::in, |
1137 | src_pos: Dims - OutDimsNum, n: OutDimsNum); |
1138 | isl::set IsolateOption = IsolateRelation.wrap(); |
1139 | isl::id Id = isl::id::alloc(ctx: IsolateOption.ctx(), name: "isolate" , user: nullptr); |
1140 | IsolateOption = IsolateOption.set_tuple_id(Id); |
1141 | return isl::union_set(IsolateOption); |
1142 | } |
1143 | |
1144 | isl::union_set polly::getDimOptions(isl::ctx Ctx, const char *Option) { |
1145 | isl::space Space(Ctx, 0, 1); |
1146 | auto DimOption = isl::set::universe(space: Space); |
1147 | auto Id = isl::id::alloc(ctx: Ctx, name: Option, user: nullptr); |
1148 | DimOption = DimOption.set_tuple_id(Id); |
1149 | return isl::union_set(DimOption); |
1150 | } |
1151 | |
1152 | isl::schedule_node polly::tileNode(isl::schedule_node Node, |
1153 | const char *Identifier, |
1154 | ArrayRef<int> TileSizes, |
1155 | int DefaultTileSize) { |
1156 | auto Space = isl::manage(ptr: isl_schedule_node_band_get_space(node: Node.get())); |
1157 | auto Dims = Space.dim(type: isl::dim::set); |
1158 | auto Sizes = isl::multi_val::zero(space: Space); |
1159 | std::string IdentifierString(Identifier); |
1160 | for (unsigned i : rangeIslSize(Begin: 0, End: Dims)) { |
1161 | unsigned tileSize = i < TileSizes.size() ? TileSizes[i] : DefaultTileSize; |
1162 | Sizes = Sizes.set_val(pos: i, el: isl::val(Node.ctx(), tileSize)); |
1163 | } |
1164 | auto TileLoopMarkerStr = IdentifierString + " - Tiles" ; |
1165 | auto TileLoopMarker = isl::id::alloc(ctx: Node.ctx(), name: TileLoopMarkerStr, user: nullptr); |
1166 | Node = Node.insert_mark(mark: TileLoopMarker); |
1167 | Node = Node.child(pos: 0); |
1168 | Node = |
1169 | isl::manage(ptr: isl_schedule_node_band_tile(node: Node.release(), sizes: Sizes.release())); |
1170 | Node = Node.child(pos: 0); |
1171 | auto PointLoopMarkerStr = IdentifierString + " - Points" ; |
1172 | auto PointLoopMarker = |
1173 | isl::id::alloc(ctx: Node.ctx(), name: PointLoopMarkerStr, user: nullptr); |
1174 | Node = Node.insert_mark(mark: PointLoopMarker); |
1175 | return Node.child(pos: 0); |
1176 | } |
1177 | |
1178 | isl::schedule_node polly::applyRegisterTiling(isl::schedule_node Node, |
1179 | ArrayRef<int> TileSizes, |
1180 | int DefaultTileSize) { |
1181 | Node = tileNode(Node, Identifier: "Register tiling" , TileSizes, DefaultTileSize); |
1182 | auto Ctx = Node.ctx(); |
1183 | return Node.as<isl::schedule_node_band>().set_ast_build_options( |
1184 | isl::union_set(Ctx, "{unroll[x]}" )); |
1185 | } |
1186 | |
1187 | /// Find statements and sub-loops in (possibly nested) sequences. |
1188 | static void |
1189 | collectFissionableStmts(isl::schedule_node Node, |
1190 | SmallVectorImpl<isl::schedule_node> &ScheduleStmts) { |
1191 | if (isBand(Node) || isLeaf(Node)) { |
1192 | ScheduleStmts.push_back(Elt: Node); |
1193 | return; |
1194 | } |
1195 | |
1196 | if (Node.has_children()) { |
1197 | isl::schedule_node C = Node.first_child(); |
1198 | while (true) { |
1199 | collectFissionableStmts(Node: C, ScheduleStmts); |
1200 | if (!C.has_next_sibling()) |
1201 | break; |
1202 | C = C.next_sibling(); |
1203 | } |
1204 | } |
1205 | } |
1206 | |
1207 | isl::schedule polly::applyMaxFission(isl::schedule_node BandToFission) { |
1208 | isl::ctx Ctx = BandToFission.ctx(); |
1209 | BandToFission = removeMark(MarkOrBand: BandToFission); |
1210 | isl::schedule_node BandBody = BandToFission.child(pos: 0); |
1211 | |
1212 | SmallVector<isl::schedule_node> FissionableStmts; |
1213 | collectFissionableStmts(Node: BandBody, ScheduleStmts&: FissionableStmts); |
1214 | size_t N = FissionableStmts.size(); |
1215 | |
1216 | // Collect the domain for each of the statements that will get their own loop. |
1217 | isl::union_set_list DomList = isl::union_set_list(Ctx, N); |
1218 | for (size_t i = 0; i < N; ++i) { |
1219 | isl::schedule_node BodyPart = FissionableStmts[i]; |
1220 | DomList = DomList.add(el: BodyPart.get_domain()); |
1221 | } |
1222 | |
1223 | // Apply the fission by copying the entire loop, but inserting a filter for |
1224 | // the statement domains for each fissioned loop. |
1225 | isl::schedule_node Fissioned = BandToFission.insert_sequence(filters: DomList); |
1226 | |
1227 | return Fissioned.get_schedule(); |
1228 | } |
1229 | |
1230 | isl::schedule polly::applyGreedyFusion(isl::schedule Sched, |
1231 | const isl::union_map &Deps) { |
1232 | POLLY_DEBUG(dbgs() << "Greedy loop fusion\n" ); |
1233 | |
1234 | GreedyFusionRewriter Rewriter; |
1235 | isl::schedule Result = Rewriter.visit(Schedule: Sched, args: Deps); |
1236 | if (!Rewriter.AnyChange) { |
1237 | POLLY_DEBUG(dbgs() << "Found nothing to fuse\n" ); |
1238 | return Sched; |
1239 | } |
1240 | |
1241 | // GreedyFusionRewriter due to working loop-by-loop, bands with multiple loops |
1242 | // may have been split into multiple bands. |
1243 | return collapseBands(Sched: Result); |
1244 | } |
1245 | |