1//===- TestDialectInterfaces.cpp - Test dialect interface definitions -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "TestDialect.h"
10#include "TestOps.h"
11#include "mlir/Interfaces/FoldInterfaces.h"
12#include "mlir/Reducer/ReductionPatternInterface.h"
13#include "mlir/Transforms/InliningUtils.h"
14
15using namespace mlir;
16using namespace test;
17
18//===----------------------------------------------------------------------===//
19// TestDialect Interfaces
20//===----------------------------------------------------------------------===//
21
22namespace {
23
24/// Testing the correctness of some traits.
25static_assert(
26 llvm::is_detected<OpTrait::has_implicit_terminator_t,
27 SingleBlockImplicitTerminatorOp>::value,
28 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
29static_assert(OpTrait::hasSingleBlockImplicitTerminator<
30 SingleBlockImplicitTerminatorOp>::value,
31 "hasSingleBlockImplicitTerminator does not match "
32 "SingleBlockImplicitTerminatorOp");
33
34struct TestResourceBlobManagerInterface
35 : public ResourceBlobManagerDialectInterfaceBase<
36 TestDialectResourceBlobHandle> {
37 using ResourceBlobManagerDialectInterfaceBase<
38 TestDialectResourceBlobHandle>::ResourceBlobManagerDialectInterfaceBase;
39};
40
41namespace {
42enum test_encoding { k_attr_params = 0, k_test_i32 = 99 };
43} // namespace
44
45// Test support for interacting with the Bytecode reader/writer.
46struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
47 using BytecodeDialectInterface::BytecodeDialectInterface;
48 TestBytecodeDialectInterface(Dialect *dialect)
49 : BytecodeDialectInterface(dialect) {}
50
51 LogicalResult writeType(Type type,
52 DialectBytecodeWriter &writer) const final {
53 if (auto concreteType = llvm::dyn_cast<TestI32Type>(type)) {
54 writer.writeVarInt(value: test_encoding::k_test_i32);
55 return success();
56 }
57 return failure();
58 }
59
60 Type readType(DialectBytecodeReader &reader) const final {
61 uint64_t encoding;
62 if (failed(result: reader.readVarInt(result&: encoding)))
63 return Type();
64 if (encoding == test_encoding::k_test_i32)
65 return TestI32Type::get(getContext());
66 return Type();
67 }
68
69 LogicalResult writeAttribute(Attribute attr,
70 DialectBytecodeWriter &writer) const final {
71 if (auto concreteAttr = llvm::dyn_cast<TestAttrParamsAttr>(attr)) {
72 writer.writeVarInt(value: test_encoding::k_attr_params);
73 writer.writeVarInt(value: concreteAttr.getV0());
74 writer.writeVarInt(value: concreteAttr.getV1());
75 return success();
76 }
77 return failure();
78 }
79
80 Attribute readAttribute(DialectBytecodeReader &reader) const final {
81 auto versionOr = reader.getDialectVersion<test::TestDialect>();
82 // Assume current version if not available through the reader.
83 const auto version =
84 (succeeded(result: versionOr))
85 ? *reinterpret_cast<const TestDialectVersion *>(*versionOr)
86 : TestDialectVersion();
87 if (version.major_ < 2)
88 return readAttrOldEncoding(reader);
89 if (version.major_ == 2 && version.minor_ == 0)
90 return readAttrNewEncoding(reader);
91 // Forbid reading future versions by returning nullptr.
92 return Attribute();
93 }
94
95 // Emit a specific version of the dialect.
96 void writeVersion(DialectBytecodeWriter &writer) const final {
97 // Construct the current dialect version.
98 test::TestDialectVersion versionToEmit;
99
100 // Check if a target version to emit was specified on the writer configs.
101 auto versionOr = writer.getDialectVersion<test::TestDialect>();
102 if (succeeded(result: versionOr))
103 versionToEmit =
104 *reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
105 writer.writeVarInt(value: versionToEmit.major_); // major
106 writer.writeVarInt(value: versionToEmit.minor_); // minor
107 }
108
109 std::unique_ptr<DialectVersion>
110 readVersion(DialectBytecodeReader &reader) const final {
111 uint64_t major_, minor_;
112 if (failed(result: reader.readVarInt(result&: major_)) || failed(result: reader.readVarInt(result&: minor_)))
113 return nullptr;
114 auto version = std::make_unique<TestDialectVersion>();
115 version->major_ = major_;
116 version->minor_ = minor_;
117 return version;
118 }
119
120 LogicalResult upgradeFromVersion(Operation *topLevelOp,
121 const DialectVersion &version_) const final {
122 const auto &version = static_cast<const TestDialectVersion &>(version_);
123 if ((version.major_ == 2) && (version.minor_ == 0))
124 return success();
125 if (version.major_ > 2 || (version.major_ == 2 && version.minor_ > 0)) {
126 return topLevelOp->emitError()
127 << "current test dialect version is 2.0, can't parse version: "
128 << version.major_ << "." << version.minor_;
129 }
130 // Prior version 2.0, the old op supported only a single attribute called
131 // "dimensions". We can perform the upgrade.
132 topLevelOp->walk(callback: [](TestVersionedOpA op) {
133 // Prior version 2.0, `readProperties` did not process the modifier
134 // attribute. Handle that according to the version here.
135 auto &prop = op.getProperties();
136 prop.modifier = BoolAttr::get(context: op->getContext(), value: false);
137 });
138 return success();
139 }
140
141private:
142 Attribute readAttrNewEncoding(DialectBytecodeReader &reader) const {
143 uint64_t encoding;
144 if (failed(result: reader.readVarInt(result&: encoding)) ||
145 encoding != test_encoding::k_attr_params)
146 return Attribute();
147 // The new encoding has v0 first, v1 second.
148 uint64_t v0, v1;
149 if (failed(result: reader.readVarInt(result&: v0)) || failed(result: reader.readVarInt(result&: v1)))
150 return Attribute();
151 return TestAttrParamsAttr::get(getContext(), static_cast<int>(v0),
152 static_cast<int>(v1));
153 }
154
155 Attribute readAttrOldEncoding(DialectBytecodeReader &reader) const {
156 uint64_t encoding;
157 if (failed(result: reader.readVarInt(result&: encoding)) ||
158 encoding != test_encoding::k_attr_params)
159 return Attribute();
160 // The old encoding has v1 first, v0 second.
161 uint64_t v0, v1;
162 if (failed(result: reader.readVarInt(result&: v1)) || failed(result: reader.readVarInt(result&: v0)))
163 return Attribute();
164 return TestAttrParamsAttr::get(getContext(), static_cast<int>(v0),
165 static_cast<int>(v1));
166 }
167};
168
169// Test support for interacting with the AsmPrinter.
170struct TestOpAsmInterface : public OpAsmDialectInterface {
171 using OpAsmDialectInterface::OpAsmDialectInterface;
172 TestOpAsmInterface(Dialect *dialect, TestResourceBlobManagerInterface &mgr)
173 : OpAsmDialectInterface(dialect), blobManager(mgr) {}
174
175 //===------------------------------------------------------------------===//
176 // Aliases
177 //===------------------------------------------------------------------===//
178
179 AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
180 StringAttr strAttr = dyn_cast<StringAttr>(attr);
181 if (!strAttr)
182 return AliasResult::NoAlias;
183
184 // Check the contents of the string attribute to see what the test alias
185 // should be named.
186 std::optional<StringRef> aliasName =
187 StringSwitch<std::optional<StringRef>>(strAttr.getValue())
188 .Case(S: "alias_test:dot_in_name", Value: StringRef("test.alias"))
189 .Case(S: "alias_test:trailing_digit", Value: StringRef("test_alias0"))
190 .Case(S: "alias_test:prefixed_digit", Value: StringRef("0_test_alias"))
191 .Case(S: "alias_test:sanitize_conflict_a",
192 Value: StringRef("test_alias_conflict0"))
193 .Case(S: "alias_test:sanitize_conflict_b",
194 Value: StringRef("test_alias_conflict0_"))
195 .Case(S: "alias_test:tensor_encoding", Value: StringRef("test_encoding"))
196 .Default(Value: std::nullopt);
197 if (!aliasName)
198 return AliasResult::NoAlias;
199
200 os << *aliasName;
201 return AliasResult::FinalAlias;
202 }
203
204 AliasResult getAlias(Type type, raw_ostream &os) const final {
205 if (auto tupleType = dyn_cast<TupleType>(type)) {
206 if (tupleType.size() > 0 &&
207 llvm::all_of(tupleType.getTypes(), [](Type elemType) {
208 return isa<SimpleAType>(elemType);
209 })) {
210 os << "test_tuple";
211 return AliasResult::FinalAlias;
212 }
213 }
214 if (auto intType = dyn_cast<TestIntegerType>(type)) {
215 if (intType.getSignedness() ==
216 TestIntegerType::SignednessSemantics::Unsigned &&
217 intType.getWidth() == 8) {
218 os << "test_ui8";
219 return AliasResult::FinalAlias;
220 }
221 }
222 if (auto recType = dyn_cast<TestRecursiveType>(type)) {
223 if (recType.getName() == "type_to_alias") {
224 // We only make alias for a specific recursive type.
225 os << "testrec";
226 return AliasResult::FinalAlias;
227 }
228 }
229 if (auto recAliasType = dyn_cast<TestRecursiveAliasType>(type)) {
230 os << recAliasType.getName();
231 return AliasResult::FinalAlias;
232 }
233 return AliasResult::NoAlias;
234 }
235
236 //===------------------------------------------------------------------===//
237 // Resources
238 //===------------------------------------------------------------------===//
239
240 std::string
241 getResourceKey(const AsmDialectResourceHandle &handle) const override {
242 return cast<TestDialectResourceBlobHandle>(Val: handle).getKey().str();
243 }
244
245 FailureOr<AsmDialectResourceHandle>
246 declareResource(StringRef key) const final {
247 return blobManager.insert(name: key);
248 }
249
250 LogicalResult parseResource(AsmParsedResourceEntry &entry) const final {
251 FailureOr<AsmResourceBlob> blob = entry.parseAsBlob();
252 if (failed(result: blob))
253 return failure();
254
255 // Update the blob for this entry.
256 blobManager.update(name: entry.getKey(), newBlob: std::move(*blob));
257 return success();
258 }
259
260 void
261 buildResources(Operation *op,
262 const SetVector<AsmDialectResourceHandle> &referencedResources,
263 AsmResourceBuilder &provider) const final {
264 blobManager.buildResources(provider, referencedResources: referencedResources.getArrayRef());
265 }
266
267private:
268 /// The blob manager for the dialect.
269 TestResourceBlobManagerInterface &blobManager;
270};
271
272struct TestDialectFoldInterface : public DialectFoldInterface {
273 using DialectFoldInterface::DialectFoldInterface;
274
275 /// Registered hook to check if the given region, which is attached to an
276 /// operation that is *not* isolated from above, should be used when
277 /// materializing constants.
278 bool shouldMaterializeInto(Region *region) const final {
279 // If this is a one region operation, then insert into it.
280 return isa<OneRegionOp>(region->getParentOp());
281 }
282};
283
284/// This class defines the interface for handling inlining with standard
285/// operations.
286struct TestInlinerInterface : public DialectInlinerInterface {
287 using DialectInlinerInterface::DialectInlinerInterface;
288
289 //===--------------------------------------------------------------------===//
290 // Analysis Hooks
291 //===--------------------------------------------------------------------===//
292
293 bool isLegalToInline(Operation *call, Operation *callable,
294 bool wouldBeCloned) const final {
295 // Don't allow inlining calls that are marked `noinline`.
296 return !call->hasAttr(name: "noinline");
297 }
298 bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
299 // Inlining into test dialect regions is legal.
300 return true;
301 }
302 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
303 return true;
304 }
305
306 bool shouldAnalyzeRecursively(Operation *op) const final {
307 // Analyze recursively if this is not a functional region operation, it
308 // froms a separate functional scope.
309 return !isa<FunctionalRegionOp>(op);
310 }
311
312 //===--------------------------------------------------------------------===//
313 // Transformation Hooks
314 //===--------------------------------------------------------------------===//
315
316 /// Handle the given inlined terminator by replacing it with a new operation
317 /// as necessary.
318 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
319 // Only handle "test.return" here.
320 auto returnOp = dyn_cast<TestReturnOp>(op);
321 if (!returnOp)
322 return;
323
324 // Replace the values directly with the return operands.
325 assert(returnOp.getNumOperands() == valuesToRepl.size());
326 for (const auto &it : llvm::enumerate(returnOp.getOperands()))
327 valuesToRepl[it.index()].replaceAllUsesWith(it.value());
328 }
329
330 /// Attempt to materialize a conversion for a type mismatch between a call
331 /// from this dialect, and a callable region. This method should generate an
332 /// operation that takes 'input' as the only operand, and produces a single
333 /// result of 'resultType'. If a conversion can not be generated, nullptr
334 /// should be returned.
335 Operation *materializeCallConversion(OpBuilder &builder, Value input,
336 Type resultType,
337 Location conversionLoc) const final {
338 // Only allow conversion for i16/i32 types.
339 if (!(resultType.isSignlessInteger(width: 16) ||
340 resultType.isSignlessInteger(width: 32)) ||
341 !(input.getType().isSignlessInteger(width: 16) ||
342 input.getType().isSignlessInteger(width: 32)))
343 return nullptr;
344 return builder.create<TestCastOp>(conversionLoc, resultType, input);
345 }
346
347 Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable,
348 Value argument,
349 DictionaryAttr argumentAttrs) const final {
350 if (!argumentAttrs.contains("test.handle_argument"))
351 return argument;
352 return builder.create<TestTypeChangerOp>(call->getLoc(), argument.getType(),
353 argument);
354 }
355
356 Value handleResult(OpBuilder &builder, Operation *call, Operation *callable,
357 Value result, DictionaryAttr resultAttrs) const final {
358 if (!resultAttrs.contains("test.handle_result"))
359 return result;
360 return builder.create<TestTypeChangerOp>(call->getLoc(), result.getType(),
361 result);
362 }
363
364 void processInlinedCallBlocks(
365 Operation *call,
366 iterator_range<Region::iterator> inlinedBlocks) const final {
367 if (!isa<ConversionCallOp>(call))
368 return;
369
370 // Set attributed on all ops in the inlined blocks.
371 for (Block &block : inlinedBlocks) {
372 block.walk(callback: [&](Operation *op) {
373 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
374 });
375 }
376 }
377};
378
379struct TestReductionPatternInterface : public DialectReductionPatternInterface {
380public:
381 TestReductionPatternInterface(Dialect *dialect)
382 : DialectReductionPatternInterface(dialect) {}
383
384 void populateReductionPatterns(RewritePatternSet &patterns) const final {
385 populateTestReductionPatterns(patterns);
386 }
387};
388
389} // namespace
390
391void TestDialect::registerInterfaces() {
392 auto &blobInterface = addInterface<TestResourceBlobManagerInterface>();
393 addInterface<TestOpAsmInterface>(blobInterface);
394
395 addInterfaces<TestDialectFoldInterface, TestInlinerInterface,
396 TestReductionPatternInterface, TestBytecodeDialectInterface>();
397}
398

source code of mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp