1//===- TransformDialect.cpp - Transform Dialect Definition ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Transform/IR/TransformDialect.h"
10#include "mlir/Analysis/CallGraph.h"
11#include "mlir/Dialect/Transform/IR/TransformOps.h"
12#include "mlir/Dialect/Transform/IR/TransformTypes.h"
13#include "mlir/Dialect/Transform/IR/Utils.h"
14#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
15#include "mlir/IR/DialectImplementation.h"
16#include "llvm/ADT/SCCIterator.h"
17
18using namespace mlir;
19
20#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
21
22#ifndef NDEBUG
23void transform::detail::checkImplementsTransformOpInterface(
24 StringRef name, MLIRContext *context) {
25 // Since the operation is being inserted into the Transform dialect and the
26 // dialect does not implement the interface fallback, only check for the op
27 // itself having the interface implementation.
28 RegisteredOperationName opName =
29 *RegisteredOperationName::lookup(name, ctx: context);
30 assert((opName.hasInterface<TransformOpInterface>() ||
31 opName.hasInterface<PatternDescriptorOpInterface>() ||
32 opName.hasInterface<ConversionPatternDescriptorOpInterface>() ||
33 opName.hasInterface<TypeConverterBuilderOpInterface>() ||
34 opName.hasTrait<OpTrait::IsTerminator>()) &&
35 "non-terminator ops injected into the transform dialect must "
36 "implement TransformOpInterface or PatternDescriptorOpInterface or "
37 "ConversionPatternDescriptorOpInterface");
38 if (!opName.hasInterface<PatternDescriptorOpInterface>() &&
39 !opName.hasInterface<ConversionPatternDescriptorOpInterface>() &&
40 !opName.hasInterface<TypeConverterBuilderOpInterface>()) {
41 assert(opName.hasInterface<MemoryEffectOpInterface>() &&
42 "ops injected into the transform dialect must implement "
43 "MemoryEffectsOpInterface");
44 }
45}
46
47void transform::detail::checkImplementsTransformHandleTypeInterface(
48 TypeID typeID, MLIRContext *context) {
49 const auto &abstractType = AbstractType::lookup(typeID, context);
50 assert((abstractType.hasInterface(
51 TransformHandleTypeInterface::getInterfaceID()) ||
52 abstractType.hasInterface(
53 TransformParamTypeInterface::getInterfaceID()) ||
54 abstractType.hasInterface(
55 TransformValueHandleTypeInterface::getInterfaceID())) &&
56 "expected Transform dialect type to implement one of the three "
57 "interfaces");
58}
59#endif // NDEBUG
60
61void transform::TransformDialect::initialize() {
62 // Using the checked versions to enable the same assertions as for the ops
63 // from extensions.
64 addOperationsChecked<
65#define GET_OP_LIST
66#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
67 >();
68 initializeTypes();
69 initializeLibraryModule();
70}
71
72Type transform::TransformDialect::parseType(DialectAsmParser &parser) const {
73 StringRef keyword;
74 SMLoc loc = parser.getCurrentLocation();
75 if (failed(parser.parseKeyword(&keyword)))
76 return nullptr;
77
78 auto it = typeParsingHooks.find(keyword);
79 if (it == typeParsingHooks.end()) {
80 parser.emitError(loc) << "unknown type mnemonic: " << keyword;
81 return nullptr;
82 }
83
84 return it->getValue()(parser);
85}
86
87void transform::TransformDialect::printType(Type type,
88 DialectAsmPrinter &printer) const {
89 auto it = typePrintingHooks.find(type.getTypeID());
90 assert(it != typePrintingHooks.end() && "printing unknown type");
91 it->getSecond()(type, printer);
92}
93
94LogicalResult transform::TransformDialect::loadIntoLibraryModule(
95 ::mlir::OwningOpRef<::mlir::ModuleOp> &&library) {
96 return detail::mergeSymbolsInto(getLibraryModule(), std::move(library));
97}
98
99void transform::TransformDialect::initializeLibraryModule() {
100 MLIRContext *context = getContext();
101 auto loc =
102 FileLineColLoc::get(context, "<transform-dialect-library-module>", 0, 0);
103 libraryModule = ModuleOp::create(loc, "__transform_library");
104 libraryModule.get()->setAttr(TransformDialect::kWithNamedSequenceAttrName,
105 UnitAttr::get(context));
106}
107
108void transform::TransformDialect::reportDuplicateTypeRegistration(
109 StringRef mnemonic) {
110 std::string buffer;
111 llvm::raw_string_ostream msg(buffer);
112 msg << "extensible dialect type '" << mnemonic
113 << "' is already registered with a different implementation";
114 msg.flush();
115 llvm::report_fatal_error(StringRef(buffer));
116}
117
118void transform::TransformDialect::reportDuplicateOpRegistration(
119 StringRef opName) {
120 std::string buffer;
121 llvm::raw_string_ostream msg(buffer);
122 msg << "extensible dialect operation '" << opName
123 << "' is already registered with a mismatching TypeID";
124 msg.flush();
125 llvm::report_fatal_error(StringRef(buffer));
126}
127
128LogicalResult transform::TransformDialect::verifyOperationAttribute(
129 Operation *op, NamedAttribute attribute) {
130 if (attribute.getName().getValue() == kWithNamedSequenceAttrName) {
131 if (!op->hasTrait<OpTrait::SymbolTable>()) {
132 return emitError(op->getLoc()) << attribute.getName()
133 << " attribute can only be attached to "
134 "operations with symbol tables";
135 }
136
137 const mlir::CallGraph callgraph(op);
138 for (auto scc = llvm::scc_begin(&callgraph); !scc.isAtEnd(); ++scc) {
139 if (!scc.hasCycle())
140 continue;
141
142 // Need to check this here additionally because this verification may run
143 // before we check the nested operations.
144 if ((*scc->begin())->isExternal())
145 return op->emitOpError() << "contains a call to an external operation, "
146 "which is not allowed";
147
148 Operation *first = (*scc->begin())->getCallableRegion()->getParentOp();
149 InFlightDiagnostic diag = emitError(first->getLoc())
150 << "recursion not allowed in named sequences";
151 for (auto it = std::next(scc->begin()); it != scc->end(); ++it) {
152 // Need to check this here additionally because this verification may
153 // run before we check the nested operations.
154 if ((*it)->isExternal()) {
155 return op->emitOpError() << "contains a call to an external "
156 "operation, which is not allowed";
157 }
158
159 Operation *current = (*it)->getCallableRegion()->getParentOp();
160 diag.attachNote(current->getLoc()) << "operation on recursion stack";
161 }
162 return diag;
163 }
164 return success();
165 }
166 if (attribute.getName().getValue() == kTargetTagAttrName) {
167 if (!llvm::isa<StringAttr>(attribute.getValue())) {
168 return op->emitError()
169 << attribute.getName() << " attribute must be a string";
170 }
171 return success();
172 }
173 if (attribute.getName().getValue() == kArgConsumedAttrName ||
174 attribute.getName().getValue() == kArgReadOnlyAttrName) {
175 if (!llvm::isa<UnitAttr>(attribute.getValue())) {
176 return op->emitError()
177 << attribute.getName() << " must be a unit attribute";
178 }
179 return success();
180 }
181 if (attribute.getName().getValue() ==
182 FindPayloadReplacementOpInterface::kSilenceTrackingFailuresAttrName) {
183 if (!llvm::isa<UnitAttr>(attribute.getValue())) {
184 return op->emitError()
185 << attribute.getName() << " must be a unit attribute";
186 }
187 return success();
188 }
189 return emitError(op->getLoc())
190 << "unknown attribute: " << attribute.getName();
191}
192

source code of mlir/lib/Dialect/Transform/IR/TransformDialect.cpp