1//===- polly/ScheduleTreeTransform.cpp --------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Make changes to isl's schedule tree data structure.
10//
11//===----------------------------------------------------------------------===//
12
13#include "polly/ScheduleTreeTransform.h"
14#include "polly/Support/GICHelper.h"
15#include "polly/Support/ISLTools.h"
16#include "polly/Support/ScopHelper.h"
17#include "llvm/ADT/ArrayRef.h"
18#include "llvm/ADT/Sequence.h"
19#include "llvm/ADT/SmallVector.h"
20#include "llvm/IR/Constants.h"
21#include "llvm/IR/Metadata.h"
22#include "llvm/Transforms/Utils/UnrollLoop.h"
23
24#include "polly/Support/PollyDebug.h"
25#define DEBUG_TYPE "polly-opt-isl"
26
27using namespace polly;
28using namespace llvm;
29
30namespace {
31
32/// Copy the band member attributes (coincidence, loop type, isolate ast loop
33/// type) from one band to another.
34static isl::schedule_node_band
35applyBandMemberAttributes(isl::schedule_node_band Target, int TargetIdx,
36 const isl::schedule_node_band &Source,
37 int SourceIdx) {
38 bool Coincident = Source.member_get_coincident(pos: SourceIdx).release();
39 Target = Target.member_set_coincident(pos: TargetIdx, coincident: Coincident);
40
41 isl_ast_loop_type LoopType =
42 isl_schedule_node_band_member_get_ast_loop_type(node: Source.get(), pos: SourceIdx);
43 Target = isl::manage(ptr: isl_schedule_node_band_member_set_ast_loop_type(
44 node: Target.release(), pos: TargetIdx, type: LoopType))
45 .as<isl::schedule_node_band>();
46
47 isl_ast_loop_type IsolateType =
48 isl_schedule_node_band_member_get_isolate_ast_loop_type(node: Source.get(),
49 pos: SourceIdx);
50 Target = isl::manage(ptr: isl_schedule_node_band_member_set_isolate_ast_loop_type(
51 node: Target.release(), pos: TargetIdx, type: IsolateType))
52 .as<isl::schedule_node_band>();
53
54 return Target;
55}
56
57/// Create a new band by copying members from another @p Band. @p IncludeCb
58/// decides which band indices are copied to the result.
59template <typename CbTy>
60static isl::schedule rebuildBand(isl::schedule_node_band OldBand,
61 isl::schedule Body, CbTy IncludeCb) {
62 int NumBandDims = unsignedFromIslSize(Size: OldBand.n_member());
63
64 bool ExcludeAny = false;
65 bool IncludeAny = false;
66 for (auto OldIdx : seq<int>(Begin: 0, End: NumBandDims)) {
67 if (IncludeCb(OldIdx))
68 IncludeAny = true;
69 else
70 ExcludeAny = true;
71 }
72
73 // Instead of creating a zero-member band, don't create a band at all.
74 if (!IncludeAny)
75 return Body;
76
77 isl::multi_union_pw_aff PartialSched = OldBand.get_partial_schedule();
78 isl::multi_union_pw_aff NewPartialSched;
79 if (ExcludeAny) {
80 // Select the included partial scatter functions.
81 isl::union_pw_aff_list List = PartialSched.list();
82 int NewIdx = 0;
83 for (auto OldIdx : seq<int>(Begin: 0, End: NumBandDims)) {
84 if (IncludeCb(OldIdx))
85 NewIdx += 1;
86 else
87 List = List.drop(first: NewIdx, n: 1);
88 }
89 isl::space ParamSpace = PartialSched.get_space().params();
90 isl::space NewScatterSpace = ParamSpace.add_unnamed_tuple(dim: NewIdx);
91 NewPartialSched = isl::multi_union_pw_aff(NewScatterSpace, List);
92 } else {
93 // Just reuse original scatter function of copying all of them.
94 NewPartialSched = PartialSched;
95 }
96
97 // Create the new band node.
98 isl::schedule_node_band NewBand =
99 Body.insert_partial_schedule(partial: NewPartialSched)
100 .get_root()
101 .child(pos: 0)
102 .as<isl::schedule_node_band>();
103
104 // If OldBand was permutable, so is the new one, even if some dimensions are
105 // missing.
106 bool IsPermutable = OldBand.permutable().release();
107 NewBand = NewBand.set_permutable(IsPermutable);
108
109 // Reapply member attributes.
110 int NewIdx = 0;
111 for (auto OldIdx : seq<int>(Begin: 0, End: NumBandDims)) {
112 if (!IncludeCb(OldIdx))
113 continue;
114 NewBand =
115 applyBandMemberAttributes(Target: std::move(NewBand), TargetIdx: NewIdx, Source: OldBand, SourceIdx: OldIdx);
116 NewIdx += 1;
117 }
118
119 return NewBand.get_schedule();
120}
121
122/// Rewrite a schedule tree by reconstructing it bottom-up.
123///
124/// By default, the original schedule tree is reconstructed. To build a
125/// different tree, redefine visitor methods in a derived class (CRTP).
126///
127/// Note that AST build options are not applied; Setting the isolate[] option
128/// makes the schedule tree 'anchored' and cannot be modified afterwards. Hence,
129/// AST build options must be set after the tree has been constructed.
130template <typename Derived, typename... Args>
131struct ScheduleTreeRewriter
132 : RecursiveScheduleTreeVisitor<Derived, isl::schedule, Args...> {
133 Derived &getDerived() { return *static_cast<Derived *>(this); }
134 const Derived &getDerived() const {
135 return *static_cast<const Derived *>(this);
136 }
137
138 isl::schedule visitDomain(isl::schedule_node_domain Node, Args... args) {
139 // Every schedule_tree already has a domain node, no need to add one.
140 return getDerived().visit(Node.first_child(), std::forward<Args>(args)...);
141 }
142
143 isl::schedule visitBand(isl::schedule_node_band Band, Args... args) {
144 isl::schedule NewChild =
145 getDerived().visit(Band.child(pos: 0), std::forward<Args>(args)...);
146 return rebuildBand(Band, NewChild, [](int) { return true; });
147 }
148
149 isl::schedule visitSequence(isl::schedule_node_sequence Sequence,
150 Args... args) {
151 int NumChildren = isl_schedule_node_n_children(node: Sequence.get());
152 isl::schedule Result =
153 getDerived().visit(Sequence.child(pos: 0), std::forward<Args>(args)...);
154 for (int i = 1; i < NumChildren; i += 1)
155 Result = Result.sequence(
156 schedule2: getDerived().visit(Sequence.child(pos: i), std::forward<Args>(args)...));
157 return Result;
158 }
159
160 isl::schedule visitSet(isl::schedule_node_set Set, Args... args) {
161 int NumChildren = isl_schedule_node_n_children(node: Set.get());
162 isl::schedule Result =
163 getDerived().visit(Set.child(pos: 0), std::forward<Args>(args)...);
164 for (int i = 1; i < NumChildren; i += 1)
165 Result = isl::manage(
166 isl_schedule_set(Result.release(),
167 getDerived()
168 .visit(Set.child(pos: i), std::forward<Args>(args)...)
169 .release()));
170 return Result;
171 }
172
173 isl::schedule visitLeaf(isl::schedule_node_leaf Leaf, Args... args) {
174 return isl::schedule::from_domain(domain: Leaf.get_domain());
175 }
176
177 isl::schedule visitMark(const isl::schedule_node &Mark, Args... args) {
178
179 isl::id TheMark = Mark.as<isl::schedule_node_mark>().get_id();
180 isl::schedule_node NewChild =
181 getDerived()
182 .visit(Mark.first_child(), std::forward<Args>(args)...)
183 .get_root()
184 .first_child();
185 return NewChild.insert_mark(mark: TheMark).get_schedule();
186 }
187
188 isl::schedule visitExtension(isl::schedule_node_extension Extension,
189 Args... args) {
190 isl::union_map TheExtension =
191 Extension.as<isl::schedule_node_extension>().get_extension();
192 isl::schedule_node NewChild = getDerived()
193 .visit(Extension.child(pos: 0), args...)
194 .get_root()
195 .first_child();
196 isl::schedule_node NewExtension =
197 isl::schedule_node::from_extension(extension: TheExtension);
198 return NewChild.graft_before(graft: NewExtension).get_schedule();
199 }
200
201 isl::schedule visitFilter(isl::schedule_node_filter Filter, Args... args) {
202 isl::union_set FilterDomain =
203 Filter.as<isl::schedule_node_filter>().get_filter();
204 isl::schedule NewSchedule =
205 getDerived().visit(Filter.child(pos: 0), std::forward<Args>(args)...);
206 return NewSchedule.intersect_domain(domain: FilterDomain);
207 }
208
209 isl::schedule visitNode(isl::schedule_node Node, Args... args) {
210 llvm_unreachable("Not implemented");
211 }
212};
213
214/// Rewrite the schedule tree without any changes. Useful to copy a subtree into
215/// a new schedule, discarding everything but.
216struct IdentityRewriter : ScheduleTreeRewriter<IdentityRewriter> {};
217
218/// Rewrite a schedule tree to an equivalent one without extension nodes.
219///
220/// Each visit method takes two additional arguments:
221///
222/// * The new domain the node, which is the inherited domain plus any domains
223/// added by extension nodes.
224///
225/// * A map of extension domains of all children is returned; it is required by
226/// band nodes to schedule the additional domains at the same position as the
227/// extension node would.
228///
229struct ExtensionNodeRewriter final
230 : ScheduleTreeRewriter<ExtensionNodeRewriter, const isl::union_set &,
231 isl::union_map &> {
232 using BaseTy = ScheduleTreeRewriter<ExtensionNodeRewriter,
233 const isl::union_set &, isl::union_map &>;
234 BaseTy &getBase() { return *this; }
235 const BaseTy &getBase() const { return *this; }
236
237 isl::schedule visitSchedule(isl::schedule Schedule) {
238 isl::union_map Extensions;
239 isl::schedule Result =
240 visit(Node: Schedule.get_root(), args: Schedule.get_domain(), args&: Extensions);
241 assert(!Extensions.is_null() && Extensions.is_empty());
242 return Result;
243 }
244
245 isl::schedule visitSequence(isl::schedule_node_sequence Sequence,
246 const isl::union_set &Domain,
247 isl::union_map &Extensions) {
248 int NumChildren = isl_schedule_node_n_children(node: Sequence.get());
249 isl::schedule NewNode = visit(Node: Sequence.first_child(), args: Domain, args&: Extensions);
250 for (int i = 1; i < NumChildren; i += 1) {
251 isl::schedule_node OldChild = Sequence.child(pos: i);
252 isl::union_map NewChildExtensions;
253 isl::schedule NewChildNode = visit(Node: OldChild, args: Domain, args&: NewChildExtensions);
254 NewNode = NewNode.sequence(schedule2: NewChildNode);
255 Extensions = Extensions.unite(umap2: NewChildExtensions);
256 }
257 return NewNode;
258 }
259
260 isl::schedule visitSet(isl::schedule_node_set Set,
261 const isl::union_set &Domain,
262 isl::union_map &Extensions) {
263 int NumChildren = isl_schedule_node_n_children(node: Set.get());
264 isl::schedule NewNode = visit(Node: Set.first_child(), args: Domain, args&: Extensions);
265 for (int i = 1; i < NumChildren; i += 1) {
266 isl::schedule_node OldChild = Set.child(pos: i);
267 isl::union_map NewChildExtensions;
268 isl::schedule NewChildNode = visit(Node: OldChild, args: Domain, args&: NewChildExtensions);
269 NewNode = isl::manage(
270 ptr: isl_schedule_set(schedule1: NewNode.release(), schedule2: NewChildNode.release()));
271 Extensions = Extensions.unite(umap2: NewChildExtensions);
272 }
273 return NewNode;
274 }
275
276 isl::schedule visitLeaf(isl::schedule_node_leaf Leaf,
277 const isl::union_set &Domain,
278 isl::union_map &Extensions) {
279 Extensions = isl::union_map::empty(ctx: Leaf.ctx());
280 return isl::schedule::from_domain(domain: Domain);
281 }
282
283 isl::schedule visitBand(isl::schedule_node_band OldNode,
284 const isl::union_set &Domain,
285 isl::union_map &OuterExtensions) {
286 isl::schedule_node OldChild = OldNode.first_child();
287 isl::multi_union_pw_aff PartialSched =
288 isl::manage(ptr: isl_schedule_node_band_get_partial_schedule(node: OldNode.get()));
289
290 isl::union_map NewChildExtensions;
291 isl::schedule NewChild = visit(Node: OldChild, args: Domain, args&: NewChildExtensions);
292
293 // Add the extensions to the partial schedule.
294 OuterExtensions = isl::union_map::empty(ctx: NewChildExtensions.ctx());
295 isl::union_map NewPartialSchedMap = isl::union_map::from(mupa: PartialSched);
296 unsigned BandDims = isl_schedule_node_band_n_member(node: OldNode.get());
297 for (isl::map Ext : NewChildExtensions.get_map_list()) {
298 unsigned ExtDims = unsignedFromIslSize(Size: Ext.domain_tuple_dim());
299 assert(ExtDims >= BandDims);
300 unsigned OuterDims = ExtDims - BandDims;
301
302 isl::map BandSched =
303 Ext.project_out(type: isl::dim::in, first: 0, n: OuterDims).reverse();
304 NewPartialSchedMap = NewPartialSchedMap.unite(umap2: BandSched);
305
306 // There might be more outer bands that have to schedule the extensions.
307 if (OuterDims > 0) {
308 isl::map OuterSched =
309 Ext.project_out(type: isl::dim::in, first: OuterDims, n: BandDims);
310 OuterExtensions = OuterExtensions.unite(umap2: OuterSched);
311 }
312 }
313 isl::multi_union_pw_aff NewPartialSchedAsAsMultiUnionPwAff =
314 isl::multi_union_pw_aff::from_union_map(umap: NewPartialSchedMap);
315 isl::schedule_node NewNode =
316 NewChild.insert_partial_schedule(partial: NewPartialSchedAsAsMultiUnionPwAff)
317 .get_root()
318 .child(pos: 0);
319
320 // Reapply permutability and coincidence attributes.
321 NewNode = isl::manage(ptr: isl_schedule_node_band_set_permutable(
322 node: NewNode.release(),
323 permutable: isl_schedule_node_band_get_permutable(node: OldNode.get())));
324 for (unsigned i = 0; i < BandDims; i += 1)
325 NewNode = applyBandMemberAttributes(Target: NewNode.as<isl::schedule_node_band>(),
326 TargetIdx: i, Source: OldNode, SourceIdx: i);
327
328 return NewNode.get_schedule();
329 }
330
331 isl::schedule visitFilter(isl::schedule_node_filter Filter,
332 const isl::union_set &Domain,
333 isl::union_map &Extensions) {
334 isl::union_set FilterDomain =
335 Filter.as<isl::schedule_node_filter>().get_filter();
336 isl::union_set NewDomain = Domain.intersect(uset2: FilterDomain);
337
338 // A filter is added implicitly if necessary when joining schedule trees.
339 return visit(Node: Filter.first_child(), args: NewDomain, args&: Extensions);
340 }
341
342 isl::schedule visitExtension(isl::schedule_node_extension Extension,
343 const isl::union_set &Domain,
344 isl::union_map &Extensions) {
345 isl::union_map ExtDomain =
346 Extension.as<isl::schedule_node_extension>().get_extension();
347 isl::union_set NewDomain = Domain.unite(uset2: ExtDomain.range());
348 isl::union_map ChildExtensions;
349 isl::schedule NewChild =
350 visit(Node: Extension.first_child(), args: NewDomain, args&: ChildExtensions);
351 Extensions = ChildExtensions.unite(umap2: ExtDomain);
352 return NewChild;
353 }
354};
355
356/// Collect all AST build options in any schedule tree band.
357///
358/// ScheduleTreeRewriter cannot apply the schedule tree options. This class
359/// collects these options to apply them later.
360struct CollectASTBuildOptions final
361 : RecursiveScheduleTreeVisitor<CollectASTBuildOptions> {
362 using BaseTy = RecursiveScheduleTreeVisitor<CollectASTBuildOptions>;
363 BaseTy &getBase() { return *this; }
364 const BaseTy &getBase() const { return *this; }
365
366 llvm::SmallVector<isl::union_set, 8> ASTBuildOptions;
367
368 void visitBand(isl::schedule_node_band Band) {
369 ASTBuildOptions.push_back(
370 Elt: isl::manage(ptr: isl_schedule_node_band_get_ast_build_options(node: Band.get())));
371 return getBase().visitBand(Band);
372 }
373};
374
375/// Apply AST build options to the bands in a schedule tree.
376///
377/// This rewrites a schedule tree with the AST build options applied. We assume
378/// that the band nodes are visited in the same order as they were when the
379/// build options were collected, typically by CollectASTBuildOptions.
380struct ApplyASTBuildOptions final : ScheduleNodeRewriter<ApplyASTBuildOptions> {
381 using BaseTy = ScheduleNodeRewriter<ApplyASTBuildOptions>;
382 BaseTy &getBase() { return *this; }
383 const BaseTy &getBase() const { return *this; }
384
385 size_t Pos;
386 llvm::ArrayRef<isl::union_set> ASTBuildOptions;
387
388 ApplyASTBuildOptions(llvm::ArrayRef<isl::union_set> ASTBuildOptions)
389 : ASTBuildOptions(ASTBuildOptions) {}
390
391 isl::schedule visitSchedule(isl::schedule Schedule) {
392 Pos = 0;
393 isl::schedule Result = visit(Schedule).get_schedule();
394 assert(Pos == ASTBuildOptions.size() &&
395 "AST build options must match to band nodes");
396 return Result;
397 }
398
399 isl::schedule_node visitBand(isl::schedule_node_band Band) {
400 isl::schedule_node_band Result =
401 Band.set_ast_build_options(ASTBuildOptions[Pos]);
402 Pos += 1;
403 return getBase().visitBand(Band: Result);
404 }
405};
406
407/// Return whether the schedule contains an extension node.
408static bool containsExtensionNode(isl::schedule Schedule) {
409 assert(!Schedule.is_null());
410
411 auto Callback = [](__isl_keep isl_schedule_node *Node,
412 void *User) -> isl_bool {
413 if (isl_schedule_node_get_type(node: Node) == isl_schedule_node_extension) {
414 // Stop walking the schedule tree.
415 return isl_bool_error;
416 }
417
418 // Continue searching the subtree.
419 return isl_bool_true;
420 };
421 isl_stat RetVal = isl_schedule_foreach_schedule_node_top_down(
422 sched: Schedule.get(), fn: Callback, user: nullptr);
423
424 // We assume that the traversal itself does not fail, i.e. the only reason to
425 // return isl_stat_error is that an extension node was found.
426 return RetVal == isl_stat_error;
427}
428
429/// Find a named MDNode property in a LoopID.
430static MDNode *findOptionalNodeOperand(MDNode *LoopMD, StringRef Name) {
431 return dyn_cast_or_null<MDNode>(
432 Val: findMetadataOperand(LoopMD, Name).value_or(u: nullptr));
433}
434
435/// Is this node of type mark?
436static bool isMark(const isl::schedule_node &Node) {
437 return isl_schedule_node_get_type(node: Node.get()) == isl_schedule_node_mark;
438}
439
440/// Is this node of type band?
441static bool isBand(const isl::schedule_node &Node) {
442 return isl_schedule_node_get_type(node: Node.get()) == isl_schedule_node_band;
443}
444
445#ifndef NDEBUG
446/// Is this node a band of a single dimension (i.e. could represent a loop)?
447static bool isBandWithSingleLoop(const isl::schedule_node &Node) {
448 return isBand(Node) && isl_schedule_node_band_n_member(node: Node.get()) == 1;
449}
450#endif
451
452static bool isLeaf(const isl::schedule_node &Node) {
453 return isl_schedule_node_get_type(node: Node.get()) == isl_schedule_node_leaf;
454}
455
456/// Create an isl::id representing the output loop after a transformation.
457static isl::id createGeneratedLoopAttr(isl::ctx Ctx, MDNode *FollowupLoopMD) {
458 // Don't need to id the followup.
459 // TODO: Append llvm.loop.disable_heustistics metadata unless overridden by
460 // user followup-MD
461 if (!FollowupLoopMD)
462 return {};
463
464 BandAttr *Attr = new BandAttr();
465 Attr->Metadata = FollowupLoopMD;
466 return getIslLoopAttr(Ctx, Attr);
467}
468
469/// A loop consists of a band and an optional marker that wraps it. Return the
470/// outermost of the two.
471
472/// That is, either the mark or, if there is not mark, the loop itself. Can
473/// start with either the mark or the band.
474static isl::schedule_node moveToBandMark(isl::schedule_node BandOrMark) {
475 if (isBandMark(Node: BandOrMark)) {
476 assert(isBandWithSingleLoop(BandOrMark.child(0)));
477 return BandOrMark;
478 }
479 assert(isBandWithSingleLoop(BandOrMark));
480
481 isl::schedule_node Mark = BandOrMark.parent();
482 if (isBandMark(Node: Mark))
483 return Mark;
484
485 // Band has no loop marker.
486 return BandOrMark;
487}
488
489static isl::schedule_node removeMark(isl::schedule_node MarkOrBand,
490 BandAttr *&Attr) {
491 MarkOrBand = moveToBandMark(BandOrMark: MarkOrBand);
492
493 isl::schedule_node Band;
494 if (isMark(Node: MarkOrBand)) {
495 Attr = getLoopAttr(Id: MarkOrBand.as<isl::schedule_node_mark>().get_id());
496 Band = isl::manage(ptr: isl_schedule_node_delete(node: MarkOrBand.release()));
497 } else {
498 Attr = nullptr;
499 Band = MarkOrBand;
500 }
501
502 assert(isBandWithSingleLoop(Band));
503 return Band;
504}
505
506/// Remove the mark that wraps a loop. Return the band representing the loop.
507static isl::schedule_node removeMark(isl::schedule_node MarkOrBand) {
508 BandAttr *Attr;
509 return removeMark(MarkOrBand, Attr);
510}
511
512static isl::schedule_node insertMark(isl::schedule_node Band, isl::id Mark) {
513 assert(isBand(Band));
514 assert(moveToBandMark(Band).is_equal(Band) &&
515 "Don't add a two marks for a band");
516
517 return Band.insert_mark(mark: Mark).child(pos: 0);
518}
519
520/// Return the (one-dimensional) set of numbers that are divisible by @p Factor
521/// with remainder @p Offset.
522///
523/// isDivisibleBySet(Ctx, 4, 0) = { [i] : floord(i,4) = 0 }
524/// isDivisibleBySet(Ctx, 4, 1) = { [i] : floord(i,4) = 1 }
525///
526static isl::basic_set isDivisibleBySet(isl::ctx &Ctx, long Factor,
527 long Offset) {
528 isl::val ValFactor{Ctx, Factor};
529 isl::val ValOffset{Ctx, Offset};
530
531 isl::space Unispace{Ctx, 0, 1};
532 isl::local_space LUnispace{Unispace};
533 isl::aff AffFactor{LUnispace, ValFactor};
534 isl::aff AffOffset{LUnispace, ValOffset};
535
536 isl::aff Id = isl::aff::var_on_domain(ls: LUnispace, type: isl::dim::out, pos: 0);
537 isl::aff DivMul = Id.mod(mod: ValFactor);
538 isl::basic_map Divisible = isl::basic_map::from_aff(aff: DivMul);
539 isl::basic_map Modulo = Divisible.fix_val(type: isl::dim::out, pos: 0, v: ValOffset);
540 return Modulo.domain();
541}
542
543/// Make the last dimension of Set to take values from 0 to VectorWidth - 1.
544///
545/// @param Set A set, which should be modified.
546/// @param VectorWidth A parameter, which determines the constraint.
547static isl::set addExtentConstraints(isl::set Set, int VectorWidth) {
548 unsigned Dims = unsignedFromIslSize(Size: Set.tuple_dim());
549 assert(Dims >= 1);
550 isl::space Space = Set.get_space();
551 isl::local_space LocalSpace = isl::local_space(Space);
552 isl::constraint ExtConstr = isl::constraint::alloc_inequality(ls: LocalSpace);
553 ExtConstr = ExtConstr.set_constant_si(0);
554 ExtConstr = ExtConstr.set_coefficient_si(type: isl::dim::set, pos: Dims - 1, v: 1);
555 Set = Set.add_constraint(constraint: ExtConstr);
556 ExtConstr = isl::constraint::alloc_inequality(ls: LocalSpace);
557 ExtConstr = ExtConstr.set_constant_si(VectorWidth - 1);
558 ExtConstr = ExtConstr.set_coefficient_si(type: isl::dim::set, pos: Dims - 1, v: -1);
559 return Set.add_constraint(constraint: ExtConstr);
560}
561
562/// Collapse perfectly nested bands into a single band.
563class BandCollapseRewriter final
564 : public ScheduleTreeRewriter<BandCollapseRewriter> {
565private:
566 using BaseTy = ScheduleTreeRewriter<BandCollapseRewriter>;
567 BaseTy &getBase() { return *this; }
568 const BaseTy &getBase() const { return *this; }
569
570public:
571 isl::schedule visitBand(isl::schedule_node_band RootBand) {
572 isl::schedule_node_band Band = RootBand;
573 isl::ctx Ctx = Band.ctx();
574
575 // Do not merge permutable band to avoid loosing the permutability property.
576 // Cannot collapse even two permutable loops, they might be permutable
577 // individually, but not necassarily across.
578 if (unsignedFromIslSize(Size: Band.n_member()) > 1u && Band.permutable())
579 return getBase().visitBand(Band);
580
581 // Find collapsable bands.
582 SmallVector<isl::schedule_node_band> Nest;
583 int NumTotalLoops = 0;
584 isl::schedule_node Body;
585 while (true) {
586 Nest.push_back(Elt: Band);
587 NumTotalLoops += unsignedFromIslSize(Size: Band.n_member());
588 Body = Band.first_child();
589 if (!Body.isa<isl::schedule_node_band>())
590 break;
591 Band = Body.as<isl::schedule_node_band>();
592
593 // Do not include next band if it is permutable to not lose its
594 // permutability property.
595 if (unsignedFromIslSize(Size: Band.n_member()) > 1u && Band.permutable())
596 break;
597 }
598
599 // Nothing to collapse, preserve permutability.
600 if (Nest.size() <= 1)
601 return getBase().visitBand(Band);
602
603 POLLY_DEBUG({
604 dbgs() << "Found loops to collapse between\n";
605 dumpIslObj(RootBand, dbgs());
606 dbgs() << "and\n";
607 dumpIslObj(Body, dbgs());
608 dbgs() << "\n";
609 });
610
611 isl::schedule NewBody = visit(Node: Body);
612
613 // Collect partial schedules from all members.
614 isl::union_pw_aff_list PartScheds{Ctx, NumTotalLoops};
615 for (isl::schedule_node_band Band : Nest) {
616 int NumLoops = unsignedFromIslSize(Size: Band.n_member());
617 isl::multi_union_pw_aff BandScheds = Band.get_partial_schedule();
618 for (auto j : seq<int>(Begin: 0, End: NumLoops))
619 PartScheds = PartScheds.add(el: BandScheds.at(pos: j));
620 }
621 isl::space ScatterSpace = isl::space(Ctx, 0, NumTotalLoops);
622 isl::multi_union_pw_aff PartSchedsMulti{ScatterSpace, PartScheds};
623
624 isl::schedule_node_band CollapsedBand =
625 NewBody.insert_partial_schedule(partial: PartSchedsMulti)
626 .get_root()
627 .first_child()
628 .as<isl::schedule_node_band>();
629
630 // Copy over loop attributes form original bands.
631 int LoopIdx = 0;
632 for (isl::schedule_node_band Band : Nest) {
633 int NumLoops = unsignedFromIslSize(Size: Band.n_member());
634 for (int i : seq<int>(Begin: 0, End: NumLoops)) {
635 CollapsedBand = applyBandMemberAttributes(Target: std::move(CollapsedBand),
636 TargetIdx: LoopIdx, Source: Band, SourceIdx: i);
637 LoopIdx += 1;
638 }
639 }
640 assert(LoopIdx == NumTotalLoops &&
641 "Expect the same number of loops to add up again");
642
643 return CollapsedBand.get_schedule();
644 }
645};
646
647static isl::schedule collapseBands(isl::schedule Sched) {
648 POLLY_DEBUG(dbgs() << "Collapse bands in schedule\n");
649 BandCollapseRewriter Rewriter;
650 return Rewriter.visit(Schedule: Sched);
651}
652
653/// Collect sequentially executed bands (or anything else), even if nested in a
654/// mark or other nodes whose child is executed just once. If we can
655/// successfully fuse the bands, we allow them to be removed.
656static void collectPotentiallyFusableBands(
657 isl::schedule_node Node,
658 SmallVectorImpl<std::pair<isl::schedule_node, isl::schedule_node>>
659 &ScheduleBands,
660 const isl::schedule_node &DirectChild) {
661 switch (isl_schedule_node_get_type(node: Node.get())) {
662 case isl_schedule_node_sequence:
663 case isl_schedule_node_set:
664 case isl_schedule_node_mark:
665 case isl_schedule_node_domain:
666 case isl_schedule_node_filter:
667 if (Node.has_children()) {
668 isl::schedule_node C = Node.first_child();
669 while (true) {
670 collectPotentiallyFusableBands(Node: C, ScheduleBands, DirectChild);
671 if (!C.has_next_sibling())
672 break;
673 C = C.next_sibling();
674 }
675 }
676 break;
677
678 default:
679 // Something that does not execute suquentially (e.g. a band)
680 ScheduleBands.push_back(Elt: {Node, DirectChild});
681 break;
682 }
683}
684
685/// Remove dependencies that are resolved by @p PartSched. That is, remove
686/// everything that we already know is executed in-order.
687static isl::union_map remainingDepsFromPartialSchedule(isl::union_map PartSched,
688 isl::union_map Deps) {
689 unsigned NumDims = getNumScatterDims(Schedule: PartSched);
690 auto ParamSpace = PartSched.get_space().params();
691
692 // { Scatter[] }
693 isl::space ScatterSpace =
694 ParamSpace.set_from_params().add_dims(type: isl::dim::set, n: NumDims);
695
696 // { Scatter[] -> Domain[] }
697 isl::union_map PartSchedRev = PartSched.reverse();
698
699 // { Scatter[] -> Scatter[] }
700 isl::map MaybeBefore = isl::map::lex_le(set_space: ScatterSpace);
701
702 // { Domain[] -> Domain[] }
703 isl::union_map DomMaybeBefore =
704 MaybeBefore.apply_domain(umap2: PartSchedRev).apply_range(umap2: PartSchedRev);
705
706 // { Domain[] -> Domain[] }
707 isl::union_map ChildRemainingDeps = Deps.intersect(umap2: DomMaybeBefore);
708
709 return ChildRemainingDeps;
710}
711
712/// Remove dependencies that are resolved by executing them in the order
713/// specified by @p Domains;
714static isl::union_map remainigDepsFromSequence(ArrayRef<isl::union_set> Domains,
715 isl::union_map Deps) {
716 isl::ctx Ctx = Deps.ctx();
717 isl::space ParamSpace = Deps.get_space().params();
718
719 // Create a partial schedule mapping to constants that reflect the execution
720 // order.
721 isl::union_map PartialSchedules = isl::union_map::empty(ctx: Ctx);
722 for (auto P : enumerate(First&: Domains)) {
723 isl::val ExecTime = isl::val(Ctx, P.index());
724 isl::union_pw_aff DomSched{P.value(), ExecTime};
725 PartialSchedules = PartialSchedules.unite(umap2: DomSched.as_union_map());
726 }
727
728 return remainingDepsFromPartialSchedule(PartSched: PartialSchedules, Deps);
729}
730
731/// Determine whether the outermost loop of to bands can be fused while
732/// respecting validity dependencies.
733static bool canFuseOutermost(const isl::schedule_node_band &LHS,
734 const isl::schedule_node_band &RHS,
735 const isl::union_map &Deps) {
736 // { LHSDomain[] -> Scatter[] }
737 isl::union_map LHSPartSched =
738 LHS.get_partial_schedule().get_at(pos: 0).as_union_map();
739
740 // { Domain[] -> Scatter[] }
741 isl::union_map RHSPartSched =
742 RHS.get_partial_schedule().get_at(pos: 0).as_union_map();
743
744 // Dependencies that are already resolved because LHS executes before RHS, but
745 // will not be anymore after fusion. { DefDomain[] -> UseDomain[] }
746 isl::union_map OrderedBySequence =
747 Deps.intersect_domain(uset: LHSPartSched.domain())
748 .intersect_range(uset: RHSPartSched.domain());
749
750 isl::space ParamSpace = OrderedBySequence.get_space().params();
751 isl::space NewScatterSpace = ParamSpace.add_unnamed_tuple(dim: 1);
752
753 // { Scatter[] -> Scatter[] }
754 isl::map After = isl::map::lex_gt(set_space: NewScatterSpace);
755
756 // After fusion, instances with smaller (or equal, which means they will be
757 // executed in the same iteration, but the LHS instance is still sequenced
758 // before RHS) scatter value will still be executed before. This are the
759 // orderings where this is not necessarily the case.
760 // { LHSDomain[] -> RHSDomain[] }
761 isl::union_map MightBeAfterDoms = After.apply_domain(umap2: LHSPartSched.reverse())
762 .apply_range(umap2: RHSPartSched.reverse());
763
764 // Dependencies that are not resolved by the new execution order.
765 isl::union_map WithBefore = OrderedBySequence.intersect(umap2: MightBeAfterDoms);
766
767 return WithBefore.is_empty();
768}
769
770/// Fuse @p LHS and @p RHS if possible while preserving validity dependenvies.
771static isl::schedule tryGreedyFuse(isl::schedule_node_band LHS,
772 isl::schedule_node_band RHS,
773 const isl::union_map &Deps) {
774 if (!canFuseOutermost(LHS, RHS, Deps))
775 return {};
776
777 POLLY_DEBUG({
778 dbgs() << "Found loops for greedy fusion:\n";
779 dumpIslObj(LHS, dbgs());
780 dbgs() << "and\n";
781 dumpIslObj(RHS, dbgs());
782 dbgs() << "\n";
783 });
784
785 // The partial schedule of the bands outermost loop that we need to combine
786 // for the fusion.
787 isl::union_pw_aff LHSPartOuterSched = LHS.get_partial_schedule().get_at(pos: 0);
788 isl::union_pw_aff RHSPartOuterSched = RHS.get_partial_schedule().get_at(pos: 0);
789
790 // Isolate band bodies as roots of their own schedule trees.
791 IdentityRewriter Rewriter;
792 isl::schedule LHSBody = Rewriter.visit(Node: LHS.first_child());
793 isl::schedule RHSBody = Rewriter.visit(Node: RHS.first_child());
794
795 // Reconstruct the non-outermost (not going to be fused) loops from both
796 // bands.
797 // TODO: Maybe it is possibly to transfer the 'permutability' property from
798 // LHS+RHS. At minimum we need merge multiple band members at once, otherwise
799 // permutability has no meaning.
800 isl::schedule LHSNewBody =
801 rebuildBand(OldBand: LHS, Body: LHSBody, IncludeCb: [](int i) { return i > 0; });
802 isl::schedule RHSNewBody =
803 rebuildBand(OldBand: RHS, Body: RHSBody, IncludeCb: [](int i) { return i > 0; });
804
805 // The loop body of the fused loop.
806 isl::schedule NewCommonBody = LHSNewBody.sequence(schedule2: RHSNewBody);
807
808 // Combine the partial schedules of both loops to a new one. Instances with
809 // the same scatter value are put together.
810 isl::union_map NewCommonPartialSched =
811 LHSPartOuterSched.as_union_map().unite(umap2: RHSPartOuterSched.as_union_map());
812 isl::schedule NewCommonSchedule = NewCommonBody.insert_partial_schedule(
813 partial: NewCommonPartialSched.as_multi_union_pw_aff());
814
815 return NewCommonSchedule;
816}
817
818static isl::schedule tryGreedyFuse(isl::schedule_node LHS,
819 isl::schedule_node RHS,
820 const isl::union_map &Deps) {
821 // TODO: Non-bands could be interpreted as a band with just as single
822 // iteration. However, this is only useful if both ends of a fused loop were
823 // originally loops themselves.
824 if (!LHS.isa<isl::schedule_node_band>())
825 return {};
826 if (!RHS.isa<isl::schedule_node_band>())
827 return {};
828
829 return tryGreedyFuse(LHS: LHS.as<isl::schedule_node_band>(),
830 RHS: RHS.as<isl::schedule_node_band>(), Deps);
831}
832
833/// Fuse all fusable loop top-down in a schedule tree.
834///
835/// The isl::union_map parameters is the set of validity dependencies that have
836/// not been resolved/carried by a parent schedule node.
837class GreedyFusionRewriter final
838 : public ScheduleTreeRewriter<GreedyFusionRewriter, isl::union_map> {
839private:
840 using BaseTy = ScheduleTreeRewriter<GreedyFusionRewriter, isl::union_map>;
841 BaseTy &getBase() { return *this; }
842 const BaseTy &getBase() const { return *this; }
843
844public:
845 /// Is set to true if anything has been fused.
846 bool AnyChange = false;
847
848 isl::schedule visitBand(isl::schedule_node_band Band, isl::union_map Deps) {
849 // { Domain[] -> Scatter[] }
850 isl::union_map PartSched =
851 isl::union_map::from(mupa: Band.get_partial_schedule());
852 assert(getNumScatterDims(PartSched) ==
853 unsignedFromIslSize(Band.n_member()));
854 isl::space ParamSpace = PartSched.get_space().params();
855
856 // { Scatter[] -> Domain[] }
857 isl::union_map PartSchedRev = PartSched.reverse();
858
859 // Possible within the same iteration. Dependencies with smaller scatter
860 // value are carried by this loop and therefore have been resolved by the
861 // in-order execution if the loop iteration. A dependency with small scatter
862 // value would be a dependency violation that we assume did not happen. {
863 // Domain[] -> Domain[] }
864 isl::union_map Unsequenced = PartSchedRev.apply_domain(umap2: PartSchedRev);
865
866 // Actual dependencies within the same iteration.
867 // { DefDomain[] -> UseDomain[] }
868 isl::union_map RemDeps = Deps.intersect(umap2: Unsequenced);
869
870 return getBase().visitBand(Band, args: RemDeps);
871 }
872
873 isl::schedule visitSequence(isl::schedule_node_sequence Sequence,
874 isl::union_map Deps) {
875 int NumChildren = isl_schedule_node_n_children(node: Sequence.get());
876
877 // List of fusion candidates. The first element is the fusion candidate, the
878 // second is candidate's ancestor that is the sequence's direct child. It is
879 // preferable to use the direct child if not if its non-direct children is
880 // fused to preserve its structure such as mark nodes.
881 SmallVector<std::pair<isl::schedule_node, isl::schedule_node>> Bands;
882 for (auto i : seq<int>(Begin: 0, End: NumChildren)) {
883 isl::schedule_node Child = Sequence.child(pos: i);
884 collectPotentiallyFusableBands(Node: Child, ScheduleBands&: Bands, DirectChild: Child);
885 }
886
887 // Direct children that had at least one of its decendants fused.
888 SmallDenseSet<isl_schedule_node *, 4> ChangedDirectChildren;
889
890 // Fuse neigboring bands until reaching the end of candidates.
891 int i = 0;
892 while (i + 1 < (int)Bands.size()) {
893 isl::schedule Fused =
894 tryGreedyFuse(LHS: Bands[i].first, RHS: Bands[i + 1].first, Deps);
895 if (Fused.is_null()) {
896 // Cannot merge this node with the next; look at next pair.
897 i += 1;
898 continue;
899 }
900
901 // Mark the direct children as (partially) fused.
902 if (!Bands[i].second.is_null())
903 ChangedDirectChildren.insert(V: Bands[i].second.get());
904 if (!Bands[i + 1].second.is_null())
905 ChangedDirectChildren.insert(V: Bands[i + 1].second.get());
906
907 // Collapse the neigbros to a single new candidate that could be fused
908 // with the next candidate.
909 Bands[i] = {Fused.get_root(), {}};
910 Bands.erase(CI: Bands.begin() + i + 1);
911
912 AnyChange = true;
913 }
914
915 // By construction equal if done with collectPotentiallyFusableBands's
916 // output.
917 SmallVector<isl::union_set> SubDomains;
918 SubDomains.reserve(N: NumChildren);
919 for (int i = 0; i < NumChildren; i += 1)
920 SubDomains.push_back(Elt: Sequence.child(pos: i).domain());
921 auto SubRemainingDeps = remainigDepsFromSequence(Domains: SubDomains, Deps);
922
923 // We may iterate over direct children multiple times, be sure to add each
924 // at most once.
925 SmallDenseSet<isl_schedule_node *, 4> AlreadyAdded;
926
927 isl::schedule Result;
928 for (auto &P : Bands) {
929 isl::schedule_node MaybeFused = P.first;
930 isl::schedule_node DirectChild = P.second;
931
932 // If not modified, use the direct child.
933 if (!DirectChild.is_null() &&
934 !ChangedDirectChildren.count(V: DirectChild.get())) {
935 if (AlreadyAdded.count(V: DirectChild.get()))
936 continue;
937 AlreadyAdded.insert(V: DirectChild.get());
938 MaybeFused = DirectChild;
939 } else {
940 assert(AnyChange &&
941 "Need changed flag for be consistent with actual change");
942 }
943
944 // Top-down recursion: If the outermost loop has been fused, their nested
945 // bands might be fusable now as well.
946 isl::schedule InnerFused = visit(Node: MaybeFused, args: SubRemainingDeps);
947
948 // Reconstruct the sequence, with some of the children fused.
949 if (Result.is_null())
950 Result = InnerFused;
951 else
952 Result = Result.sequence(schedule2: InnerFused);
953 }
954
955 return Result;
956 }
957};
958
959} // namespace
960
961bool polly::isBandMark(const isl::schedule_node &Node) {
962 return isMark(Node) &&
963 isLoopAttr(Id: Node.as<isl::schedule_node_mark>().get_id());
964}
965
966BandAttr *polly::getBandAttr(isl::schedule_node MarkOrBand) {
967 MarkOrBand = moveToBandMark(BandOrMark: MarkOrBand);
968 if (!isMark(Node: MarkOrBand))
969 return nullptr;
970
971 return getLoopAttr(Id: MarkOrBand.as<isl::schedule_node_mark>().get_id());
972}
973
974isl::schedule polly::hoistExtensionNodes(isl::schedule Sched) {
975 // If there is no extension node in the first place, return the original
976 // schedule tree.
977 if (!containsExtensionNode(Schedule: Sched))
978 return Sched;
979
980 // Build options can anchor schedule nodes, such that the schedule tree cannot
981 // be modified anymore. Therefore, apply build options after the tree has been
982 // created.
983 CollectASTBuildOptions Collector;
984 Collector.visit(Schedule: Sched);
985
986 // Rewrite the schedule tree without extension nodes.
987 ExtensionNodeRewriter Rewriter;
988 isl::schedule NewSched = Rewriter.visitSchedule(Schedule: Sched);
989
990 // Reapply the AST build options. The rewriter must not change the iteration
991 // order of bands. Any other node type is ignored.
992 ApplyASTBuildOptions Applicator(Collector.ASTBuildOptions);
993 NewSched = Applicator.visitSchedule(Schedule: NewSched);
994
995 return NewSched;
996}
997
998isl::schedule polly::applyFullUnroll(isl::schedule_node BandToUnroll) {
999 isl::ctx Ctx = BandToUnroll.ctx();
1000
1001 // Remove the loop's mark, the loop will disappear anyway.
1002 BandToUnroll = removeMark(MarkOrBand: BandToUnroll);
1003 assert(isBandWithSingleLoop(BandToUnroll));
1004
1005 isl::multi_union_pw_aff PartialSched = isl::manage(
1006 ptr: isl_schedule_node_band_get_partial_schedule(node: BandToUnroll.get()));
1007 assert(unsignedFromIslSize(PartialSched.dim(isl::dim::out)) == 1u &&
1008 "Can only unroll a single dimension");
1009 isl::union_pw_aff PartialSchedUAff = PartialSched.at(pos: 0);
1010
1011 isl::union_set Domain = BandToUnroll.get_domain();
1012 PartialSchedUAff = PartialSchedUAff.intersect_domain(uset: Domain);
1013 isl::union_map PartialSchedUMap =
1014 isl::union_map::from(upma: isl::union_pw_multi_aff(PartialSchedUAff));
1015
1016 // Enumerator only the scatter elements.
1017 isl::union_set ScatterList = PartialSchedUMap.range();
1018
1019 // Enumerate all loop iterations.
1020 // TODO: Diagnose if not enumerable or depends on a parameter.
1021 SmallVector<isl::point, 16> Elts;
1022 ScatterList.foreach_point(fn: [&Elts](isl::point P) -> isl::stat {
1023 Elts.push_back(Elt: P);
1024 return isl::stat::ok();
1025 });
1026
1027 // Don't assume that foreach_point returns in execution order.
1028 llvm::sort(C&: Elts, Comp: [](isl::point P1, isl::point P2) -> bool {
1029 isl::val C1 = P1.get_coordinate_val(type: isl::dim::set, pos: 0);
1030 isl::val C2 = P2.get_coordinate_val(type: isl::dim::set, pos: 0);
1031 return C1.lt(v2: C2);
1032 });
1033
1034 // Convert the points to a sequence of filters.
1035 isl::union_set_list List = isl::union_set_list(Ctx, Elts.size());
1036 for (isl::point P : Elts) {
1037 // Determine the domains that map this scatter element.
1038 isl::union_set DomainFilter = PartialSchedUMap.intersect_range(uset: P).domain();
1039
1040 List = List.add(el: DomainFilter);
1041 }
1042
1043 // Replace original band with unrolled sequence.
1044 isl::schedule_node Body =
1045 isl::manage(ptr: isl_schedule_node_delete(node: BandToUnroll.release()));
1046 Body = Body.insert_sequence(filters: List);
1047 return Body.get_schedule();
1048}
1049
1050isl::schedule polly::applyPartialUnroll(isl::schedule_node BandToUnroll,
1051 int Factor) {
1052 assert(Factor > 0 && "Positive unroll factor required");
1053 isl::ctx Ctx = BandToUnroll.ctx();
1054
1055 // Remove the mark, save the attribute for later use.
1056 BandAttr *Attr;
1057 BandToUnroll = removeMark(MarkOrBand: BandToUnroll, Attr);
1058 assert(isBandWithSingleLoop(BandToUnroll));
1059
1060 isl::multi_union_pw_aff PartialSched = isl::manage(
1061 ptr: isl_schedule_node_band_get_partial_schedule(node: BandToUnroll.get()));
1062
1063 // { Stmt[] -> [x] }
1064 isl::union_pw_aff PartialSchedUAff = PartialSched.at(pos: 0);
1065
1066 // Here we assume the schedule stride is one and starts with 0, which is not
1067 // necessarily the case.
1068 isl::union_pw_aff StridedPartialSchedUAff =
1069 isl::union_pw_aff::empty(space: PartialSchedUAff.get_space());
1070 isl::val ValFactor{Ctx, Factor};
1071 PartialSchedUAff.foreach_pw_aff(fn: [&StridedPartialSchedUAff,
1072 &ValFactor](isl::pw_aff PwAff) -> isl::stat {
1073 isl::space Space = PwAff.get_space();
1074 isl::set Universe = isl::set::universe(space: Space.domain());
1075 isl::pw_aff AffFactor{Universe, ValFactor};
1076 isl::pw_aff DivSchedAff = PwAff.div(pa2: AffFactor).floor().mul(pwaff2: AffFactor);
1077 StridedPartialSchedUAff = StridedPartialSchedUAff.union_add(upa2: DivSchedAff);
1078 return isl::stat::ok();
1079 });
1080
1081 isl::union_set_list List = isl::union_set_list(Ctx, Factor);
1082 for (auto i : seq<int>(Begin: 0, End: Factor)) {
1083 // { Stmt[] -> [x] }
1084 isl::union_map UMap =
1085 isl::union_map::from(upma: isl::union_pw_multi_aff(PartialSchedUAff));
1086
1087 // { [x] }
1088 isl::basic_set Divisible = isDivisibleBySet(Ctx, Factor, Offset: i);
1089
1090 // { Stmt[] }
1091 isl::union_set UnrolledDomain = UMap.intersect_range(uset: Divisible).domain();
1092
1093 List = List.add(el: UnrolledDomain);
1094 }
1095
1096 isl::schedule_node Body =
1097 isl::manage(ptr: isl_schedule_node_delete(node: BandToUnroll.copy()));
1098 Body = Body.insert_sequence(filters: List);
1099 isl::schedule_node NewLoop =
1100 Body.insert_partial_schedule(schedule: StridedPartialSchedUAff);
1101
1102 MDNode *FollowupMD = nullptr;
1103 if (Attr && Attr->Metadata)
1104 FollowupMD =
1105 findOptionalNodeOperand(LoopMD: Attr->Metadata, Name: LLVMLoopUnrollFollowupUnrolled);
1106
1107 isl::id NewBandId = createGeneratedLoopAttr(Ctx, FollowupLoopMD: FollowupMD);
1108 if (!NewBandId.is_null())
1109 NewLoop = insertMark(Band: NewLoop, Mark: NewBandId);
1110
1111 return NewLoop.get_schedule();
1112}
1113
1114isl::set polly::getPartialTilePrefixes(isl::set ScheduleRange,
1115 int VectorWidth) {
1116 unsigned Dims = unsignedFromIslSize(Size: ScheduleRange.tuple_dim());
1117 assert(Dims >= 1);
1118 isl::set LoopPrefixes =
1119 ScheduleRange.drop_constraints_involving_dims(type: isl::dim::set, first: Dims - 1, n: 1);
1120 auto ExtentPrefixes = addExtentConstraints(Set: LoopPrefixes, VectorWidth);
1121 isl::set BadPrefixes = ExtentPrefixes.subtract(set2: ScheduleRange);
1122 BadPrefixes = BadPrefixes.project_out(type: isl::dim::set, first: Dims - 1, n: 1);
1123 LoopPrefixes = LoopPrefixes.project_out(type: isl::dim::set, first: Dims - 1, n: 1);
1124 return LoopPrefixes.subtract(set2: BadPrefixes);
1125}
1126
1127isl::union_set polly::getIsolateOptions(isl::set IsolateDomain,
1128 unsigned OutDimsNum) {
1129 unsigned Dims = unsignedFromIslSize(Size: IsolateDomain.tuple_dim());
1130 assert(OutDimsNum <= Dims &&
1131 "The isl::set IsolateDomain is used to describe the range of schedule "
1132 "dimensions values, which should be isolated. Consequently, the "
1133 "number of its dimensions should be greater than or equal to the "
1134 "number of the schedule dimensions.");
1135 isl::map IsolateRelation = isl::map::from_domain(set: IsolateDomain);
1136 IsolateRelation = IsolateRelation.move_dims(dst_type: isl::dim::out, dst_pos: 0, src_type: isl::dim::in,
1137 src_pos: Dims - OutDimsNum, n: OutDimsNum);
1138 isl::set IsolateOption = IsolateRelation.wrap();
1139 isl::id Id = isl::id::alloc(ctx: IsolateOption.ctx(), name: "isolate", user: nullptr);
1140 IsolateOption = IsolateOption.set_tuple_id(Id);
1141 return isl::union_set(IsolateOption);
1142}
1143
1144isl::union_set polly::getDimOptions(isl::ctx Ctx, const char *Option) {
1145 isl::space Space(Ctx, 0, 1);
1146 auto DimOption = isl::set::universe(space: Space);
1147 auto Id = isl::id::alloc(ctx: Ctx, name: Option, user: nullptr);
1148 DimOption = DimOption.set_tuple_id(Id);
1149 return isl::union_set(DimOption);
1150}
1151
1152isl::schedule_node polly::tileNode(isl::schedule_node Node,
1153 const char *Identifier,
1154 ArrayRef<int> TileSizes,
1155 int DefaultTileSize) {
1156 auto Space = isl::manage(ptr: isl_schedule_node_band_get_space(node: Node.get()));
1157 auto Dims = Space.dim(type: isl::dim::set);
1158 auto Sizes = isl::multi_val::zero(space: Space);
1159 std::string IdentifierString(Identifier);
1160 for (unsigned i : rangeIslSize(Begin: 0, End: Dims)) {
1161 unsigned tileSize = i < TileSizes.size() ? TileSizes[i] : DefaultTileSize;
1162 Sizes = Sizes.set_val(pos: i, el: isl::val(Node.ctx(), tileSize));
1163 }
1164 auto TileLoopMarkerStr = IdentifierString + " - Tiles";
1165 auto TileLoopMarker = isl::id::alloc(ctx: Node.ctx(), name: TileLoopMarkerStr, user: nullptr);
1166 Node = Node.insert_mark(mark: TileLoopMarker);
1167 Node = Node.child(pos: 0);
1168 Node =
1169 isl::manage(ptr: isl_schedule_node_band_tile(node: Node.release(), sizes: Sizes.release()));
1170 Node = Node.child(pos: 0);
1171 auto PointLoopMarkerStr = IdentifierString + " - Points";
1172 auto PointLoopMarker =
1173 isl::id::alloc(ctx: Node.ctx(), name: PointLoopMarkerStr, user: nullptr);
1174 Node = Node.insert_mark(mark: PointLoopMarker);
1175 return Node.child(pos: 0);
1176}
1177
1178isl::schedule_node polly::applyRegisterTiling(isl::schedule_node Node,
1179 ArrayRef<int> TileSizes,
1180 int DefaultTileSize) {
1181 Node = tileNode(Node, Identifier: "Register tiling", TileSizes, DefaultTileSize);
1182 auto Ctx = Node.ctx();
1183 return Node.as<isl::schedule_node_band>().set_ast_build_options(
1184 isl::union_set(Ctx, "{unroll[x]}"));
1185}
1186
1187/// Find statements and sub-loops in (possibly nested) sequences.
1188static void
1189collectFissionableStmts(isl::schedule_node Node,
1190 SmallVectorImpl<isl::schedule_node> &ScheduleStmts) {
1191 if (isBand(Node) || isLeaf(Node)) {
1192 ScheduleStmts.push_back(Elt: Node);
1193 return;
1194 }
1195
1196 if (Node.has_children()) {
1197 isl::schedule_node C = Node.first_child();
1198 while (true) {
1199 collectFissionableStmts(Node: C, ScheduleStmts);
1200 if (!C.has_next_sibling())
1201 break;
1202 C = C.next_sibling();
1203 }
1204 }
1205}
1206
1207isl::schedule polly::applyMaxFission(isl::schedule_node BandToFission) {
1208 isl::ctx Ctx = BandToFission.ctx();
1209 BandToFission = removeMark(MarkOrBand: BandToFission);
1210 isl::schedule_node BandBody = BandToFission.child(pos: 0);
1211
1212 SmallVector<isl::schedule_node> FissionableStmts;
1213 collectFissionableStmts(Node: BandBody, ScheduleStmts&: FissionableStmts);
1214 size_t N = FissionableStmts.size();
1215
1216 // Collect the domain for each of the statements that will get their own loop.
1217 isl::union_set_list DomList = isl::union_set_list(Ctx, N);
1218 for (size_t i = 0; i < N; ++i) {
1219 isl::schedule_node BodyPart = FissionableStmts[i];
1220 DomList = DomList.add(el: BodyPart.get_domain());
1221 }
1222
1223 // Apply the fission by copying the entire loop, but inserting a filter for
1224 // the statement domains for each fissioned loop.
1225 isl::schedule_node Fissioned = BandToFission.insert_sequence(filters: DomList);
1226
1227 return Fissioned.get_schedule();
1228}
1229
1230isl::schedule polly::applyGreedyFusion(isl::schedule Sched,
1231 const isl::union_map &Deps) {
1232 POLLY_DEBUG(dbgs() << "Greedy loop fusion\n");
1233
1234 GreedyFusionRewriter Rewriter;
1235 isl::schedule Result = Rewriter.visit(Schedule: Sched, args: Deps);
1236 if (!Rewriter.AnyChange) {
1237 POLLY_DEBUG(dbgs() << "Found nothing to fuse\n");
1238 return Sched;
1239 }
1240
1241 // GreedyFusionRewriter due to working loop-by-loop, bands with multiple loops
1242 // may have been split into multiple bands.
1243 return collapseBands(Sched: Result);
1244}
1245

source code of polly/lib/Transform/ScheduleTreeTransform.cpp