1//===- MeshOps.cpp - Mesh Dialect Operations ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Mesh/IR/MeshOps.h"
10
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
13#include "mlir/Dialect/Utils/StaticValueUtils.h"
14#include "mlir/IR/Attributes.h"
15#include "mlir/IR/BuiltinAttributes.h"
16#include "mlir/IR/BuiltinTypeInterfaces.h"
17#include "mlir/IR/BuiltinTypes.h"
18#include "mlir/IR/Diagnostics.h"
19#include "mlir/IR/DialectImplementation.h"
20#include "mlir/IR/Location.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/IR/TypeUtilities.h"
23#include "mlir/Interfaces/ViewLikeInterface.h"
24#include "mlir/Support/LLVM.h"
25#include "mlir/Support/LogicalResult.h"
26#include "llvm/ADT/ArrayRef.h"
27#include "llvm/ADT/STLExtras.h"
28#include "llvm/ADT/SmallSet.h"
29#include "llvm/ADT/SmallVector.h"
30#include "llvm/ADT/TypeSwitch.h"
31#include <algorithm>
32#include <functional>
33#include <iterator>
34#include <numeric>
35#include <optional>
36#include <utility>
37
38#define DEBUG_TYPE "mesh-ops"
39#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
40
41using namespace mlir;
42using namespace mlir::mesh;
43
44#include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
45
46namespace {
47
48struct DimensionSize {
49 static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
50 DimensionSize(int64_t val) : val(val) {}
51 int64_t value() const { return val; }
52 operator int64_t() const { return val; }
53 bool isDynamic() const { return ShapedType::isDynamic(val); }
54
55private:
56 int64_t val;
57};
58
59} // namespace
60
61static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
62 if (lhs.isDynamic() || rhs.isDynamic()) {
63 return DimensionSize::dynamic();
64 }
65 return lhs.value() / rhs.value();
66}
67
68static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
69 if (lhs.isDynamic() || rhs.isDynamic()) {
70 return DimensionSize::dynamic();
71 }
72 return lhs.value() * rhs.value();
73}
74
75//===----------------------------------------------------------------------===//
76// Mesh dialect
77//===----------------------------------------------------------------------===//
78
79void MeshDialect::initialize() {
80 addOperations<
81#define GET_OP_LIST
82#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
83 >();
84 addAttributes<
85#define GET_ATTRDEF_LIST
86#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
87 >();
88}
89
90Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
91 Type type, Location loc) {
92 return arith::ConstantOp::materialize(builder, value, type, loc);
93}
94
95//===----------------------------------------------------------------------===//
96// Mesh utilities
97//===----------------------------------------------------------------------===//
98
99static FailureOr<MeshOp> getMeshAndVerify(Operation *op,
100 FlatSymbolRefAttr meshSymbol,
101 SymbolTableCollection &symbolTable) {
102 mesh::MeshOp mesh = getMesh(op, meshSymbol, symbolTable);
103 if (!mesh) {
104 return op->emitError() << "Undefined required mesh symbol \""
105 << meshSymbol.getValue() << "\".";
106 }
107
108 return mesh;
109}
110
111template <typename It>
112bool isUnique(It begin, It end) {
113 if (begin == end) {
114 return true;
115 }
116 It next = std::next(begin);
117 if (next == end) {
118 return true;
119 }
120 for (; next != end; ++next, ++begin) {
121 if (*begin == *next) {
122 return false;
123 }
124 }
125 return true;
126}
127
128static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
129 MeshOp mesh) {
130 SmallVector<MeshAxis> sorted = llvm::to_vector(Range&: axes);
131 llvm::sort(C&: sorted);
132 if (!isUnique(begin: sorted.begin(), end: sorted.end())) {
133 return emitError(loc) << "Mesh axes contains duplicate elements.";
134 }
135
136 MeshAxis rank = mesh.getRank();
137 for (auto axis : axes) {
138 if (axis >= rank || axis < 0) {
139 return emitError(loc)
140 << "0-based mesh axis index " << axis
141 << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
142 << "\" is of rank " << rank << ".";
143 }
144 }
145
146 return success();
147}
148
149template <typename InShape, typename MeshShape, typename SplitAxes,
150 typename OutShape>
151static void shardShape(const InShape &inShape, const MeshShape &meshShape,
152 const SplitAxes &splitAxes, OutShape &outShape) {
153 std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
154 llvm::adl_begin(outShape));
155 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
156 outShape[tensorAxis] = shardDimension(
157 inShape[tensorAxis],
158 collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
159 }
160}
161
162ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
163 MeshShardingAttr sharding) {
164 using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
165 SmallVector<Dim> resShapeArr(shape.getShape().size());
166 shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
167 resShapeArr);
168 return shape.clone(resShapeArr);
169}
170
171Type mesh::shardType(Type type, MeshOp mesh, MeshShardingAttr sharding) {
172 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
173 if (rankedTensorType) {
174 return shardShapedType(rankedTensorType, mesh, sharding);
175 }
176
177 assert(!sharding);
178 return type;
179}
180
181//===----------------------------------------------------------------------===//
182// mesh.mesh op
183//===----------------------------------------------------------------------===//
184
185LogicalResult MeshOp::verify() {
186 int64_t rank = getRank();
187
188 if (rank <= 0)
189 return emitOpError("rank of mesh is expected to be a positive integer");
190
191 for (int64_t dimSize : getShape()) {
192 if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
193 return emitOpError("dimension size of a mesh is expected to be "
194 "non-negative or dynamic");
195 }
196
197 return success();
198}
199
200//===----------------------------------------------------------------------===//
201// mesh.mesh_shape op
202//===----------------------------------------------------------------------===//
203
204LogicalResult
205MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
206 auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
207 if (failed(mesh)) {
208 return failure();
209 }
210 if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
211 return failure();
212 }
213
214 size_t expectedResultsCount =
215 getAxes().empty() ? mesh->getRank() : getAxes().size();
216 if (getResult().size() != expectedResultsCount) {
217 return emitError() << "Unexpected number of results " << getResult().size()
218 << ". Expected " << expectedResultsCount << ".";
219 }
220
221 return success();
222}
223
224void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
225 MeshOp mesh) {
226 build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>());
227}
228
229void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
230 MeshOp mesh, ArrayRef<MeshAxis> axes) {
231 build(odsBuilder, odsState,
232 SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(),
233 odsBuilder.getIndexType()),
234 mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes));
235}
236
237void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
238 StringRef mesh, ArrayRef<MeshAxis> axes) {
239 assert(!axes.empty());
240 build(odsBuilder, odsState,
241 SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
242 MeshAxesAttr::get(odsBuilder.getContext(), axes));
243}
244
245void MeshShapeOp::getAsmResultNames(
246 function_ref<void(Value, StringRef)> setNameFn) {
247 setNameFn(getResults()[0], "mesh_shape");
248}
249
250//===----------------------------------------------------------------------===//
251// mesh.shard attr
252//===----------------------------------------------------------------------===//
253
254LogicalResult
255MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
256 FlatSymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
257 ArrayRef<MeshAxis> partialAxes, ReductionKind) {
258 // TODO: At present mesh symbol ref is not verified. This is due to the
259 // difficulty in fetching the corresponding symbol op based on an attribute.
260
261 llvm::SmallSet<MeshAxis, 4> visitedAxes;
262
263 auto checkMeshAxis = [&](ArrayRef<MeshAxis> axesArray) -> LogicalResult {
264 for (MeshAxis axis : axesArray) {
265 if (axis < 0)
266 return emitError() << "mesh axis is expected to be non-negative";
267 if (!visitedAxes.insert(axis).second)
268 return emitError() << "mesh axis duplicated";
269 }
270 return success();
271 };
272
273 for (MeshAxesAttr subAxes : splitAxes) {
274 ArrayRef<MeshAxis> subAxesArray = subAxes.asArrayRef();
275 if (failed(checkMeshAxis(subAxesArray)))
276 return failure();
277 }
278 if (failed(checkMeshAxis(partialAxes)))
279 return failure();
280 return success();
281}
282
283bool MeshShardingAttr::operator==(Attribute rhs) const {
284 MeshShardingAttr rhsAsMeshShardingAttr =
285 mlir::dyn_cast<MeshShardingAttr>(rhs);
286 return rhsAsMeshShardingAttr && *this == rhsAsMeshShardingAttr;
287}
288
289bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
290 if (getMesh() != rhs.getMesh() || getPartialAxes() != rhs.getPartialAxes()) {
291 return false;
292 }
293
294 if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) {
295 return false;
296 }
297
298 auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
299 if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
300 getSplitAxes().begin() + minSize),
301 llvm::make_range(rhs.getSplitAxes().begin(),
302 rhs.getSplitAxes().begin() + minSize))) {
303 return false;
304 }
305
306 return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize,
307 getSplitAxes().end()),
308 std::mem_fn(&MeshAxesAttr::empty)) &&
309 llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize,
310 rhs.getSplitAxes().end()),
311 std::mem_fn(&MeshAxesAttr::empty));
312}
313
314//===----------------------------------------------------------------------===//
315// mesh.shard op
316//===----------------------------------------------------------------------===//
317
318void ShardOp::getAsmResultNames(
319 function_ref<void(Value, StringRef)> setNameFn) {
320 setNameFn(getResult(), "sharding_annotated");
321}
322
323//===----------------------------------------------------------------------===//
324// mesh.process_multi_index op
325//===----------------------------------------------------------------------===//
326
327LogicalResult
328ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
329 auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
330 if (failed(mesh)) {
331 return failure();
332 }
333 if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
334 return failure();
335 }
336
337 size_t expectedResultsCount =
338 getAxes().empty() ? mesh->getRank() : getAxes().size();
339 if (getResult().size() != expectedResultsCount) {
340 return emitError() << "Unexpected number of results " << getResult().size()
341 << ". Expected " << expectedResultsCount << ".";
342 }
343
344 return success();
345}
346
347void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
348 MeshOp mesh) {
349 build(odsBuilder, odsState,
350 SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
351 mesh.getSymName(), ArrayRef<MeshAxis>());
352}
353
354void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
355 StringRef mesh, ArrayRef<MeshAxis> axes) {
356 build(odsBuilder, odsState,
357 SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
358 MeshAxesAttr::get(odsBuilder.getContext(), axes));
359}
360
361void ProcessMultiIndexOp::getAsmResultNames(
362 function_ref<void(Value, StringRef)> setNameFn) {
363 setNameFn(getResults()[0], "proc_linear_idx");
364}
365
366//===----------------------------------------------------------------------===//
367// mesh.process_linear_index op
368//===----------------------------------------------------------------------===//
369
370LogicalResult
371ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
372 auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
373 if (failed(mesh)) {
374 return failure();
375 }
376 return success();
377}
378
379void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
380 OperationState &odsState, MeshOp mesh) {
381 build(odsBuilder, odsState, mesh.getSymName());
382}
383
384void ProcessLinearIndexOp::getAsmResultNames(
385 function_ref<void(Value, StringRef)> setNameFn) {
386 setNameFn(getResult(), "proc_linear_idx");
387}
388
389//===----------------------------------------------------------------------===//
390// collective communication ops
391//===----------------------------------------------------------------------===//
392
393namespace {
394
395template <typename Op>
396struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
397 using OpRewritePattern<Op>::OpRewritePattern;
398 LogicalResult matchAndRewrite(Op op,
399 PatternRewriter &rewriter) const override {
400 auto meshAxes = op.getMeshAxes();
401 if (!meshAxes.empty()) {
402 return failure();
403 }
404 if (op.getInput().getType() != op.getResult().getType()) {
405 return failure();
406 }
407
408 rewriter.replaceAllUsesWith(op.getResult(), op.getInput());
409 rewriter.eraseOp(op: op.getOperation());
410 return success();
411 }
412};
413
414} // namespace
415
416static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
417 ArrayRef<int64_t> device,
418 Operation::operand_range deviceDynamic,
419 ArrayRef<MeshAxis> meshAxes,
420 ArrayRef<int64_t> meshShape) {
421 if (device.size() != meshAxes.size()) {
422 return emitError(loc) << "In-group device \"" << deviceName
423 << "\" has unexpected multi-index size "
424 << device.size() << ". Expected " << meshAxes.size()
425 << ".";
426 }
427
428 for (size_t i = 0; i < device.size(); ++i) {
429 if (!ShapedType::isDynamic(device[i]) &&
430 !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
431 meshShape[meshAxes[i]] <= device[i]) {
432 return emitError(loc)
433 << "Out of bounds coordinate " << i << " for in-group device \""
434 << deviceName << "\"."
435 << " Got " << device[i] << ", but expected value in the range [0, "
436 << (meshShape[meshAxes[i]] - 1) << "].";
437 }
438 }
439 return success();
440}
441
442template <typename Op>
443static FailureOr<MeshOp>
444getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
445 auto mesh =
446 ::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable);
447 if (failed(mesh)) {
448 return failure();
449 }
450 if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) {
451 return failure();
452 }
453 return mesh;
454}
455
456template <typename It>
457static auto product(It begin, It end) {
458 using ElementType = std::decay_t<decltype(*begin)>;
459 return std::accumulate(begin, end, static_cast<ElementType>(1),
460 std::multiplies<ElementType>());
461}
462
463template <typename R>
464static auto product(R &&range) {
465 return product(adl_begin(range), adl_end(range));
466}
467
468static LogicalResult verifyDimensionCompatibility(Location loc,
469 int64_t expectedDimSize,
470 int64_t resultDimSize,
471 int64_t resultAxis) {
472 if (!ShapedType::isDynamic(resultDimSize) &&
473 expectedDimSize != resultDimSize) {
474 return emitError(loc) << "Dimension size mismatch for result axis "
475 << resultAxis << ". Expected "
476 << (ShapedType::isDynamic(expectedDimSize)
477 ? Twine("dynamic")
478 : Twine(expectedDimSize))
479 << ", but got " << resultDimSize << ".";
480 }
481
482 return success();
483}
484
485static LogicalResult verifyGatherOperandAndResultShape(
486 Value operand, Value result, int64_t gatherAxis,
487 ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
488 auto resultRank = cast<ShapedType>(result.getType()).getRank();
489 if (gatherAxis < 0 || gatherAxis >= resultRank) {
490 return emitError(loc: result.getLoc())
491 << "Gather axis " << gatherAxis << " is out of bounds [0, "
492 << resultRank << ").";
493 }
494
495 ShapedType operandType = cast<ShapedType>(operand.getType());
496 ShapedType resultType = cast<ShapedType>(result.getType());
497 auto deviceGroupSize =
498 DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
499 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
500 auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
501 auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
502 auto expectedResultDimSize =
503 axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
504 if (failed(verifyDimensionCompatibility(
505 result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
506 return failure();
507 }
508 }
509 return success();
510}
511
512static LogicalResult verifyAllToAllOperandAndResultShape(
513 Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
514 ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
515 ShapedType operandType = cast<ShapedType>(operand.getType());
516 ShapedType resultType = cast<ShapedType>(result.getType());
517 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
518 if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
519 if (failed(verifyDimensionCompatibility(
520 result.getLoc(), operandType.getDimSize(axis),
521 resultType.getDimSize(axis), axis))) {
522 return failure();
523 }
524 }
525 }
526
527 if (splitAxis == concatAxis) {
528 return success();
529 }
530
531 auto deviceGroupSize =
532 DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
533 auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
534 auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
535 DimensionSize expectedResultConcatDimSize =
536 operandConcatDimSize * deviceGroupSize;
537 DimensionSize expectedResultSplitDimSize =
538 operandSplitDimSize / deviceGroupSize;
539 if (!expectedResultSplitDimSize.isDynamic() &&
540 int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
541 expectedResultSplitDimSize = DimensionSize::dynamic();
542 }
543 if (failed(verifyDimensionCompatibility(
544 result.getLoc(), expectedResultConcatDimSize.value(),
545 resultType.getDimSize(concatAxis), concatAxis))) {
546 return failure();
547 }
548 if (failed(verifyDimensionCompatibility(
549 result.getLoc(), expectedResultSplitDimSize.value(),
550 resultType.getDimSize(splitAxis), splitAxis))) {
551 return failure();
552 }
553
554 return success();
555}
556
557static LogicalResult verifyScatterOrSliceOperandAndResultShape(
558 Value operand, Value result, int64_t tensorAxis,
559 ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
560 ShapedType operandType = cast<ShapedType>(operand.getType());
561 ShapedType resultType = cast<ShapedType>(result.getType());
562 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
563 if (axis != tensorAxis) {
564 if (failed(verifyDimensionCompatibility(
565 result.getLoc(), operandType.getDimSize(axis),
566 resultType.getDimSize(axis), axis))) {
567 return failure();
568 }
569 }
570 }
571
572 auto deviceGroupSize =
573 DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
574 auto operandScatterDimSize =
575 DimensionSize(operandType.getDimSize(tensorAxis));
576 if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
577 int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
578 return emitError(loc: result.getLoc())
579 << "Operand dimension size " << int64_t(operandScatterDimSize)
580 << " is not divisible by collective device group size "
581 << int64_t(deviceGroupSize) << " for tensor axis " << tensorAxis
582 << ".";
583 }
584 DimensionSize expectedResultTensorDimSize =
585 operandScatterDimSize / deviceGroupSize;
586 if (failed(verifyDimensionCompatibility(
587 result.getLoc(), expectedResultTensorDimSize.value(),
588 resultType.getDimSize(tensorAxis), tensorAxis))) {
589 return failure();
590 }
591
592 return success();
593}
594
595static RankedTensorType sliceResultType(Type operandType, MeshOp mesh,
596 ArrayRef<MeshAxis> meshAxes,
597 int64_t sliceAxis) {
598 RankedTensorType operandRankedTensorType =
599 cast<RankedTensorType>(operandType);
600 DimensionSize operandSliceAxisSize =
601 operandRankedTensorType.getShape()[sliceAxis];
602 SmallVector<int64_t> resultShape =
603 llvm::to_vector(operandRankedTensorType.getShape());
604
605 resultShape[sliceAxis] =
606 operandSliceAxisSize /
607 DimensionSize(collectiveProcessGroupSize(meshAxes, mesh));
608 return operandRankedTensorType.clone(resultShape);
609}
610
611//===----------------------------------------------------------------------===//
612// mesh.all_gather op
613//===----------------------------------------------------------------------===//
614
615LogicalResult
616AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
617 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
618 if (failed(mesh)) {
619 return failure();
620 }
621 auto gatherAxis = getGatherAxis().getSExtValue();
622 return verifyGatherOperandAndResultShape(getOperand(), getResult(),
623 gatherAxis, getMeshAxes(),
624 mesh.value().getShape());
625}
626
627void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
628 MLIRContext *context) {
629 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
630}
631
632void AllGatherOp::getAsmResultNames(
633 function_ref<void(Value, StringRef)> setNameFn) {
634 setNameFn(getResult(), "all_gather");
635}
636
637//===----------------------------------------------------------------------===//
638// mesh.all_reduce op
639//===----------------------------------------------------------------------===//
640
641LogicalResult
642AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
643 return getMeshAndVerifyAxes(*this, symbolTable);
644}
645
646void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
647 MLIRContext *context) {
648 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
649}
650
651void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
652 Value input, StringRef mesh,
653 ArrayRef<MeshAxis> meshAxes, ReductionKind reduction) {
654 build(odsBuilder, odsState, input.getType(), mesh, meshAxes, input,
655 reduction);
656}
657
658void AllReduceOp::getAsmResultNames(
659 function_ref<void(Value, StringRef)> setNameFn) {
660 setNameFn(getResult(), "all_reduce");
661}
662
663//===----------------------------------------------------------------------===//
664// mesh.all_slice op
665//===----------------------------------------------------------------------===//
666
667LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
668 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
669 if (failed(mesh)) {
670 return failure();
671 }
672 return verifyScatterOrSliceOperandAndResultShape(
673 getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(),
674 mesh.value().getShape());
675}
676
677void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
678 MLIRContext *context) {
679 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context);
680}
681
682void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
683 Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
684 int64_t sliceAxis) {
685 Type resultType = sliceResultType(input.getType(), mesh, meshAxes, sliceAxis);
686 build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
687 sliceAxis);
688}
689
690void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
691 Type resultType, Value input, StringRef mesh,
692 ArrayRef<MeshAxis> meshAxes, int64_t sliceAxis) {
693 build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
694 APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
695}
696
697void AllSliceOp::getAsmResultNames(
698 function_ref<void(Value, StringRef)> setNameFn) {
699 setNameFn(getResult(), "all_slice");
700}
701
702//===----------------------------------------------------------------------===//
703// mesh.all_to_all op
704//===----------------------------------------------------------------------===//
705
706LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
707 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
708 if (failed(mesh)) {
709 return failure();
710 }
711
712 return verifyAllToAllOperandAndResultShape(
713 getOperand(), getResult(), getSplitAxis().getSExtValue(),
714 getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
715}
716
717void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
718 MLIRContext *context) {
719 patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
720}
721
722void AllToAllOp::getAsmResultNames(
723 function_ref<void(Value, StringRef)> setNameFn) {
724 setNameFn(getResult(), "all_to_all");
725}
726
727//===----------------------------------------------------------------------===//
728// mesh.broadcast op
729//===----------------------------------------------------------------------===//
730
731LogicalResult
732BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
733 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
734 if (failed(mesh)) {
735 return failure();
736 }
737 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
738 getRootDynamic(), getMeshAxes(),
739 mesh.value().getShape()))) {
740 return failure();
741 }
742
743 return success();
744}
745
746void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
747 MLIRContext *context) {
748 patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
749}
750
751void BroadcastOp::getAsmResultNames(
752 function_ref<void(Value, StringRef)> setNameFn) {
753 setNameFn(getResult(), "broadcast");
754}
755
756//===----------------------------------------------------------------------===//
757// mesh.gather op
758//===----------------------------------------------------------------------===//
759
760LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
761 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
762 if (failed(mesh)) {
763 return failure();
764 }
765 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
766 getRootDynamic(), getMeshAxes(),
767 mesh.value().getShape()))) {
768 return failure();
769 }
770
771 auto gatherAxis = getGatherAxis().getSExtValue();
772 return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
773 getMeshAxes(),
774 mesh.value().getShape());
775}
776
777void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
778 MLIRContext *context) {
779 patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
780}
781
782void GatherOp::getAsmResultNames(
783 function_ref<void(Value, StringRef)> setNameFn) {
784 setNameFn(getResult(), "gather");
785}
786
787//===----------------------------------------------------------------------===//
788// mesh.recv op
789//===----------------------------------------------------------------------===//
790
791LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
792 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
793 if (failed(mesh)) {
794 return failure();
795 }
796 if (getSource() &&
797 failed(verifyInGroupDevice(getLoc(), getSourceAttrName(),
798 getSource().value(), getSourceDynamic(),
799 getMeshAxes(), mesh.value().getShape()))) {
800 return failure();
801 }
802 return success();
803}
804
805void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
806 MLIRContext *context) {
807 patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
808}
809
810void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
811 setNameFn(getResult(), "recv");
812}
813
814//===----------------------------------------------------------------------===//
815// mesh.reduce op
816//===----------------------------------------------------------------------===//
817
818LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
819 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
820 if (failed(mesh)) {
821 return failure();
822 }
823 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
824 getRootDynamic(), getMeshAxes(),
825 mesh.value().getShape()))) {
826 return failure();
827 }
828
829 return success();
830}
831
832void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
833 MLIRContext *context) {
834 patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
835}
836
837void ReduceOp::getAsmResultNames(
838 function_ref<void(Value, StringRef)> setNameFn) {
839 setNameFn(getResult(), "reduce");
840}
841
842//===----------------------------------------------------------------------===//
843// mesh.reduce_scatter op
844//===----------------------------------------------------------------------===//
845
846LogicalResult
847ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
848 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
849 if (failed(mesh)) {
850 return failure();
851 }
852
853 return verifyScatterOrSliceOperandAndResultShape(
854 getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
855 mesh.value().getShape());
856}
857
858void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
859 MLIRContext *context) {
860 patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
861}
862
863void ReduceScatterOp::getAsmResultNames(
864 function_ref<void(Value, StringRef)> setNameFn) {
865 setNameFn(getResult(), "reduce_scatter");
866}
867
868//===----------------------------------------------------------------------===//
869// mesh.scatter op
870//===----------------------------------------------------------------------===//
871
872LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
873 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
874 if (failed(mesh)) {
875 return failure();
876 }
877 if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
878 getRootDynamic(), getMeshAxes(),
879 mesh.value().getShape()))) {
880 return failure();
881 }
882
883 auto scatterAxis = getScatterAxis().getSExtValue();
884 return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(),
885 scatterAxis, getMeshAxes(),
886 mesh.value().getShape());
887}
888
889void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
890 MLIRContext *context) {
891 patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
892}
893
894void ScatterOp::getAsmResultNames(
895 function_ref<void(Value, StringRef)> setNameFn) {
896 setNameFn(getResult(), "scatter");
897}
898
899//===----------------------------------------------------------------------===//
900// mesh.send op
901//===----------------------------------------------------------------------===//
902
903LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
904 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
905 if (failed(mesh)) {
906 return failure();
907 }
908 if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
909 getDestination(), getDestinationDynamic(),
910 getMeshAxes(), mesh.value().getShape()))) {
911 return failure();
912 }
913 return success();
914}
915
916void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
917 MLIRContext *context) {
918 patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
919}
920
921void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
922 setNameFn(getResult(), "send");
923}
924
925//===----------------------------------------------------------------------===//
926// mesh.shift op
927//===----------------------------------------------------------------------===//
928
929LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
930 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
931 if (failed(mesh)) {
932 return failure();
933 }
934
935 auto meshAxes = getMeshAxes();
936 auto shiftAxis = getShiftAxis().getZExtValue();
937 if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) {
938 return emitError() << "Invalid shift axis " << shiftAxis
939 << ". It must be one of the grouping mesh axes.";
940 }
941
942 return success();
943}
944
945void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
946 MLIRContext *context) {
947 // TODO: remove op when offset is 0 or if it is a rotate with and
948 // offset % shift_axis_mesh_dim_size == 0.
949}
950
951void ShiftOp::getAsmResultNames(
952 function_ref<void(Value, StringRef)> setNameFn) {
953 setNameFn(getResult(), "shift");
954}
955
956//===----------------------------------------------------------------------===//
957// TableGen'd op method definitions
958//===----------------------------------------------------------------------===//
959
960#define GET_OP_CLASSES
961#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
962
963#define GET_ATTRDEF_CLASSES
964#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
965
966#include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
967

source code of mlir/lib/Dialect/Mesh/IR/MeshOps.cpp