1 | //===- MeshOps.cpp - Mesh Dialect Operations ------------------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Mesh/IR/MeshOps.h" |
10 | |
11 | #include "mlir/Dialect/Arith/IR/Arith.h" |
12 | #include "mlir/Dialect/Mesh/IR/MeshDialect.h" |
13 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
14 | #include "mlir/IR/Attributes.h" |
15 | #include "mlir/IR/BuiltinAttributes.h" |
16 | #include "mlir/IR/BuiltinTypeInterfaces.h" |
17 | #include "mlir/IR/BuiltinTypes.h" |
18 | #include "mlir/IR/Diagnostics.h" |
19 | #include "mlir/IR/DialectImplementation.h" |
20 | #include "mlir/IR/IRMapping.h" |
21 | #include "mlir/IR/Location.h" |
22 | #include "mlir/IR/PatternMatch.h" |
23 | #include "mlir/IR/TypeUtilities.h" |
24 | #include "mlir/IR/Value.h" |
25 | #include "mlir/Interfaces/ViewLikeInterface.h" |
26 | #include "mlir/Support/LLVM.h" |
27 | #include "mlir/Transforms/InliningUtils.h" |
28 | #include "llvm/ADT/ArrayRef.h" |
29 | #include "llvm/ADT/STLExtras.h" |
30 | #include "llvm/ADT/SmallSet.h" |
31 | #include "llvm/ADT/SmallVector.h" |
32 | #include "llvm/ADT/TypeSwitch.h" |
33 | #include "llvm/Support/Casting.h" |
34 | #include <algorithm> |
35 | #include <functional> |
36 | #include <iterator> |
37 | #include <numeric> |
38 | #include <optional> |
39 | #include <utility> |
40 | |
41 | #define DEBUG_TYPE "mesh-ops" |
42 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") |
43 | |
44 | using namespace mlir; |
45 | using namespace mlir::mesh; |
46 | |
47 | #include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc" |
48 | |
49 | namespace { |
50 | |
51 | struct DimensionSize { |
52 | static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); } |
53 | DimensionSize(int64_t val) : val(val) {} |
54 | int64_t value() const { return val; } |
55 | operator int64_t() const { return val; } |
56 | bool isDynamic() const { return ShapedType::isDynamic(val); } |
57 | |
58 | private: |
59 | int64_t val; |
60 | }; |
61 | |
62 | } // namespace |
63 | |
64 | static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) { |
65 | if (lhs.isDynamic() || rhs.isDynamic()) { |
66 | return DimensionSize::dynamic(); |
67 | } |
68 | return lhs.value() / rhs.value(); |
69 | } |
70 | |
71 | static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) { |
72 | if (lhs.isDynamic() || rhs.isDynamic()) { |
73 | return DimensionSize::dynamic(); |
74 | } |
75 | return lhs.value() * rhs.value(); |
76 | } |
77 | |
78 | //===----------------------------------------------------------------------===// |
79 | // Inliner |
80 | //===----------------------------------------------------------------------===// |
81 | |
82 | namespace { |
83 | struct MeshInlinerInterface : public DialectInlinerInterface { |
84 | using DialectInlinerInterface::DialectInlinerInterface; |
85 | // Currently no restrictions are encoded for inlining. |
86 | bool isLegalToInline(Operation *, Operation *, bool) const final { |
87 | return true; |
88 | } |
89 | bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final { |
90 | return true; |
91 | } |
92 | bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { |
93 | return true; |
94 | } |
95 | }; |
96 | } // namespace |
97 | |
98 | //===----------------------------------------------------------------------===// |
99 | // Mesh dialect |
100 | //===----------------------------------------------------------------------===// |
101 | |
102 | void MeshDialect::initialize() { |
103 | addOperations< |
104 | #define GET_OP_LIST |
105 | #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc" |
106 | >(); |
107 | addAttributes< |
108 | #define GET_ATTRDEF_LIST |
109 | #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc" |
110 | >(); |
111 | addTypes< |
112 | #define GET_TYPEDEF_LIST |
113 | #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc" |
114 | >(); |
115 | addInterface<MeshInlinerInterface>(); |
116 | } |
117 | |
118 | Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value, |
119 | Type type, Location loc) { |
120 | return arith::ConstantOp::materialize(builder, value, type, loc); |
121 | } |
122 | |
123 | //===----------------------------------------------------------------------===// |
124 | // Mesh utilities |
125 | //===----------------------------------------------------------------------===// |
126 | |
127 | static FailureOr<MeshOp> getMeshAndVerify(Operation *op, |
128 | FlatSymbolRefAttr meshSymbol, |
129 | SymbolTableCollection &symbolTable) { |
130 | mesh::MeshOp mesh = getMeshOrNull(op, meshSymbol, symbolTable); |
131 | if (!mesh) { |
132 | return op->emitError() << "Undefined required mesh symbol \"" |
133 | << meshSymbol.getValue() << "\"."; |
134 | } |
135 | |
136 | return mesh; |
137 | } |
138 | |
139 | template <typename It> |
140 | bool isUnique(It begin, It end) { |
141 | if (begin == end) { |
142 | return true; |
143 | } |
144 | It next = std::next(begin); |
145 | if (next == end) { |
146 | return true; |
147 | } |
148 | for (; next != end; ++next, ++begin) { |
149 | if (*begin == *next) { |
150 | return false; |
151 | } |
152 | } |
153 | return true; |
154 | } |
155 | |
156 | static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes, |
157 | MeshOp mesh) { |
158 | SmallVector<MeshAxis> sorted = llvm::to_vector(Range&: axes); |
159 | llvm::sort(C&: sorted); |
160 | if (!isUnique(begin: sorted.begin(), end: sorted.end())) { |
161 | return emitError(loc) << "Mesh axes contains duplicate elements."; |
162 | } |
163 | |
164 | MeshAxis rank = mesh.getRank(); |
165 | for (auto axis : axes) { |
166 | if (axis >= rank || axis < 0) { |
167 | return emitError(loc) |
168 | << "0-based mesh axis index "<< axis |
169 | << " is out of bounds. The referenced mesh \""<< mesh.getSymName() |
170 | << "\" is of rank "<< rank << "."; |
171 | } |
172 | } |
173 | |
174 | return success(); |
175 | } |
176 | |
177 | template <typename Op> |
178 | static FailureOr<MeshOp> |
179 | getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) { |
180 | auto mesh = |
181 | ::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable); |
182 | if (failed(mesh)) { |
183 | return failure(); |
184 | } |
185 | if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) { |
186 | return failure(); |
187 | } |
188 | return mesh; |
189 | } |
190 | |
191 | template <typename InShape, typename MeshShape, typename SplitAxes, |
192 | typename OutShape> |
193 | static void shardShape(const InShape &inShape, const MeshShape &meshShape, |
194 | const SplitAxes &splitAxes, OutShape &outShape, |
195 | ArrayRef<int64_t> shardedDimsOffsets = {}, |
196 | ArrayRef<int64_t> haloSizes = {}) { |
197 | // 0d tensors cannot be sharded and must get replicated |
198 | if (inShape.empty()) { |
199 | assert(outShape.empty()); |
200 | return; |
201 | } |
202 | |
203 | std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape), |
204 | llvm::adl_begin(outShape)); |
205 | |
206 | if (!shardedDimsOffsets.empty()) { |
207 | auto isDynShape = ShapedType::isDynamicShape(meshShape); |
208 | uint64_t pos = 1; |
209 | for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) { |
210 | if (!innerSplitAxes.empty()) { |
211 | auto sz = shardedDimsOffsets[pos]; |
212 | bool same = !isDynShape; |
213 | if (same) { |
214 | // Find sharded dims in shardedDimsOffsets with same static size on |
215 | // all devices. Use kDynamic for dimensions with dynamic or |
216 | // non-uniform offs in shardedDimsOffsets. |
217 | uint64_t numShards = 0; |
218 | for (auto i : innerSplitAxes.asArrayRef()) { |
219 | numShards += meshShape[i]; |
220 | } |
221 | for (size_t i = 1; i < numShards; ++i) { |
222 | if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] != |
223 | sz) { |
224 | same = false; |
225 | break; |
226 | } |
227 | } |
228 | pos += numShards + 1; |
229 | } |
230 | outShape[tensorAxis] = same ? sz : ShapedType::kDynamic; |
231 | } |
232 | } |
233 | } else { |
234 | for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) { |
235 | outShape[tensorAxis] = shardDimension( |
236 | inShape[tensorAxis], |
237 | collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape)); |
238 | } |
239 | |
240 | if (!haloSizes.empty()) { |
241 | // add halo sizes if requested |
242 | int haloAxis = 0; |
243 | for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) { |
244 | if (!ShapedType::isDynamic(outShape[tensorAxis]) && |
245 | !innerSplitAxes.empty()) { |
246 | if (haloSizes[haloAxis * 2] >= 0 && |
247 | haloSizes[haloAxis * 2 + 1] >= 0) { |
248 | outShape[tensorAxis] += |
249 | haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1]; |
250 | ++haloAxis; |
251 | } else { |
252 | outShape[tensorAxis] = ShapedType::kDynamic; |
253 | } |
254 | } |
255 | } |
256 | } |
257 | } |
258 | } |
259 | |
260 | ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh, |
261 | MeshSharding sharding) { |
262 | using Dim = std::decay_t<decltype(shape.getDimSize(0))>; |
263 | SmallVector<Dim> resShapeArr(shape.getShape().size()); |
264 | shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(), |
265 | resShapeArr, sharding.getStaticShardedDimsOffsets(), |
266 | sharding.getStaticHaloSizes()); |
267 | return shape.clone(resShapeArr); |
268 | } |
269 | |
270 | Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) { |
271 | RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type); |
272 | if (rankedTensorType && !rankedTensorType.getShape().empty()) { |
273 | return shardShapedType(rankedTensorType, mesh, sharding); |
274 | } |
275 | return type; |
276 | } |
277 | |
278 | void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, |
279 | OpOperand &operand, |
280 | OpBuilder &builder, |
281 | ShardOp &newShardOp) { |
282 | OpBuilder::InsertionGuard insertionGuard(builder); |
283 | Value operandValue = operand.get(); |
284 | Operation *operandOp = operand.getOwner(); |
285 | builder.setInsertionPointAfterValue(operandValue); |
286 | ShardOp shardOp = dyn_cast<ShardOp>(operandOp); |
287 | if (shardOp && sharding == shardOp.getSharding() && |
288 | !shardOp.getAnnotateForUsers()) { |
289 | // No need for anything if the correct sharding is already set. |
290 | if (!newShardOp) { |
291 | newShardOp = shardOp; |
292 | } |
293 | return; |
294 | } |
295 | |
296 | if (!newShardOp) { |
297 | auto shardingOp = |
298 | builder.create<ShardingOp>(operandValue.getLoc(), sharding); |
299 | newShardOp = |
300 | builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp, |
301 | /*annotate_for_users*/ false); |
302 | } |
303 | IRRewriter rewriter(builder); |
304 | rewriter.replaceUsesWithIf( |
305 | operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) { |
306 | return use.getOwner() == operandOp && use.get() == operandValue; |
307 | }); |
308 | |
309 | if (!shardOp || shardOp.getAnnotateForUsers()) { |
310 | return; |
311 | } |
312 | |
313 | auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp, |
314 | newShardOp.getSharding(), |
315 | /*annotate_for_users*/ true); |
316 | rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2); |
317 | } |
318 | |
319 | void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, |
320 | OpResult result, |
321 | OpBuilder &builder) { |
322 | ShardOp newShardOp; |
323 | for (auto &use : llvm::make_early_inc_range(Range: result.getUses())) { |
324 | maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp); |
325 | } |
326 | } |
327 | |
328 | void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding, |
329 | OpOperand &operand, |
330 | OpBuilder &builder) { |
331 | OpBuilder::InsertionGuard insertionGuard(builder); |
332 | Value operandValue = operand.get(); |
333 | Operation *operandSrcOp = operandValue.getDefiningOp(); |
334 | bool isBlockArg = !operandSrcOp; |
335 | { |
336 | [[maybe_unused]] auto opType = |
337 | dyn_cast<mlir::RankedTensorType>(operandValue.getType()); |
338 | assert(!opType || opType.getRank() > 0 || isFullReplication(sharding)); |
339 | } |
340 | if (!isa<RankedTensorType>(Val: operandValue.getType()) && operandSrcOp && |
341 | operandSrcOp->hasTrait<OpTrait::ConstantLike>()) { |
342 | return; |
343 | } |
344 | |
345 | Operation *operandOp = operand.getOwner(); |
346 | ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp); |
347 | |
348 | if (shardOp && sharding == shardOp.getSharding() && |
349 | shardOp.getAnnotateForUsers()) { |
350 | // No need for anything the correct sharding is already set. |
351 | return; |
352 | } |
353 | |
354 | builder.setInsertionPoint(operandOp); |
355 | auto shardingOp = |
356 | builder.create<ShardingOp>(operand.get().getLoc(), sharding); |
357 | auto newShardOp = |
358 | builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp, |
359 | /*annotate_for_users*/ true); |
360 | IRRewriter rewriter(builder); |
361 | rewriter.replaceUsesWithIf( |
362 | operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) { |
363 | return use.getOwner() == operandOp && use.get() == operandValue; |
364 | }); |
365 | |
366 | if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) { |
367 | // No need for resharding. |
368 | return; |
369 | } |
370 | |
371 | builder.setInsertionPoint(newShardOp); |
372 | auto newPreceedingShardOp = |
373 | builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp, |
374 | /*annotate_for_users*/ false); |
375 | rewriter.replaceUsesWithIf( |
376 | newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](OpOperand &use) { |
377 | return use.getOwner() == newShardOp.getOperation(); |
378 | }); |
379 | } |
380 | |
381 | //===----------------------------------------------------------------------===// |
382 | // mesh.mesh op |
383 | //===----------------------------------------------------------------------===// |
384 | |
385 | LogicalResult MeshOp::verify() { |
386 | int64_t rank = getRank(); |
387 | |
388 | if (rank <= 0) |
389 | return emitOpError("rank of mesh is expected to be a positive integer"); |
390 | |
391 | for (int64_t dimSize : getShape()) { |
392 | if (dimSize < 0 && !ShapedType::isDynamic(dimSize)) |
393 | return emitOpError("dimension size of a mesh is expected to be " |
394 | "non-negative or dynamic"); |
395 | } |
396 | |
397 | return success(); |
398 | } |
399 | |
400 | //===----------------------------------------------------------------------===// |
401 | // mesh.mesh_shape op |
402 | //===----------------------------------------------------------------------===// |
403 | |
404 | LogicalResult |
405 | MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
406 | auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); |
407 | if (failed(mesh)) { |
408 | return failure(); |
409 | } |
410 | if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) { |
411 | return failure(); |
412 | } |
413 | |
414 | size_t expectedResultsCount = |
415 | getAxes().empty() ? mesh->getRank() : getAxes().size(); |
416 | if (getResult().size() != expectedResultsCount) { |
417 | return emitError() << "Unexpected number of results "<< getResult().size() |
418 | << ". Expected "<< expectedResultsCount << "."; |
419 | } |
420 | |
421 | return success(); |
422 | } |
423 | |
424 | void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, |
425 | MeshOp mesh) { |
426 | build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>()); |
427 | } |
428 | |
429 | void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, |
430 | MeshOp mesh, ArrayRef<MeshAxis> axes) { |
431 | build(odsBuilder, odsState, |
432 | SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(), |
433 | odsBuilder.getIndexType()), |
434 | mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes)); |
435 | } |
436 | |
437 | void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, |
438 | StringRef mesh, ArrayRef<MeshAxis> axes) { |
439 | assert(!axes.empty()); |
440 | build(odsBuilder, odsState, |
441 | SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh, |
442 | MeshAxesAttr::get(odsBuilder.getContext(), axes)); |
443 | } |
444 | |
445 | void MeshShapeOp::getAsmResultNames( |
446 | function_ref<void(Value, StringRef)> setNameFn) { |
447 | setNameFn(getResults()[0], "mesh_shape"); |
448 | } |
449 | |
450 | //===----------------------------------------------------------------------===// |
451 | // mesh.sharding |
452 | //===----------------------------------------------------------------------===// |
453 | |
454 | void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, |
455 | FlatSymbolRefAttr mesh, |
456 | ArrayRef<MeshAxesAttr> split_axes, |
457 | ArrayRef<MeshAxis> partial_axes, |
458 | mesh::ReductionKind partial_type, |
459 | ArrayRef<int64_t> static_halos, |
460 | ArrayRef<int64_t> static_offsets) { |
461 | return build( |
462 | b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), |
463 | ::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes), |
464 | ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type), |
465 | ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {}, |
466 | ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {}); |
467 | } |
468 | |
469 | void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, |
470 | FlatSymbolRefAttr mesh, |
471 | ArrayRef<MeshAxesAttr> split_axes) { |
472 | return build( |
473 | b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {}, |
474 | ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum), |
475 | {}, {}, {}, {}); |
476 | } |
477 | |
478 | void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, |
479 | llvm::StringRef mesh, ArrayRef<MeshAxesAttr> split_axes, |
480 | ArrayRef<int64_t> static_halos, |
481 | ArrayRef<int64_t> static_offsets) { |
482 | return build( |
483 | b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh), |
484 | MeshAxesArrayAttr::get(b.getContext(), split_axes), {}, |
485 | ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum), |
486 | ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {}, |
487 | ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {}); |
488 | } |
489 | |
490 | void ShardingOp::build( |
491 | ::mlir::OpBuilder &b, ::mlir::OperationState &odsState, |
492 | FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes, |
493 | ::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes, |
494 | ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) { |
495 | mlir::SmallVector<int64_t> staticHalos, staticDims; |
496 | mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims; |
497 | dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos); |
498 | dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims); |
499 | return build( |
500 | b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {}, |
501 | ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum), |
502 | ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos, |
503 | ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims); |
504 | } |
505 | |
506 | void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, |
507 | mlir::mesh::MeshSharding from) { |
508 | |
509 | build(b, odsState, ShardingType::get(b.getContext()), from.getMeshAttr(), |
510 | MeshAxesArrayAttr::get(b.getContext(), from.getSplitAxes()), |
511 | from.getPartialAxes().empty() |
512 | ? DenseI16ArrayAttr() |
513 | : b.getDenseI16ArrayAttr(from.getPartialAxes()), |
514 | ::mlir::mesh::ReductionKindAttr::get(b.getContext(), |
515 | from.getPartialType()), |
516 | from.getStaticShardedDimsOffsets().empty() |
517 | ? DenseI64ArrayAttr() |
518 | : b.getDenseI64ArrayAttr(from.getStaticShardedDimsOffsets()), |
519 | from.getDynamicShardedDimsOffsets(), |
520 | from.getStaticHaloSizes().empty() |
521 | ? DenseI64ArrayAttr() |
522 | : b.getDenseI64ArrayAttr(from.getStaticHaloSizes()), |
523 | from.getDynamicHaloSizes()); |
524 | } |
525 | |
526 | LogicalResult ShardingOp::verify() { |
527 | llvm::SmallSet<MeshAxis, 4> visitedAxes; |
528 | |
529 | auto checkMeshAxis = [&](ArrayRef<MeshAxis> axesArray) -> LogicalResult { |
530 | for (MeshAxis axis : axesArray) { |
531 | if (axis < 0) |
532 | return emitError() << "mesh axis is expected to be non-negative"; |
533 | if (!visitedAxes.insert(axis).second) |
534 | return emitError() << "mesh axis duplicated"; |
535 | } |
536 | return success(); |
537 | }; |
538 | |
539 | for (auto subAxes : getSplitAxes().getAxes()) { |
540 | ArrayRef<MeshAxis> subAxesArray = subAxes.asArrayRef(); |
541 | if (failed(checkMeshAxis(subAxesArray))) |
542 | return failure(); |
543 | } |
544 | if (getPartialAxes().has_value() && |
545 | failed(checkMeshAxis(getPartialAxes().value()))) |
546 | return failure(); |
547 | |
548 | if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) { |
549 | return emitOpError("halo sizes and shard offsets are mutually exclusive"); |
550 | } |
551 | |
552 | if (!getStaticHaloSizes().empty()) { |
553 | auto numSplitAxes = getSplitAxes().getAxes().size(); |
554 | for (auto splitAxis : getSplitAxes().getAxes()) { |
555 | if (splitAxis.empty()) { |
556 | --numSplitAxes; |
557 | } |
558 | } |
559 | if (getStaticHaloSizes().size() != numSplitAxes * 2) { |
560 | return emitError() << "halo sizes must be specified for all split axes."; |
561 | } |
562 | } |
563 | |
564 | return success(); |
565 | } |
566 | |
567 | void ShardingOp::getAsmResultNames( |
568 | function_ref<void(Value, StringRef)> setNameFn) { |
569 | setNameFn(getResult(), "sharding"); |
570 | } |
571 | |
572 | LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
573 | auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); |
574 | if (failed(mesh)) { |
575 | return failure(); |
576 | } |
577 | if (mlir::ShapedType::isDynamicShape(mesh->getShape()) && |
578 | getStaticShardedDimsOffsets().size() > 0) { |
579 | return emitError() << "sharded dims offsets are not allowed for " |
580 | "devices meshes with dynamic shape."; |
581 | } |
582 | |
583 | auto shardedDimsOffsets = getStaticShardedDimsOffsets(); |
584 | if (!shardedDimsOffsets.empty()) { |
585 | auto meshShape = mesh.value().getShape(); |
586 | assert(!ShapedType::isDynamicShape(meshShape)); |
587 | uint64_t pos = 0; |
588 | for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) { |
589 | if (!innerSplitAxes.empty()) { |
590 | int64_t numShards = 0, off = 0; |
591 | for (auto i : innerSplitAxes.asArrayRef()) { |
592 | numShards += meshShape[i]; |
593 | } |
594 | for (int64_t i = 0; i <= numShards; ++i) { |
595 | if (shardedDimsOffsets.size() <= pos + i) { |
596 | return emitError() << "sharded dims offsets has wrong size."; |
597 | } |
598 | if (!ShapedType::isDynamic(shardedDimsOffsets[pos + i])) { |
599 | if (shardedDimsOffsets[pos + i] < off) { |
600 | return emitError() |
601 | << "sharded dims offsets must be non-decreasing."; |
602 | } |
603 | off = shardedDimsOffsets[pos + i]; |
604 | } |
605 | } |
606 | pos += numShards + 1; |
607 | } |
608 | } |
609 | } |
610 | return success(); |
611 | } |
612 | |
613 | namespace { |
614 | // Sharding annotations "halo sizes" and "sharded dims offsets" |
615 | // are a mix of attributes and dynamic values. This canonicalization moves |
616 | // constant values to the respective attribute lists, minimizing the number |
617 | // of values. |
618 | // It also removes sharded_dims_sizes and halos if they are effectively "empty". |
619 | class NormalizeSharding final : public OpRewritePattern<ShardingOp> { |
620 | public: |
621 | using OpRewritePattern<ShardingOp>::OpRewritePattern; |
622 | |
623 | LogicalResult matchAndRewrite(ShardingOp op, |
624 | PatternRewriter &b) const override { |
625 | auto mixedHalos = |
626 | getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b); |
627 | auto mixedOffs = getMixedValues(op.getStaticShardedDimsOffsets(), |
628 | op.getDynamicShardedDimsOffsets(), b); |
629 | |
630 | // No constant operands were folded, just return; |
631 | bool modified = succeeded(foldDynamicIndexList(mixedHalos, true)) || |
632 | succeeded(foldDynamicIndexList(mixedOffs, true)); |
633 | |
634 | auto [staticHalos, dynamicHalos] = decomposeMixedValues(mixedHalos); |
635 | auto [staticOffs, dynamicOffs] = decomposeMixedValues(mixedOffs); |
636 | |
637 | if (dynamicHalos.empty() && !staticHalos.empty()) { |
638 | if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) { |
639 | staticHalos.clear(); |
640 | modified = true; |
641 | } |
642 | } |
643 | |
644 | // Remove sharded dims offsets if they are effectively the default values, |
645 | // e.g. if they define equi-distance between all neighboring shards. |
646 | // Requires static-only offsets. Compares the first distance as the |
647 | // difference between the first two offsets. Only if all consecutive |
648 | // distances are the same, the offsets are removed. |
649 | if (dynamicOffs.empty() && !staticOffs.empty()) { |
650 | assert(staticOffs.size() >= 2); |
651 | auto diff = staticOffs[1] - staticOffs[0]; |
652 | bool all_same = staticOffs.size() > 2; |
653 | for (auto i = 2u; i < staticOffs.size(); ++i) { |
654 | if (staticOffs[i] - staticOffs[i - 1] != diff) { |
655 | all_same = false; |
656 | break; |
657 | } |
658 | } |
659 | if (all_same) { |
660 | staticOffs.clear(); |
661 | modified = true; |
662 | } |
663 | } |
664 | |
665 | if (!modified) { |
666 | return failure(); |
667 | } |
668 | |
669 | op.setStaticHaloSizes(staticHalos); |
670 | op.getDynamicHaloSizesMutable().assign(dynamicHalos); |
671 | op.setStaticShardedDimsOffsets(staticOffs); |
672 | op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs); |
673 | |
674 | return success(); |
675 | } |
676 | }; |
677 | } // namespace |
678 | |
679 | void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results, |
680 | mlir::MLIRContext *context) { |
681 | results.add<NormalizeSharding>(context); |
682 | } |
683 | |
684 | //===----------------------------------------------------------------------===// |
685 | // MeshSharding |
686 | //===----------------------------------------------------------------------===// |
687 | |
688 | bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const { |
689 | if (getMesh() != rhs.getMesh()) { |
690 | return false; |
691 | } |
692 | |
693 | if (getPartialAxes().size() != rhs.getPartialAxes().size() || |
694 | (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) || |
695 | !llvm::equal(getPartialAxes(), rhs.getPartialAxes())) { |
696 | return false; |
697 | } |
698 | |
699 | auto minSize = std::min(a: getSplitAxes().size(), b: rhs.getSplitAxes().size()); |
700 | if (!llvm::equal(llvm::make_range(getSplitAxes().begin(), |
701 | getSplitAxes().begin() + minSize), |
702 | llvm::make_range(rhs.getSplitAxes().begin(), |
703 | rhs.getSplitAxes().begin() + minSize))) { |
704 | return false; |
705 | } |
706 | |
707 | return llvm::all_of(llvm::drop_begin(getSplitAxes(), minSize), |
708 | std::mem_fn(&MeshAxesAttr::empty)) && |
709 | llvm::all_of(llvm::drop_begin(rhs.getSplitAxes(), minSize), |
710 | std::mem_fn(&MeshAxesAttr::empty)); |
711 | } |
712 | |
713 | bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const { |
714 | return equalShardSizes(rhs) && equalHaloSizes(rhs); |
715 | } |
716 | |
717 | bool MeshSharding::equalShardSizes(const MeshSharding &rhs) const { |
718 | if (rhs.getStaticShardedDimsOffsets().size() != |
719 | getStaticShardedDimsOffsets().size() || |
720 | !llvm::equal(LRange: getStaticShardedDimsOffsets(), |
721 | RRange: rhs.getStaticShardedDimsOffsets())) { |
722 | return false; |
723 | } |
724 | if (rhs.getDynamicShardedDimsOffsets().size() != |
725 | getDynamicShardedDimsOffsets().size() || |
726 | !llvm::equal(LRange: getDynamicShardedDimsOffsets(), |
727 | RRange: rhs.getDynamicShardedDimsOffsets())) { |
728 | return false; |
729 | } |
730 | return true; |
731 | } |
732 | |
733 | bool MeshSharding::equalHaloSizes(const MeshSharding &rhs) const { |
734 | if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() || |
735 | !llvm::equal(LRange: getStaticHaloSizes(), RRange: rhs.getStaticHaloSizes())) { |
736 | return false; |
737 | } |
738 | if (rhs.getDynamicHaloSizes().size() != getDynamicHaloSizes().size() || |
739 | !llvm::equal(LRange: getDynamicHaloSizes(), RRange: rhs.getDynamicHaloSizes())) { |
740 | return false; |
741 | } |
742 | return true; |
743 | } |
744 | |
745 | bool MeshSharding::operator==(Value rhs) const { |
746 | return equalSplitAndPartialAxes(rhs) && equalHaloAndShardSizes(rhs); |
747 | } |
748 | |
749 | bool MeshSharding::operator!=(Value rhs) const { return !(*this == rhs); } |
750 | |
751 | bool MeshSharding::operator==(const MeshSharding &rhs) const { |
752 | return equalSplitAndPartialAxes(rhs) && equalHaloAndShardSizes(rhs); |
753 | } |
754 | |
755 | bool MeshSharding::operator!=(const MeshSharding &rhs) const { |
756 | return !(*this == rhs); |
757 | } |
758 | |
759 | MeshSharding::MeshSharding(::mlir::FlatSymbolRefAttr mesh_) : mesh(mesh_) {} |
760 | |
761 | MeshSharding::MeshSharding(Value rhs) { |
762 | auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp()); |
763 | assert(shardingOp && "expected sharding op"); |
764 | auto splitAxes = shardingOp.getSplitAxes().getAxes(); |
765 | auto partialAxes = shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>()); |
766 | // If splitAxes and partialAxes are empty, use "empty" constructor. |
767 | if (splitAxes.empty() && partialAxes.empty()) { |
768 | *this = MeshSharding(shardingOp.getMeshAttr()); |
769 | return; |
770 | } |
771 | *this = get(shardingOp.getMeshAttr(), splitAxes, partialAxes, |
772 | shardingOp.getPartialType().value_or(ReductionKind::Sum), |
773 | shardingOp.getStaticHaloSizes(), |
774 | shardingOp.getStaticShardedDimsOffsets(), |
775 | SmallVector<Value>(shardingOp.getDynamicHaloSizes()), |
776 | SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets())); |
777 | } |
778 | |
779 | MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_, |
780 | ArrayRef<MeshAxesAttr> split_axes_, |
781 | ArrayRef<MeshAxis> partial_axes_, |
782 | ReductionKind partial_type_, |
783 | ArrayRef<int64_t> static_halo_sizes_, |
784 | ArrayRef<int64_t> static_sharded_dims_offsets_, |
785 | ArrayRef<Value> dynamic_halo_sizes_, |
786 | ArrayRef<Value> dynamic_sharded_dims_offsets_) { |
787 | MeshSharding res(mesh_); |
788 | if (split_axes_.empty() && partial_axes_.empty()) { |
789 | return res; |
790 | } |
791 | |
792 | res.split_axes.resize(split_axes_.size()); |
793 | for (auto [i, axis] : llvm::enumerate(First&: split_axes_)) { |
794 | res.split_axes[i] = |
795 | MeshAxesAttr::get(mesh_.getContext(), axis.asArrayRef()); |
796 | } |
797 | |
798 | auto clone = [](const auto src, auto &dst) { |
799 | dst.resize(src.size()); |
800 | llvm::copy(src, dst.begin()); |
801 | }; |
802 | |
803 | clone(partial_axes_, res.partial_axes); |
804 | res.partial_type = partial_type_; |
805 | clone(static_halo_sizes_, res.static_halo_sizes); |
806 | clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets); |
807 | clone(dynamic_halo_sizes_, res.dynamic_halo_sizes); |
808 | clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets); |
809 | |
810 | return res; |
811 | } |
812 | |
813 | //===----------------------------------------------------------------------===// |
814 | // mesh.shard_shape |
815 | //===----------------------------------------------------------------------===// |
816 | |
817 | void ShardShapeOp::getAsmResultNames( |
818 | function_ref<void(Value, StringRef)> setNameFn) { |
819 | setNameFn(getResult()[0], "shard_shape"); |
820 | } |
821 | |
822 | void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder, |
823 | ::mlir::OperationState &odsState, |
824 | ::llvm::ArrayRef<int64_t> dims, |
825 | ArrayRef<Value> dims_dyn, ::mlir::Value sharding, |
826 | ::mlir::ValueRange device) { |
827 | SmallVector<mlir::Type> resType(dims.size(), odsBuilder.getIndexType()); |
828 | build(odsBuilder, odsState, resType, dims, dims_dyn, sharding, |
829 | SmallVector<int64_t>(device.size(), ShapedType::kDynamic), device); |
830 | } |
831 | |
832 | //===----------------------------------------------------------------------===// |
833 | // mesh.shard op |
834 | //===----------------------------------------------------------------------===// |
835 | |
836 | void ShardOp::getAsmResultNames( |
837 | function_ref<void(Value, StringRef)> setNameFn) { |
838 | setNameFn(getResult(), "sharding_annotated"); |
839 | } |
840 | |
841 | namespace { |
842 | // Determine if the given ShardOp is a duplicate of another ShardOp |
843 | // on the same value. This can happen if constant values are sharded. |
844 | class FoldDuplicateShardOp final : public OpRewritePattern<ShardOp> { |
845 | public: |
846 | using OpRewritePattern<ShardOp>::OpRewritePattern; |
847 | |
848 | LogicalResult matchAndRewrite(ShardOp op, PatternRewriter &b) const override { |
849 | // Get the use-list of the value being sharded and check if it has more than |
850 | // one use. |
851 | Value value = op.getSrc(); |
852 | if (value.hasOneUse() || value.getDefiningOp<ShardOp>()) { |
853 | return failure(); |
854 | } |
855 | |
856 | // Iterate through the uses of the value to find a duplicate ShardOp. |
857 | for (auto &use : value.getUses()) { |
858 | if (use.getOwner() != op.getOperation()) { |
859 | auto otherOp = dyn_cast<ShardOp>(use.getOwner()); |
860 | if (!otherOp || !otherOp->isBeforeInBlock(op)) { |
861 | return failure(); |
862 | } |
863 | // Create a MeshSharding object for the current and the other ShardOp |
864 | // If the two are equal replace current op with the other op. |
865 | MeshSharding currentSharding(op.getSharding()); |
866 | MeshSharding otherSharding(otherOp.getSharding()); |
867 | if (currentSharding == otherSharding) { |
868 | b.replaceAllUsesWith(op.getResult(), otherOp.getResult()); |
869 | b.eraseOp(op.getOperation()); |
870 | } else { |
871 | // use the other sharding as input for op |
872 | op.getSrcMutable().assign(otherOp.getResult()); |
873 | } |
874 | return success(); |
875 | } |
876 | } |
877 | |
878 | return failure(); |
879 | } |
880 | }; |
881 | } // namespace |
882 | |
883 | void ShardOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results, |
884 | mlir::MLIRContext *context) { |
885 | results.add<FoldDuplicateShardOp>(context); |
886 | } |
887 | |
888 | //===----------------------------------------------------------------------===// |
889 | // mesh.process_multi_index op |
890 | //===----------------------------------------------------------------------===// |
891 | |
892 | LogicalResult |
893 | ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
894 | auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); |
895 | if (failed(mesh)) { |
896 | return failure(); |
897 | } |
898 | if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) { |
899 | return failure(); |
900 | } |
901 | |
902 | size_t expectedResultsCount = |
903 | getAxes().empty() ? mesh->getRank() : getAxes().size(); |
904 | if (getResult().size() != expectedResultsCount) { |
905 | return emitError() << "Unexpected number of results "<< getResult().size() |
906 | << ". Expected "<< expectedResultsCount << "."; |
907 | } |
908 | |
909 | return success(); |
910 | } |
911 | |
912 | void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState, |
913 | MeshOp mesh) { |
914 | build(odsBuilder, odsState, |
915 | SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()), |
916 | mesh.getSymName(), ArrayRef<MeshAxis>()); |
917 | } |
918 | |
919 | void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState, |
920 | StringRef mesh, ArrayRef<MeshAxis> axes) { |
921 | build(odsBuilder, odsState, |
922 | SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh, |
923 | MeshAxesAttr::get(odsBuilder.getContext(), axes)); |
924 | } |
925 | |
926 | void ProcessMultiIndexOp::getAsmResultNames( |
927 | function_ref<void(Value, StringRef)> setNameFn) { |
928 | setNameFn(getResults()[0], "proc_linear_idx"); |
929 | } |
930 | |
931 | //===----------------------------------------------------------------------===// |
932 | // mesh.process_linear_index op |
933 | //===----------------------------------------------------------------------===// |
934 | |
935 | LogicalResult |
936 | ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
937 | auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); |
938 | if (failed(mesh)) { |
939 | return failure(); |
940 | } |
941 | return success(); |
942 | } |
943 | |
944 | void ProcessLinearIndexOp::build(OpBuilder &odsBuilder, |
945 | OperationState &odsState, MeshOp mesh) { |
946 | build(odsBuilder, odsState, mesh.getSymName()); |
947 | } |
948 | |
949 | void ProcessLinearIndexOp::getAsmResultNames( |
950 | function_ref<void(Value, StringRef)> setNameFn) { |
951 | setNameFn(getResult(), "proc_linear_idx"); |
952 | } |
953 | |
954 | //===----------------------------------------------------------------------===// |
955 | // mesh.neighbors_linear_indices op |
956 | //===----------------------------------------------------------------------===// |
957 | |
958 | LogicalResult |
959 | NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
960 | auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); |
961 | if (failed(mesh)) { |
962 | return failure(); |
963 | } |
964 | return success(); |
965 | } |
966 | |
967 | void NeighborsLinearIndicesOp::getAsmResultNames( |
968 | function_ref<void(Value, StringRef)> setNameFn) { |
969 | setNameFn(getNeighborDown(), "down_linear_idx"); |
970 | setNameFn(getNeighborUp(), "up_linear_idx"); |
971 | } |
972 | |
973 | //===----------------------------------------------------------------------===// |
974 | // collective communication ops |
975 | //===----------------------------------------------------------------------===// |
976 | |
977 | namespace { |
978 | |
979 | template <typename Op> |
980 | struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> { |
981 | using OpRewritePattern<Op>::OpRewritePattern; |
982 | LogicalResult matchAndRewrite(Op op, |
983 | PatternRewriter &rewriter) const override { |
984 | auto meshAxes = op.getMeshAxes(); |
985 | if (!meshAxes.empty()) { |
986 | return failure(); |
987 | } |
988 | if (op.getInput().getType() != op.getResult().getType()) { |
989 | return failure(); |
990 | } |
991 | |
992 | rewriter.replaceAllUsesWith(op.getResult(), op.getInput()); |
993 | rewriter.eraseOp(op: op.getOperation()); |
994 | return success(); |
995 | } |
996 | }; |
997 | |
998 | } // namespace |
999 | |
1000 | static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, |
1001 | ArrayRef<int64_t> device, |
1002 | Operation::operand_range deviceDynamic, |
1003 | ArrayRef<MeshAxis> meshAxes, |
1004 | ArrayRef<int64_t> meshShape) { |
1005 | if (device.size() != meshAxes.size()) { |
1006 | return emitError(loc) << "In-group device \""<< deviceName |
1007 | << "\" has unexpected multi-index size " |
1008 | << device.size() << ". Expected "<< meshAxes.size() |
1009 | << "."; |
1010 | } |
1011 | |
1012 | for (size_t i = 0; i < device.size(); ++i) { |
1013 | if (!ShapedType::isDynamic(device[i]) && |
1014 | !ShapedType::isDynamic(meshShape[meshAxes[i]]) && |
1015 | meshShape[meshAxes[i]] <= device[i]) { |
1016 | return emitError(loc) |
1017 | << "Out of bounds coordinate "<< i << " for in-group device \"" |
1018 | << deviceName << "\"." |
1019 | << " Got "<< device[i] << ", but expected value in the range [0, " |
1020 | << (meshShape[meshAxes[i]] - 1) << "]."; |
1021 | } |
1022 | } |
1023 | return success(); |
1024 | } |
1025 | |
1026 | template <typename It> |
1027 | static auto product(It begin, It end) { |
1028 | using ElementType = std::decay_t<decltype(*begin)>; |
1029 | return std::accumulate(begin, end, static_cast<ElementType>(1), |
1030 | std::multiplies<ElementType>()); |
1031 | } |
1032 | |
1033 | template <typename R> |
1034 | static auto product(R &&range) { |
1035 | return product(adl_begin(range), adl_end(range)); |
1036 | } |
1037 | |
1038 | static LogicalResult verifyDimensionCompatibility(Location loc, |
1039 | int64_t expectedDimSize, |
1040 | int64_t resultDimSize, |
1041 | int64_t resultAxis) { |
1042 | if (!ShapedType::isDynamic(resultDimSize) && |
1043 | expectedDimSize != resultDimSize) { |
1044 | return emitError(loc) << "Dimension size mismatch for result axis " |
1045 | << resultAxis << ". Expected " |
1046 | << (ShapedType::isDynamic(expectedDimSize) |
1047 | ? Twine("dynamic") |
1048 | : Twine(expectedDimSize)) |
1049 | << ", but got "<< resultDimSize << "."; |
1050 | } |
1051 | |
1052 | return success(); |
1053 | } |
1054 | |
1055 | static LogicalResult verifyGatherOperandAndResultShape( |
1056 | Value operand, Value result, int64_t gatherAxis, |
1057 | ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) { |
1058 | auto resultRank = cast<ShapedType>(result.getType()).getRank(); |
1059 | if (gatherAxis < 0 || gatherAxis >= resultRank) { |
1060 | return emitError(loc: result.getLoc()) |
1061 | << "Gather axis "<< gatherAxis << " is out of bounds [0, " |
1062 | << resultRank << ")."; |
1063 | } |
1064 | |
1065 | ShapedType operandType = cast<ShapedType>(operand.getType()); |
1066 | ShapedType resultType = cast<ShapedType>(result.getType()); |
1067 | auto deviceGroupSize = |
1068 | DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape)); |
1069 | for (int64_t axis = 0; axis < operandType.getRank(); ++axis) { |
1070 | auto operandDimSize = DimensionSize(operandType.getDimSize(axis)); |
1071 | auto resultDimSize = DimensionSize(resultType.getDimSize(axis)); |
1072 | auto expectedResultDimSize = |
1073 | axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize; |
1074 | if (failed(verifyDimensionCompatibility( |
1075 | result.getLoc(), expectedResultDimSize, resultDimSize, axis))) { |
1076 | return failure(); |
1077 | } |
1078 | } |
1079 | return success(); |
1080 | } |
1081 | |
1082 | static LogicalResult verifyAllToAllOperandAndResultShape( |
1083 | Value operand, Value result, int64_t splitAxis, int64_t concatAxis, |
1084 | ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) { |
1085 | ShapedType operandType = cast<ShapedType>(operand.getType()); |
1086 | ShapedType resultType = cast<ShapedType>(result.getType()); |
1087 | for (int64_t axis = 0; axis < operandType.getRank(); ++axis) { |
1088 | if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) { |
1089 | if (failed(verifyDimensionCompatibility( |
1090 | result.getLoc(), operandType.getDimSize(axis), |
1091 | resultType.getDimSize(axis), axis))) { |
1092 | return failure(); |
1093 | } |
1094 | } |
1095 | } |
1096 | |
1097 | if (splitAxis == concatAxis) { |
1098 | return success(); |
1099 | } |
1100 | |
1101 | auto deviceGroupSize = |
1102 | DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape)); |
1103 | auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis)); |
1104 | auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis)); |
1105 | DimensionSize expectedResultConcatDimSize = |
1106 | operandConcatDimSize * deviceGroupSize; |
1107 | DimensionSize expectedResultSplitDimSize = |
1108 | operandSplitDimSize / deviceGroupSize; |
1109 | if (!expectedResultSplitDimSize.isDynamic() && |
1110 | int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) { |
1111 | expectedResultSplitDimSize = DimensionSize::dynamic(); |
1112 | } |
1113 | if (failed(verifyDimensionCompatibility( |
1114 | result.getLoc(), expectedResultConcatDimSize.value(), |
1115 | resultType.getDimSize(concatAxis), concatAxis))) { |
1116 | return failure(); |
1117 | } |
1118 | if (failed(verifyDimensionCompatibility( |
1119 | result.getLoc(), expectedResultSplitDimSize.value(), |
1120 | resultType.getDimSize(splitAxis), splitAxis))) { |
1121 | return failure(); |
1122 | } |
1123 | |
1124 | return success(); |
1125 | } |
1126 | |
1127 | static LogicalResult verifyScatterOrSliceOperandAndResultShape( |
1128 | Value operand, Value result, int64_t tensorAxis, |
1129 | ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) { |
1130 | ShapedType operandType = cast<ShapedType>(operand.getType()); |
1131 | ShapedType resultType = cast<ShapedType>(result.getType()); |
1132 | for (int64_t axis = 0; axis < operandType.getRank(); ++axis) { |
1133 | if (axis != tensorAxis) { |
1134 | if (failed(verifyDimensionCompatibility( |
1135 | result.getLoc(), operandType.getDimSize(axis), |
1136 | resultType.getDimSize(axis), axis))) { |
1137 | return failure(); |
1138 | } |
1139 | } |
1140 | } |
1141 | |
1142 | auto deviceGroupSize = |
1143 | DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape)); |
1144 | auto operandScatterDimSize = |
1145 | DimensionSize(operandType.getDimSize(tensorAxis)); |
1146 | if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() && |
1147 | int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) { |
1148 | return emitError(loc: result.getLoc()) |
1149 | << "Operand dimension size "<< int64_t(operandScatterDimSize) |
1150 | << " is not divisible by collective device group size " |
1151 | << int64_t(deviceGroupSize) << " for tensor axis "<< tensorAxis |
1152 | << "."; |
1153 | } |
1154 | DimensionSize expectedResultTensorDimSize = |
1155 | operandScatterDimSize / deviceGroupSize; |
1156 | if (failed(verifyDimensionCompatibility( |
1157 | result.getLoc(), expectedResultTensorDimSize.value(), |
1158 | resultType.getDimSize(tensorAxis), tensorAxis))) { |
1159 | return failure(); |
1160 | } |
1161 | |
1162 | return success(); |
1163 | } |
1164 | |
1165 | static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, |
1166 | ArrayRef<MeshAxis> meshAxes, |
1167 | int64_t sliceAxis) { |
1168 | RankedTensorType operandRankedTensorType = |
1169 | cast<RankedTensorType>(operandType); |
1170 | DimensionSize operandSliceAxisSize = |
1171 | operandRankedTensorType.getShape()[sliceAxis]; |
1172 | SmallVector<int64_t> resultShape = |
1173 | llvm::to_vector(operandRankedTensorType.getShape()); |
1174 | |
1175 | resultShape[sliceAxis] = |
1176 | operandSliceAxisSize / |
1177 | DimensionSize(collectiveProcessGroupSize(meshAxes, mesh)); |
1178 | return operandRankedTensorType.clone(resultShape); |
1179 | } |
1180 | |
1181 | //===----------------------------------------------------------------------===// |
1182 | // mesh.all_gather op |
1183 | //===----------------------------------------------------------------------===// |
1184 | |
1185 | LogicalResult |
1186 | AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
1187 | auto mesh = getMeshAndVerifyAxes(*this, symbolTable); |
1188 | if (failed(mesh)) { |
1189 | return failure(); |
1190 | } |
1191 | auto gatherAxis = getGatherAxis().getSExtValue(); |
1192 | return verifyGatherOperandAndResultShape(getOperand(), getResult(), |
1193 | gatherAxis, getMeshAxes(), |
1194 | mesh.value().getShape()); |
1195 | } |
1196 | |
1197 | void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1198 | MLIRContext *context) { |
1199 | patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context); |
1200 | } |
1201 | |
1202 | void AllGatherOp::getAsmResultNames( |
1203 | function_ref<void(Value, StringRef)> setNameFn) { |
1204 | setNameFn(getResult(), "all_gather"); |
1205 | } |
1206 | |
1207 | //===----------------------------------------------------------------------===// |
1208 | // mesh.all_reduce op |
1209 | //===----------------------------------------------------------------------===// |
1210 | |
1211 | LogicalResult |
1212 | AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
1213 | return getMeshAndVerifyAxes(*this, symbolTable); |
1214 | } |
1215 | |
1216 | void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1217 | MLIRContext *context) { |
1218 | patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context); |
1219 | } |
1220 | |
1221 | void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState, |
1222 | Value input, StringRef mesh, |
1223 | ArrayRef<MeshAxis> meshAxes, ReductionKind reduction) { |
1224 | build(odsBuilder, odsState, input.getType(), mesh, meshAxes, input, |
1225 | reduction); |
1226 | } |
1227 | |
1228 | void AllReduceOp::getAsmResultNames( |
1229 | function_ref<void(Value, StringRef)> setNameFn) { |
1230 | setNameFn(getResult(), "all_reduce"); |
1231 | } |
1232 | |
1233 | //===----------------------------------------------------------------------===// |
1234 | // mesh.all_slice op |
1235 | //===----------------------------------------------------------------------===// |
1236 | |
1237 | LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
1238 | auto mesh = getMeshAndVerifyAxes(*this, symbolTable); |
1239 | if (failed(mesh)) { |
1240 | return failure(); |
1241 | } |
1242 | return verifyScatterOrSliceOperandAndResultShape( |
1243 | getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(), |
1244 | mesh.value().getShape()); |
1245 | } |
1246 | |
1247 | void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1248 | MLIRContext *context) { |
1249 | patterns.add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context); |
1250 | } |
1251 | |
1252 | void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState, |
1253 | Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes, |
1254 | int64_t sliceAxis) { |
1255 | Type resultType = sliceResultType(input.getType(), mesh, meshAxes, sliceAxis); |
1256 | build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes, |
1257 | sliceAxis); |
1258 | } |
1259 | |
1260 | void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState, |
1261 | Type resultType, Value input, StringRef mesh, |
1262 | ArrayRef<MeshAxis> meshAxes, int64_t sliceAxis) { |
1263 | build(odsBuilder, odsState, resultType, mesh, meshAxes, input, |
1264 | APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis)); |
1265 | } |
1266 | |
1267 | void AllSliceOp::getAsmResultNames( |
1268 | function_ref<void(Value, StringRef)> setNameFn) { |
1269 | setNameFn(getResult(), "all_slice"); |
1270 | } |
1271 | |
1272 | //===----------------------------------------------------------------------===// |
1273 | // mesh.all_to_all op |
1274 | //===----------------------------------------------------------------------===// |
1275 | |
1276 | LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
1277 | auto mesh = getMeshAndVerifyAxes(*this, symbolTable); |
1278 | if (failed(mesh)) { |
1279 | return failure(); |
1280 | } |
1281 | |
1282 | return verifyAllToAllOperandAndResultShape( |
1283 | getOperand(), getResult(), getSplitAxis().getSExtValue(), |
1284 | getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape()); |
1285 | } |
1286 | |
1287 | void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1288 | MLIRContext *context) { |
1289 | patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context); |
1290 | } |
1291 | |
1292 | void AllToAllOp::getAsmResultNames( |
1293 | function_ref<void(Value, StringRef)> setNameFn) { |
1294 | setNameFn(getResult(), "all_to_all"); |
1295 | } |
1296 | |
1297 | //===----------------------------------------------------------------------===// |
1298 | // mesh.broadcast op |
1299 | //===----------------------------------------------------------------------===// |
1300 | |
1301 | LogicalResult |
1302 | BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
1303 | auto mesh = getMeshAndVerifyAxes(*this, symbolTable); |
1304 | if (failed(mesh)) { |
1305 | return failure(); |
1306 | } |
1307 | if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), |
1308 | getRootDynamic(), getMeshAxes(), |
1309 | mesh.value().getShape()))) { |
1310 | return failure(); |
1311 | } |
1312 | |
1313 | return success(); |
1314 | } |
1315 | |
1316 | void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1317 | MLIRContext *context) { |
1318 | patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context); |
1319 | } |
1320 | |
1321 | void BroadcastOp::getAsmResultNames( |
1322 | function_ref<void(Value, StringRef)> setNameFn) { |
1323 | setNameFn(getResult(), "broadcast"); |
1324 | } |
1325 | |
1326 | //===----------------------------------------------------------------------===// |
1327 | // mesh.gather op |
1328 | //===----------------------------------------------------------------------===// |
1329 | |
1330 | LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
1331 | auto mesh = getMeshAndVerifyAxes(*this, symbolTable); |
1332 | if (failed(mesh)) { |
1333 | return failure(); |
1334 | } |
1335 | if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), |
1336 | getRootDynamic(), getMeshAxes(), |
1337 | mesh.value().getShape()))) { |
1338 | return failure(); |
1339 | } |
1340 | |
1341 | auto gatherAxis = getGatherAxis().getSExtValue(); |
1342 | return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis, |
1343 | getMeshAxes(), |
1344 | mesh.value().getShape()); |
1345 | } |
1346 | |
1347 | void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1348 | MLIRContext *context) { |
1349 | patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context); |
1350 | } |
1351 | |
1352 | void GatherOp::getAsmResultNames( |
1353 | function_ref<void(Value, StringRef)> setNameFn) { |
1354 | setNameFn(getResult(), "gather"); |
1355 | } |
1356 | |
1357 | //===----------------------------------------------------------------------===// |
1358 | // mesh.recv op |
1359 | //===----------------------------------------------------------------------===// |
1360 | |
1361 | LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
1362 | auto mesh = getMeshAndVerifyAxes(*this, symbolTable); |
1363 | if (failed(mesh)) { |
1364 | return failure(); |
1365 | } |
1366 | if (getSource() && |
1367 | failed(verifyInGroupDevice(getLoc(), getSourceAttrName(), |
1368 | getSource().value(), getSourceDynamic(), |
1369 | getMeshAxes(), mesh.value().getShape()))) { |
1370 | return failure(); |
1371 | } |
1372 | return success(); |
1373 | } |
1374 | |
1375 | void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1376 | MLIRContext *context) { |
1377 | patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context); |
1378 | } |
1379 | |
1380 | void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { |
1381 | setNameFn(getResult(), "recv"); |
1382 | } |
1383 | |
1384 | //===----------------------------------------------------------------------===// |
1385 | // mesh.reduce op |
1386 | //===----------------------------------------------------------------------===// |
1387 | |
1388 | LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
1389 | auto mesh = getMeshAndVerifyAxes(*this, symbolTable); |
1390 | if (failed(mesh)) { |
1391 | return failure(); |
1392 | } |
1393 | if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), |
1394 | getRootDynamic(), getMeshAxes(), |
1395 | mesh.value().getShape()))) { |
1396 | return failure(); |
1397 | } |
1398 | |
1399 | return success(); |
1400 | } |
1401 | |
1402 | void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1403 | MLIRContext *context) { |
1404 | patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context); |
1405 | } |
1406 | |
1407 | void ReduceOp::getAsmResultNames( |
1408 | function_ref<void(Value, StringRef)> setNameFn) { |
1409 | setNameFn(getResult(), "reduce"); |
1410 | } |
1411 | |
1412 | //===----------------------------------------------------------------------===// |
1413 | // mesh.reduce_scatter op |
1414 | //===----------------------------------------------------------------------===// |
1415 | |
1416 | LogicalResult |
1417 | ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
1418 | auto mesh = getMeshAndVerifyAxes(*this, symbolTable); |
1419 | if (failed(mesh)) { |
1420 | return failure(); |
1421 | } |
1422 | |
1423 | return verifyScatterOrSliceOperandAndResultShape( |
1424 | getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(), |
1425 | mesh.value().getShape()); |
1426 | } |
1427 | |
1428 | void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1429 | MLIRContext *context) { |
1430 | patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context); |
1431 | } |
1432 | |
1433 | void ReduceScatterOp::getAsmResultNames( |
1434 | function_ref<void(Value, StringRef)> setNameFn) { |
1435 | setNameFn(getResult(), "reduce_scatter"); |
1436 | } |
1437 | |
1438 | //===----------------------------------------------------------------------===// |
1439 | // mesh.scatter op |
1440 | //===----------------------------------------------------------------------===// |
1441 | |
1442 | LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
1443 | auto mesh = getMeshAndVerifyAxes(*this, symbolTable); |
1444 | if (failed(mesh)) { |
1445 | return failure(); |
1446 | } |
1447 | if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), |
1448 | getRootDynamic(), getMeshAxes(), |
1449 | mesh.value().getShape()))) { |
1450 | return failure(); |
1451 | } |
1452 | |
1453 | auto scatterAxis = getScatterAxis().getSExtValue(); |
1454 | return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(), |
1455 | scatterAxis, getMeshAxes(), |
1456 | mesh.value().getShape()); |
1457 | } |
1458 | |
1459 | void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1460 | MLIRContext *context) { |
1461 | patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context); |
1462 | } |
1463 | |
1464 | void ScatterOp::getAsmResultNames( |
1465 | function_ref<void(Value, StringRef)> setNameFn) { |
1466 | setNameFn(getResult(), "scatter"); |
1467 | } |
1468 | |
1469 | //===----------------------------------------------------------------------===// |
1470 | // mesh.send op |
1471 | //===----------------------------------------------------------------------===// |
1472 | |
1473 | LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
1474 | auto mesh = getMeshAndVerifyAxes(*this, symbolTable); |
1475 | if (failed(mesh)) { |
1476 | return failure(); |
1477 | } |
1478 | if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(), |
1479 | getDestination(), getDestinationDynamic(), |
1480 | getMeshAxes(), mesh.value().getShape()))) { |
1481 | return failure(); |
1482 | } |
1483 | return success(); |
1484 | } |
1485 | |
1486 | void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1487 | MLIRContext *context) { |
1488 | patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context); |
1489 | } |
1490 | |
1491 | void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { |
1492 | setNameFn(getResult(), "send"); |
1493 | } |
1494 | |
1495 | //===----------------------------------------------------------------------===// |
1496 | // mesh.shift op |
1497 | //===----------------------------------------------------------------------===// |
1498 | |
1499 | LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
1500 | auto mesh = getMeshAndVerifyAxes(*this, symbolTable); |
1501 | if (failed(mesh)) { |
1502 | return failure(); |
1503 | } |
1504 | |
1505 | auto meshAxes = getMeshAxes(); |
1506 | auto shiftAxis = getShiftAxis().getZExtValue(); |
1507 | if (!llvm::is_contained(meshAxes, shiftAxis)) { |
1508 | return emitError() << "Invalid shift axis "<< shiftAxis |
1509 | << ". It must be one of the grouping mesh axes."; |
1510 | } |
1511 | |
1512 | return success(); |
1513 | } |
1514 | |
1515 | void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
1516 | MLIRContext *context) { |
1517 | // TODO: remove op when offset is 0 or if it is a rotate with and |
1518 | // offset % shift_axis_mesh_dim_size == 0. |
1519 | } |
1520 | |
1521 | void ShiftOp::getAsmResultNames( |
1522 | function_ref<void(Value, StringRef)> setNameFn) { |
1523 | setNameFn(getResult(), "shift"); |
1524 | } |
1525 | |
1526 | //===----------------------------------------------------------------------===// |
1527 | // mesh.update_halo op |
1528 | //===----------------------------------------------------------------------===// |
1529 | |
1530 | LogicalResult |
1531 | UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
1532 | auto mesh = getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); |
1533 | if (failed(mesh)) { |
1534 | return failure(); |
1535 | } |
1536 | |
1537 | return success(); |
1538 | } |
1539 | |
1540 | //===----------------------------------------------------------------------===// |
1541 | // TableGen'd op method definitions |
1542 | //===----------------------------------------------------------------------===// |
1543 | |
1544 | #define GET_OP_CLASSES |
1545 | #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc" |
1546 | |
1547 | #define GET_ATTRDEF_CLASSES |
1548 | #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc" |
1549 | |
1550 | #define GET_TYPEDEF_CLASSES |
1551 | #include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc" |
1552 | |
1553 | #include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc" |
1554 |
Definitions
- DimensionSize
- dynamic
- DimensionSize
- value
- operator int64_t
- isDynamic
- operator/
- operator*
- MeshInlinerInterface
- isLegalToInline
- isLegalToInline
- isLegalToInline
- getMeshAndVerify
- isUnique
- verifyMeshAxes
- getMeshAndVerifyAxes
- shardShape
- shardShapedType
- shardType
- maybeInsertTargetShardingAnnotation
- maybeInsertTargetShardingAnnotation
- maybeInsertSourceShardingAnnotation
- NormalizeSharding
- matchAndRewrite
- equalSplitAndPartialAxes
- equalHaloAndShardSizes
- equalShardSizes
- equalHaloSizes
- operator==
- operator!=
- operator==
- operator!=
- MeshSharding
- MeshSharding
- get
- FoldDuplicateShardOp
- matchAndRewrite
- EmptyMeshAxesCanonicalizationPattern
- matchAndRewrite
- verifyInGroupDevice
- product
- product
- verifyDimensionCompatibility
- verifyGatherOperandAndResultShape
- verifyAllToAllOperandAndResultShape
- verifyScatterOrSliceOperandAndResultShape
Improve your Profiling and Debugging skills
Find out more