1 | |
2 | //===- DLTITransformOps.cpp - Implementation of DLTI transform ops --------===// |
3 | // |
4 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
5 | // See https://llvm.org/LICENSE.txt for license information. |
6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
7 | // |
8 | //===----------------------------------------------------------------------===// |
9 | |
10 | #include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h" |
11 | |
12 | #include "mlir/Dialect/DLTI/DLTI.h" |
13 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
14 | #include "mlir/Dialect/Transform/Utils/Utils.h" |
15 | #include "mlir/Interfaces/DataLayoutInterfaces.h" |
16 | |
17 | using namespace mlir; |
18 | using namespace mlir::transform; |
19 | |
20 | #define DEBUG_TYPE "dlti-transforms" |
21 | |
22 | //===----------------------------------------------------------------------===// |
23 | // QueryOp |
24 | //===----------------------------------------------------------------------===// |
25 | |
26 | void transform::QueryOp::getEffects( |
27 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
28 | onlyReadsHandle(getTargetMutable(), effects); |
29 | producesHandle(getOperation()->getOpResults(), effects); |
30 | onlyReadsPayload(effects); |
31 | } |
32 | |
33 | DiagnosedSilenceableFailure transform::QueryOp::applyToOne( |
34 | transform::TransformRewriter &rewriter, Operation *target, |
35 | transform::ApplyToEachResultList &results, TransformState &state) { |
36 | SmallVector<DataLayoutEntryKey> keys; |
37 | for (Attribute key : getKeys()) { |
38 | if (auto strKey = dyn_cast<StringAttr>(key)) |
39 | keys.push_back(strKey); |
40 | else if (auto typeKey = dyn_cast<TypeAttr>(key)) |
41 | keys.push_back(typeKey.getValue()); |
42 | else |
43 | return emitDefiniteFailure("'transform.dlti.query' keys of wrong type: " |
44 | "only StringAttr and TypeAttr are allowed" ); |
45 | } |
46 | |
47 | FailureOr<Attribute> result = dlti::query(target, keys, /*emitError=*/true); |
48 | |
49 | if (failed(result)) |
50 | return emitSilenceableFailure(getLoc(), |
51 | "'transform.dlti.query' op failed to apply" ); |
52 | |
53 | results.push_back(*result); |
54 | return DiagnosedSilenceableFailure::success(); |
55 | } |
56 | |
57 | //===----------------------------------------------------------------------===// |
58 | // Transform op registration |
59 | //===----------------------------------------------------------------------===// |
60 | |
61 | namespace { |
62 | class DLTITransformDialectExtension |
63 | : public transform::TransformDialectExtension< |
64 | DLTITransformDialectExtension> { |
65 | public: |
66 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DLTITransformDialectExtension) |
67 | |
68 | using Base::Base; |
69 | |
70 | void init() { |
71 | registerTransformOps< |
72 | #define GET_OP_LIST |
73 | #include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.cpp.inc" |
74 | >(); |
75 | } |
76 | }; |
77 | } // namespace |
78 | |
79 | #define GET_OP_CLASSES |
80 | #include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.cpp.inc" |
81 | |
82 | void mlir::dlti::registerTransformDialectExtension(DialectRegistry ®istry) { |
83 | registry.addExtensions<DLTITransformDialectExtension>(); |
84 | } |
85 | |