1//===- TestDialectInterfaces.cpp - Test dialect interface definitions -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "TestDialect.h"
10#include "TestOps.h"
11#include "mlir/Interfaces/FoldInterfaces.h"
12#include "mlir/Reducer/ReductionPatternInterface.h"
13#include "mlir/Transforms/InliningUtils.h"
14
15using namespace mlir;
16using namespace test;
17
18//===----------------------------------------------------------------------===//
19// TestDialect Interfaces
20//===----------------------------------------------------------------------===//
21
22namespace {
23
24/// Testing the correctness of some traits.
25static_assert(
26 llvm::is_detected<OpTrait::has_implicit_terminator_t,
27 SingleBlockImplicitTerminatorOp>::value,
28 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
29static_assert(OpTrait::hasSingleBlockImplicitTerminator<
30 SingleBlockImplicitTerminatorOp>::value,
31 "hasSingleBlockImplicitTerminator does not match "
32 "SingleBlockImplicitTerminatorOp");
33
34struct TestResourceBlobManagerInterface
35 : public ResourceBlobManagerDialectInterfaceBase<
36 TestDialectResourceBlobHandle> {
37 using ResourceBlobManagerDialectInterfaceBase<
38 TestDialectResourceBlobHandle>::ResourceBlobManagerDialectInterfaceBase;
39};
40
41namespace {
42enum test_encoding { k_attr_params = 0, k_test_i32 = 99 };
43} // namespace
44
45// Test support for interacting with the Bytecode reader/writer.
46struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
47 using BytecodeDialectInterface::BytecodeDialectInterface;
48 TestBytecodeDialectInterface(Dialect *dialect)
49 : BytecodeDialectInterface(dialect) {}
50
51 LogicalResult writeType(Type type,
52 DialectBytecodeWriter &writer) const final {
53 if (auto concreteType = llvm::dyn_cast<TestI32Type>(type)) {
54 writer.writeVarInt(value: test_encoding::k_test_i32);
55 return success();
56 }
57 return failure();
58 }
59
60 Type readType(DialectBytecodeReader &reader) const final {
61 uint64_t encoding;
62 if (failed(Result: reader.readVarInt(result&: encoding)))
63 return Type();
64 if (encoding == test_encoding::k_test_i32)
65 return TestI32Type::get(getContext());
66 return Type();
67 }
68
69 LogicalResult writeAttribute(Attribute attr,
70 DialectBytecodeWriter &writer) const final {
71 if (auto concreteAttr = llvm::dyn_cast<TestAttrParamsAttr>(attr)) {
72 writer.writeVarInt(value: test_encoding::k_attr_params);
73 writer.writeVarInt(value: concreteAttr.getV0());
74 writer.writeVarInt(value: concreteAttr.getV1());
75 return success();
76 }
77 return failure();
78 }
79
80 Attribute readAttribute(DialectBytecodeReader &reader) const final {
81 auto versionOr = reader.getDialectVersion<test::TestDialect>();
82 // Assume current version if not available through the reader.
83 const auto version =
84 (succeeded(Result: versionOr))
85 ? *reinterpret_cast<const TestDialectVersion *>(*versionOr)
86 : TestDialectVersion();
87 if (version.major_ < 2)
88 return readAttrOldEncoding(reader);
89 if (version.major_ == 2 && version.minor_ == 0)
90 return readAttrNewEncoding(reader);
91 // Forbid reading future versions by returning nullptr.
92 return Attribute();
93 }
94
95 // Emit a specific version of the dialect.
96 void writeVersion(DialectBytecodeWriter &writer) const final {
97 // Construct the current dialect version.
98 test::TestDialectVersion versionToEmit;
99
100 // Check if a target version to emit was specified on the writer configs.
101 auto versionOr = writer.getDialectVersion<test::TestDialect>();
102 if (succeeded(Result: versionOr))
103 versionToEmit =
104 *reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
105 writer.writeVarInt(value: versionToEmit.major_); // major
106 writer.writeVarInt(value: versionToEmit.minor_); // minor
107 }
108
109 std::unique_ptr<DialectVersion>
110 readVersion(DialectBytecodeReader &reader) const final {
111 uint64_t major_, minor_;
112 if (failed(Result: reader.readVarInt(result&: major_)) || failed(Result: reader.readVarInt(result&: minor_)))
113 return nullptr;
114 auto version = std::make_unique<TestDialectVersion>();
115 version->major_ = major_;
116 version->minor_ = minor_;
117 return version;
118 }
119
120 LogicalResult upgradeFromVersion(Operation *topLevelOp,
121 const DialectVersion &version_) const final {
122 const auto &version = static_cast<const TestDialectVersion &>(version_);
123 if ((version.major_ == 2) && (version.minor_ == 0))
124 return success();
125 if (version.major_ > 2 || (version.major_ == 2 && version.minor_ > 0)) {
126 return topLevelOp->emitError()
127 << "current test dialect version is 2.0, can't parse version: "
128 << version.major_ << "." << version.minor_;
129 }
130 // Prior version 2.0, the old op supported only a single attribute called
131 // "dimensions". We can perform the upgrade.
132 topLevelOp->walk(callback: [](TestVersionedOpA op) {
133 // Prior version 2.0, `readProperties` did not process the modifier
134 // attribute. Handle that according to the version here.
135 auto &prop = op.getProperties();
136 prop.modifier = BoolAttr::get(context: op->getContext(), value: false);
137 });
138 return success();
139 }
140
141private:
142 Attribute readAttrNewEncoding(DialectBytecodeReader &reader) const {
143 uint64_t encoding;
144 if (failed(Result: reader.readVarInt(result&: encoding)) ||
145 encoding != test_encoding::k_attr_params)
146 return Attribute();
147 // The new encoding has v0 first, v1 second.
148 uint64_t v0, v1;
149 if (failed(Result: reader.readVarInt(result&: v0)) || failed(Result: reader.readVarInt(result&: v1)))
150 return Attribute();
151 return TestAttrParamsAttr::get(getContext(), static_cast<int>(v0),
152 static_cast<int>(v1));
153 }
154
155 Attribute readAttrOldEncoding(DialectBytecodeReader &reader) const {
156 uint64_t encoding;
157 if (failed(Result: reader.readVarInt(result&: encoding)) ||
158 encoding != test_encoding::k_attr_params)
159 return Attribute();
160 // The old encoding has v1 first, v0 second.
161 uint64_t v0, v1;
162 if (failed(Result: reader.readVarInt(result&: v1)) || failed(Result: reader.readVarInt(result&: v0)))
163 return Attribute();
164 return TestAttrParamsAttr::get(getContext(), static_cast<int>(v0),
165 static_cast<int>(v1));
166 }
167};
168
169// Test support for interacting with the AsmPrinter.
170struct TestOpAsmInterface : public OpAsmDialectInterface {
171 using OpAsmDialectInterface::OpAsmDialectInterface;
172 TestOpAsmInterface(Dialect *dialect, TestResourceBlobManagerInterface &mgr)
173 : OpAsmDialectInterface(dialect), blobManager(mgr) {}
174
175 //===------------------------------------------------------------------===//
176 // Aliases
177 //===------------------------------------------------------------------===//
178
179 AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
180 StringAttr strAttr = dyn_cast<StringAttr>(attr);
181 if (!strAttr)
182 return AliasResult::NoAlias;
183
184 // Check the contents of the string attribute to see what the test alias
185 // should be named.
186 std::optional<StringRef> aliasName =
187 StringSwitch<std::optional<StringRef>>(strAttr.getValue())
188 .Case(S: "alias_test:dot_in_name", Value: StringRef("test.alias"))
189 .Case(S: "alias_test:trailing_digit", Value: StringRef("test_alias0"))
190 .Case(S: "alias_test:trailing_digit_conflict_a",
191 Value: StringRef("test_alias_conflict0_1_1_1"))
192 .Case(S: "alias_test:trailing_digit_conflict_b",
193 Value: StringRef("test_alias_conflict0"))
194 .Case(S: "alias_test:trailing_digit_conflict_c",
195 Value: StringRef("test_alias_conflict0"))
196 .Case(S: "alias_test:trailing_digit_conflict_d",
197 Value: StringRef("test_alias_conflict0_"))
198 .Case(S: "alias_test:trailing_digit_conflict_e",
199 Value: StringRef("test_alias_conflict0_1"))
200 .Case(S: "alias_test:trailing_digit_conflict_f",
201 Value: StringRef("test_alias_conflict0_1"))
202 .Case(S: "alias_test:trailing_digit_conflict_g",
203 Value: StringRef("test_alias_conflict0_1_"))
204 .Case(S: "alias_test:trailing_digit_conflict_h",
205 Value: StringRef("test_alias_conflict0_1_1"))
206 .Case(S: "alias_test:prefixed_digit", Value: StringRef("0_test_alias"))
207 .Case(S: "alias_test:prefixed_symbol", Value: StringRef("%test"))
208 .Case(S: "alias_test:tensor_encoding", Value: StringRef("test_encoding"))
209 .Default(Value: std::nullopt);
210 if (!aliasName)
211 return AliasResult::NoAlias;
212
213 os << *aliasName;
214 return AliasResult::FinalAlias;
215 }
216
217 AliasResult getAlias(Type type, raw_ostream &os) const final {
218 if (auto tupleType = dyn_cast<TupleType>(type)) {
219 if (tupleType.size() > 0 &&
220 llvm::all_of(tupleType.getTypes(), [](Type elemType) {
221 return isa<SimpleAType>(elemType);
222 })) {
223 os << "test_tuple";
224 return AliasResult::FinalAlias;
225 }
226 }
227 if (auto intType = dyn_cast<TestIntegerType>(type)) {
228 if (intType.getSignedness() ==
229 TestIntegerType::SignednessSemantics::Unsigned &&
230 intType.getWidth() == 8) {
231 os << "test_ui8";
232 return AliasResult::FinalAlias;
233 }
234 }
235 if (auto recType = dyn_cast<TestRecursiveType>(type)) {
236 if (recType.getName() == "type_to_alias") {
237 // We only make alias for a specific recursive type.
238 os << "testrec";
239 return AliasResult::FinalAlias;
240 }
241 }
242 if (auto recAliasType = dyn_cast<TestRecursiveAliasType>(type)) {
243 os << recAliasType.getName();
244 return AliasResult::FinalAlias;
245 }
246 return AliasResult::NoAlias;
247 }
248
249 //===------------------------------------------------------------------===//
250 // Resources
251 //===------------------------------------------------------------------===//
252
253 std::string
254 getResourceKey(const AsmDialectResourceHandle &handle) const override {
255 return cast<TestDialectResourceBlobHandle>(Val: handle).getKey().str();
256 }
257
258 FailureOr<AsmDialectResourceHandle>
259 declareResource(StringRef key) const final {
260 return blobManager.insert(name: key);
261 }
262
263 LogicalResult parseResource(AsmParsedResourceEntry &entry) const final {
264 FailureOr<AsmResourceBlob> blob = entry.parseAsBlob();
265 if (failed(Result: blob))
266 return failure();
267
268 // Update the blob for this entry.
269 blobManager.update(name: entry.getKey(), newBlob: std::move(*blob));
270 return success();
271 }
272
273 void
274 buildResources(Operation *op,
275 const SetVector<AsmDialectResourceHandle> &referencedResources,
276 AsmResourceBuilder &provider) const final {
277 blobManager.buildResources(provider, referencedResources: referencedResources.getArrayRef());
278 }
279
280private:
281 /// The blob manager for the dialect.
282 TestResourceBlobManagerInterface &blobManager;
283};
284
285struct TestDialectFoldInterface : public DialectFoldInterface {
286 using DialectFoldInterface::DialectFoldInterface;
287
288 /// Registered hook to check if the given region, which is attached to an
289 /// operation that is *not* isolated from above, should be used when
290 /// materializing constants.
291 bool shouldMaterializeInto(Region *region) const final {
292 // If this is a one region operation, then insert into it.
293 return isa<OneRegionOp>(region->getParentOp());
294 }
295};
296
297/// This class defines the interface for handling inlining with standard
298/// operations.
299struct TestInlinerInterface : public DialectInlinerInterface {
300 using DialectInlinerInterface::DialectInlinerInterface;
301
302 //===--------------------------------------------------------------------===//
303 // Analysis Hooks
304 //===--------------------------------------------------------------------===//
305
306 bool isLegalToInline(Operation *call, Operation *callable,
307 bool wouldBeCloned) const final {
308 // Don't allow inlining calls that are marked `noinline`.
309 return !call->hasAttr(name: "noinline");
310 }
311 bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
312 // Inlining into test dialect regions is legal.
313 return true;
314 }
315 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
316 return true;
317 }
318
319 bool shouldAnalyzeRecursively(Operation *op) const final {
320 // Analyze recursively if this is not a functional region operation, it
321 // froms a separate functional scope.
322 return !isa<FunctionalRegionOp>(op);
323 }
324
325 //===--------------------------------------------------------------------===//
326 // Transformation Hooks
327 //===--------------------------------------------------------------------===//
328
329 /// Handle the given inlined terminator by replacing it with a new operation
330 /// as necessary.
331 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
332 // Only handle "test.return" here.
333 auto returnOp = dyn_cast<TestReturnOp>(op);
334 if (!returnOp)
335 return;
336
337 // Replace the values directly with the return operands.
338 assert(returnOp.getNumOperands() == valuesToRepl.size());
339 for (const auto &it : llvm::enumerate(returnOp.getOperands()))
340 valuesToRepl[it.index()].replaceAllUsesWith(it.value());
341 }
342
343 /// Attempt to materialize a conversion for a type mismatch between a call
344 /// from this dialect, and a callable region. This method should generate an
345 /// operation that takes 'input' as the only operand, and produces a single
346 /// result of 'resultType'. If a conversion can not be generated, nullptr
347 /// should be returned.
348 Operation *materializeCallConversion(OpBuilder &builder, Value input,
349 Type resultType,
350 Location conversionLoc) const final {
351 // Only allow conversion for i16/i32 types.
352 if (!(resultType.isSignlessInteger(width: 16) ||
353 resultType.isSignlessInteger(width: 32)) ||
354 !(input.getType().isSignlessInteger(width: 16) ||
355 input.getType().isSignlessInteger(width: 32)))
356 return nullptr;
357 return builder.create<TestCastOp>(conversionLoc, resultType, input);
358 }
359
360 Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable,
361 Value argument,
362 DictionaryAttr argumentAttrs) const final {
363 if (!argumentAttrs.contains("test.handle_argument"))
364 return argument;
365 return builder.create<TestTypeChangerOp>(call->getLoc(), argument.getType(),
366 argument);
367 }
368
369 Value handleResult(OpBuilder &builder, Operation *call, Operation *callable,
370 Value result, DictionaryAttr resultAttrs) const final {
371 if (!resultAttrs.contains("test.handle_result"))
372 return result;
373 return builder.create<TestTypeChangerOp>(call->getLoc(), result.getType(),
374 result);
375 }
376
377 void processInlinedCallBlocks(
378 Operation *call,
379 iterator_range<Region::iterator> inlinedBlocks) const final {
380 if (!isa<ConversionCallOp>(call))
381 return;
382
383 // Set attributed on all ops in the inlined blocks.
384 for (Block &block : inlinedBlocks) {
385 block.walk(callback: [&](Operation *op) {
386 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
387 });
388 }
389 }
390};
391
392struct TestReductionPatternInterface : public DialectReductionPatternInterface {
393public:
394 TestReductionPatternInterface(Dialect *dialect)
395 : DialectReductionPatternInterface(dialect) {}
396
397 void populateReductionPatterns(RewritePatternSet &patterns) const final {
398 populateTestReductionPatterns(patterns);
399 }
400};
401
402} // namespace
403
404void TestDialect::registerInterfaces() {
405 auto &blobInterface = addInterface<TestResourceBlobManagerInterface>();
406 addInterface<TestOpAsmInterface>(blobInterface);
407
408 addInterfaces<TestDialectFoldInterface, TestInlinerInterface,
409 TestReductionPatternInterface, TestBytecodeDialectInterface>();
410}
411

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp