1 | //===- TestOpProperties.cpp - Test all properties-related APIs ------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/IR/Attributes.h" |
10 | #include "mlir/IR/OpDefinition.h" |
11 | #include "mlir/IR/OperationSupport.h" |
12 | #include "mlir/Parser/Parser.h" |
13 | #include "gtest/gtest.h" |
14 | #include <optional> |
15 | |
16 | using namespace mlir; |
17 | |
18 | namespace { |
19 | /// Simple structure definining a struct to define "properties" for a given |
20 | /// operation. Default values are honored when creating an operation. |
21 | struct TestProperties { |
22 | int a = -1; |
23 | float b = -1.; |
24 | std::vector<int64_t> array = {-33}; |
25 | /// A shared_ptr to a const object is safe: it is equivalent to a value-based |
26 | /// member. Here the label will be deallocated when the last operation |
27 | /// referring to it is destroyed. However there is no pool-allocation: this is |
28 | /// offloaded to the client. |
29 | std::shared_ptr<const std::string> label; |
30 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestProperties) |
31 | }; |
32 | |
33 | bool operator==(const TestProperties &lhs, TestProperties &rhs) { |
34 | return lhs.a == rhs.a && lhs.b == rhs.b && lhs.array == rhs.array && |
35 | lhs.label == rhs.label; |
36 | } |
37 | |
38 | /// Convert a DictionaryAttr to a TestProperties struct, optionally emit errors |
39 | /// through the provided diagnostic if any. This is used for example during |
40 | /// parsing with the generic format. |
41 | static LogicalResult |
42 | setPropertiesFromAttribute(TestProperties &prop, Attribute attr, |
43 | function_ref<InFlightDiagnostic()> emitError) { |
44 | DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr); |
45 | if (!dict) { |
46 | emitError() << "expected DictionaryAttr to set TestProperties" ; |
47 | return failure(); |
48 | } |
49 | auto aAttr = dict.getAs<IntegerAttr>("a" ); |
50 | if (!aAttr) { |
51 | emitError() << "expected IntegerAttr for key `a`" ; |
52 | return failure(); |
53 | } |
54 | auto bAttr = dict.getAs<FloatAttr>("b" ); |
55 | if (!bAttr || |
56 | &bAttr.getValue().getSemantics() != &llvm::APFloatBase::IEEEsingle()) { |
57 | emitError() << "expected FloatAttr for key `b`" ; |
58 | return failure(); |
59 | } |
60 | |
61 | auto arrayAttr = dict.getAs<DenseI64ArrayAttr>("array" ); |
62 | if (!arrayAttr) { |
63 | emitError() << "expected DenseI64ArrayAttr for key `array`" ; |
64 | return failure(); |
65 | } |
66 | |
67 | auto label = dict.getAs<mlir::StringAttr>("label" ); |
68 | if (!label) { |
69 | emitError() << "expected StringAttr for key `label`" ; |
70 | return failure(); |
71 | } |
72 | |
73 | prop.a = aAttr.getValue().getSExtValue(); |
74 | prop.b = bAttr.getValue().convertToFloat(); |
75 | prop.array.assign(arrayAttr.asArrayRef().begin(), |
76 | arrayAttr.asArrayRef().end()); |
77 | prop.label = std::make_shared<std::string>(label.getValue()); |
78 | return success(); |
79 | } |
80 | |
81 | /// Convert a TestProperties struct to a DictionaryAttr, this is used for |
82 | /// example during printing with the generic format. |
83 | static Attribute getPropertiesAsAttribute(MLIRContext *ctx, |
84 | const TestProperties &prop) { |
85 | SmallVector<NamedAttribute> attrs; |
86 | Builder b{ctx}; |
87 | attrs.push_back(Elt: b.getNamedAttr("a" , b.getI32IntegerAttr(prop.a))); |
88 | attrs.push_back(Elt: b.getNamedAttr("b" , b.getF32FloatAttr(prop.b))); |
89 | attrs.push_back(Elt: b.getNamedAttr("array" , b.getDenseI64ArrayAttr(prop.array))); |
90 | attrs.push_back(Elt: b.getNamedAttr( |
91 | "label" , b.getStringAttr(prop.label ? *prop.label : "<nullptr>" ))); |
92 | return b.getDictionaryAttr(attrs); |
93 | } |
94 | |
95 | inline llvm::hash_code computeHash(const TestProperties &prop) { |
96 | // We hash `b` which is a float using its underlying array of char: |
97 | unsigned char const *p = reinterpret_cast<unsigned char const *>(&prop.b); |
98 | ArrayRef<unsigned char> bBytes{p, sizeof(prop.b)}; |
99 | return llvm::hash_combine( |
100 | args: prop.a, args: llvm::hash_combine_range(first: bBytes.begin(), last: bBytes.end()), |
101 | args: llvm::hash_combine_range(first: prop.array.begin(), last: prop.array.end()), |
102 | args: StringRef(*prop.label)); |
103 | } |
104 | |
105 | /// A custom operation for the purpose of showcasing how to use "properties". |
106 | class OpWithProperties : public Op<OpWithProperties> { |
107 | public: |
108 | // Begin boilerplate |
109 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpWithProperties) |
110 | using Op::Op; |
111 | static ArrayRef<StringRef> getAttributeNames() { return {}; } |
112 | static StringRef getOperationName() { |
113 | return "test_op_properties.op_with_properties" ; |
114 | } |
115 | // End boilerplate |
116 | |
117 | // This alias is the only definition needed for enabling "properties" for this |
118 | // operation. |
119 | using Properties = TestProperties; |
120 | static std::optional<mlir::Attribute> getInherentAttr(MLIRContext *context, |
121 | const Properties &prop, |
122 | StringRef name) { |
123 | return std::nullopt; |
124 | } |
125 | static void setInherentAttr(Properties &prop, StringRef name, |
126 | mlir::Attribute value) {} |
127 | static void populateInherentAttrs(MLIRContext *context, |
128 | const Properties &prop, |
129 | NamedAttrList &attrs) {} |
130 | static LogicalResult |
131 | verifyInherentAttrs(OperationName opName, NamedAttrList &attrs, |
132 | function_ref<InFlightDiagnostic()> emitError) { |
133 | return success(); |
134 | } |
135 | }; |
136 | |
137 | /// A custom operation for the purpose of showcasing how discardable attributes |
138 | /// are handled in absence of properties. |
139 | class OpWithoutProperties : public Op<OpWithoutProperties> { |
140 | public: |
141 | // Begin boilerplate. |
142 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpWithoutProperties) |
143 | using Op::Op; |
144 | static ArrayRef<StringRef> getAttributeNames() { |
145 | static StringRef attributeNames[] = {StringRef("inherent_attr" )}; |
146 | return ArrayRef(attributeNames); |
147 | }; |
148 | static StringRef getOperationName() { |
149 | return "test_op_properties.op_without_properties" ; |
150 | } |
151 | // End boilerplate. |
152 | }; |
153 | |
154 | // A trivial supporting dialect to register the above operation. |
155 | class TestOpPropertiesDialect : public Dialect { |
156 | public: |
157 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOpPropertiesDialect) |
158 | static constexpr StringLiteral getDialectNamespace() { |
159 | return StringLiteral("test_op_properties" ); |
160 | } |
161 | explicit TestOpPropertiesDialect(MLIRContext *context) |
162 | : Dialect(getDialectNamespace(), context, |
163 | TypeID::get<TestOpPropertiesDialect>()) { |
164 | addOperations<OpWithProperties, OpWithoutProperties>(); |
165 | } |
166 | }; |
167 | |
168 | constexpr StringLiteral mlirSrc = R"mlir( |
169 | "test_op_properties.op_with_properties"() |
170 | <{a = -42 : i32, |
171 | b = -4.200000e+01 : f32, |
172 | array = array<i64: 40, 41>, |
173 | label = "bar foo"}> : () -> () |
174 | )mlir" ; |
175 | |
176 | TEST(OpPropertiesTest, Properties) { |
177 | MLIRContext context; |
178 | context.getOrLoadDialect<TestOpPropertiesDialect>(); |
179 | ParserConfig config(&context); |
180 | // Parse the operation with some properties. |
181 | OwningOpRef<Operation *> op = parseSourceString(sourceStr: mlirSrc, config); |
182 | ASSERT_TRUE(op.get() != nullptr); |
183 | auto opWithProp = dyn_cast<OpWithProperties>(Val: op.get()); |
184 | ASSERT_TRUE(opWithProp); |
185 | { |
186 | std::string output; |
187 | llvm::raw_string_ostream os(output); |
188 | opWithProp.print(os); |
189 | ASSERT_STREQ("\"test_op_properties.op_with_properties\"() " |
190 | "<{a = -42 : i32, " |
191 | "array = array<i64: 40, 41>, " |
192 | "b = -4.200000e+01 : f32, " |
193 | "label = \"bar foo\"}> : () -> ()\n" , |
194 | os.str().c_str()); |
195 | } |
196 | // Get a mutable reference to the properties for this operation and modify it |
197 | // in place one member at a time. |
198 | TestProperties &prop = opWithProp.getProperties(); |
199 | prop.a = 42; |
200 | { |
201 | std::string output; |
202 | llvm::raw_string_ostream os(output); |
203 | opWithProp.print(os); |
204 | EXPECT_TRUE(StringRef(os.str()).contains("a = 42" )); |
205 | EXPECT_TRUE(StringRef(os.str()).contains("b = -4.200000e+01" )); |
206 | EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: 40, 41>" )); |
207 | EXPECT_TRUE(StringRef(os.str()).contains("label = \"bar foo\"" )); |
208 | } |
209 | prop.b = 42.; |
210 | { |
211 | std::string output; |
212 | llvm::raw_string_ostream os(output); |
213 | opWithProp.print(os); |
214 | EXPECT_TRUE(StringRef(os.str()).contains("a = 42" )); |
215 | EXPECT_TRUE(StringRef(os.str()).contains("b = 4.200000e+01" )); |
216 | EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: 40, 41>" )); |
217 | EXPECT_TRUE(StringRef(os.str()).contains("label = \"bar foo\"" )); |
218 | } |
219 | prop.array.push_back(x: 42); |
220 | { |
221 | std::string output; |
222 | llvm::raw_string_ostream os(output); |
223 | opWithProp.print(os); |
224 | EXPECT_TRUE(StringRef(os.str()).contains("a = 42" )); |
225 | EXPECT_TRUE(StringRef(os.str()).contains("b = 4.200000e+01" )); |
226 | EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: 40, 41, 42>" )); |
227 | EXPECT_TRUE(StringRef(os.str()).contains("label = \"bar foo\"" )); |
228 | } |
229 | prop.label = std::make_shared<std::string>(args: "foo bar" ); |
230 | { |
231 | std::string output; |
232 | llvm::raw_string_ostream os(output); |
233 | opWithProp.print(os); |
234 | EXPECT_TRUE(StringRef(os.str()).contains("a = 42" )); |
235 | EXPECT_TRUE(StringRef(os.str()).contains("b = 4.200000e+01" )); |
236 | EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: 40, 41, 42>" )); |
237 | EXPECT_TRUE(StringRef(os.str()).contains("label = \"foo bar\"" )); |
238 | } |
239 | } |
240 | |
241 | // Test diagnostic emission when using invalid dictionary. |
242 | TEST(OpPropertiesTest, FailedProperties) { |
243 | MLIRContext context; |
244 | context.getOrLoadDialect<TestOpPropertiesDialect>(); |
245 | std::string diagnosticStr; |
246 | context.getDiagEngine().registerHandler(handler: [&](Diagnostic &diag) { |
247 | diagnosticStr += diag.str(); |
248 | return success(); |
249 | }); |
250 | |
251 | // Parse the operation with some properties. |
252 | ParserConfig config(&context); |
253 | |
254 | // Parse an operation with invalid (incomplete) properties. |
255 | OwningOpRef<Operation *> owningOp = |
256 | parseSourceString(sourceStr: "\"test_op_properties.op_with_properties\"() " |
257 | "<{a = -42 : i32}> : () -> ()\n" , |
258 | config); |
259 | ASSERT_EQ(owningOp.get(), nullptr); |
260 | EXPECT_STREQ( |
261 | "invalid properties {a = -42 : i32} for op " |
262 | "test_op_properties.op_with_properties: expected FloatAttr for key `b`" , |
263 | diagnosticStr.c_str()); |
264 | diagnosticStr.clear(); |
265 | |
266 | owningOp = parseSourceString(sourceStr: mlirSrc, config); |
267 | Operation *op = owningOp.get(); |
268 | ASSERT_TRUE(op != nullptr); |
269 | Location loc = op->getLoc(); |
270 | auto opWithProp = dyn_cast<OpWithProperties>(Val: op); |
271 | ASSERT_TRUE(opWithProp); |
272 | |
273 | OperationState state(loc, op->getName()); |
274 | Builder b{&context}; |
275 | NamedAttrList attrs; |
276 | attrs.push_back(newAttribute: b.getNamedAttr("a" , b.getStringAttr("foo" ))); |
277 | state.propertiesAttr = attrs.getDictionary(&context); |
278 | { |
279 | auto emitError = [&]() { |
280 | return op->emitError(message: "setting properties failed: " ); |
281 | }; |
282 | auto result = state.setProperties(op, emitError); |
283 | EXPECT_TRUE(result.failed()); |
284 | } |
285 | EXPECT_STREQ("setting properties failed: expected IntegerAttr for key `a`" , |
286 | diagnosticStr.c_str()); |
287 | } |
288 | |
289 | TEST(OpPropertiesTest, DefaultValues) { |
290 | MLIRContext context; |
291 | context.getOrLoadDialect<TestOpPropertiesDialect>(); |
292 | OperationState state(UnknownLoc::get(&context), |
293 | "test_op_properties.op_with_properties" ); |
294 | Operation *op = Operation::create(state); |
295 | ASSERT_TRUE(op != nullptr); |
296 | { |
297 | std::string output; |
298 | llvm::raw_string_ostream os(output); |
299 | op->print(os); |
300 | EXPECT_TRUE(StringRef(os.str()).contains("a = -1" )); |
301 | EXPECT_TRUE(StringRef(os.str()).contains("b = -1" )); |
302 | EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: -33>" )); |
303 | } |
304 | op->erase(); |
305 | } |
306 | |
307 | TEST(OpPropertiesTest, Cloning) { |
308 | MLIRContext context; |
309 | context.getOrLoadDialect<TestOpPropertiesDialect>(); |
310 | ParserConfig config(&context); |
311 | // Parse the operation with some properties. |
312 | OwningOpRef<Operation *> op = parseSourceString(sourceStr: mlirSrc, config); |
313 | ASSERT_TRUE(op.get() != nullptr); |
314 | auto opWithProp = dyn_cast<OpWithProperties>(Val: op.get()); |
315 | ASSERT_TRUE(opWithProp); |
316 | Operation *clone = opWithProp->clone(); |
317 | |
318 | // Check that op and its clone prints equally |
319 | std::string opStr; |
320 | std::string cloneStr; |
321 | { |
322 | llvm::raw_string_ostream os(opStr); |
323 | op.get()->print(os); |
324 | } |
325 | { |
326 | llvm::raw_string_ostream os(cloneStr); |
327 | clone->print(os); |
328 | } |
329 | clone->erase(); |
330 | EXPECT_STREQ(opStr.c_str(), cloneStr.c_str()); |
331 | } |
332 | |
333 | TEST(OpPropertiesTest, Equivalence) { |
334 | MLIRContext context; |
335 | context.getOrLoadDialect<TestOpPropertiesDialect>(); |
336 | ParserConfig config(&context); |
337 | // Parse the operation with some properties. |
338 | OwningOpRef<Operation *> op = parseSourceString(sourceStr: mlirSrc, config); |
339 | ASSERT_TRUE(op.get() != nullptr); |
340 | auto opWithProp = dyn_cast<OpWithProperties>(Val: op.get()); |
341 | ASSERT_TRUE(opWithProp); |
342 | llvm::hash_code reference = OperationEquivalence::computeHash(op: opWithProp); |
343 | TestProperties &prop = opWithProp.getProperties(); |
344 | prop.a = 42; |
345 | EXPECT_NE(reference, OperationEquivalence::computeHash(opWithProp)); |
346 | prop.a = -42; |
347 | EXPECT_EQ(reference, OperationEquivalence::computeHash(opWithProp)); |
348 | prop.b = 42.; |
349 | EXPECT_NE(reference, OperationEquivalence::computeHash(opWithProp)); |
350 | prop.b = -42.; |
351 | EXPECT_EQ(reference, OperationEquivalence::computeHash(opWithProp)); |
352 | prop.array.push_back(x: 42); |
353 | EXPECT_NE(reference, OperationEquivalence::computeHash(opWithProp)); |
354 | prop.array.pop_back(); |
355 | EXPECT_EQ(reference, OperationEquivalence::computeHash(opWithProp)); |
356 | } |
357 | |
358 | TEST(OpPropertiesTest, getOrAddProperties) { |
359 | MLIRContext context; |
360 | context.getOrLoadDialect<TestOpPropertiesDialect>(); |
361 | OperationState state(UnknownLoc::get(&context), |
362 | "test_op_properties.op_with_properties" ); |
363 | // Test `getOrAddProperties` API on OperationState. |
364 | TestProperties &prop = state.getOrAddProperties<TestProperties>(); |
365 | prop.a = 1; |
366 | prop.b = 2; |
367 | prop.array = {3, 4, 5}; |
368 | Operation *op = Operation::create(state); |
369 | ASSERT_TRUE(op != nullptr); |
370 | { |
371 | std::string output; |
372 | llvm::raw_string_ostream os(output); |
373 | op->print(os); |
374 | EXPECT_TRUE(StringRef(os.str()).contains("a = 1" )); |
375 | EXPECT_TRUE(StringRef(os.str()).contains("b = 2" )); |
376 | EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: 3, 4, 5>" )); |
377 | } |
378 | op->erase(); |
379 | } |
380 | |
381 | constexpr StringLiteral = R"mlir( |
382 | "test_op_properties.op_without_properties"() |
383 | {inherent_attr = 42, other_attr = 56} : () -> () |
384 | )mlir" ; |
385 | |
386 | TEST(OpPropertiesTest, withoutPropertiesDiscardableAttrs) { |
387 | MLIRContext context; |
388 | context.getOrLoadDialect<TestOpPropertiesDialect>(); |
389 | ParserConfig config(&context); |
390 | OwningOpRef<Operation *> op = |
391 | parseSourceString(sourceStr: withoutPropertiesAttrsSrc, config); |
392 | ASSERT_EQ(llvm::range_size(op->getDiscardableAttrs()), 1u); |
393 | EXPECT_EQ(op->getDiscardableAttrs().begin()->getName().getValue(), |
394 | "other_attr" ); |
395 | |
396 | EXPECT_EQ(op->getAttrs().size(), 2u); |
397 | EXPECT_TRUE(op->getInherentAttr("inherent_attr" ) != std::nullopt); |
398 | EXPECT_TRUE(op->getDiscardableAttr("other_attr" ) != Attribute()); |
399 | |
400 | std::string output; |
401 | llvm::raw_string_ostream os(output); |
402 | op->print(os); |
403 | EXPECT_TRUE(StringRef(os.str()).contains("inherent_attr = 42" )); |
404 | EXPECT_TRUE(StringRef(os.str()).contains("other_attr = 56" )); |
405 | |
406 | OwningOpRef<Operation *> reparsed = parseSourceString(sourceStr: os.str(), config); |
407 | auto trivialHash = [](Value v) { return hash_value(arg: v); }; |
408 | auto hash = [&](Operation *operation) { |
409 | return OperationEquivalence::computeHash( |
410 | op: operation, hashOperands: trivialHash, hashResults: trivialHash, |
411 | flags: OperationEquivalence::Flags::IgnoreLocations); |
412 | }; |
413 | EXPECT_TRUE(hash(op.get()) == hash(reparsed.get())); |
414 | } |
415 | |
416 | } // namespace |
417 | |