1//===- MeshOps.cpp - Mesh Dialect Operations ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Mesh/IR/MeshOps.h"
10
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
13#include "mlir/Dialect/Utils/StaticValueUtils.h"
14#include "mlir/IR/Attributes.h"
15#include "mlir/IR/BuiltinAttributes.h"
16#include "mlir/IR/BuiltinTypeInterfaces.h"
17#include "mlir/IR/BuiltinTypes.h"
18#include "mlir/IR/Diagnostics.h"
19#include "mlir/IR/DialectImplementation.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Location.h"
22#include "mlir/IR/PatternMatch.h"
23#include "mlir/IR/TypeUtilities.h"
24#include "mlir/IR/Value.h"
25#include "mlir/Interfaces/ViewLikeInterface.h"
26#include "mlir/Support/LLVM.h"
27#include "mlir/Transforms/InliningUtils.h"
28#include "llvm/ADT/ArrayRef.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallSet.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include <algorithm>
34#include <functional>
35#include <iterator>
36#include <numeric>
37#include <optional>
38#include <utility>
39
40#define DEBUG_TYPE "mesh-ops"
41#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
42
43using namespace mlir;
44using namespace mlir::mesh;
45
46#include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
47
48namespace {
49
50struct DimensionSize {
51 static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
52 DimensionSize(int64_t val) : val(val) {}
53 int64_t value() const { return val; }
54 operator int64_t() const { return val; }
55 bool isDynamic() const { return ShapedType::isDynamic(dValue: val); }
56
57private:
58 int64_t val;
59};
60
61} // namespace
62
63static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
64 if (lhs.isDynamic() || rhs.isDynamic()) {
65 return DimensionSize::dynamic();
66 }
67 return lhs.value() / rhs.value();
68}
69
70static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
71 if (lhs.isDynamic() || rhs.isDynamic()) {
72 return DimensionSize::dynamic();
73 }
74 return lhs.value() * rhs.value();
75}
76
77SmallVector<Value> mlir::mesh::getMixedAsValues(OpBuilder b,
78 const Location &loc,
79 llvm::ArrayRef<int64_t> statics,
80 ValueRange dynamics,
81 Type type) {
82 SmallVector<Value> values;
83 auto dyn = dynamics.begin();
84 Type i64 = b.getI64Type();
85 if (!type)
86 type = i64;
87 assert((i64 == type || b.getIndexType() == type) &&
88 "expected an i64 or an intex type");
89 for (auto s : statics) {
90 if (s == ShapedType::kDynamic) {
91 values.emplace_back(Args: *(dyn++));
92 } else {
93 TypedAttr val = type == i64 ? b.getI64IntegerAttr(value: s) : b.getIndexAttr(value: s);
94 values.emplace_back(Args: b.create<arith::ConstantOp>(location: loc, args&: type, args&: val));
95 }
96 }
97 return values;
98}
99
100//===----------------------------------------------------------------------===//
101// Inliner
102//===----------------------------------------------------------------------===//
103
104namespace {
105struct MeshInlinerInterface : public DialectInlinerInterface {
106 using DialectInlinerInterface::DialectInlinerInterface;
107 // Currently no restrictions are encoded for inlining.
108 bool isLegalToInline(Operation *, Operation *, bool) const final {
109 return true;
110 }
111 bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
112 return true;
113 }
114 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
115 return true;
116 }
117};
118} // namespace
119
120//===----------------------------------------------------------------------===//
121// Mesh dialect
122//===----------------------------------------------------------------------===//
123
124void MeshDialect::initialize() {
125 addOperations<
126#define GET_OP_LIST
127#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
128 >();
129 addAttributes<
130#define GET_ATTRDEF_LIST
131#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
132 >();
133 addTypes<
134#define GET_TYPEDEF_LIST
135#include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
136 >();
137 addInterface<MeshInlinerInterface>();
138}
139
140Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
141 Type type, Location loc) {
142 return arith::ConstantOp::materialize(builder, value, type, loc);
143}
144
145//===----------------------------------------------------------------------===//
146// Mesh utilities
147//===----------------------------------------------------------------------===//
148
149static FailureOr<MeshOp> getMeshAndVerify(Operation *op,
150 FlatSymbolRefAttr meshSymbol,
151 SymbolTableCollection &symbolTable) {
152 mesh::MeshOp mesh = getMeshOrNull(op, meshSymbol, symbolTableCollection&: symbolTable);
153 if (!mesh) {
154 return op->emitError() << "Undefined required mesh symbol \""
155 << meshSymbol.getValue() << "\".";
156 }
157
158 return mesh;
159}
160
161template <typename It>
162bool isUnique(It begin, It end) {
163 if (begin == end) {
164 return true;
165 }
166 It next = std::next(begin);
167 if (next == end) {
168 return true;
169 }
170 for (; next != end; ++next, ++begin) {
171 if (*begin == *next) {
172 return false;
173 }
174 }
175 return true;
176}
177
178static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
179 MeshOp mesh) {
180 SmallVector<MeshAxis> sorted = llvm::to_vector(Range&: axes);
181 llvm::sort(C&: sorted);
182 if (!isUnique(begin: sorted.begin(), end: sorted.end())) {
183 return emitError(loc) << "Mesh axes contains duplicate elements.";
184 }
185
186 MeshAxis rank = mesh.getRank();
187 for (auto axis : axes) {
188 if (axis >= rank || axis < 0) {
189 return emitError(loc)
190 << "0-based mesh axis index " << axis
191 << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
192 << "\" is of rank " << rank << ".";
193 }
194 }
195
196 return success();
197}
198
199template <typename Op>
200static FailureOr<MeshOp>
201getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
202 auto mesh =
203 ::getMeshAndVerify(op: op.getOperation(), meshSymbol: op.getMeshAttr(), symbolTable);
204 if (failed(mesh)) {
205 return failure();
206 }
207 if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) {
208 return failure();
209 }
210 return mesh;
211}
212
213template <typename InShape, typename MeshShape, typename SplitAxes,
214 typename OutShape>
215static void shardShape(const InShape &inShape, const MeshShape &meshShape,
216 const SplitAxes &splitAxes, OutShape &outShape,
217 ArrayRef<int64_t> shardedDimsOffsets = {},
218 ArrayRef<int64_t> haloSizes = {}) {
219 // 0d tensors cannot be sharded and must get replicated
220 if (inShape.empty()) {
221 assert(outShape.empty());
222 return;
223 }
224
225 std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
226 llvm::adl_begin(outShape));
227
228 if (!shardedDimsOffsets.empty()) {
229 auto isDynShape = ShapedType::isDynamicShape(dSizes: meshShape);
230 uint64_t pos = 1;
231 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
232 if (!innerSplitAxes.empty()) {
233 auto sz = shardedDimsOffsets[pos];
234 bool same = !isDynShape;
235 if (same) {
236 // Find sharded dims in shardedDimsOffsets with same static size on
237 // all devices. Use kDynamic for dimensions with dynamic or
238 // non-uniform offs in shardedDimsOffsets.
239 uint64_t numShards = 0;
240 for (auto i : innerSplitAxes.asArrayRef()) {
241 numShards += meshShape[i];
242 }
243 for (size_t i = 1; i < numShards; ++i) {
244 if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
245 sz) {
246 same = false;
247 break;
248 }
249 }
250 pos += numShards + 1;
251 }
252 outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
253 }
254 }
255 } else {
256 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
257 outShape[tensorAxis] = shardDimension(
258 inShape[tensorAxis],
259 collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
260 }
261
262 if (!haloSizes.empty()) {
263 // add halo sizes if requested
264 int haloAxis = 0;
265 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
266 if (ShapedType::isStatic(dValue: outShape[tensorAxis]) &&
267 !innerSplitAxes.empty()) {
268 if (haloSizes[haloAxis * 2] >= 0 &&
269 haloSizes[haloAxis * 2 + 1] >= 0) {
270 outShape[tensorAxis] +=
271 haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1];
272 ++haloAxis;
273 } else {
274 outShape[tensorAxis] = ShapedType::kDynamic;
275 }
276 }
277 }
278 }
279 }
280}
281
282ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
283 MeshSharding sharding) {
284 using Dim = std::decay_t<decltype(shape.getDimSize(idx: 0))>;
285 SmallVector<Dim> resShapeArr(shape.getShape().size());
286 shardShape(inShape: shape.getShape(), meshShape: mesh.getShape(), splitAxes: sharding.getSplitAxes(),
287 outShape&: resShapeArr, shardedDimsOffsets: sharding.getStaticShardedDimsOffsets(),
288 haloSizes: sharding.getStaticHaloSizes());
289 return shape.clone(shape: resShapeArr);
290}
291
292Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
293 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(Val&: type);
294 if (rankedTensorType && !rankedTensorType.getShape().empty()) {
295 return shardShapedType(shape: rankedTensorType, mesh, sharding);
296 }
297 return type;
298}
299
300static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
301 Value &operandValue,
302 Operation *operandOp,
303 OpBuilder &builder,
304 ShardOp &newShardOp) {
305 OpBuilder::InsertionGuard insertionGuard(builder);
306 builder.setInsertionPointAfterValue(operandValue);
307 ShardOp shardOp = dyn_cast<ShardOp>(Val: operandOp);
308 if (shardOp && sharding == shardOp.getSharding() &&
309 !shardOp.getAnnotateForUsers()) {
310 // No need for anything if the correct sharding is already set.
311 if (!newShardOp) {
312 newShardOp = shardOp;
313 }
314 return;
315 }
316
317 if (!newShardOp) {
318 auto shardingOp =
319 builder.create<ShardingOp>(location: operandValue.getLoc(), args&: sharding);
320 newShardOp =
321 builder.create<ShardOp>(location: operandValue.getLoc(), args&: operandValue, args&: shardingOp,
322 /*annotate_for_users*/ args: false);
323 }
324 operandValue.replaceUsesWithIf(
325 newValue: newShardOp, shouldReplace: [operandOp, operandValue](OpOperand &use) {
326 return use.getOwner() == operandOp && use.get() == operandValue;
327 });
328
329 if (!shardOp || shardOp.getAnnotateForUsers()) {
330 return;
331 }
332
333 auto newShardOp2 = builder.create<ShardOp>(location: operandValue.getLoc(), args&: newShardOp,
334 args: newShardOp.getSharding(),
335 /*annotate_for_users*/ args: true);
336 newShardOp.getResult().replaceAllUsesExcept(newValue: newShardOp2, exceptedUser: newShardOp2);
337}
338
339void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
340 OpResult result,
341 OpBuilder &builder) {
342 ShardOp newShardOp;
343 SmallVector<std::pair<Value, Operation *>> uses;
344 for (auto &use : result.getUses()) {
345 uses.emplace_back(Args: use.get(), Args: use.getOwner());
346 }
347 for (auto &[operandValue, operandOp] : uses) {
348 maybeInsertTargetShardingAnnotationImpl(sharding, operandValue, operandOp,
349 builder, newShardOp);
350 }
351}
352
353void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding,
354 OpOperand &operand,
355 OpBuilder &builder) {
356 OpBuilder::InsertionGuard insertionGuard(builder);
357 Value operandValue = operand.get();
358 Operation *operandSrcOp = operandValue.getDefiningOp();
359 bool isBlockArg = !operandSrcOp;
360 {
361 [[maybe_unused]] auto opType =
362 dyn_cast<mlir::RankedTensorType>(Val: operandValue.getType());
363 assert(!opType || opType.getRank() > 0 || isFullReplication(sharding));
364 }
365 if (!isa<RankedTensorType>(Val: operandValue.getType()) && operandSrcOp &&
366 operandSrcOp->hasTrait<OpTrait::ConstantLike>()) {
367 return;
368 }
369
370 Operation *operandOp = operand.getOwner();
371 ShardOp shardOp = dyn_cast_or_null<ShardOp>(Val: operandSrcOp);
372
373 if (shardOp && sharding == shardOp.getSharding() &&
374 shardOp.getAnnotateForUsers()) {
375 // No need for anything the correct sharding is already set.
376 return;
377 }
378
379 builder.setInsertionPoint(operandOp);
380 auto shardingOp =
381 builder.create<ShardingOp>(location: operand.get().getLoc(), args&: sharding);
382 auto newShardOp =
383 builder.create<ShardOp>(location: operandValue.getLoc(), args&: operandValue, args&: shardingOp,
384 /*annotate_for_users*/ args: true);
385 IRRewriter rewriter(builder);
386 rewriter.replaceUsesWithIf(
387 from: operandValue, to: newShardOp, functor: [operandOp, operandValue](OpOperand &use) {
388 return use.getOwner() == operandOp && use.get() == operandValue;
389 });
390
391 if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
392 // No need for resharding.
393 return;
394 }
395
396 builder.setInsertionPoint(newShardOp);
397 auto newPreceedingShardOp =
398 builder.create<ShardOp>(location: operandValue.getLoc(), args&: operandValue, args&: shardingOp,
399 /*annotate_for_users*/ args: false);
400 rewriter.replaceUsesWithIf(
401 from: newShardOp.getSrc(), to: newPreceedingShardOp, functor: [&newShardOp](OpOperand &use) {
402 return use.getOwner() == newShardOp.getOperation();
403 });
404}
405
406//===----------------------------------------------------------------------===//
407// mesh.mesh op
408//===----------------------------------------------------------------------===//
409
410LogicalResult MeshOp::verify() {
411 int64_t rank = getRank();
412
413 if (rank <= 0)
414 return emitOpError(message: "rank of mesh is expected to be a positive integer");
415
416 for (int64_t dimSize : getShape()) {
417 if (dimSize < 0 && ShapedType::isStatic(dValue: dimSize))
418 return emitOpError(message: "dimension size of a mesh is expected to be "
419 "non-negative or dynamic");
420 }
421
422 return success();
423}
424
425//===----------------------------------------------------------------------===//
426// mesh.mesh_shape op
427//===----------------------------------------------------------------------===//
428
429LogicalResult
430MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
431 auto mesh = ::getMeshAndVerify(op: getOperation(), meshSymbol: getMeshAttr(), symbolTable);
432 if (failed(Result: mesh)) {
433 return failure();
434 }
435 if (failed(Result: verifyMeshAxes(loc: getLoc(), axes: getAxes(), mesh: mesh.value()))) {
436 return failure();
437 }
438
439 size_t expectedResultsCount =
440 getAxes().empty() ? mesh->getRank() : getAxes().size();
441 if (getResult().size() != expectedResultsCount) {
442 return emitError() << "Unexpected number of results " << getResult().size()
443 << ". Expected " << expectedResultsCount << ".";
444 }
445
446 return success();
447}
448
449void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
450 MeshOp mesh) {
451 build(odsBuilder, odsState, mesh, axes: SmallVector<MeshAxis>());
452}
453
454void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
455 MeshOp mesh, ArrayRef<MeshAxis> axes) {
456 build(odsBuilder, odsState,
457 result: SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(),
458 odsBuilder.getIndexType()),
459 mesh: mesh.getSymName(), axes: MeshAxesAttr::get(context: odsBuilder.getContext(), content: axes));
460}
461
462void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
463 StringRef mesh, ArrayRef<MeshAxis> axes) {
464 assert(!axes.empty());
465 build(odsBuilder, odsState,
466 result: SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
467 axes: MeshAxesAttr::get(context: odsBuilder.getContext(), content: axes));
468}
469
470void MeshShapeOp::getAsmResultNames(
471 function_ref<void(Value, StringRef)> setNameFn) {
472 setNameFn(getResults()[0], "mesh_shape");
473}
474
475//===----------------------------------------------------------------------===//
476// mesh.sharding
477//===----------------------------------------------------------------------===//
478
479void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
480 FlatSymbolRefAttr mesh,
481 ArrayRef<MeshAxesAttr> split_axes,
482 ArrayRef<MeshAxis> partial_axes,
483 mesh::ReductionKind partial_type,
484 ArrayRef<int64_t> static_halos,
485 ArrayRef<int64_t> static_offsets) {
486 return build(
487 odsBuilder&: b, odsState, mesh, split_axes: MeshAxesArrayAttr::get(context: b.getContext(), axes: split_axes),
488 partial_axes: ::mlir::DenseI16ArrayAttr::get(context: b.getContext(), content: partial_axes),
489 partial_type: ::mlir::mesh::ReductionKindAttr::get(context: b.getContext(), value: partial_type),
490 static_sharded_dims_offsets: ::mlir::DenseI64ArrayAttr::get(context: b.getContext(), content: static_halos), dynamic_sharded_dims_offsets: {},
491 static_halo_sizes: ::mlir::DenseI64ArrayAttr::get(context: b.getContext(), content: static_offsets), dynamic_halo_sizes: {});
492}
493
494void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
495 FlatSymbolRefAttr mesh,
496 ArrayRef<MeshAxesAttr> split_axes) {
497 return build(
498 odsBuilder&: b, odsState, mesh, split_axes: MeshAxesArrayAttr::get(context: b.getContext(), axes: split_axes), partial_axes: {},
499 partial_type: ::mlir::mesh::ReductionKindAttr::get(context: b.getContext(), value: ReductionKind::Sum),
500 static_sharded_dims_offsets: {}, dynamic_sharded_dims_offsets: {}, static_halo_sizes: {}, dynamic_halo_sizes: {});
501}
502
503void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
504 llvm::StringRef mesh, ArrayRef<MeshAxesAttr> split_axes,
505 ArrayRef<int64_t> static_halos,
506 ArrayRef<int64_t> static_offsets) {
507 return build(
508 odsBuilder&: b, odsState, mesh: FlatSymbolRefAttr::get(ctx: b.getContext(), value: mesh),
509 split_axes: MeshAxesArrayAttr::get(context: b.getContext(), axes: split_axes), partial_axes: {},
510 partial_type: ::mlir::mesh::ReductionKindAttr::get(context: b.getContext(), value: ReductionKind::Sum),
511 static_sharded_dims_offsets: ::mlir::DenseI64ArrayAttr::get(context: b.getContext(), content: static_halos), dynamic_sharded_dims_offsets: {},
512 static_halo_sizes: ::mlir::DenseI64ArrayAttr::get(context: b.getContext(), content: static_offsets), dynamic_halo_sizes: {});
513}
514
515void ShardingOp::build(
516 ::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
517 FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
518 ::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes,
519 ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) {
520 mlir::SmallVector<int64_t> staticHalos, staticDims;
521 mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
522 dispatchIndexOpFoldResults(ofrs: halo_sizes, dynamicVec&: dynamicHalos, staticVec&: staticHalos);
523 dispatchIndexOpFoldResults(ofrs: sharded_dims_offsets, dynamicVec&: dynamicDims, staticVec&: staticDims);
524 return build(
525 odsBuilder&: b, odsState, mesh, split_axes: MeshAxesArrayAttr::get(context: b.getContext(), axes: split_axes), partial_axes: {},
526 partial_type: ::mlir::mesh::ReductionKindAttr::get(context: b.getContext(), value: ReductionKind::Sum),
527 static_sharded_dims_offsets: ::mlir::DenseI64ArrayAttr::get(context: b.getContext(), content: staticHalos), dynamic_sharded_dims_offsets: dynamicHalos,
528 static_halo_sizes: ::mlir::DenseI64ArrayAttr::get(context: b.getContext(), content: staticDims), dynamic_halo_sizes: dynamicDims);
529}
530
531void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
532 mlir::mesh::MeshSharding from) {
533
534 build(odsBuilder&: b, odsState, result: ShardingType::get(ctx: b.getContext()), mesh: from.getMeshAttr(),
535 split_axes: MeshAxesArrayAttr::get(context: b.getContext(), axes: from.getSplitAxes()),
536 partial_axes: from.getPartialAxes().empty()
537 ? DenseI16ArrayAttr()
538 : b.getDenseI16ArrayAttr(values: from.getPartialAxes()),
539 partial_type: ::mlir::mesh::ReductionKindAttr::get(context: b.getContext(),
540 value: from.getPartialType()),
541 static_sharded_dims_offsets: from.getStaticShardedDimsOffsets().empty()
542 ? DenseI64ArrayAttr()
543 : b.getDenseI64ArrayAttr(values: from.getStaticShardedDimsOffsets()),
544 dynamic_sharded_dims_offsets: from.getDynamicShardedDimsOffsets(),
545 static_halo_sizes: from.getStaticHaloSizes().empty()
546 ? DenseI64ArrayAttr()
547 : b.getDenseI64ArrayAttr(values: from.getStaticHaloSizes()),
548 dynamic_halo_sizes: from.getDynamicHaloSizes());
549}
550
551LogicalResult ShardingOp::verify() {
552 llvm::SmallSet<MeshAxis, 4> visitedAxes;
553
554 auto checkMeshAxis = [&](ArrayRef<MeshAxis> axesArray) -> LogicalResult {
555 for (MeshAxis axis : axesArray) {
556 if (axis < 0)
557 return emitError() << "mesh axis is expected to be non-negative";
558 if (!visitedAxes.insert(V: axis).second)
559 return emitError() << "mesh axis duplicated";
560 }
561 return success();
562 };
563
564 for (auto subAxes : getSplitAxes().getAxes()) {
565 ArrayRef<MeshAxis> subAxesArray = subAxes.asArrayRef();
566 if (failed(Result: checkMeshAxis(subAxesArray)))
567 return failure();
568 }
569 if (getPartialAxes().has_value() &&
570 failed(Result: checkMeshAxis(getPartialAxes().value())))
571 return failure();
572
573 if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
574 return emitOpError(message: "halo sizes and shard offsets are mutually exclusive");
575 }
576
577 if (!getStaticHaloSizes().empty()) {
578 auto numSplitAxes = getSplitAxes().getAxes().size();
579 for (auto splitAxis : getSplitAxes().getAxes()) {
580 if (splitAxis.empty()) {
581 --numSplitAxes;
582 }
583 }
584 if (getStaticHaloSizes().size() != numSplitAxes * 2) {
585 return emitError() << "halo sizes must be specified for all split axes.";
586 }
587 }
588
589 return success();
590}
591
592void ShardingOp::getAsmResultNames(
593 function_ref<void(Value, StringRef)> setNameFn) {
594 setNameFn(getResult(), "sharding");
595}
596
597LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
598 auto mesh = ::getMeshAndVerify(op: getOperation(), meshSymbol: getMeshAttr(), symbolTable);
599 if (failed(Result: mesh)) {
600 return failure();
601 }
602 if (mlir::ShapedType::isDynamicShape(dSizes: mesh->getShape()) &&
603 getStaticShardedDimsOffsets().size() > 0) {
604 return emitError() << "sharded dims offsets are not allowed for "
605 "devices meshes with dynamic shape.";
606 }
607
608 auto shardedDimsOffsets = getStaticShardedDimsOffsets();
609 if (!shardedDimsOffsets.empty()) {
610 auto meshShape = mesh.value().getShape();
611 assert(ShapedType::isStaticShape(meshShape));
612 uint64_t pos = 0;
613 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(First: getSplitAxes())) {
614 if (!innerSplitAxes.empty()) {
615 int64_t numShards = 0, off = 0;
616 for (auto i : innerSplitAxes.asArrayRef()) {
617 numShards += meshShape[i];
618 }
619 for (int64_t i = 0; i <= numShards; ++i) {
620 if (shardedDimsOffsets.size() <= pos + i) {
621 return emitError() << "sharded dims offsets has wrong size.";
622 }
623 if (ShapedType::isStatic(dValue: shardedDimsOffsets[pos + i])) {
624 if (shardedDimsOffsets[pos + i] < off) {
625 return emitError()
626 << "sharded dims offsets must be non-decreasing.";
627 }
628 off = shardedDimsOffsets[pos + i];
629 }
630 }
631 pos += numShards + 1;
632 }
633 }
634 }
635 return success();
636}
637
638namespace {
639// Sharding annotations "halo sizes" and "sharded dims offsets"
640// are a mix of attributes and dynamic values. This canonicalization moves
641// constant values to the respective attribute lists, minimizing the number
642// of values.
643// It also removes sharded_dims_sizes and halos if they are effectively "empty".
644class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
645public:
646 using OpRewritePattern<ShardingOp>::OpRewritePattern;
647
648 LogicalResult matchAndRewrite(ShardingOp op,
649 PatternRewriter &b) const override {
650 auto mixedHalos =
651 getMixedValues(staticValues: op.getStaticHaloSizes(), dynamicValues: op.getDynamicHaloSizes(), b);
652 auto mixedOffs = getMixedValues(staticValues: op.getStaticShardedDimsOffsets(),
653 dynamicValues: op.getDynamicShardedDimsOffsets(), b);
654
655 // No constant operands were folded, just return;
656 bool modified = succeeded(Result: foldDynamicIndexList(ofrs&: mixedHalos, onlyNonNegative: true)) ||
657 succeeded(Result: foldDynamicIndexList(ofrs&: mixedOffs, onlyNonNegative: true));
658
659 auto [staticHalos, dynamicHalos] = decomposeMixedValues(mixedValues: mixedHalos);
660 auto [staticOffs, dynamicOffs] = decomposeMixedValues(mixedValues: mixedOffs);
661
662 if (dynamicHalos.empty() && !staticHalos.empty()) {
663 if (staticHalos[0] == 0 && llvm::all_equal(Range&: staticHalos)) {
664 staticHalos.clear();
665 modified = true;
666 }
667 }
668
669 // Remove sharded dims offsets if they are effectively the default values,
670 // e.g. if they define equi-distance between all neighboring shards.
671 // Requires static-only offsets. Compares the first distance as the
672 // difference between the first two offsets. Only if all consecutive
673 // distances are the same, the offsets are removed.
674 if (dynamicOffs.empty() && !staticOffs.empty()) {
675 assert(staticOffs.size() >= 2);
676 auto diff = staticOffs[1] - staticOffs[0];
677 bool all_same = staticOffs.size() > 2;
678 for (auto i = 2u; i < staticOffs.size(); ++i) {
679 if (staticOffs[i] - staticOffs[i - 1] != diff) {
680 all_same = false;
681 break;
682 }
683 }
684 if (all_same) {
685 staticOffs.clear();
686 modified = true;
687 }
688 }
689
690 if (!modified) {
691 return failure();
692 }
693
694 op.setStaticHaloSizes(staticHalos);
695 op.getDynamicHaloSizesMutable().assign(values: dynamicHalos);
696 op.setStaticShardedDimsOffsets(staticOffs);
697 op.getDynamicShardedDimsOffsetsMutable().assign(values: dynamicOffs);
698
699 return success();
700 }
701};
702} // namespace
703
704void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
705 mlir::MLIRContext *context) {
706 results.add<NormalizeSharding>(arg&: context);
707}
708
709//===----------------------------------------------------------------------===//
710// MeshSharding
711//===----------------------------------------------------------------------===//
712
713bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const {
714 if (getMesh() != rhs.getMesh()) {
715 return false;
716 }
717
718 if (getPartialAxes().size() != rhs.getPartialAxes().size() ||
719 (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) ||
720 !llvm::equal(LRange: getPartialAxes(), RRange: rhs.getPartialAxes())) {
721 return false;
722 }
723
724 auto minSize = std::min(a: getSplitAxes().size(), b: rhs.getSplitAxes().size());
725 if (!llvm::equal(LRange: llvm::make_range(x: getSplitAxes().begin(),
726 y: getSplitAxes().begin() + minSize),
727 RRange: llvm::make_range(x: rhs.getSplitAxes().begin(),
728 y: rhs.getSplitAxes().begin() + minSize))) {
729 return false;
730 }
731
732 return llvm::all_of(Range: llvm::drop_begin(RangeOrContainer: getSplitAxes(), N: minSize),
733 P: std::mem_fn(pm: &MeshAxesAttr::empty)) &&
734 llvm::all_of(Range: llvm::drop_begin(RangeOrContainer: rhs.getSplitAxes(), N: minSize),
735 P: std::mem_fn(pm: &MeshAxesAttr::empty));
736}
737
738bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const {
739 return equalShardSizes(rhs) && equalHaloSizes(rhs);
740}
741
742bool MeshSharding::equalShardSizes(const MeshSharding &rhs) const {
743 if (rhs.getStaticShardedDimsOffsets().size() !=
744 getStaticShardedDimsOffsets().size() ||
745 !llvm::equal(LRange: getStaticShardedDimsOffsets(),
746 RRange: rhs.getStaticShardedDimsOffsets())) {
747 return false;
748 }
749 if (rhs.getDynamicShardedDimsOffsets().size() !=
750 getDynamicShardedDimsOffsets().size() ||
751 !llvm::equal(LRange: getDynamicShardedDimsOffsets(),
752 RRange: rhs.getDynamicShardedDimsOffsets())) {
753 return false;
754 }
755 return true;
756}
757
758bool MeshSharding::equalHaloSizes(const MeshSharding &rhs) const {
759 if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() ||
760 !llvm::equal(LRange: getStaticHaloSizes(), RRange: rhs.getStaticHaloSizes())) {
761 return false;
762 }
763 if (rhs.getDynamicHaloSizes().size() != getDynamicHaloSizes().size() ||
764 !llvm::equal(LRange: getDynamicHaloSizes(), RRange: rhs.getDynamicHaloSizes())) {
765 return false;
766 }
767 return true;
768}
769
770bool MeshSharding::operator==(Value rhs) const {
771 return equalSplitAndPartialAxes(rhs) && equalHaloAndShardSizes(rhs);
772}
773
774bool MeshSharding::operator!=(Value rhs) const { return !(*this == rhs); }
775
776bool MeshSharding::operator==(const MeshSharding &rhs) const {
777 return equalSplitAndPartialAxes(rhs) && equalHaloAndShardSizes(rhs);
778}
779
780bool MeshSharding::operator!=(const MeshSharding &rhs) const {
781 return !(*this == rhs);
782}
783
784MeshSharding::MeshSharding(::mlir::FlatSymbolRefAttr mesh_) : mesh(mesh_) {}
785
786MeshSharding::MeshSharding(Value rhs) {
787 auto shardingOp = mlir::dyn_cast<ShardingOp>(Val: rhs.getDefiningOp());
788 assert(shardingOp && "expected sharding op");
789 auto splitAxes = shardingOp.getSplitAxes().getAxes();
790 auto partialAxes = shardingOp.getPartialAxes().value_or(u: ArrayRef<MeshAxis>());
791 // If splitAxes and partialAxes are empty, use "empty" constructor.
792 if (splitAxes.empty() && partialAxes.empty()) {
793 *this = MeshSharding(shardingOp.getMeshAttr());
794 return;
795 }
796 *this = get(mesh_: shardingOp.getMeshAttr(), split_axes_: splitAxes, partial_axes_: partialAxes,
797 partial_type_: shardingOp.getPartialType().value_or(u: ReductionKind::Sum),
798 static_halo_sizes_: shardingOp.getStaticHaloSizes(),
799 static_sharded_dims_offsets_: shardingOp.getStaticShardedDimsOffsets(),
800 dynamic_halo_sizes_: SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
801 dynamic_sharded_dims_offsets_: SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
802}
803
804MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
805 ArrayRef<MeshAxesAttr> split_axes_,
806 ArrayRef<MeshAxis> partial_axes_,
807 ReductionKind partial_type_,
808 ArrayRef<int64_t> static_halo_sizes_,
809 ArrayRef<int64_t> static_sharded_dims_offsets_,
810 ArrayRef<Value> dynamic_halo_sizes_,
811 ArrayRef<Value> dynamic_sharded_dims_offsets_) {
812 MeshSharding res(mesh_);
813 if (split_axes_.empty() && partial_axes_.empty()) {
814 return res;
815 }
816
817 res.split_axes.resize(N: split_axes_.size());
818 for (auto [i, axis] : llvm::enumerate(First&: split_axes_)) {
819 res.split_axes[i] =
820 MeshAxesAttr::get(context: mesh_.getContext(), content: axis.asArrayRef());
821 }
822
823 auto clone = [](const auto src, auto &dst) {
824 dst.resize(src.size());
825 llvm::copy(src, dst.begin());
826 };
827
828 clone(partial_axes_, res.partial_axes);
829 res.partial_type = partial_type_;
830 clone(static_halo_sizes_, res.static_halo_sizes);
831 clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
832 clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
833 clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets);
834
835 return res;
836}
837
838//===----------------------------------------------------------------------===//
839// mesh.shard_shape
840//===----------------------------------------------------------------------===//
841
842void ShardShapeOp::getAsmResultNames(
843 function_ref<void(Value, StringRef)> setNameFn) {
844 setNameFn(getResult()[0], "shard_shape");
845}
846
847void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
848 ::mlir::OperationState &odsState,
849 ::llvm::ArrayRef<int64_t> dims,
850 ArrayRef<Value> dims_dyn, ::mlir::Value sharding,
851 ::mlir::ValueRange device) {
852 SmallVector<mlir::Type> resType(dims.size(), odsBuilder.getIndexType());
853 build(odsBuilder, odsState, result: resType, dims, dims_dynamic: dims_dyn, sharding,
854 device: SmallVector<int64_t>(device.size(), ShapedType::kDynamic), device_dynamic: device);
855}
856
857//===----------------------------------------------------------------------===//
858// mesh.shard op
859//===----------------------------------------------------------------------===//
860
861void ShardOp::getAsmResultNames(
862 function_ref<void(Value, StringRef)> setNameFn) {
863 setNameFn(getResult(), "sharding_annotated");
864}
865
866namespace {
867// Determine if the given ShardOp is a duplicate of another ShardOp
868// on the same value. This can happen if constant values are sharded.
869class FoldDuplicateShardOp final : public OpRewritePattern<ShardOp> {
870public:
871 using OpRewritePattern<ShardOp>::OpRewritePattern;
872
873 LogicalResult matchAndRewrite(ShardOp op, PatternRewriter &b) const override {
874 // Get the use-list of the value being sharded and check if it has more than
875 // one use.
876 Value value = op.getSrc();
877 if (value.hasOneUse() || value.getDefiningOp<ShardOp>()) {
878 return failure();
879 }
880
881 // Iterate through the uses of the value to find a duplicate ShardOp.
882 for (auto &use : value.getUses()) {
883 if (use.getOwner() != op.getOperation()) {
884 auto otherOp = dyn_cast<ShardOp>(Val: use.getOwner());
885 if (!otherOp || !otherOp->isBeforeInBlock(other: op)) {
886 return failure();
887 }
888 // Create a MeshSharding object for the current and the other ShardOp
889 // If the two are equal replace current op with the other op.
890 MeshSharding currentSharding(op.getSharding());
891 MeshSharding otherSharding(otherOp.getSharding());
892 if (currentSharding == otherSharding) {
893 b.replaceAllUsesWith(from: op.getResult(), to: otherOp.getResult());
894 b.eraseOp(op: op.getOperation());
895 } else {
896 // use the other sharding as input for op
897 op.getSrcMutable().assign(value: otherOp.getResult());
898 }
899 return success();
900 }
901 }
902
903 return failure();
904 }
905};
906} // namespace
907
908void ShardOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
909 mlir::MLIRContext *context) {
910 results.add<FoldDuplicateShardOp>(arg&: context);
911}
912
913//===----------------------------------------------------------------------===//
914// mesh.process_multi_index op
915//===----------------------------------------------------------------------===//
916
917LogicalResult
918ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
919 auto mesh = ::getMeshAndVerify(op: getOperation(), meshSymbol: getMeshAttr(), symbolTable);
920 if (failed(Result: mesh)) {
921 return failure();
922 }
923 if (failed(Result: verifyMeshAxes(loc: getLoc(), axes: getAxes(), mesh: mesh.value()))) {
924 return failure();
925 }
926
927 size_t expectedResultsCount =
928 getAxes().empty() ? mesh->getRank() : getAxes().size();
929 if (getResult().size() != expectedResultsCount) {
930 return emitError() << "Unexpected number of results " << getResult().size()
931 << ". Expected " << expectedResultsCount << ".";
932 }
933
934 return success();
935}
936
937void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
938 MeshOp mesh) {
939 build(odsBuilder, odsState,
940 result: SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
941 mesh: mesh.getSymName(), axes: ArrayRef<MeshAxis>());
942}
943
944void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
945 StringRef mesh, ArrayRef<MeshAxis> axes) {
946 build(odsBuilder, odsState,
947 result: SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
948 axes: MeshAxesAttr::get(context: odsBuilder.getContext(), content: axes));
949}
950
951void ProcessMultiIndexOp::getAsmResultNames(
952 function_ref<void(Value, StringRef)> setNameFn) {
953 setNameFn(getResults()[0], "proc_linear_idx");
954}
955
956//===----------------------------------------------------------------------===//
957// mesh.process_linear_index op
958//===----------------------------------------------------------------------===//
959
960LogicalResult
961ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
962 auto mesh = ::getMeshAndVerify(op: getOperation(), meshSymbol: getMeshAttr(), symbolTable);
963 if (failed(Result: mesh)) {
964 return failure();
965 }
966 return success();
967}
968
969void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
970 OperationState &odsState, MeshOp mesh) {
971 build(odsBuilder, odsState, mesh: mesh.getSymName());
972}
973
974void ProcessLinearIndexOp::getAsmResultNames(
975 function_ref<void(Value, StringRef)> setNameFn) {
976 setNameFn(getResult(), "proc_linear_idx");
977}
978
979//===----------------------------------------------------------------------===//
980// mesh.neighbors_linear_indices op
981//===----------------------------------------------------------------------===//
982
983LogicalResult
984NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
985 auto mesh = ::getMeshAndVerify(op: getOperation(), meshSymbol: getMeshAttr(), symbolTable);
986 if (failed(Result: mesh)) {
987 return failure();
988 }
989 return success();
990}
991
992void NeighborsLinearIndicesOp::getAsmResultNames(
993 function_ref<void(Value, StringRef)> setNameFn) {
994 setNameFn(getNeighborDown(), "down_linear_idx");
995 setNameFn(getNeighborUp(), "up_linear_idx");
996}
997
998//===----------------------------------------------------------------------===//
999// collective communication ops
1000//===----------------------------------------------------------------------===//
1001
1002namespace {
1003
1004template <typename Op>
1005struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
1006 using OpRewritePattern<Op>::OpRewritePattern;
1007 LogicalResult matchAndRewrite(Op op,
1008 PatternRewriter &rewriter) const override {
1009 auto meshAxes = op.getMeshAxes();
1010 if (!meshAxes.empty()) {
1011 return failure();
1012 }
1013 if (op.getInput().getType() != op.getResult().getType()) {
1014 return failure();
1015 }
1016
1017 rewriter.replaceAllUsesWith(op.getResult(), op.getInput());
1018 rewriter.eraseOp(op: op.getOperation());
1019 return success();
1020 }
1021};
1022
1023} // namespace
1024
1025static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
1026 ArrayRef<int64_t> device,
1027 Operation::operand_range deviceDynamic,
1028 ArrayRef<MeshAxis> meshAxes,
1029 ArrayRef<int64_t> meshShape) {
1030 if (device.size() != meshAxes.size()) {
1031 return emitError(loc) << "In-group device \"" << deviceName
1032 << "\" has unexpected multi-index size "
1033 << device.size() << ". Expected " << meshAxes.size()
1034 << ".";
1035 }
1036
1037 for (size_t i = 0; i < device.size(); ++i) {
1038 if (ShapedType::isStatic(dValue: device[i]) &&
1039 ShapedType::isStatic(dValue: meshShape[meshAxes[i]]) &&
1040 meshShape[meshAxes[i]] <= device[i]) {
1041 return emitError(loc)
1042 << "Out of bounds coordinate " << i << " for in-group device \""
1043 << deviceName << "\"."
1044 << " Got " << device[i] << ", but expected value in the range [0, "
1045 << (meshShape[meshAxes[i]] - 1) << "].";
1046 }
1047 }
1048 return success();
1049}
1050
1051template <typename It>
1052static auto product(It begin, It end) {
1053 using ElementType = std::decay_t<decltype(*begin)>;
1054 return std::accumulate(begin, end, static_cast<ElementType>(1),
1055 std::multiplies<ElementType>());
1056}
1057
1058template <typename R>
1059static auto product(R &&range) {
1060 return product(adl_begin(range), adl_end(range));
1061}
1062
1063static LogicalResult verifyDimensionCompatibility(Location loc,
1064 int64_t expectedDimSize,
1065 int64_t resultDimSize,
1066 int64_t resultAxis) {
1067 if (ShapedType::isStatic(dValue: resultDimSize) && expectedDimSize != resultDimSize) {
1068 return emitError(loc) << "Dimension size mismatch for result axis "
1069 << resultAxis << ". Expected "
1070 << (ShapedType::isDynamic(dValue: expectedDimSize)
1071 ? Twine("dynamic")
1072 : Twine(expectedDimSize))
1073 << ", but got " << resultDimSize << ".";
1074 }
1075
1076 return success();
1077}
1078
1079static LogicalResult verifyGatherOperandAndResultShape(
1080 Value operand, Value result, int64_t gatherAxis,
1081 ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
1082 auto resultRank = cast<ShapedType>(Val: result.getType()).getRank();
1083 if (gatherAxis < 0 || gatherAxis >= resultRank) {
1084 return emitError(loc: result.getLoc())
1085 << "Gather axis " << gatherAxis << " is out of bounds [0, "
1086 << resultRank << ").";
1087 }
1088
1089 ShapedType operandType = cast<ShapedType>(Val: operand.getType());
1090 ShapedType resultType = cast<ShapedType>(Val: result.getType());
1091 auto deviceGroupSize =
1092 DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
1093 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1094 auto operandDimSize = DimensionSize(operandType.getDimSize(idx: axis));
1095 auto resultDimSize = DimensionSize(resultType.getDimSize(idx: axis));
1096 auto expectedResultDimSize =
1097 axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
1098 if (failed(Result: verifyDimensionCompatibility(
1099 loc: result.getLoc(), expectedDimSize: expectedResultDimSize, resultDimSize, resultAxis: axis))) {
1100 return failure();
1101 }
1102 }
1103 return success();
1104}
1105
1106static LogicalResult verifyAllToAllOperandAndResultShape(
1107 Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
1108 ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
1109 ShapedType operandType = cast<ShapedType>(Val: operand.getType());
1110 ShapedType resultType = cast<ShapedType>(Val: result.getType());
1111 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1112 if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
1113 if (failed(Result: verifyDimensionCompatibility(
1114 loc: result.getLoc(), expectedDimSize: operandType.getDimSize(idx: axis),
1115 resultDimSize: resultType.getDimSize(idx: axis), resultAxis: axis))) {
1116 return failure();
1117 }
1118 }
1119 }
1120
1121 if (splitAxis == concatAxis) {
1122 return success();
1123 }
1124
1125 auto deviceGroupSize =
1126 DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
1127 auto operandConcatDimSize = DimensionSize(operandType.getDimSize(idx: concatAxis));
1128 auto operandSplitDimSize = DimensionSize(operandType.getDimSize(idx: splitAxis));
1129 DimensionSize expectedResultConcatDimSize =
1130 operandConcatDimSize * deviceGroupSize;
1131 DimensionSize expectedResultSplitDimSize =
1132 operandSplitDimSize / deviceGroupSize;
1133 if (!expectedResultSplitDimSize.isDynamic() &&
1134 int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
1135 expectedResultSplitDimSize = DimensionSize::dynamic();
1136 }
1137 if (failed(Result: verifyDimensionCompatibility(
1138 loc: result.getLoc(), expectedDimSize: expectedResultConcatDimSize.value(),
1139 resultDimSize: resultType.getDimSize(idx: concatAxis), resultAxis: concatAxis))) {
1140 return failure();
1141 }
1142 if (failed(Result: verifyDimensionCompatibility(
1143 loc: result.getLoc(), expectedDimSize: expectedResultSplitDimSize.value(),
1144 resultDimSize: resultType.getDimSize(idx: splitAxis), resultAxis: splitAxis))) {
1145 return failure();
1146 }
1147
1148 return success();
1149}
1150
1151static LogicalResult verifyScatterOrSliceOperandAndResultShape(
1152 Value operand, Value result, int64_t tensorAxis,
1153 ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
1154 ShapedType operandType = cast<ShapedType>(Val: operand.getType());
1155 ShapedType resultType = cast<ShapedType>(Val: result.getType());
1156 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1157 if (axis != tensorAxis) {
1158 if (failed(Result: verifyDimensionCompatibility(
1159 loc: result.getLoc(), expectedDimSize: operandType.getDimSize(idx: axis),
1160 resultDimSize: resultType.getDimSize(idx: axis), resultAxis: axis))) {
1161 return failure();
1162 }
1163 }
1164 }
1165
1166 auto deviceGroupSize =
1167 DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
1168 auto operandScatterDimSize =
1169 DimensionSize(operandType.getDimSize(idx: tensorAxis));
1170 if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
1171 int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
1172 return emitError(loc: result.getLoc())
1173 << "Operand dimension size " << int64_t(operandScatterDimSize)
1174 << " is not divisible by collective device group size "
1175 << int64_t(deviceGroupSize) << " for tensor axis " << tensorAxis
1176 << ".";
1177 }
1178 DimensionSize expectedResultTensorDimSize =
1179 operandScatterDimSize / deviceGroupSize;
1180 if (failed(Result: verifyDimensionCompatibility(
1181 loc: result.getLoc(), expectedDimSize: expectedResultTensorDimSize.value(),
1182 resultDimSize: resultType.getDimSize(idx: tensorAxis), resultAxis: tensorAxis))) {
1183 return failure();
1184 }
1185
1186 return success();
1187}
1188
1189static RankedTensorType sliceResultType(Type operandType, MeshOp mesh,
1190 ArrayRef<MeshAxis> meshAxes,
1191 int64_t sliceAxis) {
1192 RankedTensorType operandRankedTensorType =
1193 cast<RankedTensorType>(Val&: operandType);
1194 DimensionSize operandSliceAxisSize =
1195 operandRankedTensorType.getShape()[sliceAxis];
1196 SmallVector<int64_t> resultShape =
1197 llvm::to_vector(Range: operandRankedTensorType.getShape());
1198
1199 resultShape[sliceAxis] =
1200 operandSliceAxisSize /
1201 DimensionSize(collectiveProcessGroupSize(meshAxes, mesh));
1202 return operandRankedTensorType.clone(shape: resultShape);
1203}
1204
1205//===----------------------------------------------------------------------===//
1206// mesh.all_gather op
1207//===----------------------------------------------------------------------===//
1208
1209LogicalResult
1210AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1211 auto mesh = getMeshAndVerifyAxes(op: *this, symbolTable);
1212 if (failed(Result: mesh)) {
1213 return failure();
1214 }
1215 auto gatherAxis = getGatherAxis().getSExtValue();
1216 return verifyGatherOperandAndResultShape(operand: getOperand(), result: getResult(),
1217 gatherAxis, meshAxes: getMeshAxes(),
1218 meshShape: mesh.value().getShape());
1219}
1220
1221void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1222 MLIRContext *context) {
1223 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(arg&: context);
1224}
1225
1226void AllGatherOp::getAsmResultNames(
1227 function_ref<void(Value, StringRef)> setNameFn) {
1228 setNameFn(getResult(), "all_gather");
1229}
1230
1231//===----------------------------------------------------------------------===//
1232// mesh.all_reduce op
1233//===----------------------------------------------------------------------===//
1234
1235LogicalResult
1236AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1237 return getMeshAndVerifyAxes(op: *this, symbolTable);
1238}
1239
1240void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1241 MLIRContext *context) {
1242 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(arg&: context);
1243}
1244
1245void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1246 Value input, StringRef mesh,
1247 ArrayRef<MeshAxis> meshAxes, ReductionKind reduction) {
1248 build(odsBuilder, odsState, result: input.getType(), mesh, mesh_axes: meshAxes, input,
1249 reduction);
1250}
1251
1252void AllReduceOp::getAsmResultNames(
1253 function_ref<void(Value, StringRef)> setNameFn) {
1254 setNameFn(getResult(), "all_reduce");
1255}
1256
1257//===----------------------------------------------------------------------===//
1258// mesh.all_slice op
1259//===----------------------------------------------------------------------===//
1260
1261LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1262 auto mesh = getMeshAndVerifyAxes(op: *this, symbolTable);
1263 if (failed(Result: mesh)) {
1264 return failure();
1265 }
1266 return verifyScatterOrSliceOperandAndResultShape(
1267 operand: getOperand(), result: getResult(), tensorAxis: getSliceAxis().getSExtValue(), meshAxes: getMeshAxes(),
1268 meshShape: mesh.value().getShape());
1269}
1270
1271void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1272 MLIRContext *context) {
1273 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(arg&: context);
1274}
1275
1276void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1277 Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
1278 int64_t sliceAxis) {
1279 Type resultType = sliceResultType(operandType: input.getType(), mesh, meshAxes, sliceAxis);
1280 build(odsBuilder, odsState, result_type: resultType, input, mesh: mesh.getSymName(), meshAxes,
1281 sliceAxis);
1282}
1283
1284void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1285 Type resultType, Value input, StringRef mesh,
1286 ArrayRef<MeshAxis> meshAxes, int64_t sliceAxis) {
1287 build(odsBuilder, odsState, result: resultType, mesh, mesh_axes: meshAxes, input,
1288 slice_axis: APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
1289}
1290
1291void AllSliceOp::getAsmResultNames(
1292 function_ref<void(Value, StringRef)> setNameFn) {
1293 setNameFn(getResult(), "all_slice");
1294}
1295
1296//===----------------------------------------------------------------------===//
1297// mesh.all_to_all op
1298//===----------------------------------------------------------------------===//
1299
1300LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1301 auto mesh = getMeshAndVerifyAxes(op: *this, symbolTable);
1302 if (failed(Result: mesh)) {
1303 return failure();
1304 }
1305
1306 return verifyAllToAllOperandAndResultShape(
1307 operand: getOperand(), result: getResult(), splitAxis: getSplitAxis().getSExtValue(),
1308 concatAxis: getConcatAxis().getSExtValue(), meshAxes: getMeshAxes(), meshShape: mesh.value().getShape());
1309}
1310
1311void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1312 MLIRContext *context) {
1313 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(arg&: context);
1314}
1315
1316void AllToAllOp::getAsmResultNames(
1317 function_ref<void(Value, StringRef)> setNameFn) {
1318 setNameFn(getResult(), "all_to_all");
1319}
1320
1321//===----------------------------------------------------------------------===//
1322// mesh.broadcast op
1323//===----------------------------------------------------------------------===//
1324
1325LogicalResult
1326BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1327 auto mesh = getMeshAndVerifyAxes(op: *this, symbolTable);
1328 if (failed(Result: mesh)) {
1329 return failure();
1330 }
1331 if (failed(Result: verifyInGroupDevice(loc: getLoc(), deviceName: getRootAttrName(), device: getRoot(),
1332 deviceDynamic: getRootDynamic(), meshAxes: getMeshAxes(),
1333 meshShape: mesh.value().getShape()))) {
1334 return failure();
1335 }
1336
1337 return success();
1338}
1339
1340void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1341 MLIRContext *context) {
1342 patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(arg&: context);
1343}
1344
1345void BroadcastOp::getAsmResultNames(
1346 function_ref<void(Value, StringRef)> setNameFn) {
1347 setNameFn(getResult(), "broadcast");
1348}
1349
1350//===----------------------------------------------------------------------===//
1351// mesh.gather op
1352//===----------------------------------------------------------------------===//
1353
1354LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1355 auto mesh = getMeshAndVerifyAxes(op: *this, symbolTable);
1356 if (failed(Result: mesh)) {
1357 return failure();
1358 }
1359 if (failed(Result: verifyInGroupDevice(loc: getLoc(), deviceName: getRootAttrName(), device: getRoot(),
1360 deviceDynamic: getRootDynamic(), meshAxes: getMeshAxes(),
1361 meshShape: mesh.value().getShape()))) {
1362 return failure();
1363 }
1364
1365 auto gatherAxis = getGatherAxis().getSExtValue();
1366 return verifyGatherOperandAndResultShape(operand: getInput(), result: getResult(), gatherAxis,
1367 meshAxes: getMeshAxes(),
1368 meshShape: mesh.value().getShape());
1369}
1370
1371void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1372 MLIRContext *context) {
1373 patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(arg&: context);
1374}
1375
1376void GatherOp::getAsmResultNames(
1377 function_ref<void(Value, StringRef)> setNameFn) {
1378 setNameFn(getResult(), "gather");
1379}
1380
1381//===----------------------------------------------------------------------===//
1382// mesh.recv op
1383//===----------------------------------------------------------------------===//
1384
1385LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1386 auto mesh = getMeshAndVerifyAxes(op: *this, symbolTable);
1387 if (failed(Result: mesh)) {
1388 return failure();
1389 }
1390 if (getSource() &&
1391 failed(Result: verifyInGroupDevice(loc: getLoc(), deviceName: getSourceAttrName(),
1392 device: getSource().value(), deviceDynamic: getSourceDynamic(),
1393 meshAxes: getMeshAxes(), meshShape: mesh.value().getShape()))) {
1394 return failure();
1395 }
1396 return success();
1397}
1398
1399void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1400 MLIRContext *context) {
1401 patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(arg&: context);
1402}
1403
1404void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1405 setNameFn(getResult(), "recv");
1406}
1407
1408//===----------------------------------------------------------------------===//
1409// mesh.reduce op
1410//===----------------------------------------------------------------------===//
1411
1412LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1413 auto mesh = getMeshAndVerifyAxes(op: *this, symbolTable);
1414 if (failed(Result: mesh)) {
1415 return failure();
1416 }
1417 if (failed(Result: verifyInGroupDevice(loc: getLoc(), deviceName: getRootAttrName(), device: getRoot(),
1418 deviceDynamic: getRootDynamic(), meshAxes: getMeshAxes(),
1419 meshShape: mesh.value().getShape()))) {
1420 return failure();
1421 }
1422
1423 return success();
1424}
1425
1426void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1427 MLIRContext *context) {
1428 patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(arg&: context);
1429}
1430
1431void ReduceOp::getAsmResultNames(
1432 function_ref<void(Value, StringRef)> setNameFn) {
1433 setNameFn(getResult(), "reduce");
1434}
1435
1436//===----------------------------------------------------------------------===//
1437// mesh.reduce_scatter op
1438//===----------------------------------------------------------------------===//
1439
1440LogicalResult
1441ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1442 auto mesh = getMeshAndVerifyAxes(op: *this, symbolTable);
1443 if (failed(Result: mesh)) {
1444 return failure();
1445 }
1446
1447 return verifyScatterOrSliceOperandAndResultShape(
1448 operand: getOperand(), result: getResult(), tensorAxis: getScatterAxis().getSExtValue(), meshAxes: getMeshAxes(),
1449 meshShape: mesh.value().getShape());
1450}
1451
1452void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1453 MLIRContext *context) {
1454 patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(arg&: context);
1455}
1456
1457void ReduceScatterOp::getAsmResultNames(
1458 function_ref<void(Value, StringRef)> setNameFn) {
1459 setNameFn(getResult(), "reduce_scatter");
1460}
1461
1462//===----------------------------------------------------------------------===//
1463// mesh.scatter op
1464//===----------------------------------------------------------------------===//
1465
1466LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1467 auto mesh = getMeshAndVerifyAxes(op: *this, symbolTable);
1468 if (failed(Result: mesh)) {
1469 return failure();
1470 }
1471 if (failed(Result: verifyInGroupDevice(loc: getLoc(), deviceName: getRootAttrName(), device: getRoot(),
1472 deviceDynamic: getRootDynamic(), meshAxes: getMeshAxes(),
1473 meshShape: mesh.value().getShape()))) {
1474 return failure();
1475 }
1476
1477 auto scatterAxis = getScatterAxis().getSExtValue();
1478 return verifyScatterOrSliceOperandAndResultShape(operand: getInput(), result: getResult(),
1479 tensorAxis: scatterAxis, meshAxes: getMeshAxes(),
1480 meshShape: mesh.value().getShape());
1481}
1482
1483void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1484 MLIRContext *context) {
1485 patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(arg&: context);
1486}
1487
1488void ScatterOp::getAsmResultNames(
1489 function_ref<void(Value, StringRef)> setNameFn) {
1490 setNameFn(getResult(), "scatter");
1491}
1492
1493//===----------------------------------------------------------------------===//
1494// mesh.send op
1495//===----------------------------------------------------------------------===//
1496
1497LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1498 auto mesh = getMeshAndVerifyAxes(op: *this, symbolTable);
1499 if (failed(Result: mesh)) {
1500 return failure();
1501 }
1502 if (failed(Result: verifyInGroupDevice(loc: getLoc(), deviceName: getDestinationAttrName(),
1503 device: getDestination(), deviceDynamic: getDestinationDynamic(),
1504 meshAxes: getMeshAxes(), meshShape: mesh.value().getShape()))) {
1505 return failure();
1506 }
1507 return success();
1508}
1509
1510void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1511 MLIRContext *context) {
1512 patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(arg&: context);
1513}
1514
1515void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1516 setNameFn(getResult(), "send");
1517}
1518
1519//===----------------------------------------------------------------------===//
1520// mesh.shift op
1521//===----------------------------------------------------------------------===//
1522
1523LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1524 auto mesh = getMeshAndVerifyAxes(op: *this, symbolTable);
1525 if (failed(Result: mesh)) {
1526 return failure();
1527 }
1528
1529 auto meshAxes = getMeshAxes();
1530 auto shiftAxis = getShiftAxis().getZExtValue();
1531 if (!llvm::is_contained(Range&: meshAxes, Element: shiftAxis)) {
1532 return emitError() << "Invalid shift axis " << shiftAxis
1533 << ". It must be one of the grouping mesh axes.";
1534 }
1535
1536 return success();
1537}
1538
1539void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1540 MLIRContext *context) {
1541 // TODO: remove op when offset is 0 or if it is a rotate with and
1542 // offset % shift_axis_mesh_dim_size == 0.
1543}
1544
1545void ShiftOp::getAsmResultNames(
1546 function_ref<void(Value, StringRef)> setNameFn) {
1547 setNameFn(getResult(), "shift");
1548}
1549
1550//===----------------------------------------------------------------------===//
1551// mesh.update_halo op
1552//===----------------------------------------------------------------------===//
1553
1554LogicalResult
1555UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1556 auto mesh = getMeshAndVerify(op: getOperation(), meshSymbol: getMeshAttr(), symbolTable);
1557 if (failed(Result: mesh)) {
1558 return failure();
1559 }
1560
1561 return success();
1562}
1563
1564//===----------------------------------------------------------------------===//
1565// TableGen'd op method definitions
1566//===----------------------------------------------------------------------===//
1567
1568#define GET_OP_CLASSES
1569#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
1570
1571#define GET_ATTRDEF_CLASSES
1572#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
1573
1574#define GET_TYPEDEF_CLASSES
1575#include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
1576
1577#include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
1578

source code of mlir/lib/Dialect/Mesh/IR/MeshOps.cpp