| 1 | //===- TestTensorLikeAndBufferLike.cpp - Bufferization Test -----*- c++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "TestDialect.h" |
| 10 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| 11 | #include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h" |
| 12 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 13 | #include "mlir/IR/Attributes.h" |
| 14 | #include "mlir/IR/BuiltinAttributes.h" |
| 15 | #include "mlir/Pass/Pass.h" |
| 16 | |
| 17 | #include <string> |
| 18 | |
| 19 | using namespace mlir; |
| 20 | |
| 21 | namespace { |
| 22 | std::string getImplementationStatus(Type type) { |
| 23 | if (isa<bufferization::TensorLikeType>(type)) { |
| 24 | return "is_tensor_like" ; |
| 25 | } |
| 26 | if (isa<bufferization::BufferLikeType>(type)) { |
| 27 | return "is_buffer_like" ; |
| 28 | } |
| 29 | return {}; |
| 30 | } |
| 31 | |
| 32 | DictionaryAttr findAllImplementeesOfTensorOrBufferLike(func::FuncOp funcOp) { |
| 33 | llvm::SmallVector<NamedAttribute> attributes; |
| 34 | |
| 35 | const auto funcType = funcOp.getFunctionType(); |
| 36 | for (auto [index, inputType] : llvm::enumerate(funcType.getInputs())) { |
| 37 | const auto status = getImplementationStatus(inputType); |
| 38 | if (status.empty()) { |
| 39 | continue; |
| 40 | } |
| 41 | |
| 42 | attributes.push_back( |
| 43 | NamedAttribute(StringAttr::get(funcOp.getContext(), |
| 44 | "operand_" + std::to_string(index)), |
| 45 | StringAttr::get(funcOp.getContext(), status))); |
| 46 | } |
| 47 | |
| 48 | for (auto [index, resultType] : llvm::enumerate(funcType.getResults())) { |
| 49 | const auto status = getImplementationStatus(resultType); |
| 50 | if (status.empty()) { |
| 51 | continue; |
| 52 | } |
| 53 | |
| 54 | attributes.push_back(NamedAttribute( |
| 55 | StringAttr::get(funcOp.getContext(), "result_" + std::to_string(index)), |
| 56 | StringAttr::get(funcOp.getContext(), status))); |
| 57 | } |
| 58 | |
| 59 | return mlir::DictionaryAttr::get(funcOp.getContext(), attributes); |
| 60 | } |
| 61 | |
| 62 | /// This pass tests whether specified types implement TensorLike and (or) |
| 63 | /// BufferLike type interfaces defined in bufferization. |
| 64 | /// |
| 65 | /// The pass analyses operation signature. When the aforementioned interface |
| 66 | /// implementation found, an attribute is added to the operation, signifying the |
| 67 | /// associated operand / result. |
| 68 | struct TestTensorLikeAndBufferLikePass |
| 69 | : public PassWrapper<TestTensorLikeAndBufferLikePass, |
| 70 | OperationPass<ModuleOp>> { |
| 71 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTensorLikeAndBufferLikePass) |
| 72 | |
| 73 | void getDependentDialects(DialectRegistry ®istry) const override { |
| 74 | registry.insert<bufferization::BufferizationDialect, test::TestDialect>(); |
| 75 | } |
| 76 | StringRef getArgument() const final { return "test-tensorlike-bufferlike" ; } |
| 77 | StringRef getDescription() const final { |
| 78 | return "Module pass to test custom types that implement TensorLike / " |
| 79 | "BufferLike interfaces" ; |
| 80 | } |
| 81 | |
| 82 | void runOnOperation() override { |
| 83 | auto op = getOperation(); |
| 84 | |
| 85 | op.walk([](func::FuncOp funcOp) { |
| 86 | const auto dict = findAllImplementeesOfTensorOrBufferLike(funcOp); |
| 87 | if (!dict.empty()) { |
| 88 | funcOp->setAttr("found" , dict); |
| 89 | } |
| 90 | }); |
| 91 | } |
| 92 | }; |
| 93 | } // namespace |
| 94 | |
| 95 | namespace mlir::test { |
| 96 | void registerTestTensorLikeAndBufferLikePass() { |
| 97 | PassRegistration<TestTensorLikeAndBufferLikePass>(); |
| 98 | } |
| 99 | } // namespace mlir::test |
| 100 | |