1 | //===- TestSlicing.cpp - Testing slice functionality ----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a simple testing pass for slicing. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Analysis/SliceAnalysis.h" |
14 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
15 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
16 | #include "mlir/IR/BuiltinOps.h" |
17 | #include "mlir/IR/IRMapping.h" |
18 | #include "mlir/IR/PatternMatch.h" |
19 | #include "mlir/Pass/Pass.h" |
20 | #include "mlir/Support/LLVM.h" |
21 | |
22 | using namespace mlir; |
23 | |
24 | /// Create a function with the same signature as the parent function of `op` |
25 | /// with name being the function name and a `suffix`. |
26 | static LogicalResult createBackwardSliceFunction(Operation *op, |
27 | StringRef suffix, |
28 | bool omitBlockArguments) { |
29 | func::FuncOp parentFuncOp = op->getParentOfType<func::FuncOp>(); |
30 | OpBuilder builder(parentFuncOp); |
31 | Location loc = op->getLoc(); |
32 | std::string clonedFuncOpName = parentFuncOp.getName().str() + suffix.str(); |
33 | func::FuncOp clonedFuncOp = builder.create<func::FuncOp>( |
34 | loc, clonedFuncOpName, parentFuncOp.getFunctionType()); |
35 | IRMapping mapper; |
36 | builder.setInsertionPointToEnd(clonedFuncOp.addEntryBlock()); |
37 | for (const auto &arg : enumerate(parentFuncOp.getArguments())) |
38 | mapper.map(arg.value(), clonedFuncOp.getArgument(arg.index())); |
39 | SetVector<Operation *> slice; |
40 | BackwardSliceOptions options; |
41 | options.omitBlockArguments = omitBlockArguments; |
42 | // TODO: Make this default. |
43 | options.omitUsesFromAbove = false; |
44 | LogicalResult result = getBackwardSlice(op, backwardSlice: &slice, options); |
45 | assert(result.succeeded() && "expected a backward slice" ); |
46 | (void)result; |
47 | for (Operation *slicedOp : slice) |
48 | builder.clone(op&: *slicedOp, mapper); |
49 | builder.create<func::ReturnOp>(loc); |
50 | return success(); |
51 | } |
52 | |
53 | namespace { |
54 | /// Pass to test slice generated from slice analysis. |
55 | struct SliceAnalysisTestPass |
56 | : public PassWrapper<SliceAnalysisTestPass, OperationPass<ModuleOp>> { |
57 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SliceAnalysisTestPass) |
58 | |
59 | StringRef getArgument() const final { return "slice-analysis-test" ; } |
60 | StringRef getDescription() const final { |
61 | return "Test Slice analysis functionality." ; |
62 | } |
63 | |
64 | Option<bool> omitBlockArguments{ |
65 | *this, "omit-block-arguments" , |
66 | llvm::cl::desc("Test Slice analysis with multiple blocks but slice " |
67 | "omiting block arguments" ), |
68 | llvm::cl::init(Val: true)}; |
69 | |
70 | void runOnOperation() override; |
71 | SliceAnalysisTestPass() = default; |
72 | SliceAnalysisTestPass(const SliceAnalysisTestPass &) {} |
73 | }; |
74 | } // namespace |
75 | |
76 | void SliceAnalysisTestPass::runOnOperation() { |
77 | ModuleOp module = getOperation(); |
78 | auto funcOps = module.getOps<func::FuncOp>(); |
79 | unsigned opNum = 0; |
80 | for (auto funcOp : funcOps) { |
81 | // TODO: For now this is just looking for Linalg ops. It can be generalized |
82 | // to look for other ops using flags. |
83 | funcOp.walk([&](Operation *op) { |
84 | if (!isa<linalg::LinalgOp>(op)) |
85 | return WalkResult::advance(); |
86 | std::string append = |
87 | std::string("__backward_slice__" ) + std::to_string(opNum); |
88 | (void)createBackwardSliceFunction(op, append, omitBlockArguments); |
89 | opNum++; |
90 | return WalkResult::advance(); |
91 | }); |
92 | } |
93 | } |
94 | |
95 | namespace mlir { |
96 | void registerSliceAnalysisTestPass() { |
97 | PassRegistration<SliceAnalysisTestPass>(); |
98 | } |
99 | } // namespace mlir |
100 | |