1//===- TestShapeFunctions.cpp - Passes to test shape function ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <queue>
10
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/Dialect/Shape/IR/Shape.h"
13#include "mlir/IR/BuiltinDialect.h"
14#include "mlir/Interfaces/InferTypeOpInterface.h"
15#include "mlir/Pass/Pass.h"
16
17using namespace mlir;
18
19namespace {
20/// This is a pass that reports shape functions associated with ops.
21struct ReportShapeFnPass
22 : public PassWrapper<ReportShapeFnPass, OperationPass<ModuleOp>> {
23 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ReportShapeFnPass)
24
25 void runOnOperation() override;
26 StringRef getArgument() const final { return "test-shape-function-report"; }
27 StringRef getDescription() const final {
28 return "Test pass to report associated shape functions";
29 }
30};
31} // namespace
32
33void ReportShapeFnPass::runOnOperation() {
34 auto module = getOperation();
35
36 // Report the shape function available to refine the op.
37 auto shapeFnId = StringAttr::get(&getContext(), "shape.function");
38 auto remarkShapeFn = [&](shape::FunctionLibraryOp shapeFnLib, Operation *op) {
39 if (op->hasTrait<OpTrait::IsTerminator>())
40 return true;
41 if (auto typeInterface = dyn_cast<InferTypeOpInterface>(op)) {
42 op->emitRemark() << "implements InferType op interface";
43 return true;
44 }
45 if (auto fn = shapeFnLib.getShapeFunction(op)) {
46 op->emitRemark() << "associated shape function: " << fn.getName();
47 return true;
48 }
49 if (auto symbol = op->getAttrOfType<SymbolRefAttr>(shapeFnId)) {
50 auto fn =
51 cast<shape::FuncOp>(SymbolTable::lookupSymbolIn(module, symbol));
52 op->emitRemark() << "associated shape function: " << fn.getName();
53 return true;
54 }
55 return false;
56 };
57
58 // Lookup shape function library.
59 SmallVector<shape::FunctionLibraryOp, 4> libraries;
60 auto attr = module->getDiscardableAttr("shape.lib");
61 if (attr) {
62 auto lookup = [&](Attribute attr) {
63 return cast<shape::FunctionLibraryOp>(
64 SymbolTable::lookupSymbolIn(module, cast<SymbolRefAttr>(attr)));
65 };
66 if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
67 libraries.reserve(arrayAttr.size());
68 for (auto attr : arrayAttr)
69 libraries.push_back(lookup(attr));
70 } else {
71 libraries.reserve(1);
72 libraries.push_back(lookup(attr));
73 }
74 }
75
76 module.getBodyRegion().walk([&](func::FuncOp func) {
77 // Skip ops in the shape function library.
78 if (isa<shape::FunctionLibraryOp>(func->getParentOp()))
79 return;
80
81 func.walk([&](Operation *op) {
82 bool found = llvm::any_of(libraries, [&](shape::FunctionLibraryOp lib) {
83 return remarkShapeFn(lib, op);
84 });
85 if (!found)
86 op->emitRemark() << "no associated way to refine shape";
87 });
88 });
89}
90
91namespace mlir {
92void registerShapeFunctionTestPasses() {
93 PassRegistration<ReportShapeFnPass>();
94}
95} // namespace mlir
96

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp