| 1 | //===- ShardingInterface.cpp -------------------------------------*- C++-*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" |
| 10 | #include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" |
| 11 | |
| 12 | #include "mlir/Dialect/Mesh/IR/MeshOps.h" |
| 13 | #include "mlir/IR/AffineMap.h" |
| 14 | #include "mlir/IR/IRMapping.h" |
| 15 | #include "mlir/Support/LLVM.h" |
| 16 | #include "llvm/ADT/ArrayRef.h" |
| 17 | #include "llvm/ADT/STLExtras.h" |
| 18 | #include "llvm/ADT/SmallSet.h" |
| 19 | #include "llvm/Support/Debug.h" |
| 20 | |
| 21 | #include <utility> |
| 22 | |
| 23 | #define DEBUG_TYPE "sharding-interface" |
| 24 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") |
| 25 | |
| 26 | using namespace mlir; |
| 27 | using namespace mlir::mesh; |
| 28 | |
| 29 | #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc" |
| 30 | |
| 31 | //===----------------------------------------------------------------------===// |
| 32 | // common util functions |
| 33 | //===----------------------------------------------------------------------===// |
| 34 | |
| 35 | static LogicalResult |
| 36 | checkOperandAffineExprRecursively(AffineExpr expr, |
| 37 | SmallVectorImpl<bool> &seenIds) { |
| 38 | switch (expr.getKind()) { |
| 39 | case AffineExprKind::Add: { |
| 40 | auto binOpExpr = cast<AffineBinaryOpExpr>(expr); |
| 41 | AffineExpr lhs = binOpExpr.getLHS(); |
| 42 | AffineExpr rhs = binOpExpr.getRHS(); |
| 43 | if (failed(Result: checkOperandAffineExprRecursively(expr: lhs, seenIds))) |
| 44 | return failure(); |
| 45 | if (failed(Result: checkOperandAffineExprRecursively(expr: rhs, seenIds))) |
| 46 | return failure(); |
| 47 | return success(); |
| 48 | } |
| 49 | case AffineExprKind::Mul: { |
| 50 | auto binOpExpr = cast<AffineBinaryOpExpr>(expr); |
| 51 | AffineExpr lhs = binOpExpr.getLHS(); |
| 52 | AffineExpr rhs = binOpExpr.getRHS(); |
| 53 | AffineExpr dimExpr; |
| 54 | if (lhs.getKind() == AffineExprKind::DimId && |
| 55 | rhs.getKind() == AffineExprKind::Constant) { |
| 56 | dimExpr = lhs; |
| 57 | } else if (rhs.getKind() == AffineExprKind::DimId && |
| 58 | lhs.getKind() == AffineExprKind::Constant) { |
| 59 | dimExpr = rhs; |
| 60 | } else { |
| 61 | return failure(); |
| 62 | } |
| 63 | unsigned position = cast<AffineDimExpr>(dimExpr).getPosition(); |
| 64 | if ((size_t)position >= seenIds.size() || seenIds[position]) |
| 65 | return failure(); |
| 66 | seenIds[position] = true; |
| 67 | return success(); |
| 68 | } |
| 69 | case AffineExprKind::DimId: { |
| 70 | unsigned position = cast<AffineDimExpr>(expr).getPosition(); |
| 71 | if ((size_t)position >= seenIds.size() || seenIds[position]) |
| 72 | return failure(); |
| 73 | seenIds[position] = true; |
| 74 | return success(); |
| 75 | } |
| 76 | default: |
| 77 | return failure(); |
| 78 | } |
| 79 | } |
| 80 | |
| 81 | static FailureOr<llvm::SmallSet<unsigned, 2>> |
| 82 | checkOperandAffineExpr(AffineExpr expr, unsigned numDims) { |
| 83 | SmallVector<bool> seenIds(numDims, false); |
| 84 | if (failed(Result: checkOperandAffineExprRecursively(expr, seenIds))) |
| 85 | return failure(); |
| 86 | |
| 87 | llvm::SmallSet<unsigned, 2> positions; |
| 88 | for (auto it : llvm::enumerate(First&: seenIds)) { |
| 89 | if (it.value()) |
| 90 | positions.insert(V: (unsigned)it.index()); |
| 91 | } |
| 92 | return positions; |
| 93 | } |
| 94 | |
| 95 | template <typename T> |
| 96 | SmallVector<MeshAxesAttr> |
| 97 | fromArrayOfVector(MLIRContext *ctxt, const SmallVector<SmallVector<T>> &vec) { |
| 98 | SmallVector<MeshAxesAttr> res; |
| 99 | for (const auto &v : vec) { |
| 100 | res.emplace_back(MeshAxesAttr::get(ctxt, v)); |
| 101 | } |
| 102 | return res; |
| 103 | } |
| 104 | |
| 105 | //===----------------------------------------------------------------------===// |
| 106 | // mesh::getMeshSharding |
| 107 | //===----------------------------------------------------------------------===// |
| 108 | |
| 109 | FailureOr<std::pair<bool, MeshSharding>> |
| 110 | mesh::getMeshSharding(OpResult result) { |
| 111 | Value val = cast<Value>(Val&: result); |
| 112 | bool anyShardedForDef = llvm::any_of(Range: val.getUsers(), P: [](Operation *user) { |
| 113 | auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user); |
| 114 | if (!shardOp) |
| 115 | return false; |
| 116 | return !shardOp.getAnnotateForUsers(); |
| 117 | }); |
| 118 | |
| 119 | if (anyShardedForDef) { |
| 120 | // expected to have exact one use if it has a use of `mesh.shard` without |
| 121 | // unit attr annotate_for_users |
| 122 | if (!val.hasOneUse()) |
| 123 | return failure(); |
| 124 | auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin()); |
| 125 | return std::make_pair(x: false, y: MeshSharding(shardOp.getSharding())); |
| 126 | } |
| 127 | |
| 128 | bool anyShardedForUsers = llvm::any_of(Range: val.getUsers(), P: [](Operation *user) { |
| 129 | auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user); |
| 130 | if (!shardOp) |
| 131 | return false; |
| 132 | return shardOp.getAnnotateForUsers(); |
| 133 | }); |
| 134 | if (anyShardedForUsers) { |
| 135 | SmallVector<ShardOp> shardOps; |
| 136 | for (Operation *user : val.getUsers()) { |
| 137 | ShardOp shardOp = llvm::dyn_cast<ShardOp>(user); |
| 138 | if (shardOp) |
| 139 | shardOps.push_back(shardOp); |
| 140 | } |
| 141 | MeshSharding shardForDef = shardOps[0].getSharding(); |
| 142 | for (size_t i = 1; i < shardOps.size(); ++i) { |
| 143 | // TODO: Deduce a reasonable mesh sharding attr for def when they are |
| 144 | // different |
| 145 | assert(shardForDef == shardOps[i].getSharding() && |
| 146 | "only support all shard ops have the same mesh sharding attr" ); |
| 147 | } |
| 148 | return std::make_pair(x: true, y&: shardForDef); |
| 149 | } |
| 150 | return failure(); |
| 151 | } |
| 152 | |
| 153 | FailureOr<std::pair<bool, MeshSharding>> |
| 154 | mesh::getMeshSharding(OpOperand &opOperand) { |
| 155 | Value val = opOperand.get(); |
| 156 | if (ShardOp shardOp = val.getDefiningOp<ShardOp>()) |
| 157 | return std::make_pair(shardOp.getAnnotateForUsers(), |
| 158 | MeshSharding(shardOp.getSharding())); |
| 159 | |
| 160 | return failure(); |
| 161 | } |
| 162 | |
| 163 | //===----------------------------------------------------------------------===// |
| 164 | // ShardingInterface::verifyShardingInterfaceImpl |
| 165 | //===----------------------------------------------------------------------===// |
| 166 | |
| 167 | LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() { |
| 168 | Operation *op = getOperation(); |
| 169 | |
| 170 | // check operands and results type |
| 171 | for (Type type : op->getOperandTypes()) |
| 172 | if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat()) |
| 173 | return failure(); |
| 174 | for (Type type : op->getResultTypes()) |
| 175 | if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat()) |
| 176 | return failure(); |
| 177 | |
| 178 | // check maps |
| 179 | SmallVector<AffineMap> maps = getIndexingMaps(); |
| 180 | if (maps.empty()) |
| 181 | return failure(); |
| 182 | unsigned numOperands = op->getNumOperands(); |
| 183 | unsigned numResults = op->getNumResults(); |
| 184 | if (numOperands + numResults != maps.size()) |
| 185 | return failure(); |
| 186 | |
| 187 | for (OpResult result : op->getResults()) { |
| 188 | auto resultType = dyn_cast<RankedTensorType>(result.getType()); |
| 189 | if (!resultType) |
| 190 | return failure(); |
| 191 | AffineMap map = maps[numOperands + result.getResultNumber()]; |
| 192 | if (!map.isProjectedPermutation()) { |
| 193 | return failure(); |
| 194 | } |
| 195 | } |
| 196 | |
| 197 | return success(); |
| 198 | } |
| 199 | |
| 200 | //===----------------------------------------------------------------------===// |
| 201 | // ShardingInterface::printLoopTypesAndIndexingMaps |
| 202 | //===----------------------------------------------------------------------===// |
| 203 | |
| 204 | void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) { |
| 205 | os << "print loop types and indexing maps for: \n" ; |
| 206 | getOperation()->print(os); |
| 207 | os << "\n" ; |
| 208 | os << "loop types: [" ; |
| 209 | for (utils::IteratorType type : getLoopIteratorTypes()) { |
| 210 | os << stringifyEnum(type) << " " ; |
| 211 | } |
| 212 | os << "]\n" ; |
| 213 | os << "indexing maps: \n" ; |
| 214 | for (AffineMap map : getIndexingMaps()) |
| 215 | os << map << "\n" ; |
| 216 | os << "\n" ; |
| 217 | } |
| 218 | |
| 219 | //===----------------------------------------------------------------------===// |
| 220 | // detail::defaultGetShardingOption |
| 221 | //===----------------------------------------------------------------------===// |
| 222 | |
| 223 | namespace { |
| 224 | |
| 225 | // Update the given `shardingOption` according to `meshAxes` and `loopIdx` |
| 226 | static LogicalResult fillShardingOption(Operation *op, |
| 227 | ShardingOption &shardingOption, |
| 228 | FlatSymbolRefAttr mesh, |
| 229 | ArrayRef<MeshAxis> meshAxes, |
| 230 | unsigned loopIdx) { |
| 231 | if ((shardingOption.mesh && mesh && shardingOption.mesh != mesh) || |
| 232 | (!shardingOption.shardingArray[loopIdx].empty() && |
| 233 | shardingOption.shardingArray[loopIdx] != meshAxes)) { |
| 234 | LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator " |
| 235 | << loopIdx << "\n" ); |
| 236 | return failure(); |
| 237 | } |
| 238 | for (size_t i = 0; i < shardingOption.shardingArray.size(); ++i) { |
| 239 | if (i == loopIdx) |
| 240 | continue; |
| 241 | |
| 242 | for (MeshAxis axis : meshAxes) { |
| 243 | if (llvm::is_contained(Range&: shardingOption.shardingArray[i], Element: axis)) { |
| 244 | LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes " |
| 245 | << axis << " duplicate" ); |
| 246 | return failure(); |
| 247 | } |
| 248 | } |
| 249 | } |
| 250 | if (mesh) |
| 251 | shardingOption.mesh = mesh; |
| 252 | if (shardingOption.shardingArray[loopIdx].empty()) |
| 253 | shardingOption.shardingArray[loopIdx].append(in_start: meshAxes.begin(), |
| 254 | in_end: meshAxes.end()); |
| 255 | return success(); |
| 256 | } |
| 257 | |
| 258 | } // namespace |
| 259 | |
| 260 | FailureOr<ShardingOption> |
| 261 | mesh::detail::defaultGetShardingOption(Operation *op, |
| 262 | ArrayRef<MeshSharding> operandShardings, |
| 263 | ArrayRef<MeshSharding> resultShardings) { |
| 264 | ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op); |
| 265 | ShardingOption shardingOption; |
| 266 | |
| 267 | if (failed(shardingOp.verifyShardingInterfaceImpl())) |
| 268 | return op->emitOpError() << "invalid sharding interface implementation" ; |
| 269 | SmallVector<utils::IteratorType> loopTypes = |
| 270 | shardingOp.getLoopIteratorTypes(); |
| 271 | SmallVector<AffineMap> maps = shardingOp.getIndexingMaps(); |
| 272 | unsigned numOperands = op->getNumOperands(); |
| 273 | shardingOption.shardingArray.resize(loopTypes.size()); |
| 274 | llvm::SmallVector<MeshAxis> partialMeshAxes; |
| 275 | llvm::SmallSet<unsigned, 4> visitedLoopIndices; |
| 276 | bool anyShardingInResultsOrOperands = false; |
| 277 | |
| 278 | // 1. Fill sharding option based on op results |
| 279 | for (auto shardingIt : llvm::enumerate(First&: resultShardings)) { |
| 280 | MeshSharding shardAttr = shardingIt.value(); |
| 281 | if (!shardAttr) |
| 282 | continue; |
| 283 | AffineMap map = maps[numOperands + shardingIt.index()]; |
| 284 | anyShardingInResultsOrOperands = true; |
| 285 | if (shardAttr.getSplitAxes().empty() || map.getResults().empty()) { |
| 286 | shardingOption.mesh = shardAttr.getMeshAttr(); |
| 287 | } else { |
| 288 | // Handle the split axes: calculate the corresponding loop index for each |
| 289 | // split axes sub-array, and then store the sub-array to |
| 290 | // shardingOption[index] |
| 291 | for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) { |
| 292 | AffineExpr expr = std::get<0>(it); |
| 293 | ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef(); |
| 294 | auto dim = cast<AffineDimExpr>(expr); |
| 295 | unsigned index = dim.getPosition(); |
| 296 | visitedLoopIndices.insert(index); |
| 297 | if (failed(fillShardingOption(op, shardingOption, |
| 298 | shardAttr.getMeshAttr(), axes, index))) |
| 299 | return failure(); |
| 300 | } |
| 301 | } |
| 302 | |
| 303 | // Handle the partial axes: at this stage, the exact loop index/indices |
| 304 | // cannot be decided because there could be multiple reduction loops. |
| 305 | ArrayRef<MeshAxis> partialAxes = shardAttr.getPartialAxes(); |
| 306 | if (!partialAxes.empty()) { |
| 307 | if (!partialMeshAxes.empty()) |
| 308 | return op->emitOpError() << "at most one result with partial axes is " |
| 309 | "supported at present" ; |
| 310 | partialMeshAxes.append(in_start: partialAxes.begin(), in_end: partialAxes.end()); |
| 311 | // Add all the reduction loop indices to `visitedLoopIndices` if |
| 312 | // `partialAxes` is not empty |
| 313 | for (size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) { |
| 314 | if (isReductionLoop(loopTypes[loopIdx])) |
| 315 | visitedLoopIndices.insert(V: loopIdx); |
| 316 | } |
| 317 | } |
| 318 | } |
| 319 | |
| 320 | // 2. Fill sharding option based on operands |
| 321 | for (auto shardingIt : llvm::enumerate(First&: operandShardings)) { |
| 322 | MeshSharding shardAttr = shardingIt.value(); |
| 323 | if (!shardAttr) |
| 324 | continue; |
| 325 | |
| 326 | anyShardingInResultsOrOperands = !shardAttr.getSplitAxes().empty(); |
| 327 | AffineMap map = maps[shardingIt.index()]; |
| 328 | unsigned numDims = map.getNumDims(); |
| 329 | |
| 330 | // Handle the split axes. Partial axes don't need to be handled because they |
| 331 | // only affect the defining op of the operand. |
| 332 | // |
| 333 | // TODO: Change to process the operands with single loop index first and |
| 334 | // then the operands with multiple loop indices. |
| 335 | for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) { |
| 336 | AffineExpr expr = std::get<0>(it); |
| 337 | ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef(); |
| 338 | FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices = |
| 339 | checkOperandAffineExpr(expr, numDims); |
| 340 | if (failed(loopIndices)) |
| 341 | return op->emitOpError() |
| 342 | << "operand's affine expression is restricted to const_i * " |
| 343 | "dim_i + const_j + dim_j + ..." ; |
| 344 | if (loopIndices->empty()) |
| 345 | continue; |
| 346 | if (loopIndices->size() == 1) { |
| 347 | unsigned loopIdx = *loopIndices->begin(); |
| 348 | visitedLoopIndices.insert(loopIdx); |
| 349 | if (failed(fillShardingOption(op, shardingOption, |
| 350 | shardAttr.getMeshAttr(), axes, loopIdx))) |
| 351 | return failure(); |
| 352 | } |
| 353 | // If multiple loop indices correspond to a dimension of an operand, it is |
| 354 | // difficult to infer which loop indices are responsible for sharding. |
| 355 | // Therefore, the exact loop index must be specified by others. |
| 356 | if (loopIndices->size() > 1) { |
| 357 | bool seenLoopIndices = false; |
| 358 | for (unsigned loopIdx : *loopIndices) { |
| 359 | if (visitedLoopIndices.contains(loopIdx)) { |
| 360 | seenLoopIndices = true; |
| 361 | break; |
| 362 | } |
| 363 | } |
| 364 | if (!seenLoopIndices) |
| 365 | return op->emitOpError() |
| 366 | << "the operand " << shardingIt.index() |
| 367 | << " has multiple loop indices in a dimension, but none of " |
| 368 | "them could be found in the exactly specified annotation " |
| 369 | "of op results or operands." ; |
| 370 | } |
| 371 | } |
| 372 | } |
| 373 | |
| 374 | // 3. Finalize sharding option |
| 375 | if (!partialMeshAxes.empty()) { |
| 376 | bool anyNonEmptyReductionLoop = llvm::any_of( |
| 377 | Range: llvm::enumerate(First&: shardingOption.shardingArray), P: [&](auto it) { |
| 378 | SmallVector<MeshAxis> &subArray = it.value(); |
| 379 | int64_t idx = it.index(); |
| 380 | return isReductionLoop(loopTypes[idx]) && !subArray.empty(); |
| 381 | }); |
| 382 | if (!anyNonEmptyReductionLoop) { |
| 383 | bool filled = false; |
| 384 | for (size_t idx = 0; idx < loopTypes.size(); ++idx) { |
| 385 | if (isReductionLoop(loopTypes[idx])) { |
| 386 | std::ignore = fillShardingOption(op, shardingOption, nullptr, |
| 387 | partialMeshAxes, idx); |
| 388 | filled = true; |
| 389 | break; |
| 390 | } |
| 391 | } |
| 392 | if (!filled) |
| 393 | return op->emitOpError() << "no matched reduction loop found for the " |
| 394 | "result's partial type" ; |
| 395 | } |
| 396 | } |
| 397 | removeTrailingEmptySubArray(array&: shardingOption.shardingArray); |
| 398 | if (!anyShardingInResultsOrOperands) |
| 399 | shardingOption.empty = true; |
| 400 | return shardingOption; |
| 401 | } |
| 402 | |
| 403 | // Get the sharding attributed for the given result and sharding option. |
| 404 | MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption, |
| 405 | AffineMap map, ArrayRef<utils::IteratorType> loopTypes, |
| 406 | ArrayRef<ReductionKind> reductionLoopKinds) { |
| 407 | auto resultType = cast<RankedTensorType>(result.getType()); |
| 408 | SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank()); |
| 409 | SmallVector<MeshAxis> partialAxes; |
| 410 | |
| 411 | // process the split axes |
| 412 | for (auto it : llvm::enumerate(First: map.getResults())) { |
| 413 | AffineExpr expr = it.value(); |
| 414 | // `expr` must be an `AffineDimExpr` because `map` is verified by |
| 415 | // isProjectedPermutation |
| 416 | auto dim = cast<AffineDimExpr>(Val&: expr); |
| 417 | unsigned loopIdx = dim.getPosition(); |
| 418 | if (loopIdx < shardingOption.shardingArray.size()) |
| 419 | splitAxes[it.index()].append(RHS: shardingOption.shardingArray[loopIdx]); |
| 420 | } |
| 421 | |
| 422 | // process the partial axes |
| 423 | // partialType will be ignored if partialAxes is empty |
| 424 | ReductionKind partialType = ReductionKind::Sum; |
| 425 | size_t reductionLoopKindsIdx = 0; |
| 426 | for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) { |
| 427 | utils::IteratorType iType = std::get<0>(it); |
| 428 | if (isReductionLoop(iType)) { |
| 429 | ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx]; |
| 430 | ++reductionLoopKindsIdx; |
| 431 | if (!partialAxes.empty()) |
| 432 | assert(partialType == curPartialType && |
| 433 | "Only one reduction type is supported" ); |
| 434 | partialType = curPartialType; |
| 435 | const SmallVector<MeshAxis> &axis = std::get<1>(it); |
| 436 | partialAxes.append(axis); |
| 437 | } |
| 438 | } |
| 439 | |
| 440 | removeTrailingEmptySubArray(array&: splitAxes); |
| 441 | return MeshSharding::get(shardingOption.mesh, |
| 442 | fromArrayOfVector(result.getContext(), splitAxes), |
| 443 | partialAxes, partialType); |
| 444 | } |
| 445 | |
| 446 | static FailureOr<MeshSharding> getSharding(OpOperand &opOperand, |
| 447 | const ShardingOption &shardingOption, |
| 448 | AffineMap map) { |
| 449 | Value operandValue = opOperand.get(); |
| 450 | auto operandType = dyn_cast<RankedTensorType>(operandValue.getType()); |
| 451 | if (!operandType) { |
| 452 | if (operandValue.getType().isIntOrIndexOrFloat()) |
| 453 | return MeshSharding(); |
| 454 | return failure(); |
| 455 | } |
| 456 | // 0d tensors cannot be sharded and must get replicated |
| 457 | if (operandType.getRank() == 0) { |
| 458 | return MeshSharding(shardingOption.mesh); |
| 459 | } |
| 460 | SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank()); |
| 461 | unsigned numDims = map.getNumDims(); |
| 462 | for (auto it : llvm::enumerate(First: map.getResults())) { |
| 463 | int64_t idx = it.index(); |
| 464 | AffineExpr expr = it.value(); |
| 465 | FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices = |
| 466 | checkOperandAffineExpr(expr, numDims); |
| 467 | if (failed(Result: loopIndices)) |
| 468 | return failure(); |
| 469 | SmallVector<unsigned> shardedLoopIndices; |
| 470 | for (unsigned loopIdx : *loopIndices) { |
| 471 | if ((size_t)loopIdx < shardingOption.shardingArray.size() && |
| 472 | !shardingOption.shardingArray[loopIdx].empty()) |
| 473 | shardedLoopIndices.push_back(Elt: loopIdx); |
| 474 | } |
| 475 | // mostly one sharded loop index is accepted |
| 476 | if (shardedLoopIndices.size() > 1) |
| 477 | return failure(); |
| 478 | if (shardedLoopIndices.size() == 1) { |
| 479 | splitAxes[idx].append( |
| 480 | RHS: shardingOption.shardingArray[shardedLoopIndices[0]]); |
| 481 | } |
| 482 | } |
| 483 | |
| 484 | removeTrailingEmptySubArray(array&: splitAxes); |
| 485 | return MeshSharding::get( |
| 486 | shardingOption.mesh, |
| 487 | fromArrayOfVector(opOperand.get().getContext(), splitAxes)); |
| 488 | } |
| 489 | |
| 490 | FailureOr<std::vector<MeshSharding>> |
| 491 | mesh::detail::defaultGetShardingAnnotations( |
| 492 | Operation *op, const ShardingOption &shardingOption) { |
| 493 | std::vector<MeshSharding> res; |
| 494 | |
| 495 | ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op); |
| 496 | SmallVector<utils::IteratorType> loopTypes = |
| 497 | shardingOp.getLoopIteratorTypes(); |
| 498 | SmallVector<ReductionKind> reductionKinds = |
| 499 | shardingOp.getReductionLoopIteratorKinds(); |
| 500 | SmallVector<AffineMap> maps = shardingOp.getIndexingMaps(); |
| 501 | unsigned numOperands = op->getNumOperands(); |
| 502 | |
| 503 | for (OpOperand &opOperand : op->getOpOperands()) { |
| 504 | FailureOr<MeshSharding> shardingAttr = getSharding( |
| 505 | opOperand, shardingOption, maps[opOperand.getOperandNumber()]); |
| 506 | if (failed(Result: shardingAttr)) |
| 507 | return failure(); |
| 508 | res.push_back(*shardingAttr); |
| 509 | } |
| 510 | |
| 511 | for (OpResult result : op->getResults()) { |
| 512 | res.push_back(getSharding(result, shardingOption, |
| 513 | maps[numOperands + result.getResultNumber()], |
| 514 | loopTypes, reductionKinds)); |
| 515 | } |
| 516 | |
| 517 | return res; |
| 518 | } |
| 519 | |
| 520 | //===----------------------------------------------------------------------===// |
| 521 | // detail::defaultAddShardingAnnotations |
| 522 | //===----------------------------------------------------------------------===// |
| 523 | |
| 524 | // To add a `mesh.shard` op for the given result, based on the details provided |
| 525 | // in `shardingOption`, `map`, and `loopTypes`. |
| 526 | static LogicalResult addShardOp(OpBuilder &b, OpResult result, |
| 527 | const ShardingOption &shardingOption, |
| 528 | AffineMap map, |
| 529 | ArrayRef<utils::IteratorType> loopTypes, |
| 530 | ArrayRef<ReductionKind> reductionLoopKinds) { |
| 531 | MeshSharding sharding = |
| 532 | getSharding(result, shardingOption, map, loopTypes, reductionLoopKinds); |
| 533 | maybeInsertTargetShardingAnnotation(sharding, result, builder&: b); |
| 534 | |
| 535 | return success(); |
| 536 | } |
| 537 | |
| 538 | // To add a `mesh.shard` op for the given operand, based on the details provided |
| 539 | // in `shardingOption`, `map`, and `loopTypes`. |
| 540 | static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand, |
| 541 | const ShardingOption &shardingOption, |
| 542 | AffineMap map) { |
| 543 | |
| 544 | FailureOr<MeshSharding> sharding = |
| 545 | getSharding(opOperand, shardingOption, map); |
| 546 | if (failed(Result: sharding)) { |
| 547 | return failure(); |
| 548 | } |
| 549 | OpBuilder::InsertionGuard guard(b); |
| 550 | maybeInsertSourceShardingAnnotation(sharding.value(), opOperand, b); |
| 551 | |
| 552 | return success(); |
| 553 | } |
| 554 | |
| 555 | LogicalResult mesh::detail::defaultAddShardingAnnotations( |
| 556 | Operation *op, OpBuilder &b, const ShardingOption &shardingOption) { |
| 557 | assert(!shardingOption.empty && shardingOption.mesh); |
| 558 | |
| 559 | ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op); |
| 560 | SmallVector<utils::IteratorType> loopTypes = |
| 561 | shardingOp.getLoopIteratorTypes(); |
| 562 | SmallVector<ReductionKind> reductionKinds = |
| 563 | shardingOp.getReductionLoopIteratorKinds(); |
| 564 | SmallVector<AffineMap> maps = shardingOp.getIndexingMaps(); |
| 565 | unsigned numOperands = op->getNumOperands(); |
| 566 | |
| 567 | // 1. add mesh.shard ops for all op results |
| 568 | for (OpResult result : op->getResults()) { |
| 569 | if (failed(addShardOp(b, result, shardingOption, |
| 570 | maps[numOperands + result.getResultNumber()], |
| 571 | loopTypes, reductionKinds))) |
| 572 | return failure(); |
| 573 | } |
| 574 | |
| 575 | // 2. add mesh.shard ops for all operands |
| 576 | for (OpOperand &opOperand : op->getOpOperands()) { |
| 577 | if (failed(Result: addShardOp(b, opOperand, shardingOption, |
| 578 | map: maps[opOperand.getOperandNumber()]))) |
| 579 | return failure(); |
| 580 | } |
| 581 | |
| 582 | return success(); |
| 583 | } |
| 584 | |
| 585 | #ifndef NDEBUG |
| 586 | static bool |
| 587 | isValueCompatibleWithFullReplicationSharding(Value value, |
| 588 | MeshSharding sharding) { |
| 589 | if (isa<RankedTensorType>(Val: value.getType())) { |
| 590 | return isFullReplication(sharding); |
| 591 | } |
| 592 | |
| 593 | return !sharding; |
| 594 | } |
| 595 | |
| 596 | template <typename ValueRange, typename MeshShardingRage> |
| 597 | static bool |
| 598 | areValuesCompatibleWithFullReplicationShardings(ValueRange &&values, |
| 599 | MeshShardingRage &&shardings) { |
| 600 | if (std::size(values) != std::size(shardings)) { |
| 601 | return false; |
| 602 | } |
| 603 | return llvm::all_of( |
| 604 | llvm::zip_equal(std::forward<ValueRange>(values), |
| 605 | std::forward<MeshShardingRage>(shardings)), |
| 606 | [](auto valueAndSharding) { |
| 607 | return isValueCompatibleWithFullReplicationSharding( |
| 608 | std::get<0>(valueAndSharding), std::get<1>(valueAndSharding)); |
| 609 | }); |
| 610 | } |
| 611 | #endif // NDEBUG |
| 612 | |
| 613 | void mesh::spmdizeFullyReplicatedOperation( |
| 614 | Operation &op, ArrayRef<Value> spmdizedOperands, |
| 615 | ArrayRef<MeshSharding> operandShardings, |
| 616 | ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap, |
| 617 | SymbolTableCollection &symbolTable, OpBuilder &builder) { |
| 618 | assert(spmdizedOperands.size() == operandShardings.size()); |
| 619 | assert(areValuesCompatibleWithFullReplicationShardings(op.getOperands(), |
| 620 | operandShardings)); |
| 621 | assert(areValuesCompatibleWithFullReplicationShardings(op.getResults(), |
| 622 | resultShardings)); |
| 623 | // `clone` will populate the mapping of old to new results. |
| 624 | builder.clone(op, mapper&: spmdizationMap); |
| 625 | } |
| 626 | |
| 627 | static void updateMeshAxisAssignmentForLoopIterators( |
| 628 | ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr, |
| 629 | SmallVector<std::optional<SmallVector<MeshAxis>>> |
| 630 | &meshAxesAssignmentForLoopIterators) { |
| 631 | AffineDimExpr affineDimExpr = cast<AffineDimExpr>(Val&: indexingExpr); |
| 632 | unsigned loopIteratorIdx = affineDimExpr.getPosition(); |
| 633 | if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) { |
| 634 | assert(llvm::equal(meshAxesAssignmentForTensorAxis, |
| 635 | *meshAxesAssignmentForLoopIterators[loopIteratorIdx])); |
| 636 | } else { |
| 637 | meshAxesAssignmentForLoopIterators[loopIteratorIdx] = |
| 638 | llvm::to_vector(Range&: meshAxesAssignmentForTensorAxis); |
| 639 | } |
| 640 | } |
| 641 | |
| 642 | ShardingArray mesh::getMeshAxisAssignmentForLoopIterators( |
| 643 | ArrayRef<MeshSharding> operandShardings, |
| 644 | ArrayRef<MeshSharding> resultShardings, |
| 645 | ArrayRef<utils::IteratorType> loopIteratorTypes, |
| 646 | ArrayRef<AffineMap> indexingMaps) { |
| 647 | SmallVector<std::optional<SmallVector<MeshAxis>>> |
| 648 | meshAxisAssignmentForLoopIterators(loopIteratorTypes.size()); |
| 649 | std::vector<MeshSharding> operatorAndResultShardings; |
| 650 | operatorAndResultShardings.reserve(n: operandShardings.size() + |
| 651 | resultShardings.size()); |
| 652 | llvm::append_range(C&: operatorAndResultShardings, R&: operandShardings); |
| 653 | for (auto [sharding, affineMap] : |
| 654 | llvm::zip_equal(t&: operatorAndResultShardings, u&: indexingMaps)) { |
| 655 | if (!sharding) { |
| 656 | continue; |
| 657 | } |
| 658 | for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] : |
| 659 | llvm::zip(t: sharding.getSplitAxes(), u: affineMap.getResults())) { |
| 660 | updateMeshAxisAssignmentForLoopIterators( |
| 661 | meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr, |
| 662 | meshAxisAssignmentForLoopIterators); |
| 663 | } |
| 664 | // Missing trailing split axes means replication on those tensor dimensions. |
| 665 | for (unsigned i = sharding.getSplitAxes().size(); |
| 666 | i < affineMap.getNumResults(); ++i) { |
| 667 | updateMeshAxisAssignmentForLoopIterators( |
| 668 | meshAxesAssignmentForTensorAxis: {}, indexingExpr: affineMap.getResults()[i], meshAxesAssignmentForLoopIterators&: meshAxisAssignmentForLoopIterators); |
| 669 | } |
| 670 | } |
| 671 | |
| 672 | ShardingArray res; |
| 673 | llvm::transform(Range&: meshAxisAssignmentForLoopIterators, d_first: std::back_inserter(x&: res), |
| 674 | F: [](std::optional<SmallVector<MeshAxis>> &axes) { |
| 675 | if (!axes) { |
| 676 | return SmallVector<MeshAxis>(); |
| 677 | }; |
| 678 | return std::move(*axes); |
| 679 | }); |
| 680 | return res; |
| 681 | } |
| 682 | |
| 683 | bool mesh::isAtLeastOneReductionIteratorSharded( |
| 684 | ArrayRef<utils::IteratorType> loopIteratorTypes, |
| 685 | ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) { |
| 686 | for (auto [loopIteratorType, meshAxisAssignment] : |
| 687 | llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) { |
| 688 | if (loopIteratorType == utils::IteratorType::reduction && |
| 689 | !meshAxisAssignment.empty()) { |
| 690 | return true; |
| 691 | } |
| 692 | } |
| 693 | return false; |
| 694 | } |
| 695 | |
| 696 | SmallVector<MeshAxis> mesh::getReductionMeshAxes( |
| 697 | ArrayRef<utils::IteratorType> loopIteratorTypes, |
| 698 | ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) { |
| 699 | SmallVector<MeshAxis> meshAxes; |
| 700 | for (auto [loopIteratorType, meshAxisAssignment] : |
| 701 | llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) { |
| 702 | if (loopIteratorType == utils::IteratorType::reduction) { |
| 703 | llvm::append_range(meshAxes, meshAxisAssignment); |
| 704 | } |
| 705 | } |
| 706 | return meshAxes; |
| 707 | } |
| 708 | |
| 709 | void mesh::spmdizeTriviallyShardableOperation( |
| 710 | Operation &op, ArrayRef<Value> spmdizedOperands, |
| 711 | ArrayRef<MeshSharding> operandShardings, |
| 712 | ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap, |
| 713 | SymbolTableCollection &symbolTable, OpBuilder &builder) { |
| 714 | // `clone` will populate the mapping of old to new results. |
| 715 | Operation *newOp = builder.clone(op, mapper&: spmdizationMap); |
| 716 | // Set the result types to the sharded counterparts. |
| 717 | for (auto [oldResult, newResult, sharding] : |
| 718 | llvm::zip_equal(t: op.getResults(), u: newOp->getResults(), args&: resultShardings)) { |
| 719 | newResult.setType(shardType( |
| 720 | newResult.getType(), |
| 721 | getMeshOrNull(&op, sharding.getMeshAttr(), symbolTable), sharding)); |
| 722 | } |
| 723 | } |
| 724 | |