1//===- TestTensorLikeAndBufferLike.cpp - Bufferization Test -----*- c++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "TestDialect.h"
10#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
12#include "mlir/Dialect/Func/IR/FuncOps.h"
13#include "mlir/IR/Attributes.h"
14#include "mlir/IR/BuiltinAttributes.h"
15#include "mlir/Pass/Pass.h"
16
17#include <string>
18
19using namespace mlir;
20
21namespace {
22std::string getImplementationStatus(Type type) {
23 if (isa<bufferization::TensorLikeType>(type)) {
24 return "is_tensor_like";
25 }
26 if (isa<bufferization::BufferLikeType>(type)) {
27 return "is_buffer_like";
28 }
29 return {};
30}
31
32DictionaryAttr findAllImplementeesOfTensorOrBufferLike(func::FuncOp funcOp) {
33 llvm::SmallVector<NamedAttribute> attributes;
34
35 const auto funcType = funcOp.getFunctionType();
36 for (auto [index, inputType] : llvm::enumerate(funcType.getInputs())) {
37 const auto status = getImplementationStatus(inputType);
38 if (status.empty()) {
39 continue;
40 }
41
42 attributes.push_back(
43 NamedAttribute(StringAttr::get(funcOp.getContext(),
44 "operand_" + std::to_string(index)),
45 StringAttr::get(funcOp.getContext(), status)));
46 }
47
48 for (auto [index, resultType] : llvm::enumerate(funcType.getResults())) {
49 const auto status = getImplementationStatus(resultType);
50 if (status.empty()) {
51 continue;
52 }
53
54 attributes.push_back(NamedAttribute(
55 StringAttr::get(funcOp.getContext(), "result_" + std::to_string(index)),
56 StringAttr::get(funcOp.getContext(), status)));
57 }
58
59 return mlir::DictionaryAttr::get(funcOp.getContext(), attributes);
60}
61
62/// This pass tests whether specified types implement TensorLike and (or)
63/// BufferLike type interfaces defined in bufferization.
64///
65/// The pass analyses operation signature. When the aforementioned interface
66/// implementation found, an attribute is added to the operation, signifying the
67/// associated operand / result.
68struct TestTensorLikeAndBufferLikePass
69 : public PassWrapper<TestTensorLikeAndBufferLikePass,
70 OperationPass<ModuleOp>> {
71 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTensorLikeAndBufferLikePass)
72
73 void getDependentDialects(DialectRegistry &registry) const override {
74 registry.insert<bufferization::BufferizationDialect, test::TestDialect>();
75 }
76 StringRef getArgument() const final { return "test-tensorlike-bufferlike"; }
77 StringRef getDescription() const final {
78 return "Module pass to test custom types that implement TensorLike / "
79 "BufferLike interfaces";
80 }
81
82 void runOnOperation() override {
83 auto op = getOperation();
84
85 op.walk([](func::FuncOp funcOp) {
86 const auto dict = findAllImplementeesOfTensorOrBufferLike(funcOp);
87 if (!dict.empty()) {
88 funcOp->setAttr("found", dict);
89 }
90 });
91 }
92};
93} // namespace
94
95namespace mlir::test {
96void registerTestTensorLikeAndBufferLikePass() {
97 PassRegistration<TestTensorLikeAndBufferLikePass>();
98}
99} // namespace mlir::test
100

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/test/lib/Dialect/Bufferization/TestTensorLikeAndBufferLike.cpp