1//===- TestProcessMultiIndexOpLowering.cpp --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/IR/Arith.h"
10#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
11#include "mlir/Dialect/Utils/IndexingUtils.h"
12#include "mlir/IR/SymbolTable.h"
13#include "mlir/Pass/Pass.h"
14#include "mlir/Support/LogicalResult.h"
15#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16
17using namespace mlir;
18
19namespace {
20
21struct TestAllSliceOpLoweringPass
22 : public PassWrapper<TestAllSliceOpLoweringPass, OperationPass<>> {
23 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllSliceOpLoweringPass)
24
25 void runOnOperation() override {
26 RewritePatternSet patterns(&getContext());
27 SymbolTableCollection symbolTableCollection;
28 mesh::populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection);
29 LogicalResult status =
30 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
31 (void)status;
32 assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
33 }
34 void getDependentDialects(DialectRegistry &registry) const override {
35 mesh::registerAllSliceOpLoweringDialects(registry);
36 }
37 StringRef getArgument() const final {
38 return "test-mesh-all-slice-op-lowering";
39 }
40 StringRef getDescription() const final {
41 return "Test lowering of all-slice.";
42 }
43};
44
45struct TestMultiIndexOpLoweringPass
46 : public PassWrapper<TestMultiIndexOpLoweringPass, OperationPass<>> {
47 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMultiIndexOpLoweringPass)
48
49 void runOnOperation() override {
50 RewritePatternSet patterns(&getContext());
51 SymbolTableCollection symbolTableCollection;
52 mesh::populateProcessMultiIndexOpLoweringPatterns(patterns,
53 symbolTableCollection);
54 LogicalResult status =
55 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
56 (void)status;
57 assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
58 }
59 void getDependentDialects(DialectRegistry &registry) const override {
60 mesh::registerProcessMultiIndexOpLoweringDialects(registry);
61 }
62 StringRef getArgument() const final {
63 return "test-mesh-process-multi-index-op-lowering";
64 }
65 StringRef getDescription() const final {
66 return "Test lowering of mesh.process_multi_index op.";
67 }
68};
69
70} // namespace
71
72namespace mlir {
73namespace test {
74void registerTestOpLoweringPasses() {
75 PassRegistration<TestAllSliceOpLoweringPass>();
76 PassRegistration<TestMultiIndexOpLoweringPass>();
77}
78} // namespace test
79} // namespace mlir
80

source code of mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp