| 1 | //===- TestPassStateExtensionCommunication.cpp ----------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file defines a test pass that showcases how communication can be |
| 10 | // conducted between a regular mlir pass and transform ops through the |
| 11 | // transform state extension stateInitializer and stateExporter mechanism. |
| 12 | // |
| 13 | //===----------------------------------------------------------------------===// |
| 14 | |
| 15 | #include "TestTransformStateExtension.h" |
| 16 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
| 17 | #include "mlir/IR/BuiltinOps.h" |
| 18 | #include "mlir/Pass/Pass.h" |
| 19 | |
| 20 | using namespace llvm; |
| 21 | using namespace mlir; |
| 22 | using namespace mlir::test; |
| 23 | |
| 24 | namespace { |
| 25 | template <typename Derived> |
| 26 | class OpPassWrapper : public PassWrapper<Derived, OperationPass<>> {}; |
| 27 | |
| 28 | struct TestPassStateExtensionCommunication |
| 29 | : public PassWrapper<TestPassStateExtensionCommunication, |
| 30 | OperationPass<ModuleOp>> { |
| 31 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
| 32 | TestPassStateExtensionCommunication) |
| 33 | |
| 34 | StringRef getArgument() const final { |
| 35 | return "test-pass-state-extension-communication" ; |
| 36 | } |
| 37 | |
| 38 | StringRef getDescription() const final { |
| 39 | return "test state communciation between a mlir pass and transform ops" ; |
| 40 | } |
| 41 | |
| 42 | static void printVector(const SmallVector<std::string> &opCollection, |
| 43 | const std::string & = {}) { |
| 44 | outs() << "Printing opCollection" << extraMessage |
| 45 | << ", size: " << opCollection.size() << "\n" ; |
| 46 | for (const auto &subVector : opCollection) { |
| 47 | outs() << subVector << " " ; |
| 48 | } |
| 49 | outs() << "\n" ; |
| 50 | } |
| 51 | |
| 52 | void runOnOperation() override { |
| 53 | ModuleOp module = getOperation(); |
| 54 | |
| 55 | // Create an opCollection vector. |
| 56 | SmallVector<std::string> opCollection = {"PASS-TRANSFORMOP-PASS " }; |
| 57 | printVector(opCollection, extraMessage: " before processing transform ops" ); |
| 58 | |
| 59 | auto stateInitializer = |
| 60 | [&opCollection](mlir::transform::TransformState &state) -> void { |
| 61 | TransformStateInitializerExtension *ext = |
| 62 | state.getExtension<TransformStateInitializerExtension>(); |
| 63 | if (!ext) |
| 64 | state.addExtension<TransformStateInitializerExtension>(args: 0, args&: opCollection); |
| 65 | }; |
| 66 | |
| 67 | auto stateExporter = |
| 68 | [&opCollection]( |
| 69 | mlir::transform::TransformState &state) -> LogicalResult { |
| 70 | TransformStateInitializerExtension *ext = |
| 71 | state.getExtension<TransformStateInitializerExtension>(); |
| 72 | if (!ext) { |
| 73 | errs() << "Target transform state extension not found!\n" ; |
| 74 | return failure(); |
| 75 | } |
| 76 | opCollection.clear(); |
| 77 | opCollection = ext->getRegisteredOps(); |
| 78 | return success(); |
| 79 | }; |
| 80 | |
| 81 | // Process transform ops with stateInitializer and stateExporter. |
| 82 | for (auto op : module.getBody()->getOps<transform::TransformOpInterface>()) |
| 83 | if (failed(transform::applyTransforms( |
| 84 | module, op, {}, mlir::transform::TransformOptions(), false, |
| 85 | stateInitializer, stateExporter))) |
| 86 | return signalPassFailure(); |
| 87 | |
| 88 | // Print the opCollection vector after processing transform ops. |
| 89 | printVector(opCollection, extraMessage: " after processing transform ops" ); |
| 90 | } |
| 91 | }; |
| 92 | } // namespace |
| 93 | |
| 94 | namespace mlir { |
| 95 | namespace test { |
| 96 | /// Registers the test pass here. |
| 97 | void registerTestPassStateExtensionCommunication() { |
| 98 | PassRegistration<TestPassStateExtensionCommunication> reg; |
| 99 | } |
| 100 | } // namespace test |
| 101 | } // namespace mlir |
| 102 | |