1//===- TestOpProperties.cpp - Test all properties-related APIs ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/Attributes.h"
10#include "mlir/IR/OpDefinition.h"
11#include "mlir/IR/OperationSupport.h"
12#include "mlir/Parser/Parser.h"
13#include "gtest/gtest.h"
14#include <optional>
15
16using namespace mlir;
17
18namespace {
19/// Simple structure definining a struct to define "properties" for a given
20/// operation. Default values are honored when creating an operation.
21struct TestProperties {
22 int a = -1;
23 float b = -1.;
24 std::vector<int64_t> array = {-33};
25 /// A shared_ptr to a const object is safe: it is equivalent to a value-based
26 /// member. Here the label will be deallocated when the last operation
27 /// referring to it is destroyed. However there is no pool-allocation: this is
28 /// offloaded to the client.
29 std::shared_ptr<const std::string> label;
30 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestProperties)
31};
32
33bool operator==(const TestProperties &lhs, TestProperties &rhs) {
34 return lhs.a == rhs.a && lhs.b == rhs.b && lhs.array == rhs.array &&
35 lhs.label == rhs.label;
36}
37
38/// Convert a DictionaryAttr to a TestProperties struct, optionally emit errors
39/// through the provided diagnostic if any. This is used for example during
40/// parsing with the generic format.
41static LogicalResult
42setPropertiesFromAttribute(TestProperties &prop, Attribute attr,
43 function_ref<InFlightDiagnostic()> emitError) {
44 DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
45 if (!dict) {
46 emitError() << "expected DictionaryAttr to set TestProperties";
47 return failure();
48 }
49 auto aAttr = dict.getAs<IntegerAttr>("a");
50 if (!aAttr) {
51 emitError() << "expected IntegerAttr for key `a`";
52 return failure();
53 }
54 auto bAttr = dict.getAs<FloatAttr>("b");
55 if (!bAttr ||
56 &bAttr.getValue().getSemantics() != &llvm::APFloatBase::IEEEsingle()) {
57 emitError() << "expected FloatAttr for key `b`";
58 return failure();
59 }
60
61 auto arrayAttr = dict.getAs<DenseI64ArrayAttr>("array");
62 if (!arrayAttr) {
63 emitError() << "expected DenseI64ArrayAttr for key `array`";
64 return failure();
65 }
66
67 auto label = dict.getAs<mlir::StringAttr>("label");
68 if (!label) {
69 emitError() << "expected StringAttr for key `label`";
70 return failure();
71 }
72
73 prop.a = aAttr.getValue().getSExtValue();
74 prop.b = bAttr.getValue().convertToFloat();
75 prop.array.assign(arrayAttr.asArrayRef().begin(),
76 arrayAttr.asArrayRef().end());
77 prop.label = std::make_shared<std::string>(label.getValue());
78 return success();
79}
80
81/// Convert a TestProperties struct to a DictionaryAttr, this is used for
82/// example during printing with the generic format.
83static Attribute getPropertiesAsAttribute(MLIRContext *ctx,
84 const TestProperties &prop) {
85 SmallVector<NamedAttribute> attrs;
86 Builder b{ctx};
87 attrs.push_back(Elt: b.getNamedAttr("a", b.getI32IntegerAttr(prop.a)));
88 attrs.push_back(Elt: b.getNamedAttr("b", b.getF32FloatAttr(prop.b)));
89 attrs.push_back(Elt: b.getNamedAttr("array", b.getDenseI64ArrayAttr(prop.array)));
90 attrs.push_back(Elt: b.getNamedAttr(
91 "label", b.getStringAttr(prop.label ? *prop.label : "<nullptr>")));
92 return b.getDictionaryAttr(attrs);
93}
94
95inline llvm::hash_code computeHash(const TestProperties &prop) {
96 // We hash `b` which is a float using its underlying array of char:
97 unsigned char const *p = reinterpret_cast<unsigned char const *>(&prop.b);
98 ArrayRef<unsigned char> bBytes{p, sizeof(prop.b)};
99 return llvm::hash_combine(
100 args: prop.a, args: llvm::hash_combine_range(first: bBytes.begin(), last: bBytes.end()),
101 args: llvm::hash_combine_range(first: prop.array.begin(), last: prop.array.end()),
102 args: StringRef(*prop.label));
103}
104
105/// A custom operation for the purpose of showcasing how to use "properties".
106class OpWithProperties : public Op<OpWithProperties> {
107public:
108 // Begin boilerplate
109 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpWithProperties)
110 using Op::Op;
111 static ArrayRef<StringRef> getAttributeNames() { return {}; }
112 static StringRef getOperationName() {
113 return "test_op_properties.op_with_properties";
114 }
115 // End boilerplate
116
117 // This alias is the only definition needed for enabling "properties" for this
118 // operation.
119 using Properties = TestProperties;
120 static std::optional<mlir::Attribute> getInherentAttr(MLIRContext *context,
121 const Properties &prop,
122 StringRef name) {
123 return std::nullopt;
124 }
125 static void setInherentAttr(Properties &prop, StringRef name,
126 mlir::Attribute value) {}
127 static void populateInherentAttrs(MLIRContext *context,
128 const Properties &prop,
129 NamedAttrList &attrs) {}
130 static LogicalResult
131 verifyInherentAttrs(OperationName opName, NamedAttrList &attrs,
132 function_ref<InFlightDiagnostic()> emitError) {
133 return success();
134 }
135};
136
137/// A custom operation for the purpose of showcasing how discardable attributes
138/// are handled in absence of properties.
139class OpWithoutProperties : public Op<OpWithoutProperties> {
140public:
141 // Begin boilerplate.
142 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpWithoutProperties)
143 using Op::Op;
144 static ArrayRef<StringRef> getAttributeNames() {
145 static StringRef attributeNames[] = {StringRef("inherent_attr")};
146 return ArrayRef(attributeNames);
147 };
148 static StringRef getOperationName() {
149 return "test_op_properties.op_without_properties";
150 }
151 // End boilerplate.
152};
153
154// A trivial supporting dialect to register the above operation.
155class TestOpPropertiesDialect : public Dialect {
156public:
157 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOpPropertiesDialect)
158 static constexpr StringLiteral getDialectNamespace() {
159 return StringLiteral("test_op_properties");
160 }
161 explicit TestOpPropertiesDialect(MLIRContext *context)
162 : Dialect(getDialectNamespace(), context,
163 TypeID::get<TestOpPropertiesDialect>()) {
164 addOperations<OpWithProperties, OpWithoutProperties>();
165 }
166};
167
168constexpr StringLiteral mlirSrc = R"mlir(
169 "test_op_properties.op_with_properties"()
170 <{a = -42 : i32,
171 b = -4.200000e+01 : f32,
172 array = array<i64: 40, 41>,
173 label = "bar foo"}> : () -> ()
174)mlir";
175
176TEST(OpPropertiesTest, Properties) {
177 MLIRContext context;
178 context.getOrLoadDialect<TestOpPropertiesDialect>();
179 ParserConfig config(&context);
180 // Parse the operation with some properties.
181 OwningOpRef<Operation *> op = parseSourceString(sourceStr: mlirSrc, config);
182 ASSERT_TRUE(op.get() != nullptr);
183 auto opWithProp = dyn_cast<OpWithProperties>(Val: op.get());
184 ASSERT_TRUE(opWithProp);
185 {
186 std::string output;
187 llvm::raw_string_ostream os(output);
188 opWithProp.print(os);
189 ASSERT_STREQ("\"test_op_properties.op_with_properties\"() "
190 "<{a = -42 : i32, "
191 "array = array<i64: 40, 41>, "
192 "b = -4.200000e+01 : f32, "
193 "label = \"bar foo\"}> : () -> ()\n",
194 os.str().c_str());
195 }
196 // Get a mutable reference to the properties for this operation and modify it
197 // in place one member at a time.
198 TestProperties &prop = opWithProp.getProperties();
199 prop.a = 42;
200 {
201 std::string output;
202 llvm::raw_string_ostream os(output);
203 opWithProp.print(os);
204 EXPECT_TRUE(StringRef(os.str()).contains("a = 42"));
205 EXPECT_TRUE(StringRef(os.str()).contains("b = -4.200000e+01"));
206 EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: 40, 41>"));
207 EXPECT_TRUE(StringRef(os.str()).contains("label = \"bar foo\""));
208 }
209 prop.b = 42.;
210 {
211 std::string output;
212 llvm::raw_string_ostream os(output);
213 opWithProp.print(os);
214 EXPECT_TRUE(StringRef(os.str()).contains("a = 42"));
215 EXPECT_TRUE(StringRef(os.str()).contains("b = 4.200000e+01"));
216 EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: 40, 41>"));
217 EXPECT_TRUE(StringRef(os.str()).contains("label = \"bar foo\""));
218 }
219 prop.array.push_back(x: 42);
220 {
221 std::string output;
222 llvm::raw_string_ostream os(output);
223 opWithProp.print(os);
224 EXPECT_TRUE(StringRef(os.str()).contains("a = 42"));
225 EXPECT_TRUE(StringRef(os.str()).contains("b = 4.200000e+01"));
226 EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: 40, 41, 42>"));
227 EXPECT_TRUE(StringRef(os.str()).contains("label = \"bar foo\""));
228 }
229 prop.label = std::make_shared<std::string>(args: "foo bar");
230 {
231 std::string output;
232 llvm::raw_string_ostream os(output);
233 opWithProp.print(os);
234 EXPECT_TRUE(StringRef(os.str()).contains("a = 42"));
235 EXPECT_TRUE(StringRef(os.str()).contains("b = 4.200000e+01"));
236 EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: 40, 41, 42>"));
237 EXPECT_TRUE(StringRef(os.str()).contains("label = \"foo bar\""));
238 }
239}
240
241// Test diagnostic emission when using invalid dictionary.
242TEST(OpPropertiesTest, FailedProperties) {
243 MLIRContext context;
244 context.getOrLoadDialect<TestOpPropertiesDialect>();
245 std::string diagnosticStr;
246 context.getDiagEngine().registerHandler(handler: [&](Diagnostic &diag) {
247 diagnosticStr += diag.str();
248 return success();
249 });
250
251 // Parse the operation with some properties.
252 ParserConfig config(&context);
253
254 // Parse an operation with invalid (incomplete) properties.
255 OwningOpRef<Operation *> owningOp =
256 parseSourceString(sourceStr: "\"test_op_properties.op_with_properties\"() "
257 "<{a = -42 : i32}> : () -> ()\n",
258 config);
259 ASSERT_EQ(owningOp.get(), nullptr);
260 EXPECT_STREQ(
261 "invalid properties {a = -42 : i32} for op "
262 "test_op_properties.op_with_properties: expected FloatAttr for key `b`",
263 diagnosticStr.c_str());
264 diagnosticStr.clear();
265
266 owningOp = parseSourceString(sourceStr: mlirSrc, config);
267 Operation *op = owningOp.get();
268 ASSERT_TRUE(op != nullptr);
269 Location loc = op->getLoc();
270 auto opWithProp = dyn_cast<OpWithProperties>(Val: op);
271 ASSERT_TRUE(opWithProp);
272
273 OperationState state(loc, op->getName());
274 Builder b{&context};
275 NamedAttrList attrs;
276 attrs.push_back(newAttribute: b.getNamedAttr("a", b.getStringAttr("foo")));
277 state.propertiesAttr = attrs.getDictionary(&context);
278 {
279 auto emitError = [&]() {
280 return op->emitError(message: "setting properties failed: ");
281 };
282 auto result = state.setProperties(op, emitError);
283 EXPECT_TRUE(result.failed());
284 }
285 EXPECT_STREQ("setting properties failed: expected IntegerAttr for key `a`",
286 diagnosticStr.c_str());
287}
288
289TEST(OpPropertiesTest, DefaultValues) {
290 MLIRContext context;
291 context.getOrLoadDialect<TestOpPropertiesDialect>();
292 OperationState state(UnknownLoc::get(&context),
293 "test_op_properties.op_with_properties");
294 Operation *op = Operation::create(state);
295 ASSERT_TRUE(op != nullptr);
296 {
297 std::string output;
298 llvm::raw_string_ostream os(output);
299 op->print(os);
300 EXPECT_TRUE(StringRef(os.str()).contains("a = -1"));
301 EXPECT_TRUE(StringRef(os.str()).contains("b = -1"));
302 EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: -33>"));
303 }
304 op->erase();
305}
306
307TEST(OpPropertiesTest, Cloning) {
308 MLIRContext context;
309 context.getOrLoadDialect<TestOpPropertiesDialect>();
310 ParserConfig config(&context);
311 // Parse the operation with some properties.
312 OwningOpRef<Operation *> op = parseSourceString(sourceStr: mlirSrc, config);
313 ASSERT_TRUE(op.get() != nullptr);
314 auto opWithProp = dyn_cast<OpWithProperties>(Val: op.get());
315 ASSERT_TRUE(opWithProp);
316 Operation *clone = opWithProp->clone();
317
318 // Check that op and its clone prints equally
319 std::string opStr;
320 std::string cloneStr;
321 {
322 llvm::raw_string_ostream os(opStr);
323 op.get()->print(os);
324 }
325 {
326 llvm::raw_string_ostream os(cloneStr);
327 clone->print(os);
328 }
329 clone->erase();
330 EXPECT_STREQ(opStr.c_str(), cloneStr.c_str());
331}
332
333TEST(OpPropertiesTest, Equivalence) {
334 MLIRContext context;
335 context.getOrLoadDialect<TestOpPropertiesDialect>();
336 ParserConfig config(&context);
337 // Parse the operation with some properties.
338 OwningOpRef<Operation *> op = parseSourceString(sourceStr: mlirSrc, config);
339 ASSERT_TRUE(op.get() != nullptr);
340 auto opWithProp = dyn_cast<OpWithProperties>(Val: op.get());
341 ASSERT_TRUE(opWithProp);
342 llvm::hash_code reference = OperationEquivalence::computeHash(op: opWithProp);
343 TestProperties &prop = opWithProp.getProperties();
344 prop.a = 42;
345 EXPECT_NE(reference, OperationEquivalence::computeHash(opWithProp));
346 prop.a = -42;
347 EXPECT_EQ(reference, OperationEquivalence::computeHash(opWithProp));
348 prop.b = 42.;
349 EXPECT_NE(reference, OperationEquivalence::computeHash(opWithProp));
350 prop.b = -42.;
351 EXPECT_EQ(reference, OperationEquivalence::computeHash(opWithProp));
352 prop.array.push_back(x: 42);
353 EXPECT_NE(reference, OperationEquivalence::computeHash(opWithProp));
354 prop.array.pop_back();
355 EXPECT_EQ(reference, OperationEquivalence::computeHash(opWithProp));
356}
357
358TEST(OpPropertiesTest, getOrAddProperties) {
359 MLIRContext context;
360 context.getOrLoadDialect<TestOpPropertiesDialect>();
361 OperationState state(UnknownLoc::get(&context),
362 "test_op_properties.op_with_properties");
363 // Test `getOrAddProperties` API on OperationState.
364 TestProperties &prop = state.getOrAddProperties<TestProperties>();
365 prop.a = 1;
366 prop.b = 2;
367 prop.array = {3, 4, 5};
368 Operation *op = Operation::create(state);
369 ASSERT_TRUE(op != nullptr);
370 {
371 std::string output;
372 llvm::raw_string_ostream os(output);
373 op->print(os);
374 EXPECT_TRUE(StringRef(os.str()).contains("a = 1"));
375 EXPECT_TRUE(StringRef(os.str()).contains("b = 2"));
376 EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: 3, 4, 5>"));
377 }
378 op->erase();
379}
380
381constexpr StringLiteral withoutPropertiesAttrsSrc = R"mlir(
382 "test_op_properties.op_without_properties"()
383 {inherent_attr = 42, other_attr = 56} : () -> ()
384)mlir";
385
386TEST(OpPropertiesTest, withoutPropertiesDiscardableAttrs) {
387 MLIRContext context;
388 context.getOrLoadDialect<TestOpPropertiesDialect>();
389 ParserConfig config(&context);
390 OwningOpRef<Operation *> op =
391 parseSourceString(sourceStr: withoutPropertiesAttrsSrc, config);
392 ASSERT_EQ(llvm::range_size(op->getDiscardableAttrs()), 1u);
393 EXPECT_EQ(op->getDiscardableAttrs().begin()->getName().getValue(),
394 "other_attr");
395
396 EXPECT_EQ(op->getAttrs().size(), 2u);
397 EXPECT_TRUE(op->getInherentAttr("inherent_attr") != std::nullopt);
398 EXPECT_TRUE(op->getDiscardableAttr("other_attr") != Attribute());
399
400 std::string output;
401 llvm::raw_string_ostream os(output);
402 op->print(os);
403 EXPECT_TRUE(StringRef(os.str()).contains("inherent_attr = 42"));
404 EXPECT_TRUE(StringRef(os.str()).contains("other_attr = 56"));
405
406 OwningOpRef<Operation *> reparsed = parseSourceString(sourceStr: os.str(), config);
407 auto trivialHash = [](Value v) { return hash_value(arg: v); };
408 auto hash = [&](Operation *operation) {
409 return OperationEquivalence::computeHash(
410 op: operation, hashOperands: trivialHash, hashResults: trivialHash,
411 flags: OperationEquivalence::Flags::IgnoreLocations);
412 };
413 EXPECT_TRUE(hash(op.get()) == hash(reparsed.get()));
414}
415
416} // namespace
417

source code of mlir/unittests/IR/OpPropertiesTest.cpp