1//===- InferEffects.cpp - Infer memory effects for named symbols ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Transform/IR/TransformDialect.h"
10#include "mlir/Dialect/Transform/Transforms/Passes.h"
11
12#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
13#include "mlir/IR/Visitors.h"
14#include "mlir/Interfaces/FunctionInterfaces.h"
15#include "mlir/Interfaces/SideEffectInterfaces.h"
16#include "llvm/ADT/DenseSet.h"
17
18using namespace mlir;
19
20namespace mlir {
21namespace transform {
22#define GEN_PASS_DEF_INFEREFFECTSPASS
23#include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
24} // namespace transform
25} // namespace mlir
26
27static LogicalResult inferSideEffectAnnotations(Operation *op) {
28 if (!isa<transform::TransformOpInterface>(op))
29 return success();
30
31 auto func = dyn_cast<FunctionOpInterface>(op);
32 if (!func || func.isExternal())
33 return success();
34
35 if (!func.getFunctionBody().hasOneBlock()) {
36 return op->emitError()
37 << "only single-block operations are currently supported";
38 }
39
40 // Note that there can't be an inclusion of an unannotated symbol because it
41 // wouldn't have passed the verifier, so recursion isn't necessary here.
42 llvm::SmallDenseSet<unsigned> consumedArguments;
43 transform::getConsumedBlockArguments(block&: func.getFunctionBody().front(),
44 consumedArguments&: consumedArguments);
45
46 for (unsigned i = 0, e = func.getNumArguments(); i < e; ++i) {
47 func.setArgAttr(i,
48 consumedArguments.contains(i)
49 ? transform::TransformDialect::kArgConsumedAttrName
50 : transform::TransformDialect::kArgReadOnlyAttrName,
51 UnitAttr::get(op->getContext()));
52 }
53 return success();
54}
55
56namespace {
57class InferEffectsPass
58 : public transform::impl::InferEffectsPassBase<InferEffectsPass> {
59public:
60 void runOnOperation() override {
61 WalkResult result = getOperation()->walk([](Operation *op) {
62 return failed(result: inferSideEffectAnnotations(op)) ? WalkResult::interrupt()
63 : WalkResult::advance();
64 });
65 if (result.wasInterrupted())
66 return signalPassFailure();
67 }
68};
69} // namespace
70

source code of mlir/lib/Dialect/Transform/Transforms/InferEffects.cpp