1 | //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "TestDialect.h" |
10 | #include "TestOps.h" |
11 | #include "TestTypes.h" |
12 | #include "mlir/Bytecode/BytecodeImplementation.h" |
13 | #include "mlir/Dialect/Arith/IR/Arith.h" |
14 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
15 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
16 | #include "mlir/IR/AsmState.h" |
17 | #include "mlir/IR/BuiltinAttributes.h" |
18 | #include "mlir/IR/BuiltinOps.h" |
19 | #include "mlir/IR/Diagnostics.h" |
20 | #include "mlir/IR/ExtensibleDialect.h" |
21 | #include "mlir/IR/MLIRContext.h" |
22 | #include "mlir/IR/ODSSupport.h" |
23 | #include "mlir/IR/OperationSupport.h" |
24 | #include "mlir/IR/PatternMatch.h" |
25 | #include "mlir/IR/TypeUtilities.h" |
26 | #include "mlir/IR/Verifier.h" |
27 | #include "mlir/Interfaces/CallInterfaces.h" |
28 | #include "mlir/Interfaces/FunctionImplementation.h" |
29 | #include "mlir/Interfaces/InferIntRangeInterface.h" |
30 | #include "mlir/Support/LLVM.h" |
31 | #include "mlir/Support/LogicalResult.h" |
32 | #include "mlir/Transforms/FoldUtils.h" |
33 | #include "mlir/Transforms/InliningUtils.h" |
34 | #include "llvm/ADT/STLFunctionalExtras.h" |
35 | #include "llvm/ADT/SmallString.h" |
36 | #include "llvm/ADT/StringExtras.h" |
37 | #include "llvm/ADT/StringSwitch.h" |
38 | #include "llvm/Support/Base64.h" |
39 | #include "llvm/Support/Casting.h" |
40 | |
41 | #include "mlir/Dialect/Arith/IR/Arith.h" |
42 | #include "mlir/Dialect/DLTI/DLTI.h" |
43 | #include "mlir/Interfaces/FoldInterfaces.h" |
44 | #include "mlir/Reducer/ReductionPatternInterface.h" |
45 | #include "mlir/Transforms/InliningUtils.h" |
46 | #include <cstdint> |
47 | #include <numeric> |
48 | #include <optional> |
49 | |
50 | // Include this before the using namespace lines below to test that we don't |
51 | // have namespace dependencies. |
52 | #include "TestOpsDialect.cpp.inc" |
53 | |
54 | using namespace mlir; |
55 | using namespace test; |
56 | |
57 | //===----------------------------------------------------------------------===// |
58 | // PropertiesWithCustomPrint |
59 | //===----------------------------------------------------------------------===// |
60 | |
61 | LogicalResult |
62 | test::setPropertiesFromAttribute(PropertiesWithCustomPrint &prop, |
63 | Attribute attr, |
64 | function_ref<InFlightDiagnostic()> emitError) { |
65 | DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr); |
66 | if (!dict) { |
67 | emitError() << "expected DictionaryAttr to set TestProperties" ; |
68 | return failure(); |
69 | } |
70 | auto label = dict.getAs<mlir::StringAttr>("label" ); |
71 | if (!label) { |
72 | emitError() << "expected StringAttr for key `label`" ; |
73 | return failure(); |
74 | } |
75 | auto valueAttr = dict.getAs<IntegerAttr>("value" ); |
76 | if (!valueAttr) { |
77 | emitError() << "expected IntegerAttr for key `value`" ; |
78 | return failure(); |
79 | } |
80 | |
81 | prop.label = std::make_shared<std::string>(label.getValue()); |
82 | prop.value = valueAttr.getValue().getSExtValue(); |
83 | return success(); |
84 | } |
85 | |
86 | DictionaryAttr |
87 | test::getPropertiesAsAttribute(MLIRContext *ctx, |
88 | const PropertiesWithCustomPrint &prop) { |
89 | SmallVector<NamedAttribute> attrs; |
90 | Builder b{ctx}; |
91 | attrs.push_back(Elt: b.getNamedAttr("label" , b.getStringAttr(*prop.label))); |
92 | attrs.push_back(Elt: b.getNamedAttr("value" , b.getI32IntegerAttr(prop.value))); |
93 | return b.getDictionaryAttr(attrs); |
94 | } |
95 | |
96 | llvm::hash_code test::computeHash(const PropertiesWithCustomPrint &prop) { |
97 | return llvm::hash_combine(args: prop.value, args: StringRef(*prop.label)); |
98 | } |
99 | |
100 | void test::customPrintProperties(OpAsmPrinter &p, |
101 | const PropertiesWithCustomPrint &prop) { |
102 | p.printKeywordOrString(keyword: *prop.label); |
103 | p << " is " << prop.value; |
104 | } |
105 | |
106 | ParseResult test::customParseProperties(OpAsmParser &parser, |
107 | PropertiesWithCustomPrint &prop) { |
108 | std::string label; |
109 | if (parser.parseKeywordOrString(result: &label) || parser.parseKeyword(keyword: "is" ) || |
110 | parser.parseInteger(result&: prop.value)) |
111 | return failure(); |
112 | prop.label = std::make_shared<std::string>(args: std::move(label)); |
113 | return success(); |
114 | } |
115 | |
116 | //===----------------------------------------------------------------------===// |
117 | // MyPropStruct |
118 | //===----------------------------------------------------------------------===// |
119 | |
120 | Attribute MyPropStruct::asAttribute(MLIRContext *ctx) const { |
121 | return StringAttr::get(ctx, content); |
122 | } |
123 | |
124 | LogicalResult |
125 | MyPropStruct::setFromAttr(MyPropStruct &prop, Attribute attr, |
126 | function_ref<InFlightDiagnostic()> emitError) { |
127 | StringAttr strAttr = dyn_cast<StringAttr>(attr); |
128 | if (!strAttr) { |
129 | emitError() << "Expect StringAttr but got " << attr; |
130 | return failure(); |
131 | } |
132 | prop.content = strAttr.getValue(); |
133 | return success(); |
134 | } |
135 | |
136 | llvm::hash_code MyPropStruct::hash() const { |
137 | return hash_value(S: StringRef(content)); |
138 | } |
139 | |
140 | LogicalResult test::readFromMlirBytecode(DialectBytecodeReader &reader, |
141 | MyPropStruct &prop) { |
142 | StringRef str; |
143 | if (failed(result: reader.readString(result&: str))) |
144 | return failure(); |
145 | prop.content = str.str(); |
146 | return success(); |
147 | } |
148 | |
149 | void test::writeToMlirBytecode(DialectBytecodeWriter &writer, |
150 | MyPropStruct &prop) { |
151 | writer.writeOwnedString(str: prop.content); |
152 | } |
153 | |
154 | //===----------------------------------------------------------------------===// |
155 | // VersionedProperties |
156 | //===----------------------------------------------------------------------===// |
157 | |
158 | LogicalResult |
159 | test::setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr, |
160 | function_ref<InFlightDiagnostic()> emitError) { |
161 | DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr); |
162 | if (!dict) { |
163 | emitError() << "expected DictionaryAttr to set VersionedProperties" ; |
164 | return failure(); |
165 | } |
166 | auto value1Attr = dict.getAs<IntegerAttr>("value1" ); |
167 | if (!value1Attr) { |
168 | emitError() << "expected IntegerAttr for key `value1`" ; |
169 | return failure(); |
170 | } |
171 | auto value2Attr = dict.getAs<IntegerAttr>("value2" ); |
172 | if (!value2Attr) { |
173 | emitError() << "expected IntegerAttr for key `value2`" ; |
174 | return failure(); |
175 | } |
176 | |
177 | prop.value1 = value1Attr.getValue().getSExtValue(); |
178 | prop.value2 = value2Attr.getValue().getSExtValue(); |
179 | return success(); |
180 | } |
181 | |
182 | DictionaryAttr test::getPropertiesAsAttribute(MLIRContext *ctx, |
183 | const VersionedProperties &prop) { |
184 | SmallVector<NamedAttribute> attrs; |
185 | Builder b{ctx}; |
186 | attrs.push_back(Elt: b.getNamedAttr("value1" , b.getI32IntegerAttr(prop.value1))); |
187 | attrs.push_back(Elt: b.getNamedAttr("value2" , b.getI32IntegerAttr(prop.value2))); |
188 | return b.getDictionaryAttr(attrs); |
189 | } |
190 | |
191 | llvm::hash_code test::computeHash(const VersionedProperties &prop) { |
192 | return llvm::hash_combine(args: prop.value1, args: prop.value2); |
193 | } |
194 | |
195 | void test::customPrintProperties(OpAsmPrinter &p, |
196 | const VersionedProperties &prop) { |
197 | p << prop.value1 << " | " << prop.value2; |
198 | } |
199 | |
200 | ParseResult test::customParseProperties(OpAsmParser &parser, |
201 | VersionedProperties &prop) { |
202 | if (parser.parseInteger(result&: prop.value1) || parser.parseVerticalBar() || |
203 | parser.parseInteger(result&: prop.value2)) |
204 | return failure(); |
205 | return success(); |
206 | } |
207 | |
208 | //===----------------------------------------------------------------------===// |
209 | // Bytecode Support |
210 | //===----------------------------------------------------------------------===// |
211 | |
212 | LogicalResult test::readFromMlirBytecode(DialectBytecodeReader &reader, |
213 | MutableArrayRef<int64_t> prop) { |
214 | uint64_t size; |
215 | if (failed(result: reader.readVarInt(result&: size))) |
216 | return failure(); |
217 | if (size != prop.size()) |
218 | return reader.emitError(msg: "array size mismach when reading properties: " ) |
219 | << size << " vs expected " << prop.size(); |
220 | for (auto &elt : prop) { |
221 | uint64_t value; |
222 | if (failed(result: reader.readVarInt(result&: value))) |
223 | return failure(); |
224 | elt = value; |
225 | } |
226 | return success(); |
227 | } |
228 | |
229 | void test::writeToMlirBytecode(DialectBytecodeWriter &writer, |
230 | ArrayRef<int64_t> prop) { |
231 | writer.writeVarInt(value: prop.size()); |
232 | for (auto elt : prop) |
233 | writer.writeVarInt(value: elt); |
234 | } |
235 | |
236 | //===----------------------------------------------------------------------===// |
237 | // Dynamic operations |
238 | //===----------------------------------------------------------------------===// |
239 | |
240 | std::unique_ptr<DynamicOpDefinition> getDynamicGenericOp(TestDialect *dialect) { |
241 | return DynamicOpDefinition::get( |
242 | "dynamic_generic" , dialect, [](Operation *op) { return success(); }, |
243 | [](Operation *op) { return success(); }); |
244 | } |
245 | |
246 | std::unique_ptr<DynamicOpDefinition> |
247 | getDynamicOneOperandTwoResultsOp(TestDialect *dialect) { |
248 | return DynamicOpDefinition::get( |
249 | "dynamic_one_operand_two_results" , dialect, |
250 | [](Operation *op) { |
251 | if (op->getNumOperands() != 1) { |
252 | op->emitOpError() |
253 | << "expected 1 operand, but had " << op->getNumOperands(); |
254 | return failure(); |
255 | } |
256 | if (op->getNumResults() != 2) { |
257 | op->emitOpError() |
258 | << "expected 2 results, but had " << op->getNumResults(); |
259 | return failure(); |
260 | } |
261 | return success(); |
262 | }, |
263 | [](Operation *op) { return success(); }); |
264 | } |
265 | |
266 | std::unique_ptr<DynamicOpDefinition> |
267 | getDynamicCustomParserPrinterOp(TestDialect *dialect) { |
268 | auto verifier = [](Operation *op) { |
269 | if (op->getNumOperands() == 0 && op->getNumResults() == 0) |
270 | return success(); |
271 | op->emitError() << "operation should have no operands and no results" ; |
272 | return failure(); |
273 | }; |
274 | auto regionVerifier = [](Operation *op) { return success(); }; |
275 | |
276 | auto parser = [](OpAsmParser &parser, OperationState &state) { |
277 | return parser.parseKeyword(keyword: "custom_keyword" ); |
278 | }; |
279 | |
280 | auto printer = [](Operation *op, OpAsmPrinter &printer, llvm::StringRef) { |
281 | printer << op->getName() << " custom_keyword" ; |
282 | }; |
283 | |
284 | return DynamicOpDefinition::get("dynamic_custom_parser_printer" , dialect, |
285 | verifier, regionVerifier, parser, printer); |
286 | } |
287 | |
288 | //===----------------------------------------------------------------------===// |
289 | // TestDialect |
290 | //===----------------------------------------------------------------------===// |
291 | |
292 | void test::registerTestDialect(DialectRegistry ®istry) { |
293 | registry.insert<TestDialect>(); |
294 | } |
295 | |
296 | void test::testSideEffectOpGetEffect( |
297 | Operation *op, |
298 | SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> |
299 | &effects) { |
300 | auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter" ); |
301 | if (!effectsAttr) |
302 | return; |
303 | |
304 | effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); |
305 | } |
306 | |
307 | // This is the implementation of a dialect fallback for `TestEffectOpInterface`. |
308 | struct TestOpEffectInterfaceFallback |
309 | : public TestEffectOpInterface::FallbackModel< |
310 | TestOpEffectInterfaceFallback> { |
311 | static bool classof(Operation *op) { |
312 | bool isSupportedOp = |
313 | op->getName().getStringRef() == "test.unregistered_side_effect_op" ; |
314 | assert(isSupportedOp && "Unexpected dispatch" ); |
315 | return isSupportedOp; |
316 | } |
317 | |
318 | void |
319 | getEffects(Operation *op, |
320 | SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> |
321 | &effects) const { |
322 | testSideEffectOpGetEffect(op, effects); |
323 | } |
324 | }; |
325 | |
326 | void TestDialect::initialize() { |
327 | registerAttributes(); |
328 | registerTypes(); |
329 | addOperations< |
330 | #define GET_OP_LIST |
331 | #include "TestOps.cpp.inc" |
332 | >(); |
333 | registerOpsSyntax(); |
334 | addOperations<ManualCppOpWithFold>(); |
335 | registerDynamicOp(getDynamicGenericOp(this)); |
336 | registerDynamicOp(getDynamicOneOperandTwoResultsOp(this)); |
337 | registerDynamicOp(getDynamicCustomParserPrinterOp(this)); |
338 | registerInterfaces(); |
339 | allowUnknownOperations(); |
340 | |
341 | // Instantiate our fallback op interface that we'll use on specific |
342 | // unregistered op. |
343 | fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; |
344 | } |
345 | |
346 | TestDialect::~TestDialect() { |
347 | delete static_cast<TestOpEffectInterfaceFallback *>( |
348 | fallbackEffectOpInterfaces); |
349 | } |
350 | |
351 | Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, |
352 | Type type, Location loc) { |
353 | return builder.create<TestOpConstant>(loc, type, value); |
354 | } |
355 | |
356 | void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, |
357 | OperationName opName) { |
358 | if (opName.getIdentifier() == "test.unregistered_side_effect_op" && |
359 | typeID == TypeID::get<TestEffectOpInterface>()) |
360 | return fallbackEffectOpInterfaces; |
361 | return nullptr; |
362 | } |
363 | |
364 | LogicalResult TestDialect::verifyOperationAttribute(Operation *op, |
365 | NamedAttribute namedAttr) { |
366 | if (namedAttr.getName() == "test.invalid_attr" ) |
367 | return op->emitError() << "invalid to use 'test.invalid_attr'" ; |
368 | return success(); |
369 | } |
370 | |
371 | LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, |
372 | unsigned regionIndex, |
373 | unsigned argIndex, |
374 | NamedAttribute namedAttr) { |
375 | if (namedAttr.getName() == "test.invalid_attr" ) |
376 | return op->emitError() << "invalid to use 'test.invalid_attr'" ; |
377 | return success(); |
378 | } |
379 | |
380 | LogicalResult |
381 | TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, |
382 | unsigned resultIndex, |
383 | NamedAttribute namedAttr) { |
384 | if (namedAttr.getName() == "test.invalid_attr" ) |
385 | return op->emitError() << "invalid to use 'test.invalid_attr'" ; |
386 | return success(); |
387 | } |
388 | |
389 | std::optional<Dialect::ParseOpHook> |
390 | TestDialect::getParseOperationHook(StringRef opName) const { |
391 | if (opName == "test.dialect_custom_printer" ) { |
392 | return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { |
393 | return parser.parseKeyword("custom_format" ); |
394 | }}; |
395 | } |
396 | if (opName == "test.dialect_custom_format_fallback" ) { |
397 | return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { |
398 | return parser.parseKeyword("custom_format_fallback" ); |
399 | }}; |
400 | } |
401 | if (opName == "test.dialect_custom_printer.with.dot" ) { |
402 | return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { |
403 | return ParseResult::success(); |
404 | }}; |
405 | } |
406 | return std::nullopt; |
407 | } |
408 | |
409 | llvm::unique_function<void(Operation *, OpAsmPrinter &)> |
410 | TestDialect::getOperationPrinter(Operation *op) const { |
411 | StringRef opName = op->getName().getStringRef(); |
412 | if (opName == "test.dialect_custom_printer" ) { |
413 | return [](Operation *op, OpAsmPrinter &printer) { |
414 | printer.getStream() << " custom_format" ; |
415 | }; |
416 | } |
417 | if (opName == "test.dialect_custom_format_fallback" ) { |
418 | return [](Operation *op, OpAsmPrinter &printer) { |
419 | printer.getStream() << " custom_format_fallback" ; |
420 | }; |
421 | } |
422 | return {}; |
423 | } |
424 | |
425 | static LogicalResult |
426 | dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, |
427 | PatternRewriter &rewriter) { |
428 | rewriter.replaceOpWithNewOp<arith::ConstantOp>( |
429 | op, rewriter.getI32IntegerAttr(42)); |
430 | return success(); |
431 | } |
432 | |
433 | void TestDialect::getCanonicalizationPatterns( |
434 | RewritePatternSet &results) const { |
435 | results.add(&dialectCanonicalizationPattern); |
436 | } |
437 | |