1//===- MeshOps.cpp - Mesh Dialect Operations ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Mesh/IR/MeshOps.h"
10
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
13#include "mlir/Dialect/Utils/StaticValueUtils.h"
14#include "mlir/IR/Attributes.h"
15#include "mlir/IR/BuiltinAttributes.h"
16#include "mlir/IR/BuiltinTypeInterfaces.h"
17#include "mlir/IR/BuiltinTypes.h"
18#include "mlir/IR/Diagnostics.h"
19#include "mlir/IR/DialectImplementation.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Location.h"
22#include "mlir/IR/PatternMatch.h"
23#include "mlir/IR/TypeUtilities.h"
24#include "mlir/IR/Value.h"
25#include "mlir/Interfaces/ViewLikeInterface.h"
26#include "mlir/Support/LLVM.h"
27#include "mlir/Transforms/InliningUtils.h"
28#include "llvm/ADT/ArrayRef.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallSet.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/Support/Casting.h"
34#include <algorithm>
35#include <functional>
36#include <iterator>
37#include <numeric>
38#include <optional>
39#include <utility>
40
41#define DEBUG_TYPE "mesh-ops"
42#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
43
44using namespace mlir;
45using namespace mlir::mesh;
46
47#include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
48
49namespace {
50
51struct DimensionSize {
52 static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
53 DimensionSize(int64_t val) : val(val) {}
54 int64_t value() const { return val; }
55 operator int64_t() const { return val; }
56 bool isDynamic() const { return ShapedType::isDynamic(val); }
57
58private:
59 int64_t val;
60};
61
62} // namespace
63
64static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
65 if (lhs.isDynamic() || rhs.isDynamic()) {
66 return DimensionSize::dynamic();
67 }
68 return lhs.value() / rhs.value();
69}
70
71static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
72 if (lhs.isDynamic() || rhs.isDynamic()) {
73 return DimensionSize::dynamic();
74 }
75 return lhs.value() * rhs.value();
76}
77
78//===----------------------------------------------------------------------===//
79// Inliner
80//===----------------------------------------------------------------------===//
81
82namespace {
83struct MeshInlinerInterface : public DialectInlinerInterface {
84 using DialectInlinerInterface::DialectInlinerInterface;
85 // Currently no restrictions are encoded for inlining.
86 bool isLegalToInline(Operation *, Operation *, bool) const final {
87 return true;
88 }
89 bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
90 return true;
91 }
92 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
93 return true;
94 }
95};
96} // namespace
97
98//===----------------------------------------------------------------------===//
99// Mesh dialect
100//===----------------------------------------------------------------------===//
101
102void MeshDialect::initialize() {
103 addOperations<
104#define GET_OP_LIST
105#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
106 >();
107 addAttributes<
108#define GET_ATTRDEF_LIST
109#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
110 >();
111 addTypes<
112#define GET_TYPEDEF_LIST
113#include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
114 >();
115 addInterface<MeshInlinerInterface>();
116}
117
118Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
119 Type type, Location loc) {
120 return arith::ConstantOp::materialize(builder, value, type, loc);
121}
122
123//===----------------------------------------------------------------------===//
124// Mesh utilities
125//===----------------------------------------------------------------------===//
126
127static FailureOr<MeshOp> getMeshAndVerify(Operation *op,
128 FlatSymbolRefAttr meshSymbol,
129 SymbolTableCollection &symbolTable) {
130 mesh::MeshOp mesh = getMeshOrNull(op, meshSymbol, symbolTable);
131 if (!mesh) {
132 return op->emitError() << "Undefined required mesh symbol \""
133 << meshSymbol.getValue() << "\".";
134 }
135
136 return mesh;
137}
138
139template <typename It>
140bool isUnique(It begin, It end) {
141 if (begin == end) {
142 return true;
143 }
144 It next = std::next(begin);
145 if (next == end) {
146 return true;
147 }
148 for (; next != end; ++next, ++begin) {
149 if (*begin == *next) {
150 return false;
151 }
152 }
153 return true;
154}
155
156static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
157 MeshOp mesh) {
158 SmallVector<MeshAxis> sorted = llvm::to_vector(Range&: axes);
159 llvm::sort(C&: sorted);
160 if (!isUnique(begin: sorted.begin(), end: sorted.end())) {
161 return emitError(loc) << "Mesh axes contains duplicate elements.";
162 }
163
164 MeshAxis rank = mesh.getRank();
165 for (auto axis : axes) {
166 if (axis >= rank || axis < 0) {
167 return emitError(loc)
168 << "0-based mesh axis index " << axis
169 << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
170 << "\" is of rank " << rank << ".";
171 }
172 }
173
174 return success();
175}
176
177template <typename Op>
178static FailureOr<MeshOp>
179getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
180 auto mesh =
181 ::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable);
182 if (failed(mesh)) {
183 return failure();
184 }
185 if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) {
186 return failure();
187 }
188 return mesh;
189}
190
191template <typename InShape, typename MeshShape, typename SplitAxes,
192 typename OutShape>
193static void shardShape(const InShape &inShape, const MeshShape &meshShape,
194 const SplitAxes &splitAxes, OutShape &outShape,
195 ArrayRef<int64_t> shardedDimsOffsets = {},
196 ArrayRef<int64_t> haloSizes = {}) {
197 // 0d tensors cannot be sharded and must get replicated
198 if (inShape.empty()) {
199 assert(outShape.empty());
200 return;
201 }
202
203 std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
204 llvm::adl_begin(outShape));
205
206 if (!shardedDimsOffsets.empty()) {
207 auto isDynShape = ShapedType::isDynamicShape(meshShape);
208 uint64_t pos = 1;
209 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
210 if (!innerSplitAxes.empty()) {
211 auto sz = shardedDimsOffsets[pos];
212 bool same = !isDynShape;
213 if (same) {
214 // Find sharded dims in shardedDimsOffsets with same static size on
215 // all devices. Use kDynamic for dimensions with dynamic or
216 // non-uniform offs in shardedDimsOffsets.
217 uint64_t numShards = 0;
218 for (auto i : innerSplitAxes.asArrayRef()) {
219 numShards += meshShape[i];
220 }
221 for (size_t i = 1; i < numShards; ++i) {
222 if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
223 sz) {
224 same = false;
225 break;
226 }
227 }
228 pos += numShards + 1;
229 }
230 outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
231 }
232 }
233 } else {
234 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
235 outShape[tensorAxis] = shardDimension(
236 inShape[tensorAxis],
237 collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
238 }
239
240 if (!haloSizes.empty()) {
241 // add halo sizes if requested
242 int haloAxis = 0;
243 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
244 if (!ShapedType::isDynamic(outShape[tensorAxis]) &&
245 !innerSplitAxes.empty()) {
246 if (haloSizes[haloAxis * 2] >= 0 &&
247 haloSizes[haloAxis * 2 + 1] >= 0) {
248 outShape[tensorAxis] +=
249 haloSizes[haloAxis * 2] + haloSizes[haloAxis * 2 + 1];
250 ++haloAxis;
251 } else {
252 outShape[tensorAxis] = ShapedType::kDynamic;
253 }
254 }
255 }
256 }
257 }
258}
259
260ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
261 MeshSharding sharding) {
262 using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
263 SmallVector<Dim> resShapeArr(shape.getShape().size());
264 shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
265 resShapeArr, sharding.getStaticShardedDimsOffsets(),
266 sharding.getStaticHaloSizes());
267 return shape.clone(resShapeArr);
268}
269
270Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
271 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
272 if (rankedTensorType && !rankedTensorType.getShape().empty()) {
273 return shardShapedType(rankedTensorType, mesh, sharding);
274 }
275 return type;
276}
277
278void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
279 OpOperand &operand,
280 OpBuilder &builder,
281 ShardOp &newShardOp) {
282 OpBuilder::InsertionGuard insertionGuard(builder);
283 Value operandValue = operand.get();
284 Operation *operandOp = operand.getOwner();
285 builder.setInsertionPointAfterValue(operandValue);
286 ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
287 if (shardOp && sharding == shardOp.getSharding() &&
288 !shardOp.getAnnotateForUsers()) {
289 // No need for anything if the correct sharding is already set.
290 if (!newShardOp) {
291 newShardOp = shardOp;
292 }
293 return;
294 }
295
296 if (!newShardOp) {
297 auto shardingOp =
298 builder.create<ShardingOp>(operandValue.getLoc(), sharding);
299 newShardOp =
300 builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
301 /*annotate_for_users*/ false);
302 }
303 IRRewriter rewriter(builder);
304 rewriter.replaceUsesWithIf(
305 operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
306 return use.getOwner() == operandOp && use.get() == operandValue;
307 });
308
309 if (!shardOp || shardOp.getAnnotateForUsers()) {
310 return;
311 }
312
313 auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
314 newShardOp.getSharding(),
315 /*annotate_for_users*/ true);
316 rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
317}
318
319void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
320 OpResult result,
321 OpBuilder &builder) {
322 ShardOp newShardOp;
323 for (auto &use : llvm::make_early_inc_range(Range: result.getUses())) {
324 maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
325 }
326}
327
328void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding,
329 OpOperand &operand,
330 OpBuilder &builder) {
331 OpBuilder::InsertionGuard insertionGuard(builder);
332 Value operandValue = operand.get();
333 Operation *operandSrcOp = operandValue.getDefiningOp();
334 bool isBlockArg = !operandSrcOp;
335 {
336 [[maybe_unused]] auto opType =
337 dyn_cast<mlir::RankedTensorType>(operandValue.getType());
338 assert(!opType || opType.getRank() > 0 || isFullReplication(sharding));
339 }
340 if (!isa<RankedTensorType>(Val: operandValue.getType()) && operandSrcOp &&
341 operandSrcOp->hasTrait<OpTrait::ConstantLike>()) {
342 return;
343 }
344
345 Operation *operandOp = operand.getOwner();
346 ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
347
348 if (shardOp && sharding == shardOp.getSharding() &&
349 shardOp.getAnnotateForUsers()) {
350 // No need for anything the correct sharding is already set.
351 return;
352 }
353
354 builder.setInsertionPoint(operandOp);
355 auto shardingOp =
356 builder.create<ShardingOp>(operand.get().getLoc(), sharding);
357 auto newShardOp =
358 builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
359 /*annotate_for_users*/ true);
360 IRRewriter rewriter(builder);
361 rewriter.replaceUsesWithIf(
362 operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
363 return use.getOwner() == operandOp && use.get() == operandValue;
364 });
365
366 if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
367 // No need for resharding.
368 return;
369 }
370
371 builder.setInsertionPoint(newShardOp);
372 auto newPreceedingShardOp =
373 builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
374 /*annotate_for_users*/ false);
375 rewriter.replaceUsesWithIf(
376 newShardOp.getSrc(), newPreceedingShardOp, [&newShardOp](OpOperand &use) {
377 return use.getOwner() == newShardOp.getOperation();
378 });
379}
380
381//===----------------------------------------------------------------------===//
382// mesh.mesh op
383//===----------------------------------------------------------------------===//
384
385LogicalResult MeshOp::verify() {
386 int64_t rank = getRank();
387
388 if (rank <= 0)
389 return emitOpError("rank of mesh is expected to be a positive integer");
390
391 for (int64_t dimSize : getShape()) {
392 if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
393 return emitOpError("dimension size of a mesh is expected to be "
394 "non-negative or dynamic");
395 }
396
397 return success();
398}
399
400//===----------------------------------------------------------------------===//
401// mesh.mesh_shape op
402//===----------------------------------------------------------------------===//
403
404LogicalResult
405MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
406 auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
407 if (failed(mesh)) {
408 return failure();
409 }
410 if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
411 return failure();
412 }
413
414 size_t expectedResultsCount =
415 getAxes().empty() ? mesh->getRank() : getAxes().size();
416 if (getResult().size() != expectedResultsCount) {
417 return emitError() << "Unexpected number of results " << getResult().size()
418 << ". Expected " << expectedResultsCount << ".";
419 }
420
421 return success();
422}
423
424void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
425 MeshOp mesh) {
426 build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>());
427}
428
429void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
430 MeshOp mesh, ArrayRef<MeshAxis> axes) {
431 build(odsBuilder, odsState,
432 SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(),
433 odsBuilder.getIndexType()),
434 mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes));
435}
436
437void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
438 StringRef mesh, ArrayRef<MeshAxis> axes) {
439 assert(!axes.empty());
440 build(odsBuilder, odsState,
441 SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
442 MeshAxesAttr::get(odsBuilder.getContext(), axes));
443}
444
445void MeshShapeOp::getAsmResultNames(
446 function_ref<void(Value, StringRef)> setNameFn) {
447 setNameFn(getResults()[0], "mesh_shape");
448}
449
450//===----------------------------------------------------------------------===//
451// mesh.sharding
452//===----------------------------------------------------------------------===//
453
454void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
455 FlatSymbolRefAttr mesh,
456 ArrayRef<MeshAxesAttr> split_axes,
457 ArrayRef<MeshAxis> partial_axes,
458 mesh::ReductionKind partial_type,
459 ArrayRef<int64_t> static_halos,
460 ArrayRef<int64_t> static_offsets) {
461 return build(
462 b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
463 ::mlir::DenseI16ArrayAttr::get(b.getContext(), partial_axes),
464 ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
465 ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
466 ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
467}
468
469void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
470 FlatSymbolRefAttr mesh,
471 ArrayRef<MeshAxesAttr> split_axes) {
472 return build(
473 b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
474 ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
475 {}, {}, {}, {});
476}
477
478void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
479 llvm::StringRef mesh, ArrayRef<MeshAxesAttr> split_axes,
480 ArrayRef<int64_t> static_halos,
481 ArrayRef<int64_t> static_offsets) {
482 return build(
483 b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh),
484 MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
485 ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
486 ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
487 ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
488}
489
490void ShardingOp::build(
491 ::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
492 FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
493 ::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes,
494 ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) {
495 mlir::SmallVector<int64_t> staticHalos, staticDims;
496 mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
497 dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
498 dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims);
499 return build(
500 b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
501 ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
502 ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos,
503 ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
504}
505
506void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
507 mlir::mesh::MeshSharding from) {
508
509 build(b, odsState, ShardingType::get(b.getContext()), from.getMeshAttr(),
510 MeshAxesArrayAttr::get(b.getContext(), from.getSplitAxes()),
511 from.getPartialAxes().empty()
512 ? DenseI16ArrayAttr()
513 : b.getDenseI16ArrayAttr(from.getPartialAxes()),
514 ::mlir::mesh::ReductionKindAttr::get(b.getContext(),
515 from.getPartialType()),
516 from.getStaticShardedDimsOffsets().empty()
517 ? DenseI64ArrayAttr()
518 : b.getDenseI64ArrayAttr(from.getStaticShardedDimsOffsets()),
519 from.getDynamicShardedDimsOffsets(),
520 from.getStaticHaloSizes().empty()
521 ? DenseI64ArrayAttr()
522 : b.getDenseI64ArrayAttr(from.getStaticHaloSizes()),
523 from.getDynamicHaloSizes());
524}
525
526LogicalResult ShardingOp::verify() {
527 llvm::SmallSet<MeshAxis, 4> visitedAxes;
528
529 auto checkMeshAxis = [&](ArrayRef<MeshAxis> axesArray) -> LogicalResult {
530 for (MeshAxis axis : axesArray) {
531 if (axis < 0)
532 return emitError() << "mesh axis is expected to be non-negative";
533 if (!visitedAxes.insert(axis).second)
534 return emitError() << "mesh axis duplicated";
535 }
536 return success();
537 };
538
539 for (auto subAxes : getSplitAxes().getAxes()) {
540 ArrayRef<MeshAxis> subAxesArray = subAxes.asArrayRef();
541 if (failed(checkMeshAxis(subAxesArray)))
542 return failure();
543 }
544 if (getPartialAxes().has_value() &&
545 failed(checkMeshAxis(getPartialAxes().value())))
546 return failure();
547
548 if (!getStaticHaloSizes().empty() && !getStaticShardedDimsOffsets().empty()) {
549 return emitOpError("halo sizes and shard offsets are mutually exclusive");
550 }
551
552 if (!getStaticHaloSizes().empty()) {
553 auto numSplitAxes = getSplitAxes().getAxes().size();
554 for (auto splitAxis : getSplitAxes().getAxes()) {
555 if (splitAxis.empty()) {
556 --numSplitAxes;
557 }
558 }
559 if (getStaticHaloSizes().size() != numSplitAxes * 2) {
560 return emitError() << "halo sizes must be specified for all split axes.";
561 }
562 }
563
564 return success();
565}
566
567void ShardingOp::getAsmResultNames(
568 function_ref<void(Value, StringRef)> setNameFn) {
569 setNameFn(getResult(), "sharding");
570}
571
572LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
573 auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
574 if (failed(mesh)) {
575 return failure();
576 }
577 if (mlir::ShapedType::isDynamicShape(mesh->getShape()) &&
578 getStaticShardedDimsOffsets().size() > 0) {
579 return emitError() << "sharded dims offsets are not allowed for "
580 "devices meshes with dynamic shape.";
581 }
582
583 auto shardedDimsOffsets = getStaticShardedDimsOffsets();
584 if (!shardedDimsOffsets.empty()) {
585 auto meshShape = mesh.value().getShape();
586 assert(!ShapedType::isDynamicShape(meshShape));
587 uint64_t pos = 0;
588 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) {
589 if (!innerSplitAxes.empty()) {
590 int64_t numShards = 0, off = 0;
591 for (auto i : innerSplitAxes.asArrayRef()) {
592 numShards += meshShape[i];
593 }
594 for (int64_t i = 0; i <= numShards; ++i) {
595 if (shardedDimsOffsets.size() <= pos + i) {
596 return emitError() << "sharded dims offsets has wrong size.";
597 }
598 if (!ShapedType::isDynamic(shardedDimsOffsets[pos + i])) {
599 if (shardedDimsOffsets[pos + i] < off) {
600 return emitError()
601 << "sharded dims offsets must be non-decreasing.";
602 }
603 off = shardedDimsOffsets[pos + i];
604 }
605 }
606 pos += numShards + 1;
607 }
608 }
609 }
610 return success();
611}
612
613namespace {
614// Sharding annotations "halo sizes" and "sharded dims offsets"
615// are a mix of attributes and dynamic values. This canonicalization moves
616// constant values to the respective attribute lists, minimizing the number
617// of values.
618// It also removes sharded_dims_sizes and halos if they are effectively "empty".
619class NormalizeSharding final : public OpRewritePattern<ShardingOp> {
620public:
621 using OpRewritePattern<ShardingOp>::OpRewritePattern;
622
623 LogicalResult matchAndRewrite(ShardingOp op,
624 PatternRewriter &b) const override {
625 auto mixedHalos =
626 getMixedValues(op.getStaticHaloSizes(), op.getDynamicHaloSizes(), b);
627 auto mixedOffs = getMixedValues(op.getStaticShardedDimsOffsets(),
628 op.getDynamicShardedDimsOffsets(), b);
629
630 // No constant operands were folded, just return;
631 bool modified = succeeded(foldDynamicIndexList(mixedHalos, true)) ||
632 succeeded(foldDynamicIndexList(mixedOffs, true));
633
634 auto [staticHalos, dynamicHalos] = decomposeMixedValues(mixedHalos);
635 auto [staticOffs, dynamicOffs] = decomposeMixedValues(mixedOffs);
636
637 if (dynamicHalos.empty() && !staticHalos.empty()) {
638 if (staticHalos[0] == 0 && llvm::all_equal(staticHalos)) {
639 staticHalos.clear();
640 modified = true;
641 }
642 }
643
644 // Remove sharded dims offsets if they are effectively the default values,
645 // e.g. if they define equi-distance between all neighboring shards.
646 // Requires static-only offsets. Compares the first distance as the
647 // difference between the first two offsets. Only if all consecutive
648 // distances are the same, the offsets are removed.
649 if (dynamicOffs.empty() && !staticOffs.empty()) {
650 assert(staticOffs.size() >= 2);
651 auto diff = staticOffs[1] - staticOffs[0];
652 bool all_same = staticOffs.size() > 2;
653 for (auto i = 2u; i < staticOffs.size(); ++i) {
654 if (staticOffs[i] - staticOffs[i - 1] != diff) {
655 all_same = false;
656 break;
657 }
658 }
659 if (all_same) {
660 staticOffs.clear();
661 modified = true;
662 }
663 }
664
665 if (!modified) {
666 return failure();
667 }
668
669 op.setStaticHaloSizes(staticHalos);
670 op.getDynamicHaloSizesMutable().assign(dynamicHalos);
671 op.setStaticShardedDimsOffsets(staticOffs);
672 op.getDynamicShardedDimsOffsetsMutable().assign(dynamicOffs);
673
674 return success();
675 }
676};
677} // namespace
678
679void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
680 mlir::MLIRContext *context) {
681 results.add<NormalizeSharding>(context);
682}
683
684//===----------------------------------------------------------------------===//
685// MeshSharding
686//===----------------------------------------------------------------------===//
687
688bool MeshSharding::equalSplitAndPartialAxes(const MeshSharding &rhs) const {
689 if (getMesh() != rhs.getMesh()) {
690 return false;
691 }
692
693 if (getPartialAxes().size() != rhs.getPartialAxes().size() ||
694 (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) ||
695 !llvm::equal(getPartialAxes(), rhs.getPartialAxes())) {
696 return false;
697 }
698
699 auto minSize = std::min(a: getSplitAxes().size(), b: rhs.getSplitAxes().size());
700 if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
701 getSplitAxes().begin() + minSize),
702 llvm::make_range(rhs.getSplitAxes().begin(),
703 rhs.getSplitAxes().begin() + minSize))) {
704 return false;
705 }
706
707 return llvm::all_of(llvm::drop_begin(getSplitAxes(), minSize),
708 std::mem_fn(&MeshAxesAttr::empty)) &&
709 llvm::all_of(llvm::drop_begin(rhs.getSplitAxes(), minSize),
710 std::mem_fn(&MeshAxesAttr::empty));
711}
712
713bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const {
714 return equalShardSizes(rhs) && equalHaloSizes(rhs);
715}
716
717bool MeshSharding::equalShardSizes(const MeshSharding &rhs) const {
718 if (rhs.getStaticShardedDimsOffsets().size() !=
719 getStaticShardedDimsOffsets().size() ||
720 !llvm::equal(LRange: getStaticShardedDimsOffsets(),
721 RRange: rhs.getStaticShardedDimsOffsets())) {
722 return false;
723 }
724 if (rhs.getDynamicShardedDimsOffsets().size() !=
725 getDynamicShardedDimsOffsets().size() ||
726 !llvm::equal(LRange: getDynamicShardedDimsOffsets(),
727 RRange: rhs.getDynamicShardedDimsOffsets())) {
728 return false;
729 }
730 return true;
731}
732
733bool MeshSharding::equalHaloSizes(const MeshSharding &rhs) const {
734 if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() ||
735 !llvm::equal(LRange: getStaticHaloSizes(), RRange: rhs.getStaticHaloSizes())) {
736 return false;
737 }
738 if (rhs.getDynamicHaloSizes().size() != getDynamicHaloSizes().size() ||
739 !llvm::equal(LRange: getDynamicHaloSizes(), RRange: rhs.getDynamicHaloSizes())) {
740 return false;
741 }
742 return true;
743}
744
745bool MeshSharding::operator==(Value rhs) const {
746 return equalSplitAndPartialAxes(rhs) && equalHaloAndShardSizes(rhs);
747}
748
749bool MeshSharding::operator!=(Value rhs) const { return !(*this == rhs); }
750
751bool MeshSharding::operator==(const MeshSharding &rhs) const {
752 return equalSplitAndPartialAxes(rhs) && equalHaloAndShardSizes(rhs);
753}
754
755bool MeshSharding::operator!=(const MeshSharding &rhs) const {
756 return !(*this == rhs);
757}
758
759MeshSharding::MeshSharding(::mlir::FlatSymbolRefAttr mesh_) : mesh(mesh_) {}
760
761MeshSharding::MeshSharding(Value rhs) {
762 auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
763 assert(shardingOp && "expected sharding op");
764 auto splitAxes = shardingOp.getSplitAxes().getAxes();
765 auto partialAxes = shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>());
766 // If splitAxes and partialAxes are empty, use "empty" constructor.
767 if (splitAxes.empty() && partialAxes.empty()) {
768 *this = MeshSharding(shardingOp.getMeshAttr());
769 return;
770 }
771 *this = get(shardingOp.getMeshAttr(), splitAxes, partialAxes,
772 shardingOp.getPartialType().value_or(ReductionKind::Sum),
773 shardingOp.getStaticHaloSizes(),
774 shardingOp.getStaticShardedDimsOffsets(),
775 SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
776 SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
777}
778
779MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
780 ArrayRef<MeshAxesAttr> split_axes_,
781 ArrayRef<MeshAxis> partial_axes_,
782 ReductionKind partial_type_,
783 ArrayRef<int64_t> static_halo_sizes_,
784 ArrayRef<int64_t> static_sharded_dims_offsets_,
785 ArrayRef<Value> dynamic_halo_sizes_,
786 ArrayRef<Value> dynamic_sharded_dims_offsets_) {
787 MeshSharding res(mesh_);
788 if (split_axes_.empty() && partial_axes_.empty()) {
789 return res;
790 }
791
792 res.split_axes.resize(split_axes_.size());
793 for (auto [i, axis] : llvm::enumerate(First&: split_axes_)) {
794 res.split_axes[i] =
795 MeshAxesAttr::get(mesh_.getContext(), axis.asArrayRef());
796 }
797
798 auto clone = [](const auto src, auto &dst) {
799 dst.resize(src.size());
800 llvm::copy(src, dst.begin());
801 };
802
803 clone(partial_axes_, res.partial_axes);
804 res.partial_type = partial_type_;
805 clone(static_halo_sizes_, res.static_halo_sizes);
806 clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
807 clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
808 clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets);
809
810 return res;
811}
812
813//===----------------------------------------------------------------------===//
814// mesh.shard_shape
815//===----------------------------------------------------------------------===//
816
817void ShardShapeOp::getAsmResultNames(
818 function_ref<void(Value, StringRef)> setNameFn) {
819 setNameFn(getResult()[0], "shard_shape");
820}
821
822void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
823 ::mlir::OperationState &odsState,
824 ::llvm::ArrayRef<int64_t> dims,
825 ArrayRef<Value> dims_dyn, ::mlir::Value sharding,
826 ::mlir::ValueRange device) {
827 SmallVector<mlir::Type> resType(dims.size(), odsBuilder.getIndexType());
828 build(odsBuilder, odsState, resType, dims, dims_dyn, sharding,
829 SmallVector<int64_t>(device.size(), ShapedType::kDynamic), device);
830}
831
832//===----------------------------------------------------------------------===//
833// mesh.shard op
834//===----------------------------------------------------------------------===//
835
836void ShardOp::getAsmResultNames(
837 function_ref<void(Value, StringRef)> setNameFn) {
838 setNameFn(getResult(), "sharding_annotated");
839}
840
841namespace {
842// Determine if the given ShardOp is a duplicate of another ShardOp
843// on the same value. This can happen if constant values are sharded.
844class FoldDuplicateShardOp final : public OpRewritePattern<ShardOp> {
845public:
846 using OpRewritePattern<ShardOp>::OpRewritePattern;
847
848 LogicalResult matchAndRewrite(ShardOp op, PatternRewriter &b) const override {
849 // Get the use-list of the value being sharded and check if it has more than
850 // one use.
851 Value value = op.getSrc();
852 if (value.hasOneUse() || value.getDefiningOp<ShardOp>()) {
853 return failure();
854 }
855
856 // Iterate through the uses of the value to find a duplicate ShardOp.
857 for (auto &use : value.getUses()) {
858 if (use.getOwner() != op.getOperation()) {
859 auto otherOp = dyn_cast<ShardOp>(use.getOwner());
860 if (!otherOp || !otherOp->isBeforeInBlock(op)) {
861 return failure();
862 }
863 // Create a MeshSharding object for the current and the other ShardOp
864 // If the two are equal replace current op with the other op.
865 MeshSharding currentSharding(op.getSharding());
866 MeshSharding otherSharding(otherOp.getSharding());
867 if (currentSharding == otherSharding) {
868 b.replaceAllUsesWith(op.getResult(), otherOp.getResult());
869 b.eraseOp(op.getOperation());
870 } else {
871 // use the other sharding as input for op
872 op.getSrcMutable().assign(otherOp.getResult());
873 }
874 return success();
875 }
876 }
877
878 return failure();
879 }
880};
881} // namespace
882
883void ShardOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
884 mlir::MLIRContext *context) {
885 results.add<FoldDuplicateShardOp>(context);
886}
887
888//===----------------------------------------------------------------------===//
889// mesh.process_multi_index op
890//===----------------------------------------------------------------------===//
891
892LogicalResult
893ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
894 auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
895 if (failed(mesh)) {
896 return failure();
897 }
898 if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
899 return failure();
900 }
901
902 size_t expectedResultsCount =
903 getAxes().empty() ? mesh->getRank() : getAxes().size();
904 if (getResult().size() != expectedResultsCount) {
905 return emitError() << "Unexpected number of results " << getResult().size()
906 << ". Expected " << expectedResultsCount << ".";
907 }
908
909 return success();
910}
911
912void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
913 MeshOp mesh) {
914 build(odsBuilder, odsState,
915 SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
916 mesh.getSymName(), ArrayRef<MeshAxis>());
917}
918
919void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
920 StringRef mesh, ArrayRef<MeshAxis> axes) {
921 build(odsBuilder, odsState,
922 SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
923 MeshAxesAttr::get(odsBuilder.getContext(), axes));
924}
925
926void ProcessMultiIndexOp::getAsmResultNames(
927 function_ref<void(Value, StringRef)> setNameFn) {
928 setNameFn(getResults()[0], "proc_linear_idx");
929}
930
931//===----------------------------------------------------------------------===//
932// mesh.process_linear_index op
933//===----------------------------------------------------------------------===//
934
935LogicalResult
936ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
937 auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
938 if (failed(mesh)) {
939 return failure();
940 }
941 return success();
942}
943
944void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
945 OperationState &odsState, MeshOp mesh) {
946 build(odsBuilder, odsState, mesh.getSymName());
947}
948
949void ProcessLinearIndexOp::getAsmResultNames(
950 function_ref<void(Value, StringRef)> setNameFn) {
951 setNameFn(getResult(), "proc_linear_idx");
952}
953
954//===----------------------------------------------------------------------===//
955// mesh.neighbors_linear_indices op
956//===----------------------------------------------------------------------===//
957
958LogicalResult
959NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
960 auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
961 if (failed(mesh)) {
962 return failure();
963 }
964 return success();
965}
966
967void NeighborsLinearIndicesOp::getAsmResultNames(
968 function_ref<void(Value, StringRef)> setNameFn) {
969 setNameFn(getNeighborDown(), "down_linear_idx");
970 setNameFn(getNeighborUp(), "up_linear_idx");
971}
972
973//===----------------------------------------------------------------------===//
974// collective communication ops
975//===----------------------------------------------------------------------===//
976
977namespace {
978
979template <typename Op>
980struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
981 using OpRewritePattern<Op>::OpRewritePattern;
982 LogicalResult matchAndRewrite(Op op,
983 PatternRewriter &rewriter) const override {
984 auto meshAxes = op.getMeshAxes();
985 if (!meshAxes.empty()) {
986 return failure();
987 }
988 if (op.getInput().getType() != op.getResult().getType()) {
989 return failure();
990 }
991
992 rewriter.replaceAllUsesWith(op.getResult(), op.getInput());
993 rewriter.eraseOp(op: op.getOperation());
994 return success();
995 }
996};
997
998} // namespace
999
1000static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
1001 ArrayRef<int64_t> device,
1002 Operation::operand_range deviceDynamic,
1003 ArrayRef<MeshAxis> meshAxes,
1004 ArrayRef<int64_t> meshShape) {
1005 if (device.size() != meshAxes.size()) {
1006 return emitError(loc) << "In-group device \"" << deviceName
1007 << "\" has unexpected multi-index size "
1008 << device.size() << ". Expected " << meshAxes.size()
1009 << ".";
1010 }
1011
1012 for (size_t i = 0; i < device.size(); ++i) {
1013 if (!ShapedType::isDynamic(device[i]) &&
1014 !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
1015 meshShape[meshAxes[i]] <= device[i]) {
1016 return emitError(loc)
1017 << "Out of bounds coordinate " << i << " for in-group device \""
1018 << deviceName << "\"."
1019 << " Got " << device[i] << ", but expected value in the range [0, "
1020 << (meshShape[meshAxes[i]] - 1) << "].";
1021 }
1022 }
1023 return success();
1024}
1025
1026template <typename It>
1027static auto product(It begin, It end) {
1028 using ElementType = std::decay_t<decltype(*begin)>;
1029 return std::accumulate(begin, end, static_cast<ElementType>(1),
1030 std::multiplies<ElementType>());
1031}
1032
1033template <typename R>
1034static auto product(R &&range) {
1035 return product(adl_begin(range), adl_end(range));
1036}
1037
1038static LogicalResult verifyDimensionCompatibility(Location loc,
1039 int64_t expectedDimSize,
1040 int64_t resultDimSize,
1041 int64_t resultAxis) {
1042 if (!ShapedType::isDynamic(resultDimSize) &&
1043 expectedDimSize != resultDimSize) {
1044 return emitError(loc) << "Dimension size mismatch for result axis "
1045 << resultAxis << ". Expected "
1046 << (ShapedType::isDynamic(expectedDimSize)
1047 ? Twine("dynamic")
1048 : Twine(expectedDimSize))
1049 << ", but got " << resultDimSize << ".";
1050 }
1051
1052 return success();
1053}
1054
1055static LogicalResult verifyGatherOperandAndResultShape(
1056 Value operand, Value result, int64_t gatherAxis,
1057 ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
1058 auto resultRank = cast<ShapedType>(result.getType()).getRank();
1059 if (gatherAxis < 0 || gatherAxis >= resultRank) {
1060 return emitError(loc: result.getLoc())
1061 << "Gather axis " << gatherAxis << " is out of bounds [0, "
1062 << resultRank << ").";
1063 }
1064
1065 ShapedType operandType = cast<ShapedType>(operand.getType());
1066 ShapedType resultType = cast<ShapedType>(result.getType());
1067 auto deviceGroupSize =
1068 DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
1069 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1070 auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
1071 auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
1072 auto expectedResultDimSize =
1073 axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
1074 if (failed(verifyDimensionCompatibility(
1075 result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
1076 return failure();
1077 }
1078 }
1079 return success();
1080}
1081
1082static LogicalResult verifyAllToAllOperandAndResultShape(
1083 Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
1084 ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
1085 ShapedType operandType = cast<ShapedType>(operand.getType());
1086 ShapedType resultType = cast<ShapedType>(result.getType());
1087 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1088 if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
1089 if (failed(verifyDimensionCompatibility(
1090 result.getLoc(), operandType.getDimSize(axis),
1091 resultType.getDimSize(axis), axis))) {
1092 return failure();
1093 }
1094 }
1095 }
1096
1097 if (splitAxis == concatAxis) {
1098 return success();
1099 }
1100
1101 auto deviceGroupSize =
1102 DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
1103 auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
1104 auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
1105 DimensionSize expectedResultConcatDimSize =
1106 operandConcatDimSize * deviceGroupSize;
1107 DimensionSize expectedResultSplitDimSize =
1108 operandSplitDimSize / deviceGroupSize;
1109 if (!expectedResultSplitDimSize.isDynamic() &&
1110 int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
1111 expectedResultSplitDimSize = DimensionSize::dynamic();
1112 }
1113 if (failed(verifyDimensionCompatibility(
1114 result.getLoc(), expectedResultConcatDimSize.value(),
1115 resultType.getDimSize(concatAxis), concatAxis))) {
1116 return failure();
1117 }
1118 if (failed(verifyDimensionCompatibility(
1119 result.getLoc(), expectedResultSplitDimSize.value(),
1120 resultType.getDimSize(splitAxis), splitAxis))) {
1121 return failure();
1122 }
1123
1124 return success();
1125}
1126
1127static LogicalResult verifyScatterOrSliceOperandAndResultShape(
1128 Value operand, Value result, int64_t tensorAxis,
1129 ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
1130 ShapedType operandType = cast<ShapedType>(operand.getType());
1131 ShapedType resultType = cast<ShapedType>(result.getType());
1132 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1133 if (axis != tensorAxis) {
1134 if (failed(verifyDimensionCompatibility(
1135 result.getLoc(), operandType.getDimSize(axis),
1136 resultType.getDimSize(axis), axis))) {
1137 return failure();
1138 }
1139 }
1140 }
1141
1142 auto deviceGroupSize =
1143 DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
1144 auto operandScatterDimSize =
1145 DimensionSize(operandType.getDimSize(tensorAxis));
1146 if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
1147 int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
1148 return emitError(loc: result.getLoc())
1149 << "Operand dimension size " << int64_t(operandScatterDimSize)
1150 << " is not divisible by collective device group size "
1151 << int64_t(deviceGroupSize) << " for tensor axis " << tensorAxis
1152 << ".";
1153 }
1154 DimensionSize expectedResultTensorDimSize =
1155 operandScatterDimSize / deviceGroupSize;
1156 if (failed(verifyDimensionCompatibility(
1157 result.getLoc(), expectedResultTensorDimSize.value(),
1158 resultType.getDimSize(tensorAxis), tensorAxis))) {
1159 return failure();
1160 }
1161
1162 return success();
1163}
1164
1165static RankedTensorType sliceResultType(Type operandType, MeshOp mesh,
1166 ArrayRef<MeshAxis> meshAxes,
1167 int64_t sliceAxis) {
1168 RankedTensorType operandRankedTensorType =
1169 cast<RankedTensorType>(operandType);
1170 DimensionSize operandSliceAxisSize =
1171 operandRankedTensorType.getShape()[sliceAxis];
1172 SmallVector<int64_t> resultShape =
1173 llvm::to_vector(operandRankedTensorType.getShape());
1174
1175 resultShape[sliceAxis] =
1176 operandSliceAxisSize /
1177 DimensionSize(collectiveProcessGroupSize(meshAxes, mesh));
1178 return operandRankedTensorType.clone(resultShape);
1179}
1180
1181//===----------------------------------------------------------------------===//
1182// mesh.all_gather op
1183//===----------------------------------------------------------------------===//
1184
1185LogicalResult
1186AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1187 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1188 if (failed(mesh)) {
1189 return failure();
1190 }
1191 auto gatherAxis = getGatherAxis().getSExtValue();
1192 return verifyGatherOperandAndResultShape(getOperand(), getResult(),
1193 gatherAxis, getMeshAxes(),
1194 mesh.value().getShape());
1195}
1196
1197void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1198 MLIRContext *context) {
1199 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
1200}
1201
1202void AllGatherOp::getAsmResultNames(
1203 function_ref<void(Value, StringRef)> setNameFn) {
1204 setNameFn(getResult(), "all_gather");
1205}
1206
1207//===----------------------------------------------------------------------===//
1208// mesh.all_reduce op
1209//===----------------------------------------------------------------------===//
1210
1211LogicalResult
1212AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1213 return getMeshAndVerifyAxes(*this, symbolTable);
1214}
1215
1216void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1217 MLIRContext *context) {
1218 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
1219}
1220
1221void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1222 Value input, StringRef mesh,
1223 ArrayRef<MeshAxis> meshAxes, ReductionKind reduction) {
1224 build(odsBuilder, odsState, input.getType(), mesh, meshAxes, input,
1225 reduction);
1226}
1227
1228void AllReduceOp::getAsmResultNames(
1229 function_ref<void(Value, StringRef)> setNameFn) {
1230 setNameFn(getResult(), "all_reduce");
1231}
1232
1233//===----------------------------------------------------------------------===//
1234// mesh.all_slice op
1235//===----------------------------------------------------------------------===//
1236
1237LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1238 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1239 if (failed(mesh)) {
1240 return failure();
1241 }
1242 return verifyScatterOrSliceOperandAndResultShape(
1243 getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(),
1244 mesh.value().getShape());
1245}
1246
1247void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1248 MLIRContext *context) {
1249 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context);
1250}
1251
1252void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1253 Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
1254 int64_t sliceAxis) {
1255 Type resultType = sliceResultType(input.getType(), mesh, meshAxes, sliceAxis);
1256 build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
1257 sliceAxis);
1258}
1259
1260void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1261 Type resultType, Value input, StringRef mesh,
1262 ArrayRef<MeshAxis> meshAxes, int64_t sliceAxis) {
1263 build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
1264 APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
1265}
1266
1267void AllSliceOp::getAsmResultNames(
1268 function_ref<void(Value, StringRef)> setNameFn) {
1269 setNameFn(getResult(), "all_slice");
1270}
1271
1272//===----------------------------------------------------------------------===//
1273// mesh.all_to_all op
1274//===----------------------------------------------------------------------===//
1275
1276LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1277 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1278 if (failed(mesh)) {
1279 return failure();
1280 }
1281
1282 return verifyAllToAllOperandAndResultShape(
1283 getOperand(), getResult(), getSplitAxis().getSExtValue(),
1284 getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
1285}
1286
1287void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1288 MLIRContext *context) {
1289 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
1290}
1291
1292void AllToAllOp::getAsmResultNames(
1293 function_ref<void(Value, StringRef)> setNameFn) {
1294 setNameFn(getResult(), "all_to_all");
1295}
1296
1297//===----------------------------------------------------------------------===//
1298// mesh.broadcast op
1299//===----------------------------------------------------------------------===//
1300
1301LogicalResult
1302BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1303 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1304 if (failed(mesh)) {
1305 return failure();
1306 }
1307 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1308 getRootDynamic(), getMeshAxes(),
1309 mesh.value().getShape()))) {
1310 return failure();
1311 }
1312
1313 return success();
1314}
1315
1316void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1317 MLIRContext *context) {
1318 patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
1319}
1320
1321void BroadcastOp::getAsmResultNames(
1322 function_ref<void(Value, StringRef)> setNameFn) {
1323 setNameFn(getResult(), "broadcast");
1324}
1325
1326//===----------------------------------------------------------------------===//
1327// mesh.gather op
1328//===----------------------------------------------------------------------===//
1329
1330LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1331 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1332 if (failed(mesh)) {
1333 return failure();
1334 }
1335 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1336 getRootDynamic(), getMeshAxes(),
1337 mesh.value().getShape()))) {
1338 return failure();
1339 }
1340
1341 auto gatherAxis = getGatherAxis().getSExtValue();
1342 return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
1343 getMeshAxes(),
1344 mesh.value().getShape());
1345}
1346
1347void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1348 MLIRContext *context) {
1349 patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
1350}
1351
1352void GatherOp::getAsmResultNames(
1353 function_ref<void(Value, StringRef)> setNameFn) {
1354 setNameFn(getResult(), "gather");
1355}
1356
1357//===----------------------------------------------------------------------===//
1358// mesh.recv op
1359//===----------------------------------------------------------------------===//
1360
1361LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1362 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1363 if (failed(mesh)) {
1364 return failure();
1365 }
1366 if (getSource() &&
1367 failed(verifyInGroupDevice(getLoc(), getSourceAttrName(),
1368 getSource().value(), getSourceDynamic(),
1369 getMeshAxes(), mesh.value().getShape()))) {
1370 return failure();
1371 }
1372 return success();
1373}
1374
1375void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1376 MLIRContext *context) {
1377 patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
1378}
1379
1380void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1381 setNameFn(getResult(), "recv");
1382}
1383
1384//===----------------------------------------------------------------------===//
1385// mesh.reduce op
1386//===----------------------------------------------------------------------===//
1387
1388LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1389 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1390 if (failed(mesh)) {
1391 return failure();
1392 }
1393 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1394 getRootDynamic(), getMeshAxes(),
1395 mesh.value().getShape()))) {
1396 return failure();
1397 }
1398
1399 return success();
1400}
1401
1402void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1403 MLIRContext *context) {
1404 patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
1405}
1406
1407void ReduceOp::getAsmResultNames(
1408 function_ref<void(Value, StringRef)> setNameFn) {
1409 setNameFn(getResult(), "reduce");
1410}
1411
1412//===----------------------------------------------------------------------===//
1413// mesh.reduce_scatter op
1414//===----------------------------------------------------------------------===//
1415
1416LogicalResult
1417ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1418 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1419 if (failed(mesh)) {
1420 return failure();
1421 }
1422
1423 return verifyScatterOrSliceOperandAndResultShape(
1424 getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
1425 mesh.value().getShape());
1426}
1427
1428void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1429 MLIRContext *context) {
1430 patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
1431}
1432
1433void ReduceScatterOp::getAsmResultNames(
1434 function_ref<void(Value, StringRef)> setNameFn) {
1435 setNameFn(getResult(), "reduce_scatter");
1436}
1437
1438//===----------------------------------------------------------------------===//
1439// mesh.scatter op
1440//===----------------------------------------------------------------------===//
1441
1442LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1443 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1444 if (failed(mesh)) {
1445 return failure();
1446 }
1447 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
1448 getRootDynamic(), getMeshAxes(),
1449 mesh.value().getShape()))) {
1450 return failure();
1451 }
1452
1453 auto scatterAxis = getScatterAxis().getSExtValue();
1454 return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(),
1455 scatterAxis, getMeshAxes(),
1456 mesh.value().getShape());
1457}
1458
1459void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1460 MLIRContext *context) {
1461 patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
1462}
1463
1464void ScatterOp::getAsmResultNames(
1465 function_ref<void(Value, StringRef)> setNameFn) {
1466 setNameFn(getResult(), "scatter");
1467}
1468
1469//===----------------------------------------------------------------------===//
1470// mesh.send op
1471//===----------------------------------------------------------------------===//
1472
1473LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1474 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1475 if (failed(mesh)) {
1476 return failure();
1477 }
1478 if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
1479 getDestination(), getDestinationDynamic(),
1480 getMeshAxes(), mesh.value().getShape()))) {
1481 return failure();
1482 }
1483 return success();
1484}
1485
1486void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1487 MLIRContext *context) {
1488 patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
1489}
1490
1491void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1492 setNameFn(getResult(), "send");
1493}
1494
1495//===----------------------------------------------------------------------===//
1496// mesh.shift op
1497//===----------------------------------------------------------------------===//
1498
1499LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1500 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1501 if (failed(mesh)) {
1502 return failure();
1503 }
1504
1505 auto meshAxes = getMeshAxes();
1506 auto shiftAxis = getShiftAxis().getZExtValue();
1507 if (!llvm::is_contained(meshAxes, shiftAxis)) {
1508 return emitError() << "Invalid shift axis " << shiftAxis
1509 << ". It must be one of the grouping mesh axes.";
1510 }
1511
1512 return success();
1513}
1514
1515void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1516 MLIRContext *context) {
1517 // TODO: remove op when offset is 0 or if it is a rotate with and
1518 // offset % shift_axis_mesh_dim_size == 0.
1519}
1520
1521void ShiftOp::getAsmResultNames(
1522 function_ref<void(Value, StringRef)> setNameFn) {
1523 setNameFn(getResult(), "shift");
1524}
1525
1526//===----------------------------------------------------------------------===//
1527// mesh.update_halo op
1528//===----------------------------------------------------------------------===//
1529
1530LogicalResult
1531UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1532 auto mesh = getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
1533 if (failed(mesh)) {
1534 return failure();
1535 }
1536
1537 return success();
1538}
1539
1540//===----------------------------------------------------------------------===//
1541// TableGen'd op method definitions
1542//===----------------------------------------------------------------------===//
1543
1544#define GET_OP_CLASSES
1545#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
1546
1547#define GET_ATTRDEF_CLASSES
1548#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
1549
1550#define GET_TYPEDEF_CLASSES
1551#include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
1552
1553#include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
1554

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Mesh/IR/MeshOps.cpp