1//===- Preload.cpp - Test MlirOptMain parameterization ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h"
10#include "mlir/Dialect/Transform/IR/TransformDialect.h"
11#include "mlir/Dialect/Transform/IR/Utils.h"
12#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
13#include "mlir/IR/AsmState.h"
14#include "mlir/IR/DialectRegistry.h"
15#include "mlir/IR/Verifier.h"
16#include "mlir/Parser/Parser.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Pass/PassManager.h"
19#include "mlir/Support/FileUtilities.h"
20#include "mlir/Support/TypeID.h"
21#include "mlir/Tools/mlir-opt/MlirOptMain.h"
22#include "llvm/Support/MemoryBuffer.h"
23#include "llvm/Support/raw_ostream.h"
24#include "gtest/gtest.h"
25
26using namespace mlir;
27
28namespace mlir {
29namespace test {
30std::unique_ptr<Pass> createTestTransformDialectInterpreterPass();
31} // namespace test
32} // namespace mlir
33
34const static llvm::StringLiteral library = R"MLIR(
35module attributes {transform.with_named_sequence} {
36 transform.named_sequence private @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
37 transform.debug.emit_remark_at %arg0, "from external symbol" : !transform.any_op
38 transform.yield
39 }
40})MLIR";
41
42const static llvm::StringLiteral input = R"MLIR(
43module attributes {transform.with_named_sequence} {
44 transform.named_sequence private @__transform_main(%arg0: !transform.any_op {transform.readonly})
45
46 transform.sequence failures(propagate) {
47 ^bb0(%arg0: !transform.any_op):
48 include @__transform_main failures(propagate) (%arg0) : (!transform.any_op) -> ()
49 }
50})MLIR";
51
52TEST(Preload, ContextPreloadConstructedLibrary) {
53 registerPassManagerCLOptions();
54
55 MLIRContext context;
56 auto *dialect = context.getOrLoadDialect<transform::TransformDialect>();
57 DialectRegistry registry;
58 mlir::transform::registerDebugExtension(dialectRegistry&: registry);
59 registry.applyExtensions(ctx: &context);
60 ParserConfig parserConfig(&context);
61
62 OwningOpRef<ModuleOp> inputModule =
63 parseSourceString<ModuleOp>(sourceStr: input, config: parserConfig, sourceName: "<input>");
64 EXPECT_TRUE(inputModule) << "failed to parse input module";
65
66 OwningOpRef<ModuleOp> transformLibrary =
67 parseSourceString<ModuleOp>(sourceStr: library, config: parserConfig, sourceName: "<transform-library>");
68 EXPECT_TRUE(transformLibrary) << "failed to parse transform module";
69 LogicalResult diag =
70 dialect->loadIntoLibraryModule(std::move(transformLibrary));
71 EXPECT_TRUE(succeeded(diag));
72
73 ModuleOp retrievedTransformLibrary =
74 transform::detail::getPreloadedTransformModule(&context);
75 EXPECT_TRUE(retrievedTransformLibrary)
76 << "failed to retrieve transform module";
77
78 OwningOpRef<Operation *> clonedTransformModule(
79 retrievedTransformLibrary->clone());
80
81 LogicalResult res = transform::detail::mergeSymbolsInto(
82 inputModule->getOperation(), std::move(clonedTransformModule));
83 EXPECT_TRUE(succeeded(res)) << "failed to define declared symbols";
84
85 transform::TransformOpInterface entryPoint =
86 transform::detail::findTransformEntryPoint(inputModule->getOperation(),
87 retrievedTransformLibrary);
88 EXPECT_TRUE(entryPoint) << "failed to find entry point";
89
90 transform::TransformOptions options;
91 res = transform::applyTransformNamedSequence(
92 inputModule->getOperation(), entryPoint, retrievedTransformLibrary,
93 options);
94 EXPECT_TRUE(succeeded(res)) << "failed to apply named sequence";
95}
96

source code of mlir/unittests/Dialect/Transform/Preload.cpp