1//===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "TestDialect.h"
10#include "TestOps.h"
11#include "TestTypes.h"
12#include "mlir/Bytecode/BytecodeImplementation.h"
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/Func/IR/FuncOps.h"
15#include "mlir/Dialect/Tensor/IR/Tensor.h"
16#include "mlir/IR/AsmState.h"
17#include "mlir/IR/BuiltinAttributes.h"
18#include "mlir/IR/BuiltinOps.h"
19#include "mlir/IR/Diagnostics.h"
20#include "mlir/IR/ExtensibleDialect.h"
21#include "mlir/IR/MLIRContext.h"
22#include "mlir/IR/ODSSupport.h"
23#include "mlir/IR/OperationSupport.h"
24#include "mlir/IR/PatternMatch.h"
25#include "mlir/IR/TypeUtilities.h"
26#include "mlir/IR/Verifier.h"
27#include "mlir/Interfaces/CallInterfaces.h"
28#include "mlir/Interfaces/FunctionImplementation.h"
29#include "mlir/Interfaces/InferIntRangeInterface.h"
30#include "mlir/Support/LLVM.h"
31#include "mlir/Support/LogicalResult.h"
32#include "mlir/Transforms/FoldUtils.h"
33#include "mlir/Transforms/InliningUtils.h"
34#include "llvm/ADT/STLFunctionalExtras.h"
35#include "llvm/ADT/SmallString.h"
36#include "llvm/ADT/StringExtras.h"
37#include "llvm/ADT/StringSwitch.h"
38#include "llvm/Support/Base64.h"
39#include "llvm/Support/Casting.h"
40
41#include "mlir/Dialect/Arith/IR/Arith.h"
42#include "mlir/Dialect/DLTI/DLTI.h"
43#include "mlir/Interfaces/FoldInterfaces.h"
44#include "mlir/Reducer/ReductionPatternInterface.h"
45#include "mlir/Transforms/InliningUtils.h"
46#include <cstdint>
47#include <numeric>
48#include <optional>
49
50// Include this before the using namespace lines below to test that we don't
51// have namespace dependencies.
52#include "TestOpsDialect.cpp.inc"
53
54using namespace mlir;
55using namespace test;
56
57//===----------------------------------------------------------------------===//
58// PropertiesWithCustomPrint
59//===----------------------------------------------------------------------===//
60
61LogicalResult
62test::setPropertiesFromAttribute(PropertiesWithCustomPrint &prop,
63 Attribute attr,
64 function_ref<InFlightDiagnostic()> emitError) {
65 DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
66 if (!dict) {
67 emitError() << "expected DictionaryAttr to set TestProperties";
68 return failure();
69 }
70 auto label = dict.getAs<mlir::StringAttr>("label");
71 if (!label) {
72 emitError() << "expected StringAttr for key `label`";
73 return failure();
74 }
75 auto valueAttr = dict.getAs<IntegerAttr>("value");
76 if (!valueAttr) {
77 emitError() << "expected IntegerAttr for key `value`";
78 return failure();
79 }
80
81 prop.label = std::make_shared<std::string>(label.getValue());
82 prop.value = valueAttr.getValue().getSExtValue();
83 return success();
84}
85
86DictionaryAttr
87test::getPropertiesAsAttribute(MLIRContext *ctx,
88 const PropertiesWithCustomPrint &prop) {
89 SmallVector<NamedAttribute> attrs;
90 Builder b{ctx};
91 attrs.push_back(Elt: b.getNamedAttr("label", b.getStringAttr(*prop.label)));
92 attrs.push_back(Elt: b.getNamedAttr("value", b.getI32IntegerAttr(prop.value)));
93 return b.getDictionaryAttr(attrs);
94}
95
96llvm::hash_code test::computeHash(const PropertiesWithCustomPrint &prop) {
97 return llvm::hash_combine(args: prop.value, args: StringRef(*prop.label));
98}
99
100void test::customPrintProperties(OpAsmPrinter &p,
101 const PropertiesWithCustomPrint &prop) {
102 p.printKeywordOrString(keyword: *prop.label);
103 p << " is " << prop.value;
104}
105
106ParseResult test::customParseProperties(OpAsmParser &parser,
107 PropertiesWithCustomPrint &prop) {
108 std::string label;
109 if (parser.parseKeywordOrString(result: &label) || parser.parseKeyword(keyword: "is") ||
110 parser.parseInteger(result&: prop.value))
111 return failure();
112 prop.label = std::make_shared<std::string>(args: std::move(label));
113 return success();
114}
115
116//===----------------------------------------------------------------------===//
117// MyPropStruct
118//===----------------------------------------------------------------------===//
119
120Attribute MyPropStruct::asAttribute(MLIRContext *ctx) const {
121 return StringAttr::get(ctx, content);
122}
123
124LogicalResult
125MyPropStruct::setFromAttr(MyPropStruct &prop, Attribute attr,
126 function_ref<InFlightDiagnostic()> emitError) {
127 StringAttr strAttr = dyn_cast<StringAttr>(attr);
128 if (!strAttr) {
129 emitError() << "Expect StringAttr but got " << attr;
130 return failure();
131 }
132 prop.content = strAttr.getValue();
133 return success();
134}
135
136llvm::hash_code MyPropStruct::hash() const {
137 return hash_value(S: StringRef(content));
138}
139
140LogicalResult test::readFromMlirBytecode(DialectBytecodeReader &reader,
141 MyPropStruct &prop) {
142 StringRef str;
143 if (failed(result: reader.readString(result&: str)))
144 return failure();
145 prop.content = str.str();
146 return success();
147}
148
149void test::writeToMlirBytecode(DialectBytecodeWriter &writer,
150 MyPropStruct &prop) {
151 writer.writeOwnedString(str: prop.content);
152}
153
154//===----------------------------------------------------------------------===//
155// VersionedProperties
156//===----------------------------------------------------------------------===//
157
158LogicalResult
159test::setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr,
160 function_ref<InFlightDiagnostic()> emitError) {
161 DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
162 if (!dict) {
163 emitError() << "expected DictionaryAttr to set VersionedProperties";
164 return failure();
165 }
166 auto value1Attr = dict.getAs<IntegerAttr>("value1");
167 if (!value1Attr) {
168 emitError() << "expected IntegerAttr for key `value1`";
169 return failure();
170 }
171 auto value2Attr = dict.getAs<IntegerAttr>("value2");
172 if (!value2Attr) {
173 emitError() << "expected IntegerAttr for key `value2`";
174 return failure();
175 }
176
177 prop.value1 = value1Attr.getValue().getSExtValue();
178 prop.value2 = value2Attr.getValue().getSExtValue();
179 return success();
180}
181
182DictionaryAttr test::getPropertiesAsAttribute(MLIRContext *ctx,
183 const VersionedProperties &prop) {
184 SmallVector<NamedAttribute> attrs;
185 Builder b{ctx};
186 attrs.push_back(Elt: b.getNamedAttr("value1", b.getI32IntegerAttr(prop.value1)));
187 attrs.push_back(Elt: b.getNamedAttr("value2", b.getI32IntegerAttr(prop.value2)));
188 return b.getDictionaryAttr(attrs);
189}
190
191llvm::hash_code test::computeHash(const VersionedProperties &prop) {
192 return llvm::hash_combine(args: prop.value1, args: prop.value2);
193}
194
195void test::customPrintProperties(OpAsmPrinter &p,
196 const VersionedProperties &prop) {
197 p << prop.value1 << " | " << prop.value2;
198}
199
200ParseResult test::customParseProperties(OpAsmParser &parser,
201 VersionedProperties &prop) {
202 if (parser.parseInteger(result&: prop.value1) || parser.parseVerticalBar() ||
203 parser.parseInteger(result&: prop.value2))
204 return failure();
205 return success();
206}
207
208//===----------------------------------------------------------------------===//
209// Bytecode Support
210//===----------------------------------------------------------------------===//
211
212LogicalResult test::readFromMlirBytecode(DialectBytecodeReader &reader,
213 MutableArrayRef<int64_t> prop) {
214 uint64_t size;
215 if (failed(result: reader.readVarInt(result&: size)))
216 return failure();
217 if (size != prop.size())
218 return reader.emitError(msg: "array size mismach when reading properties: ")
219 << size << " vs expected " << prop.size();
220 for (auto &elt : prop) {
221 uint64_t value;
222 if (failed(result: reader.readVarInt(result&: value)))
223 return failure();
224 elt = value;
225 }
226 return success();
227}
228
229void test::writeToMlirBytecode(DialectBytecodeWriter &writer,
230 ArrayRef<int64_t> prop) {
231 writer.writeVarInt(value: prop.size());
232 for (auto elt : prop)
233 writer.writeVarInt(value: elt);
234}
235
236//===----------------------------------------------------------------------===//
237// Dynamic operations
238//===----------------------------------------------------------------------===//
239
240std::unique_ptr<DynamicOpDefinition> getDynamicGenericOp(TestDialect *dialect) {
241 return DynamicOpDefinition::get(
242 "dynamic_generic", dialect, [](Operation *op) { return success(); },
243 [](Operation *op) { return success(); });
244}
245
246std::unique_ptr<DynamicOpDefinition>
247getDynamicOneOperandTwoResultsOp(TestDialect *dialect) {
248 return DynamicOpDefinition::get(
249 "dynamic_one_operand_two_results", dialect,
250 [](Operation *op) {
251 if (op->getNumOperands() != 1) {
252 op->emitOpError()
253 << "expected 1 operand, but had " << op->getNumOperands();
254 return failure();
255 }
256 if (op->getNumResults() != 2) {
257 op->emitOpError()
258 << "expected 2 results, but had " << op->getNumResults();
259 return failure();
260 }
261 return success();
262 },
263 [](Operation *op) { return success(); });
264}
265
266std::unique_ptr<DynamicOpDefinition>
267getDynamicCustomParserPrinterOp(TestDialect *dialect) {
268 auto verifier = [](Operation *op) {
269 if (op->getNumOperands() == 0 && op->getNumResults() == 0)
270 return success();
271 op->emitError() << "operation should have no operands and no results";
272 return failure();
273 };
274 auto regionVerifier = [](Operation *op) { return success(); };
275
276 auto parser = [](OpAsmParser &parser, OperationState &state) {
277 return parser.parseKeyword(keyword: "custom_keyword");
278 };
279
280 auto printer = [](Operation *op, OpAsmPrinter &printer, llvm::StringRef) {
281 printer << op->getName() << " custom_keyword";
282 };
283
284 return DynamicOpDefinition::get("dynamic_custom_parser_printer", dialect,
285 verifier, regionVerifier, parser, printer);
286}
287
288//===----------------------------------------------------------------------===//
289// TestDialect
290//===----------------------------------------------------------------------===//
291
292void test::registerTestDialect(DialectRegistry &registry) {
293 registry.insert<TestDialect>();
294}
295
296void test::testSideEffectOpGetEffect(
297 Operation *op,
298 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
299 &effects) {
300 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
301 if (!effectsAttr)
302 return;
303
304 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
305}
306
307// This is the implementation of a dialect fallback for `TestEffectOpInterface`.
308struct TestOpEffectInterfaceFallback
309 : public TestEffectOpInterface::FallbackModel<
310 TestOpEffectInterfaceFallback> {
311 static bool classof(Operation *op) {
312 bool isSupportedOp =
313 op->getName().getStringRef() == "test.unregistered_side_effect_op";
314 assert(isSupportedOp && "Unexpected dispatch");
315 return isSupportedOp;
316 }
317
318 void
319 getEffects(Operation *op,
320 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
321 &effects) const {
322 testSideEffectOpGetEffect(op, effects);
323 }
324};
325
326void TestDialect::initialize() {
327 registerAttributes();
328 registerTypes();
329 addOperations<
330#define GET_OP_LIST
331#include "TestOps.cpp.inc"
332 >();
333 registerOpsSyntax();
334 addOperations<ManualCppOpWithFold>();
335 registerDynamicOp(getDynamicGenericOp(this));
336 registerDynamicOp(getDynamicOneOperandTwoResultsOp(this));
337 registerDynamicOp(getDynamicCustomParserPrinterOp(this));
338 registerInterfaces();
339 allowUnknownOperations();
340
341 // Instantiate our fallback op interface that we'll use on specific
342 // unregistered op.
343 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
344}
345
346TestDialect::~TestDialect() {
347 delete static_cast<TestOpEffectInterfaceFallback *>(
348 fallbackEffectOpInterfaces);
349}
350
351Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
352 Type type, Location loc) {
353 return builder.create<TestOpConstant>(loc, type, value);
354}
355
356void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
357 OperationName opName) {
358 if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
359 typeID == TypeID::get<TestEffectOpInterface>())
360 return fallbackEffectOpInterfaces;
361 return nullptr;
362}
363
364LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
365 NamedAttribute namedAttr) {
366 if (namedAttr.getName() == "test.invalid_attr")
367 return op->emitError() << "invalid to use 'test.invalid_attr'";
368 return success();
369}
370
371LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
372 unsigned regionIndex,
373 unsigned argIndex,
374 NamedAttribute namedAttr) {
375 if (namedAttr.getName() == "test.invalid_attr")
376 return op->emitError() << "invalid to use 'test.invalid_attr'";
377 return success();
378}
379
380LogicalResult
381TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
382 unsigned resultIndex,
383 NamedAttribute namedAttr) {
384 if (namedAttr.getName() == "test.invalid_attr")
385 return op->emitError() << "invalid to use 'test.invalid_attr'";
386 return success();
387}
388
389std::optional<Dialect::ParseOpHook>
390TestDialect::getParseOperationHook(StringRef opName) const {
391 if (opName == "test.dialect_custom_printer") {
392 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
393 return parser.parseKeyword("custom_format");
394 }};
395 }
396 if (opName == "test.dialect_custom_format_fallback") {
397 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
398 return parser.parseKeyword("custom_format_fallback");
399 }};
400 }
401 if (opName == "test.dialect_custom_printer.with.dot") {
402 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
403 return ParseResult::success();
404 }};
405 }
406 return std::nullopt;
407}
408
409llvm::unique_function<void(Operation *, OpAsmPrinter &)>
410TestDialect::getOperationPrinter(Operation *op) const {
411 StringRef opName = op->getName().getStringRef();
412 if (opName == "test.dialect_custom_printer") {
413 return [](Operation *op, OpAsmPrinter &printer) {
414 printer.getStream() << " custom_format";
415 };
416 }
417 if (opName == "test.dialect_custom_format_fallback") {
418 return [](Operation *op, OpAsmPrinter &printer) {
419 printer.getStream() << " custom_format_fallback";
420 };
421 }
422 return {};
423}
424
425static LogicalResult
426dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
427 PatternRewriter &rewriter) {
428 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
429 op, rewriter.getI32IntegerAttr(42));
430 return success();
431}
432
433void TestDialect::getCanonicalizationPatterns(
434 RewritePatternSet &results) const {
435 results.add(&dialectCanonicalizationPattern);
436}
437

source code of mlir/test/lib/Dialect/Test/TestDialect.cpp