1 | //===- TestTransformStateExtension.h - Test Utility -------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines an TransformState extension for the purpose of testing the |
10 | // relevant APIs. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #ifndef MLIR_TEST_LIB_DIALECT_TRANSFORM_TESTTRANSFORMSTATEEXTENSION_H |
15 | #define MLIR_TEST_LIB_DIALECT_TRANSFORM_TESTTRANSFORMSTATEEXTENSION_H |
16 | |
17 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
18 | |
19 | using namespace mlir; |
20 | |
21 | namespace mlir { |
22 | namespace test { |
23 | class TestTransformStateExtension |
24 | : public transform::TransformState::Extension { |
25 | public: |
26 | TestTransformStateExtension(transform::TransformState &state, |
27 | StringAttr message) |
28 | : Extension(state), message(message) {} |
29 | |
30 | StringRef getMessage() const { return message.getValue(); } |
31 | |
32 | LogicalResult updateMapping(Operation *previous, Operation *updated); |
33 | |
34 | private: |
35 | StringAttr message; |
36 | }; |
37 | |
38 | class TransformStateInitializerExtension |
39 | : public transform::TransformState::Extension { |
40 | public: |
41 | TransformStateInitializerExtension(transform::TransformState &state, |
42 | int numOp, |
43 | SmallVector<std::string> ®isteredOps) |
44 | : Extension(state), numOp(numOp), registeredOps(registeredOps) {} |
45 | |
46 | int getNumOp() { return numOp; } |
47 | void setNumOp(int num) { numOp = num; } |
48 | SmallVector<std::string> getRegisteredOps() { return registeredOps; } |
49 | void pushRegisteredOps(const std::string &newOp) { |
50 | registeredOps.push_back(Elt: newOp); |
51 | } |
52 | std::string printMessage() const { |
53 | std::string message = "Registered transformOps are: " ; |
54 | for (const auto &op : registeredOps) { |
55 | message += op + " | " ; |
56 | } |
57 | return message; |
58 | } |
59 | |
60 | private: |
61 | int numOp; |
62 | SmallVector<std::string> registeredOps; |
63 | }; |
64 | |
65 | } // namespace test |
66 | } // namespace mlir |
67 | |
68 | #endif // MLIR_TEST_LIB_DIALECT_TRANSFORM_TESTTRANSFORMSTATEEXTENSION_H |
69 | |