1//===- TestSimplification.cpp - Test simplification -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Func/IR/FuncOps.h"
10#include "mlir/Dialect/Mesh/IR/MeshOps.h"
11#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
12#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
13#include "mlir/IR/BuiltinDialect.h"
14#include "mlir/IR/BuiltinOps.h"
15#include "mlir/IR/BuiltinTypeInterfaces.h"
16#include "mlir/IR/Diagnostics.h"
17#include "mlir/IR/ImplicitLocOpBuilder.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/IR/SymbolTable.h"
20#include "mlir/IR/Value.h"
21#include "mlir/Pass/Pass.h"
22#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23
24using namespace mlir;
25using namespace mlir::mesh;
26
27namespace {
28
29struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
30 using OpRewritePattern<ShardOp>::OpRewritePattern;
31
32 LogicalResult matchAndRewrite(ShardOp op,
33 PatternRewriter &rewriter) const override {
34 if (op.getAnnotateForUsers()) {
35 return failure();
36 }
37
38 SymbolTableCollection symbolTable;
39 mesh::MeshOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
40 op, cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr());
41
42 bool foundUser = false;
43 for (auto user : op->getUsers()) {
44 if (auto targetShardOp = llvm::dyn_cast<ShardOp>(user)) {
45 if (targetShardOp.getAnnotateForUsers() &&
46 mesh == symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
47 targetShardOp,
48 cast<ShardingOp>(
49 targetShardOp.getSharding().getDefiningOp())
50 .getMeshAttr())) {
51 foundUser = true;
52 break;
53 }
54 }
55 }
56
57 if (!foundUser) {
58 return failure();
59 }
60
61 for (auto user : op->getUsers()) {
62 auto targetShardOp = llvm::dyn_cast<ShardOp>(user);
63 if (!targetShardOp || !targetShardOp.getAnnotateForUsers() ||
64 symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
65 targetShardOp,
66 cast<ShardingOp>(targetShardOp.getSharding().getDefiningOp())
67 .getMeshAttr()) != mesh) {
68 continue;
69 }
70
71 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
72 ShapedType sourceShardShape =
73 shardShapedType(op.getResult().getType(), mesh, op.getSharding());
74 TypedValue<ShapedType> sourceShard = cast<TypedValue<ShapedType>>(
75 builder
76 .create<UnrealizedConversionCastOp>(sourceShardShape, op.getSrc())
77 ->getResult(0));
78 TypedValue<ShapedType> targetShard =
79 reshard(builder, mesh, op, targetShardOp, sourceShard);
80 Value newTargetUnsharded =
81 builder
82 .create<UnrealizedConversionCastOp>(
83 targetShardOp.getResult().getType(), targetShard)
84 ->getResult(0);
85 rewriter.replaceAllUsesWith(targetShardOp.getResult(),
86 newTargetUnsharded);
87 }
88
89 return success();
90 }
91};
92
93struct TestMeshReshardingPass
94 : public PassWrapper<TestMeshReshardingPass, OperationPass<ModuleOp>> {
95 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshReshardingPass)
96
97 void runOnOperation() override {
98 RewritePatternSet patterns(&getContext());
99 patterns.insert<TestMeshReshardingRewritePattern>(arg: &getContext());
100 if (failed(applyPatternsGreedily(getOperation().getOperation(),
101 std::move(patterns)))) {
102 return signalPassFailure();
103 }
104 }
105 void getDependentDialects(DialectRegistry &registry) const override {
106 reshardingRegisterDependentDialects(registry);
107 registry.insert<BuiltinDialect>();
108 }
109 StringRef getArgument() const final {
110 return "test-mesh-resharding-spmdization";
111 }
112 StringRef getDescription() const final {
113 return "Test Mesh dialect resharding spmdization.";
114 }
115};
116} // namespace
117
118namespace mlir {
119namespace test {
120void registerTestMeshReshardingSpmdizationPass() {
121 PassRegistration<TestMeshReshardingPass>();
122}
123} // namespace test
124} // namespace mlir
125

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp