1//===- TestPassStateExtensionCommunication.cpp ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines a test pass that showcases how communication can be
10// conducted between a regular mlir pass and transform ops through the
11// transform state extension stateInitializer and stateExporter mechanism.
12//
13//===----------------------------------------------------------------------===//
14
15#include "TestTransformStateExtension.h"
16#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
17#include "mlir/IR/BuiltinOps.h"
18#include "mlir/Pass/Pass.h"
19
20using namespace llvm;
21using namespace mlir;
22using namespace mlir::test;
23
24namespace {
25template <typename Derived>
26class OpPassWrapper : public PassWrapper<Derived, OperationPass<>> {};
27
28struct TestPassStateExtensionCommunication
29 : public PassWrapper<TestPassStateExtensionCommunication,
30 OperationPass<ModuleOp>> {
31 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
32 TestPassStateExtensionCommunication)
33
34 StringRef getArgument() const final {
35 return "test-pass-state-extension-communication";
36 }
37
38 StringRef getDescription() const final {
39 return "test state communciation between a mlir pass and transform ops";
40 }
41
42 static void printVector(const SmallVector<std::string> &opCollection,
43 const std::string &extraMessage = {}) {
44 outs() << "Printing opCollection" << extraMessage
45 << ", size: " << opCollection.size() << "\n";
46 for (const auto &subVector : opCollection) {
47 outs() << subVector << " ";
48 }
49 outs() << "\n";
50 }
51
52 void runOnOperation() override {
53 ModuleOp module = getOperation();
54
55 // Create an opCollection vector.
56 SmallVector<std::string> opCollection = {"PASS-TRANSFORMOP-PASS "};
57 printVector(opCollection, extraMessage: " before processing transform ops");
58
59 auto stateInitializer =
60 [&opCollection](mlir::transform::TransformState &state) -> void {
61 TransformStateInitializerExtension *ext =
62 state.getExtension<TransformStateInitializerExtension>();
63 if (!ext)
64 state.addExtension<TransformStateInitializerExtension>(args: 0, args&: opCollection);
65 };
66
67 auto stateExporter =
68 [&opCollection](
69 mlir::transform::TransformState &state) -> LogicalResult {
70 TransformStateInitializerExtension *ext =
71 state.getExtension<TransformStateInitializerExtension>();
72 if (!ext) {
73 errs() << "Target transform state extension not found!\n";
74 return failure();
75 }
76 opCollection.clear();
77 opCollection = ext->getRegisteredOps();
78 return success();
79 };
80
81 // Process transform ops with stateInitializer and stateExporter.
82 for (auto op : module.getBody()->getOps<transform::TransformOpInterface>())
83 if (failed(transform::applyTransforms(
84 module, op, {}, mlir::transform::TransformOptions(), false,
85 stateInitializer, stateExporter)))
86 return signalPassFailure();
87
88 // Print the opCollection vector after processing transform ops.
89 printVector(opCollection, extraMessage: " after processing transform ops");
90 }
91};
92} // namespace
93
94namespace mlir {
95namespace test {
96/// Registers the test pass here.
97void registerTestPassStateExtensionCommunication() {
98 PassRegistration<TestPassStateExtensionCommunication> reg;
99}
100} // namespace test
101} // namespace mlir
102

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp