1 | //===- MatmulOptimizer.cpp -----------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "polly/MatmulOptimizer.h" |
10 | #include "polly/DependenceInfo.h" |
11 | #include "polly/Options.h" |
12 | #include "polly/ScheduleTreeTransform.h" |
13 | #include "polly/ScopInfo.h" |
14 | #include "polly/ScopPass.h" |
15 | #include "polly/Simplify.h" |
16 | #include "polly/Support/GICHelper.h" |
17 | #include "polly/Support/ISLTools.h" |
18 | #include "llvm/ADT/ArrayRef.h" |
19 | #include "llvm/ADT/DenseSet.h" |
20 | #include "llvm/ADT/Sequence.h" |
21 | #include "llvm/ADT/SetOperations.h" |
22 | #include "llvm/ADT/SmallVector.h" |
23 | #include "llvm/ADT/StringRef.h" |
24 | #include "llvm/ADT/iterator_range.h" |
25 | #include "llvm/Analysis/TargetTransformInfo.h" |
26 | #include "llvm/IR/DataLayout.h" |
27 | #include "llvm/IR/Function.h" |
28 | #include "llvm/IR/Module.h" |
29 | #include "llvm/Support/CommandLine.h" |
30 | #include "llvm/Support/Debug.h" |
31 | #include "llvm/Support/TypeSize.h" |
32 | #include "llvm/Support/raw_ostream.h" |
33 | #include "isl/ctx.h" |
34 | #include "isl/schedule_node.h" |
35 | #include "isl/schedule_type.h" |
36 | #include "isl/union_map.h" |
37 | #include "isl/union_set.h" |
38 | #include <algorithm> |
39 | #include <cassert> |
40 | #include <cmath> |
41 | #include <cstdint> |
42 | #include <string> |
43 | #include <vector> |
44 | |
45 | #include "polly/Support/PollyDebug.h" |
46 | #define DEBUG_TYPE "polly-opt-isl" |
47 | |
48 | using namespace llvm; |
49 | using namespace polly; |
50 | |
51 | namespace llvm { |
52 | class Value; |
53 | } |
54 | |
55 | static cl::opt<int> LatencyVectorFma( |
56 | "polly-target-latency-vector-fma" , |
57 | cl::desc("The minimal number of cycles between issuing two " |
58 | "dependent consecutive vector fused multiply-add " |
59 | "instructions." ), |
60 | cl::Hidden, cl::init(Val: 8), cl::cat(PollyCategory)); |
61 | |
62 | static cl::opt<int> ThroughputVectorFma( |
63 | "polly-target-throughput-vector-fma" , |
64 | cl::desc("A throughput of the processor floating-point arithmetic units " |
65 | "expressed in the number of vector fused multiply-add " |
66 | "instructions per clock cycle." ), |
67 | cl::Hidden, cl::init(Val: 1), cl::cat(PollyCategory)); |
68 | |
69 | static cl::opt<int> FirstCacheLevelSize( |
70 | "polly-target-1st-cache-level-size" , |
71 | cl::desc("The size of the first cache level specified in bytes." ), |
72 | cl::Hidden, cl::init(Val: -1), cl::cat(PollyCategory)); |
73 | |
74 | static cl::opt<int> FirstCacheLevelDefaultSize( |
75 | "polly-target-1st-cache-level-default-size" , |
76 | cl::desc("The default size of the first cache level specified in bytes" |
77 | " (if not enough were provided by the TargetTransformInfo)." ), |
78 | cl::Hidden, cl::init(Val: 32768), cl::cat(PollyCategory)); |
79 | |
80 | static cl::opt<int> SecondCacheLevelSize( |
81 | "polly-target-2nd-cache-level-size" , |
82 | cl::desc("The size of the second level specified in bytes." ), cl::Hidden, |
83 | cl::init(Val: -1), cl::cat(PollyCategory)); |
84 | |
85 | static cl::opt<int> SecondCacheLevelDefaultSize( |
86 | "polly-target-2nd-cache-level-default-size" , |
87 | cl::desc("The default size of the second cache level specified in bytes" |
88 | " (if not enough were provided by the TargetTransformInfo)." ), |
89 | cl::Hidden, cl::init(Val: 262144), cl::cat(PollyCategory)); |
90 | |
91 | // This option, along with --polly-target-2nd-cache-level-associativity, |
92 | // --polly-target-1st-cache-level-size, and --polly-target-2st-cache-level-size |
93 | // represent the parameters of the target cache, which do not have typical |
94 | // values that can be used by default. However, to apply the pattern matching |
95 | // optimizations, we use the values of the parameters of Intel Core i7-3820 |
96 | // SandyBridge in case the parameters are not specified or not provided by the |
97 | // TargetTransformInfo. |
98 | static cl::opt<int> FirstCacheLevelAssociativity( |
99 | "polly-target-1st-cache-level-associativity" , |
100 | cl::desc("The associativity of the first cache level." ), cl::Hidden, |
101 | cl::init(Val: -1), cl::cat(PollyCategory)); |
102 | |
103 | static cl::opt<int> FirstCacheLevelDefaultAssociativity( |
104 | "polly-target-1st-cache-level-default-associativity" , |
105 | cl::desc("The default associativity of the first cache level" |
106 | " (if not enough were provided by the TargetTransformInfo)." ), |
107 | cl::Hidden, cl::init(Val: 8), cl::cat(PollyCategory)); |
108 | |
109 | static cl::opt<int> SecondCacheLevelAssociativity( |
110 | "polly-target-2nd-cache-level-associativity" , |
111 | cl::desc("The associativity of the second cache level." ), cl::Hidden, |
112 | cl::init(Val: -1), cl::cat(PollyCategory)); |
113 | |
114 | static cl::opt<int> SecondCacheLevelDefaultAssociativity( |
115 | "polly-target-2nd-cache-level-default-associativity" , |
116 | cl::desc("The default associativity of the second cache level" |
117 | " (if not enough were provided by the TargetTransformInfo)." ), |
118 | cl::Hidden, cl::init(Val: 8), cl::cat(PollyCategory)); |
119 | |
120 | static cl::opt<int> VectorRegisterBitwidth( |
121 | "polly-target-vector-register-bitwidth" , |
122 | cl::desc("The size in bits of a vector register (if not set, this " |
123 | "information is taken from LLVM's target information." ), |
124 | cl::Hidden, cl::init(Val: -1), cl::cat(PollyCategory)); |
125 | |
126 | static cl::opt<int> PollyPatternMatchingNcQuotient( |
127 | "polly-pattern-matching-nc-quotient" , |
128 | cl::desc("Quotient that is obtained by dividing Nc, the parameter of the" |
129 | "macro-kernel, by Nr, the parameter of the micro-kernel" ), |
130 | cl::Hidden, cl::init(Val: 256), cl::cat(PollyCategory)); |
131 | |
132 | static cl::opt<bool> |
133 | PMBasedTCOpts("polly-tc-opt" , |
134 | cl::desc("Perform optimizations of tensor contractions based " |
135 | "on pattern matching" ), |
136 | cl::init(Val: false), cl::ZeroOrMore, cl::cat(PollyCategory)); |
137 | |
138 | static cl::opt<bool> |
139 | PMBasedMMMOpts("polly-matmul-opt" , |
140 | cl::desc("Perform optimizations of matrix multiplications " |
141 | "based on pattern matching" ), |
142 | cl::init(Val: true), cl::ZeroOrMore, cl::cat(PollyCategory)); |
143 | |
144 | static cl::opt<int> OptComputeOut( |
145 | "polly-tc-dependences-computeout" , |
146 | cl::desc("Bound the dependence analysis by a maximal amount of " |
147 | "computational steps (0 means no bound)" ), |
148 | cl::Hidden, cl::init(Val: 500000), cl::ZeroOrMore, cl::cat(PollyCategory)); |
149 | |
150 | namespace { |
151 | /// Parameters of the micro kernel. |
152 | /// |
153 | /// Parameters, which determine sizes of rank-1 (i.e., outer product) update |
154 | /// used in the optimized matrix multiplication. |
155 | struct MicroKernelParamsTy { |
156 | int Mr; |
157 | int Nr; |
158 | }; |
159 | |
160 | /// Parameters of the macro kernel. |
161 | /// |
162 | /// Parameters, which determine sizes of blocks of partitioned matrices |
163 | /// used in the optimized matrix multiplication. |
164 | struct MacroKernelParamsTy { |
165 | int Mc; |
166 | int Nc; |
167 | int Kc; |
168 | }; |
169 | |
170 | /// Parameters of the matrix multiplication operands. |
171 | /// |
172 | /// Parameters, which describe access relations that represent operands of the |
173 | /// matrix multiplication. |
174 | struct MatMulInfoTy { |
175 | MemoryAccess *A = nullptr; |
176 | MemoryAccess *B = nullptr; |
177 | MemoryAccess *ReadFromC = nullptr; |
178 | MemoryAccess *WriteToC = nullptr; |
179 | int i = -1; |
180 | int j = -1; |
181 | int k = -1; |
182 | }; |
183 | |
184 | /// Parameters of the tensor contraction operands. |
185 | /// |
186 | /// A general d-dimensional tensor T ∈ R ^ Nu0 x ... x Nud−1 can be defined |
187 | /// as the set of scalar elements indexed by the set of indices u0 ... ud, |
188 | /// |
189 | /// T ≡ {Anu0...nud−1 ∈ R | (u0,...,ud−1) ∈ Nu0 x ... x Nud−1}. |
190 | /// |
191 | /// Let A, B, and C be dA, dB, and dC-dimensional tensors, respectively. |
192 | /// Let the free and the contracted indices of the tensor A be grouped into |
193 | /// two bundles I = i0...ir−1 and P = p0...pt−1, respectively. Similarly, |
194 | /// the free and the contracted indices of B are grouped into bundles |
195 | /// J = j0..js−1 and P and the free indices of C are grouped into |
196 | /// bundles I and J. |
197 | /// |
198 | /// Tensor contraction (TC) of tensors A, B into tensor C can be represented as |
199 | /// C(shuffle(I,J))=∑α·A(shuffle(I,P))·B(shuffle(P,J))+β·C(shuffle(I,J)), |
200 | /// where ∑ is a summation over all contracted indices of P, |
201 | /// α, β ∈ R, Npi is the length of the tensor dimension that corresponds |
202 | /// to the index pi, A(shuffle(I, P)), B(shuffle(P, J)), C(shuffle(I, J)) are |
203 | /// accesses to tensors A, B, C, respectively, |
204 | /// shuffle(I, J), shuffle(I, P), and shuffle(P, J) are permutations of |
205 | /// the enclosed indices. |
206 | /// |
207 | /// Multiplication of C(shuffle(I,J)) by β can be moved into a different SCoP |
208 | /// statement by loop distribution, which is done by the isl scheduler. |
209 | // If β is not equal to one, the optimization of TC of Polly requires |
210 | /// such a transformation. |
211 | /// |
212 | /// TCInfoTy contains parameters, which describe access relations that represent |
213 | /// operands of the tensor contraction. |
214 | struct TCInfoTy { |
215 | /// @{ |
216 | /// Memory accesses that represent reading from tensors, which are operands of |
217 | /// the tensor contraction. |
218 | MemoryAccess *A = nullptr; |
219 | MemoryAccess *B = nullptr; |
220 | /// @} |
221 | |
222 | /// @{ |
223 | /// Memory accesses that represent reading from and writing into the tensor, |
224 | /// which contains the result of the tensor contraction. |
225 | MemoryAccess *ReadFromC = nullptr; |
226 | MemoryAccess *WriteToC = nullptr; |
227 | /// @} |
228 | |
229 | /// @{ |
230 | /// Input dimensions of the schedule space, which represent free |
231 | /// indices of tensors. |
232 | SmallDenseSet<int> I; |
233 | SmallDenseSet<int> J; |
234 | /// @} |
235 | |
236 | /// Input dimension of the schedule space, which represents contracted |
237 | /// indices of tensors. |
238 | SmallDenseSet<int> P; |
239 | |
240 | /// @{ |
241 | /// Sizes of tensor dimensions for corresponding input dimensions of |
242 | /// the schedule space. The size of the tensor dimension can be larger than |
243 | /// the size of the corresponding input dimension of the schedule space. |
244 | /// This does not correspond to a tensor contraction. However, such a pattern |
245 | /// will be optimized by the transformation. |
246 | SmallVector<int> DimensionSizes; |
247 | SmallVector<int> ADimensions; |
248 | SmallVector<int> BDimensions; |
249 | SmallVector<int> CDimensions; |
250 | /// @} |
251 | |
252 | /// @{ |
253 | /// Permutations of indices of I, J, and P, which describe operands of |
254 | /// the tensor contraction and its result. |
255 | SmallVector<int> OrderedI; |
256 | SmallVector<int> OrderedJ; |
257 | SmallVector<int> OrderedP; |
258 | /// @} |
259 | }; |
260 | |
261 | /// Create an isl::union_set, which describes the option of the form |
262 | /// [isolate[] -> unroll[x]]. |
263 | /// |
264 | /// @param Ctx An isl::ctx, which is used to create the isl::union_set. |
265 | static isl::union_set getUnrollIsolatedSetOptions(isl::ctx Ctx) { |
266 | isl::space Space = isl::space(Ctx, 0, 0, 1); |
267 | isl::map UnrollIsolatedSetOption = isl::map::universe(space: Space); |
268 | isl::id DimInId = isl::id::alloc(ctx: Ctx, name: "isolate" , user: nullptr); |
269 | isl::id DimOutId = isl::id::alloc(ctx: Ctx, name: "unroll" , user: nullptr); |
270 | UnrollIsolatedSetOption = |
271 | UnrollIsolatedSetOption.set_tuple_id(type: isl::dim::in, id: DimInId); |
272 | UnrollIsolatedSetOption = |
273 | UnrollIsolatedSetOption.set_tuple_id(type: isl::dim::out, id: DimOutId); |
274 | return UnrollIsolatedSetOption.wrap(); |
275 | } |
276 | |
277 | /// Permute the two dimensions of the isl map. |
278 | /// |
279 | /// Permute @p DstPos and @p SrcPos dimensions of the isl map @p Map that |
280 | /// have type @p DimType. |
281 | /// |
282 | /// @param Map The isl map to be modified. |
283 | /// @param DimType The type of the dimensions. |
284 | /// @param DstPos The first dimension. |
285 | /// @param SrcPos The second dimension. |
286 | /// @return The modified map. |
287 | static isl::map permuteDimensions(isl::map Map, isl::dim DimType, |
288 | unsigned DstPos, unsigned SrcPos) { |
289 | assert(DstPos < unsignedFromIslSize(Map.dim(DimType)) && |
290 | SrcPos < unsignedFromIslSize(Map.dim(DimType))); |
291 | if (DstPos == SrcPos) |
292 | return Map; |
293 | isl::id DimId; |
294 | if (Map.has_tuple_id(type: DimType)) |
295 | DimId = Map.get_tuple_id(type: DimType); |
296 | auto FreeDim = DimType == isl::dim::in ? isl::dim::out : isl::dim::in; |
297 | isl::id FreeDimId; |
298 | if (Map.has_tuple_id(type: FreeDim)) |
299 | FreeDimId = Map.get_tuple_id(type: FreeDim); |
300 | auto MaxDim = std::max(a: DstPos, b: SrcPos); |
301 | auto MinDim = std::min(a: DstPos, b: SrcPos); |
302 | Map = Map.move_dims(dst_type: FreeDim, dst_pos: 0, src_type: DimType, src_pos: MaxDim, n: 1); |
303 | Map = Map.move_dims(dst_type: FreeDim, dst_pos: 0, src_type: DimType, src_pos: MinDim, n: 1); |
304 | Map = Map.move_dims(dst_type: DimType, dst_pos: MinDim, src_type: FreeDim, src_pos: 1, n: 1); |
305 | Map = Map.move_dims(dst_type: DimType, dst_pos: MaxDim, src_type: FreeDim, src_pos: 0, n: 1); |
306 | if (!DimId.is_null()) |
307 | Map = Map.set_tuple_id(type: DimType, id: DimId); |
308 | if (!FreeDimId.is_null()) |
309 | Map = Map.set_tuple_id(type: FreeDim, id: FreeDimId); |
310 | return Map; |
311 | } |
312 | |
313 | /// Check the form of the access relation. |
314 | /// |
315 | /// Check that the access relation @p AccMap has the form M[i][j], where i |
316 | /// is a @p FirstPos and j is a @p SecondPos. |
317 | /// |
318 | /// @param AccMap The access relation to be checked. |
319 | /// @param FirstPos The index of the input dimension that is mapped to |
320 | /// the first output dimension. |
321 | /// @param SecondPos The index of the input dimension that is mapped to the |
322 | /// second output dimension. |
323 | /// @return True in case @p AccMap has the expected form and false, |
324 | /// otherwise. |
325 | static bool isMatMulOperandAcc(isl::set Domain, isl::map AccMap, int &FirstPos, |
326 | int &SecondPos) { |
327 | isl::space Space = AccMap.get_space(); |
328 | isl::map Universe = isl::map::universe(space: Space); |
329 | |
330 | if (unsignedFromIslSize(Size: Space.dim(type: isl::dim::out)) != 2) |
331 | return false; |
332 | |
333 | // MatMul has the form: |
334 | // for (i = 0; i < N; i++) |
335 | // for (j = 0; j < M; j++) |
336 | // for (k = 0; k < P; k++) |
337 | // C[i, j] += A[i, k] * B[k, j] |
338 | // |
339 | // Permutation of three outer loops: 3! = 6 possibilities. |
340 | int FirstDims[] = {0, 0, 1, 1, 2, 2}; |
341 | int SecondDims[] = {1, 2, 2, 0, 0, 1}; |
342 | for (int i = 0; i < 6; i += 1) { |
343 | auto PossibleMatMul = |
344 | Universe.equate(type1: isl::dim::in, pos1: FirstDims[i], type2: isl::dim::out, pos2: 0) |
345 | .equate(type1: isl::dim::in, pos1: SecondDims[i], type2: isl::dim::out, pos2: 1); |
346 | |
347 | AccMap = AccMap.intersect_domain(set: Domain); |
348 | PossibleMatMul = PossibleMatMul.intersect_domain(set: Domain); |
349 | |
350 | // If AccMap spans entire domain (Non-partial write), |
351 | // compute FirstPos and SecondPos. |
352 | // If AccMap != PossibleMatMul here (the two maps have been gisted at |
353 | // this point), it means that the writes are not complete, or in other |
354 | // words, it is a Partial write and Partial writes must be rejected. |
355 | if (AccMap.is_equal(map2: PossibleMatMul)) { |
356 | if (FirstPos != -1 && FirstPos != FirstDims[i]) |
357 | continue; |
358 | FirstPos = FirstDims[i]; |
359 | if (SecondPos != -1 && SecondPos != SecondDims[i]) |
360 | continue; |
361 | SecondPos = SecondDims[i]; |
362 | return true; |
363 | } |
364 | } |
365 | |
366 | return false; |
367 | } |
368 | |
369 | /// Does the memory access represent a non-scalar operand of the matrix |
370 | /// multiplication. |
371 | /// |
372 | /// Check that the memory access @p MemAccess is the read access to a non-scalar |
373 | /// operand of the matrix multiplication or its result. |
374 | /// |
375 | /// @param MemAccess The memory access to be checked. |
376 | /// @param MMI Parameters of the matrix multiplication operands. |
377 | /// @return True in case the memory access represents the read access |
378 | /// to a non-scalar operand of the matrix multiplication and |
379 | /// false, otherwise. |
380 | static bool isMatMulNonScalarReadAccess(MemoryAccess *MemAccess, |
381 | MatMulInfoTy &MMI) { |
382 | if (!MemAccess->isLatestArrayKind() || !MemAccess->isRead()) |
383 | return false; |
384 | auto AccMap = MemAccess->getLatestAccessRelation(); |
385 | isl::set StmtDomain = MemAccess->getStatement()->getDomain(); |
386 | if (isMatMulOperandAcc(Domain: StmtDomain, AccMap, FirstPos&: MMI.i, SecondPos&: MMI.j) && !MMI.ReadFromC) { |
387 | MMI.ReadFromC = MemAccess; |
388 | return true; |
389 | } |
390 | if (isMatMulOperandAcc(Domain: StmtDomain, AccMap, FirstPos&: MMI.i, SecondPos&: MMI.k) && !MMI.A) { |
391 | MMI.A = MemAccess; |
392 | return true; |
393 | } |
394 | if (isMatMulOperandAcc(Domain: StmtDomain, AccMap, FirstPos&: MMI.k, SecondPos&: MMI.j) && !MMI.B) { |
395 | MMI.B = MemAccess; |
396 | return true; |
397 | } |
398 | return false; |
399 | } |
400 | |
401 | /// Check accesses to operands of the matrix multiplication. |
402 | /// |
403 | /// Check that accesses of the SCoP statement, which corresponds to |
404 | /// the partial schedule @p PartialSchedule, are scalar in terms of loops |
405 | /// containing the matrix multiplication, in case they do not represent |
406 | /// accesses to the non-scalar operands of the matrix multiplication or |
407 | /// its result. |
408 | /// |
409 | /// @param PartialSchedule The partial schedule of the SCoP statement. |
410 | /// @param MMI Parameters of the matrix multiplication operands. |
411 | /// @return True in case the corresponding SCoP statement |
412 | /// represents matrix multiplication and false, |
413 | /// otherwise. |
414 | static bool containsOnlyMatrMultAcc(isl::map PartialSchedule, |
415 | MatMulInfoTy &MMI) { |
416 | auto InputDimId = PartialSchedule.get_tuple_id(type: isl::dim::in); |
417 | auto *Stmt = static_cast<ScopStmt *>(InputDimId.get_user()); |
418 | unsigned OutDimNum = unsignedFromIslSize(Size: PartialSchedule.range_tuple_dim()); |
419 | assert(OutDimNum > 2 && "In case of the matrix multiplication the loop nest " |
420 | "and, consequently, the corresponding scheduling " |
421 | "functions have at least three dimensions." ); |
422 | auto MapI = |
423 | permuteDimensions(Map: PartialSchedule, DimType: isl::dim::out, DstPos: MMI.i, SrcPos: OutDimNum - 1); |
424 | auto MapJ = |
425 | permuteDimensions(Map: PartialSchedule, DimType: isl::dim::out, DstPos: MMI.j, SrcPos: OutDimNum - 1); |
426 | auto MapK = |
427 | permuteDimensions(Map: PartialSchedule, DimType: isl::dim::out, DstPos: MMI.k, SrcPos: OutDimNum - 1); |
428 | |
429 | auto Accesses = getAccessesInOrder(Stmt&: *Stmt); |
430 | for (auto *MemA = Accesses.begin(); MemA != Accesses.end() - 1; MemA++) { |
431 | auto *MemAccessPtr = *MemA; |
432 | if (MemAccessPtr->isLatestArrayKind() && MemAccessPtr != MMI.WriteToC && |
433 | !isMatMulNonScalarReadAccess(MemAccess: MemAccessPtr, MMI) && |
434 | !(MemAccessPtr->isStrideZero(Schedule: MapI) && |
435 | MemAccessPtr->isStrideZero(Schedule: MapJ) && MemAccessPtr->isStrideZero(Schedule: MapK))) |
436 | return false; |
437 | } |
438 | return true; |
439 | } |
440 | |
441 | /// Check for dependencies corresponding to the matrix multiplication. |
442 | /// |
443 | /// Check that there is only true dependence of the form |
444 | /// S(..., k, ...) -> S(..., k + 1, …), where S is the SCoP statement |
445 | /// represented by @p Schedule and k is @p Pos. Such a dependence corresponds |
446 | /// to the dependency produced by the matrix multiplication. |
447 | /// |
448 | /// @param Schedule The schedule of the SCoP statement. |
449 | /// @param D The SCoP dependencies. |
450 | /// @param Pos The parameter to describe an acceptable true dependence. |
451 | /// In case it has a negative value, try to determine its |
452 | /// acceptable value. |
453 | /// @return True in case dependencies correspond to the matrix multiplication |
454 | /// and false, otherwise. |
455 | static bool containsOnlyMatMulDep(isl::map Schedule, const Dependences *D, |
456 | int &Pos) { |
457 | isl::union_map Dep = D->getDependences(Kinds: Dependences::TYPE_RAW); |
458 | isl::union_map Red = D->getDependences(Kinds: Dependences::TYPE_RED); |
459 | if (!Red.is_null()) |
460 | Dep = Dep.unite(umap2: Red); |
461 | auto DomainSpace = Schedule.get_space().domain(); |
462 | auto Space = DomainSpace.map_from_domain_and_range(range: DomainSpace); |
463 | auto Deltas = Dep.extract_map(space: Space).deltas(); |
464 | int DeltasDimNum = unsignedFromIslSize(Size: Deltas.dim(type: isl::dim::set)); |
465 | for (int i = 0; i < DeltasDimNum; i++) { |
466 | auto Val = Deltas.plain_get_val_if_fixed(type: isl::dim::set, pos: i); |
467 | Pos = Pos < 0 && Val.is_one() ? i : Pos; |
468 | if (Val.is_nan() || !(Val.is_zero() || (i == Pos && Val.is_one()))) |
469 | return false; |
470 | } |
471 | if (DeltasDimNum == 0 || Pos < 0) |
472 | return false; |
473 | return true; |
474 | } |
475 | |
476 | /// Check if the SCoP statement could probably be optimized with analytical |
477 | /// modeling. |
478 | /// |
479 | /// containsMatrMult tries to determine whether the following conditions |
480 | /// are true: |
481 | /// 1. The last memory access modeling an array, MA1, represents writing to |
482 | /// memory and has the form S(..., i1, ..., i2, ...) -> M(i1, i2) or |
483 | /// S(..., i2, ..., i1, ...) -> M(i1, i2), where S is the SCoP statement |
484 | /// under consideration. |
485 | /// 2. There is only one loop-carried true dependency, and it has the |
486 | /// form S(..., i3, ...) -> S(..., i3 + 1, ...), and there are no |
487 | /// loop-carried or anti dependencies. |
488 | /// 3. SCoP contains three access relations, MA2, MA3, and MA4 that represent |
489 | /// reading from memory and have the form S(..., i3, ...) -> M(i1, i3), |
490 | /// S(..., i3, ...) -> M(i3, i2), S(...) -> M(i1, i2), respectively, |
491 | /// and all memory accesses of the SCoP that are different from MA1, MA2, |
492 | /// MA3, and MA4 have stride 0, if the innermost loop is exchanged with any |
493 | /// of loops i1, i2 and i3. |
494 | /// |
495 | /// @param PartialSchedule The PartialSchedule that contains a SCoP statement |
496 | /// to check. |
497 | /// @D The SCoP dependencies. |
498 | /// @MMI Parameters of the matrix multiplication operands. |
499 | static bool containsMatrMult(isl::map PartialSchedule, const Dependences *D, |
500 | MatMulInfoTy &MMI) { |
501 | auto InputDimsId = PartialSchedule.get_tuple_id(type: isl::dim::in); |
502 | auto *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user()); |
503 | if (Stmt->size() <= 1) |
504 | return false; |
505 | |
506 | auto Accesses = getAccessesInOrder(Stmt&: *Stmt); |
507 | for (auto *MemA = Accesses.end() - 1; MemA != Accesses.begin(); MemA--) { |
508 | auto *MemAccessPtr = *MemA; |
509 | if (!MemAccessPtr->isLatestArrayKind()) |
510 | continue; |
511 | if (!MemAccessPtr->isWrite()) |
512 | return false; |
513 | auto AccMap = MemAccessPtr->getLatestAccessRelation(); |
514 | if (!isMatMulOperandAcc(Domain: Stmt->getDomain(), AccMap, FirstPos&: MMI.i, SecondPos&: MMI.j)) |
515 | return false; |
516 | MMI.WriteToC = MemAccessPtr; |
517 | break; |
518 | } |
519 | |
520 | if (!containsOnlyMatMulDep(Schedule: PartialSchedule, D, Pos&: MMI.k)) |
521 | return false; |
522 | |
523 | if (!MMI.WriteToC || !containsOnlyMatrMultAcc(PartialSchedule, MMI)) |
524 | return false; |
525 | |
526 | if (!MMI.A || !MMI.B || !MMI.ReadFromC) |
527 | return false; |
528 | return true; |
529 | } |
530 | |
531 | /// Permute two dimensions of the band node. |
532 | /// |
533 | /// Permute FirstDim and SecondDim dimensions of the Node. |
534 | /// |
535 | /// @param Node The band node to be modified. |
536 | /// @param FirstDim The first dimension to be permuted. |
537 | /// @param SecondDim The second dimension to be permuted. |
538 | static isl::schedule_node permuteBandNodeDimensions(isl::schedule_node Node, |
539 | unsigned FirstDim, |
540 | unsigned SecondDim) { |
541 | assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band && |
542 | (unsigned)isl_schedule_node_band_n_member(Node.get()) > |
543 | std::max(FirstDim, SecondDim)); |
544 | auto PartialSchedule = |
545 | isl::manage(ptr: isl_schedule_node_band_get_partial_schedule(node: Node.get())); |
546 | auto PartialScheduleFirstDim = PartialSchedule.at(pos: FirstDim); |
547 | auto PartialScheduleSecondDim = PartialSchedule.at(pos: SecondDim); |
548 | PartialSchedule = |
549 | PartialSchedule.set_union_pw_aff(pos: SecondDim, el: PartialScheduleFirstDim); |
550 | PartialSchedule = |
551 | PartialSchedule.set_union_pw_aff(pos: FirstDim, el: PartialScheduleSecondDim); |
552 | Node = isl::manage(ptr: isl_schedule_node_delete(node: Node.release())); |
553 | return Node.insert_partial_schedule(schedule: PartialSchedule); |
554 | } |
555 | |
556 | static isl::schedule_node |
557 | createMicroKernel(isl::schedule_node Node, |
558 | MicroKernelParamsTy MicroKernelParams) { |
559 | Node = applyRegisterTiling(Node, TileSizes: {MicroKernelParams.Mr, MicroKernelParams.Nr}, |
560 | DefaultTileSize: 1); |
561 | Node = Node.parent().parent(); |
562 | return permuteBandNodeDimensions(Node, FirstDim: 0, SecondDim: 1).child(pos: 0).child(pos: 0); |
563 | } |
564 | |
565 | /// Create the BLIS macro-kernel. |
566 | /// |
567 | /// We create the BLIS macro-kernel by applying a combination of tiling |
568 | /// of dimensions of the band node and interchanging of two innermost |
569 | /// modified dimensions. The values of MacroKernelParams's fields are used |
570 | /// as tile sizes. |
571 | /// |
572 | /// @param Node The schedule node to be modified. |
573 | /// @param MacroKernelParams Parameters of the macro kernel |
574 | /// to be used as tile sizes. |
575 | static isl::schedule_node |
576 | createMacroKernel(isl::schedule_node Node, |
577 | MacroKernelParamsTy MacroKernelParams) { |
578 | assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band); |
579 | if (MacroKernelParams.Mc == 1 && MacroKernelParams.Nc == 1 && |
580 | MacroKernelParams.Kc == 1) |
581 | return Node; |
582 | int DimOutNum = isl_schedule_node_band_n_member(node: Node.get()); |
583 | std::vector<int> TileSizes(DimOutNum, 1); |
584 | TileSizes[DimOutNum - 3] = MacroKernelParams.Mc; |
585 | TileSizes[DimOutNum - 2] = MacroKernelParams.Nc; |
586 | TileSizes[DimOutNum - 1] = MacroKernelParams.Kc; |
587 | Node = tileNode(Node, Identifier: "1st level tiling" , TileSizes, DefaultTileSize: 1); |
588 | Node = Node.parent().parent(); |
589 | Node = permuteBandNodeDimensions(Node, FirstDim: DimOutNum - 2, SecondDim: DimOutNum - 1); |
590 | Node = permuteBandNodeDimensions(Node, FirstDim: DimOutNum - 3, SecondDim: DimOutNum - 1); |
591 | |
592 | return Node.child(pos: 0).child(pos: 0); |
593 | } |
594 | |
595 | /// Get the size of the widest type of the matrix multiplication operands |
596 | /// in bytes, including alignment padding. |
597 | /// |
598 | /// @param MMI Parameters of the matrix multiplication operands. |
599 | /// @return The size of the widest type of the matrix multiplication operands |
600 | /// in bytes, including alignment padding. |
601 | static uint64_t getMatMulAlignTypeSize(const MatMulInfoTy &MMI) { |
602 | auto *S = MMI.A->getStatement()->getParent(); |
603 | auto &DL = S->getFunction().getParent()->getDataLayout(); |
604 | auto ElementSizeA = DL.getTypeAllocSize(Ty: MMI.A->getElementType()); |
605 | auto ElementSizeB = DL.getTypeAllocSize(Ty: MMI.B->getElementType()); |
606 | auto ElementSizeC = DL.getTypeAllocSize(Ty: MMI.WriteToC->getElementType()); |
607 | return std::max(l: {ElementSizeA, ElementSizeB, ElementSizeC}); |
608 | } |
609 | |
610 | /// Get the size of the widest type of the matrix multiplication operands |
611 | /// in bits. |
612 | /// |
613 | /// @param MMI Parameters of the matrix multiplication operands. |
614 | /// @return The size of the widest type of the matrix multiplication operands |
615 | /// in bits. |
616 | static uint64_t getMatMulTypeSize(const MatMulInfoTy &MMI) { |
617 | auto *S = MMI.A->getStatement()->getParent(); |
618 | auto &DL = S->getFunction().getParent()->getDataLayout(); |
619 | auto ElementSizeA = DL.getTypeSizeInBits(Ty: MMI.A->getElementType()); |
620 | auto ElementSizeB = DL.getTypeSizeInBits(Ty: MMI.B->getElementType()); |
621 | auto ElementSizeC = DL.getTypeSizeInBits(Ty: MMI.WriteToC->getElementType()); |
622 | return std::max(l: {ElementSizeA, ElementSizeB, ElementSizeC}); |
623 | } |
624 | |
625 | /// Get parameters of the BLIS micro kernel. |
626 | /// |
627 | /// We choose the Mr and Nr parameters of the micro kernel to be large enough |
628 | /// such that no stalls caused by the combination of latencies and dependencies |
629 | /// are introduced during the updates of the resulting matrix of the matrix |
630 | /// multiplication. However, they should also be as small as possible to |
631 | /// release more registers for entries of multiplied matrices. |
632 | /// |
633 | /// @param TTI Target Transform Info. |
634 | /// @param MMI Parameters of the matrix multiplication operands. |
635 | /// @return The structure of type MicroKernelParamsTy. |
636 | /// @see MicroKernelParamsTy |
637 | static MicroKernelParamsTy getMicroKernelParams(const TargetTransformInfo *TTI, |
638 | const MatMulInfoTy &MMI) { |
639 | assert(TTI && "The target transform info should be provided." ); |
640 | |
641 | // Nvec - Number of double-precision floating-point numbers that can be hold |
642 | // by a vector register. Use 2 by default. |
643 | long RegisterBitwidth = VectorRegisterBitwidth; |
644 | |
645 | if (RegisterBitwidth == -1) |
646 | RegisterBitwidth = |
647 | TTI->getRegisterBitWidth(K: TargetTransformInfo::RGK_FixedWidthVector); |
648 | auto ElementSize = getMatMulTypeSize(MMI); |
649 | assert(ElementSize > 0 && "The element size of the matrix multiplication " |
650 | "operands should be greater than zero." ); |
651 | auto Nvec = RegisterBitwidth / ElementSize; |
652 | if (Nvec == 0) |
653 | Nvec = 2; |
654 | int Nr = ceil(x: sqrt(x: (double)(Nvec * LatencyVectorFma * ThroughputVectorFma)) / |
655 | Nvec) * |
656 | Nvec; |
657 | int Mr = ceil(x: (double)(Nvec * LatencyVectorFma * ThroughputVectorFma / Nr)); |
658 | return {.Mr: Mr, .Nr: Nr}; |
659 | } |
660 | |
661 | /// Determine parameters of the target cache. |
662 | /// |
663 | /// @param TTI Target Transform Info. |
664 | static void getTargetCacheParameters(const llvm::TargetTransformInfo *TTI) { |
665 | auto L1DCache = llvm::TargetTransformInfo::CacheLevel::L1D; |
666 | auto L2DCache = llvm::TargetTransformInfo::CacheLevel::L2D; |
667 | if (FirstCacheLevelSize == -1) { |
668 | if (TTI->getCacheSize(Level: L1DCache)) |
669 | FirstCacheLevelSize = TTI->getCacheSize(Level: L1DCache).value(); |
670 | else |
671 | FirstCacheLevelSize = static_cast<int>(FirstCacheLevelDefaultSize); |
672 | } |
673 | if (SecondCacheLevelSize == -1) { |
674 | if (TTI->getCacheSize(Level: L2DCache)) |
675 | SecondCacheLevelSize = TTI->getCacheSize(Level: L2DCache).value(); |
676 | else |
677 | SecondCacheLevelSize = static_cast<int>(SecondCacheLevelDefaultSize); |
678 | } |
679 | if (FirstCacheLevelAssociativity == -1) { |
680 | if (TTI->getCacheAssociativity(Level: L1DCache)) |
681 | FirstCacheLevelAssociativity = |
682 | TTI->getCacheAssociativity(Level: L1DCache).value(); |
683 | else |
684 | FirstCacheLevelAssociativity = |
685 | static_cast<int>(FirstCacheLevelDefaultAssociativity); |
686 | } |
687 | if (SecondCacheLevelAssociativity == -1) { |
688 | if (TTI->getCacheAssociativity(Level: L2DCache)) |
689 | SecondCacheLevelAssociativity = |
690 | TTI->getCacheAssociativity(Level: L2DCache).value(); |
691 | else |
692 | SecondCacheLevelAssociativity = |
693 | static_cast<int>(SecondCacheLevelDefaultAssociativity); |
694 | } |
695 | } |
696 | |
697 | /// Get parameters of the BLIS macro kernel. |
698 | /// |
699 | /// During the computation of matrix multiplication, blocks of partitioned |
700 | /// matrices are mapped to different layers of the memory hierarchy. |
701 | /// To optimize data reuse, blocks should be ideally kept in cache between |
702 | /// iterations. Since parameters of the macro kernel determine sizes of these |
703 | /// blocks, there are upper and lower bounds on these parameters. |
704 | /// |
705 | /// @param TTI Target Transform Info. |
706 | /// @param MicroKernelParams Parameters of the micro-kernel |
707 | /// to be taken into account. |
708 | /// @param MMI Parameters of the matrix multiplication operands. |
709 | /// @return The structure of type MacroKernelParamsTy. |
710 | /// @see MacroKernelParamsTy |
711 | /// @see MicroKernelParamsTy |
712 | static MacroKernelParamsTy |
713 | getMacroKernelParams(const llvm::TargetTransformInfo *TTI, |
714 | const MicroKernelParamsTy &MicroKernelParams, |
715 | const MatMulInfoTy &MMI) { |
716 | getTargetCacheParameters(TTI); |
717 | // According to www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf, |
718 | // it requires information about the first two levels of a cache to determine |
719 | // all the parameters of a macro-kernel. It also checks that an associativity |
720 | // degree of a cache level is greater than two. Otherwise, another algorithm |
721 | // for determination of the parameters should be used. |
722 | if (!(MicroKernelParams.Mr > 0 && MicroKernelParams.Nr > 0 && |
723 | FirstCacheLevelSize > 0 && SecondCacheLevelSize > 0 && |
724 | FirstCacheLevelAssociativity > 2 && SecondCacheLevelAssociativity > 2)) |
725 | return {.Mc: 1, .Nc: 1, .Kc: 1}; |
726 | // The quotient should be greater than zero. |
727 | if (PollyPatternMatchingNcQuotient <= 0) |
728 | return {.Mc: 1, .Nc: 1, .Kc: 1}; |
729 | int Car = floor( |
730 | x: (FirstCacheLevelAssociativity - 1) / |
731 | (1 + static_cast<double>(MicroKernelParams.Nr) / MicroKernelParams.Mr)); |
732 | |
733 | // Car can be computed to be zero since it is floor to int. |
734 | // On Mac OS, division by 0 does not raise a signal. This causes negative |
735 | // tile sizes to be computed. Prevent division by Cac==0 by early returning |
736 | // if this happens. |
737 | if (Car == 0) |
738 | return {.Mc: 1, .Nc: 1, .Kc: 1}; |
739 | |
740 | auto ElementSize = getMatMulAlignTypeSize(MMI); |
741 | assert(ElementSize > 0 && "The element size of the matrix multiplication " |
742 | "operands should be greater than zero." ); |
743 | int Kc = (Car * FirstCacheLevelSize) / |
744 | (MicroKernelParams.Mr * FirstCacheLevelAssociativity * ElementSize); |
745 | double Cac = |
746 | static_cast<double>(Kc * ElementSize * SecondCacheLevelAssociativity) / |
747 | SecondCacheLevelSize; |
748 | int Mc = floor(x: (SecondCacheLevelAssociativity - 2) / Cac); |
749 | int Nc = PollyPatternMatchingNcQuotient * MicroKernelParams.Nr; |
750 | |
751 | assert(Mc > 0 && Nc > 0 && Kc > 0 && |
752 | "Matrix block sizes should be greater than zero" ); |
753 | return {.Mc: Mc, .Nc: Nc, .Kc: Kc}; |
754 | } |
755 | |
756 | /// Create an access relation that is specific to |
757 | /// the matrix multiplication pattern. |
758 | /// |
759 | /// Create an access relation of the following form: |
760 | /// [O0, O1, O2, O3, O4, O5, O6, O7, O8] -> [OI, O5, OJ] |
761 | /// where I is @p FirstDim, J is @p SecondDim. |
762 | /// |
763 | /// It can be used, for example, to create relations that helps to consequently |
764 | /// access elements of operands of a matrix multiplication after creation of |
765 | /// the BLIS micro and macro kernels. |
766 | /// |
767 | /// @see ScheduleTreeOptimizer::createMicroKernel |
768 | /// @see ScheduleTreeOptimizer::createMacroKernel |
769 | /// |
770 | /// Subsequently, the described access relation is applied to the range of |
771 | /// @p MapOldIndVar, that is used to map original induction variables to |
772 | /// the ones, which are produced by schedule transformations. It helps to |
773 | /// define relations using a new space and, at the same time, keep them |
774 | /// in the original one. |
775 | /// |
776 | /// @param MapOldIndVar The relation, which maps original induction variables |
777 | /// to the ones, which are produced by schedule |
778 | /// transformations. |
779 | /// @param FirstDim, SecondDim The input dimensions that are used to define |
780 | /// the specified access relation. |
781 | /// @return The specified access relation. |
782 | static isl::map getMatMulAccRel(isl::map MapOldIndVar, unsigned FirstDim, |
783 | unsigned SecondDim) { |
784 | auto AccessRelSpace = isl::space(MapOldIndVar.ctx(), 0, 9, 3); |
785 | auto AccessRel = isl::map::universe(space: AccessRelSpace); |
786 | AccessRel = AccessRel.equate(type1: isl::dim::in, pos1: FirstDim, type2: isl::dim::out, pos2: 0); |
787 | AccessRel = AccessRel.equate(type1: isl::dim::in, pos1: 5, type2: isl::dim::out, pos2: 1); |
788 | AccessRel = AccessRel.equate(type1: isl::dim::in, pos1: SecondDim, type2: isl::dim::out, pos2: 2); |
789 | return MapOldIndVar.apply_range(map2: AccessRel); |
790 | } |
791 | |
792 | static isl::schedule_node createExtensionNode(isl::schedule_node Node, |
793 | isl::map ExtensionMap) { |
794 | auto Extension = isl::union_map(ExtensionMap); |
795 | auto NewNode = isl::schedule_node::from_extension(extension: Extension); |
796 | return Node.graft_before(graft: NewNode); |
797 | } |
798 | |
799 | static isl::schedule_node optimizePackedB(isl::schedule_node Node, |
800 | ScopStmt *Stmt, isl::map MapOldIndVar, |
801 | MicroKernelParamsTy MicroParams, |
802 | MacroKernelParamsTy MacroParams, |
803 | MatMulInfoTy &MMI) { |
804 | Scop *S = Stmt->getParent(); |
805 | isl::set Domain = Stmt->getDomain(); |
806 | |
807 | // Create packed array. |
808 | unsigned FirstDimSize = MacroParams.Nc / MicroParams.Nr; |
809 | unsigned SecondDimSize = MacroParams.Kc; |
810 | unsigned ThirdDimSize = MicroParams.Nr; |
811 | ScopArrayInfo *PackedB = |
812 | S->createScopArrayInfo(ElementType: MMI.B->getElementType(), BaseName: "Packed_B" , |
813 | Sizes: {FirstDimSize, SecondDimSize, ThirdDimSize}); |
814 | |
815 | // Compute the access relation for copying from B to PackedB. |
816 | isl::map AccRelB = MMI.B->getLatestAccessRelation(); |
817 | isl::map AccRelPackedB = getMatMulAccRel(MapOldIndVar, FirstDim: 3, SecondDim: 7); |
818 | AccRelPackedB = |
819 | AccRelPackedB.set_tuple_id(type: isl::dim::out, id: PackedB->getBasePtrId()); |
820 | |
821 | // Create the copy statement and redirect access. |
822 | ScopStmt *CopyStmt = S->addScopStmt(SourceRel: AccRelB, TargetRel: AccRelPackedB, Domain); |
823 | MMI.B->setNewAccessRelation(AccRelPackedB); |
824 | |
825 | unsigned Dim = unsignedFromIslSize(Size: MapOldIndVar.range_tuple_dim()); |
826 | assert(Dim >= 2); |
827 | // Insert into the schedule tree. |
828 | isl::map ExtMap = MapOldIndVar.project_out(type: isl::dim::out, first: 2, n: Dim - 2); |
829 | ExtMap = ExtMap.reverse(); |
830 | ExtMap = ExtMap.fix_si(type: isl::dim::out, pos: MMI.i, value: 0); |
831 | ExtMap = ExtMap.intersect_range(set: Domain); |
832 | ExtMap = ExtMap.set_tuple_id(type: isl::dim::out, id: CopyStmt->getDomainId()); |
833 | return createExtensionNode(Node, ExtensionMap: ExtMap); |
834 | } |
835 | |
836 | static isl::schedule_node optimizePackedA(isl::schedule_node Node, ScopStmt *, |
837 | isl::map MapOldIndVar, |
838 | MicroKernelParamsTy MicroParams, |
839 | MacroKernelParamsTy MacroParams, |
840 | MatMulInfoTy &MMI) { |
841 | isl::id InputDimsId = MapOldIndVar.get_tuple_id(type: isl::dim::in); |
842 | ScopStmt *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user()); |
843 | isl::set Domain = Stmt->getDomain(); |
844 | isl::id DomainId = Domain.get_tuple_id(); |
845 | |
846 | // Create the packed array. |
847 | unsigned FirstDimSize = MacroParams.Mc / MicroParams.Mr; |
848 | unsigned SecondDimSize = MacroParams.Kc; |
849 | unsigned ThirdDimSize = MicroParams.Mr; |
850 | ScopArrayInfo *PackedA = Stmt->getParent()->createScopArrayInfo( |
851 | ElementType: MMI.A->getElementType(), BaseName: "Packed_A" , |
852 | Sizes: {FirstDimSize, SecondDimSize, ThirdDimSize}); |
853 | |
854 | // Compute the access relation for copying from A to PackedA. |
855 | isl::map AccRelA = MMI.A->getLatestAccessRelation(); |
856 | isl::map AccRelPackedA = getMatMulAccRel(MapOldIndVar, FirstDim: 4, SecondDim: 6); |
857 | AccRelPackedA = |
858 | AccRelPackedA.set_tuple_id(type: isl::dim::out, id: PackedA->getBasePtrId()); |
859 | // { MemrefA[] -> PackedA[] } |
860 | isl::map PackedATranslator = AccRelPackedA.apply_domain(map2: AccRelA); |
861 | |
862 | // Compute the domain for the copy statement. |
863 | // Construct the copy statement domain out of the 3 outermost scatter |
864 | // dimensions (to match the 3 band nodes surrounding the extension node) and |
865 | // the array elements to copy (one statement instance per array element). |
866 | // { Scatter[] } |
867 | isl::set ScatterDomain = MapOldIndVar.intersect_domain(set: Domain).range(); |
868 | // { Scatter[] -> OutermostScatter[] } |
869 | isl::map OuterDomainMap = |
870 | makeIdentityMap(Set: ScatterDomain, RestrictDomain: true).project_out(type: isl::dim::out, first: 3, n: 6); |
871 | // { Scatter[] -> MemrefA[] } |
872 | isl::map CopyFrom = MapOldIndVar.reverse().apply_range(map2: AccRelA); |
873 | // { Scatter[] -> CopyStmt[] } |
874 | isl::map DomainTranslator = OuterDomainMap.range_product(map2: CopyFrom); |
875 | // { CopyStmt[] } |
876 | isl::set CopyDomain = DomainTranslator.range(); |
877 | |
878 | // Translate the access relations to the new domain. |
879 | // { CopyStmt[] -> MemrefA[] } |
880 | CopyFrom = CopyFrom.apply_domain(map2: DomainTranslator); |
881 | // { CopyStmt[] -> PackedA[] } |
882 | isl::map CopyTo = CopyFrom.apply_range(map2: PackedATranslator); |
883 | |
884 | // Create the copy statement and redirect access. |
885 | ScopStmt *CopyStmt = |
886 | Stmt->getParent()->addScopStmt(SourceRel: CopyFrom, TargetRel: CopyTo, Domain: CopyDomain); |
887 | MMI.A->setNewAccessRelation(AccRelPackedA); |
888 | |
889 | // Insert into the schedule tree. |
890 | // { Scatter[] -> CopyStmt[] } |
891 | isl::map ExtScatterCopy = makeIdentityMap(Set: CopyStmt->getDomain(), RestrictDomain: true); |
892 | ExtScatterCopy = ExtScatterCopy.project_out(type: isl::dim::in, first: 3, n: 2); |
893 | return createExtensionNode(Node, ExtensionMap: ExtScatterCopy); |
894 | } |
895 | |
896 | /// Apply the packing transformation. |
897 | /// |
898 | /// The packing transformation can be described as a data-layout |
899 | /// transformation that requires to introduce a new array, copy data |
900 | /// to the array, and change memory access locations to reference the array. |
901 | /// It can be used to ensure that elements of the new array are read in-stride |
902 | /// access, aligned to cache lines boundaries, and preloaded into certain cache |
903 | /// levels. |
904 | /// |
905 | /// As an example let us consider the packing of the array A that would help |
906 | /// to read its elements with in-stride access. An access to the array A |
907 | /// is represented by an access relation that has the form |
908 | /// S[i, j, k] -> A[i, k]. The scheduling function of the SCoP statement S has |
909 | /// the form S[i,j, k] -> [floor((j mod Nc) / Nr), floor((i mod Mc) / Mr), |
910 | /// k mod Kc, j mod Nr, i mod Mr]. |
911 | /// |
912 | /// To ensure that elements of the array A are read in-stride access, we add |
913 | /// a new array Packed_A[Mc/Mr][Kc][Mr] to the SCoP, using |
914 | /// Scop::createScopArrayInfo, change the access relation |
915 | /// S[i, j, k] -> A[i, k] to |
916 | /// S[i, j, k] -> Packed_A[floor((i mod Mc) / Mr), k mod Kc, i mod Mr], using |
917 | /// MemoryAccess::setNewAccessRelation, and copy the data to the array, using |
918 | /// the copy statement created by Scop::addScopStmt. |
919 | /// |
920 | /// @param Node The schedule node to be optimized. |
921 | /// @param MapOldIndVar The relation, which maps original induction variables |
922 | /// to the ones, which are produced by schedule |
923 | /// transformations. |
924 | /// @param MicroParams, MacroParams Parameters of the BLIS kernel |
925 | /// to be taken into account. |
926 | /// @param MMI Parameters of the matrix multiplication operands. |
927 | /// @return The optimized schedule node. |
928 | static isl::schedule_node |
929 | optimizeDataLayoutMatrMulPattern(isl::schedule_node Node, isl::map MapOldIndVar, |
930 | MicroKernelParamsTy MicroParams, |
931 | MacroKernelParamsTy MacroParams, |
932 | MatMulInfoTy &MMI) { |
933 | isl::id InputDimsId = MapOldIndVar.get_tuple_id(type: isl::dim::in); |
934 | ScopStmt *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user()); |
935 | |
936 | Node = Node.parent().parent().parent().parent().parent().parent(); |
937 | Node = isl::manage(ptr: isl_schedule_node_band_split(node: Node.release(), pos: 2)); |
938 | |
939 | Node = Node.child(pos: 0); |
940 | Node = |
941 | optimizePackedB(Node, Stmt, MapOldIndVar, MicroParams, MacroParams, MMI); |
942 | |
943 | Node = Node.child(pos: 0); |
944 | Node = |
945 | optimizePackedA(Node, Stmt, MapOldIndVar, MicroParams, MacroParams, MMI); |
946 | |
947 | return Node.child(pos: 0).child(pos: 0).child(pos: 0).child(pos: 0).child(pos: 0); |
948 | } |
949 | |
950 | /// Get a relation mapping induction variables produced by schedule |
951 | /// transformations to the original ones. |
952 | /// |
953 | /// @param Node The schedule node produced as the result of creation |
954 | /// of the BLIS kernels. |
955 | /// @param MicroKernelParams, MacroKernelParams Parameters of the BLIS kernel |
956 | /// to be taken into account. |
957 | /// @return The relation mapping original induction variables to the ones |
958 | /// produced by schedule transformation. |
959 | /// @see ScheduleTreeOptimizer::createMicroKernel |
960 | /// @see ScheduleTreeOptimizer::createMacroKernel |
961 | /// @see getMacroKernelParams |
962 | static isl::map |
963 | getInductionVariablesSubstitution(isl::schedule_node Node, |
964 | MicroKernelParamsTy MicroKernelParams, |
965 | MacroKernelParamsTy MacroKernelParams) { |
966 | auto Child = Node.child(pos: 0); |
967 | auto UnMapOldIndVar = Child.get_prefix_schedule_union_map(); |
968 | auto MapOldIndVar = isl::map::from_union_map(umap: UnMapOldIndVar); |
969 | unsigned Dim = unsignedFromIslSize(Size: MapOldIndVar.range_tuple_dim()); |
970 | if (Dim > 9u) |
971 | return MapOldIndVar.project_out(type: isl::dim::out, first: 0, n: Dim - 9); |
972 | return MapOldIndVar; |
973 | } |
974 | |
975 | /// Isolate a set of partial tile prefixes and unroll the isolated part. |
976 | /// |
977 | /// The set should ensure that it contains only partial tile prefixes that have |
978 | /// exactly Mr x Nr iterations of the two innermost loops produced by |
979 | /// the optimization of the matrix multiplication. Mr and Nr are parameters of |
980 | /// the micro-kernel. |
981 | /// |
982 | /// In case of parametric bounds, this helps to auto-vectorize the unrolled |
983 | /// innermost loops, using the SLP vectorizer. |
984 | /// |
985 | /// @param Node The schedule node to be modified. |
986 | /// @param MicroKernelParams Parameters of the micro-kernel |
987 | /// to be taken into account. |
988 | /// @return The modified isl_schedule_node. |
989 | static isl::schedule_node |
990 | isolateAndUnrollMatMulInnerLoops(isl::schedule_node Node, |
991 | MicroKernelParamsTy MicroKernelParams) { |
992 | isl::schedule_node Child = Node.child(pos: 0); |
993 | isl::union_map UnMapOldIndVar = Child.get_prefix_schedule_relation(); |
994 | isl::set Prefix = isl::map::from_union_map(umap: UnMapOldIndVar).range(); |
995 | unsigned Dims = unsignedFromIslSize(Size: Prefix.tuple_dim()); |
996 | assert(Dims >= 1); |
997 | Prefix = Prefix.project_out(type: isl::dim::set, first: Dims - 1, n: 1); |
998 | Prefix = getPartialTilePrefixes(ScheduleRange: Prefix, VectorWidth: MicroKernelParams.Nr); |
999 | Prefix = getPartialTilePrefixes(ScheduleRange: Prefix, VectorWidth: MicroKernelParams.Mr); |
1000 | |
1001 | isl::union_set IsolateOption = |
1002 | getIsolateOptions(IsolateDomain: Prefix.add_dims(type: isl::dim::set, n: 3), OutDimsNum: 3); |
1003 | isl::ctx Ctx = Node.ctx(); |
1004 | auto Options = IsolateOption.unite(uset2: getDimOptions(Ctx, Option: "unroll" )); |
1005 | Options = Options.unite(uset2: getUnrollIsolatedSetOptions(Ctx)); |
1006 | Node = Node.as<isl::schedule_node_band>().set_ast_build_options(Options); |
1007 | Node = Node.parent().parent().parent(); |
1008 | IsolateOption = getIsolateOptions(IsolateDomain: Prefix, OutDimsNum: 3); |
1009 | Options = IsolateOption.unite(uset2: getDimOptions(Ctx, Option: "separate" )); |
1010 | Node = Node.as<isl::schedule_node_band>().set_ast_build_options(Options); |
1011 | Node = Node.child(pos: 0).child(pos: 0).child(pos: 0); |
1012 | return Node; |
1013 | } |
1014 | |
1015 | /// Insert "Loop Vectorizer Disabled" mark node. |
1016 | /// |
1017 | /// @param Node The child of the mark node to be inserted. |
1018 | /// @return The modified isl_schedule_node. |
1019 | static isl::schedule_node markLoopVectorizerDisabled(isl::schedule_node Node) { |
1020 | auto Id = isl::id::alloc(ctx: Node.ctx(), name: "Loop Vectorizer Disabled" , user: nullptr); |
1021 | return Node.insert_mark(mark: Id).child(pos: 0); |
1022 | } |
1023 | |
1024 | /// Restore the initial ordering of dimensions of the band node |
1025 | /// |
1026 | /// In case the band node represents all the dimensions of the iteration |
1027 | /// domain, recreate the band node to restore the initial ordering of the |
1028 | /// dimensions. |
1029 | /// |
1030 | /// @param Node The band node to be modified. |
1031 | /// @return The modified schedule node. |
1032 | static isl::schedule_node |
1033 | getBandNodeWithOriginDimOrder(isl::schedule_node Node) { |
1034 | assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band); |
1035 | if (isl_schedule_node_get_type(node: Node.child(pos: 0).get()) != isl_schedule_node_leaf) |
1036 | return Node; |
1037 | auto Domain = Node.get_universe_domain(); |
1038 | assert(isl_union_set_n_set(Domain.get()) == 1); |
1039 | if (Node.get_schedule_depth().release() != 0 || |
1040 | (unsignedFromIslSize(Size: isl::set(Domain).tuple_dim()) != |
1041 | unsignedFromIslSize(Size: Node.as<isl::schedule_node_band>().n_member()))) |
1042 | return Node; |
1043 | Node = isl::manage(ptr: isl_schedule_node_delete(node: Node.copy())); |
1044 | auto PartialSchedulePwAff = Domain.identity_union_pw_multi_aff(); |
1045 | auto PartialScheduleMultiPwAff = |
1046 | isl::multi_union_pw_aff(PartialSchedulePwAff); |
1047 | PartialScheduleMultiPwAff = |
1048 | PartialScheduleMultiPwAff.reset_tuple_id(type: isl::dim::set); |
1049 | return Node.insert_partial_schedule(schedule: PartialScheduleMultiPwAff); |
1050 | } |
1051 | |
1052 | static isl::schedule_node optimizeMatMulPattern(isl::schedule_node Node, |
1053 | const TargetTransformInfo *TTI, |
1054 | MatMulInfoTy &MMI) { |
1055 | assert(TTI && "The target transform info should be provided." ); |
1056 | int DimOutNum = isl_schedule_node_band_n_member(node: Node.get()); |
1057 | assert(DimOutNum > 2 && "In case of the matrix multiplication the loop nest " |
1058 | "and, consequently, the corresponding scheduling " |
1059 | "functions have at least three dimensions." ); |
1060 | Node = getBandNodeWithOriginDimOrder(Node); |
1061 | Node = permuteBandNodeDimensions(Node, FirstDim: MMI.i, SecondDim: DimOutNum - 3); |
1062 | int NewJ = MMI.j == DimOutNum - 3 ? MMI.i : MMI.j; |
1063 | int NewK = MMI.k == DimOutNum - 3 ? MMI.i : MMI.k; |
1064 | Node = permuteBandNodeDimensions(Node, FirstDim: NewJ, SecondDim: DimOutNum - 2); |
1065 | NewK = NewK == DimOutNum - 2 ? NewJ : NewK; |
1066 | Node = permuteBandNodeDimensions(Node, FirstDim: NewK, SecondDim: DimOutNum - 1); |
1067 | auto MicroKernelParams = getMicroKernelParams(TTI, MMI); |
1068 | auto MacroKernelParams = getMacroKernelParams(TTI, MicroKernelParams, MMI); |
1069 | Node = createMacroKernel(Node, MacroKernelParams); |
1070 | Node = createMicroKernel(Node, MicroKernelParams); |
1071 | if (MacroKernelParams.Mc == 1 || MacroKernelParams.Nc == 1 || |
1072 | MacroKernelParams.Kc == 1) |
1073 | return Node; |
1074 | auto MapOldIndVar = getInductionVariablesSubstitution(Node, MicroKernelParams, |
1075 | MacroKernelParams); |
1076 | if (MapOldIndVar.is_null()) |
1077 | return Node; |
1078 | Node = markLoopVectorizerDisabled(Node: Node.parent()).child(pos: 0); |
1079 | Node = isolateAndUnrollMatMulInnerLoops(Node, MicroKernelParams); |
1080 | return optimizeDataLayoutMatrMulPattern(Node, MapOldIndVar, MicroParams: MicroKernelParams, |
1081 | MacroParams: MacroKernelParams, MMI); |
1082 | } |
1083 | |
1084 | /// Check if this node contains a partial schedule that could |
1085 | /// probably be optimized with analytical modeling. |
1086 | /// |
1087 | /// isMatrMultPattern tries to determine whether the following conditions |
1088 | /// are true: |
1089 | /// 1. the partial schedule contains only one statement. |
1090 | /// 2. there are exactly three input dimensions. |
1091 | /// 3. all memory accesses of the statement will have stride 0 or 1, if we |
1092 | /// interchange loops (switch the variable used in the inner loop to |
1093 | /// the outer loop). |
1094 | /// 4. all memory accesses of the statement except from the last one, are |
1095 | /// read memory access and the last one is write memory access. |
1096 | /// 5. all subscripts of the last memory access of the statement don't |
1097 | /// contain the variable used in the inner loop. |
1098 | /// If this is the case, we could try to use an approach that is similar to |
1099 | /// the one used to get close-to-peak performance of matrix multiplications. |
1100 | /// |
1101 | /// @param Node The node to check. |
1102 | /// @param D The SCoP dependencies. |
1103 | /// @param MMI Parameters of the matrix multiplication operands. |
1104 | static bool isMatrMultPattern(isl::schedule_node Node, const Dependences *D, |
1105 | MatMulInfoTy &MMI) { |
1106 | auto PartialSchedule = isl::manage( |
1107 | ptr: isl_schedule_node_band_get_partial_schedule_union_map(node: Node.get())); |
1108 | if (isl_schedule_node_band_n_member(node: Node.get()) < 3 || |
1109 | Node.get_schedule_depth().release() != 0 || |
1110 | isl_union_map_n_map(umap: PartialSchedule.get()) != 1) |
1111 | return false; |
1112 | auto NewPartialSchedule = isl::map::from_union_map(umap: PartialSchedule); |
1113 | if (containsMatrMult(PartialSchedule: NewPartialSchedule, D, MMI)) |
1114 | return true; |
1115 | return false; |
1116 | } |
1117 | |
1118 | /// Get the dimension size. |
1119 | /// |
1120 | /// Return the size of the dimension @p Pos, which is obtained from @p SAI. |
1121 | /// Return -1 in the case of the first dimension of a multi-dimensional array, |
1122 | /// since the ScopArrayInfo class does not carry size information. |
1123 | /// |
1124 | /// @param SAI The information about the array. |
1125 | /// @param Pos The position of the dimension. |
1126 | /// @return The size of the dimension. |
1127 | static int getDimSize(const ScopArrayInfo *SAI, unsigned Pos) { |
1128 | if (Pos == 0) |
1129 | return -1; |
1130 | const llvm::SCEV *SCEVDimSize = SAI->getDimensionSize(Dim: Pos); |
1131 | assert(SCEVDimSize); |
1132 | auto *ConstantDimSize = dyn_cast<const SCEVConstant>(Val: SCEVDimSize); |
1133 | assert(ConstantDimSize); |
1134 | auto *IntDimSize = dyn_cast<ConstantInt>(Val: ConstantDimSize->getValue()); |
1135 | assert(IntDimSize); |
1136 | return IntDimSize->getSExtValue(); |
1137 | } |
1138 | |
1139 | /// Check whether the access relation has the specified form. |
1140 | /// |
1141 | /// Check that the access relation @p AccMap has the form T[I0, …, In], where |
1142 | /// indexes I0, …, In are specified by @p Dimensions. |
1143 | /// |
1144 | /// @param Domain The domain of the access relation. |
1145 | /// @param AccMap The access relation to be checked. |
1146 | /// @param Dimensions The permutation of the subset of the input dimensions. |
1147 | /// @return True if @p AccMap has the expected form and false, |
1148 | /// otherwise. |
1149 | static bool isCorrectAccessMap(isl::set Domain, isl::map AccMap, |
1150 | ArrayRef<int> Dimensions) { |
1151 | isl::space Space = AccMap.get_space(); |
1152 | if (unsignedFromIslSize(Size: Space.dim(type: isl::dim::out)) != Dimensions.size()) |
1153 | return false; |
1154 | |
1155 | // Create an access relation of the following form: |
1156 | // [I0, …, Im] -> [Il, …, In], where indexes |
1157 | // Il, …, In are specified by @p Dimensions. |
1158 | isl::map PossibleTensor = isl::map::universe(space: Space); |
1159 | unsigned DimInSize = unsignedFromIslSize(Size: Space.dim(type: isl::dim::in)); |
1160 | for (unsigned i = 0; i < Dimensions.size(); i++) { |
1161 | const int InPos = Dimensions[i]; |
1162 | if ((InPos >= static_cast<int>(DimInSize)) || (InPos < 0)) |
1163 | return false; |
1164 | PossibleTensor = |
1165 | PossibleTensor.equate(type1: isl::dim::in, pos1: InPos, type2: isl::dim::out, pos2: i); |
1166 | } |
1167 | |
1168 | AccMap = AccMap.intersect_domain(set: Domain); |
1169 | PossibleTensor = PossibleTensor.intersect_domain(set: Domain); |
1170 | |
1171 | // If AccMap != PossibleTensor here (the two maps have been gisted at |
1172 | // this point), it means that the writes are not complete, or in other |
1173 | // words, it is a Partial write and Partial writes must be rejected. |
1174 | return AccMap.is_equal(map2: PossibleTensor); |
1175 | } |
1176 | |
1177 | /// Check whether the access represents the tensor contraction operand. |
1178 | /// |
1179 | /// Check that the access relation @p AccMap has the form T[i1, …, in]. |
1180 | /// Obtained indexes i1, …, in, their sizes and their permutation are stored |
1181 | /// into @p IndexSet, @p DimensionSizes, and @p Dimensions, respectively. |
1182 | /// |
1183 | /// @param Domain The domain of the access relation. |
1184 | /// @param AccMap The access relation to be checked. |
1185 | /// @param IndexSet The subset of the input dimensions. |
1186 | /// @param DimensionSizes Sizes of the input dimensions of @p Dimensions. |
1187 | /// @param Dimensions The permutation of the subset of the input dimensions. |
1188 | /// @return True if @p AccMap has the expected form and false, |
1189 | /// otherwise. |
1190 | static bool isTCOperandAcc(isl::set Domain, isl::map AccMap, |
1191 | SmallDenseSet<int> &IndexSet, |
1192 | SmallVectorImpl<int> &DimensionSizes, |
1193 | SmallVectorImpl<int> &Dimensions) { |
1194 | isl::id Id = AccMap.get_tuple_id(type: isl::dim::out); |
1195 | const ScopArrayInfo *SAI = ScopArrayInfo::getFromId(Id); |
1196 | assert(SAI && "AccMap should represent memory access" ); |
1197 | |
1198 | // Fix values of output dimensions with respect to their positions. |
1199 | // In the case of the tensor contraction, values of output dimensions are |
1200 | // fixed and form a permutation of a subset of values of input dimensions. |
1201 | // |
1202 | // For example, in the case of Stmt[i][j][k] -> A[k][i], which represents |
1203 | // the operand of the tensor contraction, we get the following map by fixing |
1204 | // the output dimensions Stmt[1][j][0] -> A[0][1]. |
1205 | // |
1206 | // We store the permutation of the subset of the input dimensions {2, 0} into |
1207 | // @p Dimensions. |
1208 | // |
1209 | // The obtained permutation and the isCorrectAccessMap function are used to |
1210 | // check whether the access relation @p AccMap represents the tensor |
1211 | // contraction operand. For example, in the case of |
1212 | // Stmt[i][j][k] -> A[i-1][j+1], we get Stmt[1][0][k] -> A[0][1] and, |
1213 | // consequently, {1, 0}, which is rejected by isCorrectAccessMap, |
1214 | // since it corresponds to Stmt[i][j][k] -> A[j][i]. |
1215 | isl::map CheckMap = isl::manage(ptr: AccMap.copy()); |
1216 | unsigned OutDimNum = unsignedFromIslSize(Size: CheckMap.dim(type: isl::dim::out)); |
1217 | for (unsigned i = 0; i < OutDimNum; i++) |
1218 | CheckMap = CheckMap.fix_si(type: isl::dim::out, pos: i, value: i); |
1219 | |
1220 | // Try to obtain the permutation and sizes of corresponding input dimensions. |
1221 | Dimensions.assign(NumElts: OutDimNum, Elt: -1); |
1222 | for (unsigned i : rangeIslSize(Begin: 0, End: CheckMap.dim(type: isl::dim::in))) { |
1223 | isl::val Val = getConstant(Map: CheckMap, Dim: isl::dim::in, Pos: i); |
1224 | if (!Val.is_int()) |
1225 | continue; |
1226 | int OutPos = -1; |
1227 | llvm::APInt ValAPInt = APIntFromVal(V: Val); |
1228 | if (ValAPInt.isSignedIntN(N: 32)) |
1229 | OutPos = ValAPInt.getSExtValue(); |
1230 | if ((OutPos < 0) || (OutPos >= static_cast<int>(OutDimNum)) || |
1231 | IndexSet.count(V: i)) |
1232 | return false; |
1233 | IndexSet.insert(V: i); |
1234 | Dimensions[OutPos] = i; |
1235 | if (DimensionSizes[i] <= 0) |
1236 | DimensionSizes[i] = getDimSize(SAI, Pos: OutPos); |
1237 | } |
1238 | |
1239 | return isCorrectAccessMap(Domain, AccMap, Dimensions); |
1240 | } |
1241 | |
1242 | /// Find the intersection of two sets. |
1243 | /// |
1244 | /// Find the intersection of the set @p A and the set @p B. |
1245 | /// |
1246 | /// @param A, B Sets to intersect. |
1247 | /// @return The set intersection. |
1248 | static SmallDenseSet<int> intersect(const SmallDenseSet<int> &A, |
1249 | const SmallDenseSet<int> &B) { |
1250 | SmallDenseSet<int> Intersection = A; |
1251 | set_intersect(S1&: Intersection, S2: B); |
1252 | return Intersection; |
1253 | } |
1254 | |
1255 | /// Check whether the set is a superset. |
1256 | /// |
1257 | /// Check that the set @p A is a superset of @p B. |
1258 | /// |
1259 | /// @param A, B Sets to be checked. |
1260 | /// @return True if the set A is a superset of B. |
1261 | static bool isSuperset(const SmallDenseSet<int> &A, |
1262 | const SmallDenseSet<int> &B) { |
1263 | return intersect(A, B).size() == B.size(); |
1264 | } |
1265 | |
1266 | /// Find the union of two sets. |
1267 | /// |
1268 | /// Find the union of the set @p A and the set @p B. |
1269 | /// |
1270 | /// @param A, B Sets to unite. |
1271 | /// @return The set union. |
1272 | static SmallDenseSet<int> unite(const SmallDenseSet<int> &A, |
1273 | const SmallDenseSet<int> &B) { |
1274 | SmallDenseSet<int> Union = A; |
1275 | set_union(S1&: Union, S2: B); |
1276 | return Union; |
1277 | } |
1278 | |
1279 | /// Determine the access that writes to the tensor, which contains |
1280 | /// the result of the tensor contraction. |
1281 | /// |
1282 | /// @param Domain The domain of the statement. |
1283 | /// @param Stmt The statement, which writes to memory. |
1284 | /// @param TCI The information about the tensor contraction. |
1285 | /// @param IandJIndexSet The set, which contains free indexes of tensors. |
1286 | /// @return The determined MemoryAccess, or nullptr if there is no necessary |
1287 | /// access within the SCoP. |
1288 | static MemoryAccess *getWriteAccess(isl::set Domain, ScopStmt *Stmt, |
1289 | TCInfoTy &TCI, |
1290 | SmallDenseSet<int> &IandJIndexSet) { |
1291 | TCI.WriteToC = nullptr; |
1292 | SmallVector<MemoryAccess *, 32> Accesses = getAccessesInOrder(Stmt&: *Stmt); |
1293 | for (MemoryAccess *MemA : reverse(C&: Accesses)) { |
1294 | // A TC-like does not contain write scalar memory accesses |
1295 | if (!MemA->isLatestArrayKind()) |
1296 | return nullptr; |
1297 | // The last memory access should be a write memory access. |
1298 | if (!MemA->isWrite()) |
1299 | return nullptr; |
1300 | |
1301 | isl::map AccMap = MemA->getLatestAccessRelation(); |
1302 | if (!isTCOperandAcc(Domain, AccMap, IndexSet&: IandJIndexSet, DimensionSizes&: TCI.DimensionSizes, |
1303 | Dimensions&: TCI.CDimensions)) |
1304 | return nullptr; |
1305 | |
1306 | return MemA; |
1307 | } |
1308 | return nullptr; |
1309 | } |
1310 | |
1311 | /// Determine an access, which reads elements of an operand of the tensor |
1312 | /// contraction |
1313 | /// |
1314 | /// @param MemAccessPtr The access, which reads elements of the tensor. |
1315 | /// @param IndexSet The set, which contains indexes of the tensors. |
1316 | /// @param IandJIndexSet The set, which contains free indexes of tensors. |
1317 | /// @param Dimensions The permutation of the subset of the input dimensions. |
1318 | /// @param TCI The information about the tensor contraction. |
1319 | /// @return True if the memory access @p MemAccessPtr corresponds |
1320 | /// to the tensor contraction. |
1321 | static bool setReadAccess(MemoryAccess *MemAccessPtr, |
1322 | const SmallDenseSet<int> &IndexSet, |
1323 | const SmallDenseSet<int> &IandJIndexSet, |
1324 | ArrayRef<int> Dimensions, TCInfoTy &TCI) { |
1325 | if (!TCI.A) { |
1326 | // Probably IndexSet is a union of I and P sets. |
1327 | if (!isSuperset(A: IndexSet, B: TCI.P)) |
1328 | return false; |
1329 | |
1330 | // Obtain the set I. |
1331 | TCI.I = set_difference(S1: IndexSet, S2: TCI.P); |
1332 | if (!isSuperset(A: IandJIndexSet, B: TCI.I)) |
1333 | return false; |
1334 | |
1335 | // Obtain the set J. |
1336 | TCI.J = set_difference(S1: IandJIndexSet, S2: TCI.I); |
1337 | |
1338 | // Set the first operand of the tensor contraction. |
1339 | TCI.A = MemAccessPtr; |
1340 | llvm::replace(Cont&: TCI.ADimensions, ContIt: TCI.ADimensions.begin(), |
1341 | ContEnd: TCI.ADimensions.end(), ValIt: Dimensions.begin(), ValEnd: Dimensions.end()); |
1342 | return true; |
1343 | } |
1344 | |
1345 | if (!TCI.B) { |
1346 | // IndexSet should be a union of J and P sets. |
1347 | if (unite(A: TCI.P, B: TCI.J) != IndexSet) |
1348 | return false; |
1349 | |
1350 | // Set the second operand of the tensor contraction. |
1351 | TCI.B = MemAccessPtr; |
1352 | llvm::replace(Cont&: TCI.BDimensions, ContIt: TCI.BDimensions.begin(), |
1353 | ContEnd: TCI.BDimensions.end(), ValIt: Dimensions.begin(), ValEnd: Dimensions.end()); |
1354 | return true; |
1355 | } |
1356 | |
1357 | return false; |
1358 | } |
1359 | |
1360 | /// Check that all memory accesses of the statement, except from the last |
1361 | /// one, are read memory accesses, which read elements of operands of the tensor |
1362 | /// contraction and its result. |
1363 | /// |
1364 | /// @param Domain The domain of the statement. |
1365 | /// @param Stmt The statement, which writes to memory. |
1366 | /// @param TCI The information about the tensor contraction. |
1367 | /// @param IandJIndexSet The set, which contains free indexes of tensors. |
1368 | /// @return True if all read memory accesses of the statement @p Stmt correspond |
1369 | /// to the tensor contraction. |
1370 | static bool setReadAccesses(isl::set Domain, ScopStmt *Stmt, TCInfoTy &TCI, |
1371 | SmallDenseSet<int> &IandJIndexSet) { |
1372 | TCI.A = nullptr; |
1373 | TCI.B = nullptr; |
1374 | TCI.ReadFromC = nullptr; |
1375 | SmallVector<MemoryAccess *, 32> Accesses = getAccessesInOrder(Stmt&: *Stmt); |
1376 | for (auto *MemA = Accesses.begin(); *MemA != TCI.WriteToC; MemA++) { |
1377 | MemoryAccess *MemAccessPtr = *MemA; |
1378 | |
1379 | // All memory accesses, except from the last one, should be read memory |
1380 | // accesses. |
1381 | if (MemAccessPtr->isWrite()) |
1382 | return false; |
1383 | |
1384 | isl::map AccMap = MemAccessPtr->getLatestAccessRelation(); |
1385 | |
1386 | if (!MemAccessPtr->isLatestArrayKind()) { |
1387 | // Check whether the scalar read memory access is not partial. |
1388 | if (!Domain.is_subset(set2: AccMap.domain())) |
1389 | return false; |
1390 | continue; |
1391 | return false; |
1392 | } |
1393 | |
1394 | // There is only one memory access, which reads elements of the result of |
1395 | // the tensor contraction. |
1396 | if (AccMap.is_equal(map2: TCI.WriteToC->getLatestAccessRelation())) { |
1397 | if (TCI.ReadFromC) |
1398 | return false; |
1399 | TCI.ReadFromC = MemAccessPtr; |
1400 | continue; |
1401 | } |
1402 | |
1403 | SmallVector<int> Dimensions; |
1404 | SmallDenseSet<int> IndexSet; |
1405 | if (!isTCOperandAcc(Domain, AccMap, IndexSet, DimensionSizes&: TCI.DimensionSizes, |
1406 | Dimensions)) |
1407 | return false; |
1408 | |
1409 | if (!setReadAccess(MemAccessPtr, IndexSet, IandJIndexSet, Dimensions, TCI)) |
1410 | return false; |
1411 | } |
1412 | |
1413 | // Check that there are read memory accesses, which read elements of operands |
1414 | // of the tensor contraction and its result. |
1415 | return TCI.ReadFromC && TCI.A && TCI.B; |
1416 | } |
1417 | |
1418 | /// Check accesses to operands of the tensor contraction. |
1419 | /// |
1420 | /// Check that accesses of the SCoP statement, which corresponds to |
1421 | /// the partial schedule @p PartialSchedule, represent accesses |
1422 | /// to the non-scalar operands of the tensor contraction. |
1423 | /// |
1424 | /// @param Domain The domain of the SCoP statement. |
1425 | /// @param PartialSchedule The partial schedule of the SCoP statement. |
1426 | /// @param TCI Parameters of the tensor contraction operands. |
1427 | /// @return True if the corresponding SCoP statement |
1428 | /// represents tensor contraction and false, |
1429 | /// otherwise. |
1430 | static bool containsOnlyTCAcc(isl::set Domain, isl::map PartialSchedule, |
1431 | TCInfoTy &TCI) { |
1432 | isl::id InputDimsId = PartialSchedule.get_tuple_id(type: isl::dim::in); |
1433 | ScopStmt *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user()); |
1434 | |
1435 | // In region statements, the order of memory accesses execution is not |
1436 | // predictable at compile-time. |
1437 | if ((Stmt->size() <= 1) || Stmt->isRegionStmt()) |
1438 | return false; |
1439 | |
1440 | unsigned DimNum = unsignedFromIslSize(Size: PartialSchedule.dim(type: isl::dim::in)); |
1441 | TCI.DimensionSizes.resize(N: DimNum); |
1442 | SmallDenseSet<int> IandJIndexSet; |
1443 | |
1444 | TCI.WriteToC = getWriteAccess(Domain, Stmt, TCI, IandJIndexSet); |
1445 | if (!TCI.WriteToC) |
1446 | return false; |
1447 | |
1448 | if (intersect(A: IandJIndexSet, B: TCI.P).size() != 0) |
1449 | return false; |
1450 | |
1451 | if (!setReadAccesses(Domain, Stmt, TCI, IandJIndexSet)) |
1452 | return false; |
1453 | |
1454 | return true; |
1455 | } |
1456 | |
1457 | /// Check that dependency corresponds to the tensor contraction carried over |
1458 | /// loop dimension @p Dim. |
1459 | /// |
1460 | /// Check that the dependency has the form |
1461 | /// S(..., ki, max(k(i + 1)), ..., max(kn), ...) -> |
1462 | /// S(..., ki + 1, min(k(i + 1)), ..., min(kn), ...), where S is the SCoP |
1463 | /// statement. For this purpose, we analyze the set @p DepDelta, which |
1464 | /// represents the differences between image elements and domain elements of |
1465 | /// the corresponding map. |
1466 | /// |
1467 | /// @param DepDelta The set contains the differences between image elements |
1468 | /// and corresponding domain elements of the map, which |
1469 | /// represents the dependency. |
1470 | /// @param Dim The position of the index ki. |
1471 | /// @param BoundDeltas In the case of indexes of ki, the difference between |
1472 | /// image elements and corresponding domain elements |
1473 | /// corresponds to the difference between lexicographic |
1474 | /// minimum and lexicographic maximum of the corresponding |
1475 | /// dimension of the domain of the statement. |
1476 | /// @param IndexSet Obtained indexes ki, which describe the dependency. |
1477 | /// @return True if dependencies correspond to the tensor contraction |
1478 | /// and false, otherwise. |
1479 | static bool isReductionCarriedOverDim(isl::set DepDelta, unsigned Dim, |
1480 | isl::pw_multi_aff BoundDeltas, |
1481 | const SmallDenseSet<int> &IndexSet) { |
1482 | isl::space Space = DepDelta.get_space(); |
1483 | isl::set Superset = isl::set::universe(space: Space); |
1484 | for (unsigned i = 0; i < Dim; i += 1) |
1485 | Superset = Superset.fix_si(type: isl::dim::set, pos: i, value: 0); |
1486 | Superset = Superset.fix_si(type: isl::dim::set, pos: Dim, value: 1); |
1487 | |
1488 | // Check that the difference between the image element and the domain element |
1489 | // is equal to one in the case of the index ki. Image elements and |
1490 | // corresponding domain elements should be equal in the case of positions, |
1491 | // which are lower than the specified position. |
1492 | if (!DepDelta.is_subset(set2: Superset)) |
1493 | return false; |
1494 | |
1495 | // Compute a set, which is used to analyze how values of |
1496 | // the domain are related to the map that describes the dependency. |
1497 | isl_pw_multi_aff *DepDeltaPW = isl_pw_multi_aff_from_set(set: DepDelta.copy()); |
1498 | BoundDeltas = BoundDeltas.add(pma2: isl::manage(ptr: DepDeltaPW)); |
1499 | isl_set *ComplementRawSet = isl_set_from_pw_multi_aff(pma: BoundDeltas.release()); |
1500 | isl::set Complement = isl::manage(ptr: ComplementRawSet); |
1501 | |
1502 | for (unsigned i : rangeIslSize(Begin: Dim + 1, End: DepDelta.dim(type: isl::dim::set))) { |
1503 | if (!IndexSet.count(V: i)) { |
1504 | // Check the difference between the image element and the domain element |
1505 | // in the case of indexes, which do not describe the dependency. |
1506 | if (DepDelta.plain_get_val_if_fixed(type: isl::dim::set, pos: i).is_zero()) |
1507 | continue; |
1508 | return false; |
1509 | } |
1510 | |
1511 | // In the case of other indexes, which describe the dependency, |
1512 | // the difference between the image element and the domain element |
1513 | // should be equal to the difference between lexicographic minimum and |
1514 | // lexicographic maximum of the domain of the statement. |
1515 | if (!Complement.plain_get_val_if_fixed(type: isl::dim::set, pos: i).is_zero()) |
1516 | return false; |
1517 | } |
1518 | |
1519 | return true; |
1520 | } |
1521 | |
1522 | /// Check whether dependencies are over the complete domain. |
1523 | /// |
1524 | /// In the case of the tensor contraction RAW, WAW, WAR dependencies |
1525 | /// have the form |
1526 | /// S(..., ki, max(k(i + 1)), ..., max(kn), ...) -> |
1527 | /// S(..., ki + 1, min(k(i + 1)), ..., min(kn), ...), where S is the SCoP |
1528 | /// statement. Consequently, the domain of the dependencies |
1529 | /// can be described as |
1530 | /// Domain / Domain ∩ S(…, max(kn),…) ∩ S(…, max(k(i + 1)),…), |
1531 | /// where Domain is the domain of the statement S. |
1532 | /// |
1533 | /// For example, in the case of the following tensor contraction, |
1534 | /// corresponding domains will have the following form. |
1535 | /// |
1536 | /// An example of the tensor contraction: |
1537 | /// for (i = 0; i < 1024; i++) |
1538 | /// for (j = 0; j < 1024; j++) |
1539 | /// for (l = 0; l < 64; ++l) |
1540 | /// for (w = 0; w < 64; ++w) |
1541 | /// C[i][j] += A[i][l][w] * B[w][j][l]; |
1542 | /// |
1543 | /// The domain of the statement: |
1544 | /// { S[i0, i1, i2, i3] : i0 >= 0 and i0 <= 1023 and |
1545 | /// i1 >= 0 and i1 <= 1023 and |
1546 | /// i2 >= 0 and i2 <= 63 and |
1547 | /// i3 >= 0 and i3 <= 63 } |
1548 | /// |
1549 | /// The domain of the dependencies: |
1550 | /// { S[i0, i1, i2, i3] : (i0 >= 0 and i0 <= 1023 and |
1551 | /// i1 >= 0 and i1 <= 1023 and |
1552 | /// i2 >= 0 and i2 <= 63 and |
1553 | /// i3 >= 0 and i3 <= 62) or |
1554 | /// (i3 = 63 and i0 >= 0 and i0 <= 1023 and |
1555 | /// i1 >= 0 and i1 <= 1023 and |
1556 | /// i2 >= 0 and i2 <= 62) } |
1557 | /// |
1558 | /// @param Domain The domain of the statement. |
1559 | /// @param DepsForStmt RAW and RED dependencies for the statement. |
1560 | /// @param UpperBound The lexicographic maximum of the elements in |
1561 | /// the @p Domain. |
1562 | /// @param IndexSet Obtained indexes ki, which describe the dependencies. |
1563 | /// @return True if dependencies are over the complete domain |
1564 | /// and false, otherwise. |
1565 | static bool areDepsOverCompleteDomain(isl::set Domain, isl::map DepsForStmt, |
1566 | isl::pw_multi_aff UpperBound, |
1567 | SmallDenseSet<int> &IndexSet) { |
1568 | isl_set *UpperBoundRawSet = isl_set_from_pw_multi_aff(pma: UpperBound.copy()); |
1569 | isl::set UpperBoundSet = isl::manage(ptr: UpperBoundRawSet); |
1570 | |
1571 | isl::set DomainRed = isl::manage(ptr: Domain.copy()); |
1572 | for (const auto It : IndexSet) { |
1573 | isl::val FixedVal = UpperBoundSet.plain_get_val_if_fixed(type: isl::dim::set, pos: It); |
1574 | if (FixedVal.is_nan()) |
1575 | return false; |
1576 | DomainRed = isl::manage( |
1577 | ptr: isl_set_fix_val(set: DomainRed.copy(), type: isl_dim_set, pos: It, v: FixedVal.release())); |
1578 | } |
1579 | return DepsForStmt.domain().intersect(set2: Domain).is_equal( |
1580 | set2: Domain.subtract(set2: DomainRed)); |
1581 | } |
1582 | |
1583 | /// Check that dependencies correspond to the tensor contraction. |
1584 | /// |
1585 | /// Check that there are only true dependencies of the form |
1586 | /// S(..., ki, max(k(i + 1)), ..., max(kn), ...) -> |
1587 | /// S(..., ki + 1, min(k(i + 1)), ..., min(kn), ...), where S is the SCoP |
1588 | /// statement represented by @p Schedule. Such dependencies are produced by |
1589 | /// the tensor contraction. Obtained indexes ki are stored into @p IndexSet. |
1590 | /// |
1591 | /// The form of anti and output dependencies is specified implicitly by |
1592 | /// the form the SCoP statement, which is checked by subsequent analysis. |
1593 | /// |
1594 | /// @param Schedule The schedule of the SCoP statement. |
1595 | /// @param D The SCoP dependencies. |
1596 | /// @param Domain The domain of the statement. |
1597 | /// @param IndexSet Obtained indexes ki, which describe the dependencies. |
1598 | /// @return True if dependencies correspond to the tensor contraction |
1599 | /// and false, otherwise. |
1600 | static bool containsOnlyTcDeps(isl::map Schedule, const Dependences *D, |
1601 | SmallDenseSet<int> &IndexSet, isl::set Domain) { |
1602 | IslMaxOperationsGuard MaxOpGuard(Schedule.ctx().get(), OptComputeOut); |
1603 | |
1604 | isl::union_map Dep = |
1605 | D->getDependences(Kinds: Dependences::TYPE_RAW | Dependences::TYPE_RED); |
1606 | |
1607 | isl::space DomainSpace = Schedule.get_space().domain(); |
1608 | isl::space Space = DomainSpace.map_from_domain_and_range(range: DomainSpace); |
1609 | isl::map DepsForStmt = Dep.extract_map(space: Space); |
1610 | isl::set DepDeltas = DepsForStmt.deltas(); |
1611 | isl::size DeltasDimNum = DepDeltas.dim(type: isl::dim::set); |
1612 | isl::pw_multi_aff LowerBound = Domain.lexmin_pw_multi_aff(); |
1613 | isl::pw_multi_aff UpperBound = Domain.lexmax_pw_multi_aff(); |
1614 | isl::pw_multi_aff BoundDeltas = UpperBound.sub(pma2: LowerBound); |
1615 | |
1616 | for (int i : reverse(C: rangeIslSize(Begin: 0, End: DeltasDimNum))) { |
1617 | // In the case of the tensor contraction, the difference between image |
1618 | // elements and domain elements lies on a hyperplane where a dimension |
1619 | // has the fixed value one. |
1620 | isl::set Intersection = DepDeltas.fix_si(type: isl::dim::set, pos: i, value: 1); |
1621 | if (Intersection.is_empty()) |
1622 | continue; |
1623 | |
1624 | if (!isReductionCarriedOverDim(DepDelta: Intersection, Dim: i, BoundDeltas, IndexSet)) |
1625 | return false; |
1626 | |
1627 | IndexSet.insert(V: i); |
1628 | DepDeltas = DepDeltas.subtract(set2: Intersection); |
1629 | } |
1630 | |
1631 | // In the case of the tensor contraction, all dependencies should have |
1632 | // the previously described form. |
1633 | if ((unsignedFromIslSize(Size: DeltasDimNum) == 0) || !DepDeltas.is_empty()) |
1634 | return false; |
1635 | |
1636 | return areDepsOverCompleteDomain(Domain, DepsForStmt, UpperBound, IndexSet); |
1637 | } |
1638 | |
1639 | /// Check if the SCoP statement could probably be optimized with analytical |
1640 | /// modeling. |
1641 | /// |
1642 | /// containsTCInfoTy tries to determine whether the following conditions |
1643 | /// are true: |
1644 | /// |
1645 | /// 1. The last memory access modeling an array, MA1, represents writing to |
1646 | /// memory and has the form S(..., I, ..., J, ...) -> M(shuffle(I, J)), |
1647 | /// where S is the SCoP statement under consideration and shuffle(I, J) |
1648 | /// is a permutation of indexes of sets I and J. |
1649 | /// 2. There are only true dependencies of the form |
1650 | /// S(..., ki, max(k(i + 1)), ..., max(kn), ...) -> |
1651 | /// S(..., ki + 1, min(k(i + 1)), ..., min(kn), ...), where S is the SCoP |
1652 | /// statement represented by @p Schedule and ki are indexes of the set P. |
1653 | /// 3. SCoP contains an arbitrary number of reads from constants and only three |
1654 | /// access relations, MA2, MA3, and MA4 that represent reading from memory |
1655 | /// and have the form |
1656 | /// S(..., I, ..., P, ...) -> M(shuffle(I, P)), |
1657 | /// S(..., P, ..., J, ...) -> M(shuffle(J, P)), |
1658 | /// S(...) -> M(shuffle(I, J)), respectively. |
1659 | /// |
1660 | /// @param PartialSchedule The PartialSchedule that contains a SCoP statement |
1661 | /// to check. |
1662 | /// @param D The SCoP dependencies. |
1663 | /// @param TCI Parameters of the tensor contraction operands. |
1664 | /// @param Domain The domain of the statement. |
1665 | /// @return True if dependencies and memory accesses correspond to the tensor |
1666 | /// contraction and false, otherwise. |
1667 | static bool containsTCInfoTy(isl::map PartialSchedule, const Dependences *D, |
1668 | TCInfoTy &TCI, isl::set Domain) { |
1669 | if (!containsOnlyTcDeps(Schedule: PartialSchedule, D, IndexSet&: TCI.P, Domain)) |
1670 | return false; |
1671 | |
1672 | // TODO: handle cases of scalar multiplication if needed. |
1673 | if (TCI.P.size() == 0) |
1674 | return false; |
1675 | |
1676 | if (!containsOnlyTCAcc(Domain, PartialSchedule, TCI)) |
1677 | return false; |
1678 | |
1679 | // TODO: handle cases of GEMV if needed. |
1680 | if ((TCI.I.size() == 0) || (TCI.J.size() == 0)) |
1681 | return false; |
1682 | |
1683 | return true; |
1684 | } |
1685 | |
1686 | /// Check if this node contains a partial schedule that could |
1687 | /// probably be optimized with analytical modeling. |
1688 | /// |
1689 | /// isTCPattern is used to determine whether the SCoP represents a TC-like |
1690 | /// kernel [1], which is a perfectly nested set of loops, with a data usage |
1691 | /// pattern that is similar to that produced by the tensor contraction. |
1692 | /// |
1693 | /// A TC-like kernel can be defined as follows: |
1694 | /// |
1695 | /// 1. It satisfies the requirements of the polyhedral model. |
1696 | /// 2. Without loss of generality, it contains three nonempty bundles of |
1697 | /// one-dimensional for-loops with induction variables that are grouped into |
1698 | /// bundles I = i0...i(r-1), J = j0..j(s-1), and P = p0...p(t-1), and they |
1699 | /// are incremented by one. |
1700 | /// 3. The innermost loop body can be represented as a statement of the form |
1701 | /// C(shuffle(I, J)) = E(A(shuffle(I, P)), B(shuffle(P, J)), |
1702 | /// C(shuffle(I, J))), where A(shuffle(I, P)), B(shuffle(P, J)), |
1703 | /// C(shuffle(I, J)) are accesses to tensors A, B, C, respectively, |
1704 | /// shuffle(I, J), shuffle(I, P), and shuffle(P, J) are permutations of the |
1705 | /// enclosed indices, and E is an expression that contains reads from |
1706 | /// the tensors A, B, C, and an arbitrary number of reads from constants |
1707 | /// with respect to bundles I, J, and P. |
1708 | /// |
1709 | /// TC can be considered as a particular case of a TC-like kernel. |
1710 | /// |
1711 | /// The order of loops with indexes from P should be preserved. Otherwise, |
1712 | /// isTCPattern should check if a commutative operation is used. |
1713 | /// |
1714 | /// isTCPattern performs the following steps to check whether the SCoP |
1715 | /// corresponds to a definition of a TC-like kernel: |
1716 | /// |
1717 | /// 1. Checks that the node is the innermost band node. |
1718 | /// 2. Checks that the partial schedule contains only one statement. |
1719 | /// 3. Check that all ancestors of the node contain all band nodes for |
1720 | /// the statement and only mark nodes interleave such band nodes. This |
1721 | /// corresponds to a straightforward implementation of TC. |
1722 | /// 4. Analyses the dependencies to determine contraction dimensions. |
1723 | /// 5. Check that the last memory access modeling an array, represents writing |
1724 | /// to the result of the TC-like kernel. |
1725 | /// 6. Check that SCoP contains only three access relations that represent |
1726 | /// reading of the operands of the TC-like kernel and an arbitrary number of |
1727 | /// reads from constants. |
1728 | /// |
1729 | /// [1] - Gareev R., Grosser T., Kruse M. High-Performance Generalized Tensor |
1730 | /// Operations: A Compiler-Oriented Approach // ACM Transactions |
1731 | /// Architecture and Code Optimization (TACO). 2018. |
1732 | /// Vol. 15, no. 3. P. 34:1–34:27. DOI: 10.1145/3235029. |
1733 | /// |
1734 | /// If this is the case, we could logically represent tensors as matrices and |
1735 | /// apply algorithms, which are used to get close-to-peak performance of |
1736 | /// matrix multiplications in manually tuned BLAS libraries (e.g., BLIS). |
1737 | /// |
1738 | /// @param Node The node to check. |
1739 | /// @param D The SCoP dependencies. |
1740 | /// @param TCI Parameters of the tensor contraction operands. |
1741 | static bool isTCPattern(isl::schedule_node Node, const Dependences *D, |
1742 | TCInfoTy &TCI) { |
1743 | Node = Node.child(pos: 0); |
1744 | isl::union_map PartialSchedule = Node.get_prefix_schedule_union_map(); |
1745 | isl::union_set Domain = Node.domain(); |
1746 | Node = Node.parent(); |
1747 | |
1748 | // The partial schedule should contain only one statement. |
1749 | // TODO: This constraint should not be intrinsic to the algorithm. |
1750 | if (isl_union_set_n_set(uset: Domain.get()) != 1) |
1751 | return false; |
1752 | |
1753 | isl_schedule_node_type NodeType = isl_schedule_node_get_type(node: Node.get()); |
1754 | |
1755 | // Check that all ancestors of the node contain all band nodes for |
1756 | // the statement, which represents the TC-like kernel, and only mark nodes |
1757 | // interleave such band nodes. This corresponds to a straightforward |
1758 | // implementation of TC with/without DeLICM applied. |
1759 | // |
1760 | // For example, this covers the matrix multiplication pattern after a full |
1761 | // run of -polly-optree and -polly-delicm, where the write access is not |
1762 | // through the original memory access, but trough a PHI node that was |
1763 | // delicmed. Subsequently, such band nodes will be replaced by a single band |
1764 | // node. |
1765 | // |
1766 | // The corresponding schedule can be the following, where Stmt_for_body8 |
1767 | // contains the matrix multiplication: |
1768 | // |
1769 | // domain: "{ Stmt_for_body8[i0, i1, i2] : 0 <= i0 <= 1599 and |
1770 | // 0 <= i1 <= 1799 and |
1771 | // 0 <= i2 <= 2199; |
1772 | // Stmt_for_body3[i0, i1] : 0 <= i0 <= 1599 and |
1773 | // 0 <= i1 <= 1799; |
1774 | // Stmt_for_body3_last[i0, i1] : 0 <= i0 <= 1599 and |
1775 | // 0 <= i1 <= 1799 }" |
1776 | // child: |
1777 | // sequence: |
1778 | // - filter: "{ Stmt_for_body3[i0, i1] }" |
1779 | // child: |
1780 | // schedule: "[{ Stmt_for_body3[i0, i1] -> [(i0)] }, |
1781 | // { Stmt_for_body3[i0, i1] -> [(i1)] }]" |
1782 | // permutable: 1 |
1783 | // coincident: [ 1, 1 ] |
1784 | // - filter: "{ Stmt_for_body3_last[i0, i1] }" |
1785 | // child: |
1786 | // schedule: "[{ Stmt_for_body3_last[i0, i1] -> [(i0)] }, |
1787 | // { Stmt_for_body3_last[i0, i1] -> [(i1)] }]" |
1788 | // permutable: 1 |
1789 | // coincident: [ 1, 1 ] |
1790 | // - filter: "{ Stmt_for_body8[i0, i1, i2] }" |
1791 | // child: |
1792 | // schedule: "[{ Stmt_for_body8[i0, i1, i2] -> [(i0)] }, |
1793 | // { Stmt_for_body8[i0, i1, i2] -> [(i1)] }, |
1794 | // { Stmt_for_body8[i0, i1, i2] -> [(i2)] }]" |
1795 | // permutable: 1 |
1796 | // coincident: [ 1, 1, 0 ] |
1797 | // |
1798 | while (NodeType != isl_schedule_node_domain) { |
1799 | if (NodeType == isl_schedule_node_filter) { |
1800 | if (!Node.parent().isa<isl::schedule_node_sequence>() || |
1801 | !Node.parent().parent().isa<isl::schedule_node_domain>()) |
1802 | return false; |
1803 | break; |
1804 | } |
1805 | |
1806 | if ((NodeType != isl_schedule_node_band) && |
1807 | (NodeType != isl_schedule_node_mark)) |
1808 | return false; |
1809 | |
1810 | Node = Node.parent(); |
1811 | NodeType = isl_schedule_node_get_type(node: Node.get()); |
1812 | } |
1813 | |
1814 | isl::map PartialScheduleMap = isl::map::from_union_map(umap: PartialSchedule); |
1815 | if (containsTCInfoTy(PartialSchedule: PartialScheduleMap, D, TCI, Domain: isl::set(Domain))) |
1816 | return true; |
1817 | |
1818 | return false; |
1819 | } |
1820 | |
1821 | } // namespace |
1822 | |
1823 | isl::schedule_node |
1824 | polly::tryOptimizeMatMulPattern(isl::schedule_node Node, |
1825 | const llvm::TargetTransformInfo *TTI, |
1826 | const Dependences *D) { |
1827 | TCInfoTy TCI; |
1828 | if (PMBasedTCOpts && isTCPattern(Node, D, TCI)) |
1829 | POLLY_DEBUG(dbgs() << "The tensor contraction pattern was detected\n" ); |
1830 | MatMulInfoTy MMI; |
1831 | if (PMBasedMMMOpts && isMatrMultPattern(Node, D, MMI)) { |
1832 | POLLY_DEBUG(dbgs() << "The matrix multiplication pattern was detected\n" ); |
1833 | return optimizeMatMulPattern(Node, TTI, MMI); |
1834 | } |
1835 | return {}; |
1836 | } |
1837 | |