1//===- TestBytecodeCallbacks.cpp - Pass to test bytecode callback hooks --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "TestDialect.h"
10#include "TestOps.h"
11#include "mlir/Bytecode/BytecodeReader.h"
12#include "mlir/Bytecode/BytecodeWriter.h"
13#include "mlir/IR/BuiltinOps.h"
14#include "mlir/IR/OperationSupport.h"
15#include "mlir/Parser/Parser.h"
16#include "mlir/Pass/Pass.h"
17#include "llvm/Support/CommandLine.h"
18#include "llvm/Support/MemoryBufferRef.h"
19#include "llvm/Support/raw_ostream.h"
20#include <list>
21
22using namespace mlir;
23using namespace llvm;
24
25namespace {
26class TestDialectVersionParser : public cl::parser<test::TestDialectVersion> {
27public:
28 TestDialectVersionParser(cl::Option &o)
29 : cl::parser<test::TestDialectVersion>(o) {}
30
31 bool parse(cl::Option &o, StringRef /*argName*/, StringRef arg,
32 test::TestDialectVersion &v) {
33 long long major, minor;
34 if (getAsSignedInteger(Str: arg.split(Separator: ".").first, Radix: 10, Result&: major))
35 return o.error(Message: "Invalid argument '" + arg);
36 if (getAsSignedInteger(Str: arg.split(Separator: ".").second, Radix: 10, Result&: minor))
37 return o.error(Message: "Invalid argument '" + arg);
38 v = test::TestDialectVersion(major, minor);
39 // Returns true on error.
40 return false;
41 }
42 static void print(raw_ostream &os, const test::TestDialectVersion &v) {
43 os << v.major_ << "." << v.minor_;
44 };
45};
46
47/// This is a test pass which uses callbacks to encode attributes and types in a
48/// custom fashion.
49struct TestBytecodeRoundtripPass
50 : public PassWrapper<TestBytecodeRoundtripPass, OperationPass<ModuleOp>> {
51 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestBytecodeRoundtripPass)
52
53 StringRef getArgument() const final { return "test-bytecode-roundtrip"; }
54 StringRef getDescription() const final {
55 return "Test pass to implement bytecode roundtrip tests.";
56 }
57 void getDependentDialects(DialectRegistry &registry) const override {
58 registry.insert<test::TestDialect>();
59 }
60 TestBytecodeRoundtripPass() = default;
61 TestBytecodeRoundtripPass(const TestBytecodeRoundtripPass &) {}
62
63 LogicalResult initialize(MLIRContext *context) override {
64 testDialect = context->getOrLoadDialect<test::TestDialect>();
65 return success();
66 }
67
68 void runOnOperation() override {
69 switch (testKind) {
70 // Tests 0-5 implement a custom roundtrip with callbacks.
71 case (0):
72 return runTest0(getOperation());
73 case (1):
74 return runTest1(getOperation());
75 case (2):
76 return runTest2(getOperation());
77 case (3):
78 return runTest3(getOperation());
79 case (4):
80 return runTest4(getOperation());
81 case (5):
82 return runTest5(getOperation());
83 case (6):
84 // test-kind 6 is a plain roundtrip with downgrade/upgrade to/from
85 // `targetVersion`.
86 return runTest6(getOperation());
87 default:
88 llvm_unreachable("unhandled test kind for TestBytecodeCallbacks pass");
89 }
90 }
91
92 mlir::Pass::Option<test::TestDialectVersion, TestDialectVersionParser>
93 targetVersion{*this, "test-dialect-version",
94 llvm::cl::desc(
95 "Specifies the test dialect version to emit and parse"),
96 cl::init(Val: test::TestDialectVersion())};
97
98 mlir::Pass::Option<int> testKind{
99 *this, "test-kind", llvm::cl::desc("Specifies the test kind to execute"),
100 cl::init(Val: 0)};
101
102private:
103 void doRoundtripWithConfigs(Operation *op,
104 const BytecodeWriterConfig &writeConfig,
105 const ParserConfig &parseConfig) {
106 std::string bytecode;
107 llvm::raw_string_ostream os(bytecode);
108 if (failed(result: writeBytecodeToFile(op, os, config: writeConfig))) {
109 op->emitError() << "failed to write bytecode\n";
110 signalPassFailure();
111 return;
112 }
113 auto newModuleOp = parseSourceString(sourceStr: StringRef(bytecode), config: parseConfig);
114 if (!newModuleOp.get()) {
115 op->emitError() << "failed to read bytecode\n";
116 signalPassFailure();
117 return;
118 }
119 // Print the module to the output stream, so that we can filecheck the
120 // result.
121 newModuleOp->print(os&: llvm::outs());
122 }
123
124 // Test0: let's assume that versions older than 2.0 were relying on a special
125 // integer attribute of a deprecated dialect called "funky". Assume that its
126 // encoding was made by two varInts, the first was the ID (999) and the second
127 // contained width and signedness info. We can emit it using a callback
128 // writing a custom encoding for the "funky" dialect group, and parse it back
129 // with a custom parser reading the same encoding in the same dialect group.
130 // Note that the ID 999 does not correspond to a valid integer type in the
131 // current encodings of builtin types.
132 void runTest0(Operation *op) {
133 auto newCtx = std::make_shared<MLIRContext>();
134 test::TestDialectVersion targetEmissionVersion = targetVersion;
135 BytecodeWriterConfig writeConfig;
136 // Set the emission version for the test dialect.
137 writeConfig.setDialectVersion<test::TestDialect>(
138 std::make_unique<test::TestDialectVersion>(args&: targetEmissionVersion));
139 writeConfig.attachTypeCallback(
140 emitFn: [&](Type entryValue, std::optional<StringRef> &dialectGroupName,
141 DialectBytecodeWriter &writer) -> LogicalResult {
142 // Do not override anything if version greater than 2.0.
143 auto versionOr = writer.getDialectVersion<test::TestDialect>();
144 assert(succeeded(versionOr) && "expected reader to be able to access "
145 "the version for test dialect");
146 const auto *version =
147 reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
148 if (version->major_ >= 2)
149 return failure();
150
151 // For version less than 2.0, override the encoding of IntegerType.
152 if (auto type = llvm::dyn_cast<IntegerType>(entryValue)) {
153 llvm::outs() << "Overriding IntegerType encoding...\n";
154 dialectGroupName = StringLiteral("funky");
155 writer.writeVarInt(/* IntegerType */ value: 999);
156 writer.writeVarInt(value: type.getWidth() << 2 | type.getSignedness());
157 return success();
158 }
159 return failure();
160 });
161 newCtx->appendDialectRegistry(registry: op->getContext()->getDialectRegistry());
162 newCtx->allowUnregisteredDialects();
163 ParserConfig parseConfig(newCtx.get(), /*verifyAfterParse=*/true);
164 parseConfig.getBytecodeReaderConfig().attachTypeCallback(
165 parserFn: [&](DialectBytecodeReader &reader, StringRef dialectName,
166 Type &entry) -> LogicalResult {
167 // Get test dialect version from the version map.
168 auto versionOr = reader.getDialectVersion<test::TestDialect>();
169 assert(succeeded(versionOr) && "expected reader to be able to access "
170 "the version for test dialect");
171 const auto *version =
172 reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
173 if (version->major_ >= 2)
174 return success();
175
176 // `dialectName` is the name of the group we have the opportunity to
177 // override. In this case, override only the dialect group "funky",
178 // for which does not exist in memory.
179 if (dialectName != StringLiteral("funky"))
180 return success();
181
182 uint64_t encoding;
183 if (failed(result: reader.readVarInt(result&: encoding)) || encoding != 999)
184 return success();
185 llvm::outs() << "Overriding parsing of IntegerType encoding...\n";
186 uint64_t widthAndSignedness, width;
187 IntegerType::SignednessSemantics signedness;
188 if (succeeded(reader.readVarInt(widthAndSignedness)) &&
189 ((width = widthAndSignedness >> 2), true) &&
190 ((signedness = static_cast<IntegerType::SignednessSemantics>(
191 widthAndSignedness & 0x3)),
192 true))
193 entry = IntegerType::get(reader.getContext(), width, signedness);
194 // Return nullopt to fall through the rest of the parsing code path.
195 return success();
196 });
197 doRoundtripWithConfigs(op, writeConfig, parseConfig);
198 }
199
200 // Test1: When writing bytecode, we override the encoding of TestI32Type with
201 // the encoding of builtin IntegerType. We can natively parse this without
202 // the use of a callback, relying on the existing builtin reader mechanism.
203 void runTest1(Operation *op) {
204 auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
205 BytecodeDialectInterface *iface =
206 builtin->getRegisteredInterface<BytecodeDialectInterface>();
207 BytecodeWriterConfig writeConfig;
208 writeConfig.attachTypeCallback(
209 emitFn: [&](Type entryValue, std::optional<StringRef> &dialectGroupName,
210 DialectBytecodeWriter &writer) -> LogicalResult {
211 // Emit TestIntegerType using the builtin dialect encoding.
212 if (llvm::isa<test::TestI32Type>(entryValue)) {
213 llvm::outs() << "Overriding TestI32Type encoding...\n";
214 auto builtinI32Type =
215 IntegerType::get(op->getContext(), 32,
216 IntegerType::SignednessSemantics::Signless);
217 // Specify that this type will need to be written as part of the
218 // builtin group. This will override the default dialect group of
219 // the attribute (test).
220 dialectGroupName = StringLiteral("builtin");
221 if (succeeded(iface->writeType(type: builtinI32Type, writer)))
222 return success();
223 }
224 return failure();
225 });
226 // We natively parse the attribute as a builtin, so no callback needed.
227 ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true);
228 doRoundtripWithConfigs(op, writeConfig, parseConfig);
229 }
230
231 // Test2: When writing bytecode, we write standard builtin IntegerTypes. At
232 // parsing, we use the encoding of IntegerType to intercept all i32. Then,
233 // instead of creating i32s, we assemble TestI32Type and return it.
234 void runTest2(Operation *op) {
235 auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
236 BytecodeDialectInterface *iface =
237 builtin->getRegisteredInterface<BytecodeDialectInterface>();
238 BytecodeWriterConfig writeConfig;
239 ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true);
240 parseConfig.getBytecodeReaderConfig().attachTypeCallback(
241 parserFn: [&](DialectBytecodeReader &reader, StringRef dialectName,
242 Type &entry) -> LogicalResult {
243 if (dialectName != StringLiteral("builtin"))
244 return success();
245 Type builtinAttr = iface->readType(reader);
246 if (auto integerType =
247 llvm::dyn_cast_or_null<IntegerType>(builtinAttr)) {
248 if (integerType.getWidth() == 32 && integerType.isSignless()) {
249 llvm::outs() << "Overriding parsing of TestI32Type encoding...\n";
250 entry = test::TestI32Type::get(reader.getContext());
251 }
252 }
253 return success();
254 });
255 doRoundtripWithConfigs(op, writeConfig, parseConfig);
256 }
257
258 // Test3: When writing bytecode, we override the encoding of
259 // TestAttrParamsAttr with the encoding of builtin DenseIntElementsAttr. We
260 // can natively parse this without the use of a callback, relying on the
261 // existing builtin reader mechanism.
262 void runTest3(Operation *op) {
263 auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
264 BytecodeDialectInterface *iface =
265 builtin->getRegisteredInterface<BytecodeDialectInterface>();
266 auto i32Type = IntegerType::get(op->getContext(), 32,
267 IntegerType::SignednessSemantics::Signless);
268 BytecodeWriterConfig writeConfig;
269 writeConfig.attachAttributeCallback(
270 emitFn: [&](Attribute entryValue, std::optional<StringRef> &dialectGroupName,
271 DialectBytecodeWriter &writer) -> LogicalResult {
272 // Emit TestIntegerType using the builtin dialect encoding.
273 if (auto testParamAttrs =
274 llvm::dyn_cast<test::TestAttrParamsAttr>(entryValue)) {
275 llvm::outs() << "Overriding TestAttrParamsAttr encoding...\n";
276 // Specify that this attribute will need to be written as part of
277 // the builtin group. This will override the default dialect group
278 // of the attribute (test).
279 dialectGroupName = StringLiteral("builtin");
280 auto denseAttr = DenseIntElementsAttr::get(
281 RankedTensorType::get({2}, i32Type),
282 {testParamAttrs.getV0(), testParamAttrs.getV1()});
283 if (succeeded(iface->writeAttribute(attr: denseAttr, writer)))
284 return success();
285 }
286 return failure();
287 });
288 // We natively parse the attribute as a builtin, so no callback needed.
289 ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
290 doRoundtripWithConfigs(op, writeConfig, parseConfig);
291 }
292
293 // Test4: When writing bytecode, we write standard builtin
294 // DenseIntElementsAttr. At parsing, we use the encoding of
295 // DenseIntElementsAttr to intercept all ElementsAttr that have shaped type of
296 // <2xi32>. Instead of assembling a DenseIntElementsAttr, we assemble
297 // TestAttrParamsAttr and return it.
298 void runTest4(Operation *op) {
299 auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
300 BytecodeDialectInterface *iface =
301 builtin->getRegisteredInterface<BytecodeDialectInterface>();
302 auto i32Type = IntegerType::get(op->getContext(), 32,
303 IntegerType::SignednessSemantics::Signless);
304 BytecodeWriterConfig writeConfig;
305 ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
306 parseConfig.getBytecodeReaderConfig().attachAttributeCallback(
307 parserFn: [&](DialectBytecodeReader &reader, StringRef dialectName,
308 Attribute &entry) -> LogicalResult {
309 // Override only the case where the return type of the builtin reader
310 // is an i32 and fall through on all the other cases, since we want to
311 // still use TestDialect normal codepath to parse the other types.
312 Attribute builtinAttr = iface->readAttribute(reader);
313 if (auto denseAttr =
314 llvm::dyn_cast_or_null<DenseIntElementsAttr>(builtinAttr)) {
315 if (denseAttr.getType().getShape() == ArrayRef<int64_t>(2) &&
316 denseAttr.getElementType() == i32Type) {
317 llvm::outs()
318 << "Overriding parsing of TestAttrParamsAttr encoding...\n";
319 int v0 = denseAttr.getValues<IntegerAttr>()[0].getInt();
320 int v1 = denseAttr.getValues<IntegerAttr>()[1].getInt();
321 entry =
322 test::TestAttrParamsAttr::get(reader.getContext(), v0, v1);
323 }
324 }
325 return success();
326 });
327 doRoundtripWithConfigs(op, writeConfig, parseConfig);
328 }
329
330 // Test5: When writing bytecode, we want TestDialect to use nothing else than
331 // the builtin types and attributes and take full control of the encoding,
332 // returning failure if any type or attribute is not part of builtin.
333 void runTest5(Operation *op) {
334 auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
335 BytecodeDialectInterface *iface =
336 builtin->getRegisteredInterface<BytecodeDialectInterface>();
337 BytecodeWriterConfig writeConfig;
338 writeConfig.attachAttributeCallback(
339 emitFn: [&](Attribute attr, std::optional<StringRef> &dialectGroupName,
340 DialectBytecodeWriter &writer) -> LogicalResult {
341 return iface->writeAttribute(attr, writer);
342 });
343 writeConfig.attachTypeCallback(
344 emitFn: [&](Type type, std::optional<StringRef> &dialectGroupName,
345 DialectBytecodeWriter &writer) -> LogicalResult {
346 return iface->writeType(type, writer);
347 });
348 ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false);
349 parseConfig.getBytecodeReaderConfig().attachAttributeCallback(
350 parserFn: [&](DialectBytecodeReader &reader, StringRef dialectName,
351 Attribute &entry) -> LogicalResult {
352 Attribute builtinAttr = iface->readAttribute(reader);
353 if (!builtinAttr)
354 return failure();
355 entry = builtinAttr;
356 return success();
357 });
358 parseConfig.getBytecodeReaderConfig().attachTypeCallback(
359 parserFn: [&](DialectBytecodeReader &reader, StringRef dialectName,
360 Type &entry) -> LogicalResult {
361 Type builtinType = iface->readType(reader);
362 if (!builtinType) {
363 return failure();
364 }
365 entry = builtinType;
366 return success();
367 });
368 doRoundtripWithConfigs(op, writeConfig, parseConfig);
369 }
370
371 LogicalResult downgradeToVersion(Operation *op,
372 const test::TestDialectVersion &version) {
373 if ((version.major_ == 2) && (version.minor_ == 0))
374 return success();
375 if (version.major_ > 2 || (version.major_ == 2 && version.minor_ > 0)) {
376 return op->emitError() << "current test dialect version is 2.0, "
377 "can't downgrade to version: "
378 << version.major_ << "." << version.minor_;
379 }
380 // Prior version 2.0, the old op supported only a single attribute called
381 // "dimensions". We need to check that the modifier is false, otherwise we
382 // can't do the downgrade.
383 auto status = op->walk(callback: [&](test::TestVersionedOpA op) {
384 auto &prop = op.getProperties();
385 if (prop.modifier.getValue()) {
386 op->emitOpError() << "cannot downgrade to version " << version.major_
387 << "." << version.minor_
388 << " since the modifier is not compatible";
389 return WalkResult::interrupt();
390 }
391 llvm::outs() << "downgrading op...\n";
392 return WalkResult::advance();
393 });
394 return failure(isFailure: status.wasInterrupted());
395 }
396
397 // Test6: Downgrade IR to `targetVersion`, write to bytecode. Then, read and
398 // upgrade IR when back in memory. The module is expected to be unmodified at
399 // the end of the function.
400 void runTest6(Operation *op) {
401 test::TestDialectVersion targetEmissionVersion = targetVersion;
402
403 // Downgrade IR constructs before writing the IR to bytecode.
404 auto status = downgradeToVersion(op, version: targetEmissionVersion);
405 assert(succeeded(status) && "expected the downgrade to succeed");
406 (void)status;
407
408 BytecodeWriterConfig writeConfig;
409 writeConfig.setDialectVersion<test::TestDialect>(
410 std::make_unique<test::TestDialectVersion>(args&: targetEmissionVersion));
411 ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true);
412 doRoundtripWithConfigs(op, writeConfig, parseConfig);
413 }
414
415 test::TestDialect *testDialect;
416};
417} // namespace
418
419namespace mlir {
420void registerTestBytecodeRoundtripPasses() {
421 PassRegistration<TestBytecodeRoundtripPass>();
422}
423} // namespace mlir
424

source code of mlir/test/lib/IR/TestBytecodeRoundtrip.cpp