1//===- TestTypes.cpp - Test passes for MLIR types -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "TestTypes.h"
10#include "TestDialect.h"
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/Pass/Pass.h"
13
14using namespace mlir;
15using namespace test;
16
17namespace {
18struct TestRecursiveTypesPass
19 : public PassWrapper<TestRecursiveTypesPass, OperationPass<func::FuncOp>> {
20 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRecursiveTypesPass)
21
22 LogicalResult createIRWithTypes();
23
24 StringRef getArgument() const final { return "test-recursive-types"; }
25 StringRef getDescription() const final {
26 return "Test support for recursive types";
27 }
28 void runOnOperation() override {
29 func::FuncOp func = getOperation();
30
31 // Just make sure recursive types are printed and parsed.
32 if (func.getName() == "roundtrip")
33 return;
34
35 // Create a recursive type and print it as a part of a dummy op.
36 if (func.getName() == "create") {
37 if (failed(result: createIRWithTypes()))
38 signalPassFailure();
39 return;
40 }
41
42 // Unknown key.
43 func.emitOpError() << "unexpected function name";
44 signalPassFailure();
45 }
46};
47} // namespace
48
49LogicalResult TestRecursiveTypesPass::createIRWithTypes() {
50 MLIRContext *ctx = &getContext();
51 func::FuncOp func = getOperation();
52 auto type = TestRecursiveType::get(ctx, name: "some_long_and_unique_name");
53 if (failed(type.setBody(type)))
54 return func.emitError("expected to be able to set the type body");
55
56 // Setting the same body is fine.
57 if (failed(type.setBody(type)))
58 return func.emitError(
59 "expected to be able to set the type body to the same value");
60
61 // Setting a different body is not.
62 if (succeeded(type.setBody(IndexType::get(ctx))))
63 return func.emitError(
64 "not expected to be able to change function body more than once");
65
66 // Expecting to get the same type for the same name.
67 auto other = TestRecursiveType::get(ctx, name: "some_long_and_unique_name");
68 if (type != other)
69 return func.emitError("expected type name to be the uniquing key");
70
71 // Create the op to check how the type is printed.
72 OperationState state(func.getLoc(), "test.dummy_type_test_op");
73 state.addTypes(type);
74 func.getBody().front().push_front(Operation::create(state));
75
76 return success();
77}
78
79namespace mlir {
80namespace test {
81
82void registerTestRecursiveTypesPass() {
83 PassRegistration<TestRecursiveTypesPass>();
84}
85
86} // namespace test
87} // namespace mlir
88

source code of mlir/test/lib/IR/TestTypes.cpp