1 | //===- DebugExtensionOps.cpp - Debug extension for the Transform dialect --===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.h" |
10 | |
11 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
12 | #include "mlir/Dialect/Transform/IR/TransformTypes.h" |
13 | #include "mlir/IR/OpImplementation.h" |
14 | #include "llvm/Support/InterleavedRange.h" |
15 | |
16 | using namespace mlir; |
17 | |
18 | #define GET_OP_CLASSES |
19 | #include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp.inc" |
20 | |
21 | DiagnosedSilenceableFailure |
22 | transform::DebugEmitRemarkAtOp::apply(transform::TransformRewriter &rewriter, |
23 | transform::TransformResults &results, |
24 | transform::TransformState &state) { |
25 | if (isa<TransformHandleTypeInterface>(getAt().getType())) { |
26 | auto payload = state.getPayloadOps(getAt()); |
27 | for (Operation *op : payload) |
28 | op->emitRemark() << getMessage(); |
29 | return DiagnosedSilenceableFailure::success(); |
30 | } |
31 | |
32 | assert(isa<transform::TransformValueHandleTypeInterface>(getAt().getType()) && |
33 | "unhandled kind of transform type" ); |
34 | |
35 | auto describeValue = [](Diagnostic &os, Value value) { |
36 | os << "value handle points to " ; |
37 | if (auto arg = llvm::dyn_cast<BlockArgument>(value)) { |
38 | os << "a block argument #" << arg.getArgNumber() << " in block #" |
39 | << std::distance(arg.getOwner()->getParent()->begin(), |
40 | arg.getOwner()->getIterator()) |
41 | << " in region #" << arg.getOwner()->getParent()->getRegionNumber(); |
42 | } else { |
43 | os << "an op result #" << llvm::cast<OpResult>(value).getResultNumber(); |
44 | } |
45 | }; |
46 | |
47 | for (Value value : state.getPayloadValues(getAt())) { |
48 | InFlightDiagnostic diag = ::emitRemark(value.getLoc()) << getMessage(); |
49 | describeValue(diag.attachNote(), value); |
50 | } |
51 | |
52 | return DiagnosedSilenceableFailure::success(); |
53 | } |
54 | |
55 | DiagnosedSilenceableFailure transform::DebugEmitParamAsRemarkOp::apply( |
56 | transform::TransformRewriter &rewriter, |
57 | transform::TransformResults &results, transform::TransformState &state) { |
58 | std::string str; |
59 | llvm::raw_string_ostream os(str); |
60 | if (getMessage()) |
61 | os << *getMessage() << " " ; |
62 | os << llvm::interleaved(state.getParams(getParam())); |
63 | if (!getAnchor()) { |
64 | emitRemark() << str; |
65 | return DiagnosedSilenceableFailure::success(); |
66 | } |
67 | for (Operation *payload : state.getPayloadOps(getAnchor())) |
68 | ::mlir::emitRemark(payload->getLoc()) << str; |
69 | return DiagnosedSilenceableFailure::success(); |
70 | } |
71 | |