1//===- TransformDialect.cpp - Transform Dialect Definition ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Transform/IR/TransformDialect.h"
10#include "mlir/Analysis/CallGraph.h"
11#include "mlir/Dialect/Transform/IR/TransformOps.h"
12#include "mlir/Dialect/Transform/IR/TransformTypes.h"
13#include "mlir/Dialect/Transform/IR/Utils.h"
14#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
15#include "mlir/IR/DialectImplementation.h"
16#include "llvm/ADT/SCCIterator.h"
17#include "llvm/ADT/TypeSwitch.h"
18
19using namespace mlir;
20
21#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
22
23#define GET_ATTRDEF_CLASSES
24#include "mlir/Dialect/Transform/IR/TransformAttrs.cpp.inc"
25
26#ifndef NDEBUG
27void transform::detail::checkImplementsTransformOpInterface(
28 StringRef name, MLIRContext *context) {
29 // Since the operation is being inserted into the Transform dialect and the
30 // dialect does not implement the interface fallback, only check for the op
31 // itself having the interface implementation.
32 RegisteredOperationName opName =
33 *RegisteredOperationName::lookup(name, context);
34 assert((opName.hasInterface<TransformOpInterface>() ||
35 opName.hasInterface<PatternDescriptorOpInterface>() ||
36 opName.hasInterface<ConversionPatternDescriptorOpInterface>() ||
37 opName.hasInterface<TypeConverterBuilderOpInterface>() ||
38 opName.hasTrait<OpTrait::IsTerminator>()) &&
39 "non-terminator ops injected into the transform dialect must "
40 "implement TransformOpInterface or PatternDescriptorOpInterface or "
41 "ConversionPatternDescriptorOpInterface");
42 if (!opName.hasInterface<PatternDescriptorOpInterface>() &&
43 !opName.hasInterface<ConversionPatternDescriptorOpInterface>() &&
44 !opName.hasInterface<TypeConverterBuilderOpInterface>()) {
45 assert(opName.hasInterface<MemoryEffectOpInterface>() &&
46 "ops injected into the transform dialect must implement "
47 "MemoryEffectsOpInterface");
48 }
49}
50
51void transform::detail::checkImplementsTransformHandleTypeInterface(
52 TypeID typeID, MLIRContext *context) {
53 const auto &abstractType = AbstractType::lookup(typeID, context);
54 assert((abstractType.hasInterface(
55 TransformHandleTypeInterface::getInterfaceID()) ||
56 abstractType.hasInterface(
57 TransformParamTypeInterface::getInterfaceID()) ||
58 abstractType.hasInterface(
59 TransformValueHandleTypeInterface::getInterfaceID())) &&
60 "expected Transform dialect type to implement one of the three "
61 "interfaces");
62}
63#endif // NDEBUG
64
65void transform::TransformDialect::initialize() {
66 // Using the checked versions to enable the same assertions as for the ops
67 // from extensions.
68 addOperationsChecked<
69#define GET_OP_LIST
70#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
71 >();
72 initializeTypes();
73 addAttributes<
74#define GET_ATTRDEF_LIST
75#include "mlir/Dialect/Transform/IR/TransformAttrs.cpp.inc"
76 >();
77 initializeLibraryModule();
78}
79
80Type transform::TransformDialect::parseType(DialectAsmParser &parser) const {
81 StringRef keyword;
82 SMLoc loc = parser.getCurrentLocation();
83 if (failed(Result: parser.parseKeyword(keyword: &keyword)))
84 return nullptr;
85
86 auto it = typeParsingHooks.find(Key: keyword);
87 if (it == typeParsingHooks.end()) {
88 parser.emitError(loc) << "unknown type mnemonic: " << keyword;
89 return nullptr;
90 }
91
92 return it->getValue()(parser);
93}
94
95void transform::TransformDialect::printType(Type type,
96 DialectAsmPrinter &printer) const {
97 auto it = typePrintingHooks.find(Val: type.getTypeID());
98 assert(it != typePrintingHooks.end() && "printing unknown type");
99 it->getSecond()(type, printer);
100}
101
102LogicalResult transform::TransformDialect::loadIntoLibraryModule(
103 ::mlir::OwningOpRef<::mlir::ModuleOp> &&library) {
104 return detail::mergeSymbolsInto(target: getLibraryModule(), other: std::move(library));
105}
106
107void transform::TransformDialect::initializeLibraryModule() {
108 MLIRContext *context = getContext();
109 auto loc =
110 FileLineColLoc::get(context, fileName: "<transform-dialect-library-module>", line: 0, column: 0);
111 libraryModule = ModuleOp::create(loc, name: "__transform_library");
112 libraryModule.get()->setAttr(name: TransformDialect::kWithNamedSequenceAttrName,
113 value: UnitAttr::get(context));
114}
115
116void transform::TransformDialect::reportDuplicateTypeRegistration(
117 StringRef mnemonic) {
118 std::string buffer;
119 llvm::raw_string_ostream msg(buffer);
120 msg << "extensible dialect type '" << mnemonic
121 << "' is already registered with a different implementation";
122 llvm::report_fatal_error(reason: StringRef(buffer));
123}
124
125void transform::TransformDialect::reportDuplicateOpRegistration(
126 StringRef opName) {
127 std::string buffer;
128 llvm::raw_string_ostream msg(buffer);
129 msg << "extensible dialect operation '" << opName
130 << "' is already registered with a mismatching TypeID";
131 llvm::report_fatal_error(reason: StringRef(buffer));
132}
133
134LogicalResult transform::TransformDialect::verifyOperationAttribute(
135 Operation *op, NamedAttribute attribute) {
136 if (attribute.getName().getValue() == kWithNamedSequenceAttrName) {
137 if (!op->hasTrait<OpTrait::SymbolTable>()) {
138 return emitError(loc: op->getLoc()) << attribute.getName()
139 << " attribute can only be attached to "
140 "operations with symbol tables";
141 }
142
143 const mlir::CallGraph callgraph(op);
144 for (auto scc = llvm::scc_begin(G: &callgraph); !scc.isAtEnd(); ++scc) {
145 if (!scc.hasCycle())
146 continue;
147
148 // Need to check this here additionally because this verification may run
149 // before we check the nested operations.
150 if ((*scc->begin())->isExternal())
151 return op->emitOpError() << "contains a call to an external operation, "
152 "which is not allowed";
153
154 Operation *first = (*scc->begin())->getCallableRegion()->getParentOp();
155 InFlightDiagnostic diag = emitError(loc: first->getLoc())
156 << "recursion not allowed in named sequences";
157 for (auto it = std::next(x: scc->begin()); it != scc->end(); ++it) {
158 // Need to check this here additionally because this verification may
159 // run before we check the nested operations.
160 if ((*it)->isExternal()) {
161 return op->emitOpError() << "contains a call to an external "
162 "operation, which is not allowed";
163 }
164
165 Operation *current = (*it)->getCallableRegion()->getParentOp();
166 diag.attachNote(noteLoc: current->getLoc()) << "operation on recursion stack";
167 }
168 return diag;
169 }
170 return success();
171 }
172 if (attribute.getName().getValue() == kTargetTagAttrName) {
173 if (!llvm::isa<StringAttr>(Val: attribute.getValue())) {
174 return op->emitError()
175 << attribute.getName() << " attribute must be a string";
176 }
177 return success();
178 }
179 if (attribute.getName().getValue() == kArgConsumedAttrName ||
180 attribute.getName().getValue() == kArgReadOnlyAttrName) {
181 if (!llvm::isa<UnitAttr>(Val: attribute.getValue())) {
182 return op->emitError()
183 << attribute.getName() << " must be a unit attribute";
184 }
185 return success();
186 }
187 if (attribute.getName().getValue() ==
188 FindPayloadReplacementOpInterface::kSilenceTrackingFailuresAttrName) {
189 if (!llvm::isa<UnitAttr>(Val: attribute.getValue())) {
190 return op->emitError()
191 << attribute.getName() << " must be a unit attribute";
192 }
193 return success();
194 }
195 return emitError(loc: op->getLoc())
196 << "unknown attribute: " << attribute.getName();
197}
198

source code of mlir/lib/Dialect/Transform/IR/TransformDialect.cpp