1//===- TestSimplification.cpp - Test simplification -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Func/IR/FuncOps.h"
10#include "mlir/Dialect/Mesh/IR/MeshOps.h"
11#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
12#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
13#include "mlir/IR/BuiltinDialect.h"
14#include "mlir/IR/BuiltinOps.h"
15#include "mlir/IR/BuiltinTypeInterfaces.h"
16#include "mlir/IR/Diagnostics.h"
17#include "mlir/IR/ImplicitLocOpBuilder.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/IR/SymbolTable.h"
20#include "mlir/IR/Value.h"
21#include "mlir/Pass/Pass.h"
22#include "mlir/Support/LogicalResult.h"
23#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24
25using namespace mlir;
26using namespace mlir::mesh;
27
28namespace {
29
30struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
31 using OpRewritePattern<ShardOp>::OpRewritePattern;
32
33 LogicalResult matchAndRewrite(ShardOp op,
34 PatternRewriter &rewriter) const override {
35 if (op.getAnnotateForUsers()) {
36 return failure();
37 }
38
39 SymbolTableCollection symbolTable;
40 mesh::MeshOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
41 op, op.getShard().getMesh());
42
43 bool foundUser = false;
44 for (auto user : op->getUsers()) {
45 if (auto targetShardOp = llvm::dyn_cast<ShardOp>(user)) {
46 if (targetShardOp.getAnnotateForUsers() &&
47 mesh == symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
48 targetShardOp, targetShardOp.getShard().getMesh())) {
49 foundUser = true;
50 break;
51 }
52 }
53 }
54
55 if (!foundUser) {
56 return failure();
57 }
58
59 for (auto user : op->getUsers()) {
60 auto targetShardOp = llvm::dyn_cast<ShardOp>(user);
61 if (!targetShardOp || !targetShardOp.getAnnotateForUsers() ||
62 symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
63 targetShardOp, targetShardOp.getShard().getMesh()) != mesh) {
64 continue;
65 }
66
67 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
68 ShapedType sourceShardShape =
69 shardShapedType(op.getResult().getType(), mesh, op.getShard());
70 TypedValue<ShapedType> sourceShard = cast<TypedValue<ShapedType>>(
71 builder
72 .create<UnrealizedConversionCastOp>(sourceShardShape,
73 op.getOperand())
74 ->getResult(0));
75 TypedValue<ShapedType> targetShard =
76 reshard(builder, mesh, op, targetShardOp, sourceShard);
77 Value newTargetUnsharded =
78 builder
79 .create<UnrealizedConversionCastOp>(
80 targetShardOp.getResult().getType(), targetShard)
81 ->getResult(0);
82 rewriter.replaceAllUsesWith(targetShardOp.getResult(),
83 newTargetUnsharded);
84 }
85
86 return success();
87 }
88};
89
90struct TestMeshReshardingPass
91 : public PassWrapper<TestMeshReshardingPass, OperationPass<ModuleOp>> {
92 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshReshardingPass)
93
94 void runOnOperation() override {
95 RewritePatternSet patterns(&getContext());
96 patterns.insert<TestMeshReshardingRewritePattern>(arg: &getContext());
97 if (failed(applyPatternsAndFoldGreedily(getOperation().getOperation(),
98 std::move(patterns)))) {
99 return signalPassFailure();
100 }
101 }
102 void getDependentDialects(DialectRegistry &registry) const override {
103 reshardingRegisterDependentDialects(registry);
104 registry.insert<BuiltinDialect>();
105 }
106 StringRef getArgument() const final {
107 return "test-mesh-resharding-spmdization";
108 }
109 StringRef getDescription() const final {
110 return "Test Mesh dialect resharding spmdization.";
111 }
112};
113} // namespace
114
115namespace mlir {
116namespace test {
117void registerTestMeshReshardingSpmdizationPass() {
118 PassRegistration<TestMeshReshardingPass>();
119}
120} // namespace test
121} // namespace mlir
122

source code of mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp