1//===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "TestDialect.h"
10#include "TestOps.h"
11#include "TestTypes.h"
12#include "mlir/Bytecode/BytecodeImplementation.h"
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/Func/IR/FuncOps.h"
15#include "mlir/Dialect/Tensor/IR/Tensor.h"
16#include "mlir/IR/AsmState.h"
17#include "mlir/IR/BuiltinAttributes.h"
18#include "mlir/IR/BuiltinOps.h"
19#include "mlir/IR/Diagnostics.h"
20#include "mlir/IR/ExtensibleDialect.h"
21#include "mlir/IR/MLIRContext.h"
22#include "mlir/IR/ODSSupport.h"
23#include "mlir/IR/OperationSupport.h"
24#include "mlir/IR/PatternMatch.h"
25#include "mlir/IR/TypeUtilities.h"
26#include "mlir/IR/Verifier.h"
27#include "mlir/Interfaces/CallInterfaces.h"
28#include "mlir/Interfaces/FunctionImplementation.h"
29#include "mlir/Interfaces/InferIntRangeInterface.h"
30#include "mlir/Support/LLVM.h"
31#include "mlir/Transforms/FoldUtils.h"
32#include "mlir/Transforms/InliningUtils.h"
33#include "llvm/ADT/STLFunctionalExtras.h"
34#include "llvm/ADT/SmallString.h"
35#include "llvm/ADT/StringExtras.h"
36#include "llvm/ADT/StringSwitch.h"
37#include "llvm/Support/Base64.h"
38#include "llvm/Support/Casting.h"
39
40#include "mlir/Dialect/Arith/IR/Arith.h"
41#include "mlir/Dialect/DLTI/DLTI.h"
42#include "mlir/Interfaces/FoldInterfaces.h"
43#include "mlir/Reducer/ReductionPatternInterface.h"
44#include "mlir/Transforms/InliningUtils.h"
45#include <cstdint>
46#include <numeric>
47#include <optional>
48
49// Include this before the using namespace lines below to test that we don't
50// have namespace dependencies.
51#include "TestOpsDialect.cpp.inc"
52
53using namespace mlir;
54using namespace test;
55
56//===----------------------------------------------------------------------===//
57// PropertiesWithCustomPrint
58//===----------------------------------------------------------------------===//
59
60LogicalResult
61test::setPropertiesFromAttribute(PropertiesWithCustomPrint &prop,
62 Attribute attr,
63 function_ref<InFlightDiagnostic()> emitError) {
64 DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
65 if (!dict) {
66 emitError() << "expected DictionaryAttr to set TestProperties";
67 return failure();
68 }
69 auto label = dict.getAs<mlir::StringAttr>("label");
70 if (!label) {
71 emitError() << "expected StringAttr for key `label`";
72 return failure();
73 }
74 auto valueAttr = dict.getAs<IntegerAttr>("value");
75 if (!valueAttr) {
76 emitError() << "expected IntegerAttr for key `value`";
77 return failure();
78 }
79
80 prop.label = std::make_shared<std::string>(label.getValue());
81 prop.value = valueAttr.getValue().getSExtValue();
82 return success();
83}
84
85DictionaryAttr
86test::getPropertiesAsAttribute(MLIRContext *ctx,
87 const PropertiesWithCustomPrint &prop) {
88 SmallVector<NamedAttribute> attrs;
89 Builder b{ctx};
90 attrs.push_back(Elt: b.getNamedAttr("label", b.getStringAttr(*prop.label)));
91 attrs.push_back(Elt: b.getNamedAttr("value", b.getI32IntegerAttr(prop.value)));
92 return b.getDictionaryAttr(attrs);
93}
94
95llvm::hash_code test::computeHash(const PropertiesWithCustomPrint &prop) {
96 return llvm::hash_combine(args: prop.value, args: StringRef(*prop.label));
97}
98
99void test::customPrintProperties(OpAsmPrinter &p,
100 const PropertiesWithCustomPrint &prop) {
101 p.printKeywordOrString(keyword: *prop.label);
102 p << " is " << prop.value;
103}
104
105ParseResult test::customParseProperties(OpAsmParser &parser,
106 PropertiesWithCustomPrint &prop) {
107 std::string label;
108 if (parser.parseKeywordOrString(result: &label) || parser.parseKeyword(keyword: "is") ||
109 parser.parseInteger(result&: prop.value))
110 return failure();
111 prop.label = std::make_shared<std::string>(args: std::move(label));
112 return success();
113}
114
115//===----------------------------------------------------------------------===//
116// MyPropStruct
117//===----------------------------------------------------------------------===//
118
119Attribute MyPropStruct::asAttribute(MLIRContext *ctx) const {
120 return StringAttr::get(ctx, content);
121}
122
123LogicalResult
124MyPropStruct::setFromAttr(MyPropStruct &prop, Attribute attr,
125 function_ref<InFlightDiagnostic()> emitError) {
126 StringAttr strAttr = dyn_cast<StringAttr>(attr);
127 if (!strAttr) {
128 emitError() << "Expect StringAttr but got " << attr;
129 return failure();
130 }
131 prop.content = strAttr.getValue();
132 return success();
133}
134
135llvm::hash_code MyPropStruct::hash() const {
136 return hash_value(S: StringRef(content));
137}
138
139LogicalResult test::readFromMlirBytecode(DialectBytecodeReader &reader,
140 MyPropStruct &prop) {
141 StringRef str;
142 if (failed(Result: reader.readString(result&: str)))
143 return failure();
144 prop.content = str.str();
145 return success();
146}
147
148void test::writeToMlirBytecode(DialectBytecodeWriter &writer,
149 MyPropStruct &prop) {
150 writer.writeOwnedString(str: prop.content);
151}
152
153//===----------------------------------------------------------------------===//
154// VersionedProperties
155//===----------------------------------------------------------------------===//
156
157LogicalResult
158test::setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr,
159 function_ref<InFlightDiagnostic()> emitError) {
160 DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
161 if (!dict) {
162 emitError() << "expected DictionaryAttr to set VersionedProperties";
163 return failure();
164 }
165 auto value1Attr = dict.getAs<IntegerAttr>("value1");
166 if (!value1Attr) {
167 emitError() << "expected IntegerAttr for key `value1`";
168 return failure();
169 }
170 auto value2Attr = dict.getAs<IntegerAttr>("value2");
171 if (!value2Attr) {
172 emitError() << "expected IntegerAttr for key `value2`";
173 return failure();
174 }
175
176 prop.value1 = value1Attr.getValue().getSExtValue();
177 prop.value2 = value2Attr.getValue().getSExtValue();
178 return success();
179}
180
181DictionaryAttr test::getPropertiesAsAttribute(MLIRContext *ctx,
182 const VersionedProperties &prop) {
183 SmallVector<NamedAttribute> attrs;
184 Builder b{ctx};
185 attrs.push_back(Elt: b.getNamedAttr("value1", b.getI32IntegerAttr(prop.value1)));
186 attrs.push_back(Elt: b.getNamedAttr("value2", b.getI32IntegerAttr(prop.value2)));
187 return b.getDictionaryAttr(attrs);
188}
189
190llvm::hash_code test::computeHash(const VersionedProperties &prop) {
191 return llvm::hash_combine(args: prop.value1, args: prop.value2);
192}
193
194void test::customPrintProperties(OpAsmPrinter &p,
195 const VersionedProperties &prop) {
196 p << prop.value1 << " | " << prop.value2;
197}
198
199ParseResult test::customParseProperties(OpAsmParser &parser,
200 VersionedProperties &prop) {
201 if (parser.parseInteger(result&: prop.value1) || parser.parseVerticalBar() ||
202 parser.parseInteger(result&: prop.value2))
203 return failure();
204 return success();
205}
206
207//===----------------------------------------------------------------------===//
208// Bytecode Support
209//===----------------------------------------------------------------------===//
210
211LogicalResult test::readFromMlirBytecode(DialectBytecodeReader &reader,
212 MutableArrayRef<int64_t> prop) {
213 uint64_t size;
214 if (failed(Result: reader.readVarInt(result&: size)))
215 return failure();
216 if (size != prop.size())
217 return reader.emitError(msg: "array size mismach when reading properties: ")
218 << size << " vs expected " << prop.size();
219 for (auto &elt : prop) {
220 uint64_t value;
221 if (failed(Result: reader.readVarInt(result&: value)))
222 return failure();
223 elt = value;
224 }
225 return success();
226}
227
228void test::writeToMlirBytecode(DialectBytecodeWriter &writer,
229 ArrayRef<int64_t> prop) {
230 writer.writeVarInt(value: prop.size());
231 for (auto elt : prop)
232 writer.writeVarInt(value: elt);
233}
234
235//===----------------------------------------------------------------------===//
236// Dynamic operations
237//===----------------------------------------------------------------------===//
238
239std::unique_ptr<DynamicOpDefinition> getDynamicGenericOp(TestDialect *dialect) {
240 return DynamicOpDefinition::get(
241 "dynamic_generic", dialect, [](Operation *op) { return success(); },
242 [](Operation *op) { return success(); });
243}
244
245std::unique_ptr<DynamicOpDefinition>
246getDynamicOneOperandTwoResultsOp(TestDialect *dialect) {
247 return DynamicOpDefinition::get(
248 "dynamic_one_operand_two_results", dialect,
249 [](Operation *op) {
250 if (op->getNumOperands() != 1) {
251 op->emitOpError()
252 << "expected 1 operand, but had " << op->getNumOperands();
253 return failure();
254 }
255 if (op->getNumResults() != 2) {
256 op->emitOpError()
257 << "expected 2 results, but had " << op->getNumResults();
258 return failure();
259 }
260 return success();
261 },
262 [](Operation *op) { return success(); });
263}
264
265std::unique_ptr<DynamicOpDefinition>
266getDynamicCustomParserPrinterOp(TestDialect *dialect) {
267 auto verifier = [](Operation *op) {
268 if (op->getNumOperands() == 0 && op->getNumResults() == 0)
269 return success();
270 op->emitError() << "operation should have no operands and no results";
271 return failure();
272 };
273 auto regionVerifier = [](Operation *op) { return success(); };
274
275 auto parser = [](OpAsmParser &parser, OperationState &state) {
276 return parser.parseKeyword(keyword: "custom_keyword");
277 };
278
279 auto printer = [](Operation *op, OpAsmPrinter &printer, llvm::StringRef) {
280 printer << op->getName() << " custom_keyword";
281 };
282
283 return DynamicOpDefinition::get("dynamic_custom_parser_printer", dialect,
284 verifier, regionVerifier, parser, printer);
285}
286
287//===----------------------------------------------------------------------===//
288// TestDialect
289//===----------------------------------------------------------------------===//
290
291void test::registerTestDialect(DialectRegistry &registry) {
292 registry.insert<TestDialect>();
293}
294
295void test::testSideEffectOpGetEffect(
296 Operation *op,
297 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
298 &effects) {
299 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
300 if (!effectsAttr)
301 return;
302
303 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
304}
305
306// This is the implementation of a dialect fallback for `TestEffectOpInterface`.
307struct TestOpEffectInterfaceFallback
308 : public TestEffectOpInterface::FallbackModel<
309 TestOpEffectInterfaceFallback> {
310 static bool classof(Operation *op) {
311 bool isSupportedOp =
312 op->getName().getStringRef() == "test.unregistered_side_effect_op";
313 assert(isSupportedOp && "Unexpected dispatch");
314 return isSupportedOp;
315 }
316
317 void
318 getEffects(Operation *op,
319 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
320 &effects) const {
321 testSideEffectOpGetEffect(op, effects);
322 }
323};
324
325void TestDialect::initialize() {
326 registerAttributes();
327 registerTypes();
328 registerOpsSyntax();
329 addOperations<ManualCppOpWithFold>();
330 registerTestDialectOperations(this);
331 registerDynamicOp(getDynamicGenericOp(this));
332 registerDynamicOp(getDynamicOneOperandTwoResultsOp(this));
333 registerDynamicOp(getDynamicCustomParserPrinterOp(this));
334 registerInterfaces();
335 allowUnknownOperations();
336
337 // Instantiate our fallback op interface that we'll use on specific
338 // unregistered op.
339 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
340}
341
342TestDialect::~TestDialect() {
343 delete static_cast<TestOpEffectInterfaceFallback *>(
344 fallbackEffectOpInterfaces);
345}
346
347Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
348 Type type, Location loc) {
349 return builder.create<TestOpConstant>(loc, type, value);
350}
351
352void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
353 OperationName opName) {
354 if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
355 typeID == TypeID::get<TestEffectOpInterface>())
356 return fallbackEffectOpInterfaces;
357 return nullptr;
358}
359
360LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
361 NamedAttribute namedAttr) {
362 if (namedAttr.getName() == "test.invalid_attr")
363 return op->emitError() << "invalid to use 'test.invalid_attr'";
364 return success();
365}
366
367LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
368 unsigned regionIndex,
369 unsigned argIndex,
370 NamedAttribute namedAttr) {
371 if (namedAttr.getName() == "test.invalid_attr")
372 return op->emitError() << "invalid to use 'test.invalid_attr'";
373 return success();
374}
375
376LogicalResult
377TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
378 unsigned resultIndex,
379 NamedAttribute namedAttr) {
380 if (namedAttr.getName() == "test.invalid_attr")
381 return op->emitError() << "invalid to use 'test.invalid_attr'";
382 return success();
383}
384
385std::optional<Dialect::ParseOpHook>
386TestDialect::getParseOperationHook(StringRef opName) const {
387 if (opName == "test.dialect_custom_printer") {
388 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
389 return parser.parseKeyword("custom_format");
390 }};
391 }
392 if (opName == "test.dialect_custom_format_fallback") {
393 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
394 return parser.parseKeyword("custom_format_fallback");
395 }};
396 }
397 if (opName == "test.dialect_custom_printer.with.dot") {
398 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
399 return ParseResult::success();
400 }};
401 }
402 return std::nullopt;
403}
404
405llvm::unique_function<void(Operation *, OpAsmPrinter &)>
406TestDialect::getOperationPrinter(Operation *op) const {
407 StringRef opName = op->getName().getStringRef();
408 if (opName == "test.dialect_custom_printer") {
409 return [](Operation *op, OpAsmPrinter &printer) {
410 printer.getStream() << " custom_format";
411 };
412 }
413 if (opName == "test.dialect_custom_format_fallback") {
414 return [](Operation *op, OpAsmPrinter &printer) {
415 printer.getStream() << " custom_format_fallback";
416 };
417 }
418 return {};
419}
420
421static LogicalResult
422dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
423 PatternRewriter &rewriter) {
424 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
425 op, rewriter.getI32IntegerAttr(42));
426 return success();
427}
428
429void TestDialect::getCanonicalizationPatterns(
430 RewritePatternSet &results) const {
431 results.add(&dialectCanonicalizationPattern);
432}
433

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/test/lib/Dialect/Test/TestDialect.cpp