1 | //===- IRDLLoading.cpp - IRDL dialect loading --------------------- C++ -*-===// |
2 | // |
3 | // This file is licensed under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Manages the loading of MLIR objects from IRDL operations. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/IRDL/IRDLLoading.h" |
14 | #include "mlir/Dialect/IRDL/IR/IRDL.h" |
15 | #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h" |
16 | #include "mlir/Dialect/IRDL/IRDLSymbols.h" |
17 | #include "mlir/Dialect/IRDL/IRDLVerifiers.h" |
18 | #include "mlir/IR/Attributes.h" |
19 | #include "mlir/IR/BuiltinOps.h" |
20 | #include "mlir/IR/ExtensibleDialect.h" |
21 | #include "mlir/IR/OperationSupport.h" |
22 | #include "llvm/ADT/STLExtras.h" |
23 | #include "llvm/ADT/SmallPtrSet.h" |
24 | #include "llvm/Support/SMLoc.h" |
25 | #include <numeric> |
26 | |
27 | using namespace mlir; |
28 | using namespace mlir::irdl; |
29 | |
30 | /// Verify that the given list of parameters satisfy the given constraints. |
31 | /// This encodes the logic of the verification method for attributes and types |
32 | /// defined with IRDL. |
33 | static LogicalResult |
34 | irdlAttrOrTypeVerifier(function_ref<InFlightDiagnostic()> emitError, |
35 | ArrayRef<Attribute> params, |
36 | ArrayRef<std::unique_ptr<Constraint>> constraints, |
37 | ArrayRef<size_t> paramConstraints) { |
38 | if (params.size() != paramConstraints.size()) { |
39 | emitError() << "expected " << paramConstraints.size() |
40 | << " type arguments, but had " << params.size(); |
41 | return failure(); |
42 | } |
43 | |
44 | ConstraintVerifier verifier(constraints); |
45 | |
46 | // Check that each parameter satisfies its constraint. |
47 | for (auto [i, param] : enumerate(First&: params)) |
48 | if (failed(Result: verifier.verify(emitError, attr: param, variable: paramConstraints[i]))) |
49 | return failure(); |
50 | |
51 | return success(); |
52 | } |
53 | |
54 | /// Get the operand segment sizes from the attribute dictionary. |
55 | LogicalResult getSegmentSizesFromAttr(Operation *op, StringRef elemName, |
56 | StringRef attrName, unsigned numElements, |
57 | ArrayRef<Variadicity> variadicities, |
58 | SmallVectorImpl<int> &segmentSizes) { |
59 | // Get the segment sizes attribute, and check that it is of the right type. |
60 | Attribute segmentSizesAttr = op->getAttr(name: attrName); |
61 | if (!segmentSizesAttr) { |
62 | return op->emitError() << "'" << attrName |
63 | << "' attribute is expected but not provided" ; |
64 | } |
65 | |
66 | auto denseSegmentSizes = dyn_cast<DenseI32ArrayAttr>(segmentSizesAttr); |
67 | if (!denseSegmentSizes) { |
68 | return op->emitError() << "'" << attrName |
69 | << "' attribute is expected to be a dense i32 array" ; |
70 | } |
71 | |
72 | if (denseSegmentSizes.size() != (int64_t)variadicities.size()) { |
73 | return op->emitError() << "'" << attrName << "' attribute for specifying " |
74 | << elemName << " segments must have " |
75 | << variadicities.size() << " elements, but got " |
76 | << denseSegmentSizes.size(); |
77 | } |
78 | |
79 | // Check that the segment sizes are corresponding to the given variadicities, |
80 | for (auto [i, segmentSize, variadicity] : |
81 | enumerate(denseSegmentSizes.asArrayRef(), variadicities)) { |
82 | if (segmentSize < 0) |
83 | return op->emitError() |
84 | << "'" << attrName << "' attribute for specifying " << elemName |
85 | << " segments must have non-negative values" ; |
86 | if (variadicity == Variadicity::single && segmentSize != 1) |
87 | return op->emitError() << "element " << i << " in '" << attrName |
88 | << "' attribute must be equal to 1" ; |
89 | |
90 | if (variadicity == Variadicity::optional && segmentSize > 1) |
91 | return op->emitError() << "element " << i << " in '" << attrName |
92 | << "' attribute must be equal to 0 or 1" ; |
93 | |
94 | segmentSizes.push_back(segmentSize); |
95 | } |
96 | |
97 | // Check that the sum of the segment sizes is equal to the number of elements. |
98 | int32_t sum = 0; |
99 | for (int32_t segmentSize : denseSegmentSizes.asArrayRef()) |
100 | sum += segmentSize; |
101 | if (sum != static_cast<int32_t>(numElements)) |
102 | return op->emitError() << "sum of elements in '" << attrName |
103 | << "' attribute must be equal to the number of " |
104 | << elemName << "s" ; |
105 | |
106 | return success(); |
107 | } |
108 | |
109 | /// Compute the segment sizes of the given element (operands, results). |
110 | /// If the operation has more than two non-single elements (optional or |
111 | /// variadic), then get the segment sizes from the attribute dictionary. |
112 | /// Otherwise, compute the segment sizes from the number of elements. |
113 | /// `elemName` should be either `"operand"` or `"result"`. |
114 | LogicalResult getSegmentSizes(Operation *op, StringRef elemName, |
115 | StringRef attrName, unsigned numElements, |
116 | ArrayRef<Variadicity> variadicities, |
117 | SmallVectorImpl<int> &segmentSizes) { |
118 | // If we have more than one non-single variadicity, we need to get the |
119 | // segment sizes from the attribute dictionary. |
120 | int numberNonSingle = count_if( |
121 | variadicities, [](Variadicity v) { return v != Variadicity::single; }); |
122 | if (numberNonSingle > 1) |
123 | return getSegmentSizesFromAttr(op, elemName, attrName, numElements, |
124 | variadicities, segmentSizes); |
125 | |
126 | // If we only have single variadicities, the segments sizes are all 1. |
127 | if (numberNonSingle == 0) { |
128 | if (numElements != variadicities.size()) { |
129 | return op->emitError() << "op expects exactly " << variadicities.size() |
130 | << " " << elemName << "s, but got " << numElements; |
131 | } |
132 | for (size_t i = 0, e = variadicities.size(); i < e; ++i) |
133 | segmentSizes.push_back(Elt: 1); |
134 | return success(); |
135 | } |
136 | |
137 | assert(numberNonSingle == 1); |
138 | |
139 | // There is exactly one non-single element, so we can |
140 | // compute its size and check that it is valid. |
141 | int nonSingleSegmentSize = static_cast<int>(numElements) - |
142 | static_cast<int>(variadicities.size()) + 1; |
143 | |
144 | if (nonSingleSegmentSize < 0) { |
145 | return op->emitError() << "op expects at least " << variadicities.size() - 1 |
146 | << " " << elemName << "s, but got " << numElements; |
147 | } |
148 | |
149 | // Add the segment sizes. |
150 | for (Variadicity variadicity : variadicities) { |
151 | if (variadicity == Variadicity::single) { |
152 | segmentSizes.push_back(1); |
153 | continue; |
154 | } |
155 | |
156 | // If we have an optional element, we should check that it represents |
157 | // zero or one elements. |
158 | if (nonSingleSegmentSize > 1 && variadicity == Variadicity::optional) |
159 | return op->emitError() << "op expects at most " << variadicities.size() |
160 | << " " << elemName << "s, but got " << numElements; |
161 | |
162 | segmentSizes.push_back(nonSingleSegmentSize); |
163 | } |
164 | |
165 | return success(); |
166 | } |
167 | |
168 | /// Compute the segment sizes of the given operands. |
169 | /// If the operation has more than two non-single operands (optional or |
170 | /// variadic), then get the segment sizes from the attribute dictionary. |
171 | /// Otherwise, compute the segment sizes from the number of operands. |
172 | LogicalResult getOperandSegmentSizes(Operation *op, |
173 | ArrayRef<Variadicity> variadicities, |
174 | SmallVectorImpl<int> &segmentSizes) { |
175 | return getSegmentSizes(op, "operand" , "operand_segment_sizes" , |
176 | op->getNumOperands(), variadicities, segmentSizes); |
177 | } |
178 | |
179 | /// Compute the segment sizes of the given results. |
180 | /// If the operation has more than two non-single results (optional or |
181 | /// variadic), then get the segment sizes from the attribute dictionary. |
182 | /// Otherwise, compute the segment sizes from the number of results. |
183 | LogicalResult getResultSegmentSizes(Operation *op, |
184 | ArrayRef<Variadicity> variadicities, |
185 | SmallVectorImpl<int> &segmentSizes) { |
186 | return getSegmentSizes(op, "result" , "result_segment_sizes" , |
187 | op->getNumResults(), variadicities, segmentSizes); |
188 | } |
189 | |
190 | /// Verify that the given operation satisfies the given constraints. |
191 | /// This encodes the logic of the verification method for operations defined |
192 | /// with IRDL. |
193 | static LogicalResult irdlOpVerifier( |
194 | Operation *op, ConstraintVerifier &verifier, |
195 | ArrayRef<size_t> operandConstrs, ArrayRef<Variadicity> operandVariadicity, |
196 | ArrayRef<size_t> resultConstrs, ArrayRef<Variadicity> resultVariadicity, |
197 | const DenseMap<StringAttr, size_t> &attributeConstrs) { |
198 | // Get the segment sizes for the operands. |
199 | // This will check that the number of operands is correct. |
200 | SmallVector<int> operandSegmentSizes; |
201 | if (failed( |
202 | getOperandSegmentSizes(op, operandVariadicity, operandSegmentSizes))) |
203 | return failure(); |
204 | |
205 | // Get the segment sizes for the results. |
206 | // This will check that the number of results is correct. |
207 | SmallVector<int> resultSegmentSizes; |
208 | if (failed(getResultSegmentSizes(op, resultVariadicity, resultSegmentSizes))) |
209 | return failure(); |
210 | |
211 | auto emitError = [op] { return op->emitError(); }; |
212 | |
213 | /// Сheck that we have all needed attributes passed |
214 | /// and they satisfy the constraints. |
215 | DictionaryAttr actualAttrs = op->getAttrDictionary(); |
216 | |
217 | for (auto [name, constraint] : attributeConstrs) { |
218 | /// First, check if the attribute actually passed. |
219 | std::optional<NamedAttribute> actual = actualAttrs.getNamed(name); |
220 | if (!actual.has_value()) |
221 | return op->emitOpError() |
222 | << "attribute " << name << " is expected but not provided" ; |
223 | |
224 | /// Then, check if the attribute value satisfies the constraint. |
225 | if (failed(verifier.verify(emitError: {emitError}, attr: actual->getValue(), variable: constraint))) |
226 | return failure(); |
227 | } |
228 | |
229 | // Check that all operands satisfy the constraints |
230 | int operandIdx = 0; |
231 | for (auto [defIndex, segmentSize] : enumerate(First&: operandSegmentSizes)) { |
232 | for (int i = 0; i < segmentSize; i++) { |
233 | if (failed(verifier.verify( |
234 | {emitError}, TypeAttr::get(op->getOperandTypes()[operandIdx]), |
235 | operandConstrs[defIndex]))) |
236 | return failure(); |
237 | ++operandIdx; |
238 | } |
239 | } |
240 | |
241 | // Check that all results satisfy the constraints |
242 | int resultIdx = 0; |
243 | for (auto [defIndex, segmentSize] : enumerate(First&: resultSegmentSizes)) { |
244 | for (int i = 0; i < segmentSize; i++) { |
245 | if (failed(verifier.verify({emitError}, |
246 | TypeAttr::get(op->getResultTypes()[resultIdx]), |
247 | resultConstrs[defIndex]))) |
248 | return failure(); |
249 | ++resultIdx; |
250 | } |
251 | } |
252 | |
253 | return success(); |
254 | } |
255 | |
256 | static LogicalResult irdlRegionVerifier( |
257 | Operation *op, ConstraintVerifier &verifier, |
258 | ArrayRef<std::unique_ptr<RegionConstraint>> regionsConstraints) { |
259 | if (op->getNumRegions() != regionsConstraints.size()) { |
260 | return op->emitOpError() |
261 | << "unexpected number of regions: expected " |
262 | << regionsConstraints.size() << " but got " << op->getNumRegions(); |
263 | } |
264 | |
265 | for (auto [constraint, region] : |
266 | llvm::zip(t&: regionsConstraints, u: op->getRegions())) |
267 | if (failed(Result: constraint->verify(region, constraintContext&: verifier))) |
268 | return failure(); |
269 | |
270 | return success(); |
271 | } |
272 | |
273 | llvm::unique_function<LogicalResult(Operation *) const> |
274 | mlir::irdl::createVerifier( |
275 | OperationOp op, |
276 | const DenseMap<irdl::TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types, |
277 | const DenseMap<irdl::AttributeOp, std::unique_ptr<DynamicAttrDefinition>> |
278 | &attrs) { |
279 | // Resolve SSA values to verifier constraint slots |
280 | SmallVector<Value> constrToValue; |
281 | SmallVector<Value> regionToValue; |
282 | for (Operation &op : op->getRegion(0).getOps()) { |
283 | if (isa<VerifyConstraintInterface>(op)) { |
284 | if (op.getNumResults() != 1) { |
285 | op.emitError() |
286 | << "IRDL constraint operations must have exactly one result" ; |
287 | return nullptr; |
288 | } |
289 | constrToValue.push_back(op.getResult(0)); |
290 | } |
291 | if (isa<VerifyRegionInterface>(op)) { |
292 | if (op.getNumResults() != 1) { |
293 | op.emitError() |
294 | << "IRDL constraint operations must have exactly one result" ; |
295 | return nullptr; |
296 | } |
297 | regionToValue.push_back(op.getResult(0)); |
298 | } |
299 | } |
300 | |
301 | // Build the verifiers for each constraint slot |
302 | SmallVector<std::unique_ptr<Constraint>> constraints; |
303 | for (Value v : constrToValue) { |
304 | VerifyConstraintInterface op = |
305 | cast<VerifyConstraintInterface>(v.getDefiningOp()); |
306 | std::unique_ptr<Constraint> verifier = |
307 | op.getVerifier(constrToValue, types, attrs); |
308 | if (!verifier) |
309 | return nullptr; |
310 | constraints.push_back(Elt: std::move(verifier)); |
311 | } |
312 | |
313 | // Build region constraints |
314 | SmallVector<std::unique_ptr<RegionConstraint>> regionConstraints; |
315 | for (Value v : regionToValue) { |
316 | VerifyRegionInterface op = cast<VerifyRegionInterface>(v.getDefiningOp()); |
317 | std::unique_ptr<RegionConstraint> verifier = |
318 | op.getVerifier(constrToValue, types, attrs); |
319 | regionConstraints.push_back(Elt: std::move(verifier)); |
320 | } |
321 | |
322 | SmallVector<size_t> operandConstraints; |
323 | SmallVector<Variadicity> operandVariadicity; |
324 | |
325 | // Gather which constraint slots correspond to operand constraints |
326 | auto operandsOp = op.getOp<OperandsOp>(); |
327 | if (operandsOp.has_value()) { |
328 | operandConstraints.reserve(N: operandsOp->getArgs().size()); |
329 | for (Value operand : operandsOp->getArgs()) { |
330 | for (auto [i, constr] : enumerate(constrToValue)) { |
331 | if (constr == operand) { |
332 | operandConstraints.push_back(i); |
333 | break; |
334 | } |
335 | } |
336 | } |
337 | |
338 | // Gather the variadicities of each operand |
339 | for (VariadicityAttr attr : operandsOp->getVariadicity()) |
340 | operandVariadicity.push_back(attr.getValue()); |
341 | } |
342 | |
343 | SmallVector<size_t> resultConstraints; |
344 | SmallVector<Variadicity> resultVariadicity; |
345 | |
346 | // Gather which constraint slots correspond to result constraints |
347 | auto resultsOp = op.getOp<ResultsOp>(); |
348 | if (resultsOp.has_value()) { |
349 | resultConstraints.reserve(N: resultsOp->getArgs().size()); |
350 | for (Value result : resultsOp->getArgs()) { |
351 | for (auto [i, constr] : enumerate(constrToValue)) { |
352 | if (constr == result) { |
353 | resultConstraints.push_back(i); |
354 | break; |
355 | } |
356 | } |
357 | } |
358 | |
359 | // Gather the variadicities of each result |
360 | for (Attribute attr : resultsOp->getVariadicity()) |
361 | resultVariadicity.push_back(cast<VariadicityAttr>(attr).getValue()); |
362 | } |
363 | |
364 | // Gather which constraint slots correspond to attributes constraints |
365 | DenseMap<StringAttr, size_t> attributeConstraints; |
366 | auto attributesOp = op.getOp<AttributesOp>(); |
367 | if (attributesOp.has_value()) { |
368 | const Operation::operand_range values = attributesOp->getAttributeValues(); |
369 | const ArrayAttr names = attributesOp->getAttributeValueNames(); |
370 | |
371 | for (const auto &[name, value] : llvm::zip(names, values)) { |
372 | for (auto [i, constr] : enumerate(constrToValue)) { |
373 | if (constr == value) { |
374 | attributeConstraints[cast<StringAttr>(name)] = i; |
375 | break; |
376 | } |
377 | } |
378 | } |
379 | } |
380 | |
381 | return |
382 | [constraints{std::move(constraints)}, |
383 | regionConstraints{std::move(regionConstraints)}, |
384 | operandConstraints{std::move(operandConstraints)}, |
385 | operandVariadicity{std::move(operandVariadicity)}, |
386 | resultConstraints{std::move(resultConstraints)}, |
387 | resultVariadicity{std::move(resultVariadicity)}, |
388 | attributeConstraints{std::move(attributeConstraints)}](Operation *op) { |
389 | ConstraintVerifier verifier(constraints); |
390 | const LogicalResult opVerifierResult = irdlOpVerifier( |
391 | op, verifier, operandConstraints, operandVariadicity, |
392 | resultConstraints, resultVariadicity, attributeConstraints); |
393 | const LogicalResult opRegionVerifierResult = |
394 | irdlRegionVerifier(op, verifier, regionsConstraints: regionConstraints); |
395 | return LogicalResult::success(IsSuccess: opVerifierResult.succeeded() && |
396 | opRegionVerifierResult.succeeded()); |
397 | }; |
398 | } |
399 | |
400 | /// Define and load an operation represented by a `irdl.operation` |
401 | /// operation. |
402 | static WalkResult loadOperation( |
403 | OperationOp op, ExtensibleDialect *dialect, |
404 | const DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types, |
405 | const DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> |
406 | &attrs) { |
407 | |
408 | // IRDL does not support defining custom parsers or printers. |
409 | auto parser = [](OpAsmParser &parser, OperationState &result) { |
410 | return failure(); |
411 | }; |
412 | auto printer = [](Operation *op, OpAsmPrinter &printer, StringRef) { |
413 | printer.printGenericOp(op); |
414 | }; |
415 | |
416 | auto verifier = createVerifier(op, types, attrs); |
417 | if (!verifier) |
418 | return WalkResult::interrupt(); |
419 | |
420 | // IRDL supports only checking number of blocks and argument constraints |
421 | // It is done in the main verifier to reuse `ConstraintVerifier` context |
422 | auto regionVerifier = [](Operation *op) { return LogicalResult::success(); }; |
423 | |
424 | auto opDef = DynamicOpDefinition::get( |
425 | op.getName(), dialect, std::move(verifier), std::move(regionVerifier), |
426 | std::move(parser), std::move(printer)); |
427 | dialect->registerDynamicOp(type: std::move(opDef)); |
428 | |
429 | return WalkResult::advance(); |
430 | } |
431 | |
432 | /// Get the verifier of a type or attribute definition. |
433 | /// Return nullptr if the definition is invalid. |
434 | static DynamicAttrDefinition::VerifierFn getAttrOrTypeVerifier( |
435 | Operation *attrOrTypeDef, ExtensibleDialect *dialect, |
436 | DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types, |
437 | DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) { |
438 | assert((isa<AttributeOp>(attrOrTypeDef) || isa<TypeOp>(attrOrTypeDef)) && |
439 | "Expected an attribute or type definition" ); |
440 | |
441 | // Resolve SSA values to verifier constraint slots |
442 | SmallVector<Value> constrToValue; |
443 | for (Operation &op : attrOrTypeDef->getRegion(index: 0).getOps()) { |
444 | if (isa<VerifyConstraintInterface>(op)) { |
445 | assert(op.getNumResults() == 1 && |
446 | "IRDL constraint operations must have exactly one result" ); |
447 | constrToValue.push_back(Elt: op.getResult(idx: 0)); |
448 | } |
449 | } |
450 | |
451 | // Build the verifiers for each constraint slot |
452 | SmallVector<std::unique_ptr<Constraint>> constraints; |
453 | for (Value v : constrToValue) { |
454 | VerifyConstraintInterface op = |
455 | cast<VerifyConstraintInterface>(v.getDefiningOp()); |
456 | std::unique_ptr<Constraint> verifier = |
457 | op.getVerifier(constrToValue, types, attrs); |
458 | if (!verifier) |
459 | return {}; |
460 | constraints.push_back(Elt: std::move(verifier)); |
461 | } |
462 | |
463 | // Get the parameter definitions. |
464 | std::optional<ParametersOp> params; |
465 | if (auto attr = dyn_cast<AttributeOp>(attrOrTypeDef)) |
466 | params = attr.getOp<ParametersOp>(); |
467 | else if (auto type = dyn_cast<TypeOp>(attrOrTypeDef)) |
468 | params = type.getOp<ParametersOp>(); |
469 | |
470 | // Gather which constraint slots correspond to parameter constraints |
471 | SmallVector<size_t> paramConstraints; |
472 | if (params.has_value()) { |
473 | paramConstraints.reserve(N: params->getArgs().size()); |
474 | for (Value param : params->getArgs()) { |
475 | for (auto [i, constr] : enumerate(constrToValue)) { |
476 | if (constr == param) { |
477 | paramConstraints.push_back(i); |
478 | break; |
479 | } |
480 | } |
481 | } |
482 | } |
483 | |
484 | auto verifier = [paramConstraints{std::move(paramConstraints)}, |
485 | constraints{std::move(constraints)}]( |
486 | function_ref<InFlightDiagnostic()> emitError, |
487 | ArrayRef<Attribute> params) { |
488 | return irdlAttrOrTypeVerifier(emitError, params, constraints, |
489 | paramConstraints); |
490 | }; |
491 | |
492 | // While the `std::move` is not required, not adding it triggers a bug in |
493 | // clang-10. |
494 | return std::move(verifier); |
495 | } |
496 | |
497 | /// Get the possible bases of a constraint. Return `true` if all bases can |
498 | /// potentially be matched. |
499 | /// A base is a type or an attribute definition. For instance, the base of |
500 | /// `irdl.parametric "!builtin.complex"(...)` is `builtin.complex`. |
501 | /// This function returns the following information through arguments: |
502 | /// - `paramIds`: the set of type or attribute IDs that are used as bases. |
503 | /// - `paramIrdlOps`: the set of IRDL operations that are used as bases. |
504 | /// - `isIds`: the set of type or attribute IDs that are used in `irdl.is` |
505 | /// constraints. |
506 | static bool getBases(Operation *op, SmallPtrSet<TypeID, 4> ¶mIds, |
507 | SmallPtrSet<Operation *, 4> ¶mIrdlOps, |
508 | SmallPtrSet<TypeID, 4> &isIds) { |
509 | // For `irdl.any_of`, we get the bases from all its arguments. |
510 | if (auto anyOf = dyn_cast<AnyOfOp>(op)) { |
511 | bool hasAny = false; |
512 | for (Value arg : anyOf.getArgs()) |
513 | hasAny &= getBases(arg.getDefiningOp(), paramIds, paramIrdlOps, isIds); |
514 | return hasAny; |
515 | } |
516 | |
517 | // For `irdl.all_of`, we get the bases from the first argument. |
518 | // This is restrictive, but we can relax it later if needed. |
519 | if (auto allOf = dyn_cast<AllOfOp>(op)) |
520 | return getBases(allOf.getArgs()[0].getDefiningOp(), paramIds, paramIrdlOps, |
521 | isIds); |
522 | |
523 | // For `irdl.parametric`, we get directly the base from the operation. |
524 | if (auto params = dyn_cast<ParametricOp>(op)) { |
525 | SymbolRefAttr symRef = params.getBaseType(); |
526 | Operation *defOp = irdl::lookupSymbolNearDialect(op, symRef); |
527 | assert(defOp && "symbol reference should refer to an existing operation" ); |
528 | paramIrdlOps.insert(Ptr: defOp); |
529 | return false; |
530 | } |
531 | |
532 | // For `irdl.is`, we get the base TypeID directly. |
533 | if (auto is = dyn_cast<IsOp>(op)) { |
534 | Attribute expected = is.getExpected(); |
535 | isIds.insert(Ptr: expected.getTypeID()); |
536 | return false; |
537 | } |
538 | |
539 | // For `irdl.any`, we return `false` since we can match any type or attribute |
540 | // base. |
541 | if (auto isA = dyn_cast<AnyOp>(op)) |
542 | return true; |
543 | |
544 | llvm_unreachable("unknown IRDL constraint" ); |
545 | } |
546 | |
547 | /// Check that an any_of is in the subset IRDL can handle. |
548 | /// IRDL uses a greedy algorithm to match constraints. This means that if we |
549 | /// encounter an `any_of` with multiple constraints, we will match the first |
550 | /// constraint that is satisfied. Thus, the order of constraints matter in |
551 | /// `any_of` with our current algorithm. |
552 | /// In order to make the order of constraints irrelevant, we require that |
553 | /// all `any_of` constraint parameters are disjoint. For this, we check that |
554 | /// the base parameters are all disjoints between `parametric` operations, and |
555 | /// that they are disjoint between `parametric` and `is` operations. |
556 | /// This restriction will be relaxed in the future, when we will change our |
557 | /// algorithm to be non-greedy. |
558 | static LogicalResult checkCorrectAnyOf(AnyOfOp anyOf) { |
559 | SmallPtrSet<TypeID, 4> paramIds; |
560 | SmallPtrSet<Operation *, 4> paramIrdlOps; |
561 | SmallPtrSet<TypeID, 4> isIds; |
562 | |
563 | for (Value arg : anyOf.getArgs()) { |
564 | Operation *argOp = arg.getDefiningOp(); |
565 | SmallPtrSet<TypeID, 4> argParamIds; |
566 | SmallPtrSet<Operation *, 4> argParamIrdlOps; |
567 | SmallPtrSet<TypeID, 4> argIsIds; |
568 | |
569 | // Get the bases of this argument. If it can match any type or attribute, |
570 | // then our `any_of` should not be allowed. |
571 | if (getBases(argOp, argParamIds, argParamIrdlOps, argIsIds)) |
572 | return failure(); |
573 | |
574 | // We check that the base parameters are all disjoints between `parametric` |
575 | // operations, and that they are disjoint between `parametric` and `is` |
576 | // operations. |
577 | for (TypeID id : argParamIds) { |
578 | if (isIds.count(id)) |
579 | return failure(); |
580 | bool inserted = paramIds.insert(id).second; |
581 | if (!inserted) |
582 | return failure(); |
583 | } |
584 | |
585 | // We check that the base parameters are all disjoints with `irdl.is` |
586 | // operations. |
587 | for (TypeID id : isIds) { |
588 | if (paramIds.count(id)) |
589 | return failure(); |
590 | isIds.insert(id); |
591 | } |
592 | |
593 | // We check that all `parametric` operations are disjoint. We do not |
594 | // need to check that they are disjoint with `is` operations, since |
595 | // `is` operations cannot refer to attributes defined with `irdl.parametric` |
596 | // operations. |
597 | for (Operation *op : argParamIrdlOps) { |
598 | bool inserted = paramIrdlOps.insert(op).second; |
599 | if (!inserted) |
600 | return failure(); |
601 | } |
602 | } |
603 | |
604 | return success(); |
605 | } |
606 | |
607 | /// Load all dialects in the given module, without loading any operation, type |
608 | /// or attribute definitions. |
609 | static DenseMap<DialectOp, ExtensibleDialect *> loadEmptyDialects(ModuleOp op) { |
610 | DenseMap<DialectOp, ExtensibleDialect *> dialects; |
611 | op.walk([&](DialectOp dialectOp) { |
612 | MLIRContext *ctx = dialectOp.getContext(); |
613 | StringRef dialectName = dialectOp.getName(); |
614 | |
615 | DynamicDialect *dialect = ctx->getOrLoadDynamicDialect( |
616 | dialectNamespace: dialectName, ctor: [](DynamicDialect *dialect) {}); |
617 | |
618 | dialects.insert({dialectOp, dialect}); |
619 | }); |
620 | return dialects; |
621 | } |
622 | |
623 | /// Preallocate type definitions objects with empty verifiers. |
624 | /// This in particular allocates a TypeID for each type definition. |
625 | static DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> |
626 | preallocateTypeDefs(ModuleOp op, |
627 | DenseMap<DialectOp, ExtensibleDialect *> dialects) { |
628 | DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> typeDefs; |
629 | op.walk([&](TypeOp typeOp) { |
630 | ExtensibleDialect *dialect = dialects[typeOp.getParentOp()]; |
631 | auto typeDef = DynamicTypeDefinition::get( |
632 | typeOp.getName(), dialect, |
633 | [](function_ref<InFlightDiagnostic()>, ArrayRef<Attribute>) { |
634 | return success(); |
635 | }); |
636 | typeDefs.try_emplace(typeOp, std::move(typeDef)); |
637 | }); |
638 | return typeDefs; |
639 | } |
640 | |
641 | /// Preallocate attribute definitions objects with empty verifiers. |
642 | /// This in particular allocates a TypeID for each attribute definition. |
643 | static DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> |
644 | preallocateAttrDefs(ModuleOp op, |
645 | DenseMap<DialectOp, ExtensibleDialect *> dialects) { |
646 | DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> attrDefs; |
647 | op.walk([&](AttributeOp attrOp) { |
648 | ExtensibleDialect *dialect = dialects[attrOp.getParentOp()]; |
649 | auto attrDef = DynamicAttrDefinition::get( |
650 | attrOp.getName(), dialect, |
651 | [](function_ref<InFlightDiagnostic()>, ArrayRef<Attribute>) { |
652 | return success(); |
653 | }); |
654 | attrDefs.try_emplace(attrOp, std::move(attrDef)); |
655 | }); |
656 | return attrDefs; |
657 | } |
658 | |
659 | LogicalResult mlir::irdl::loadDialects(ModuleOp op) { |
660 | // First, check that all any_of constraints are in a correct form. |
661 | // This is to ensure we can do the verification correctly. |
662 | WalkResult anyOfCorrects = op.walk( |
663 | [](AnyOfOp anyOf) { return (WalkResult)checkCorrectAnyOf(anyOf); }); |
664 | if (anyOfCorrects.wasInterrupted()) |
665 | return op.emitError("any_of constraints are not in the correct form" ); |
666 | |
667 | // Preallocate all dialects, and type and attribute definitions. |
668 | // In particular, this allocates TypeIDs so type and attributes can have |
669 | // verifiers that refer to each other. |
670 | DenseMap<DialectOp, ExtensibleDialect *> dialects = loadEmptyDialects(op); |
671 | DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> types = |
672 | preallocateTypeDefs(op, dialects); |
673 | DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> attrs = |
674 | preallocateAttrDefs(op, dialects); |
675 | |
676 | // Set the verifier for types. |
677 | WalkResult res = op.walk([&](TypeOp typeOp) { |
678 | DynamicAttrDefinition::VerifierFn verifier = getAttrOrTypeVerifier( |
679 | typeOp, dialects[typeOp.getParentOp()], types, attrs); |
680 | if (!verifier) |
681 | return WalkResult::interrupt(); |
682 | types[typeOp]->setVerifyFn(std::move(verifier)); |
683 | return WalkResult::advance(); |
684 | }); |
685 | if (res.wasInterrupted()) |
686 | return failure(); |
687 | |
688 | // Set the verifier for attributes. |
689 | res = op.walk([&](AttributeOp attrOp) { |
690 | DynamicAttrDefinition::VerifierFn verifier = getAttrOrTypeVerifier( |
691 | attrOp, dialects[attrOp.getParentOp()], types, attrs); |
692 | if (!verifier) |
693 | return WalkResult::interrupt(); |
694 | attrs[attrOp]->setVerifyFn(std::move(verifier)); |
695 | return WalkResult::advance(); |
696 | }); |
697 | if (res.wasInterrupted()) |
698 | return failure(); |
699 | |
700 | // Define and load all operations. |
701 | res = op.walk([&](OperationOp opOp) { |
702 | return loadOperation(opOp, dialects[opOp.getParentOp()], types, attrs); |
703 | }); |
704 | if (res.wasInterrupted()) |
705 | return failure(); |
706 | |
707 | // Load all types in their dialects. |
708 | for (auto &pair : types) { |
709 | ExtensibleDialect *dialect = dialects[pair.first.getParentOp()]; |
710 | dialect->registerDynamicType(type: std::move(pair.second)); |
711 | } |
712 | |
713 | // Load all attributes in their dialects. |
714 | for (auto &pair : attrs) { |
715 | ExtensibleDialect *dialect = dialects[pair.first.getParentOp()]; |
716 | dialect->registerDynamicAttr(attr: std::move(pair.second)); |
717 | } |
718 | |
719 | return success(); |
720 | } |
721 | |