1//===- TestTensorLikeAndBufferLike.cpp - Bufferization Test -----*- c++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "TestDialect.h"
10#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
12#include "mlir/Dialect/Func/IR/FuncOps.h"
13#include "mlir/IR/Attributes.h"
14#include "mlir/IR/BuiltinAttributes.h"
15#include "mlir/Pass/Pass.h"
16
17#include <string>
18
19using namespace mlir;
20
21namespace {
22std::string getImplementationStatus(Type type) {
23 if (isa<bufferization::TensorLikeType>(Val: type)) {
24 return "is_tensor_like";
25 }
26 if (isa<bufferization::BufferLikeType>(Val: type)) {
27 return "is_buffer_like";
28 }
29 return {};
30}
31
32DictionaryAttr findAllImplementeesOfTensorOrBufferLike(func::FuncOp funcOp) {
33 llvm::SmallVector<NamedAttribute> attributes;
34
35 const auto funcType = funcOp.getFunctionType();
36 for (auto [index, inputType] : llvm::enumerate(First: funcType.getInputs())) {
37 const auto status = getImplementationStatus(type: inputType);
38 if (status.empty()) {
39 continue;
40 }
41
42 attributes.push_back(
43 Elt: NamedAttribute(StringAttr::get(context: funcOp.getContext(),
44 bytes: "operand_" + std::to_string(val: index)),
45 StringAttr::get(context: funcOp.getContext(), bytes: status)));
46 }
47
48 for (auto [index, resultType] : llvm::enumerate(First: funcType.getResults())) {
49 const auto status = getImplementationStatus(type: resultType);
50 if (status.empty()) {
51 continue;
52 }
53
54 attributes.push_back(Elt: NamedAttribute(
55 StringAttr::get(context: funcOp.getContext(), bytes: "result_" + std::to_string(val: index)),
56 StringAttr::get(context: funcOp.getContext(), bytes: status)));
57 }
58
59 return mlir::DictionaryAttr::get(context: funcOp.getContext(), value: attributes);
60}
61
62/// This pass tests whether specified types implement TensorLike and (or)
63/// BufferLike type interfaces defined in bufferization.
64///
65/// The pass analyses operation signature. When the aforementioned interface
66/// implementation found, an attribute is added to the operation, signifying the
67/// associated operand / result.
68struct TestTensorLikeAndBufferLikePass
69 : public PassWrapper<TestTensorLikeAndBufferLikePass,
70 OperationPass<ModuleOp>> {
71 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTensorLikeAndBufferLikePass)
72
73 void getDependentDialects(DialectRegistry &registry) const override {
74 registry.insert<bufferization::BufferizationDialect, test::TestDialect>();
75 }
76 StringRef getArgument() const final { return "test-tensorlike-bufferlike"; }
77 StringRef getDescription() const final {
78 return "Module pass to test custom types that implement TensorLike / "
79 "BufferLike interfaces";
80 }
81
82 void runOnOperation() override {
83 auto op = getOperation();
84
85 op.walk(callback: [](func::FuncOp funcOp) {
86 const auto dict = findAllImplementeesOfTensorOrBufferLike(funcOp);
87 if (!dict.empty()) {
88 funcOp->setAttr(name: "found", value: dict);
89 }
90 });
91 }
92};
93} // namespace
94
95namespace mlir::test {
96void registerTestTensorLikeAndBufferLikePass() {
97 PassRegistration<TestTensorLikeAndBufferLikePass>();
98}
99} // namespace mlir::test
100

source code of mlir/test/lib/Dialect/Bufferization/TestTensorLikeAndBufferLike.cpp