1//===- TestProcessMultiIndexOpLowering.cpp --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/IR/Arith.h"
10#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
11#include "mlir/Dialect/Utils/IndexingUtils.h"
12#include "mlir/IR/SymbolTable.h"
13#include "mlir/Pass/Pass.h"
14#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15
16using namespace mlir;
17
18namespace {
19
20struct TestAllSliceOpLoweringPass
21 : public PassWrapper<TestAllSliceOpLoweringPass, OperationPass<>> {
22 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllSliceOpLoweringPass)
23
24 void runOnOperation() override {
25 RewritePatternSet patterns(&getContext());
26 SymbolTableCollection symbolTableCollection;
27 mesh::populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection);
28 LogicalResult status =
29 applyPatternsGreedily(getOperation(), std::move(patterns));
30 (void)status;
31 assert(succeeded(status) && "applyPatternsGreedily failed.");
32 }
33 void getDependentDialects(DialectRegistry &registry) const override {
34 mesh::registerAllSliceOpLoweringDialects(registry);
35 }
36 StringRef getArgument() const final {
37 return "test-mesh-all-slice-op-lowering";
38 }
39 StringRef getDescription() const final {
40 return "Test lowering of all-slice.";
41 }
42};
43
44struct TestMultiIndexOpLoweringPass
45 : public PassWrapper<TestMultiIndexOpLoweringPass, OperationPass<>> {
46 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMultiIndexOpLoweringPass)
47
48 void runOnOperation() override {
49 RewritePatternSet patterns(&getContext());
50 SymbolTableCollection symbolTableCollection;
51 mesh::populateProcessMultiIndexOpLoweringPatterns(patterns,
52 symbolTableCollection);
53 LogicalResult status =
54 applyPatternsGreedily(getOperation(), std::move(patterns));
55 (void)status;
56 assert(succeeded(status) && "applyPatternsGreedily failed.");
57 }
58 void getDependentDialects(DialectRegistry &registry) const override {
59 mesh::registerProcessMultiIndexOpLoweringDialects(registry);
60 }
61 StringRef getArgument() const final {
62 return "test-mesh-process-multi-index-op-lowering";
63 }
64 StringRef getDescription() const final {
65 return "Test lowering of mesh.process_multi_index op.";
66 }
67};
68
69} // namespace
70
71namespace mlir {
72namespace test {
73void registerTestOpLoweringPasses() {
74 PassRegistration<TestAllSliceOpLoweringPass>();
75 PassRegistration<TestMultiIndexOpLoweringPass>();
76}
77} // namespace test
78} // namespace mlir
79

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp