1//===- TransformDialect.cpp - Transform Dialect Definition ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Transform/IR/TransformDialect.h"
10#include "mlir/Analysis/CallGraph.h"
11#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
12#include "mlir/Dialect/Transform/IR/TransformOps.h"
13#include "mlir/Dialect/Transform/IR/TransformTypes.h"
14#include "mlir/Dialect/Transform/IR/Utils.h"
15#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
16#include "mlir/IR/DialectImplementation.h"
17#include "llvm/ADT/SCCIterator.h"
18#include "llvm/ADT/TypeSwitch.h"
19
20using namespace mlir;
21
22#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
23
24#define GET_ATTRDEF_CLASSES
25#include "mlir/Dialect/Transform/IR/TransformAttrs.cpp.inc"
26
27#ifndef NDEBUG
28void transform::detail::checkImplementsTransformOpInterface(
29 StringRef name, MLIRContext *context) {
30 // Since the operation is being inserted into the Transform dialect and the
31 // dialect does not implement the interface fallback, only check for the op
32 // itself having the interface implementation.
33 RegisteredOperationName opName =
34 *RegisteredOperationName::lookup(name, ctx: context);
35 assert((opName.hasInterface<TransformOpInterface>() ||
36 opName.hasInterface<PatternDescriptorOpInterface>() ||
37 opName.hasInterface<ConversionPatternDescriptorOpInterface>() ||
38 opName.hasInterface<TypeConverterBuilderOpInterface>() ||
39 opName.hasTrait<OpTrait::IsTerminator>()) &&
40 "non-terminator ops injected into the transform dialect must "
41 "implement TransformOpInterface or PatternDescriptorOpInterface or "
42 "ConversionPatternDescriptorOpInterface");
43 if (!opName.hasInterface<PatternDescriptorOpInterface>() &&
44 !opName.hasInterface<ConversionPatternDescriptorOpInterface>() &&
45 !opName.hasInterface<TypeConverterBuilderOpInterface>()) {
46 assert(opName.hasInterface<MemoryEffectOpInterface>() &&
47 "ops injected into the transform dialect must implement "
48 "MemoryEffectsOpInterface");
49 }
50}
51
52void transform::detail::checkImplementsTransformHandleTypeInterface(
53 TypeID typeID, MLIRContext *context) {
54 const auto &abstractType = AbstractType::lookup(typeID, context);
55 assert((abstractType.hasInterface(
56 TransformHandleTypeInterface::getInterfaceID()) ||
57 abstractType.hasInterface(
58 TransformParamTypeInterface::getInterfaceID()) ||
59 abstractType.hasInterface(
60 TransformValueHandleTypeInterface::getInterfaceID())) &&
61 "expected Transform dialect type to implement one of the three "
62 "interfaces");
63}
64#endif // NDEBUG
65
66void transform::TransformDialect::initialize() {
67 // Using the checked versions to enable the same assertions as for the ops
68 // from extensions.
69 addOperationsChecked<
70#define GET_OP_LIST
71#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
72 >();
73 initializeTypes();
74 addAttributes<
75#define GET_ATTRDEF_LIST
76#include "mlir/Dialect/Transform/IR/TransformAttrs.cpp.inc"
77 >();
78 initializeLibraryModule();
79}
80
81Type transform::TransformDialect::parseType(DialectAsmParser &parser) const {
82 StringRef keyword;
83 SMLoc loc = parser.getCurrentLocation();
84 if (failed(parser.parseKeyword(&keyword)))
85 return nullptr;
86
87 auto it = typeParsingHooks.find(keyword);
88 if (it == typeParsingHooks.end()) {
89 parser.emitError(loc) << "unknown type mnemonic: " << keyword;
90 return nullptr;
91 }
92
93 return it->getValue()(parser);
94}
95
96void transform::TransformDialect::printType(Type type,
97 DialectAsmPrinter &printer) const {
98 auto it = typePrintingHooks.find(type.getTypeID());
99 assert(it != typePrintingHooks.end() && "printing unknown type");
100 it->getSecond()(type, printer);
101}
102
103LogicalResult transform::TransformDialect::loadIntoLibraryModule(
104 ::mlir::OwningOpRef<::mlir::ModuleOp> &&library) {
105 return detail::mergeSymbolsInto(getLibraryModule(), std::move(library));
106}
107
108void transform::TransformDialect::initializeLibraryModule() {
109 MLIRContext *context = getContext();
110 auto loc =
111 FileLineColLoc::get(context, "<transform-dialect-library-module>", 0, 0);
112 libraryModule = ModuleOp::create(loc, "__transform_library");
113 libraryModule.get()->setAttr(TransformDialect::kWithNamedSequenceAttrName,
114 UnitAttr::get(context));
115}
116
117void transform::TransformDialect::reportDuplicateTypeRegistration(
118 StringRef mnemonic) {
119 std::string buffer;
120 llvm::raw_string_ostream msg(buffer);
121 msg << "extensible dialect type '" << mnemonic
122 << "' is already registered with a different implementation";
123 llvm::report_fatal_error(StringRef(buffer));
124}
125
126void transform::TransformDialect::reportDuplicateOpRegistration(
127 StringRef opName) {
128 std::string buffer;
129 llvm::raw_string_ostream msg(buffer);
130 msg << "extensible dialect operation '" << opName
131 << "' is already registered with a mismatching TypeID";
132 llvm::report_fatal_error(StringRef(buffer));
133}
134
135LogicalResult transform::TransformDialect::verifyOperationAttribute(
136 Operation *op, NamedAttribute attribute) {
137 if (attribute.getName().getValue() == kWithNamedSequenceAttrName) {
138 if (!op->hasTrait<OpTrait::SymbolTable>()) {
139 return emitError(op->getLoc()) << attribute.getName()
140 << " attribute can only be attached to "
141 "operations with symbol tables";
142 }
143
144 const mlir::CallGraph callgraph(op);
145 for (auto scc = llvm::scc_begin(&callgraph); !scc.isAtEnd(); ++scc) {
146 if (!scc.hasCycle())
147 continue;
148
149 // Need to check this here additionally because this verification may run
150 // before we check the nested operations.
151 if ((*scc->begin())->isExternal())
152 return op->emitOpError() << "contains a call to an external operation, "
153 "which is not allowed";
154
155 Operation *first = (*scc->begin())->getCallableRegion()->getParentOp();
156 InFlightDiagnostic diag = emitError(first->getLoc())
157 << "recursion not allowed in named sequences";
158 for (auto it = std::next(scc->begin()); it != scc->end(); ++it) {
159 // Need to check this here additionally because this verification may
160 // run before we check the nested operations.
161 if ((*it)->isExternal()) {
162 return op->emitOpError() << "contains a call to an external "
163 "operation, which is not allowed";
164 }
165
166 Operation *current = (*it)->getCallableRegion()->getParentOp();
167 diag.attachNote(current->getLoc()) << "operation on recursion stack";
168 }
169 return diag;
170 }
171 return success();
172 }
173 if (attribute.getName().getValue() == kTargetTagAttrName) {
174 if (!llvm::isa<StringAttr>(attribute.getValue())) {
175 return op->emitError()
176 << attribute.getName() << " attribute must be a string";
177 }
178 return success();
179 }
180 if (attribute.getName().getValue() == kArgConsumedAttrName ||
181 attribute.getName().getValue() == kArgReadOnlyAttrName) {
182 if (!llvm::isa<UnitAttr>(attribute.getValue())) {
183 return op->emitError()
184 << attribute.getName() << " must be a unit attribute";
185 }
186 return success();
187 }
188 if (attribute.getName().getValue() ==
189 FindPayloadReplacementOpInterface::kSilenceTrackingFailuresAttrName) {
190 if (!llvm::isa<UnitAttr>(attribute.getValue())) {
191 return op->emitError()
192 << attribute.getName() << " must be a unit attribute";
193 }
194 return success();
195 }
196 return emitError(op->getLoc())
197 << "unknown attribute: " << attribute.getName();
198}
199

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Transform/IR/TransformDialect.cpp