1 | //===- TestTensorLikeAndBufferLike.cpp - Bufferization Test -----*- c++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "TestDialect.h" |
10 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
11 | #include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h" |
12 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
13 | #include "mlir/IR/Attributes.h" |
14 | #include "mlir/IR/BuiltinAttributes.h" |
15 | #include "mlir/Pass/Pass.h" |
16 | |
17 | #include <string> |
18 | |
19 | using namespace mlir; |
20 | |
21 | namespace { |
22 | std::string getImplementationStatus(Type type) { |
23 | if (isa<bufferization::TensorLikeType>(type)) { |
24 | return "is_tensor_like" ; |
25 | } |
26 | if (isa<bufferization::BufferLikeType>(type)) { |
27 | return "is_buffer_like" ; |
28 | } |
29 | return {}; |
30 | } |
31 | |
32 | DictionaryAttr findAllImplementeesOfTensorOrBufferLike(func::FuncOp funcOp) { |
33 | llvm::SmallVector<NamedAttribute> attributes; |
34 | |
35 | const auto funcType = funcOp.getFunctionType(); |
36 | for (auto [index, inputType] : llvm::enumerate(funcType.getInputs())) { |
37 | const auto status = getImplementationStatus(inputType); |
38 | if (status.empty()) { |
39 | continue; |
40 | } |
41 | |
42 | attributes.push_back( |
43 | NamedAttribute(StringAttr::get(funcOp.getContext(), |
44 | "operand_" + std::to_string(index)), |
45 | StringAttr::get(funcOp.getContext(), status))); |
46 | } |
47 | |
48 | for (auto [index, resultType] : llvm::enumerate(funcType.getResults())) { |
49 | const auto status = getImplementationStatus(resultType); |
50 | if (status.empty()) { |
51 | continue; |
52 | } |
53 | |
54 | attributes.push_back(NamedAttribute( |
55 | StringAttr::get(funcOp.getContext(), "result_" + std::to_string(index)), |
56 | StringAttr::get(funcOp.getContext(), status))); |
57 | } |
58 | |
59 | return mlir::DictionaryAttr::get(funcOp.getContext(), attributes); |
60 | } |
61 | |
62 | /// This pass tests whether specified types implement TensorLike and (or) |
63 | /// BufferLike type interfaces defined in bufferization. |
64 | /// |
65 | /// The pass analyses operation signature. When the aforementioned interface |
66 | /// implementation found, an attribute is added to the operation, signifying the |
67 | /// associated operand / result. |
68 | struct TestTensorLikeAndBufferLikePass |
69 | : public PassWrapper<TestTensorLikeAndBufferLikePass, |
70 | OperationPass<ModuleOp>> { |
71 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTensorLikeAndBufferLikePass) |
72 | |
73 | void getDependentDialects(DialectRegistry ®istry) const override { |
74 | registry.insert<bufferization::BufferizationDialect, test::TestDialect>(); |
75 | } |
76 | StringRef getArgument() const final { return "test-tensorlike-bufferlike" ; } |
77 | StringRef getDescription() const final { |
78 | return "Module pass to test custom types that implement TensorLike / " |
79 | "BufferLike interfaces" ; |
80 | } |
81 | |
82 | void runOnOperation() override { |
83 | auto op = getOperation(); |
84 | |
85 | op.walk([](func::FuncOp funcOp) { |
86 | const auto dict = findAllImplementeesOfTensorOrBufferLike(funcOp); |
87 | if (!dict.empty()) { |
88 | funcOp->setAttr("found" , dict); |
89 | } |
90 | }); |
91 | } |
92 | }; |
93 | } // namespace |
94 | |
95 | namespace mlir::test { |
96 | void registerTestTensorLikeAndBufferLikePass() { |
97 | PassRegistration<TestTensorLikeAndBufferLikePass>(); |
98 | } |
99 | } // namespace mlir::test |
100 | |