1 | //===- TestBytecodeCallbacks.cpp - Pass to test bytecode callback hooks --===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "TestDialect.h" |
10 | #include "TestOps.h" |
11 | #include "mlir/Bytecode/BytecodeReader.h" |
12 | #include "mlir/Bytecode/BytecodeWriter.h" |
13 | #include "mlir/IR/BuiltinOps.h" |
14 | #include "mlir/IR/OperationSupport.h" |
15 | #include "mlir/Parser/Parser.h" |
16 | #include "mlir/Pass/Pass.h" |
17 | #include "llvm/Support/CommandLine.h" |
18 | #include "llvm/Support/MemoryBufferRef.h" |
19 | #include "llvm/Support/raw_ostream.h" |
20 | #include <list> |
21 | |
22 | using namespace mlir; |
23 | using namespace llvm; |
24 | |
25 | namespace { |
26 | class TestDialectVersionParser : public cl::parser<test::TestDialectVersion> { |
27 | public: |
28 | TestDialectVersionParser(cl::Option &o) |
29 | : cl::parser<test::TestDialectVersion>(o) {} |
30 | |
31 | bool parse(cl::Option &o, StringRef /*argName*/, StringRef arg, |
32 | test::TestDialectVersion &v) { |
33 | long long major, minor; |
34 | if (getAsSignedInteger(Str: arg.split(Separator: "." ).first, Radix: 10, Result&: major)) |
35 | return o.error(Message: "Invalid argument '" + arg); |
36 | if (getAsSignedInteger(Str: arg.split(Separator: "." ).second, Radix: 10, Result&: minor)) |
37 | return o.error(Message: "Invalid argument '" + arg); |
38 | v = test::TestDialectVersion(major, minor); |
39 | // Returns true on error. |
40 | return false; |
41 | } |
42 | static void print(raw_ostream &os, const test::TestDialectVersion &v) { |
43 | os << v.major_ << "." << v.minor_; |
44 | }; |
45 | }; |
46 | |
47 | /// This is a test pass which uses callbacks to encode attributes and types in a |
48 | /// custom fashion. |
49 | struct TestBytecodeRoundtripPass |
50 | : public PassWrapper<TestBytecodeRoundtripPass, OperationPass<ModuleOp>> { |
51 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestBytecodeRoundtripPass) |
52 | |
53 | StringRef getArgument() const final { return "test-bytecode-roundtrip" ; } |
54 | StringRef getDescription() const final { |
55 | return "Test pass to implement bytecode roundtrip tests." ; |
56 | } |
57 | void getDependentDialects(DialectRegistry ®istry) const override { |
58 | registry.insert<test::TestDialect>(); |
59 | } |
60 | TestBytecodeRoundtripPass() = default; |
61 | TestBytecodeRoundtripPass(const TestBytecodeRoundtripPass &) {} |
62 | |
63 | LogicalResult initialize(MLIRContext *context) override { |
64 | testDialect = context->getOrLoadDialect<test::TestDialect>(); |
65 | return success(); |
66 | } |
67 | |
68 | void runOnOperation() override { |
69 | switch (testKind) { |
70 | // Tests 0-5 implement a custom roundtrip with callbacks. |
71 | case (0): |
72 | return runTest0(getOperation()); |
73 | case (1): |
74 | return runTest1(getOperation()); |
75 | case (2): |
76 | return runTest2(getOperation()); |
77 | case (3): |
78 | return runTest3(getOperation()); |
79 | case (4): |
80 | return runTest4(getOperation()); |
81 | case (5): |
82 | return runTest5(getOperation()); |
83 | case (6): |
84 | // test-kind 6 is a plain roundtrip with downgrade/upgrade to/from |
85 | // `targetVersion`. |
86 | return runTest6(getOperation()); |
87 | default: |
88 | llvm_unreachable("unhandled test kind for TestBytecodeCallbacks pass" ); |
89 | } |
90 | } |
91 | |
92 | mlir::Pass::Option<test::TestDialectVersion, TestDialectVersionParser> |
93 | targetVersion{*this, "test-dialect-version" , |
94 | llvm::cl::desc( |
95 | "Specifies the test dialect version to emit and parse" ), |
96 | cl::init(Val: test::TestDialectVersion())}; |
97 | |
98 | mlir::Pass::Option<int> testKind{ |
99 | *this, "test-kind" , llvm::cl::desc("Specifies the test kind to execute" ), |
100 | cl::init(Val: 0)}; |
101 | |
102 | private: |
103 | void doRoundtripWithConfigs(Operation *op, |
104 | const BytecodeWriterConfig &writeConfig, |
105 | const ParserConfig &parseConfig) { |
106 | std::string bytecode; |
107 | llvm::raw_string_ostream os(bytecode); |
108 | if (failed(result: writeBytecodeToFile(op, os, config: writeConfig))) { |
109 | op->emitError() << "failed to write bytecode\n" ; |
110 | signalPassFailure(); |
111 | return; |
112 | } |
113 | auto newModuleOp = parseSourceString(sourceStr: StringRef(bytecode), config: parseConfig); |
114 | if (!newModuleOp.get()) { |
115 | op->emitError() << "failed to read bytecode\n" ; |
116 | signalPassFailure(); |
117 | return; |
118 | } |
119 | // Print the module to the output stream, so that we can filecheck the |
120 | // result. |
121 | newModuleOp->print(os&: llvm::outs()); |
122 | } |
123 | |
124 | // Test0: let's assume that versions older than 2.0 were relying on a special |
125 | // integer attribute of a deprecated dialect called "funky". Assume that its |
126 | // encoding was made by two varInts, the first was the ID (999) and the second |
127 | // contained width and signedness info. We can emit it using a callback |
128 | // writing a custom encoding for the "funky" dialect group, and parse it back |
129 | // with a custom parser reading the same encoding in the same dialect group. |
130 | // Note that the ID 999 does not correspond to a valid integer type in the |
131 | // current encodings of builtin types. |
132 | void runTest0(Operation *op) { |
133 | auto newCtx = std::make_shared<MLIRContext>(); |
134 | test::TestDialectVersion targetEmissionVersion = targetVersion; |
135 | BytecodeWriterConfig writeConfig; |
136 | // Set the emission version for the test dialect. |
137 | writeConfig.setDialectVersion<test::TestDialect>( |
138 | std::make_unique<test::TestDialectVersion>(args&: targetEmissionVersion)); |
139 | writeConfig.attachTypeCallback( |
140 | emitFn: [&](Type entryValue, std::optional<StringRef> &dialectGroupName, |
141 | DialectBytecodeWriter &writer) -> LogicalResult { |
142 | // Do not override anything if version greater than 2.0. |
143 | auto versionOr = writer.getDialectVersion<test::TestDialect>(); |
144 | assert(succeeded(versionOr) && "expected reader to be able to access " |
145 | "the version for test dialect" ); |
146 | const auto *version = |
147 | reinterpret_cast<const test::TestDialectVersion *>(*versionOr); |
148 | if (version->major_ >= 2) |
149 | return failure(); |
150 | |
151 | // For version less than 2.0, override the encoding of IntegerType. |
152 | if (auto type = llvm::dyn_cast<IntegerType>(entryValue)) { |
153 | llvm::outs() << "Overriding IntegerType encoding...\n" ; |
154 | dialectGroupName = StringLiteral("funky" ); |
155 | writer.writeVarInt(/* IntegerType */ value: 999); |
156 | writer.writeVarInt(value: type.getWidth() << 2 | type.getSignedness()); |
157 | return success(); |
158 | } |
159 | return failure(); |
160 | }); |
161 | newCtx->appendDialectRegistry(registry: op->getContext()->getDialectRegistry()); |
162 | newCtx->allowUnregisteredDialects(); |
163 | ParserConfig parseConfig(newCtx.get(), /*verifyAfterParse=*/true); |
164 | parseConfig.getBytecodeReaderConfig().attachTypeCallback( |
165 | parserFn: [&](DialectBytecodeReader &reader, StringRef dialectName, |
166 | Type &entry) -> LogicalResult { |
167 | // Get test dialect version from the version map. |
168 | auto versionOr = reader.getDialectVersion<test::TestDialect>(); |
169 | assert(succeeded(versionOr) && "expected reader to be able to access " |
170 | "the version for test dialect" ); |
171 | const auto *version = |
172 | reinterpret_cast<const test::TestDialectVersion *>(*versionOr); |
173 | if (version->major_ >= 2) |
174 | return success(); |
175 | |
176 | // `dialectName` is the name of the group we have the opportunity to |
177 | // override. In this case, override only the dialect group "funky", |
178 | // for which does not exist in memory. |
179 | if (dialectName != StringLiteral("funky" )) |
180 | return success(); |
181 | |
182 | uint64_t encoding; |
183 | if (failed(result: reader.readVarInt(result&: encoding)) || encoding != 999) |
184 | return success(); |
185 | llvm::outs() << "Overriding parsing of IntegerType encoding...\n" ; |
186 | uint64_t widthAndSignedness, width; |
187 | IntegerType::SignednessSemantics signedness; |
188 | if (succeeded(reader.readVarInt(widthAndSignedness)) && |
189 | ((width = widthAndSignedness >> 2), true) && |
190 | ((signedness = static_cast<IntegerType::SignednessSemantics>( |
191 | widthAndSignedness & 0x3)), |
192 | true)) |
193 | entry = IntegerType::get(reader.getContext(), width, signedness); |
194 | // Return nullopt to fall through the rest of the parsing code path. |
195 | return success(); |
196 | }); |
197 | doRoundtripWithConfigs(op, writeConfig, parseConfig); |
198 | } |
199 | |
200 | // Test1: When writing bytecode, we override the encoding of TestI32Type with |
201 | // the encoding of builtin IntegerType. We can natively parse this without |
202 | // the use of a callback, relying on the existing builtin reader mechanism. |
203 | void runTest1(Operation *op) { |
204 | auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>(); |
205 | BytecodeDialectInterface *iface = |
206 | builtin->getRegisteredInterface<BytecodeDialectInterface>(); |
207 | BytecodeWriterConfig writeConfig; |
208 | writeConfig.attachTypeCallback( |
209 | emitFn: [&](Type entryValue, std::optional<StringRef> &dialectGroupName, |
210 | DialectBytecodeWriter &writer) -> LogicalResult { |
211 | // Emit TestIntegerType using the builtin dialect encoding. |
212 | if (llvm::isa<test::TestI32Type>(entryValue)) { |
213 | llvm::outs() << "Overriding TestI32Type encoding...\n" ; |
214 | auto builtinI32Type = |
215 | IntegerType::get(op->getContext(), 32, |
216 | IntegerType::SignednessSemantics::Signless); |
217 | // Specify that this type will need to be written as part of the |
218 | // builtin group. This will override the default dialect group of |
219 | // the attribute (test). |
220 | dialectGroupName = StringLiteral("builtin" ); |
221 | if (succeeded(iface->writeType(type: builtinI32Type, writer))) |
222 | return success(); |
223 | } |
224 | return failure(); |
225 | }); |
226 | // We natively parse the attribute as a builtin, so no callback needed. |
227 | ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true); |
228 | doRoundtripWithConfigs(op, writeConfig, parseConfig); |
229 | } |
230 | |
231 | // Test2: When writing bytecode, we write standard builtin IntegerTypes. At |
232 | // parsing, we use the encoding of IntegerType to intercept all i32. Then, |
233 | // instead of creating i32s, we assemble TestI32Type and return it. |
234 | void runTest2(Operation *op) { |
235 | auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>(); |
236 | BytecodeDialectInterface *iface = |
237 | builtin->getRegisteredInterface<BytecodeDialectInterface>(); |
238 | BytecodeWriterConfig writeConfig; |
239 | ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true); |
240 | parseConfig.getBytecodeReaderConfig().attachTypeCallback( |
241 | parserFn: [&](DialectBytecodeReader &reader, StringRef dialectName, |
242 | Type &entry) -> LogicalResult { |
243 | if (dialectName != StringLiteral("builtin" )) |
244 | return success(); |
245 | Type builtinAttr = iface->readType(reader); |
246 | if (auto integerType = |
247 | llvm::dyn_cast_or_null<IntegerType>(builtinAttr)) { |
248 | if (integerType.getWidth() == 32 && integerType.isSignless()) { |
249 | llvm::outs() << "Overriding parsing of TestI32Type encoding...\n" ; |
250 | entry = test::TestI32Type::get(reader.getContext()); |
251 | } |
252 | } |
253 | return success(); |
254 | }); |
255 | doRoundtripWithConfigs(op, writeConfig, parseConfig); |
256 | } |
257 | |
258 | // Test3: When writing bytecode, we override the encoding of |
259 | // TestAttrParamsAttr with the encoding of builtin DenseIntElementsAttr. We |
260 | // can natively parse this without the use of a callback, relying on the |
261 | // existing builtin reader mechanism. |
262 | void runTest3(Operation *op) { |
263 | auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>(); |
264 | BytecodeDialectInterface *iface = |
265 | builtin->getRegisteredInterface<BytecodeDialectInterface>(); |
266 | auto i32Type = IntegerType::get(op->getContext(), 32, |
267 | IntegerType::SignednessSemantics::Signless); |
268 | BytecodeWriterConfig writeConfig; |
269 | writeConfig.attachAttributeCallback( |
270 | emitFn: [&](Attribute entryValue, std::optional<StringRef> &dialectGroupName, |
271 | DialectBytecodeWriter &writer) -> LogicalResult { |
272 | // Emit TestIntegerType using the builtin dialect encoding. |
273 | if (auto testParamAttrs = |
274 | llvm::dyn_cast<test::TestAttrParamsAttr>(entryValue)) { |
275 | llvm::outs() << "Overriding TestAttrParamsAttr encoding...\n" ; |
276 | // Specify that this attribute will need to be written as part of |
277 | // the builtin group. This will override the default dialect group |
278 | // of the attribute (test). |
279 | dialectGroupName = StringLiteral("builtin" ); |
280 | auto denseAttr = DenseIntElementsAttr::get( |
281 | RankedTensorType::get({2}, i32Type), |
282 | {testParamAttrs.getV0(), testParamAttrs.getV1()}); |
283 | if (succeeded(iface->writeAttribute(attr: denseAttr, writer))) |
284 | return success(); |
285 | } |
286 | return failure(); |
287 | }); |
288 | // We natively parse the attribute as a builtin, so no callback needed. |
289 | ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false); |
290 | doRoundtripWithConfigs(op, writeConfig, parseConfig); |
291 | } |
292 | |
293 | // Test4: When writing bytecode, we write standard builtin |
294 | // DenseIntElementsAttr. At parsing, we use the encoding of |
295 | // DenseIntElementsAttr to intercept all ElementsAttr that have shaped type of |
296 | // <2xi32>. Instead of assembling a DenseIntElementsAttr, we assemble |
297 | // TestAttrParamsAttr and return it. |
298 | void runTest4(Operation *op) { |
299 | auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>(); |
300 | BytecodeDialectInterface *iface = |
301 | builtin->getRegisteredInterface<BytecodeDialectInterface>(); |
302 | auto i32Type = IntegerType::get(op->getContext(), 32, |
303 | IntegerType::SignednessSemantics::Signless); |
304 | BytecodeWriterConfig writeConfig; |
305 | ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false); |
306 | parseConfig.getBytecodeReaderConfig().attachAttributeCallback( |
307 | parserFn: [&](DialectBytecodeReader &reader, StringRef dialectName, |
308 | Attribute &entry) -> LogicalResult { |
309 | // Override only the case where the return type of the builtin reader |
310 | // is an i32 and fall through on all the other cases, since we want to |
311 | // still use TestDialect normal codepath to parse the other types. |
312 | Attribute builtinAttr = iface->readAttribute(reader); |
313 | if (auto denseAttr = |
314 | llvm::dyn_cast_or_null<DenseIntElementsAttr>(builtinAttr)) { |
315 | if (denseAttr.getType().getShape() == ArrayRef<int64_t>(2) && |
316 | denseAttr.getElementType() == i32Type) { |
317 | llvm::outs() |
318 | << "Overriding parsing of TestAttrParamsAttr encoding...\n" ; |
319 | int v0 = denseAttr.getValues<IntegerAttr>()[0].getInt(); |
320 | int v1 = denseAttr.getValues<IntegerAttr>()[1].getInt(); |
321 | entry = |
322 | test::TestAttrParamsAttr::get(reader.getContext(), v0, v1); |
323 | } |
324 | } |
325 | return success(); |
326 | }); |
327 | doRoundtripWithConfigs(op, writeConfig, parseConfig); |
328 | } |
329 | |
330 | // Test5: When writing bytecode, we want TestDialect to use nothing else than |
331 | // the builtin types and attributes and take full control of the encoding, |
332 | // returning failure if any type or attribute is not part of builtin. |
333 | void runTest5(Operation *op) { |
334 | auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>(); |
335 | BytecodeDialectInterface *iface = |
336 | builtin->getRegisteredInterface<BytecodeDialectInterface>(); |
337 | BytecodeWriterConfig writeConfig; |
338 | writeConfig.attachAttributeCallback( |
339 | emitFn: [&](Attribute attr, std::optional<StringRef> &dialectGroupName, |
340 | DialectBytecodeWriter &writer) -> LogicalResult { |
341 | return iface->writeAttribute(attr, writer); |
342 | }); |
343 | writeConfig.attachTypeCallback( |
344 | emitFn: [&](Type type, std::optional<StringRef> &dialectGroupName, |
345 | DialectBytecodeWriter &writer) -> LogicalResult { |
346 | return iface->writeType(type, writer); |
347 | }); |
348 | ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false); |
349 | parseConfig.getBytecodeReaderConfig().attachAttributeCallback( |
350 | parserFn: [&](DialectBytecodeReader &reader, StringRef dialectName, |
351 | Attribute &entry) -> LogicalResult { |
352 | Attribute builtinAttr = iface->readAttribute(reader); |
353 | if (!builtinAttr) |
354 | return failure(); |
355 | entry = builtinAttr; |
356 | return success(); |
357 | }); |
358 | parseConfig.getBytecodeReaderConfig().attachTypeCallback( |
359 | parserFn: [&](DialectBytecodeReader &reader, StringRef dialectName, |
360 | Type &entry) -> LogicalResult { |
361 | Type builtinType = iface->readType(reader); |
362 | if (!builtinType) { |
363 | return failure(); |
364 | } |
365 | entry = builtinType; |
366 | return success(); |
367 | }); |
368 | doRoundtripWithConfigs(op, writeConfig, parseConfig); |
369 | } |
370 | |
371 | LogicalResult downgradeToVersion(Operation *op, |
372 | const test::TestDialectVersion &version) { |
373 | if ((version.major_ == 2) && (version.minor_ == 0)) |
374 | return success(); |
375 | if (version.major_ > 2 || (version.major_ == 2 && version.minor_ > 0)) { |
376 | return op->emitError() << "current test dialect version is 2.0, " |
377 | "can't downgrade to version: " |
378 | << version.major_ << "." << version.minor_; |
379 | } |
380 | // Prior version 2.0, the old op supported only a single attribute called |
381 | // "dimensions". We need to check that the modifier is false, otherwise we |
382 | // can't do the downgrade. |
383 | auto status = op->walk(callback: [&](test::TestVersionedOpA op) { |
384 | auto &prop = op.getProperties(); |
385 | if (prop.modifier.getValue()) { |
386 | op->emitOpError() << "cannot downgrade to version " << version.major_ |
387 | << "." << version.minor_ |
388 | << " since the modifier is not compatible" ; |
389 | return WalkResult::interrupt(); |
390 | } |
391 | llvm::outs() << "downgrading op...\n" ; |
392 | return WalkResult::advance(); |
393 | }); |
394 | return failure(isFailure: status.wasInterrupted()); |
395 | } |
396 | |
397 | // Test6: Downgrade IR to `targetVersion`, write to bytecode. Then, read and |
398 | // upgrade IR when back in memory. The module is expected to be unmodified at |
399 | // the end of the function. |
400 | void runTest6(Operation *op) { |
401 | test::TestDialectVersion targetEmissionVersion = targetVersion; |
402 | |
403 | // Downgrade IR constructs before writing the IR to bytecode. |
404 | auto status = downgradeToVersion(op, version: targetEmissionVersion); |
405 | assert(succeeded(status) && "expected the downgrade to succeed" ); |
406 | (void)status; |
407 | |
408 | BytecodeWriterConfig writeConfig; |
409 | writeConfig.setDialectVersion<test::TestDialect>( |
410 | std::make_unique<test::TestDialectVersion>(args&: targetEmissionVersion)); |
411 | ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true); |
412 | doRoundtripWithConfigs(op, writeConfig, parseConfig); |
413 | } |
414 | |
415 | test::TestDialect *testDialect; |
416 | }; |
417 | } // namespace |
418 | |
419 | namespace mlir { |
420 | void registerTestBytecodeRoundtripPasses() { |
421 | PassRegistration<TestBytecodeRoundtripPass>(); |
422 | } |
423 | } // namespace mlir |
424 | |