1 | //===- MeshOps.cpp - Mesh Dialect Operations ------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Mesh/IR/MeshOps.h" |
10 | |
11 | #include "mlir/Dialect/Arith/IR/Arith.h" |
12 | #include "mlir/Dialect/Mesh/IR/MeshDialect.h" |
13 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
14 | #include "mlir/IR/Attributes.h" |
15 | #include "mlir/IR/BuiltinAttributes.h" |
16 | #include "mlir/IR/BuiltinTypeInterfaces.h" |
17 | #include "mlir/IR/BuiltinTypes.h" |
18 | #include "mlir/IR/Diagnostics.h" |
19 | #include "mlir/IR/DialectImplementation.h" |
20 | #include "mlir/IR/Location.h" |
21 | #include "mlir/IR/PatternMatch.h" |
22 | #include "mlir/IR/TypeUtilities.h" |
23 | #include "mlir/Interfaces/ViewLikeInterface.h" |
24 | #include "mlir/Support/LLVM.h" |
25 | #include "mlir/Support/LogicalResult.h" |
26 | #include "llvm/ADT/ArrayRef.h" |
27 | #include "llvm/ADT/STLExtras.h" |
28 | #include "llvm/ADT/SmallSet.h" |
29 | #include "llvm/ADT/SmallVector.h" |
30 | #include "llvm/ADT/TypeSwitch.h" |
31 | #include <algorithm> |
32 | #include <functional> |
33 | #include <iterator> |
34 | #include <numeric> |
35 | #include <optional> |
36 | #include <utility> |
37 | |
38 | #define DEBUG_TYPE "mesh-ops" |
39 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") |
40 | |
41 | using namespace mlir; |
42 | using namespace mlir::mesh; |
43 | |
44 | #include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc" |
45 | |
46 | namespace { |
47 | |
48 | struct DimensionSize { |
49 | static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); } |
50 | DimensionSize(int64_t val) : val(val) {} |
51 | int64_t value() const { return val; } |
52 | operator int64_t() const { return val; } |
53 | bool isDynamic() const { return ShapedType::isDynamic(val); } |
54 | |
55 | private: |
56 | int64_t val; |
57 | }; |
58 | |
59 | } // namespace |
60 | |
61 | static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) { |
62 | if (lhs.isDynamic() || rhs.isDynamic()) { |
63 | return DimensionSize::dynamic(); |
64 | } |
65 | return lhs.value() / rhs.value(); |
66 | } |
67 | |
68 | static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) { |
69 | if (lhs.isDynamic() || rhs.isDynamic()) { |
70 | return DimensionSize::dynamic(); |
71 | } |
72 | return lhs.value() * rhs.value(); |
73 | } |
74 | |
75 | //===----------------------------------------------------------------------===// |
76 | // Mesh dialect |
77 | //===----------------------------------------------------------------------===// |
78 | |
79 | void MeshDialect::initialize() { |
80 | addOperations< |
81 | #define GET_OP_LIST |
82 | #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc" |
83 | >(); |
84 | addAttributes< |
85 | #define GET_ATTRDEF_LIST |
86 | #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc" |
87 | >(); |
88 | } |
89 | |
90 | Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value, |
91 | Type type, Location loc) { |
92 | return arith::ConstantOp::materialize(builder, value, type, loc); |
93 | } |
94 | |
95 | //===----------------------------------------------------------------------===// |
96 | // Mesh utilities |
97 | //===----------------------------------------------------------------------===// |
98 | |
99 | static FailureOr<MeshOp> getMeshAndVerify(Operation *op, |
100 | FlatSymbolRefAttr meshSymbol, |
101 | SymbolTableCollection &symbolTable) { |
102 | mesh::MeshOp mesh = getMesh(op, meshSymbol, symbolTable); |
103 | if (!mesh) { |
104 | return op->emitError() << "Undefined required mesh symbol \"" |
105 | << meshSymbol.getValue() << "\"." ; |
106 | } |
107 | |
108 | return mesh; |
109 | } |
110 | |
111 | template <typename It> |
112 | bool isUnique(It begin, It end) { |
113 | if (begin == end) { |
114 | return true; |
115 | } |
116 | It next = std::next(begin); |
117 | if (next == end) { |
118 | return true; |
119 | } |
120 | for (; next != end; ++next, ++begin) { |
121 | if (*begin == *next) { |
122 | return false; |
123 | } |
124 | } |
125 | return true; |
126 | } |
127 | |
128 | static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes, |
129 | MeshOp mesh) { |
130 | SmallVector<MeshAxis> sorted = llvm::to_vector(Range&: axes); |
131 | llvm::sort(C&: sorted); |
132 | if (!isUnique(begin: sorted.begin(), end: sorted.end())) { |
133 | return emitError(loc) << "Mesh axes contains duplicate elements." ; |
134 | } |
135 | |
136 | MeshAxis rank = mesh.getRank(); |
137 | for (auto axis : axes) { |
138 | if (axis >= rank || axis < 0) { |
139 | return emitError(loc) |
140 | << "0-based mesh axis index " << axis |
141 | << " is out of bounds. The referenced mesh \"" << mesh.getSymName() |
142 | << "\" is of rank " << rank << "." ; |
143 | } |
144 | } |
145 | |
146 | return success(); |
147 | } |
148 | |
149 | template <typename InShape, typename MeshShape, typename SplitAxes, |
150 | typename OutShape> |
151 | static void shardShape(const InShape &inShape, const MeshShape &meshShape, |
152 | const SplitAxes &splitAxes, OutShape &outShape) { |
153 | std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape), |
154 | llvm::adl_begin(outShape)); |
155 | for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) { |
156 | outShape[tensorAxis] = shardDimension( |
157 | inShape[tensorAxis], |
158 | collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape)); |
159 | } |
160 | } |
161 | |
162 | ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh, |
163 | MeshShardingAttr sharding) { |
164 | using Dim = std::decay_t<decltype(shape.getDimSize(0))>; |
165 | SmallVector<Dim> resShapeArr(shape.getShape().size()); |
166 | shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(), |
167 | resShapeArr); |
168 | return shape.clone(resShapeArr); |
169 | } |
170 | |
171 | Type mesh::shardType(Type type, MeshOp mesh, MeshShardingAttr sharding) { |
172 | RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type); |
173 | if (rankedTensorType) { |
174 | return shardShapedType(rankedTensorType, mesh, sharding); |
175 | } |
176 | |
177 | assert(!sharding); |
178 | return type; |
179 | } |
180 | |
181 | //===----------------------------------------------------------------------===// |
182 | // mesh.mesh op |
183 | //===----------------------------------------------------------------------===// |
184 | |
185 | LogicalResult MeshOp::verify() { |
186 | int64_t rank = getRank(); |
187 | |
188 | if (rank <= 0) |
189 | return emitOpError("rank of mesh is expected to be a positive integer" ); |
190 | |
191 | for (int64_t dimSize : getShape()) { |
192 | if (dimSize < 0 && !ShapedType::isDynamic(dimSize)) |
193 | return emitOpError("dimension size of a mesh is expected to be " |
194 | "non-negative or dynamic" ); |
195 | } |
196 | |
197 | return success(); |
198 | } |
199 | |
200 | //===----------------------------------------------------------------------===// |
201 | // mesh.mesh_shape op |
202 | //===----------------------------------------------------------------------===// |
203 | |
204 | LogicalResult |
205 | MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
206 | auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); |
207 | if (failed(mesh)) { |
208 | return failure(); |
209 | } |
210 | if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) { |
211 | return failure(); |
212 | } |
213 | |
214 | size_t expectedResultsCount = |
215 | getAxes().empty() ? mesh->getRank() : getAxes().size(); |
216 | if (getResult().size() != expectedResultsCount) { |
217 | return emitError() << "Unexpected number of results " << getResult().size() |
218 | << ". Expected " << expectedResultsCount << "." ; |
219 | } |
220 | |
221 | return success(); |
222 | } |
223 | |
224 | void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, |
225 | MeshOp mesh) { |
226 | build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>()); |
227 | } |
228 | |
229 | void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, |
230 | MeshOp mesh, ArrayRef<MeshAxis> axes) { |
231 | build(odsBuilder, odsState, |
232 | SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(), |
233 | odsBuilder.getIndexType()), |
234 | mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes)); |
235 | } |
236 | |
237 | void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, |
238 | StringRef mesh, ArrayRef<MeshAxis> axes) { |
239 | assert(!axes.empty()); |
240 | build(odsBuilder, odsState, |
241 | SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh, |
242 | MeshAxesAttr::get(odsBuilder.getContext(), axes)); |
243 | } |
244 | |
245 | void MeshShapeOp::getAsmResultNames( |
246 | function_ref<void(Value, StringRef)> setNameFn) { |
247 | setNameFn(getResults()[0], "mesh_shape" ); |
248 | } |
249 | |
250 | //===----------------------------------------------------------------------===// |
251 | // mesh.shard attr |
252 | //===----------------------------------------------------------------------===// |
253 | |
254 | LogicalResult |
255 | MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
256 | FlatSymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes, |
257 | ArrayRef<MeshAxis> partialAxes, ReductionKind) { |
258 | // TODO: At present mesh symbol ref is not verified. This is due to the |
259 | // difficulty in fetching the corresponding symbol op based on an attribute. |
260 | |
261 | llvm::SmallSet<MeshAxis, 4> visitedAxes; |
262 | |
263 | auto checkMeshAxis = [&](ArrayRef<MeshAxis> axesArray) -> LogicalResult { |
264 | for (MeshAxis axis : axesArray) { |
265 | if (axis < 0) |
266 | return emitError() << "mesh axis is expected to be non-negative" ; |
267 | if (!visitedAxes.insert(axis).second) |
268 | return emitError() << "mesh axis duplicated" ; |
269 | } |
270 | return success(); |
271 | }; |
272 | |
273 | for (MeshAxesAttr subAxes : splitAxes) { |
274 | ArrayRef<MeshAxis> subAxesArray = subAxes.asArrayRef(); |
275 | if (failed(checkMeshAxis(subAxesArray))) |
276 | return failure(); |
277 | } |
278 | if (failed(checkMeshAxis(partialAxes))) |
279 | return failure(); |
280 | return success(); |
281 | } |
282 | |
283 | bool MeshShardingAttr::operator==(Attribute rhs) const { |
284 | MeshShardingAttr rhsAsMeshShardingAttr = |
285 | mlir::dyn_cast<MeshShardingAttr>(rhs); |
286 | return rhsAsMeshShardingAttr && *this == rhsAsMeshShardingAttr; |
287 | } |
288 | |
289 | bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const { |
290 | if (getMesh() != rhs.getMesh() || getPartialAxes() != rhs.getPartialAxes()) { |
291 | return false; |
292 | } |
293 | |
294 | if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) { |
295 | return false; |
296 | } |
297 | |
298 | auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size()); |
299 | if (!llvm::equal(llvm::make_range(getSplitAxes().begin(), |
300 | getSplitAxes().begin() + minSize), |
301 | llvm::make_range(rhs.getSplitAxes().begin(), |
302 | rhs.getSplitAxes().begin() + minSize))) { |
303 | return false; |
304 | } |
305 | |
306 | return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize, |
307 | getSplitAxes().end()), |
308 | std::mem_fn(&MeshAxesAttr::empty)) && |
309 | llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize, |
310 | rhs.getSplitAxes().end()), |
311 | std::mem_fn(&MeshAxesAttr::empty)); |
312 | } |
313 | |
314 | //===----------------------------------------------------------------------===// |
315 | // mesh.shard op |
316 | //===----------------------------------------------------------------------===// |
317 | |
318 | void ShardOp::getAsmResultNames( |
319 | function_ref<void(Value, StringRef)> setNameFn) { |
320 | setNameFn(getResult(), "sharding_annotated" ); |
321 | } |
322 | |
323 | //===----------------------------------------------------------------------===// |
324 | // mesh.process_multi_index op |
325 | //===----------------------------------------------------------------------===// |
326 | |
327 | LogicalResult |
328 | ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
329 | auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); |
330 | if (failed(mesh)) { |
331 | return failure(); |
332 | } |
333 | if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) { |
334 | return failure(); |
335 | } |
336 | |
337 | size_t expectedResultsCount = |
338 | getAxes().empty() ? mesh->getRank() : getAxes().size(); |
339 | if (getResult().size() != expectedResultsCount) { |
340 | return emitError() << "Unexpected number of results " << getResult().size() |
341 | << ". Expected " << expectedResultsCount << "." ; |
342 | } |
343 | |
344 | return success(); |
345 | } |
346 | |
347 | void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState, |
348 | MeshOp mesh) { |
349 | build(odsBuilder, odsState, |
350 | SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()), |
351 | mesh.getSymName(), ArrayRef<MeshAxis>()); |
352 | } |
353 | |
354 | void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState, |
355 | StringRef mesh, ArrayRef<MeshAxis> axes) { |
356 | build(odsBuilder, odsState, |
357 | SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh, |
358 | MeshAxesAttr::get(odsBuilder.getContext(), axes)); |
359 | } |
360 | |
361 | void ProcessMultiIndexOp::getAsmResultNames( |
362 | function_ref<void(Value, StringRef)> setNameFn) { |
363 | setNameFn(getResults()[0], "proc_linear_idx" ); |
364 | } |
365 | |
366 | //===----------------------------------------------------------------------===// |
367 | // mesh.process_linear_index op |
368 | //===----------------------------------------------------------------------===// |
369 | |
370 | LogicalResult |
371 | ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
372 | auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); |
373 | if (failed(mesh)) { |
374 | return failure(); |
375 | } |
376 | return success(); |
377 | } |
378 | |
379 | void ProcessLinearIndexOp::build(OpBuilder &odsBuilder, |
380 | OperationState &odsState, MeshOp mesh) { |
381 | build(odsBuilder, odsState, mesh.getSymName()); |
382 | } |
383 | |
384 | void ProcessLinearIndexOp::getAsmResultNames( |
385 | function_ref<void(Value, StringRef)> setNameFn) { |
386 | setNameFn(getResult(), "proc_linear_idx" ); |
387 | } |
388 | |
389 | //===----------------------------------------------------------------------===// |
390 | // collective communication ops |
391 | //===----------------------------------------------------------------------===// |
392 | |
393 | namespace { |
394 | |
395 | template <typename Op> |
396 | struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> { |
397 | using OpRewritePattern<Op>::OpRewritePattern; |
398 | LogicalResult matchAndRewrite(Op op, |
399 | PatternRewriter &rewriter) const override { |
400 | auto meshAxes = op.getMeshAxes(); |
401 | if (!meshAxes.empty()) { |
402 | return failure(); |
403 | } |
404 | if (op.getInput().getType() != op.getResult().getType()) { |
405 | return failure(); |
406 | } |
407 | |
408 | rewriter.replaceAllUsesWith(op.getResult(), op.getInput()); |
409 | rewriter.eraseOp(op: op.getOperation()); |
410 | return success(); |
411 | } |
412 | }; |
413 | |
414 | } // namespace |
415 | |
416 | static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, |
417 | ArrayRef<int64_t> device, |
418 | Operation::operand_range deviceDynamic, |
419 | ArrayRef<MeshAxis> meshAxes, |
420 | ArrayRef<int64_t> meshShape) { |
421 | if (device.size() != meshAxes.size()) { |
422 | return emitError(loc) << "In-group device \"" << deviceName |
423 | << "\" has unexpected multi-index size " |
424 | << device.size() << ". Expected " << meshAxes.size() |
425 | << "." ; |
426 | } |
427 | |
428 | for (size_t i = 0; i < device.size(); ++i) { |
429 | if (!ShapedType::isDynamic(device[i]) && |
430 | !ShapedType::isDynamic(meshShape[meshAxes[i]]) && |
431 | meshShape[meshAxes[i]] <= device[i]) { |
432 | return emitError(loc) |
433 | << "Out of bounds coordinate " << i << " for in-group device \"" |
434 | << deviceName << "\"." |
435 | << " Got " << device[i] << ", but expected value in the range [0, " |
436 | << (meshShape[meshAxes[i]] - 1) << "]." ; |
437 | } |
438 | } |
439 | return success(); |
440 | } |
441 | |
442 | template <typename Op> |
443 | static FailureOr<MeshOp> |
444 | getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) { |
445 | auto mesh = |
446 | ::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable); |
447 | if (failed(mesh)) { |
448 | return failure(); |
449 | } |
450 | if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) { |
451 | return failure(); |
452 | } |
453 | return mesh; |
454 | } |
455 | |
456 | template <typename It> |
457 | static auto product(It begin, It end) { |
458 | using ElementType = std::decay_t<decltype(*begin)>; |
459 | return std::accumulate(begin, end, static_cast<ElementType>(1), |
460 | std::multiplies<ElementType>()); |
461 | } |
462 | |
463 | template <typename R> |
464 | static auto product(R &&range) { |
465 | return product(adl_begin(range), adl_end(range)); |
466 | } |
467 | |
468 | static LogicalResult verifyDimensionCompatibility(Location loc, |
469 | int64_t expectedDimSize, |
470 | int64_t resultDimSize, |
471 | int64_t resultAxis) { |
472 | if (!ShapedType::isDynamic(resultDimSize) && |
473 | expectedDimSize != resultDimSize) { |
474 | return emitError(loc) << "Dimension size mismatch for result axis " |
475 | << resultAxis << ". Expected " |
476 | << (ShapedType::isDynamic(expectedDimSize) |
477 | ? Twine("dynamic" ) |
478 | : Twine(expectedDimSize)) |
479 | << ", but got " << resultDimSize << "." ; |
480 | } |
481 | |
482 | return success(); |
483 | } |
484 | |
485 | static LogicalResult verifyGatherOperandAndResultShape( |
486 | Value operand, Value result, int64_t gatherAxis, |
487 | ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) { |
488 | auto resultRank = cast<ShapedType>(result.getType()).getRank(); |
489 | if (gatherAxis < 0 || gatherAxis >= resultRank) { |
490 | return emitError(loc: result.getLoc()) |
491 | << "Gather axis " << gatherAxis << " is out of bounds [0, " |
492 | << resultRank << ")." ; |
493 | } |
494 | |
495 | ShapedType operandType = cast<ShapedType>(operand.getType()); |
496 | ShapedType resultType = cast<ShapedType>(result.getType()); |
497 | auto deviceGroupSize = |
498 | DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape)); |
499 | for (int64_t axis = 0; axis < operandType.getRank(); ++axis) { |
500 | auto operandDimSize = DimensionSize(operandType.getDimSize(axis)); |
501 | auto resultDimSize = DimensionSize(resultType.getDimSize(axis)); |
502 | auto expectedResultDimSize = |
503 | axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize; |
504 | if (failed(verifyDimensionCompatibility( |
505 | result.getLoc(), expectedResultDimSize, resultDimSize, axis))) { |
506 | return failure(); |
507 | } |
508 | } |
509 | return success(); |
510 | } |
511 | |
512 | static LogicalResult verifyAllToAllOperandAndResultShape( |
513 | Value operand, Value result, int64_t splitAxis, int64_t concatAxis, |
514 | ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) { |
515 | ShapedType operandType = cast<ShapedType>(operand.getType()); |
516 | ShapedType resultType = cast<ShapedType>(result.getType()); |
517 | for (int64_t axis = 0; axis < operandType.getRank(); ++axis) { |
518 | if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) { |
519 | if (failed(verifyDimensionCompatibility( |
520 | result.getLoc(), operandType.getDimSize(axis), |
521 | resultType.getDimSize(axis), axis))) { |
522 | return failure(); |
523 | } |
524 | } |
525 | } |
526 | |
527 | if (splitAxis == concatAxis) { |
528 | return success(); |
529 | } |
530 | |
531 | auto deviceGroupSize = |
532 | DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape)); |
533 | auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis)); |
534 | auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis)); |
535 | DimensionSize expectedResultConcatDimSize = |
536 | operandConcatDimSize * deviceGroupSize; |
537 | DimensionSize expectedResultSplitDimSize = |
538 | operandSplitDimSize / deviceGroupSize; |
539 | if (!expectedResultSplitDimSize.isDynamic() && |
540 | int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) { |
541 | expectedResultSplitDimSize = DimensionSize::dynamic(); |
542 | } |
543 | if (failed(verifyDimensionCompatibility( |
544 | result.getLoc(), expectedResultConcatDimSize.value(), |
545 | resultType.getDimSize(concatAxis), concatAxis))) { |
546 | return failure(); |
547 | } |
548 | if (failed(verifyDimensionCompatibility( |
549 | result.getLoc(), expectedResultSplitDimSize.value(), |
550 | resultType.getDimSize(splitAxis), splitAxis))) { |
551 | return failure(); |
552 | } |
553 | |
554 | return success(); |
555 | } |
556 | |
557 | static LogicalResult verifyScatterOrSliceOperandAndResultShape( |
558 | Value operand, Value result, int64_t tensorAxis, |
559 | ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) { |
560 | ShapedType operandType = cast<ShapedType>(operand.getType()); |
561 | ShapedType resultType = cast<ShapedType>(result.getType()); |
562 | for (int64_t axis = 0; axis < operandType.getRank(); ++axis) { |
563 | if (axis != tensorAxis) { |
564 | if (failed(verifyDimensionCompatibility( |
565 | result.getLoc(), operandType.getDimSize(axis), |
566 | resultType.getDimSize(axis), axis))) { |
567 | return failure(); |
568 | } |
569 | } |
570 | } |
571 | |
572 | auto deviceGroupSize = |
573 | DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape)); |
574 | auto operandScatterDimSize = |
575 | DimensionSize(operandType.getDimSize(tensorAxis)); |
576 | if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() && |
577 | int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) { |
578 | return emitError(loc: result.getLoc()) |
579 | << "Operand dimension size " << int64_t(operandScatterDimSize) |
580 | << " is not divisible by collective device group size " |
581 | << int64_t(deviceGroupSize) << " for tensor axis " << tensorAxis |
582 | << "." ; |
583 | } |
584 | DimensionSize expectedResultTensorDimSize = |
585 | operandScatterDimSize / deviceGroupSize; |
586 | if (failed(verifyDimensionCompatibility( |
587 | result.getLoc(), expectedResultTensorDimSize.value(), |
588 | resultType.getDimSize(tensorAxis), tensorAxis))) { |
589 | return failure(); |
590 | } |
591 | |
592 | return success(); |
593 | } |
594 | |
595 | static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, |
596 | ArrayRef<MeshAxis> meshAxes, |
597 | int64_t sliceAxis) { |
598 | RankedTensorType operandRankedTensorType = |
599 | cast<RankedTensorType>(operandType); |
600 | DimensionSize operandSliceAxisSize = |
601 | operandRankedTensorType.getShape()[sliceAxis]; |
602 | SmallVector<int64_t> resultShape = |
603 | llvm::to_vector(operandRankedTensorType.getShape()); |
604 | |
605 | resultShape[sliceAxis] = |
606 | operandSliceAxisSize / |
607 | DimensionSize(collectiveProcessGroupSize(meshAxes, mesh)); |
608 | return operandRankedTensorType.clone(resultShape); |
609 | } |
610 | |
611 | //===----------------------------------------------------------------------===// |
612 | // mesh.all_gather op |
613 | //===----------------------------------------------------------------------===// |
614 | |
615 | LogicalResult |
616 | AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
617 | auto mesh = getMeshAndVerifyAxes(*this, symbolTable); |
618 | if (failed(mesh)) { |
619 | return failure(); |
620 | } |
621 | auto gatherAxis = getGatherAxis().getSExtValue(); |
622 | return verifyGatherOperandAndResultShape(getOperand(), getResult(), |
623 | gatherAxis, getMeshAxes(), |
624 | mesh.value().getShape()); |
625 | } |
626 | |
627 | void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
628 | MLIRContext *context) { |
629 | patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context); |
630 | } |
631 | |
632 | void AllGatherOp::getAsmResultNames( |
633 | function_ref<void(Value, StringRef)> setNameFn) { |
634 | setNameFn(getResult(), "all_gather" ); |
635 | } |
636 | |
637 | //===----------------------------------------------------------------------===// |
638 | // mesh.all_reduce op |
639 | //===----------------------------------------------------------------------===// |
640 | |
641 | LogicalResult |
642 | AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
643 | return getMeshAndVerifyAxes(*this, symbolTable); |
644 | } |
645 | |
646 | void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
647 | MLIRContext *context) { |
648 | patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context); |
649 | } |
650 | |
651 | void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState, |
652 | Value input, StringRef mesh, |
653 | ArrayRef<MeshAxis> meshAxes, ReductionKind reduction) { |
654 | build(odsBuilder, odsState, input.getType(), mesh, meshAxes, input, |
655 | reduction); |
656 | } |
657 | |
658 | void AllReduceOp::getAsmResultNames( |
659 | function_ref<void(Value, StringRef)> setNameFn) { |
660 | setNameFn(getResult(), "all_reduce" ); |
661 | } |
662 | |
663 | //===----------------------------------------------------------------------===// |
664 | // mesh.all_slice op |
665 | //===----------------------------------------------------------------------===// |
666 | |
667 | LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
668 | auto mesh = getMeshAndVerifyAxes(*this, symbolTable); |
669 | if (failed(mesh)) { |
670 | return failure(); |
671 | } |
672 | return verifyScatterOrSliceOperandAndResultShape( |
673 | getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(), |
674 | mesh.value().getShape()); |
675 | } |
676 | |
677 | void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
678 | MLIRContext *context) { |
679 | patterns.add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context); |
680 | } |
681 | |
682 | void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState, |
683 | Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes, |
684 | int64_t sliceAxis) { |
685 | Type resultType = sliceResultType(input.getType(), mesh, meshAxes, sliceAxis); |
686 | build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes, |
687 | sliceAxis); |
688 | } |
689 | |
690 | void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState, |
691 | Type resultType, Value input, StringRef mesh, |
692 | ArrayRef<MeshAxis> meshAxes, int64_t sliceAxis) { |
693 | build(odsBuilder, odsState, resultType, mesh, meshAxes, input, |
694 | APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis)); |
695 | } |
696 | |
697 | void AllSliceOp::getAsmResultNames( |
698 | function_ref<void(Value, StringRef)> setNameFn) { |
699 | setNameFn(getResult(), "all_slice" ); |
700 | } |
701 | |
702 | //===----------------------------------------------------------------------===// |
703 | // mesh.all_to_all op |
704 | //===----------------------------------------------------------------------===// |
705 | |
706 | LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
707 | auto mesh = getMeshAndVerifyAxes(*this, symbolTable); |
708 | if (failed(mesh)) { |
709 | return failure(); |
710 | } |
711 | |
712 | return verifyAllToAllOperandAndResultShape( |
713 | getOperand(), getResult(), getSplitAxis().getSExtValue(), |
714 | getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape()); |
715 | } |
716 | |
717 | void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
718 | MLIRContext *context) { |
719 | patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context); |
720 | } |
721 | |
722 | void AllToAllOp::getAsmResultNames( |
723 | function_ref<void(Value, StringRef)> setNameFn) { |
724 | setNameFn(getResult(), "all_to_all" ); |
725 | } |
726 | |
727 | //===----------------------------------------------------------------------===// |
728 | // mesh.broadcast op |
729 | //===----------------------------------------------------------------------===// |
730 | |
731 | LogicalResult |
732 | BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
733 | auto mesh = getMeshAndVerifyAxes(*this, symbolTable); |
734 | if (failed(mesh)) { |
735 | return failure(); |
736 | } |
737 | if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), |
738 | getRootDynamic(), getMeshAxes(), |
739 | mesh.value().getShape()))) { |
740 | return failure(); |
741 | } |
742 | |
743 | return success(); |
744 | } |
745 | |
746 | void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
747 | MLIRContext *context) { |
748 | patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context); |
749 | } |
750 | |
751 | void BroadcastOp::getAsmResultNames( |
752 | function_ref<void(Value, StringRef)> setNameFn) { |
753 | setNameFn(getResult(), "broadcast" ); |
754 | } |
755 | |
756 | //===----------------------------------------------------------------------===// |
757 | // mesh.gather op |
758 | //===----------------------------------------------------------------------===// |
759 | |
760 | LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
761 | auto mesh = getMeshAndVerifyAxes(*this, symbolTable); |
762 | if (failed(mesh)) { |
763 | return failure(); |
764 | } |
765 | if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), |
766 | getRootDynamic(), getMeshAxes(), |
767 | mesh.value().getShape()))) { |
768 | return failure(); |
769 | } |
770 | |
771 | auto gatherAxis = getGatherAxis().getSExtValue(); |
772 | return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis, |
773 | getMeshAxes(), |
774 | mesh.value().getShape()); |
775 | } |
776 | |
777 | void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
778 | MLIRContext *context) { |
779 | patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context); |
780 | } |
781 | |
782 | void GatherOp::getAsmResultNames( |
783 | function_ref<void(Value, StringRef)> setNameFn) { |
784 | setNameFn(getResult(), "gather" ); |
785 | } |
786 | |
787 | //===----------------------------------------------------------------------===// |
788 | // mesh.recv op |
789 | //===----------------------------------------------------------------------===// |
790 | |
791 | LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
792 | auto mesh = getMeshAndVerifyAxes(*this, symbolTable); |
793 | if (failed(mesh)) { |
794 | return failure(); |
795 | } |
796 | if (getSource() && |
797 | failed(verifyInGroupDevice(getLoc(), getSourceAttrName(), |
798 | getSource().value(), getSourceDynamic(), |
799 | getMeshAxes(), mesh.value().getShape()))) { |
800 | return failure(); |
801 | } |
802 | return success(); |
803 | } |
804 | |
805 | void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
806 | MLIRContext *context) { |
807 | patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context); |
808 | } |
809 | |
810 | void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { |
811 | setNameFn(getResult(), "recv" ); |
812 | } |
813 | |
814 | //===----------------------------------------------------------------------===// |
815 | // mesh.reduce op |
816 | //===----------------------------------------------------------------------===// |
817 | |
818 | LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
819 | auto mesh = getMeshAndVerifyAxes(*this, symbolTable); |
820 | if (failed(mesh)) { |
821 | return failure(); |
822 | } |
823 | if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), |
824 | getRootDynamic(), getMeshAxes(), |
825 | mesh.value().getShape()))) { |
826 | return failure(); |
827 | } |
828 | |
829 | return success(); |
830 | } |
831 | |
832 | void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
833 | MLIRContext *context) { |
834 | patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context); |
835 | } |
836 | |
837 | void ReduceOp::getAsmResultNames( |
838 | function_ref<void(Value, StringRef)> setNameFn) { |
839 | setNameFn(getResult(), "reduce" ); |
840 | } |
841 | |
842 | //===----------------------------------------------------------------------===// |
843 | // mesh.reduce_scatter op |
844 | //===----------------------------------------------------------------------===// |
845 | |
846 | LogicalResult |
847 | ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
848 | auto mesh = getMeshAndVerifyAxes(*this, symbolTable); |
849 | if (failed(mesh)) { |
850 | return failure(); |
851 | } |
852 | |
853 | return verifyScatterOrSliceOperandAndResultShape( |
854 | getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(), |
855 | mesh.value().getShape()); |
856 | } |
857 | |
858 | void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
859 | MLIRContext *context) { |
860 | patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context); |
861 | } |
862 | |
863 | void ReduceScatterOp::getAsmResultNames( |
864 | function_ref<void(Value, StringRef)> setNameFn) { |
865 | setNameFn(getResult(), "reduce_scatter" ); |
866 | } |
867 | |
868 | //===----------------------------------------------------------------------===// |
869 | // mesh.scatter op |
870 | //===----------------------------------------------------------------------===// |
871 | |
872 | LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
873 | auto mesh = getMeshAndVerifyAxes(*this, symbolTable); |
874 | if (failed(mesh)) { |
875 | return failure(); |
876 | } |
877 | if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), |
878 | getRootDynamic(), getMeshAxes(), |
879 | mesh.value().getShape()))) { |
880 | return failure(); |
881 | } |
882 | |
883 | auto scatterAxis = getScatterAxis().getSExtValue(); |
884 | return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(), |
885 | scatterAxis, getMeshAxes(), |
886 | mesh.value().getShape()); |
887 | } |
888 | |
889 | void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
890 | MLIRContext *context) { |
891 | patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context); |
892 | } |
893 | |
894 | void ScatterOp::getAsmResultNames( |
895 | function_ref<void(Value, StringRef)> setNameFn) { |
896 | setNameFn(getResult(), "scatter" ); |
897 | } |
898 | |
899 | //===----------------------------------------------------------------------===// |
900 | // mesh.send op |
901 | //===----------------------------------------------------------------------===// |
902 | |
903 | LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
904 | auto mesh = getMeshAndVerifyAxes(*this, symbolTable); |
905 | if (failed(mesh)) { |
906 | return failure(); |
907 | } |
908 | if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(), |
909 | getDestination(), getDestinationDynamic(), |
910 | getMeshAxes(), mesh.value().getShape()))) { |
911 | return failure(); |
912 | } |
913 | return success(); |
914 | } |
915 | |
916 | void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
917 | MLIRContext *context) { |
918 | patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context); |
919 | } |
920 | |
921 | void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { |
922 | setNameFn(getResult(), "send" ); |
923 | } |
924 | |
925 | //===----------------------------------------------------------------------===// |
926 | // mesh.shift op |
927 | //===----------------------------------------------------------------------===// |
928 | |
929 | LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
930 | auto mesh = getMeshAndVerifyAxes(*this, symbolTable); |
931 | if (failed(mesh)) { |
932 | return failure(); |
933 | } |
934 | |
935 | auto meshAxes = getMeshAxes(); |
936 | auto shiftAxis = getShiftAxis().getZExtValue(); |
937 | if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) { |
938 | return emitError() << "Invalid shift axis " << shiftAxis |
939 | << ". It must be one of the grouping mesh axes." ; |
940 | } |
941 | |
942 | return success(); |
943 | } |
944 | |
945 | void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
946 | MLIRContext *context) { |
947 | // TODO: remove op when offset is 0 or if it is a rotate with and |
948 | // offset % shift_axis_mesh_dim_size == 0. |
949 | } |
950 | |
951 | void ShiftOp::getAsmResultNames( |
952 | function_ref<void(Value, StringRef)> setNameFn) { |
953 | setNameFn(getResult(), "shift" ); |
954 | } |
955 | |
956 | //===----------------------------------------------------------------------===// |
957 | // TableGen'd op method definitions |
958 | //===----------------------------------------------------------------------===// |
959 | |
960 | #define GET_OP_CLASSES |
961 | #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc" |
962 | |
963 | #define GET_ATTRDEF_CLASSES |
964 | #include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc" |
965 | |
966 | #include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc" |
967 | |