| 1 | //===- TestBytecodeCallbacks.cpp - Pass to test bytecode callback hooks --===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "TestDialect.h" |
| 10 | #include "TestOps.h" |
| 11 | #include "mlir/Bytecode/BytecodeReader.h" |
| 12 | #include "mlir/Bytecode/BytecodeWriter.h" |
| 13 | #include "mlir/IR/BuiltinOps.h" |
| 14 | #include "mlir/IR/OperationSupport.h" |
| 15 | #include "mlir/Parser/Parser.h" |
| 16 | #include "mlir/Pass/Pass.h" |
| 17 | #include "llvm/Support/CommandLine.h" |
| 18 | #include "llvm/Support/MemoryBufferRef.h" |
| 19 | #include "llvm/Support/raw_ostream.h" |
| 20 | #include <list> |
| 21 | |
| 22 | using namespace mlir; |
| 23 | using namespace llvm; |
| 24 | |
| 25 | namespace { |
| 26 | class TestDialectVersionParser : public cl::parser<test::TestDialectVersion> { |
| 27 | public: |
| 28 | TestDialectVersionParser(cl::Option &o) |
| 29 | : cl::parser<test::TestDialectVersion>(o) {} |
| 30 | |
| 31 | bool parse(cl::Option &o, StringRef /*argName*/, StringRef arg, |
| 32 | test::TestDialectVersion &v) { |
| 33 | long long major, minor; |
| 34 | if (getAsSignedInteger(Str: arg.split(Separator: "." ).first, Radix: 10, Result&: major)) |
| 35 | return o.error(Message: "Invalid argument '" + arg); |
| 36 | if (getAsSignedInteger(Str: arg.split(Separator: "." ).second, Radix: 10, Result&: minor)) |
| 37 | return o.error(Message: "Invalid argument '" + arg); |
| 38 | v = test::TestDialectVersion(major, minor); |
| 39 | // Returns true on error. |
| 40 | return false; |
| 41 | } |
| 42 | static void print(raw_ostream &os, const test::TestDialectVersion &v) { |
| 43 | os << v.major_ << "." << v.minor_; |
| 44 | }; |
| 45 | }; |
| 46 | |
| 47 | /// This is a test pass which uses callbacks to encode attributes and types in a |
| 48 | /// custom fashion. |
| 49 | struct TestBytecodeRoundtripPass |
| 50 | : public PassWrapper<TestBytecodeRoundtripPass, OperationPass<ModuleOp>> { |
| 51 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestBytecodeRoundtripPass) |
| 52 | |
| 53 | StringRef getArgument() const final { return "test-bytecode-roundtrip" ; } |
| 54 | StringRef getDescription() const final { |
| 55 | return "Test pass to implement bytecode roundtrip tests." ; |
| 56 | } |
| 57 | void getDependentDialects(DialectRegistry ®istry) const override { |
| 58 | registry.insert<test::TestDialect>(); |
| 59 | } |
| 60 | TestBytecodeRoundtripPass() = default; |
| 61 | TestBytecodeRoundtripPass(const TestBytecodeRoundtripPass &) {} |
| 62 | |
| 63 | LogicalResult initialize(MLIRContext *context) override { |
| 64 | testDialect = context->getOrLoadDialect<test::TestDialect>(); |
| 65 | return success(); |
| 66 | } |
| 67 | |
| 68 | void runOnOperation() override { |
| 69 | switch (testKind) { |
| 70 | // Tests 0-5 implement a custom roundtrip with callbacks. |
| 71 | case (0): |
| 72 | return runTest0(getOperation()); |
| 73 | case (1): |
| 74 | return runTest1(getOperation()); |
| 75 | case (2): |
| 76 | return runTest2(getOperation()); |
| 77 | case (3): |
| 78 | return runTest3(getOperation()); |
| 79 | case (4): |
| 80 | return runTest4(getOperation()); |
| 81 | case (5): |
| 82 | return runTest5(getOperation()); |
| 83 | case (6): |
| 84 | // test-kind 6 is a plain roundtrip with downgrade/upgrade to/from |
| 85 | // `targetVersion`. |
| 86 | return runTest6(getOperation()); |
| 87 | default: |
| 88 | llvm_unreachable("unhandled test kind for TestBytecodeCallbacks pass" ); |
| 89 | } |
| 90 | } |
| 91 | |
| 92 | mlir::Pass::Option<test::TestDialectVersion, TestDialectVersionParser> |
| 93 | targetVersion{*this, "test-dialect-version" , |
| 94 | llvm::cl::desc( |
| 95 | "Specifies the test dialect version to emit and parse" ), |
| 96 | cl::init(Val: test::TestDialectVersion())}; |
| 97 | |
| 98 | mlir::Pass::Option<int> testKind{ |
| 99 | *this, "test-kind" , llvm::cl::desc("Specifies the test kind to execute" ), |
| 100 | cl::init(Val: 0)}; |
| 101 | |
| 102 | private: |
| 103 | void doRoundtripWithConfigs(Operation *op, |
| 104 | const BytecodeWriterConfig &writeConfig, |
| 105 | const ParserConfig &parseConfig) { |
| 106 | std::string bytecode; |
| 107 | llvm::raw_string_ostream os(bytecode); |
| 108 | if (failed(Result: writeBytecodeToFile(op, os, config: writeConfig))) { |
| 109 | op->emitError() << "failed to write bytecode\n" ; |
| 110 | signalPassFailure(); |
| 111 | return; |
| 112 | } |
| 113 | auto newModuleOp = parseSourceString(sourceStr: StringRef(bytecode), config: parseConfig); |
| 114 | if (!newModuleOp.get()) { |
| 115 | op->emitError() << "failed to read bytecode\n" ; |
| 116 | signalPassFailure(); |
| 117 | return; |
| 118 | } |
| 119 | // Print the module to the output stream, so that we can filecheck the |
| 120 | // result. |
| 121 | newModuleOp->print(os&: llvm::outs()); |
| 122 | } |
| 123 | |
| 124 | // Test0: let's assume that versions older than 2.0 were relying on a special |
| 125 | // integer attribute of a deprecated dialect called "funky". Assume that its |
| 126 | // encoding was made by two varInts, the first was the ID (999) and the second |
| 127 | // contained width and signedness info. We can emit it using a callback |
| 128 | // writing a custom encoding for the "funky" dialect group, and parse it back |
| 129 | // with a custom parser reading the same encoding in the same dialect group. |
| 130 | // Note that the ID 999 does not correspond to a valid integer type in the |
| 131 | // current encodings of builtin types. |
| 132 | void runTest0(Operation *op) { |
| 133 | auto newCtx = std::make_shared<MLIRContext>(); |
| 134 | test::TestDialectVersion targetEmissionVersion = targetVersion; |
| 135 | BytecodeWriterConfig writeConfig; |
| 136 | // Set the emission version for the test dialect. |
| 137 | writeConfig.setDialectVersion<test::TestDialect>( |
| 138 | std::make_unique<test::TestDialectVersion>(args&: targetEmissionVersion)); |
| 139 | writeConfig.attachTypeCallback( |
| 140 | emitFn: [&](Type entryValue, std::optional<StringRef> &dialectGroupName, |
| 141 | DialectBytecodeWriter &writer) -> LogicalResult { |
| 142 | // Do not override anything if version greater than 2.0. |
| 143 | auto versionOr = writer.getDialectVersion<test::TestDialect>(); |
| 144 | assert(succeeded(versionOr) && "expected reader to be able to access " |
| 145 | "the version for test dialect" ); |
| 146 | const auto *version = |
| 147 | reinterpret_cast<const test::TestDialectVersion *>(*versionOr); |
| 148 | if (version->major_ >= 2) |
| 149 | return failure(); |
| 150 | |
| 151 | // For version less than 2.0, override the encoding of IntegerType. |
| 152 | if (auto type = llvm::dyn_cast<IntegerType>(entryValue)) { |
| 153 | llvm::outs() << "Overriding IntegerType encoding...\n" ; |
| 154 | dialectGroupName = StringLiteral("funky" ); |
| 155 | writer.writeVarInt(/* IntegerType */ value: 999); |
| 156 | writer.writeVarInt(value: type.getWidth() << 2 | type.getSignedness()); |
| 157 | return success(); |
| 158 | } |
| 159 | return failure(); |
| 160 | }); |
| 161 | newCtx->appendDialectRegistry(registry: op->getContext()->getDialectRegistry()); |
| 162 | newCtx->allowUnregisteredDialects(); |
| 163 | ParserConfig parseConfig(newCtx.get(), /*verifyAfterParse=*/true); |
| 164 | parseConfig.getBytecodeReaderConfig().attachTypeCallback( |
| 165 | parserFn: [&](DialectBytecodeReader &reader, StringRef dialectName, |
| 166 | Type &entry) -> LogicalResult { |
| 167 | // Get test dialect version from the version map. |
| 168 | auto versionOr = reader.getDialectVersion<test::TestDialect>(); |
| 169 | assert(succeeded(versionOr) && "expected reader to be able to access " |
| 170 | "the version for test dialect" ); |
| 171 | const auto *version = |
| 172 | reinterpret_cast<const test::TestDialectVersion *>(*versionOr); |
| 173 | if (version->major_ >= 2) |
| 174 | return success(); |
| 175 | |
| 176 | // `dialectName` is the name of the group we have the opportunity to |
| 177 | // override. In this case, override only the dialect group "funky", |
| 178 | // for which does not exist in memory. |
| 179 | if (dialectName != StringLiteral("funky" )) |
| 180 | return success(); |
| 181 | |
| 182 | uint64_t encoding; |
| 183 | if (failed(Result: reader.readVarInt(result&: encoding)) || encoding != 999) |
| 184 | return success(); |
| 185 | llvm::outs() << "Overriding parsing of IntegerType encoding...\n" ; |
| 186 | uint64_t widthAndSignedness, width; |
| 187 | IntegerType::SignednessSemantics signedness; |
| 188 | if (succeeded(reader.readVarInt(widthAndSignedness)) && |
| 189 | ((width = widthAndSignedness >> 2), true) && |
| 190 | ((signedness = static_cast<IntegerType::SignednessSemantics>( |
| 191 | widthAndSignedness & 0x3)), |
| 192 | true)) |
| 193 | entry = IntegerType::get(reader.getContext(), width, signedness); |
| 194 | // Return nullopt to fall through the rest of the parsing code path. |
| 195 | return success(); |
| 196 | }); |
| 197 | doRoundtripWithConfigs(op, writeConfig, parseConfig); |
| 198 | } |
| 199 | |
| 200 | // Test1: When writing bytecode, we override the encoding of TestI32Type with |
| 201 | // the encoding of builtin IntegerType. We can natively parse this without |
| 202 | // the use of a callback, relying on the existing builtin reader mechanism. |
| 203 | void runTest1(Operation *op) { |
| 204 | auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>(); |
| 205 | BytecodeDialectInterface *iface = |
| 206 | builtin->getRegisteredInterface<BytecodeDialectInterface>(); |
| 207 | BytecodeWriterConfig writeConfig; |
| 208 | writeConfig.attachTypeCallback( |
| 209 | emitFn: [&](Type entryValue, std::optional<StringRef> &dialectGroupName, |
| 210 | DialectBytecodeWriter &writer) -> LogicalResult { |
| 211 | // Emit TestIntegerType using the builtin dialect encoding. |
| 212 | if (llvm::isa<test::TestI32Type>(entryValue)) { |
| 213 | llvm::outs() << "Overriding TestI32Type encoding...\n" ; |
| 214 | auto builtinI32Type = |
| 215 | IntegerType::get(op->getContext(), 32, |
| 216 | IntegerType::SignednessSemantics::Signless); |
| 217 | // Specify that this type will need to be written as part of the |
| 218 | // builtin group. This will override the default dialect group of |
| 219 | // the attribute (test). |
| 220 | dialectGroupName = StringLiteral("builtin" ); |
| 221 | if (succeeded(iface->writeType(type: builtinI32Type, writer))) |
| 222 | return success(); |
| 223 | } |
| 224 | return failure(); |
| 225 | }); |
| 226 | // We natively parse the attribute as a builtin, so no callback needed. |
| 227 | ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true); |
| 228 | doRoundtripWithConfigs(op, writeConfig, parseConfig); |
| 229 | } |
| 230 | |
| 231 | // Test2: When writing bytecode, we write standard builtin IntegerTypes. At |
| 232 | // parsing, we use the encoding of IntegerType to intercept all i32. Then, |
| 233 | // instead of creating i32s, we assemble TestI32Type and return it. |
| 234 | void runTest2(Operation *op) { |
| 235 | auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>(); |
| 236 | BytecodeDialectInterface *iface = |
| 237 | builtin->getRegisteredInterface<BytecodeDialectInterface>(); |
| 238 | BytecodeWriterConfig writeConfig; |
| 239 | ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true); |
| 240 | parseConfig.getBytecodeReaderConfig().attachTypeCallback( |
| 241 | parserFn: [&](DialectBytecodeReader &reader, StringRef dialectName, |
| 242 | Type &entry) -> LogicalResult { |
| 243 | if (dialectName != StringLiteral("builtin" )) |
| 244 | return success(); |
| 245 | Type builtinAttr = iface->readType(reader); |
| 246 | if (auto integerType = |
| 247 | llvm::dyn_cast_or_null<IntegerType>(builtinAttr)) { |
| 248 | if (integerType.getWidth() == 32 && integerType.isSignless()) { |
| 249 | llvm::outs() << "Overriding parsing of TestI32Type encoding...\n" ; |
| 250 | entry = test::TestI32Type::get(reader.getContext()); |
| 251 | } |
| 252 | } |
| 253 | return success(); |
| 254 | }); |
| 255 | doRoundtripWithConfigs(op, writeConfig, parseConfig); |
| 256 | } |
| 257 | |
| 258 | // Test3: When writing bytecode, we override the encoding of |
| 259 | // TestAttrParamsAttr with the encoding of builtin DenseIntElementsAttr. We |
| 260 | // can natively parse this without the use of a callback, relying on the |
| 261 | // existing builtin reader mechanism. |
| 262 | void runTest3(Operation *op) { |
| 263 | auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>(); |
| 264 | BytecodeDialectInterface *iface = |
| 265 | builtin->getRegisteredInterface<BytecodeDialectInterface>(); |
| 266 | auto i32Type = IntegerType::get(op->getContext(), 32, |
| 267 | IntegerType::SignednessSemantics::Signless); |
| 268 | BytecodeWriterConfig writeConfig; |
| 269 | writeConfig.attachAttributeCallback( |
| 270 | emitFn: [&](Attribute entryValue, std::optional<StringRef> &dialectGroupName, |
| 271 | DialectBytecodeWriter &writer) -> LogicalResult { |
| 272 | // Emit TestIntegerType using the builtin dialect encoding. |
| 273 | if (auto testParamAttrs = |
| 274 | llvm::dyn_cast<test::TestAttrParamsAttr>(entryValue)) { |
| 275 | llvm::outs() << "Overriding TestAttrParamsAttr encoding...\n" ; |
| 276 | // Specify that this attribute will need to be written as part of |
| 277 | // the builtin group. This will override the default dialect group |
| 278 | // of the attribute (test). |
| 279 | dialectGroupName = StringLiteral("builtin" ); |
| 280 | auto denseAttr = DenseIntElementsAttr::get( |
| 281 | RankedTensorType::get({2}, i32Type), |
| 282 | {testParamAttrs.getV0(), testParamAttrs.getV1()}); |
| 283 | if (succeeded(iface->writeAttribute(attr: denseAttr, writer))) |
| 284 | return success(); |
| 285 | } |
| 286 | return failure(); |
| 287 | }); |
| 288 | // We natively parse the attribute as a builtin, so no callback needed. |
| 289 | ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false); |
| 290 | doRoundtripWithConfigs(op, writeConfig, parseConfig); |
| 291 | } |
| 292 | |
| 293 | // Test4: When writing bytecode, we write standard builtin |
| 294 | // DenseIntElementsAttr. At parsing, we use the encoding of |
| 295 | // DenseIntElementsAttr to intercept all ElementsAttr that have shaped type of |
| 296 | // <2xi32>. Instead of assembling a DenseIntElementsAttr, we assemble |
| 297 | // TestAttrParamsAttr and return it. |
| 298 | void runTest4(Operation *op) { |
| 299 | auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>(); |
| 300 | BytecodeDialectInterface *iface = |
| 301 | builtin->getRegisteredInterface<BytecodeDialectInterface>(); |
| 302 | auto i32Type = IntegerType::get(op->getContext(), 32, |
| 303 | IntegerType::SignednessSemantics::Signless); |
| 304 | BytecodeWriterConfig writeConfig; |
| 305 | ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false); |
| 306 | parseConfig.getBytecodeReaderConfig().attachAttributeCallback( |
| 307 | parserFn: [&](DialectBytecodeReader &reader, StringRef dialectName, |
| 308 | Attribute &entry) -> LogicalResult { |
| 309 | // Override only the case where the return type of the builtin reader |
| 310 | // is an i32 and fall through on all the other cases, since we want to |
| 311 | // still use TestDialect normal codepath to parse the other types. |
| 312 | Attribute builtinAttr = iface->readAttribute(reader); |
| 313 | if (auto denseAttr = |
| 314 | llvm::dyn_cast_or_null<DenseIntElementsAttr>(builtinAttr)) { |
| 315 | if (denseAttr.getType().getShape() == ArrayRef<int64_t>(2) && |
| 316 | denseAttr.getElementType() == i32Type) { |
| 317 | llvm::outs() |
| 318 | << "Overriding parsing of TestAttrParamsAttr encoding...\n" ; |
| 319 | int v0 = denseAttr.getValues<IntegerAttr>()[0].getInt(); |
| 320 | int v1 = denseAttr.getValues<IntegerAttr>()[1].getInt(); |
| 321 | entry = |
| 322 | test::TestAttrParamsAttr::get(reader.getContext(), v0, v1); |
| 323 | } |
| 324 | } |
| 325 | return success(); |
| 326 | }); |
| 327 | doRoundtripWithConfigs(op, writeConfig, parseConfig); |
| 328 | } |
| 329 | |
| 330 | // Test5: When writing bytecode, we want TestDialect to use nothing else than |
| 331 | // the builtin types and attributes and take full control of the encoding, |
| 332 | // returning failure if any type or attribute is not part of builtin. |
| 333 | void runTest5(Operation *op) { |
| 334 | auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>(); |
| 335 | BytecodeDialectInterface *iface = |
| 336 | builtin->getRegisteredInterface<BytecodeDialectInterface>(); |
| 337 | BytecodeWriterConfig writeConfig; |
| 338 | writeConfig.attachAttributeCallback( |
| 339 | emitFn: [&](Attribute attr, std::optional<StringRef> &dialectGroupName, |
| 340 | DialectBytecodeWriter &writer) -> LogicalResult { |
| 341 | return iface->writeAttribute(attr, writer); |
| 342 | }); |
| 343 | writeConfig.attachTypeCallback( |
| 344 | emitFn: [&](Type type, std::optional<StringRef> &dialectGroupName, |
| 345 | DialectBytecodeWriter &writer) -> LogicalResult { |
| 346 | return iface->writeType(type, writer); |
| 347 | }); |
| 348 | ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false); |
| 349 | parseConfig.getBytecodeReaderConfig().attachAttributeCallback( |
| 350 | parserFn: [&](DialectBytecodeReader &reader, StringRef dialectName, |
| 351 | Attribute &entry) -> LogicalResult { |
| 352 | Attribute builtinAttr = iface->readAttribute(reader); |
| 353 | if (!builtinAttr) |
| 354 | return failure(); |
| 355 | entry = builtinAttr; |
| 356 | return success(); |
| 357 | }); |
| 358 | parseConfig.getBytecodeReaderConfig().attachTypeCallback( |
| 359 | parserFn: [&](DialectBytecodeReader &reader, StringRef dialectName, |
| 360 | Type &entry) -> LogicalResult { |
| 361 | Type builtinType = iface->readType(reader); |
| 362 | if (!builtinType) { |
| 363 | return failure(); |
| 364 | } |
| 365 | entry = builtinType; |
| 366 | return success(); |
| 367 | }); |
| 368 | doRoundtripWithConfigs(op, writeConfig, parseConfig); |
| 369 | } |
| 370 | |
| 371 | LogicalResult downgradeToVersion(Operation *op, |
| 372 | const test::TestDialectVersion &version) { |
| 373 | if ((version.major_ == 2) && (version.minor_ == 0)) |
| 374 | return success(); |
| 375 | if (version.major_ > 2 || (version.major_ == 2 && version.minor_ > 0)) { |
| 376 | return op->emitError() << "current test dialect version is 2.0, " |
| 377 | "can't downgrade to version: " |
| 378 | << version.major_ << "." << version.minor_; |
| 379 | } |
| 380 | // Prior version 2.0, the old op supported only a single attribute called |
| 381 | // "dimensions". We need to check that the modifier is false, otherwise we |
| 382 | // can't do the downgrade. |
| 383 | auto status = op->walk(callback: [&](test::TestVersionedOpA op) { |
| 384 | auto &prop = op.getProperties(); |
| 385 | if (prop.modifier.getValue()) { |
| 386 | op->emitOpError() << "cannot downgrade to version " << version.major_ |
| 387 | << "." << version.minor_ |
| 388 | << " since the modifier is not compatible" ; |
| 389 | return WalkResult::interrupt(); |
| 390 | } |
| 391 | llvm::outs() << "downgrading op...\n" ; |
| 392 | return WalkResult::advance(); |
| 393 | }); |
| 394 | return failure(IsFailure: status.wasInterrupted()); |
| 395 | } |
| 396 | |
| 397 | // Test6: Downgrade IR to `targetVersion`, write to bytecode. Then, read and |
| 398 | // upgrade IR when back in memory. The module is expected to be unmodified at |
| 399 | // the end of the function. |
| 400 | void runTest6(Operation *op) { |
| 401 | test::TestDialectVersion targetEmissionVersion = targetVersion; |
| 402 | |
| 403 | // Downgrade IR constructs before writing the IR to bytecode. |
| 404 | auto status = downgradeToVersion(op, version: targetEmissionVersion); |
| 405 | assert(succeeded(status) && "expected the downgrade to succeed" ); |
| 406 | (void)status; |
| 407 | |
| 408 | BytecodeWriterConfig writeConfig; |
| 409 | writeConfig.setDialectVersion<test::TestDialect>( |
| 410 | std::make_unique<test::TestDialectVersion>(args&: targetEmissionVersion)); |
| 411 | ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true); |
| 412 | doRoundtripWithConfigs(op, writeConfig, parseConfig); |
| 413 | } |
| 414 | |
| 415 | test::TestDialect *testDialect; |
| 416 | }; |
| 417 | } // namespace |
| 418 | |
| 419 | namespace mlir { |
| 420 | void registerTestBytecodeRoundtripPasses() { |
| 421 | PassRegistration<TestBytecodeRoundtripPass>(); |
| 422 | } |
| 423 | } // namespace mlir |
| 424 | |