| 1 | //===- MatmulOptimizer.cpp -----------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "polly/MatmulOptimizer.h" |
| 10 | #include "polly/DependenceInfo.h" |
| 11 | #include "polly/Options.h" |
| 12 | #include "polly/ScheduleTreeTransform.h" |
| 13 | #include "polly/ScopInfo.h" |
| 14 | #include "polly/ScopPass.h" |
| 15 | #include "polly/Simplify.h" |
| 16 | #include "polly/Support/GICHelper.h" |
| 17 | #include "polly/Support/ISLTools.h" |
| 18 | #include "llvm/ADT/ArrayRef.h" |
| 19 | #include "llvm/ADT/DenseSet.h" |
| 20 | #include "llvm/ADT/Sequence.h" |
| 21 | #include "llvm/ADT/SetOperations.h" |
| 22 | #include "llvm/ADT/SmallVector.h" |
| 23 | #include "llvm/ADT/StringRef.h" |
| 24 | #include "llvm/ADT/iterator_range.h" |
| 25 | #include "llvm/Analysis/TargetTransformInfo.h" |
| 26 | #include "llvm/IR/DataLayout.h" |
| 27 | #include "llvm/IR/Function.h" |
| 28 | #include "llvm/IR/Module.h" |
| 29 | #include "llvm/Support/CommandLine.h" |
| 30 | #include "llvm/Support/Debug.h" |
| 31 | #include "llvm/Support/TypeSize.h" |
| 32 | #include "llvm/Support/raw_ostream.h" |
| 33 | #include "isl/ctx.h" |
| 34 | #include "isl/schedule_node.h" |
| 35 | #include "isl/schedule_type.h" |
| 36 | #include "isl/union_map.h" |
| 37 | #include "isl/union_set.h" |
| 38 | #include <algorithm> |
| 39 | #include <cassert> |
| 40 | #include <cmath> |
| 41 | #include <cstdint> |
| 42 | #include <string> |
| 43 | #include <vector> |
| 44 | |
| 45 | #include "polly/Support/PollyDebug.h" |
| 46 | #define DEBUG_TYPE "polly-opt-isl" |
| 47 | |
| 48 | using namespace llvm; |
| 49 | using namespace polly; |
| 50 | |
| 51 | namespace llvm { |
| 52 | class Value; |
| 53 | } |
| 54 | |
| 55 | static cl::opt<int> LatencyVectorFma( |
| 56 | "polly-target-latency-vector-fma" , |
| 57 | cl::desc("The minimal number of cycles between issuing two " |
| 58 | "dependent consecutive vector fused multiply-add " |
| 59 | "instructions." ), |
| 60 | cl::Hidden, cl::init(Val: 8), cl::cat(PollyCategory)); |
| 61 | |
| 62 | static cl::opt<int> ThroughputVectorFma( |
| 63 | "polly-target-throughput-vector-fma" , |
| 64 | cl::desc("A throughput of the processor floating-point arithmetic units " |
| 65 | "expressed in the number of vector fused multiply-add " |
| 66 | "instructions per clock cycle." ), |
| 67 | cl::Hidden, cl::init(Val: 1), cl::cat(PollyCategory)); |
| 68 | |
| 69 | static cl::opt<int> FirstCacheLevelSize( |
| 70 | "polly-target-1st-cache-level-size" , |
| 71 | cl::desc("The size of the first cache level specified in bytes." ), |
| 72 | cl::Hidden, cl::init(Val: -1), cl::cat(PollyCategory)); |
| 73 | |
| 74 | static cl::opt<int> FirstCacheLevelDefaultSize( |
| 75 | "polly-target-1st-cache-level-default-size" , |
| 76 | cl::desc("The default size of the first cache level specified in bytes" |
| 77 | " (if not enough were provided by the TargetTransformInfo)." ), |
| 78 | cl::Hidden, cl::init(Val: 32768), cl::cat(PollyCategory)); |
| 79 | |
| 80 | static cl::opt<int> SecondCacheLevelSize( |
| 81 | "polly-target-2nd-cache-level-size" , |
| 82 | cl::desc("The size of the second level specified in bytes." ), cl::Hidden, |
| 83 | cl::init(Val: -1), cl::cat(PollyCategory)); |
| 84 | |
| 85 | static cl::opt<int> SecondCacheLevelDefaultSize( |
| 86 | "polly-target-2nd-cache-level-default-size" , |
| 87 | cl::desc("The default size of the second cache level specified in bytes" |
| 88 | " (if not enough were provided by the TargetTransformInfo)." ), |
| 89 | cl::Hidden, cl::init(Val: 262144), cl::cat(PollyCategory)); |
| 90 | |
| 91 | // This option, along with --polly-target-2nd-cache-level-associativity, |
| 92 | // --polly-target-1st-cache-level-size, and --polly-target-2st-cache-level-size |
| 93 | // represent the parameters of the target cache, which do not have typical |
| 94 | // values that can be used by default. However, to apply the pattern matching |
| 95 | // optimizations, we use the values of the parameters of Intel Core i7-3820 |
| 96 | // SandyBridge in case the parameters are not specified or not provided by the |
| 97 | // TargetTransformInfo. |
| 98 | static cl::opt<int> FirstCacheLevelAssociativity( |
| 99 | "polly-target-1st-cache-level-associativity" , |
| 100 | cl::desc("The associativity of the first cache level." ), cl::Hidden, |
| 101 | cl::init(Val: -1), cl::cat(PollyCategory)); |
| 102 | |
| 103 | static cl::opt<int> FirstCacheLevelDefaultAssociativity( |
| 104 | "polly-target-1st-cache-level-default-associativity" , |
| 105 | cl::desc("The default associativity of the first cache level" |
| 106 | " (if not enough were provided by the TargetTransformInfo)." ), |
| 107 | cl::Hidden, cl::init(Val: 8), cl::cat(PollyCategory)); |
| 108 | |
| 109 | static cl::opt<int> SecondCacheLevelAssociativity( |
| 110 | "polly-target-2nd-cache-level-associativity" , |
| 111 | cl::desc("The associativity of the second cache level." ), cl::Hidden, |
| 112 | cl::init(Val: -1), cl::cat(PollyCategory)); |
| 113 | |
| 114 | static cl::opt<int> SecondCacheLevelDefaultAssociativity( |
| 115 | "polly-target-2nd-cache-level-default-associativity" , |
| 116 | cl::desc("The default associativity of the second cache level" |
| 117 | " (if not enough were provided by the TargetTransformInfo)." ), |
| 118 | cl::Hidden, cl::init(Val: 8), cl::cat(PollyCategory)); |
| 119 | |
| 120 | static cl::opt<int> VectorRegisterBitwidth( |
| 121 | "polly-target-vector-register-bitwidth" , |
| 122 | cl::desc("The size in bits of a vector register (if not set, this " |
| 123 | "information is taken from LLVM's target information." ), |
| 124 | cl::Hidden, cl::init(Val: -1), cl::cat(PollyCategory)); |
| 125 | |
| 126 | static cl::opt<int> PollyPatternMatchingNcQuotient( |
| 127 | "polly-pattern-matching-nc-quotient" , |
| 128 | cl::desc("Quotient that is obtained by dividing Nc, the parameter of the" |
| 129 | "macro-kernel, by Nr, the parameter of the micro-kernel" ), |
| 130 | cl::Hidden, cl::init(Val: 256), cl::cat(PollyCategory)); |
| 131 | |
| 132 | static cl::opt<bool> |
| 133 | PMBasedTCOpts("polly-tc-opt" , |
| 134 | cl::desc("Perform optimizations of tensor contractions based " |
| 135 | "on pattern matching" ), |
| 136 | cl::init(Val: false), cl::ZeroOrMore, cl::cat(PollyCategory)); |
| 137 | |
| 138 | static cl::opt<bool> |
| 139 | PMBasedMMMOpts("polly-matmul-opt" , |
| 140 | cl::desc("Perform optimizations of matrix multiplications " |
| 141 | "based on pattern matching" ), |
| 142 | cl::init(Val: true), cl::ZeroOrMore, cl::cat(PollyCategory)); |
| 143 | |
| 144 | static cl::opt<int> OptComputeOut( |
| 145 | "polly-tc-dependences-computeout" , |
| 146 | cl::desc("Bound the dependence analysis by a maximal amount of " |
| 147 | "computational steps (0 means no bound)" ), |
| 148 | cl::Hidden, cl::init(Val: 500000), cl::ZeroOrMore, cl::cat(PollyCategory)); |
| 149 | |
| 150 | namespace { |
| 151 | /// Parameters of the micro kernel. |
| 152 | /// |
| 153 | /// Parameters, which determine sizes of rank-1 (i.e., outer product) update |
| 154 | /// used in the optimized matrix multiplication. |
| 155 | struct MicroKernelParamsTy { |
| 156 | int Mr; |
| 157 | int Nr; |
| 158 | }; |
| 159 | |
| 160 | /// Parameters of the macro kernel. |
| 161 | /// |
| 162 | /// Parameters, which determine sizes of blocks of partitioned matrices |
| 163 | /// used in the optimized matrix multiplication. |
| 164 | struct MacroKernelParamsTy { |
| 165 | int Mc; |
| 166 | int Nc; |
| 167 | int Kc; |
| 168 | }; |
| 169 | |
| 170 | /// Parameters of the matrix multiplication operands. |
| 171 | /// |
| 172 | /// Parameters, which describe access relations that represent operands of the |
| 173 | /// matrix multiplication. |
| 174 | struct MatMulInfoTy { |
| 175 | MemoryAccess *A = nullptr; |
| 176 | MemoryAccess *B = nullptr; |
| 177 | MemoryAccess *ReadFromC = nullptr; |
| 178 | MemoryAccess *WriteToC = nullptr; |
| 179 | int i = -1; |
| 180 | int j = -1; |
| 181 | int k = -1; |
| 182 | }; |
| 183 | |
| 184 | /// Parameters of the tensor contraction operands. |
| 185 | /// |
| 186 | /// A general d-dimensional tensor T ∈ R ^ Nu0 x ... x Nud−1 can be defined |
| 187 | /// as the set of scalar elements indexed by the set of indices u0 ... ud, |
| 188 | /// |
| 189 | /// T ≡ {Anu0...nud−1 ∈ R | (u0,...,ud−1) ∈ Nu0 x ... x Nud−1}. |
| 190 | /// |
| 191 | /// Let A, B, and C be dA, dB, and dC-dimensional tensors, respectively. |
| 192 | /// Let the free and the contracted indices of the tensor A be grouped into |
| 193 | /// two bundles I = i0...ir−1 and P = p0...pt−1, respectively. Similarly, |
| 194 | /// the free and the contracted indices of B are grouped into bundles |
| 195 | /// J = j0..js−1 and P and the free indices of C are grouped into |
| 196 | /// bundles I and J. |
| 197 | /// |
| 198 | /// Tensor contraction (TC) of tensors A, B into tensor C can be represented as |
| 199 | /// C(shuffle(I,J))=∑α·A(shuffle(I,P))·B(shuffle(P,J))+β·C(shuffle(I,J)), |
| 200 | /// where ∑ is a summation over all contracted indices of P, |
| 201 | /// α, β ∈ R, Npi is the length of the tensor dimension that corresponds |
| 202 | /// to the index pi, A(shuffle(I, P)), B(shuffle(P, J)), C(shuffle(I, J)) are |
| 203 | /// accesses to tensors A, B, C, respectively, |
| 204 | /// shuffle(I, J), shuffle(I, P), and shuffle(P, J) are permutations of |
| 205 | /// the enclosed indices. |
| 206 | /// |
| 207 | /// Multiplication of C(shuffle(I,J)) by β can be moved into a different SCoP |
| 208 | /// statement by loop distribution, which is done by the isl scheduler. |
| 209 | // If β is not equal to one, the optimization of TC of Polly requires |
| 210 | /// such a transformation. |
| 211 | /// |
| 212 | /// TCInfoTy contains parameters, which describe access relations that represent |
| 213 | /// operands of the tensor contraction. |
| 214 | struct TCInfoTy { |
| 215 | /// @{ |
| 216 | /// Memory accesses that represent reading from tensors, which are operands of |
| 217 | /// the tensor contraction. |
| 218 | MemoryAccess *A = nullptr; |
| 219 | MemoryAccess *B = nullptr; |
| 220 | /// @} |
| 221 | |
| 222 | /// @{ |
| 223 | /// Memory accesses that represent reading from and writing into the tensor, |
| 224 | /// which contains the result of the tensor contraction. |
| 225 | MemoryAccess *ReadFromC = nullptr; |
| 226 | MemoryAccess *WriteToC = nullptr; |
| 227 | /// @} |
| 228 | |
| 229 | /// @{ |
| 230 | /// Input dimensions of the schedule space, which represent free |
| 231 | /// indices of tensors. |
| 232 | SmallDenseSet<int> I; |
| 233 | SmallDenseSet<int> J; |
| 234 | /// @} |
| 235 | |
| 236 | /// Input dimension of the schedule space, which represents contracted |
| 237 | /// indices of tensors. |
| 238 | SmallDenseSet<int> P; |
| 239 | |
| 240 | /// @{ |
| 241 | /// Sizes of tensor dimensions for corresponding input dimensions of |
| 242 | /// the schedule space. The size of the tensor dimension can be larger than |
| 243 | /// the size of the corresponding input dimension of the schedule space. |
| 244 | /// This does not correspond to a tensor contraction. However, such a pattern |
| 245 | /// will be optimized by the transformation. |
| 246 | SmallVector<int> DimensionSizes; |
| 247 | SmallVector<int> ADimensions; |
| 248 | SmallVector<int> BDimensions; |
| 249 | SmallVector<int> CDimensions; |
| 250 | /// @} |
| 251 | |
| 252 | /// @{ |
| 253 | /// Permutations of indices of I, J, and P, which describe operands of |
| 254 | /// the tensor contraction and its result. |
| 255 | SmallVector<int> OrderedI; |
| 256 | SmallVector<int> OrderedJ; |
| 257 | SmallVector<int> OrderedP; |
| 258 | /// @} |
| 259 | }; |
| 260 | |
| 261 | /// Create an isl::union_set, which describes the option of the form |
| 262 | /// [isolate[] -> unroll[x]]. |
| 263 | /// |
| 264 | /// @param Ctx An isl::ctx, which is used to create the isl::union_set. |
| 265 | static isl::union_set getUnrollIsolatedSetOptions(isl::ctx Ctx) { |
| 266 | isl::space Space = isl::space(Ctx, 0, 0, 1); |
| 267 | isl::map UnrollIsolatedSetOption = isl::map::universe(space: Space); |
| 268 | isl::id DimInId = isl::id::alloc(ctx: Ctx, name: "isolate" , user: nullptr); |
| 269 | isl::id DimOutId = isl::id::alloc(ctx: Ctx, name: "unroll" , user: nullptr); |
| 270 | UnrollIsolatedSetOption = |
| 271 | UnrollIsolatedSetOption.set_tuple_id(type: isl::dim::in, id: DimInId); |
| 272 | UnrollIsolatedSetOption = |
| 273 | UnrollIsolatedSetOption.set_tuple_id(type: isl::dim::out, id: DimOutId); |
| 274 | return UnrollIsolatedSetOption.wrap(); |
| 275 | } |
| 276 | |
| 277 | /// Permute the two dimensions of the isl map. |
| 278 | /// |
| 279 | /// Permute @p DstPos and @p SrcPos dimensions of the isl map @p Map that |
| 280 | /// have type @p DimType. |
| 281 | /// |
| 282 | /// @param Map The isl map to be modified. |
| 283 | /// @param DimType The type of the dimensions. |
| 284 | /// @param DstPos The first dimension. |
| 285 | /// @param SrcPos The second dimension. |
| 286 | /// @return The modified map. |
| 287 | static isl::map permuteDimensions(isl::map Map, isl::dim DimType, |
| 288 | unsigned DstPos, unsigned SrcPos) { |
| 289 | assert(DstPos < unsignedFromIslSize(Map.dim(DimType)) && |
| 290 | SrcPos < unsignedFromIslSize(Map.dim(DimType))); |
| 291 | if (DstPos == SrcPos) |
| 292 | return Map; |
| 293 | isl::id DimId; |
| 294 | if (Map.has_tuple_id(type: DimType)) |
| 295 | DimId = Map.get_tuple_id(type: DimType); |
| 296 | auto FreeDim = DimType == isl::dim::in ? isl::dim::out : isl::dim::in; |
| 297 | isl::id FreeDimId; |
| 298 | if (Map.has_tuple_id(type: FreeDim)) |
| 299 | FreeDimId = Map.get_tuple_id(type: FreeDim); |
| 300 | auto MaxDim = std::max(a: DstPos, b: SrcPos); |
| 301 | auto MinDim = std::min(a: DstPos, b: SrcPos); |
| 302 | Map = Map.move_dims(dst_type: FreeDim, dst_pos: 0, src_type: DimType, src_pos: MaxDim, n: 1); |
| 303 | Map = Map.move_dims(dst_type: FreeDim, dst_pos: 0, src_type: DimType, src_pos: MinDim, n: 1); |
| 304 | Map = Map.move_dims(dst_type: DimType, dst_pos: MinDim, src_type: FreeDim, src_pos: 1, n: 1); |
| 305 | Map = Map.move_dims(dst_type: DimType, dst_pos: MaxDim, src_type: FreeDim, src_pos: 0, n: 1); |
| 306 | if (!DimId.is_null()) |
| 307 | Map = Map.set_tuple_id(type: DimType, id: DimId); |
| 308 | if (!FreeDimId.is_null()) |
| 309 | Map = Map.set_tuple_id(type: FreeDim, id: FreeDimId); |
| 310 | return Map; |
| 311 | } |
| 312 | |
| 313 | /// Check the form of the access relation. |
| 314 | /// |
| 315 | /// Check that the access relation @p AccMap has the form M[i][j], where i |
| 316 | /// is a @p FirstPos and j is a @p SecondPos. |
| 317 | /// |
| 318 | /// @param AccMap The access relation to be checked. |
| 319 | /// @param FirstPos The index of the input dimension that is mapped to |
| 320 | /// the first output dimension. |
| 321 | /// @param SecondPos The index of the input dimension that is mapped to the |
| 322 | /// second output dimension. |
| 323 | /// @return True in case @p AccMap has the expected form and false, |
| 324 | /// otherwise. |
| 325 | static bool isMatMulOperandAcc(isl::set Domain, isl::map AccMap, int &FirstPos, |
| 326 | int &SecondPos) { |
| 327 | isl::space Space = AccMap.get_space(); |
| 328 | isl::map Universe = isl::map::universe(space: Space); |
| 329 | |
| 330 | if (unsignedFromIslSize(Size: Space.dim(type: isl::dim::out)) != 2) |
| 331 | return false; |
| 332 | |
| 333 | // MatMul has the form: |
| 334 | // for (i = 0; i < N; i++) |
| 335 | // for (j = 0; j < M; j++) |
| 336 | // for (k = 0; k < P; k++) |
| 337 | // C[i, j] += A[i, k] * B[k, j] |
| 338 | // |
| 339 | // Permutation of three outer loops: 3! = 6 possibilities. |
| 340 | int FirstDims[] = {0, 0, 1, 1, 2, 2}; |
| 341 | int SecondDims[] = {1, 2, 2, 0, 0, 1}; |
| 342 | for (int i = 0; i < 6; i += 1) { |
| 343 | auto PossibleMatMul = |
| 344 | Universe.equate(type1: isl::dim::in, pos1: FirstDims[i], type2: isl::dim::out, pos2: 0) |
| 345 | .equate(type1: isl::dim::in, pos1: SecondDims[i], type2: isl::dim::out, pos2: 1); |
| 346 | |
| 347 | AccMap = AccMap.intersect_domain(set: Domain); |
| 348 | PossibleMatMul = PossibleMatMul.intersect_domain(set: Domain); |
| 349 | |
| 350 | // If AccMap spans entire domain (Non-partial write), |
| 351 | // compute FirstPos and SecondPos. |
| 352 | // If AccMap != PossibleMatMul here (the two maps have been gisted at |
| 353 | // this point), it means that the writes are not complete, or in other |
| 354 | // words, it is a Partial write and Partial writes must be rejected. |
| 355 | if (AccMap.is_equal(map2: PossibleMatMul)) { |
| 356 | if (FirstPos != -1 && FirstPos != FirstDims[i]) |
| 357 | continue; |
| 358 | FirstPos = FirstDims[i]; |
| 359 | if (SecondPos != -1 && SecondPos != SecondDims[i]) |
| 360 | continue; |
| 361 | SecondPos = SecondDims[i]; |
| 362 | return true; |
| 363 | } |
| 364 | } |
| 365 | |
| 366 | return false; |
| 367 | } |
| 368 | |
| 369 | /// Does the memory access represent a non-scalar operand of the matrix |
| 370 | /// multiplication. |
| 371 | /// |
| 372 | /// Check that the memory access @p MemAccess is the read access to a non-scalar |
| 373 | /// operand of the matrix multiplication or its result. |
| 374 | /// |
| 375 | /// @param MemAccess The memory access to be checked. |
| 376 | /// @param MMI Parameters of the matrix multiplication operands. |
| 377 | /// @return True in case the memory access represents the read access |
| 378 | /// to a non-scalar operand of the matrix multiplication and |
| 379 | /// false, otherwise. |
| 380 | static bool isMatMulNonScalarReadAccess(MemoryAccess *MemAccess, |
| 381 | MatMulInfoTy &MMI) { |
| 382 | if (!MemAccess->isLatestArrayKind() || !MemAccess->isRead()) |
| 383 | return false; |
| 384 | auto AccMap = MemAccess->getLatestAccessRelation(); |
| 385 | isl::set StmtDomain = MemAccess->getStatement()->getDomain(); |
| 386 | if (isMatMulOperandAcc(Domain: StmtDomain, AccMap, FirstPos&: MMI.i, SecondPos&: MMI.j) && !MMI.ReadFromC) { |
| 387 | MMI.ReadFromC = MemAccess; |
| 388 | return true; |
| 389 | } |
| 390 | if (isMatMulOperandAcc(Domain: StmtDomain, AccMap, FirstPos&: MMI.i, SecondPos&: MMI.k) && !MMI.A) { |
| 391 | MMI.A = MemAccess; |
| 392 | return true; |
| 393 | } |
| 394 | if (isMatMulOperandAcc(Domain: StmtDomain, AccMap, FirstPos&: MMI.k, SecondPos&: MMI.j) && !MMI.B) { |
| 395 | MMI.B = MemAccess; |
| 396 | return true; |
| 397 | } |
| 398 | return false; |
| 399 | } |
| 400 | |
| 401 | /// Check accesses to operands of the matrix multiplication. |
| 402 | /// |
| 403 | /// Check that accesses of the SCoP statement, which corresponds to |
| 404 | /// the partial schedule @p PartialSchedule, are scalar in terms of loops |
| 405 | /// containing the matrix multiplication, in case they do not represent |
| 406 | /// accesses to the non-scalar operands of the matrix multiplication or |
| 407 | /// its result. |
| 408 | /// |
| 409 | /// @param PartialSchedule The partial schedule of the SCoP statement. |
| 410 | /// @param MMI Parameters of the matrix multiplication operands. |
| 411 | /// @return True in case the corresponding SCoP statement |
| 412 | /// represents matrix multiplication and false, |
| 413 | /// otherwise. |
| 414 | static bool containsOnlyMatrMultAcc(isl::map PartialSchedule, |
| 415 | MatMulInfoTy &MMI) { |
| 416 | auto InputDimId = PartialSchedule.get_tuple_id(type: isl::dim::in); |
| 417 | auto *Stmt = static_cast<ScopStmt *>(InputDimId.get_user()); |
| 418 | unsigned OutDimNum = unsignedFromIslSize(Size: PartialSchedule.range_tuple_dim()); |
| 419 | assert(OutDimNum > 2 && "In case of the matrix multiplication the loop nest " |
| 420 | "and, consequently, the corresponding scheduling " |
| 421 | "functions have at least three dimensions." ); |
| 422 | auto MapI = |
| 423 | permuteDimensions(Map: PartialSchedule, DimType: isl::dim::out, DstPos: MMI.i, SrcPos: OutDimNum - 1); |
| 424 | auto MapJ = |
| 425 | permuteDimensions(Map: PartialSchedule, DimType: isl::dim::out, DstPos: MMI.j, SrcPos: OutDimNum - 1); |
| 426 | auto MapK = |
| 427 | permuteDimensions(Map: PartialSchedule, DimType: isl::dim::out, DstPos: MMI.k, SrcPos: OutDimNum - 1); |
| 428 | |
| 429 | auto Accesses = getAccessesInOrder(Stmt&: *Stmt); |
| 430 | for (auto *MemA = Accesses.begin(); MemA != Accesses.end() - 1; MemA++) { |
| 431 | auto *MemAccessPtr = *MemA; |
| 432 | if (MemAccessPtr->isLatestArrayKind() && MemAccessPtr != MMI.WriteToC && |
| 433 | !isMatMulNonScalarReadAccess(MemAccess: MemAccessPtr, MMI) && |
| 434 | !(MemAccessPtr->isStrideZero(Schedule: MapI) && |
| 435 | MemAccessPtr->isStrideZero(Schedule: MapJ) && MemAccessPtr->isStrideZero(Schedule: MapK))) |
| 436 | return false; |
| 437 | } |
| 438 | return true; |
| 439 | } |
| 440 | |
| 441 | /// Check for dependencies corresponding to the matrix multiplication. |
| 442 | /// |
| 443 | /// Check that there is only true dependence of the form |
| 444 | /// S(..., k, ...) -> S(..., k + 1, …), where S is the SCoP statement |
| 445 | /// represented by @p Schedule and k is @p Pos. Such a dependence corresponds |
| 446 | /// to the dependency produced by the matrix multiplication. |
| 447 | /// |
| 448 | /// @param Schedule The schedule of the SCoP statement. |
| 449 | /// @param D The SCoP dependencies. |
| 450 | /// @param Pos The parameter to describe an acceptable true dependence. |
| 451 | /// In case it has a negative value, try to determine its |
| 452 | /// acceptable value. |
| 453 | /// @return True in case dependencies correspond to the matrix multiplication |
| 454 | /// and false, otherwise. |
| 455 | static bool containsOnlyMatMulDep(isl::map Schedule, const Dependences *D, |
| 456 | int &Pos) { |
| 457 | isl::union_map Dep = D->getDependences(Kinds: Dependences::TYPE_RAW); |
| 458 | isl::union_map Red = D->getDependences(Kinds: Dependences::TYPE_RED); |
| 459 | if (!Red.is_null()) |
| 460 | Dep = Dep.unite(umap2: Red); |
| 461 | auto DomainSpace = Schedule.get_space().domain(); |
| 462 | auto Space = DomainSpace.map_from_domain_and_range(range: DomainSpace); |
| 463 | auto Deltas = Dep.extract_map(space: Space).deltas(); |
| 464 | int DeltasDimNum = unsignedFromIslSize(Size: Deltas.dim(type: isl::dim::set)); |
| 465 | for (int i = 0; i < DeltasDimNum; i++) { |
| 466 | auto Val = Deltas.plain_get_val_if_fixed(type: isl::dim::set, pos: i); |
| 467 | Pos = Pos < 0 && Val.is_one() ? i : Pos; |
| 468 | if (Val.is_nan() || !(Val.is_zero() || (i == Pos && Val.is_one()))) |
| 469 | return false; |
| 470 | } |
| 471 | if (DeltasDimNum == 0 || Pos < 0) |
| 472 | return false; |
| 473 | return true; |
| 474 | } |
| 475 | |
| 476 | /// Check if the SCoP statement could probably be optimized with analytical |
| 477 | /// modeling. |
| 478 | /// |
| 479 | /// containsMatrMult tries to determine whether the following conditions |
| 480 | /// are true: |
| 481 | /// 1. The last memory access modeling an array, MA1, represents writing to |
| 482 | /// memory and has the form S(..., i1, ..., i2, ...) -> M(i1, i2) or |
| 483 | /// S(..., i2, ..., i1, ...) -> M(i1, i2), where S is the SCoP statement |
| 484 | /// under consideration. |
| 485 | /// 2. There is only one loop-carried true dependency, and it has the |
| 486 | /// form S(..., i3, ...) -> S(..., i3 + 1, ...), and there are no |
| 487 | /// loop-carried or anti dependencies. |
| 488 | /// 3. SCoP contains three access relations, MA2, MA3, and MA4 that represent |
| 489 | /// reading from memory and have the form S(..., i3, ...) -> M(i1, i3), |
| 490 | /// S(..., i3, ...) -> M(i3, i2), S(...) -> M(i1, i2), respectively, |
| 491 | /// and all memory accesses of the SCoP that are different from MA1, MA2, |
| 492 | /// MA3, and MA4 have stride 0, if the innermost loop is exchanged with any |
| 493 | /// of loops i1, i2 and i3. |
| 494 | /// |
| 495 | /// @param PartialSchedule The PartialSchedule that contains a SCoP statement |
| 496 | /// to check. |
| 497 | /// @D The SCoP dependencies. |
| 498 | /// @MMI Parameters of the matrix multiplication operands. |
| 499 | static bool containsMatrMult(isl::map PartialSchedule, const Dependences *D, |
| 500 | MatMulInfoTy &MMI) { |
| 501 | auto InputDimsId = PartialSchedule.get_tuple_id(type: isl::dim::in); |
| 502 | auto *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user()); |
| 503 | if (Stmt->size() <= 1) |
| 504 | return false; |
| 505 | |
| 506 | auto Accesses = getAccessesInOrder(Stmt&: *Stmt); |
| 507 | for (auto *MemA = Accesses.end() - 1; MemA != Accesses.begin(); MemA--) { |
| 508 | auto *MemAccessPtr = *MemA; |
| 509 | if (!MemAccessPtr->isLatestArrayKind()) |
| 510 | continue; |
| 511 | if (!MemAccessPtr->isWrite()) |
| 512 | return false; |
| 513 | auto AccMap = MemAccessPtr->getLatestAccessRelation(); |
| 514 | if (!isMatMulOperandAcc(Domain: Stmt->getDomain(), AccMap, FirstPos&: MMI.i, SecondPos&: MMI.j)) |
| 515 | return false; |
| 516 | MMI.WriteToC = MemAccessPtr; |
| 517 | break; |
| 518 | } |
| 519 | |
| 520 | if (!containsOnlyMatMulDep(Schedule: PartialSchedule, D, Pos&: MMI.k)) |
| 521 | return false; |
| 522 | |
| 523 | if (!MMI.WriteToC || !containsOnlyMatrMultAcc(PartialSchedule, MMI)) |
| 524 | return false; |
| 525 | |
| 526 | if (!MMI.A || !MMI.B || !MMI.ReadFromC) |
| 527 | return false; |
| 528 | return true; |
| 529 | } |
| 530 | |
| 531 | /// Permute two dimensions of the band node. |
| 532 | /// |
| 533 | /// Permute FirstDim and SecondDim dimensions of the Node. |
| 534 | /// |
| 535 | /// @param Node The band node to be modified. |
| 536 | /// @param FirstDim The first dimension to be permuted. |
| 537 | /// @param SecondDim The second dimension to be permuted. |
| 538 | static isl::schedule_node permuteBandNodeDimensions(isl::schedule_node Node, |
| 539 | unsigned FirstDim, |
| 540 | unsigned SecondDim) { |
| 541 | assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band && |
| 542 | (unsigned)isl_schedule_node_band_n_member(Node.get()) > |
| 543 | std::max(FirstDim, SecondDim)); |
| 544 | auto PartialSchedule = |
| 545 | isl::manage(ptr: isl_schedule_node_band_get_partial_schedule(node: Node.get())); |
| 546 | auto PartialScheduleFirstDim = PartialSchedule.at(pos: FirstDim); |
| 547 | auto PartialScheduleSecondDim = PartialSchedule.at(pos: SecondDim); |
| 548 | PartialSchedule = |
| 549 | PartialSchedule.set_union_pw_aff(pos: SecondDim, el: PartialScheduleFirstDim); |
| 550 | PartialSchedule = |
| 551 | PartialSchedule.set_union_pw_aff(pos: FirstDim, el: PartialScheduleSecondDim); |
| 552 | Node = isl::manage(ptr: isl_schedule_node_delete(node: Node.release())); |
| 553 | return Node.insert_partial_schedule(schedule: PartialSchedule); |
| 554 | } |
| 555 | |
| 556 | static isl::schedule_node |
| 557 | createMicroKernel(isl::schedule_node Node, |
| 558 | MicroKernelParamsTy MicroKernelParams) { |
| 559 | Node = applyRegisterTiling(Node, TileSizes: {MicroKernelParams.Mr, MicroKernelParams.Nr}, |
| 560 | DefaultTileSize: 1); |
| 561 | Node = Node.parent().parent(); |
| 562 | return permuteBandNodeDimensions(Node, FirstDim: 0, SecondDim: 1).child(pos: 0).child(pos: 0); |
| 563 | } |
| 564 | |
| 565 | /// Create the BLIS macro-kernel. |
| 566 | /// |
| 567 | /// We create the BLIS macro-kernel by applying a combination of tiling |
| 568 | /// of dimensions of the band node and interchanging of two innermost |
| 569 | /// modified dimensions. The values of MacroKernelParams's fields are used |
| 570 | /// as tile sizes. |
| 571 | /// |
| 572 | /// @param Node The schedule node to be modified. |
| 573 | /// @param MacroKernelParams Parameters of the macro kernel |
| 574 | /// to be used as tile sizes. |
| 575 | static isl::schedule_node |
| 576 | createMacroKernel(isl::schedule_node Node, |
| 577 | MacroKernelParamsTy MacroKernelParams) { |
| 578 | assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band); |
| 579 | if (MacroKernelParams.Mc == 1 && MacroKernelParams.Nc == 1 && |
| 580 | MacroKernelParams.Kc == 1) |
| 581 | return Node; |
| 582 | int DimOutNum = isl_schedule_node_band_n_member(node: Node.get()); |
| 583 | std::vector<int> TileSizes(DimOutNum, 1); |
| 584 | TileSizes[DimOutNum - 3] = MacroKernelParams.Mc; |
| 585 | TileSizes[DimOutNum - 2] = MacroKernelParams.Nc; |
| 586 | TileSizes[DimOutNum - 1] = MacroKernelParams.Kc; |
| 587 | Node = tileNode(Node, Identifier: "1st level tiling" , TileSizes, DefaultTileSize: 1); |
| 588 | Node = Node.parent().parent(); |
| 589 | Node = permuteBandNodeDimensions(Node, FirstDim: DimOutNum - 2, SecondDim: DimOutNum - 1); |
| 590 | Node = permuteBandNodeDimensions(Node, FirstDim: DimOutNum - 3, SecondDim: DimOutNum - 1); |
| 591 | |
| 592 | return Node.child(pos: 0).child(pos: 0); |
| 593 | } |
| 594 | |
| 595 | /// Get the size of the widest type of the matrix multiplication operands |
| 596 | /// in bytes, including alignment padding. |
| 597 | /// |
| 598 | /// @param MMI Parameters of the matrix multiplication operands. |
| 599 | /// @return The size of the widest type of the matrix multiplication operands |
| 600 | /// in bytes, including alignment padding. |
| 601 | static uint64_t getMatMulAlignTypeSize(const MatMulInfoTy &MMI) { |
| 602 | auto *S = MMI.A->getStatement()->getParent(); |
| 603 | auto &DL = S->getFunction().getParent()->getDataLayout(); |
| 604 | auto ElementSizeA = DL.getTypeAllocSize(Ty: MMI.A->getElementType()); |
| 605 | auto ElementSizeB = DL.getTypeAllocSize(Ty: MMI.B->getElementType()); |
| 606 | auto ElementSizeC = DL.getTypeAllocSize(Ty: MMI.WriteToC->getElementType()); |
| 607 | return std::max(l: {ElementSizeA, ElementSizeB, ElementSizeC}); |
| 608 | } |
| 609 | |
| 610 | /// Get the size of the widest type of the matrix multiplication operands |
| 611 | /// in bits. |
| 612 | /// |
| 613 | /// @param MMI Parameters of the matrix multiplication operands. |
| 614 | /// @return The size of the widest type of the matrix multiplication operands |
| 615 | /// in bits. |
| 616 | static uint64_t getMatMulTypeSize(const MatMulInfoTy &MMI) { |
| 617 | auto *S = MMI.A->getStatement()->getParent(); |
| 618 | auto &DL = S->getFunction().getParent()->getDataLayout(); |
| 619 | auto ElementSizeA = DL.getTypeSizeInBits(Ty: MMI.A->getElementType()); |
| 620 | auto ElementSizeB = DL.getTypeSizeInBits(Ty: MMI.B->getElementType()); |
| 621 | auto ElementSizeC = DL.getTypeSizeInBits(Ty: MMI.WriteToC->getElementType()); |
| 622 | return std::max(l: {ElementSizeA, ElementSizeB, ElementSizeC}); |
| 623 | } |
| 624 | |
| 625 | /// Get parameters of the BLIS micro kernel. |
| 626 | /// |
| 627 | /// We choose the Mr and Nr parameters of the micro kernel to be large enough |
| 628 | /// such that no stalls caused by the combination of latencies and dependencies |
| 629 | /// are introduced during the updates of the resulting matrix of the matrix |
| 630 | /// multiplication. However, they should also be as small as possible to |
| 631 | /// release more registers for entries of multiplied matrices. |
| 632 | /// |
| 633 | /// @param TTI Target Transform Info. |
| 634 | /// @param MMI Parameters of the matrix multiplication operands. |
| 635 | /// @return The structure of type MicroKernelParamsTy. |
| 636 | /// @see MicroKernelParamsTy |
| 637 | static MicroKernelParamsTy getMicroKernelParams(const TargetTransformInfo *TTI, |
| 638 | const MatMulInfoTy &MMI) { |
| 639 | assert(TTI && "The target transform info should be provided." ); |
| 640 | |
| 641 | // Nvec - Number of double-precision floating-point numbers that can be hold |
| 642 | // by a vector register. Use 2 by default. |
| 643 | long RegisterBitwidth = VectorRegisterBitwidth; |
| 644 | |
| 645 | if (RegisterBitwidth == -1) |
| 646 | RegisterBitwidth = |
| 647 | TTI->getRegisterBitWidth(K: TargetTransformInfo::RGK_FixedWidthVector); |
| 648 | auto ElementSize = getMatMulTypeSize(MMI); |
| 649 | assert(ElementSize > 0 && "The element size of the matrix multiplication " |
| 650 | "operands should be greater than zero." ); |
| 651 | auto Nvec = RegisterBitwidth / ElementSize; |
| 652 | if (Nvec == 0) |
| 653 | Nvec = 2; |
| 654 | int Nr = ceil(x: sqrt(x: (double)(Nvec * LatencyVectorFma * ThroughputVectorFma)) / |
| 655 | Nvec) * |
| 656 | Nvec; |
| 657 | int Mr = ceil(x: (double)(Nvec * LatencyVectorFma * ThroughputVectorFma / Nr)); |
| 658 | return {.Mr: Mr, .Nr: Nr}; |
| 659 | } |
| 660 | |
| 661 | /// Determine parameters of the target cache. |
| 662 | /// |
| 663 | /// @param TTI Target Transform Info. |
| 664 | static void getTargetCacheParameters(const llvm::TargetTransformInfo *TTI) { |
| 665 | auto L1DCache = llvm::TargetTransformInfo::CacheLevel::L1D; |
| 666 | auto L2DCache = llvm::TargetTransformInfo::CacheLevel::L2D; |
| 667 | if (FirstCacheLevelSize == -1) { |
| 668 | if (TTI->getCacheSize(Level: L1DCache)) |
| 669 | FirstCacheLevelSize = TTI->getCacheSize(Level: L1DCache).value(); |
| 670 | else |
| 671 | FirstCacheLevelSize = static_cast<int>(FirstCacheLevelDefaultSize); |
| 672 | } |
| 673 | if (SecondCacheLevelSize == -1) { |
| 674 | if (TTI->getCacheSize(Level: L2DCache)) |
| 675 | SecondCacheLevelSize = TTI->getCacheSize(Level: L2DCache).value(); |
| 676 | else |
| 677 | SecondCacheLevelSize = static_cast<int>(SecondCacheLevelDefaultSize); |
| 678 | } |
| 679 | if (FirstCacheLevelAssociativity == -1) { |
| 680 | if (TTI->getCacheAssociativity(Level: L1DCache)) |
| 681 | FirstCacheLevelAssociativity = |
| 682 | TTI->getCacheAssociativity(Level: L1DCache).value(); |
| 683 | else |
| 684 | FirstCacheLevelAssociativity = |
| 685 | static_cast<int>(FirstCacheLevelDefaultAssociativity); |
| 686 | } |
| 687 | if (SecondCacheLevelAssociativity == -1) { |
| 688 | if (TTI->getCacheAssociativity(Level: L2DCache)) |
| 689 | SecondCacheLevelAssociativity = |
| 690 | TTI->getCacheAssociativity(Level: L2DCache).value(); |
| 691 | else |
| 692 | SecondCacheLevelAssociativity = |
| 693 | static_cast<int>(SecondCacheLevelDefaultAssociativity); |
| 694 | } |
| 695 | } |
| 696 | |
| 697 | /// Get parameters of the BLIS macro kernel. |
| 698 | /// |
| 699 | /// During the computation of matrix multiplication, blocks of partitioned |
| 700 | /// matrices are mapped to different layers of the memory hierarchy. |
| 701 | /// To optimize data reuse, blocks should be ideally kept in cache between |
| 702 | /// iterations. Since parameters of the macro kernel determine sizes of these |
| 703 | /// blocks, there are upper and lower bounds on these parameters. |
| 704 | /// |
| 705 | /// @param TTI Target Transform Info. |
| 706 | /// @param MicroKernelParams Parameters of the micro-kernel |
| 707 | /// to be taken into account. |
| 708 | /// @param MMI Parameters of the matrix multiplication operands. |
| 709 | /// @return The structure of type MacroKernelParamsTy. |
| 710 | /// @see MacroKernelParamsTy |
| 711 | /// @see MicroKernelParamsTy |
| 712 | static MacroKernelParamsTy |
| 713 | getMacroKernelParams(const llvm::TargetTransformInfo *TTI, |
| 714 | const MicroKernelParamsTy &MicroKernelParams, |
| 715 | const MatMulInfoTy &MMI) { |
| 716 | getTargetCacheParameters(TTI); |
| 717 | // According to www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf, |
| 718 | // it requires information about the first two levels of a cache to determine |
| 719 | // all the parameters of a macro-kernel. It also checks that an associativity |
| 720 | // degree of a cache level is greater than two. Otherwise, another algorithm |
| 721 | // for determination of the parameters should be used. |
| 722 | if (!(MicroKernelParams.Mr > 0 && MicroKernelParams.Nr > 0 && |
| 723 | FirstCacheLevelSize > 0 && SecondCacheLevelSize > 0 && |
| 724 | FirstCacheLevelAssociativity > 2 && SecondCacheLevelAssociativity > 2)) |
| 725 | return {.Mc: 1, .Nc: 1, .Kc: 1}; |
| 726 | // The quotient should be greater than zero. |
| 727 | if (PollyPatternMatchingNcQuotient <= 0) |
| 728 | return {.Mc: 1, .Nc: 1, .Kc: 1}; |
| 729 | int Car = floor( |
| 730 | x: (FirstCacheLevelAssociativity - 1) / |
| 731 | (1 + static_cast<double>(MicroKernelParams.Nr) / MicroKernelParams.Mr)); |
| 732 | |
| 733 | // Car can be computed to be zero since it is floor to int. |
| 734 | // On Mac OS, division by 0 does not raise a signal. This causes negative |
| 735 | // tile sizes to be computed. Prevent division by Cac==0 by early returning |
| 736 | // if this happens. |
| 737 | if (Car == 0) |
| 738 | return {.Mc: 1, .Nc: 1, .Kc: 1}; |
| 739 | |
| 740 | auto ElementSize = getMatMulAlignTypeSize(MMI); |
| 741 | assert(ElementSize > 0 && "The element size of the matrix multiplication " |
| 742 | "operands should be greater than zero." ); |
| 743 | int Kc = (Car * FirstCacheLevelSize) / |
| 744 | (MicroKernelParams.Mr * FirstCacheLevelAssociativity * ElementSize); |
| 745 | double Cac = |
| 746 | static_cast<double>(Kc * ElementSize * SecondCacheLevelAssociativity) / |
| 747 | SecondCacheLevelSize; |
| 748 | int Mc = floor(x: (SecondCacheLevelAssociativity - 2) / Cac); |
| 749 | int Nc = PollyPatternMatchingNcQuotient * MicroKernelParams.Nr; |
| 750 | |
| 751 | assert(Mc > 0 && Nc > 0 && Kc > 0 && |
| 752 | "Matrix block sizes should be greater than zero" ); |
| 753 | return {.Mc: Mc, .Nc: Nc, .Kc: Kc}; |
| 754 | } |
| 755 | |
| 756 | /// Create an access relation that is specific to |
| 757 | /// the matrix multiplication pattern. |
| 758 | /// |
| 759 | /// Create an access relation of the following form: |
| 760 | /// [O0, O1, O2, O3, O4, O5, O6, O7, O8] -> [OI, O5, OJ] |
| 761 | /// where I is @p FirstDim, J is @p SecondDim. |
| 762 | /// |
| 763 | /// It can be used, for example, to create relations that helps to consequently |
| 764 | /// access elements of operands of a matrix multiplication after creation of |
| 765 | /// the BLIS micro and macro kernels. |
| 766 | /// |
| 767 | /// @see ScheduleTreeOptimizer::createMicroKernel |
| 768 | /// @see ScheduleTreeOptimizer::createMacroKernel |
| 769 | /// |
| 770 | /// Subsequently, the described access relation is applied to the range of |
| 771 | /// @p MapOldIndVar, that is used to map original induction variables to |
| 772 | /// the ones, which are produced by schedule transformations. It helps to |
| 773 | /// define relations using a new space and, at the same time, keep them |
| 774 | /// in the original one. |
| 775 | /// |
| 776 | /// @param MapOldIndVar The relation, which maps original induction variables |
| 777 | /// to the ones, which are produced by schedule |
| 778 | /// transformations. |
| 779 | /// @param FirstDim, SecondDim The input dimensions that are used to define |
| 780 | /// the specified access relation. |
| 781 | /// @return The specified access relation. |
| 782 | static isl::map getMatMulAccRel(isl::map MapOldIndVar, unsigned FirstDim, |
| 783 | unsigned SecondDim) { |
| 784 | auto AccessRelSpace = isl::space(MapOldIndVar.ctx(), 0, 9, 3); |
| 785 | auto AccessRel = isl::map::universe(space: AccessRelSpace); |
| 786 | AccessRel = AccessRel.equate(type1: isl::dim::in, pos1: FirstDim, type2: isl::dim::out, pos2: 0); |
| 787 | AccessRel = AccessRel.equate(type1: isl::dim::in, pos1: 5, type2: isl::dim::out, pos2: 1); |
| 788 | AccessRel = AccessRel.equate(type1: isl::dim::in, pos1: SecondDim, type2: isl::dim::out, pos2: 2); |
| 789 | return MapOldIndVar.apply_range(map2: AccessRel); |
| 790 | } |
| 791 | |
| 792 | static isl::schedule_node createExtensionNode(isl::schedule_node Node, |
| 793 | isl::map ExtensionMap) { |
| 794 | auto Extension = isl::union_map(ExtensionMap); |
| 795 | auto NewNode = isl::schedule_node::from_extension(extension: Extension); |
| 796 | return Node.graft_before(graft: NewNode); |
| 797 | } |
| 798 | |
| 799 | static isl::schedule_node optimizePackedB(isl::schedule_node Node, |
| 800 | ScopStmt *Stmt, isl::map MapOldIndVar, |
| 801 | MicroKernelParamsTy MicroParams, |
| 802 | MacroKernelParamsTy MacroParams, |
| 803 | MatMulInfoTy &MMI) { |
| 804 | Scop *S = Stmt->getParent(); |
| 805 | isl::set Domain = Stmt->getDomain(); |
| 806 | |
| 807 | // Create packed array. |
| 808 | unsigned FirstDimSize = MacroParams.Nc / MicroParams.Nr; |
| 809 | unsigned SecondDimSize = MacroParams.Kc; |
| 810 | unsigned ThirdDimSize = MicroParams.Nr; |
| 811 | ScopArrayInfo *PackedB = |
| 812 | S->createScopArrayInfo(ElementType: MMI.B->getElementType(), BaseName: "Packed_B" , |
| 813 | Sizes: {FirstDimSize, SecondDimSize, ThirdDimSize}); |
| 814 | |
| 815 | // Compute the access relation for copying from B to PackedB. |
| 816 | isl::map AccRelB = MMI.B->getLatestAccessRelation(); |
| 817 | isl::map AccRelPackedB = getMatMulAccRel(MapOldIndVar, FirstDim: 3, SecondDim: 7); |
| 818 | AccRelPackedB = |
| 819 | AccRelPackedB.set_tuple_id(type: isl::dim::out, id: PackedB->getBasePtrId()); |
| 820 | |
| 821 | // Create the copy statement and redirect access. |
| 822 | ScopStmt *CopyStmt = S->addScopStmt(SourceRel: AccRelB, TargetRel: AccRelPackedB, Domain); |
| 823 | MMI.B->setNewAccessRelation(AccRelPackedB); |
| 824 | |
| 825 | unsigned Dim = unsignedFromIslSize(Size: MapOldIndVar.range_tuple_dim()); |
| 826 | assert(Dim >= 2); |
| 827 | // Insert into the schedule tree. |
| 828 | isl::map ExtMap = MapOldIndVar.project_out(type: isl::dim::out, first: 2, n: Dim - 2); |
| 829 | ExtMap = ExtMap.reverse(); |
| 830 | ExtMap = ExtMap.fix_si(type: isl::dim::out, pos: MMI.i, value: 0); |
| 831 | ExtMap = ExtMap.intersect_range(set: Domain); |
| 832 | ExtMap = ExtMap.set_tuple_id(type: isl::dim::out, id: CopyStmt->getDomainId()); |
| 833 | return createExtensionNode(Node, ExtensionMap: ExtMap); |
| 834 | } |
| 835 | |
| 836 | static isl::schedule_node optimizePackedA(isl::schedule_node Node, ScopStmt *, |
| 837 | isl::map MapOldIndVar, |
| 838 | MicroKernelParamsTy MicroParams, |
| 839 | MacroKernelParamsTy MacroParams, |
| 840 | MatMulInfoTy &MMI) { |
| 841 | isl::id InputDimsId = MapOldIndVar.get_tuple_id(type: isl::dim::in); |
| 842 | ScopStmt *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user()); |
| 843 | isl::set Domain = Stmt->getDomain(); |
| 844 | isl::id DomainId = Domain.get_tuple_id(); |
| 845 | |
| 846 | // Create the packed array. |
| 847 | unsigned FirstDimSize = MacroParams.Mc / MicroParams.Mr; |
| 848 | unsigned SecondDimSize = MacroParams.Kc; |
| 849 | unsigned ThirdDimSize = MicroParams.Mr; |
| 850 | ScopArrayInfo *PackedA = Stmt->getParent()->createScopArrayInfo( |
| 851 | ElementType: MMI.A->getElementType(), BaseName: "Packed_A" , |
| 852 | Sizes: {FirstDimSize, SecondDimSize, ThirdDimSize}); |
| 853 | |
| 854 | // Compute the access relation for copying from A to PackedA. |
| 855 | isl::map AccRelA = MMI.A->getLatestAccessRelation(); |
| 856 | isl::map AccRelPackedA = getMatMulAccRel(MapOldIndVar, FirstDim: 4, SecondDim: 6); |
| 857 | AccRelPackedA = |
| 858 | AccRelPackedA.set_tuple_id(type: isl::dim::out, id: PackedA->getBasePtrId()); |
| 859 | // { MemrefA[] -> PackedA[] } |
| 860 | isl::map PackedATranslator = AccRelPackedA.apply_domain(map2: AccRelA); |
| 861 | |
| 862 | // Compute the domain for the copy statement. |
| 863 | // Construct the copy statement domain out of the 3 outermost scatter |
| 864 | // dimensions (to match the 3 band nodes surrounding the extension node) and |
| 865 | // the array elements to copy (one statement instance per array element). |
| 866 | // { Scatter[] } |
| 867 | isl::set ScatterDomain = MapOldIndVar.intersect_domain(set: Domain).range(); |
| 868 | // { Scatter[] -> OutermostScatter[] } |
| 869 | isl::map OuterDomainMap = |
| 870 | makeIdentityMap(Set: ScatterDomain, RestrictDomain: true).project_out(type: isl::dim::out, first: 3, n: 6); |
| 871 | // { Scatter[] -> MemrefA[] } |
| 872 | isl::map CopyFrom = MapOldIndVar.reverse().apply_range(map2: AccRelA); |
| 873 | // { Scatter[] -> CopyStmt[] } |
| 874 | isl::map DomainTranslator = OuterDomainMap.range_product(map2: CopyFrom); |
| 875 | // { CopyStmt[] } |
| 876 | isl::set CopyDomain = DomainTranslator.range(); |
| 877 | |
| 878 | // Translate the access relations to the new domain. |
| 879 | // { CopyStmt[] -> MemrefA[] } |
| 880 | CopyFrom = CopyFrom.apply_domain(map2: DomainTranslator); |
| 881 | // { CopyStmt[] -> PackedA[] } |
| 882 | isl::map CopyTo = CopyFrom.apply_range(map2: PackedATranslator); |
| 883 | |
| 884 | // Create the copy statement and redirect access. |
| 885 | ScopStmt *CopyStmt = |
| 886 | Stmt->getParent()->addScopStmt(SourceRel: CopyFrom, TargetRel: CopyTo, Domain: CopyDomain); |
| 887 | MMI.A->setNewAccessRelation(AccRelPackedA); |
| 888 | |
| 889 | // Insert into the schedule tree. |
| 890 | // { Scatter[] -> CopyStmt[] } |
| 891 | isl::map ExtScatterCopy = makeIdentityMap(Set: CopyStmt->getDomain(), RestrictDomain: true); |
| 892 | ExtScatterCopy = ExtScatterCopy.project_out(type: isl::dim::in, first: 3, n: 2); |
| 893 | return createExtensionNode(Node, ExtensionMap: ExtScatterCopy); |
| 894 | } |
| 895 | |
| 896 | /// Apply the packing transformation. |
| 897 | /// |
| 898 | /// The packing transformation can be described as a data-layout |
| 899 | /// transformation that requires to introduce a new array, copy data |
| 900 | /// to the array, and change memory access locations to reference the array. |
| 901 | /// It can be used to ensure that elements of the new array are read in-stride |
| 902 | /// access, aligned to cache lines boundaries, and preloaded into certain cache |
| 903 | /// levels. |
| 904 | /// |
| 905 | /// As an example let us consider the packing of the array A that would help |
| 906 | /// to read its elements with in-stride access. An access to the array A |
| 907 | /// is represented by an access relation that has the form |
| 908 | /// S[i, j, k] -> A[i, k]. The scheduling function of the SCoP statement S has |
| 909 | /// the form S[i,j, k] -> [floor((j mod Nc) / Nr), floor((i mod Mc) / Mr), |
| 910 | /// k mod Kc, j mod Nr, i mod Mr]. |
| 911 | /// |
| 912 | /// To ensure that elements of the array A are read in-stride access, we add |
| 913 | /// a new array Packed_A[Mc/Mr][Kc][Mr] to the SCoP, using |
| 914 | /// Scop::createScopArrayInfo, change the access relation |
| 915 | /// S[i, j, k] -> A[i, k] to |
| 916 | /// S[i, j, k] -> Packed_A[floor((i mod Mc) / Mr), k mod Kc, i mod Mr], using |
| 917 | /// MemoryAccess::setNewAccessRelation, and copy the data to the array, using |
| 918 | /// the copy statement created by Scop::addScopStmt. |
| 919 | /// |
| 920 | /// @param Node The schedule node to be optimized. |
| 921 | /// @param MapOldIndVar The relation, which maps original induction variables |
| 922 | /// to the ones, which are produced by schedule |
| 923 | /// transformations. |
| 924 | /// @param MicroParams, MacroParams Parameters of the BLIS kernel |
| 925 | /// to be taken into account. |
| 926 | /// @param MMI Parameters of the matrix multiplication operands. |
| 927 | /// @return The optimized schedule node. |
| 928 | static isl::schedule_node |
| 929 | optimizeDataLayoutMatrMulPattern(isl::schedule_node Node, isl::map MapOldIndVar, |
| 930 | MicroKernelParamsTy MicroParams, |
| 931 | MacroKernelParamsTy MacroParams, |
| 932 | MatMulInfoTy &MMI) { |
| 933 | isl::id InputDimsId = MapOldIndVar.get_tuple_id(type: isl::dim::in); |
| 934 | ScopStmt *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user()); |
| 935 | |
| 936 | Node = Node.parent().parent().parent().parent().parent().parent(); |
| 937 | Node = isl::manage(ptr: isl_schedule_node_band_split(node: Node.release(), pos: 2)); |
| 938 | |
| 939 | Node = Node.child(pos: 0); |
| 940 | Node = |
| 941 | optimizePackedB(Node, Stmt, MapOldIndVar, MicroParams, MacroParams, MMI); |
| 942 | |
| 943 | Node = Node.child(pos: 0); |
| 944 | Node = |
| 945 | optimizePackedA(Node, Stmt, MapOldIndVar, MicroParams, MacroParams, MMI); |
| 946 | |
| 947 | return Node.child(pos: 0).child(pos: 0).child(pos: 0).child(pos: 0).child(pos: 0); |
| 948 | } |
| 949 | |
| 950 | /// Get a relation mapping induction variables produced by schedule |
| 951 | /// transformations to the original ones. |
| 952 | /// |
| 953 | /// @param Node The schedule node produced as the result of creation |
| 954 | /// of the BLIS kernels. |
| 955 | /// @param MicroKernelParams, MacroKernelParams Parameters of the BLIS kernel |
| 956 | /// to be taken into account. |
| 957 | /// @return The relation mapping original induction variables to the ones |
| 958 | /// produced by schedule transformation. |
| 959 | /// @see ScheduleTreeOptimizer::createMicroKernel |
| 960 | /// @see ScheduleTreeOptimizer::createMacroKernel |
| 961 | /// @see getMacroKernelParams |
| 962 | static isl::map |
| 963 | getInductionVariablesSubstitution(isl::schedule_node Node, |
| 964 | MicroKernelParamsTy MicroKernelParams, |
| 965 | MacroKernelParamsTy MacroKernelParams) { |
| 966 | auto Child = Node.child(pos: 0); |
| 967 | auto UnMapOldIndVar = Child.get_prefix_schedule_union_map(); |
| 968 | auto MapOldIndVar = isl::map::from_union_map(umap: UnMapOldIndVar); |
| 969 | unsigned Dim = unsignedFromIslSize(Size: MapOldIndVar.range_tuple_dim()); |
| 970 | if (Dim > 9u) |
| 971 | return MapOldIndVar.project_out(type: isl::dim::out, first: 0, n: Dim - 9); |
| 972 | return MapOldIndVar; |
| 973 | } |
| 974 | |
| 975 | /// Isolate a set of partial tile prefixes and unroll the isolated part. |
| 976 | /// |
| 977 | /// The set should ensure that it contains only partial tile prefixes that have |
| 978 | /// exactly Mr x Nr iterations of the two innermost loops produced by |
| 979 | /// the optimization of the matrix multiplication. Mr and Nr are parameters of |
| 980 | /// the micro-kernel. |
| 981 | /// |
| 982 | /// In case of parametric bounds, this helps to auto-vectorize the unrolled |
| 983 | /// innermost loops, using the SLP vectorizer. |
| 984 | /// |
| 985 | /// @param Node The schedule node to be modified. |
| 986 | /// @param MicroKernelParams Parameters of the micro-kernel |
| 987 | /// to be taken into account. |
| 988 | /// @return The modified isl_schedule_node. |
| 989 | static isl::schedule_node |
| 990 | isolateAndUnrollMatMulInnerLoops(isl::schedule_node Node, |
| 991 | MicroKernelParamsTy MicroKernelParams) { |
| 992 | isl::schedule_node Child = Node.child(pos: 0); |
| 993 | isl::union_map UnMapOldIndVar = Child.get_prefix_schedule_relation(); |
| 994 | isl::set Prefix = isl::map::from_union_map(umap: UnMapOldIndVar).range(); |
| 995 | unsigned Dims = unsignedFromIslSize(Size: Prefix.tuple_dim()); |
| 996 | assert(Dims >= 1); |
| 997 | Prefix = Prefix.project_out(type: isl::dim::set, first: Dims - 1, n: 1); |
| 998 | Prefix = getPartialTilePrefixes(ScheduleRange: Prefix, VectorWidth: MicroKernelParams.Nr); |
| 999 | Prefix = getPartialTilePrefixes(ScheduleRange: Prefix, VectorWidth: MicroKernelParams.Mr); |
| 1000 | |
| 1001 | isl::union_set IsolateOption = |
| 1002 | getIsolateOptions(IsolateDomain: Prefix.add_dims(type: isl::dim::set, n: 3), OutDimsNum: 3); |
| 1003 | isl::ctx Ctx = Node.ctx(); |
| 1004 | auto Options = IsolateOption.unite(uset2: getDimOptions(Ctx, Option: "unroll" )); |
| 1005 | Options = Options.unite(uset2: getUnrollIsolatedSetOptions(Ctx)); |
| 1006 | Node = Node.as<isl::schedule_node_band>().set_ast_build_options(Options); |
| 1007 | Node = Node.parent().parent().parent(); |
| 1008 | IsolateOption = getIsolateOptions(IsolateDomain: Prefix, OutDimsNum: 3); |
| 1009 | Options = IsolateOption.unite(uset2: getDimOptions(Ctx, Option: "separate" )); |
| 1010 | Node = Node.as<isl::schedule_node_band>().set_ast_build_options(Options); |
| 1011 | Node = Node.child(pos: 0).child(pos: 0).child(pos: 0); |
| 1012 | return Node; |
| 1013 | } |
| 1014 | |
| 1015 | /// Insert "Loop Vectorizer Disabled" mark node. |
| 1016 | /// |
| 1017 | /// @param Node The child of the mark node to be inserted. |
| 1018 | /// @return The modified isl_schedule_node. |
| 1019 | static isl::schedule_node markLoopVectorizerDisabled(isl::schedule_node Node) { |
| 1020 | auto Id = isl::id::alloc(ctx: Node.ctx(), name: "Loop Vectorizer Disabled" , user: nullptr); |
| 1021 | return Node.insert_mark(mark: Id).child(pos: 0); |
| 1022 | } |
| 1023 | |
| 1024 | /// Restore the initial ordering of dimensions of the band node |
| 1025 | /// |
| 1026 | /// In case the band node represents all the dimensions of the iteration |
| 1027 | /// domain, recreate the band node to restore the initial ordering of the |
| 1028 | /// dimensions. |
| 1029 | /// |
| 1030 | /// @param Node The band node to be modified. |
| 1031 | /// @return The modified schedule node. |
| 1032 | static isl::schedule_node |
| 1033 | getBandNodeWithOriginDimOrder(isl::schedule_node Node) { |
| 1034 | assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band); |
| 1035 | if (isl_schedule_node_get_type(node: Node.child(pos: 0).get()) != isl_schedule_node_leaf) |
| 1036 | return Node; |
| 1037 | auto Domain = Node.get_universe_domain(); |
| 1038 | assert(isl_union_set_n_set(Domain.get()) == 1); |
| 1039 | if (Node.get_schedule_depth().release() != 0 || |
| 1040 | (unsignedFromIslSize(Size: isl::set(Domain).tuple_dim()) != |
| 1041 | unsignedFromIslSize(Size: Node.as<isl::schedule_node_band>().n_member()))) |
| 1042 | return Node; |
| 1043 | Node = isl::manage(ptr: isl_schedule_node_delete(node: Node.copy())); |
| 1044 | auto PartialSchedulePwAff = Domain.identity_union_pw_multi_aff(); |
| 1045 | auto PartialScheduleMultiPwAff = |
| 1046 | isl::multi_union_pw_aff(PartialSchedulePwAff); |
| 1047 | PartialScheduleMultiPwAff = |
| 1048 | PartialScheduleMultiPwAff.reset_tuple_id(type: isl::dim::set); |
| 1049 | return Node.insert_partial_schedule(schedule: PartialScheduleMultiPwAff); |
| 1050 | } |
| 1051 | |
| 1052 | static isl::schedule_node optimizeMatMulPattern(isl::schedule_node Node, |
| 1053 | const TargetTransformInfo *TTI, |
| 1054 | MatMulInfoTy &MMI) { |
| 1055 | assert(TTI && "The target transform info should be provided." ); |
| 1056 | int DimOutNum = isl_schedule_node_band_n_member(node: Node.get()); |
| 1057 | assert(DimOutNum > 2 && "In case of the matrix multiplication the loop nest " |
| 1058 | "and, consequently, the corresponding scheduling " |
| 1059 | "functions have at least three dimensions." ); |
| 1060 | Node = getBandNodeWithOriginDimOrder(Node); |
| 1061 | Node = permuteBandNodeDimensions(Node, FirstDim: MMI.i, SecondDim: DimOutNum - 3); |
| 1062 | int NewJ = MMI.j == DimOutNum - 3 ? MMI.i : MMI.j; |
| 1063 | int NewK = MMI.k == DimOutNum - 3 ? MMI.i : MMI.k; |
| 1064 | Node = permuteBandNodeDimensions(Node, FirstDim: NewJ, SecondDim: DimOutNum - 2); |
| 1065 | NewK = NewK == DimOutNum - 2 ? NewJ : NewK; |
| 1066 | Node = permuteBandNodeDimensions(Node, FirstDim: NewK, SecondDim: DimOutNum - 1); |
| 1067 | auto MicroKernelParams = getMicroKernelParams(TTI, MMI); |
| 1068 | auto MacroKernelParams = getMacroKernelParams(TTI, MicroKernelParams, MMI); |
| 1069 | Node = createMacroKernel(Node, MacroKernelParams); |
| 1070 | Node = createMicroKernel(Node, MicroKernelParams); |
| 1071 | if (MacroKernelParams.Mc == 1 || MacroKernelParams.Nc == 1 || |
| 1072 | MacroKernelParams.Kc == 1) |
| 1073 | return Node; |
| 1074 | auto MapOldIndVar = getInductionVariablesSubstitution(Node, MicroKernelParams, |
| 1075 | MacroKernelParams); |
| 1076 | if (MapOldIndVar.is_null()) |
| 1077 | return Node; |
| 1078 | Node = markLoopVectorizerDisabled(Node: Node.parent()).child(pos: 0); |
| 1079 | Node = isolateAndUnrollMatMulInnerLoops(Node, MicroKernelParams); |
| 1080 | return optimizeDataLayoutMatrMulPattern(Node, MapOldIndVar, MicroParams: MicroKernelParams, |
| 1081 | MacroParams: MacroKernelParams, MMI); |
| 1082 | } |
| 1083 | |
| 1084 | /// Check if this node contains a partial schedule that could |
| 1085 | /// probably be optimized with analytical modeling. |
| 1086 | /// |
| 1087 | /// isMatrMultPattern tries to determine whether the following conditions |
| 1088 | /// are true: |
| 1089 | /// 1. the partial schedule contains only one statement. |
| 1090 | /// 2. there are exactly three input dimensions. |
| 1091 | /// 3. all memory accesses of the statement will have stride 0 or 1, if we |
| 1092 | /// interchange loops (switch the variable used in the inner loop to |
| 1093 | /// the outer loop). |
| 1094 | /// 4. all memory accesses of the statement except from the last one, are |
| 1095 | /// read memory access and the last one is write memory access. |
| 1096 | /// 5. all subscripts of the last memory access of the statement don't |
| 1097 | /// contain the variable used in the inner loop. |
| 1098 | /// If this is the case, we could try to use an approach that is similar to |
| 1099 | /// the one used to get close-to-peak performance of matrix multiplications. |
| 1100 | /// |
| 1101 | /// @param Node The node to check. |
| 1102 | /// @param D The SCoP dependencies. |
| 1103 | /// @param MMI Parameters of the matrix multiplication operands. |
| 1104 | static bool isMatrMultPattern(isl::schedule_node Node, const Dependences *D, |
| 1105 | MatMulInfoTy &MMI) { |
| 1106 | auto PartialSchedule = isl::manage( |
| 1107 | ptr: isl_schedule_node_band_get_partial_schedule_union_map(node: Node.get())); |
| 1108 | if (isl_schedule_node_band_n_member(node: Node.get()) < 3 || |
| 1109 | Node.get_schedule_depth().release() != 0 || |
| 1110 | isl_union_map_n_map(umap: PartialSchedule.get()) != 1) |
| 1111 | return false; |
| 1112 | auto NewPartialSchedule = isl::map::from_union_map(umap: PartialSchedule); |
| 1113 | if (containsMatrMult(PartialSchedule: NewPartialSchedule, D, MMI)) |
| 1114 | return true; |
| 1115 | return false; |
| 1116 | } |
| 1117 | |
| 1118 | /// Get the dimension size. |
| 1119 | /// |
| 1120 | /// Return the size of the dimension @p Pos, which is obtained from @p SAI. |
| 1121 | /// Return -1 in the case of the first dimension of a multi-dimensional array, |
| 1122 | /// since the ScopArrayInfo class does not carry size information. |
| 1123 | /// |
| 1124 | /// @param SAI The information about the array. |
| 1125 | /// @param Pos The position of the dimension. |
| 1126 | /// @return The size of the dimension. |
| 1127 | static int getDimSize(const ScopArrayInfo *SAI, unsigned Pos) { |
| 1128 | if (Pos == 0) |
| 1129 | return -1; |
| 1130 | const llvm::SCEV *SCEVDimSize = SAI->getDimensionSize(Dim: Pos); |
| 1131 | assert(SCEVDimSize); |
| 1132 | auto *ConstantDimSize = dyn_cast<const SCEVConstant>(Val: SCEVDimSize); |
| 1133 | assert(ConstantDimSize); |
| 1134 | auto *IntDimSize = dyn_cast<ConstantInt>(Val: ConstantDimSize->getValue()); |
| 1135 | assert(IntDimSize); |
| 1136 | return IntDimSize->getSExtValue(); |
| 1137 | } |
| 1138 | |
| 1139 | /// Check whether the access relation has the specified form. |
| 1140 | /// |
| 1141 | /// Check that the access relation @p AccMap has the form T[I0, …, In], where |
| 1142 | /// indexes I0, …, In are specified by @p Dimensions. |
| 1143 | /// |
| 1144 | /// @param Domain The domain of the access relation. |
| 1145 | /// @param AccMap The access relation to be checked. |
| 1146 | /// @param Dimensions The permutation of the subset of the input dimensions. |
| 1147 | /// @return True if @p AccMap has the expected form and false, |
| 1148 | /// otherwise. |
| 1149 | static bool isCorrectAccessMap(isl::set Domain, isl::map AccMap, |
| 1150 | ArrayRef<int> Dimensions) { |
| 1151 | isl::space Space = AccMap.get_space(); |
| 1152 | if (unsignedFromIslSize(Size: Space.dim(type: isl::dim::out)) != Dimensions.size()) |
| 1153 | return false; |
| 1154 | |
| 1155 | // Create an access relation of the following form: |
| 1156 | // [I0, …, Im] -> [Il, …, In], where indexes |
| 1157 | // Il, …, In are specified by @p Dimensions. |
| 1158 | isl::map PossibleTensor = isl::map::universe(space: Space); |
| 1159 | unsigned DimInSize = unsignedFromIslSize(Size: Space.dim(type: isl::dim::in)); |
| 1160 | for (unsigned i = 0; i < Dimensions.size(); i++) { |
| 1161 | const int InPos = Dimensions[i]; |
| 1162 | if ((InPos >= static_cast<int>(DimInSize)) || (InPos < 0)) |
| 1163 | return false; |
| 1164 | PossibleTensor = |
| 1165 | PossibleTensor.equate(type1: isl::dim::in, pos1: InPos, type2: isl::dim::out, pos2: i); |
| 1166 | } |
| 1167 | |
| 1168 | AccMap = AccMap.intersect_domain(set: Domain); |
| 1169 | PossibleTensor = PossibleTensor.intersect_domain(set: Domain); |
| 1170 | |
| 1171 | // If AccMap != PossibleTensor here (the two maps have been gisted at |
| 1172 | // this point), it means that the writes are not complete, or in other |
| 1173 | // words, it is a Partial write and Partial writes must be rejected. |
| 1174 | return AccMap.is_equal(map2: PossibleTensor); |
| 1175 | } |
| 1176 | |
| 1177 | /// Check whether the access represents the tensor contraction operand. |
| 1178 | /// |
| 1179 | /// Check that the access relation @p AccMap has the form T[i1, …, in]. |
| 1180 | /// Obtained indexes i1, …, in, their sizes and their permutation are stored |
| 1181 | /// into @p IndexSet, @p DimensionSizes, and @p Dimensions, respectively. |
| 1182 | /// |
| 1183 | /// @param Domain The domain of the access relation. |
| 1184 | /// @param AccMap The access relation to be checked. |
| 1185 | /// @param IndexSet The subset of the input dimensions. |
| 1186 | /// @param DimensionSizes Sizes of the input dimensions of @p Dimensions. |
| 1187 | /// @param Dimensions The permutation of the subset of the input dimensions. |
| 1188 | /// @return True if @p AccMap has the expected form and false, |
| 1189 | /// otherwise. |
| 1190 | static bool isTCOperandAcc(isl::set Domain, isl::map AccMap, |
| 1191 | SmallDenseSet<int> &IndexSet, |
| 1192 | SmallVectorImpl<int> &DimensionSizes, |
| 1193 | SmallVectorImpl<int> &Dimensions) { |
| 1194 | isl::id Id = AccMap.get_tuple_id(type: isl::dim::out); |
| 1195 | const ScopArrayInfo *SAI = ScopArrayInfo::getFromId(Id); |
| 1196 | assert(SAI && "AccMap should represent memory access" ); |
| 1197 | |
| 1198 | // Fix values of output dimensions with respect to their positions. |
| 1199 | // In the case of the tensor contraction, values of output dimensions are |
| 1200 | // fixed and form a permutation of a subset of values of input dimensions. |
| 1201 | // |
| 1202 | // For example, in the case of Stmt[i][j][k] -> A[k][i], which represents |
| 1203 | // the operand of the tensor contraction, we get the following map by fixing |
| 1204 | // the output dimensions Stmt[1][j][0] -> A[0][1]. |
| 1205 | // |
| 1206 | // We store the permutation of the subset of the input dimensions {2, 0} into |
| 1207 | // @p Dimensions. |
| 1208 | // |
| 1209 | // The obtained permutation and the isCorrectAccessMap function are used to |
| 1210 | // check whether the access relation @p AccMap represents the tensor |
| 1211 | // contraction operand. For example, in the case of |
| 1212 | // Stmt[i][j][k] -> A[i-1][j+1], we get Stmt[1][0][k] -> A[0][1] and, |
| 1213 | // consequently, {1, 0}, which is rejected by isCorrectAccessMap, |
| 1214 | // since it corresponds to Stmt[i][j][k] -> A[j][i]. |
| 1215 | isl::map CheckMap = isl::manage(ptr: AccMap.copy()); |
| 1216 | unsigned OutDimNum = unsignedFromIslSize(Size: CheckMap.dim(type: isl::dim::out)); |
| 1217 | for (unsigned i = 0; i < OutDimNum; i++) |
| 1218 | CheckMap = CheckMap.fix_si(type: isl::dim::out, pos: i, value: i); |
| 1219 | |
| 1220 | // Try to obtain the permutation and sizes of corresponding input dimensions. |
| 1221 | Dimensions.assign(NumElts: OutDimNum, Elt: -1); |
| 1222 | for (unsigned i : rangeIslSize(Begin: 0, End: CheckMap.dim(type: isl::dim::in))) { |
| 1223 | isl::val Val = getConstant(Map: CheckMap, Dim: isl::dim::in, Pos: i); |
| 1224 | if (!Val.is_int()) |
| 1225 | continue; |
| 1226 | int OutPos = -1; |
| 1227 | llvm::APInt ValAPInt = APIntFromVal(V: Val); |
| 1228 | if (ValAPInt.isSignedIntN(N: 32)) |
| 1229 | OutPos = ValAPInt.getSExtValue(); |
| 1230 | if ((OutPos < 0) || (OutPos >= static_cast<int>(OutDimNum)) || |
| 1231 | IndexSet.count(V: i)) |
| 1232 | return false; |
| 1233 | IndexSet.insert(V: i); |
| 1234 | Dimensions[OutPos] = i; |
| 1235 | if (DimensionSizes[i] <= 0) |
| 1236 | DimensionSizes[i] = getDimSize(SAI, Pos: OutPos); |
| 1237 | } |
| 1238 | |
| 1239 | return isCorrectAccessMap(Domain, AccMap, Dimensions); |
| 1240 | } |
| 1241 | |
| 1242 | /// Find the intersection of two sets. |
| 1243 | /// |
| 1244 | /// Find the intersection of the set @p A and the set @p B. |
| 1245 | /// |
| 1246 | /// @param A, B Sets to intersect. |
| 1247 | /// @return The set intersection. |
| 1248 | static SmallDenseSet<int> intersect(const SmallDenseSet<int> &A, |
| 1249 | const SmallDenseSet<int> &B) { |
| 1250 | SmallDenseSet<int> Intersection = A; |
| 1251 | set_intersect(S1&: Intersection, S2: B); |
| 1252 | return Intersection; |
| 1253 | } |
| 1254 | |
| 1255 | /// Check whether the set is a superset. |
| 1256 | /// |
| 1257 | /// Check that the set @p A is a superset of @p B. |
| 1258 | /// |
| 1259 | /// @param A, B Sets to be checked. |
| 1260 | /// @return True if the set A is a superset of B. |
| 1261 | static bool isSuperset(const SmallDenseSet<int> &A, |
| 1262 | const SmallDenseSet<int> &B) { |
| 1263 | return intersect(A, B).size() == B.size(); |
| 1264 | } |
| 1265 | |
| 1266 | /// Find the union of two sets. |
| 1267 | /// |
| 1268 | /// Find the union of the set @p A and the set @p B. |
| 1269 | /// |
| 1270 | /// @param A, B Sets to unite. |
| 1271 | /// @return The set union. |
| 1272 | static SmallDenseSet<int> unite(const SmallDenseSet<int> &A, |
| 1273 | const SmallDenseSet<int> &B) { |
| 1274 | SmallDenseSet<int> Union = A; |
| 1275 | set_union(S1&: Union, S2: B); |
| 1276 | return Union; |
| 1277 | } |
| 1278 | |
| 1279 | /// Determine the access that writes to the tensor, which contains |
| 1280 | /// the result of the tensor contraction. |
| 1281 | /// |
| 1282 | /// @param Domain The domain of the statement. |
| 1283 | /// @param Stmt The statement, which writes to memory. |
| 1284 | /// @param TCI The information about the tensor contraction. |
| 1285 | /// @param IandJIndexSet The set, which contains free indexes of tensors. |
| 1286 | /// @return The determined MemoryAccess, or nullptr if there is no necessary |
| 1287 | /// access within the SCoP. |
| 1288 | static MemoryAccess *getWriteAccess(isl::set Domain, ScopStmt *Stmt, |
| 1289 | TCInfoTy &TCI, |
| 1290 | SmallDenseSet<int> &IandJIndexSet) { |
| 1291 | TCI.WriteToC = nullptr; |
| 1292 | SmallVector<MemoryAccess *, 32> Accesses = getAccessesInOrder(Stmt&: *Stmt); |
| 1293 | for (MemoryAccess *MemA : reverse(C&: Accesses)) { |
| 1294 | // A TC-like does not contain write scalar memory accesses |
| 1295 | if (!MemA->isLatestArrayKind()) |
| 1296 | return nullptr; |
| 1297 | // The last memory access should be a write memory access. |
| 1298 | if (!MemA->isWrite()) |
| 1299 | return nullptr; |
| 1300 | |
| 1301 | isl::map AccMap = MemA->getLatestAccessRelation(); |
| 1302 | if (!isTCOperandAcc(Domain, AccMap, IndexSet&: IandJIndexSet, DimensionSizes&: TCI.DimensionSizes, |
| 1303 | Dimensions&: TCI.CDimensions)) |
| 1304 | return nullptr; |
| 1305 | |
| 1306 | return MemA; |
| 1307 | } |
| 1308 | return nullptr; |
| 1309 | } |
| 1310 | |
| 1311 | /// Determine an access, which reads elements of an operand of the tensor |
| 1312 | /// contraction |
| 1313 | /// |
| 1314 | /// @param MemAccessPtr The access, which reads elements of the tensor. |
| 1315 | /// @param IndexSet The set, which contains indexes of the tensors. |
| 1316 | /// @param IandJIndexSet The set, which contains free indexes of tensors. |
| 1317 | /// @param Dimensions The permutation of the subset of the input dimensions. |
| 1318 | /// @param TCI The information about the tensor contraction. |
| 1319 | /// @return True if the memory access @p MemAccessPtr corresponds |
| 1320 | /// to the tensor contraction. |
| 1321 | static bool setReadAccess(MemoryAccess *MemAccessPtr, |
| 1322 | const SmallDenseSet<int> &IndexSet, |
| 1323 | const SmallDenseSet<int> &IandJIndexSet, |
| 1324 | ArrayRef<int> Dimensions, TCInfoTy &TCI) { |
| 1325 | if (!TCI.A) { |
| 1326 | // Probably IndexSet is a union of I and P sets. |
| 1327 | if (!isSuperset(A: IndexSet, B: TCI.P)) |
| 1328 | return false; |
| 1329 | |
| 1330 | // Obtain the set I. |
| 1331 | TCI.I = set_difference(S1: IndexSet, S2: TCI.P); |
| 1332 | if (!isSuperset(A: IandJIndexSet, B: TCI.I)) |
| 1333 | return false; |
| 1334 | |
| 1335 | // Obtain the set J. |
| 1336 | TCI.J = set_difference(S1: IandJIndexSet, S2: TCI.I); |
| 1337 | |
| 1338 | // Set the first operand of the tensor contraction. |
| 1339 | TCI.A = MemAccessPtr; |
| 1340 | llvm::replace(Cont&: TCI.ADimensions, ContIt: TCI.ADimensions.begin(), |
| 1341 | ContEnd: TCI.ADimensions.end(), ValIt: Dimensions.begin(), ValEnd: Dimensions.end()); |
| 1342 | return true; |
| 1343 | } |
| 1344 | |
| 1345 | if (!TCI.B) { |
| 1346 | // IndexSet should be a union of J and P sets. |
| 1347 | if (unite(A: TCI.P, B: TCI.J) != IndexSet) |
| 1348 | return false; |
| 1349 | |
| 1350 | // Set the second operand of the tensor contraction. |
| 1351 | TCI.B = MemAccessPtr; |
| 1352 | llvm::replace(Cont&: TCI.BDimensions, ContIt: TCI.BDimensions.begin(), |
| 1353 | ContEnd: TCI.BDimensions.end(), ValIt: Dimensions.begin(), ValEnd: Dimensions.end()); |
| 1354 | return true; |
| 1355 | } |
| 1356 | |
| 1357 | return false; |
| 1358 | } |
| 1359 | |
| 1360 | /// Check that all memory accesses of the statement, except from the last |
| 1361 | /// one, are read memory accesses, which read elements of operands of the tensor |
| 1362 | /// contraction and its result. |
| 1363 | /// |
| 1364 | /// @param Domain The domain of the statement. |
| 1365 | /// @param Stmt The statement, which writes to memory. |
| 1366 | /// @param TCI The information about the tensor contraction. |
| 1367 | /// @param IandJIndexSet The set, which contains free indexes of tensors. |
| 1368 | /// @return True if all read memory accesses of the statement @p Stmt correspond |
| 1369 | /// to the tensor contraction. |
| 1370 | static bool setReadAccesses(isl::set Domain, ScopStmt *Stmt, TCInfoTy &TCI, |
| 1371 | SmallDenseSet<int> &IandJIndexSet) { |
| 1372 | TCI.A = nullptr; |
| 1373 | TCI.B = nullptr; |
| 1374 | TCI.ReadFromC = nullptr; |
| 1375 | SmallVector<MemoryAccess *, 32> Accesses = getAccessesInOrder(Stmt&: *Stmt); |
| 1376 | for (auto *MemA = Accesses.begin(); *MemA != TCI.WriteToC; MemA++) { |
| 1377 | MemoryAccess *MemAccessPtr = *MemA; |
| 1378 | |
| 1379 | // All memory accesses, except from the last one, should be read memory |
| 1380 | // accesses. |
| 1381 | if (MemAccessPtr->isWrite()) |
| 1382 | return false; |
| 1383 | |
| 1384 | isl::map AccMap = MemAccessPtr->getLatestAccessRelation(); |
| 1385 | |
| 1386 | if (!MemAccessPtr->isLatestArrayKind()) { |
| 1387 | // Check whether the scalar read memory access is not partial. |
| 1388 | if (!Domain.is_subset(set2: AccMap.domain())) |
| 1389 | return false; |
| 1390 | continue; |
| 1391 | return false; |
| 1392 | } |
| 1393 | |
| 1394 | // There is only one memory access, which reads elements of the result of |
| 1395 | // the tensor contraction. |
| 1396 | if (AccMap.is_equal(map2: TCI.WriteToC->getLatestAccessRelation())) { |
| 1397 | if (TCI.ReadFromC) |
| 1398 | return false; |
| 1399 | TCI.ReadFromC = MemAccessPtr; |
| 1400 | continue; |
| 1401 | } |
| 1402 | |
| 1403 | SmallVector<int> Dimensions; |
| 1404 | SmallDenseSet<int> IndexSet; |
| 1405 | if (!isTCOperandAcc(Domain, AccMap, IndexSet, DimensionSizes&: TCI.DimensionSizes, |
| 1406 | Dimensions)) |
| 1407 | return false; |
| 1408 | |
| 1409 | if (!setReadAccess(MemAccessPtr, IndexSet, IandJIndexSet, Dimensions, TCI)) |
| 1410 | return false; |
| 1411 | } |
| 1412 | |
| 1413 | // Check that there are read memory accesses, which read elements of operands |
| 1414 | // of the tensor contraction and its result. |
| 1415 | return TCI.ReadFromC && TCI.A && TCI.B; |
| 1416 | } |
| 1417 | |
| 1418 | /// Check accesses to operands of the tensor contraction. |
| 1419 | /// |
| 1420 | /// Check that accesses of the SCoP statement, which corresponds to |
| 1421 | /// the partial schedule @p PartialSchedule, represent accesses |
| 1422 | /// to the non-scalar operands of the tensor contraction. |
| 1423 | /// |
| 1424 | /// @param Domain The domain of the SCoP statement. |
| 1425 | /// @param PartialSchedule The partial schedule of the SCoP statement. |
| 1426 | /// @param TCI Parameters of the tensor contraction operands. |
| 1427 | /// @return True if the corresponding SCoP statement |
| 1428 | /// represents tensor contraction and false, |
| 1429 | /// otherwise. |
| 1430 | static bool containsOnlyTCAcc(isl::set Domain, isl::map PartialSchedule, |
| 1431 | TCInfoTy &TCI) { |
| 1432 | isl::id InputDimsId = PartialSchedule.get_tuple_id(type: isl::dim::in); |
| 1433 | ScopStmt *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user()); |
| 1434 | |
| 1435 | // In region statements, the order of memory accesses execution is not |
| 1436 | // predictable at compile-time. |
| 1437 | if ((Stmt->size() <= 1) || Stmt->isRegionStmt()) |
| 1438 | return false; |
| 1439 | |
| 1440 | unsigned DimNum = unsignedFromIslSize(Size: PartialSchedule.dim(type: isl::dim::in)); |
| 1441 | TCI.DimensionSizes.resize(N: DimNum); |
| 1442 | SmallDenseSet<int> IandJIndexSet; |
| 1443 | |
| 1444 | TCI.WriteToC = getWriteAccess(Domain, Stmt, TCI, IandJIndexSet); |
| 1445 | if (!TCI.WriteToC) |
| 1446 | return false; |
| 1447 | |
| 1448 | if (intersect(A: IandJIndexSet, B: TCI.P).size() != 0) |
| 1449 | return false; |
| 1450 | |
| 1451 | if (!setReadAccesses(Domain, Stmt, TCI, IandJIndexSet)) |
| 1452 | return false; |
| 1453 | |
| 1454 | return true; |
| 1455 | } |
| 1456 | |
| 1457 | /// Check that dependency corresponds to the tensor contraction carried over |
| 1458 | /// loop dimension @p Dim. |
| 1459 | /// |
| 1460 | /// Check that the dependency has the form |
| 1461 | /// S(..., ki, max(k(i + 1)), ..., max(kn), ...) -> |
| 1462 | /// S(..., ki + 1, min(k(i + 1)), ..., min(kn), ...), where S is the SCoP |
| 1463 | /// statement. For this purpose, we analyze the set @p DepDelta, which |
| 1464 | /// represents the differences between image elements and domain elements of |
| 1465 | /// the corresponding map. |
| 1466 | /// |
| 1467 | /// @param DepDelta The set contains the differences between image elements |
| 1468 | /// and corresponding domain elements of the map, which |
| 1469 | /// represents the dependency. |
| 1470 | /// @param Dim The position of the index ki. |
| 1471 | /// @param BoundDeltas In the case of indexes of ki, the difference between |
| 1472 | /// image elements and corresponding domain elements |
| 1473 | /// corresponds to the difference between lexicographic |
| 1474 | /// minimum and lexicographic maximum of the corresponding |
| 1475 | /// dimension of the domain of the statement. |
| 1476 | /// @param IndexSet Obtained indexes ki, which describe the dependency. |
| 1477 | /// @return True if dependencies correspond to the tensor contraction |
| 1478 | /// and false, otherwise. |
| 1479 | static bool isReductionCarriedOverDim(isl::set DepDelta, unsigned Dim, |
| 1480 | isl::pw_multi_aff BoundDeltas, |
| 1481 | const SmallDenseSet<int> &IndexSet) { |
| 1482 | isl::space Space = DepDelta.get_space(); |
| 1483 | isl::set Superset = isl::set::universe(space: Space); |
| 1484 | for (unsigned i = 0; i < Dim; i += 1) |
| 1485 | Superset = Superset.fix_si(type: isl::dim::set, pos: i, value: 0); |
| 1486 | Superset = Superset.fix_si(type: isl::dim::set, pos: Dim, value: 1); |
| 1487 | |
| 1488 | // Check that the difference between the image element and the domain element |
| 1489 | // is equal to one in the case of the index ki. Image elements and |
| 1490 | // corresponding domain elements should be equal in the case of positions, |
| 1491 | // which are lower than the specified position. |
| 1492 | if (!DepDelta.is_subset(set2: Superset)) |
| 1493 | return false; |
| 1494 | |
| 1495 | // Compute a set, which is used to analyze how values of |
| 1496 | // the domain are related to the map that describes the dependency. |
| 1497 | isl_pw_multi_aff *DepDeltaPW = isl_pw_multi_aff_from_set(set: DepDelta.copy()); |
| 1498 | BoundDeltas = BoundDeltas.add(pma2: isl::manage(ptr: DepDeltaPW)); |
| 1499 | isl_set *ComplementRawSet = isl_set_from_pw_multi_aff(pma: BoundDeltas.release()); |
| 1500 | isl::set Complement = isl::manage(ptr: ComplementRawSet); |
| 1501 | |
| 1502 | for (unsigned i : rangeIslSize(Begin: Dim + 1, End: DepDelta.dim(type: isl::dim::set))) { |
| 1503 | if (!IndexSet.count(V: i)) { |
| 1504 | // Check the difference between the image element and the domain element |
| 1505 | // in the case of indexes, which do not describe the dependency. |
| 1506 | if (DepDelta.plain_get_val_if_fixed(type: isl::dim::set, pos: i).is_zero()) |
| 1507 | continue; |
| 1508 | return false; |
| 1509 | } |
| 1510 | |
| 1511 | // In the case of other indexes, which describe the dependency, |
| 1512 | // the difference between the image element and the domain element |
| 1513 | // should be equal to the difference between lexicographic minimum and |
| 1514 | // lexicographic maximum of the domain of the statement. |
| 1515 | if (!Complement.plain_get_val_if_fixed(type: isl::dim::set, pos: i).is_zero()) |
| 1516 | return false; |
| 1517 | } |
| 1518 | |
| 1519 | return true; |
| 1520 | } |
| 1521 | |
| 1522 | /// Check whether dependencies are over the complete domain. |
| 1523 | /// |
| 1524 | /// In the case of the tensor contraction RAW, WAW, WAR dependencies |
| 1525 | /// have the form |
| 1526 | /// S(..., ki, max(k(i + 1)), ..., max(kn), ...) -> |
| 1527 | /// S(..., ki + 1, min(k(i + 1)), ..., min(kn), ...), where S is the SCoP |
| 1528 | /// statement. Consequently, the domain of the dependencies |
| 1529 | /// can be described as |
| 1530 | /// Domain / Domain ∩ S(…, max(kn),…) ∩ S(…, max(k(i + 1)),…), |
| 1531 | /// where Domain is the domain of the statement S. |
| 1532 | /// |
| 1533 | /// For example, in the case of the following tensor contraction, |
| 1534 | /// corresponding domains will have the following form. |
| 1535 | /// |
| 1536 | /// An example of the tensor contraction: |
| 1537 | /// for (i = 0; i < 1024; i++) |
| 1538 | /// for (j = 0; j < 1024; j++) |
| 1539 | /// for (l = 0; l < 64; ++l) |
| 1540 | /// for (w = 0; w < 64; ++w) |
| 1541 | /// C[i][j] += A[i][l][w] * B[w][j][l]; |
| 1542 | /// |
| 1543 | /// The domain of the statement: |
| 1544 | /// { S[i0, i1, i2, i3] : i0 >= 0 and i0 <= 1023 and |
| 1545 | /// i1 >= 0 and i1 <= 1023 and |
| 1546 | /// i2 >= 0 and i2 <= 63 and |
| 1547 | /// i3 >= 0 and i3 <= 63 } |
| 1548 | /// |
| 1549 | /// The domain of the dependencies: |
| 1550 | /// { S[i0, i1, i2, i3] : (i0 >= 0 and i0 <= 1023 and |
| 1551 | /// i1 >= 0 and i1 <= 1023 and |
| 1552 | /// i2 >= 0 and i2 <= 63 and |
| 1553 | /// i3 >= 0 and i3 <= 62) or |
| 1554 | /// (i3 = 63 and i0 >= 0 and i0 <= 1023 and |
| 1555 | /// i1 >= 0 and i1 <= 1023 and |
| 1556 | /// i2 >= 0 and i2 <= 62) } |
| 1557 | /// |
| 1558 | /// @param Domain The domain of the statement. |
| 1559 | /// @param DepsForStmt RAW and RED dependencies for the statement. |
| 1560 | /// @param UpperBound The lexicographic maximum of the elements in |
| 1561 | /// the @p Domain. |
| 1562 | /// @param IndexSet Obtained indexes ki, which describe the dependencies. |
| 1563 | /// @return True if dependencies are over the complete domain |
| 1564 | /// and false, otherwise. |
| 1565 | static bool areDepsOverCompleteDomain(isl::set Domain, isl::map DepsForStmt, |
| 1566 | isl::pw_multi_aff UpperBound, |
| 1567 | SmallDenseSet<int> &IndexSet) { |
| 1568 | isl_set *UpperBoundRawSet = isl_set_from_pw_multi_aff(pma: UpperBound.copy()); |
| 1569 | isl::set UpperBoundSet = isl::manage(ptr: UpperBoundRawSet); |
| 1570 | |
| 1571 | isl::set DomainRed = isl::manage(ptr: Domain.copy()); |
| 1572 | for (const auto It : IndexSet) { |
| 1573 | isl::val FixedVal = UpperBoundSet.plain_get_val_if_fixed(type: isl::dim::set, pos: It); |
| 1574 | if (FixedVal.is_nan()) |
| 1575 | return false; |
| 1576 | DomainRed = isl::manage( |
| 1577 | ptr: isl_set_fix_val(set: DomainRed.copy(), type: isl_dim_set, pos: It, v: FixedVal.release())); |
| 1578 | } |
| 1579 | return DepsForStmt.domain().intersect(set2: Domain).is_equal( |
| 1580 | set2: Domain.subtract(set2: DomainRed)); |
| 1581 | } |
| 1582 | |
| 1583 | /// Check that dependencies correspond to the tensor contraction. |
| 1584 | /// |
| 1585 | /// Check that there are only true dependencies of the form |
| 1586 | /// S(..., ki, max(k(i + 1)), ..., max(kn), ...) -> |
| 1587 | /// S(..., ki + 1, min(k(i + 1)), ..., min(kn), ...), where S is the SCoP |
| 1588 | /// statement represented by @p Schedule. Such dependencies are produced by |
| 1589 | /// the tensor contraction. Obtained indexes ki are stored into @p IndexSet. |
| 1590 | /// |
| 1591 | /// The form of anti and output dependencies is specified implicitly by |
| 1592 | /// the form the SCoP statement, which is checked by subsequent analysis. |
| 1593 | /// |
| 1594 | /// @param Schedule The schedule of the SCoP statement. |
| 1595 | /// @param D The SCoP dependencies. |
| 1596 | /// @param Domain The domain of the statement. |
| 1597 | /// @param IndexSet Obtained indexes ki, which describe the dependencies. |
| 1598 | /// @return True if dependencies correspond to the tensor contraction |
| 1599 | /// and false, otherwise. |
| 1600 | static bool containsOnlyTcDeps(isl::map Schedule, const Dependences *D, |
| 1601 | SmallDenseSet<int> &IndexSet, isl::set Domain) { |
| 1602 | IslMaxOperationsGuard MaxOpGuard(Schedule.ctx().get(), OptComputeOut); |
| 1603 | |
| 1604 | isl::union_map Dep = |
| 1605 | D->getDependences(Kinds: Dependences::TYPE_RAW | Dependences::TYPE_RED); |
| 1606 | |
| 1607 | isl::space DomainSpace = Schedule.get_space().domain(); |
| 1608 | isl::space Space = DomainSpace.map_from_domain_and_range(range: DomainSpace); |
| 1609 | isl::map DepsForStmt = Dep.extract_map(space: Space); |
| 1610 | isl::set DepDeltas = DepsForStmt.deltas(); |
| 1611 | isl::size DeltasDimNum = DepDeltas.dim(type: isl::dim::set); |
| 1612 | isl::pw_multi_aff LowerBound = Domain.lexmin_pw_multi_aff(); |
| 1613 | isl::pw_multi_aff UpperBound = Domain.lexmax_pw_multi_aff(); |
| 1614 | isl::pw_multi_aff BoundDeltas = UpperBound.sub(pma2: LowerBound); |
| 1615 | |
| 1616 | for (int i : reverse(C: rangeIslSize(Begin: 0, End: DeltasDimNum))) { |
| 1617 | // In the case of the tensor contraction, the difference between image |
| 1618 | // elements and domain elements lies on a hyperplane where a dimension |
| 1619 | // has the fixed value one. |
| 1620 | isl::set Intersection = DepDeltas.fix_si(type: isl::dim::set, pos: i, value: 1); |
| 1621 | if (Intersection.is_empty()) |
| 1622 | continue; |
| 1623 | |
| 1624 | if (!isReductionCarriedOverDim(DepDelta: Intersection, Dim: i, BoundDeltas, IndexSet)) |
| 1625 | return false; |
| 1626 | |
| 1627 | IndexSet.insert(V: i); |
| 1628 | DepDeltas = DepDeltas.subtract(set2: Intersection); |
| 1629 | } |
| 1630 | |
| 1631 | // In the case of the tensor contraction, all dependencies should have |
| 1632 | // the previously described form. |
| 1633 | if ((unsignedFromIslSize(Size: DeltasDimNum) == 0) || !DepDeltas.is_empty()) |
| 1634 | return false; |
| 1635 | |
| 1636 | return areDepsOverCompleteDomain(Domain, DepsForStmt, UpperBound, IndexSet); |
| 1637 | } |
| 1638 | |
| 1639 | /// Check if the SCoP statement could probably be optimized with analytical |
| 1640 | /// modeling. |
| 1641 | /// |
| 1642 | /// containsTCInfoTy tries to determine whether the following conditions |
| 1643 | /// are true: |
| 1644 | /// |
| 1645 | /// 1. The last memory access modeling an array, MA1, represents writing to |
| 1646 | /// memory and has the form S(..., I, ..., J, ...) -> M(shuffle(I, J)), |
| 1647 | /// where S is the SCoP statement under consideration and shuffle(I, J) |
| 1648 | /// is a permutation of indexes of sets I and J. |
| 1649 | /// 2. There are only true dependencies of the form |
| 1650 | /// S(..., ki, max(k(i + 1)), ..., max(kn), ...) -> |
| 1651 | /// S(..., ki + 1, min(k(i + 1)), ..., min(kn), ...), where S is the SCoP |
| 1652 | /// statement represented by @p Schedule and ki are indexes of the set P. |
| 1653 | /// 3. SCoP contains an arbitrary number of reads from constants and only three |
| 1654 | /// access relations, MA2, MA3, and MA4 that represent reading from memory |
| 1655 | /// and have the form |
| 1656 | /// S(..., I, ..., P, ...) -> M(shuffle(I, P)), |
| 1657 | /// S(..., P, ..., J, ...) -> M(shuffle(J, P)), |
| 1658 | /// S(...) -> M(shuffle(I, J)), respectively. |
| 1659 | /// |
| 1660 | /// @param PartialSchedule The PartialSchedule that contains a SCoP statement |
| 1661 | /// to check. |
| 1662 | /// @param D The SCoP dependencies. |
| 1663 | /// @param TCI Parameters of the tensor contraction operands. |
| 1664 | /// @param Domain The domain of the statement. |
| 1665 | /// @return True if dependencies and memory accesses correspond to the tensor |
| 1666 | /// contraction and false, otherwise. |
| 1667 | static bool containsTCInfoTy(isl::map PartialSchedule, const Dependences *D, |
| 1668 | TCInfoTy &TCI, isl::set Domain) { |
| 1669 | if (!containsOnlyTcDeps(Schedule: PartialSchedule, D, IndexSet&: TCI.P, Domain)) |
| 1670 | return false; |
| 1671 | |
| 1672 | // TODO: handle cases of scalar multiplication if needed. |
| 1673 | if (TCI.P.size() == 0) |
| 1674 | return false; |
| 1675 | |
| 1676 | if (!containsOnlyTCAcc(Domain, PartialSchedule, TCI)) |
| 1677 | return false; |
| 1678 | |
| 1679 | // TODO: handle cases of GEMV if needed. |
| 1680 | if ((TCI.I.size() == 0) || (TCI.J.size() == 0)) |
| 1681 | return false; |
| 1682 | |
| 1683 | return true; |
| 1684 | } |
| 1685 | |
| 1686 | /// Check if this node contains a partial schedule that could |
| 1687 | /// probably be optimized with analytical modeling. |
| 1688 | /// |
| 1689 | /// isTCPattern is used to determine whether the SCoP represents a TC-like |
| 1690 | /// kernel [1], which is a perfectly nested set of loops, with a data usage |
| 1691 | /// pattern that is similar to that produced by the tensor contraction. |
| 1692 | /// |
| 1693 | /// A TC-like kernel can be defined as follows: |
| 1694 | /// |
| 1695 | /// 1. It satisfies the requirements of the polyhedral model. |
| 1696 | /// 2. Without loss of generality, it contains three nonempty bundles of |
| 1697 | /// one-dimensional for-loops with induction variables that are grouped into |
| 1698 | /// bundles I = i0...i(r-1), J = j0..j(s-1), and P = p0...p(t-1), and they |
| 1699 | /// are incremented by one. |
| 1700 | /// 3. The innermost loop body can be represented as a statement of the form |
| 1701 | /// C(shuffle(I, J)) = E(A(shuffle(I, P)), B(shuffle(P, J)), |
| 1702 | /// C(shuffle(I, J))), where A(shuffle(I, P)), B(shuffle(P, J)), |
| 1703 | /// C(shuffle(I, J)) are accesses to tensors A, B, C, respectively, |
| 1704 | /// shuffle(I, J), shuffle(I, P), and shuffle(P, J) are permutations of the |
| 1705 | /// enclosed indices, and E is an expression that contains reads from |
| 1706 | /// the tensors A, B, C, and an arbitrary number of reads from constants |
| 1707 | /// with respect to bundles I, J, and P. |
| 1708 | /// |
| 1709 | /// TC can be considered as a particular case of a TC-like kernel. |
| 1710 | /// |
| 1711 | /// The order of loops with indexes from P should be preserved. Otherwise, |
| 1712 | /// isTCPattern should check if a commutative operation is used. |
| 1713 | /// |
| 1714 | /// isTCPattern performs the following steps to check whether the SCoP |
| 1715 | /// corresponds to a definition of a TC-like kernel: |
| 1716 | /// |
| 1717 | /// 1. Checks that the node is the innermost band node. |
| 1718 | /// 2. Checks that the partial schedule contains only one statement. |
| 1719 | /// 3. Check that all ancestors of the node contain all band nodes for |
| 1720 | /// the statement and only mark nodes interleave such band nodes. This |
| 1721 | /// corresponds to a straightforward implementation of TC. |
| 1722 | /// 4. Analyses the dependencies to determine contraction dimensions. |
| 1723 | /// 5. Check that the last memory access modeling an array, represents writing |
| 1724 | /// to the result of the TC-like kernel. |
| 1725 | /// 6. Check that SCoP contains only three access relations that represent |
| 1726 | /// reading of the operands of the TC-like kernel and an arbitrary number of |
| 1727 | /// reads from constants. |
| 1728 | /// |
| 1729 | /// [1] - Gareev R., Grosser T., Kruse M. High-Performance Generalized Tensor |
| 1730 | /// Operations: A Compiler-Oriented Approach // ACM Transactions |
| 1731 | /// Architecture and Code Optimization (TACO). 2018. |
| 1732 | /// Vol. 15, no. 3. P. 34:1–34:27. DOI: 10.1145/3235029. |
| 1733 | /// |
| 1734 | /// If this is the case, we could logically represent tensors as matrices and |
| 1735 | /// apply algorithms, which are used to get close-to-peak performance of |
| 1736 | /// matrix multiplications in manually tuned BLAS libraries (e.g., BLIS). |
| 1737 | /// |
| 1738 | /// @param Node The node to check. |
| 1739 | /// @param D The SCoP dependencies. |
| 1740 | /// @param TCI Parameters of the tensor contraction operands. |
| 1741 | static bool isTCPattern(isl::schedule_node Node, const Dependences *D, |
| 1742 | TCInfoTy &TCI) { |
| 1743 | Node = Node.child(pos: 0); |
| 1744 | isl::union_map PartialSchedule = Node.get_prefix_schedule_union_map(); |
| 1745 | isl::union_set Domain = Node.domain(); |
| 1746 | Node = Node.parent(); |
| 1747 | |
| 1748 | // The partial schedule should contain only one statement. |
| 1749 | // TODO: This constraint should not be intrinsic to the algorithm. |
| 1750 | if (isl_union_set_n_set(uset: Domain.get()) != 1) |
| 1751 | return false; |
| 1752 | |
| 1753 | isl_schedule_node_type NodeType = isl_schedule_node_get_type(node: Node.get()); |
| 1754 | |
| 1755 | // Check that all ancestors of the node contain all band nodes for |
| 1756 | // the statement, which represents the TC-like kernel, and only mark nodes |
| 1757 | // interleave such band nodes. This corresponds to a straightforward |
| 1758 | // implementation of TC with/without DeLICM applied. |
| 1759 | // |
| 1760 | // For example, this covers the matrix multiplication pattern after a full |
| 1761 | // run of -polly-optree and -polly-delicm, where the write access is not |
| 1762 | // through the original memory access, but through a PHI node that was |
| 1763 | // delicmed. Subsequently, such band nodes will be replaced by a single band |
| 1764 | // node. |
| 1765 | // |
| 1766 | // The corresponding schedule can be the following, where Stmt_for_body8 |
| 1767 | // contains the matrix multiplication: |
| 1768 | // |
| 1769 | // domain: "{ Stmt_for_body8[i0, i1, i2] : 0 <= i0 <= 1599 and |
| 1770 | // 0 <= i1 <= 1799 and |
| 1771 | // 0 <= i2 <= 2199; |
| 1772 | // Stmt_for_body3[i0, i1] : 0 <= i0 <= 1599 and |
| 1773 | // 0 <= i1 <= 1799; |
| 1774 | // Stmt_for_body3_last[i0, i1] : 0 <= i0 <= 1599 and |
| 1775 | // 0 <= i1 <= 1799 }" |
| 1776 | // child: |
| 1777 | // sequence: |
| 1778 | // - filter: "{ Stmt_for_body3[i0, i1] }" |
| 1779 | // child: |
| 1780 | // schedule: "[{ Stmt_for_body3[i0, i1] -> [(i0)] }, |
| 1781 | // { Stmt_for_body3[i0, i1] -> [(i1)] }]" |
| 1782 | // permutable: 1 |
| 1783 | // coincident: [ 1, 1 ] |
| 1784 | // - filter: "{ Stmt_for_body3_last[i0, i1] }" |
| 1785 | // child: |
| 1786 | // schedule: "[{ Stmt_for_body3_last[i0, i1] -> [(i0)] }, |
| 1787 | // { Stmt_for_body3_last[i0, i1] -> [(i1)] }]" |
| 1788 | // permutable: 1 |
| 1789 | // coincident: [ 1, 1 ] |
| 1790 | // - filter: "{ Stmt_for_body8[i0, i1, i2] }" |
| 1791 | // child: |
| 1792 | // schedule: "[{ Stmt_for_body8[i0, i1, i2] -> [(i0)] }, |
| 1793 | // { Stmt_for_body8[i0, i1, i2] -> [(i1)] }, |
| 1794 | // { Stmt_for_body8[i0, i1, i2] -> [(i2)] }]" |
| 1795 | // permutable: 1 |
| 1796 | // coincident: [ 1, 1, 0 ] |
| 1797 | // |
| 1798 | while (NodeType != isl_schedule_node_domain) { |
| 1799 | if (NodeType == isl_schedule_node_filter) { |
| 1800 | if (!Node.parent().isa<isl::schedule_node_sequence>() || |
| 1801 | !Node.parent().parent().isa<isl::schedule_node_domain>()) |
| 1802 | return false; |
| 1803 | break; |
| 1804 | } |
| 1805 | |
| 1806 | if ((NodeType != isl_schedule_node_band) && |
| 1807 | (NodeType != isl_schedule_node_mark)) |
| 1808 | return false; |
| 1809 | |
| 1810 | Node = Node.parent(); |
| 1811 | NodeType = isl_schedule_node_get_type(node: Node.get()); |
| 1812 | } |
| 1813 | |
| 1814 | isl::map PartialScheduleMap = isl::map::from_union_map(umap: PartialSchedule); |
| 1815 | if (containsTCInfoTy(PartialSchedule: PartialScheduleMap, D, TCI, Domain: isl::set(Domain))) |
| 1816 | return true; |
| 1817 | |
| 1818 | return false; |
| 1819 | } |
| 1820 | |
| 1821 | } // namespace |
| 1822 | |
| 1823 | isl::schedule_node |
| 1824 | polly::tryOptimizeMatMulPattern(isl::schedule_node Node, |
| 1825 | const llvm::TargetTransformInfo *TTI, |
| 1826 | const Dependences *D) { |
| 1827 | TCInfoTy TCI; |
| 1828 | if (PMBasedTCOpts && isTCPattern(Node, D, TCI)) |
| 1829 | POLLY_DEBUG(dbgs() << "The tensor contraction pattern was detected\n" ); |
| 1830 | MatMulInfoTy MMI; |
| 1831 | if (PMBasedMMMOpts && isMatrMultPattern(Node, D, MMI)) { |
| 1832 | POLLY_DEBUG(dbgs() << "The matrix multiplication pattern was detected\n" ); |
| 1833 | return optimizeMatMulPattern(Node, TTI, MMI); |
| 1834 | } |
| 1835 | return {}; |
| 1836 | } |
| 1837 | |