| 1 | //===- VectorShuffleTreeBuilder.cpp ----- Vector shuffle tree builder -----===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements pattern rewrites to lower sequences of |
| 10 | // `vector.to_elements` and `vector.from_elements` operations into a tree of |
| 11 | // `vector.shuffle` operations. |
| 12 | // |
| 13 | //===----------------------------------------------------------------------===// |
| 14 | |
| 15 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 16 | #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
| 17 | #include "mlir/Dialect/Vector/Transforms/Passes.h" |
| 18 | #include "mlir/Rewrite/FrozenRewritePatternSet.h" |
| 19 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 20 | #include "llvm/ADT/DenseMap.h" |
| 21 | #include "llvm/Support/Debug.h" |
| 22 | #include "llvm/Support/MathExtras.h" |
| 23 | #include "llvm/Support/raw_ostream.h" |
| 24 | |
| 25 | namespace mlir { |
| 26 | namespace vector { |
| 27 | |
| 28 | #define GEN_PASS_DEF_LOWERVECTORTOFROMELEMENTSTOSHUFFLETREE |
| 29 | #include "mlir/Dialect/Vector/Transforms/Passes.h.inc" |
| 30 | |
| 31 | } // namespace vector |
| 32 | } // namespace mlir |
| 33 | |
| 34 | #define DEBUG_TYPE "lower-vector-to-from-elements-to-shuffle-tree" |
| 35 | |
| 36 | using namespace mlir; |
| 37 | using namespace mlir::vector; |
| 38 | |
| 39 | namespace { |
| 40 | |
| 41 | // Indentation unit for debug output formatting. |
| 42 | [[maybe_unused]] constexpr unsigned kIndScale = 2; |
| 43 | |
| 44 | /// Represents a closed interval of elements (e.g., [0, 7] = 8 elements). |
| 45 | using Interval = std::pair<unsigned, unsigned>; |
| 46 | // Sentinel value for uninitialized intervals. |
| 47 | constexpr unsigned kMaxUnsigned = std::numeric_limits<unsigned>::max(); |
| 48 | |
| 49 | /// The VectorShuffleTreeBuilder builds a balanced binary tree of |
| 50 | /// `vector.shuffle` operations from one or more `vector.to_elements` |
| 51 | /// operations feeding a single `vector.from_elements` operation. |
| 52 | /// |
| 53 | /// The implementation generates hardware-agnostic `vector.shuffle` operations |
| 54 | /// that minimize both the number of shuffle operations and the length of |
| 55 | /// intermediate vectors (to the extent possible). The tree has the |
| 56 | /// following properties: |
| 57 | /// |
| 58 | /// 1. Vectors are shuffled in pairs by order of appearance in |
| 59 | /// the `vector.from_elements` operand list. |
| 60 | /// 2. Each vector at each level is used only once. |
| 61 | /// 3. The number of levels in the tree is: |
| 62 | /// 1 (input vectors) + ceil(max(1,log2(# `vector.to_elements` ops))). |
| 63 | /// 4. Vectors at each level of the tree have the same vector length. |
| 64 | /// 5. Vector positions that do not need to be shuffled are represented with |
| 65 | /// poison in the shuffle mask. |
| 66 | /// |
| 67 | /// Examples #1: Concatenation of 3x vector<4xf32> to vector<12xf32>: |
| 68 | /// |
| 69 | /// %0:4 = vector.to_elements %a : vector<4xf32> |
| 70 | /// %1:4 = vector.to_elements %b : vector<4xf32> |
| 71 | /// %2:4 = vector.to_elements %c : vector<4xf32> |
| 72 | /// %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1, |
| 73 | /// %1#2, %1#3, %2#0, %2#1, %2#2, %2#3 |
| 74 | /// : vector<12xf32> |
| 75 | /// => |
| 76 | /// |
| 77 | /// %shuffle0 = vector.shuffle %a, %b [0, 1, 2, 3, 4, 5, 6, 7] |
| 78 | /// : vector<4xf32>, vector<4xf32> |
| 79 | /// %shuffle1 = vector.shuffle %c, %c [0, 1, 2, 3, -1, -1, -1, -1] |
| 80 | /// : vector<4xf32>, vector<4xf32> |
| 81 | /// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 2, 3, 4, 5, |
| 82 | /// 6, 7, 8, 9, 10, 11] |
| 83 | /// : vector<8xf32>, vector<8xf32> |
| 84 | /// |
| 85 | /// Comments: |
| 86 | /// * The shuffle tree has three levels: |
| 87 | /// - Level 0 = (%a, %b, %c, %c) |
| 88 | /// - Level 1 = (%shuffle0, %shuffle1) |
| 89 | /// - Level 2 = (%result) |
| 90 | /// * `%a` and `%b` are shuffled first because they appear first in the |
| 91 | /// `vector.from_elements` operand list (`%0#0` and `%1#0`). |
| 92 | /// * `%c` is shuffled with itself because the number of |
| 93 | /// `vector.from_elements` operands is odd. |
| 94 | /// * The vector length for level 1 and level 2 are 8 and 16, respectively. |
| 95 | /// * `%shuffle1` uses poison values to match the vector length of its |
| 96 | /// tree level (8). |
| 97 | /// |
| 98 | /// |
| 99 | /// Example #2: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: |
| 100 | /// |
| 101 | /// %0:5 = vector.to_elements %a : vector<5xf32> |
| 102 | /// %1:5 = vector.to_elements %b : vector<5xf32> |
| 103 | /// %2:5 = vector.to_elements %c : vector<5xf32> |
| 104 | /// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, |
| 105 | /// %2#2, %2#0, %1#1, %0#4 : vector<9xf32> |
| 106 | /// => |
| 107 | /// |
| 108 | /// %shuffle0 = vector.shuffle %[[C]], %[[B]] [2, 6, -1, -1, 7, 2, 0, 6] |
| 109 | /// : vector<5xf32>, vector<5xf32> |
| 110 | /// %shuffle1 = vector.shuffle %[[A]], %[[A]] [1, 1, -1, -1, -1, -1, 4, -1] |
| 111 | /// : vector<5xf32>, vector<5xf32> |
| 112 | /// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 8, 9, 4, 5, 6, 7, 14] |
| 113 | /// : vector<8xf32>, vector<8xf32> |
| 114 | /// |
| 115 | /// Comments: |
| 116 | /// * `%c` and `%b` are shuffled first because they appear first in the |
| 117 | /// `vector.from_elements` operand list (`%2#2` and `%1#1`). |
| 118 | /// * `%a` is shuffled with itself because the number of |
| 119 | /// `vector.from_elements` operands is odd. |
| 120 | /// * The vector length for level 1 and level 2 are 8 and 9, respectively. |
| 121 | /// * `%shuffle0` uses poison values to mark unused vector positions and |
| 122 | /// match the vector length of its tree level (8). |
| 123 | /// |
| 124 | /// TODO: Implement mask compression to reduce the number of intermediate poison |
| 125 | /// values. |
| 126 | class VectorShuffleTreeBuilder { |
| 127 | public: |
| 128 | VectorShuffleTreeBuilder() = delete; |
| 129 | VectorShuffleTreeBuilder(FromElementsOp fromElemOp, |
| 130 | ArrayRef<ToElementsOp> toElemDefs); |
| 131 | |
| 132 | /// Analyze the input `vector.to_elements` + `vector.from_elements` sequence |
| 133 | /// and compute the shuffle tree configuration. This method does not generate |
| 134 | /// any IR. |
| 135 | LogicalResult computeShuffleTree(); |
| 136 | |
| 137 | /// Materialize the shuffle tree configuration computed by |
| 138 | /// `computeShuffleTree` in the IR. |
| 139 | Value generateShuffleTree(PatternRewriter &rewriter); |
| 140 | |
| 141 | private: |
| 142 | // IR input information. |
| 143 | FromElementsOp fromElemsOp; |
| 144 | SmallVector<ToElementsOp> toElemsDefs; |
| 145 | |
| 146 | // Shuffle tree configuration. |
| 147 | unsigned numLevels; |
| 148 | SmallVector<unsigned> vectorSizePerLevel; |
| 149 | /// Holds the range of positions each vector in the tree contributes to in the |
| 150 | /// final output vector. |
| 151 | SmallVector<SmallVector<Interval>> intervalsPerLevel; |
| 152 | |
| 153 | // Utility methods to compute the shuffle tree configuration. |
| 154 | void computeShuffleTreeIntervals(); |
| 155 | void computeShuffleTreeVectorSizes(); |
| 156 | |
| 157 | /// Dump the shuffle tree configuration. |
| 158 | void dump(); |
| 159 | }; |
| 160 | |
| 161 | VectorShuffleTreeBuilder::VectorShuffleTreeBuilder( |
| 162 | FromElementsOp fromElemOp, ArrayRef<ToElementsOp> toElemDefs) |
| 163 | : fromElemsOp(fromElemOp), toElemsDefs(toElemDefs) { |
| 164 | assert(fromElemsOp && "from_elements op is required" ); |
| 165 | assert(!toElemsDefs.empty() && "At least one to_elements op is required" ); |
| 166 | } |
| 167 | |
| 168 | /// Duplicate the last operation, value or interval if the total number of them |
| 169 | /// is odd. This is useful to simplify the shuffle tree algorithm given that |
| 170 | /// vectors are shuffled in pairs and duplication would lead to the last shuffle |
| 171 | /// to have a single (duplicated) input vector. |
| 172 | template <typename T> |
| 173 | static void duplicateLastIfOdd(SmallVectorImpl<T> &values) { |
| 174 | if (values.size() % 2 != 0) |
| 175 | values.push_back(values.back()); |
| 176 | } |
| 177 | |
| 178 | // ===---------------------------------------------------------------------===// |
| 179 | // Shuffle Tree Analysis Utilities. |
| 180 | // ===---------------------------------------------------------------------===// |
| 181 | |
| 182 | /// Compute the intervals for all the vectors in the shuffle tree. The interval |
| 183 | /// of a vector is the range of positions that the vector contributes to in the |
| 184 | /// final output vector. |
| 185 | /// |
| 186 | /// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: |
| 187 | /// |
| 188 | /// %0:5 = vector.to_elements %a : vector<5xf32> |
| 189 | /// %1:5 = vector.to_elements %b : vector<5xf32> |
| 190 | /// %2:5 = vector.to_elements %c : vector<5xf32> |
| 191 | /// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, |
| 192 | /// %2#2, %2#0, %1#1, %0#4 : vector<9xf32> |
| 193 | /// |
| 194 | /// The shuffle tree has 3 levels. Level 0 has 4 vectors (%2, %1, %0, %0, the |
| 195 | /// last one is duplicated to make the number of inputs even) so we compute the |
| 196 | /// interval for each vector: |
| 197 | /// |
| 198 | /// * intervalsPerLevel[0][0] = interval(%2) = [0,6] |
| 199 | /// * intervalsPerLevel[0][1] = interval(%1) = [1,7] |
| 200 | /// * intervalsPerLevel[0][2] = interval(%0) = [2,8] |
| 201 | /// * intervalsPerLevel[0][3] = interval(%0) = [2,8] |
| 202 | /// |
| 203 | /// Level 1 has 2 vectors, resulting from the shuffling of %2 + %1 and %0 + %0 |
| 204 | /// so we compute the intervals for each vector at level 1 as: |
| 205 | /// * intervalsPerLevel[1][0] = intervalsPerLevel[0][0] U |
| 206 | /// intervalsPerLevel[0][1] = [0,7] |
| 207 | /// * intervalsPerLevel[1][1] = intervalsPerLevel[0][2] U |
| 208 | /// intervalsPerLevel[0][3] = [2,8] |
| 209 | /// |
| 210 | /// Level 2 is the last level and only contains the output vector so the |
| 211 | /// interval should be the whole output vector: |
| 212 | /// * intervalsPerLevel[2][0] = intervalsPerLevel[1][0] U |
| 213 | /// intervalsPerLevel[1][1] = [0,8] |
| 214 | /// |
| 215 | void VectorShuffleTreeBuilder::computeShuffleTreeIntervals() { |
| 216 | // Map `vector.to_elements` ops to their ordinal position in the |
| 217 | // `vector.from_elements` operand list. Make sure duplicated |
| 218 | // `vector.to_elements` ops are mapped to the its first occurrence. |
| 219 | DenseMap<ToElementsOp, unsigned> toElemsToInputOrdinal; |
| 220 | for (const auto &[idx, toElemsOp] : llvm::enumerate(First&: toElemsDefs)) |
| 221 | toElemsToInputOrdinal.insert(KV: {toElemsOp, idx}); |
| 222 | |
| 223 | // Compute intervals for each vector in the shuffle tree. The first |
| 224 | // level computation is special-cased to keep the implementation simpler. |
| 225 | |
| 226 | SmallVector<Interval> firstLevelIntervals(toElemsDefs.size(), |
| 227 | {kMaxUnsigned, kMaxUnsigned}); |
| 228 | |
| 229 | for (const auto &[idx, element] : |
| 230 | llvm::enumerate(First: fromElemsOp.getElements())) { |
| 231 | auto toElemsOp = cast<ToElementsOp>(Val: element.getDefiningOp()); |
| 232 | unsigned inputIdx = toElemsToInputOrdinal[toElemsOp]; |
| 233 | Interval ¤tInterval = firstLevelIntervals[inputIdx]; |
| 234 | |
| 235 | // Set lower bound to the first occurrence of the `vector.to_elements`. |
| 236 | if (currentInterval.first == kMaxUnsigned) |
| 237 | currentInterval.first = idx; |
| 238 | |
| 239 | // Set upper bound to the last occurrence of the `vector.to_elements`. |
| 240 | currentInterval.second = idx; |
| 241 | } |
| 242 | |
| 243 | duplicateLastIfOdd(values&: toElemsDefs); |
| 244 | duplicateLastIfOdd(values&: firstLevelIntervals); |
| 245 | intervalsPerLevel.push_back(Elt: std::move(firstLevelIntervals)); |
| 246 | |
| 247 | // Compute intervals for the remaining levels. |
| 248 | for (unsigned level = 1; level < numLevels; ++level) { |
| 249 | bool isLastLevel = level == numLevels - 1; |
| 250 | const auto &prevLevelIntervals = intervalsPerLevel[level - 1]; |
| 251 | SmallVector<Interval> currentLevelIntervals( |
| 252 | llvm::divideCeil(Numerator: prevLevelIntervals.size(), Denominator: 2), |
| 253 | {kMaxUnsigned, kMaxUnsigned}); |
| 254 | |
| 255 | size_t currentNumLevels = currentLevelIntervals.size(); |
| 256 | for (size_t inputIdx = 0; inputIdx < currentNumLevels; ++inputIdx) { |
| 257 | auto &interval = currentLevelIntervals[inputIdx]; |
| 258 | const auto &prevLhsInterval = prevLevelIntervals[inputIdx * 2]; |
| 259 | const auto &prevRhsInterval = prevLevelIntervals[inputIdx * 2 + 1]; |
| 260 | |
| 261 | // The interval of a vector at the current level is the union of the |
| 262 | // intervals of the two vectors from the previous level being shuffled at |
| 263 | // this level. |
| 264 | interval.first = prevLhsInterval.first; |
| 265 | interval.second = |
| 266 | std::max(a: prevLhsInterval.second, b: prevRhsInterval.second); |
| 267 | } |
| 268 | |
| 269 | // Duplicate the last interval if the number of intervals is odd, except for |
| 270 | // the last level as it only contains the output vector, which doesn't have |
| 271 | // to be shuffled. |
| 272 | if (!isLastLevel) |
| 273 | duplicateLastIfOdd(values&: currentLevelIntervals); |
| 274 | |
| 275 | intervalsPerLevel.push_back(Elt: std::move(currentLevelIntervals)); |
| 276 | } |
| 277 | } |
| 278 | |
| 279 | /// Compute the uniform vector size for each level of the shuffle tree, given |
| 280 | /// the intervals of the vectors at each level. The vector size of a level is |
| 281 | /// the size of the widest interval at that level. |
| 282 | /// |
| 283 | /// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: |
| 284 | /// |
| 285 | /// Intervals: |
| 286 | /// * Level 0: [0,6], [1,7], [2,8], [2,8] |
| 287 | /// * Level 1: [0,7], [2,8] |
| 288 | /// * Level 2: [0,8] |
| 289 | /// |
| 290 | /// Vector sizes: |
| 291 | /// * Level 0: Arbitrary sizes from input vectors. |
| 292 | /// * Level 1: max(size_of([0,7]) = 8, size_of([2,8]) = 7) = 8 |
| 293 | /// * Level 2: max(size_of([0,8]) = 9) = 9 |
| 294 | /// |
| 295 | void VectorShuffleTreeBuilder::computeShuffleTreeVectorSizes() { |
| 296 | // Compute vector size for each level. There are two direct cases: |
| 297 | // * First level: the vector size depends on the actual size of the input |
| 298 | // vectors and it's allowed to be non-uniform. We set it to 0. |
| 299 | // * Last level: the vector size is the output vector size so it doesn't |
| 300 | // have to be computed using intervals. |
| 301 | vectorSizePerLevel.front() = 0; |
| 302 | vectorSizePerLevel.back() = |
| 303 | cast<VectorType>(Val: fromElemsOp.getResult().getType()).getNumElements(); |
| 304 | |
| 305 | for (unsigned level = 1; level < numLevels - 1; ++level) { |
| 306 | const auto ¤tLevelIntervals = intervalsPerLevel[level]; |
| 307 | unsigned currentVectorSize = 1; |
| 308 | size_t numIntervals = currentLevelIntervals.size(); |
| 309 | for (size_t i = 0; i < numIntervals; ++i) { |
| 310 | const auto &interval = currentLevelIntervals[i]; |
| 311 | unsigned intervalSize = interval.second - interval.first + 1; |
| 312 | currentVectorSize = std::max(a: currentVectorSize, b: intervalSize); |
| 313 | } |
| 314 | assert(currentVectorSize > 0 && "vector size must be positive" ); |
| 315 | vectorSizePerLevel[level] = currentVectorSize; |
| 316 | } |
| 317 | } |
| 318 | |
| 319 | void VectorShuffleTreeBuilder::dump() { |
| 320 | LLVM_DEBUG({ |
| 321 | unsigned indLv = 0; |
| 322 | |
| 323 | llvm::dbgs() << "VectorShuffleTreeBuilder Configuration:\n" ; |
| 324 | ++indLv; |
| 325 | llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Inputs:\n" ; |
| 326 | ++indLv; |
| 327 | for (const auto &toElemsOp : toElemsDefs) |
| 328 | llvm::dbgs() << llvm::indent(indLv, kIndScale) << toElemsOp << "\n" ; |
| 329 | llvm::dbgs() << llvm::indent(indLv, kIndScale) << fromElemsOp << "\n\n" ; |
| 330 | --indLv; |
| 331 | |
| 332 | llvm::dbgs() << llvm::indent(indLv, kIndScale) |
| 333 | << "* Total levels: " << numLevels << "\n" ; |
| 334 | llvm::dbgs() << llvm::indent(indLv, kIndScale) |
| 335 | << "* Vector sizes per level: " ; |
| 336 | llvm::interleaveComma(vectorSizePerLevel, llvm::dbgs()); |
| 337 | llvm::dbgs() << "\n" ; |
| 338 | llvm::dbgs() << llvm::indent(indLv, kIndScale) |
| 339 | << "* Input intervals per level:\n" ; |
| 340 | ++indLv; |
| 341 | for (const auto &[level, intervals] : llvm::enumerate(intervalsPerLevel)) { |
| 342 | llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Level " << level |
| 343 | << ": " ; |
| 344 | llvm::interleaveComma(intervals, llvm::dbgs(), |
| 345 | [](const Interval &interval) { |
| 346 | llvm::dbgs() << "[" << interval.first << "," |
| 347 | << interval.second << "]" ; |
| 348 | }); |
| 349 | llvm::dbgs() << "\n" ; |
| 350 | } |
| 351 | }); |
| 352 | } |
| 353 | |
| 354 | /// Compute the shuffle tree configuration for the given `vector.to_elements` + |
| 355 | /// `vector.from_elements` input sequence. This method builds a balanced binary |
| 356 | /// shuffle tree that combines pairs of vectors at each level. |
| 357 | /// |
| 358 | /// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: |
| 359 | /// |
| 360 | /// %0:5 = vector.to_elements %a : vector<5xf32> |
| 361 | /// %1:5 = vector.to_elements %b : vector<5xf32> |
| 362 | /// %2:5 = vector.to_elements %c : vector<5xf32> |
| 363 | /// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, |
| 364 | /// %2#2, %2#0, %1#1, %0#4 : vector<9xf32> |
| 365 | /// |
| 366 | /// build a tree that looks like: |
| 367 | /// |
| 368 | /// %2 %1 %0 %0 |
| 369 | /// \ / \ / |
| 370 | /// %2_1 = vector.shuffle %0_0 = vector.shuffle |
| 371 | /// \ / |
| 372 | /// %2_1_0_0 =vector.shuffle |
| 373 | /// |
| 374 | /// The actual representation of the shuffle tree configuration is based on |
| 375 | /// intervals of each vector at each level of the shuffle tree (i.e., %2, %1, |
| 376 | /// %0, %0, %2_1, %0_0 and %2_1_0_0) and the vector size for each level. For |
| 377 | /// further details on intervals and vector size computation, please, take a |
| 378 | /// look at the corresponding utility functions. |
| 379 | LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() { |
| 380 | // Initialize shuffle tree information based on its size. For the number of |
| 381 | // levels, we add one to account for the input `vector.to_elements` as one |
| 382 | // tree level. We need the std::max(1) to account for a single element input. |
| 383 | numLevels = 1u + std::max(a: 1u, b: llvm::Log2_64_Ceil(Value: toElemsDefs.size())); |
| 384 | vectorSizePerLevel.resize(N: numLevels, NV: 0); |
| 385 | intervalsPerLevel.reserve(N: numLevels); |
| 386 | |
| 387 | computeShuffleTreeIntervals(); |
| 388 | computeShuffleTreeVectorSizes(); |
| 389 | dump(); |
| 390 | |
| 391 | return success(); |
| 392 | } |
| 393 | |
| 394 | // ===---------------------------------------------------------------------===// |
| 395 | // Shuffle Tree Code Generation Utilities. |
| 396 | // ===---------------------------------------------------------------------===// |
| 397 | |
| 398 | /// Compute the permutation mask for shuffling two input `vector.to_elements` |
| 399 | /// ops. The permutation mask is the mapping of the vector elements to their |
| 400 | /// final position in the output vector, relative to the intermediate output |
| 401 | /// vector of the `vector.shuffle` operation combining the two inputs. |
| 402 | /// |
| 403 | /// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: |
| 404 | /// |
| 405 | /// %0:5 = vector.to_elements %a : vector<5xf32> |
| 406 | /// %1:5 = vector.to_elements %b : vector<5xf32> |
| 407 | /// %2:5 = vector.to_elements %c : vector<5xf32> |
| 408 | /// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, |
| 409 | /// %2#2, %2#0, %1#1, %0#4 : vector<9xf32> |
| 410 | /// |
| 411 | /// => |
| 412 | /// |
| 413 | /// // Level 1, vector length = 8 |
| 414 | /// %2_1 = PermutationShuffleMask(%2, %1) = [2, 6, -1, -1, 7, 2, 0, 6] |
| 415 | /// %0_0 = PermutationShuffleMask(%0, %0) = [1, 1, -1, -1, -1, -1, 4, -1] |
| 416 | /// |
| 417 | /// TODO: Implement mask compression to reduce the number of intermediate poison |
| 418 | /// values. |
| 419 | static SmallVector<int64_t> computePermutationShuffleMask( |
| 420 | ToElementsOp toElementOp0, const Interval &interval0, |
| 421 | ToElementsOp toElementOp1, const Interval &interval1, |
| 422 | FromElementsOp fromElemsOp, unsigned outputVectorSize) { |
| 423 | SmallVector<int64_t> mask(outputVectorSize, ShuffleOp::kPoisonIndex); |
| 424 | unsigned inputVectorSize = |
| 425 | toElementOp0.getSource().getType().getNumElements(); |
| 426 | |
| 427 | for (const auto &[inputIdx, element] : |
| 428 | llvm::enumerate(First: fromElemsOp.getElements())) { |
| 429 | auto currentToElemOp = cast<ToElementsOp>(Val: element.getDefiningOp()); |
| 430 | // Match `vector.from_elements` operands to the two input ops. |
| 431 | if (currentToElemOp != toElementOp0 && currentToElemOp != toElementOp1) |
| 432 | continue; |
| 433 | |
| 434 | // The permutation value for a particular operand is the ordinal position of |
| 435 | // the operand in the `vector.to_elements` list of results. |
| 436 | unsigned permVal = cast<OpResult>(Val&: element).getResultNumber(); |
| 437 | unsigned maskIdx = inputIdx; |
| 438 | |
| 439 | // The mask index is the ordinal position of the operand in |
| 440 | // `vector.from_elements` operand list. We make this position relative to |
| 441 | // the output interval resulting from combining the two input intervals. |
| 442 | if (currentToElemOp == toElementOp0) { |
| 443 | maskIdx -= interval0.first; |
| 444 | } else { |
| 445 | // currentToElemOp == toElementOp1 |
| 446 | unsigned intervalOffset = interval1.first - interval0.first; |
| 447 | maskIdx += intervalOffset - interval1.first; |
| 448 | permVal += inputVectorSize; |
| 449 | } |
| 450 | |
| 451 | mask[maskIdx] = permVal; |
| 452 | } |
| 453 | |
| 454 | LLVM_DEBUG({ |
| 455 | unsigned indLv = 1; |
| 456 | llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Permutation mask: [" ; |
| 457 | llvm::interleaveComma(mask, llvm::dbgs()); |
| 458 | llvm::dbgs() << "]\n" ; |
| 459 | ++indLv; |
| 460 | llvm::dbgs() << llvm::indent(indLv, kIndScale) |
| 461 | << "* Combining: " << toElementOp0 << " and " << toElementOp1 |
| 462 | << "\n" ; |
| 463 | }); |
| 464 | |
| 465 | return mask; |
| 466 | } |
| 467 | |
| 468 | /// Compute the propagation shuffle mask for combining two intermediate shuffle |
| 469 | /// operations of the tree. The propagation shuffle mask is the mapping of the |
| 470 | /// intermediate vector elements, which have already been shuffled to their |
| 471 | /// relative output position using the mask generated by |
| 472 | /// `computePermutationShuffleMask`, to their next position in the tree. |
| 473 | /// |
| 474 | /// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: |
| 475 | /// |
| 476 | /// %0:5 = vector.to_elements %a : vector<5xf32> |
| 477 | /// %1:5 = vector.to_elements %b : vector<5xf32> |
| 478 | /// %2:5 = vector.to_elements %c : vector<5xf32> |
| 479 | /// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, |
| 480 | /// %2#2, %2#0, %1#1, %0#4 : vector<9xf32> |
| 481 | /// |
| 482 | /// // Level 1, vector length = 8 |
| 483 | /// %2_1 = PermutationShuffleMask(%2, %1) = [2, 6, -1, -1, 7, 2, 0, 6] |
| 484 | /// %0_0 = PermutationShuffleMask(%0, %0) = [1, 1, -1, -1, -1, -1, 4, -1] |
| 485 | /// |
| 486 | /// => |
| 487 | /// |
| 488 | /// // Level 2, vector length = 9 |
| 489 | /// PropagationShuffleMask(%2_1, %0_0) = [0, 1, 8, 9, 4, 5, 6, 7, 14] |
| 490 | /// |
| 491 | /// TODO: Implement mask compression to reduce the number of intermediate poison |
| 492 | /// values. |
| 493 | static SmallVector<int64_t> computePropagationShuffleMask( |
| 494 | ShuffleOp lhsShuffleOp, const Interval &lhsInterval, ShuffleOp rhsShuffleOp, |
| 495 | const Interval &rhsInterval, unsigned outputVectorSize) { |
| 496 | ArrayRef<int64_t> lhsShuffleMask = lhsShuffleOp.getMask(); |
| 497 | ArrayRef<int64_t> rhsShuffleMask = rhsShuffleOp.getMask(); |
| 498 | unsigned inputVectorSize = lhsShuffleMask.size(); |
| 499 | assert(inputVectorSize == rhsShuffleMask.size() && |
| 500 | "Expected both shuffle masks to have the same size" ); |
| 501 | |
| 502 | bool hasSameInput = lhsShuffleOp == rhsShuffleOp; |
| 503 | unsigned lhsRhsOffset = rhsInterval.first - lhsInterval.first; |
| 504 | SmallVector<int64_t> mask(outputVectorSize, ShuffleOp::kPoisonIndex); |
| 505 | |
| 506 | // Propagate any element from the input mask that is not poison. For the RHS |
| 507 | // vector, offset mask index by the distance between the intervals. |
| 508 | for (unsigned i = 0; i < inputVectorSize; ++i) { |
| 509 | if (lhsShuffleMask[i] != ShuffleOp::kPoisonIndex) |
| 510 | mask[i] = i; |
| 511 | |
| 512 | if (hasSameInput) |
| 513 | continue; |
| 514 | |
| 515 | unsigned rhsIdx = i + lhsRhsOffset; |
| 516 | if (rhsShuffleMask[i] != ShuffleOp::kPoisonIndex) { |
| 517 | assert(rhsIdx < outputVectorSize && "RHS index out of bounds" ); |
| 518 | assert(mask[rhsIdx] == ShuffleOp::kPoisonIndex && "mask already set" ); |
| 519 | mask[rhsIdx] = i + inputVectorSize; |
| 520 | } |
| 521 | } |
| 522 | |
| 523 | LLVM_DEBUG({ |
| 524 | unsigned indLv = 1; |
| 525 | llvm::dbgs() << llvm::indent(indLv, kIndScale) |
| 526 | << "* Propagation shuffle mask computation:\n" ; |
| 527 | ++indLv; |
| 528 | llvm::dbgs() << llvm::indent(indLv, kIndScale) |
| 529 | << "* LHS shuffle op: " << lhsShuffleOp << "\n" ; |
| 530 | llvm::dbgs() << llvm::indent(indLv, kIndScale) |
| 531 | << "* RHS shuffle op: " << rhsShuffleOp << "\n" ; |
| 532 | llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Result mask: [" ; |
| 533 | llvm::interleaveComma(mask, llvm::dbgs()); |
| 534 | llvm::dbgs() << "]\n" ; |
| 535 | }); |
| 536 | |
| 537 | return mask; |
| 538 | } |
| 539 | |
| 540 | /// Materialize the pre-computed shuffle tree configuration in the IR by |
| 541 | /// generating the corresponding `vector.shuffle` ops. |
| 542 | /// |
| 543 | /// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: |
| 544 | /// |
| 545 | /// %0:5 = vector.to_elements %a : vector<5xf32> |
| 546 | /// %1:5 = vector.to_elements %b : vector<5xf32> |
| 547 | /// %2:5 = vector.to_elements %c : vector<5xf32> |
| 548 | /// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, |
| 549 | /// %2#2, %2#0, %1#1, %0#4 : vector<9xf32> |
| 550 | /// |
| 551 | /// with the pre-computed shuffle tree configuration: |
| 552 | /// |
| 553 | /// * Vector sizes per level: 0, 8, 9 |
| 554 | /// * Input intervals per level: |
| 555 | /// * Level 0: [0,6], [1,7], [2,8], [2,8] |
| 556 | /// * Level 1: [0,7], [2,8] |
| 557 | /// * Level 2: [0,8] |
| 558 | /// |
| 559 | /// => |
| 560 | /// |
| 561 | /// %0 = vector.shuffle %arg2, %arg1 [2, 6, -1, -1, 7, 2, 0, 6] |
| 562 | /// : vector<5xf32>, vector<5xf32> |
| 563 | /// %1 = vector.shuffle %arg0, %arg0 [1, 1, -1, -1, -1, -1, 4, -1] |
| 564 | /// : vector<5xf32>, vector<5xf32> |
| 565 | /// %2 = vector.shuffle %0, %1 [0, 1, 8, 9, 4, 5, 6, 7, 14] |
| 566 | /// : vector<8xf32>, vector<8xf32> |
| 567 | /// |
| 568 | /// The code generation consists of combining pairs of vectors at each level of |
| 569 | /// the tree, using the pre-computed tree intervals and vector sizes. The |
| 570 | /// algorithm generates two kinds of shuffle masks: |
| 571 | /// * Permutation masks: computed for the first level of the tree and permute |
| 572 | /// the input vector elements to their relative position in the final |
| 573 | /// output. |
| 574 | /// * Propagation masks: computed for subsequent levels and propagate the |
| 575 | /// elements to the next level without permutation. |
| 576 | /// |
| 577 | /// For further details on the shuffle mask computation, please, take a look at |
| 578 | /// the corresponding `computePermutationShuffleMask` and |
| 579 | /// `computePropagationShuffleMask` functions. |
| 580 | /// |
| 581 | Value VectorShuffleTreeBuilder::generateShuffleTree(PatternRewriter &rewriter) { |
| 582 | LLVM_DEBUG(llvm::dbgs() << "VectorShuffleTreeBuilder Code Generation:\n" ); |
| 583 | |
| 584 | // Initialize work list with the `vector.to_elements` sources. |
| 585 | SmallVector<Value> levelInputs; |
| 586 | llvm::transform(Range&: toElemsDefs, d_first: std::back_inserter(x&: levelInputs), |
| 587 | F: [](ToElementsOp toElemsOp) { return toElemsOp.getSource(); }); |
| 588 | |
| 589 | // Build shuffle tree by combining pairs of vectors (represented by their |
| 590 | // corresponding intervals) in one level and producing a new vector with the |
| 591 | // next level's vector length. Skip the interval from the last tree level |
| 592 | // (actual shuffle tree output) as it doesn't have to be combined with |
| 593 | // anything else. |
| 594 | Location loc = fromElemsOp.getLoc(); |
| 595 | unsigned currentLevel = 0; |
| 596 | for (const auto &[nextLevelVectorSize, intervals] : |
| 597 | llvm::zip_equal(t: ArrayRef(vectorSizePerLevel).drop_front(), |
| 598 | u: ArrayRef(intervalsPerLevel).drop_back())) { |
| 599 | |
| 600 | duplicateLastIfOdd(values&: levelInputs); |
| 601 | |
| 602 | LLVM_DEBUG(llvm::dbgs() << llvm::indent(1, kIndScale) |
| 603 | << "* Processing level " << currentLevel |
| 604 | << " (output vector size: " << nextLevelVectorSize |
| 605 | << ", # inputs: " << levelInputs.size() << ")\n" ); |
| 606 | |
| 607 | // Process level input vectors in pairs. |
| 608 | SmallVector<Value> levelOutputs; |
| 609 | for (size_t i = 0, numLevelInputs = levelInputs.size(); i < numLevelInputs; |
| 610 | i += 2) { |
| 611 | Value lhsVector = levelInputs[i]; |
| 612 | Value rhsVector = levelInputs[i + 1]; |
| 613 | const Interval &lhsInterval = intervals[i]; |
| 614 | const Interval &rhsInterval = intervals[i + 1]; |
| 615 | |
| 616 | // For the first level of the tree, permute the vector elements to their |
| 617 | // relative position in the final output. For subsequent levels, we |
| 618 | // propagate the elements to the next level without permutation. |
| 619 | SmallVector<int64_t> shuffleMask; |
| 620 | if (currentLevel == 0) { |
| 621 | shuffleMask = computePermutationShuffleMask( |
| 622 | toElementOp0: toElemsDefs[i], interval0: lhsInterval, toElementOp1: toElemsDefs[i + 1], interval1: rhsInterval, |
| 623 | fromElemsOp, outputVectorSize: nextLevelVectorSize); |
| 624 | } else { |
| 625 | auto lhsShuffleOp = cast<ShuffleOp>(Val: lhsVector.getDefiningOp()); |
| 626 | auto rhsShuffleOp = cast<ShuffleOp>(Val: rhsVector.getDefiningOp()); |
| 627 | shuffleMask = computePropagationShuffleMask(lhsShuffleOp, lhsInterval, |
| 628 | rhsShuffleOp, rhsInterval, |
| 629 | outputVectorSize: nextLevelVectorSize); |
| 630 | } |
| 631 | |
| 632 | Value shuffleVal = rewriter.create<vector::ShuffleOp>( |
| 633 | location: loc, args&: lhsVector, args&: rhsVector, args&: shuffleMask); |
| 634 | levelOutputs.push_back(Elt: shuffleVal); |
| 635 | } |
| 636 | |
| 637 | levelInputs = std::move(levelOutputs); |
| 638 | ++currentLevel; |
| 639 | } |
| 640 | |
| 641 | assert(levelInputs.size() == 1 && "Should have exactly one result" ); |
| 642 | return levelInputs.front(); |
| 643 | } |
| 644 | |
| 645 | /// Gather and unique all the `vector.to_elements` operations that feed the |
| 646 | /// `vector.from_elements` operation. The `vector.to_elements` operations are |
| 647 | /// returned in order of appearance in the `vector.from_elements`'s operand |
| 648 | /// list. |
| 649 | static LogicalResult |
| 650 | getToElementsDefiningOps(FromElementsOp fromElemsOp, |
| 651 | SmallVectorImpl<ToElementsOp> &toElemsDefs) { |
| 652 | SetVector<ToElementsOp> toElemsDefsSet; |
| 653 | for (Value element : fromElemsOp.getElements()) { |
| 654 | auto toElemsOp = element.getDefiningOp<ToElementsOp>(); |
| 655 | if (!toElemsOp) |
| 656 | return failure(); |
| 657 | toElemsDefsSet.insert(X: toElemsOp); |
| 658 | } |
| 659 | |
| 660 | toElemsDefs.assign(in_start: toElemsDefsSet.begin(), in_end: toElemsDefsSet.end()); |
| 661 | return success(); |
| 662 | } |
| 663 | |
| 664 | /// Pass to rewrite `vector.to_elements` + `vector.from_elements` sequences into |
| 665 | /// a tree of `vector.shuffle` operations. Only 1-D input vectors are supported |
| 666 | /// for now. |
| 667 | struct ToFromElementsToShuffleTreeRewrite final |
| 668 | : OpRewritePattern<vector::FromElementsOp> { |
| 669 | |
| 670 | using OpRewritePattern::OpRewritePattern; |
| 671 | |
| 672 | LogicalResult matchAndRewrite(vector::FromElementsOp fromElemsOp, |
| 673 | PatternRewriter &rewriter) const override { |
| 674 | VectorType resultType = fromElemsOp.getType(); |
| 675 | if (resultType.getRank() != 1) |
| 676 | return rewriter.notifyMatchFailure( |
| 677 | arg&: fromElemsOp, |
| 678 | msg: "multi-dimensional output vectors are not supported yet" ); |
| 679 | if (resultType.isScalable()) |
| 680 | return rewriter.notifyMatchFailure( |
| 681 | arg&: fromElemsOp, |
| 682 | msg: "'vector.from_elements' does not support scalable vectors" ); |
| 683 | |
| 684 | // Gather all the `vector.to_elements` operations that feed the |
| 685 | // `vector.from_elements` operation. Other op definitions are not supported. |
| 686 | SmallVector<ToElementsOp> toElemsDefs; |
| 687 | if (failed(Result: getToElementsDefiningOps(fromElemsOp, toElemsDefs))) |
| 688 | return rewriter.notifyMatchFailure(arg&: fromElemsOp, msg: "unsupported sources" ); |
| 689 | |
| 690 | if (llvm::any_of(Range&: toElemsDefs, P: [](ToElementsOp toElemsOp) { |
| 691 | return toElemsOp.getSource().getType().getRank() != 1; |
| 692 | })) { |
| 693 | return rewriter.notifyMatchFailure( |
| 694 | arg&: fromElemsOp, msg: "multi-dimensional input vectors are not supported yet" ); |
| 695 | } |
| 696 | |
| 697 | if (llvm::any_of(Range&: toElemsDefs, P: [](ToElementsOp toElemsOp) { |
| 698 | return !toElemsOp.getSource().getType().hasRank(); |
| 699 | })) { |
| 700 | return rewriter.notifyMatchFailure(arg&: fromElemsOp, |
| 701 | msg: "0-D vectors are not supported" ); |
| 702 | } |
| 703 | |
| 704 | // Avoid generating a shuffle tree for trivial `vector.to_elements` -> |
| 705 | // `vector.from_elements` forwarding cases that do not require shuffling. |
| 706 | if (toElemsDefs.size() == 1) { |
| 707 | ToElementsOp toElemsOp0 = toElemsDefs.front(); |
| 708 | if (llvm::equal(LRange: fromElemsOp.getElements(), RRange: toElemsOp0.getResults())) { |
| 709 | return rewriter.notifyMatchFailure( |
| 710 | arg&: fromElemsOp, msg: "trivial forwarding case does not require shuffling" ); |
| 711 | } |
| 712 | } |
| 713 | |
| 714 | VectorShuffleTreeBuilder shuffleTreeBuilder(fromElemsOp, toElemsDefs); |
| 715 | if (failed(Result: shuffleTreeBuilder.computeShuffleTree())) |
| 716 | return rewriter.notifyMatchFailure(arg&: fromElemsOp, |
| 717 | msg: "failed to compute shuffle tree" ); |
| 718 | |
| 719 | Value finalShuffle = shuffleTreeBuilder.generateShuffleTree(rewriter); |
| 720 | rewriter.replaceOp(op: fromElemsOp, newValues: finalShuffle); |
| 721 | return success(); |
| 722 | } |
| 723 | }; |
| 724 | |
| 725 | struct LowerVectorToFromElementsToShuffleTreePass |
| 726 | : public vector::impl::LowerVectorToFromElementsToShuffleTreeBase< |
| 727 | LowerVectorToFromElementsToShuffleTreePass> { |
| 728 | |
| 729 | void runOnOperation() override { |
| 730 | RewritePatternSet patterns(&getContext()); |
| 731 | populateVectorToFromElementsToShuffleTreePatterns(patterns); |
| 732 | |
| 733 | if (failed(Result: applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns)))) |
| 734 | return signalPassFailure(); |
| 735 | } |
| 736 | }; |
| 737 | |
| 738 | } // namespace |
| 739 | |
| 740 | void mlir::vector::populateVectorToFromElementsToShuffleTreePatterns( |
| 741 | RewritePatternSet &patterns, PatternBenefit benefit) { |
| 742 | patterns.add<ToFromElementsToShuffleTreeRewrite>(arg: patterns.getContext(), |
| 743 | args&: benefit); |
| 744 | } |
| 745 | |