1//===- InferEffects.cpp - Infer memory effects for named symbols ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Transform/IR/TransformDialect.h"
10#include "mlir/Dialect/Transform/Transforms/Passes.h"
11
12#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
13#include "mlir/IR/Visitors.h"
14#include "mlir/Interfaces/FunctionInterfaces.h"
15#include "llvm/ADT/DenseSet.h"
16
17using namespace mlir;
18
19namespace mlir {
20namespace transform {
21#define GEN_PASS_DEF_INFEREFFECTSPASS
22#include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
23} // namespace transform
24} // namespace mlir
25
26static LogicalResult inferSideEffectAnnotations(Operation *op) {
27 if (!isa<transform::TransformOpInterface>(Val: op))
28 return success();
29
30 auto func = dyn_cast<FunctionOpInterface>(Val: op);
31 if (!func || func.isExternal())
32 return success();
33
34 if (!func.getFunctionBody().hasOneBlock()) {
35 return op->emitError()
36 << "only single-block operations are currently supported";
37 }
38
39 // Note that there can't be an inclusion of an unannotated symbol because it
40 // wouldn't have passed the verifier, so recursion isn't necessary here.
41 llvm::SmallDenseSet<unsigned> consumedArguments;
42 transform::getConsumedBlockArguments(block&: func.getFunctionBody().front(),
43 consumedArguments);
44
45 for (unsigned i = 0, e = func.getNumArguments(); i < e; ++i) {
46 func.setArgAttr(index: i,
47 name: consumedArguments.contains(V: i)
48 ? transform::TransformDialect::kArgConsumedAttrName
49 : transform::TransformDialect::kArgReadOnlyAttrName,
50 value: UnitAttr::get(context: op->getContext()));
51 }
52 return success();
53}
54
55namespace {
56class InferEffectsPass
57 : public transform::impl::InferEffectsPassBase<InferEffectsPass> {
58public:
59 void runOnOperation() override {
60 WalkResult result = getOperation()->walk(callback: [](Operation *op) {
61 return failed(Result: inferSideEffectAnnotations(op)) ? WalkResult::interrupt()
62 : WalkResult::advance();
63 });
64 if (result.wasInterrupted())
65 return signalPassFailure();
66 }
67};
68} // namespace
69

source code of mlir/lib/Dialect/Transform/Transforms/InferEffects.cpp