| 1 | //===- TestSlicing.cpp - Testing slice functionality ----------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements a simple testing pass for slicing. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Analysis/SliceAnalysis.h" |
| 14 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 15 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 16 | #include "mlir/IR/BuiltinOps.h" |
| 17 | #include "mlir/IR/IRMapping.h" |
| 18 | #include "mlir/IR/PatternMatch.h" |
| 19 | #include "mlir/Pass/Pass.h" |
| 20 | #include "mlir/Support/LLVM.h" |
| 21 | |
| 22 | using namespace mlir; |
| 23 | |
| 24 | /// Create a function with the same signature as the parent function of `op` |
| 25 | /// with name being the function name and a `suffix`. |
| 26 | static LogicalResult createBackwardSliceFunction(Operation *op, |
| 27 | StringRef suffix, |
| 28 | bool omitBlockArguments) { |
| 29 | func::FuncOp parentFuncOp = op->getParentOfType<func::FuncOp>(); |
| 30 | OpBuilder builder(parentFuncOp); |
| 31 | Location loc = op->getLoc(); |
| 32 | std::string clonedFuncOpName = parentFuncOp.getName().str() + suffix.str(); |
| 33 | func::FuncOp clonedFuncOp = builder.create<func::FuncOp>( |
| 34 | loc, clonedFuncOpName, parentFuncOp.getFunctionType()); |
| 35 | IRMapping mapper; |
| 36 | builder.setInsertionPointToEnd(clonedFuncOp.addEntryBlock()); |
| 37 | for (const auto &arg : enumerate(parentFuncOp.getArguments())) |
| 38 | mapper.map(arg.value(), clonedFuncOp.getArgument(arg.index())); |
| 39 | SetVector<Operation *> slice; |
| 40 | BackwardSliceOptions options; |
| 41 | options.omitBlockArguments = omitBlockArguments; |
| 42 | // TODO: Make this default. |
| 43 | options.omitUsesFromAbove = false; |
| 44 | LogicalResult result = getBackwardSlice(op, backwardSlice: &slice, options); |
| 45 | assert(result.succeeded() && "expected a backward slice" ); |
| 46 | (void)result; |
| 47 | for (Operation *slicedOp : slice) |
| 48 | builder.clone(op&: *slicedOp, mapper); |
| 49 | builder.create<func::ReturnOp>(loc); |
| 50 | return success(); |
| 51 | } |
| 52 | |
| 53 | namespace { |
| 54 | /// Pass to test slice generated from slice analysis. |
| 55 | struct SliceAnalysisTestPass |
| 56 | : public PassWrapper<SliceAnalysisTestPass, OperationPass<ModuleOp>> { |
| 57 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SliceAnalysisTestPass) |
| 58 | |
| 59 | StringRef getArgument() const final { return "slice-analysis-test" ; } |
| 60 | StringRef getDescription() const final { |
| 61 | return "Test Slice analysis functionality." ; |
| 62 | } |
| 63 | |
| 64 | Option<bool> omitBlockArguments{ |
| 65 | *this, "omit-block-arguments" , |
| 66 | llvm::cl::desc("Test Slice analysis with multiple blocks but slice " |
| 67 | "omiting block arguments" ), |
| 68 | llvm::cl::init(Val: true)}; |
| 69 | |
| 70 | void runOnOperation() override; |
| 71 | SliceAnalysisTestPass() = default; |
| 72 | SliceAnalysisTestPass(const SliceAnalysisTestPass &) {} |
| 73 | }; |
| 74 | } // namespace |
| 75 | |
| 76 | void SliceAnalysisTestPass::runOnOperation() { |
| 77 | ModuleOp module = getOperation(); |
| 78 | auto funcOps = module.getOps<func::FuncOp>(); |
| 79 | unsigned opNum = 0; |
| 80 | for (auto funcOp : funcOps) { |
| 81 | // TODO: For now this is just looking for Linalg ops. It can be generalized |
| 82 | // to look for other ops using flags. |
| 83 | funcOp.walk([&](Operation *op) { |
| 84 | if (!isa<linalg::LinalgOp>(op)) |
| 85 | return WalkResult::advance(); |
| 86 | std::string append = |
| 87 | std::string("__backward_slice__" ) + std::to_string(opNum); |
| 88 | (void)createBackwardSliceFunction(op, append, omitBlockArguments); |
| 89 | opNum++; |
| 90 | return WalkResult::advance(); |
| 91 | }); |
| 92 | } |
| 93 | } |
| 94 | |
| 95 | namespace mlir { |
| 96 | void registerSliceAnalysisTestPass() { |
| 97 | PassRegistration<SliceAnalysisTestPass>(); |
| 98 | } |
| 99 | } // namespace mlir |
| 100 | |