1//===- IRDLLoading.cpp - IRDL dialect loading --------------------- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Manages the loading of MLIR objects from IRDL operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/IRDL/IRDLLoading.h"
14#include "mlir/Dialect/IRDL/IR/IRDL.h"
15#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h"
16#include "mlir/Dialect/IRDL/IRDLVerifiers.h"
17#include "mlir/IR/Attributes.h"
18#include "mlir/IR/BuiltinOps.h"
19#include "mlir/IR/ExtensibleDialect.h"
20#include "mlir/IR/OperationSupport.h"
21#include "mlir/Support/LogicalResult.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/SmallPtrSet.h"
24#include "llvm/Support/SMLoc.h"
25#include <numeric>
26
27using namespace mlir;
28using namespace mlir::irdl;
29
30/// Verify that the given list of parameters satisfy the given constraints.
31/// This encodes the logic of the verification method for attributes and types
32/// defined with IRDL.
33static LogicalResult
34irdlAttrOrTypeVerifier(function_ref<InFlightDiagnostic()> emitError,
35 ArrayRef<Attribute> params,
36 ArrayRef<std::unique_ptr<Constraint>> constraints,
37 ArrayRef<size_t> paramConstraints) {
38 if (params.size() != paramConstraints.size()) {
39 emitError() << "expected " << paramConstraints.size()
40 << " type arguments, but had " << params.size();
41 return failure();
42 }
43
44 ConstraintVerifier verifier(constraints);
45
46 // Check that each parameter satisfies its constraint.
47 for (auto [i, param] : enumerate(First&: params))
48 if (failed(result: verifier.verify(emitError, attr: param, variable: paramConstraints[i])))
49 return failure();
50
51 return success();
52}
53
54/// Get the operand segment sizes from the attribute dictionary.
55LogicalResult getSegmentSizesFromAttr(Operation *op, StringRef elemName,
56 StringRef attrName, unsigned numElements,
57 ArrayRef<Variadicity> variadicities,
58 SmallVectorImpl<int> &segmentSizes) {
59 // Get the segment sizes attribute, and check that it is of the right type.
60 Attribute segmentSizesAttr = op->getAttr(name: attrName);
61 if (!segmentSizesAttr) {
62 return op->emitError() << "'" << attrName
63 << "' attribute is expected but not provided";
64 }
65
66 auto denseSegmentSizes = dyn_cast<DenseI32ArrayAttr>(segmentSizesAttr);
67 if (!denseSegmentSizes) {
68 return op->emitError() << "'" << attrName
69 << "' attribute is expected to be a dense i32 array";
70 }
71
72 if (denseSegmentSizes.size() != (int64_t)variadicities.size()) {
73 return op->emitError() << "'" << attrName << "' attribute for specifying "
74 << elemName << " segments must have "
75 << variadicities.size() << " elements, but got "
76 << denseSegmentSizes.size();
77 }
78
79 // Check that the segment sizes are corresponding to the given variadicities,
80 for (auto [i, segmentSize, variadicity] :
81 enumerate(denseSegmentSizes.asArrayRef(), variadicities)) {
82 if (segmentSize < 0)
83 return op->emitError()
84 << "'" << attrName << "' attribute for specifying " << elemName
85 << " segments must have non-negative values";
86 if (variadicity == Variadicity::single && segmentSize != 1)
87 return op->emitError() << "element " << i << " in '" << attrName
88 << "' attribute must be equal to 1";
89
90 if (variadicity == Variadicity::optional && segmentSize > 1)
91 return op->emitError() << "element " << i << " in '" << attrName
92 << "' attribute must be equal to 0 or 1";
93
94 segmentSizes.push_back(segmentSize);
95 }
96
97 // Check that the sum of the segment sizes is equal to the number of elements.
98 int32_t sum = 0;
99 for (int32_t segmentSize : denseSegmentSizes.asArrayRef())
100 sum += segmentSize;
101 if (sum != static_cast<int32_t>(numElements))
102 return op->emitError() << "sum of elements in '" << attrName
103 << "' attribute must be equal to the number of "
104 << elemName << "s";
105
106 return success();
107}
108
109/// Compute the segment sizes of the given element (operands, results).
110/// If the operation has more than two non-single elements (optional or
111/// variadic), then get the segment sizes from the attribute dictionary.
112/// Otherwise, compute the segment sizes from the number of elements.
113/// `elemName` should be either `"operand"` or `"result"`.
114LogicalResult getSegmentSizes(Operation *op, StringRef elemName,
115 StringRef attrName, unsigned numElements,
116 ArrayRef<Variadicity> variadicities,
117 SmallVectorImpl<int> &segmentSizes) {
118 // If we have more than one non-single variadicity, we need to get the
119 // segment sizes from the attribute dictionary.
120 int numberNonSingle = count_if(
121 variadicities, [](Variadicity v) { return v != Variadicity::single; });
122 if (numberNonSingle > 1)
123 return getSegmentSizesFromAttr(op, elemName, attrName, numElements,
124 variadicities, segmentSizes);
125
126 // If we only have single variadicities, the segments sizes are all 1.
127 if (numberNonSingle == 0) {
128 if (numElements != variadicities.size()) {
129 return op->emitError() << "op expects exactly " << variadicities.size()
130 << " " << elemName << "s, but got " << numElements;
131 }
132 for (size_t i = 0, e = variadicities.size(); i < e; ++i)
133 segmentSizes.push_back(Elt: 1);
134 return success();
135 }
136
137 assert(numberNonSingle == 1);
138
139 // There is exactly one non-single element, so we can
140 // compute its size and check that it is valid.
141 int nonSingleSegmentSize = static_cast<int>(numElements) -
142 static_cast<int>(variadicities.size()) + 1;
143
144 if (nonSingleSegmentSize < 0) {
145 return op->emitError() << "op expects at least " << variadicities.size() - 1
146 << " " << elemName << "s, but got " << numElements;
147 }
148
149 // Add the segment sizes.
150 for (Variadicity variadicity : variadicities) {
151 if (variadicity == Variadicity::single) {
152 segmentSizes.push_back(1);
153 continue;
154 }
155
156 // If we have an optional element, we should check that it represents
157 // zero or one elements.
158 if (nonSingleSegmentSize > 1 && variadicity == Variadicity::optional)
159 return op->emitError() << "op expects at most " << variadicities.size()
160 << " " << elemName << "s, but got " << numElements;
161
162 segmentSizes.push_back(nonSingleSegmentSize);
163 }
164
165 return success();
166}
167
168/// Compute the segment sizes of the given operands.
169/// If the operation has more than two non-single operands (optional or
170/// variadic), then get the segment sizes from the attribute dictionary.
171/// Otherwise, compute the segment sizes from the number of operands.
172LogicalResult getOperandSegmentSizes(Operation *op,
173 ArrayRef<Variadicity> variadicities,
174 SmallVectorImpl<int> &segmentSizes) {
175 return getSegmentSizes(op, "operand", "operand_segment_sizes",
176 op->getNumOperands(), variadicities, segmentSizes);
177}
178
179/// Compute the segment sizes of the given results.
180/// If the operation has more than two non-single results (optional or
181/// variadic), then get the segment sizes from the attribute dictionary.
182/// Otherwise, compute the segment sizes from the number of results.
183LogicalResult getResultSegmentSizes(Operation *op,
184 ArrayRef<Variadicity> variadicities,
185 SmallVectorImpl<int> &segmentSizes) {
186 return getSegmentSizes(op, "result", "result_segment_sizes",
187 op->getNumResults(), variadicities, segmentSizes);
188}
189
190/// Verify that the given operation satisfies the given constraints.
191/// This encodes the logic of the verification method for operations defined
192/// with IRDL.
193static LogicalResult irdlOpVerifier(
194 Operation *op, ConstraintVerifier &verifier,
195 ArrayRef<size_t> operandConstrs, ArrayRef<Variadicity> operandVariadicity,
196 ArrayRef<size_t> resultConstrs, ArrayRef<Variadicity> resultVariadicity,
197 const DenseMap<StringAttr, size_t> &attributeConstrs) {
198 // Get the segment sizes for the operands.
199 // This will check that the number of operands is correct.
200 SmallVector<int> operandSegmentSizes;
201 if (failed(
202 getOperandSegmentSizes(op, operandVariadicity, operandSegmentSizes)))
203 return failure();
204
205 // Get the segment sizes for the results.
206 // This will check that the number of results is correct.
207 SmallVector<int> resultSegmentSizes;
208 if (failed(getResultSegmentSizes(op, resultVariadicity, resultSegmentSizes)))
209 return failure();
210
211 auto emitError = [op] { return op->emitError(); };
212
213 /// Сheck that we have all needed attributes passed
214 /// and they satisfy the constraints.
215 DictionaryAttr actualAttrs = op->getAttrDictionary();
216
217 for (auto [name, constraint] : attributeConstrs) {
218 /// First, check if the attribute actually passed.
219 std::optional<NamedAttribute> actual = actualAttrs.getNamed(name);
220 if (!actual.has_value())
221 return op->emitOpError()
222 << "attribute " << name << " is expected but not provided";
223
224 /// Then, check if the attribute value satisfies the constraint.
225 if (failed(verifier.verify(emitError: {emitError}, attr: actual->getValue(), variable: constraint)))
226 return failure();
227 }
228
229 // Check that all operands satisfy the constraints
230 int operandIdx = 0;
231 for (auto [defIndex, segmentSize] : enumerate(First&: operandSegmentSizes)) {
232 for (int i = 0; i < segmentSize; i++) {
233 if (failed(verifier.verify(
234 {emitError}, TypeAttr::get(op->getOperandTypes()[operandIdx]),
235 operandConstrs[defIndex])))
236 return failure();
237 ++operandIdx;
238 }
239 }
240
241 // Check that all results satisfy the constraints
242 int resultIdx = 0;
243 for (auto [defIndex, segmentSize] : enumerate(First&: resultSegmentSizes)) {
244 for (int i = 0; i < segmentSize; i++) {
245 if (failed(verifier.verify({emitError},
246 TypeAttr::get(op->getResultTypes()[resultIdx]),
247 resultConstrs[defIndex])))
248 return failure();
249 ++resultIdx;
250 }
251 }
252
253 return success();
254}
255
256static LogicalResult irdlRegionVerifier(
257 Operation *op, ConstraintVerifier &verifier,
258 ArrayRef<std::unique_ptr<RegionConstraint>> regionsConstraints) {
259 if (op->getNumRegions() != regionsConstraints.size()) {
260 return op->emitOpError()
261 << "unexpected number of regions: expected "
262 << regionsConstraints.size() << " but got " << op->getNumRegions();
263 }
264
265 for (auto [constraint, region] :
266 llvm::zip(t&: regionsConstraints, u: op->getRegions()))
267 if (failed(result: constraint->verify(region, constraintContext&: verifier)))
268 return failure();
269
270 return success();
271}
272
273/// Define and load an operation represented by a `irdl.operation`
274/// operation.
275static WalkResult loadOperation(
276 OperationOp op, ExtensibleDialect *dialect,
277 DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
278 DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
279 // Resolve SSA values to verifier constraint slots
280 SmallVector<Value> constrToValue;
281 SmallVector<Value> regionToValue;
282 for (Operation &op : op->getRegion(0).getOps()) {
283 if (isa<VerifyConstraintInterface>(op)) {
284 if (op.getNumResults() != 1)
285 return op.emitError()
286 << "IRDL constraint operations must have exactly one result";
287 constrToValue.push_back(op.getResult(0));
288 }
289 if (isa<VerifyRegionInterface>(op)) {
290 if (op.getNumResults() != 1)
291 return op.emitError()
292 << "IRDL constraint operations must have exactly one result";
293 regionToValue.push_back(op.getResult(0));
294 }
295 }
296
297 // Build the verifiers for each constraint slot
298 SmallVector<std::unique_ptr<Constraint>> constraints;
299 for (Value v : constrToValue) {
300 VerifyConstraintInterface op =
301 cast<VerifyConstraintInterface>(v.getDefiningOp());
302 std::unique_ptr<Constraint> verifier =
303 op.getVerifier(constrToValue, types, attrs);
304 if (!verifier)
305 return WalkResult::interrupt();
306 constraints.push_back(Elt: std::move(verifier));
307 }
308
309 // Build region constraints
310 SmallVector<std::unique_ptr<RegionConstraint>> regionConstraints;
311 for (Value v : regionToValue) {
312 VerifyRegionInterface op = cast<VerifyRegionInterface>(v.getDefiningOp());
313 std::unique_ptr<RegionConstraint> verifier =
314 op.getVerifier(constrToValue, types, attrs);
315 regionConstraints.push_back(Elt: std::move(verifier));
316 }
317
318 SmallVector<size_t> operandConstraints;
319 SmallVector<Variadicity> operandVariadicity;
320
321 // Gather which constraint slots correspond to operand constraints
322 auto operandsOp = op.getOp<OperandsOp>();
323 if (operandsOp.has_value()) {
324 operandConstraints.reserve(N: operandsOp->getArgs().size());
325 for (Value operand : operandsOp->getArgs()) {
326 for (auto [i, constr] : enumerate(constrToValue)) {
327 if (constr == operand) {
328 operandConstraints.push_back(i);
329 break;
330 }
331 }
332 }
333
334 // Gather the variadicities of each operand
335 for (VariadicityAttr attr : operandsOp->getVariadicity())
336 operandVariadicity.push_back(attr.getValue());
337 }
338
339 SmallVector<size_t> resultConstraints;
340 SmallVector<Variadicity> resultVariadicity;
341
342 // Gather which constraint slots correspond to result constraints
343 auto resultsOp = op.getOp<ResultsOp>();
344 if (resultsOp.has_value()) {
345 resultConstraints.reserve(N: resultsOp->getArgs().size());
346 for (Value result : resultsOp->getArgs()) {
347 for (auto [i, constr] : enumerate(constrToValue)) {
348 if (constr == result) {
349 resultConstraints.push_back(i);
350 break;
351 }
352 }
353 }
354
355 // Gather the variadicities of each result
356 for (Attribute attr : resultsOp->getVariadicity())
357 resultVariadicity.push_back(cast<VariadicityAttr>(attr).getValue());
358 }
359
360 // Gather which constraint slots correspond to attributes constraints
361 DenseMap<StringAttr, size_t> attributesContraints;
362 auto attributesOp = op.getOp<AttributesOp>();
363 if (attributesOp.has_value()) {
364 const Operation::operand_range values = attributesOp->getAttributeValues();
365 const ArrayAttr names = attributesOp->getAttributeValueNames();
366
367 for (const auto &[name, value] : llvm::zip(names, values)) {
368 for (auto [i, constr] : enumerate(constrToValue)) {
369 if (constr == value) {
370 attributesContraints[cast<StringAttr>(name)] = i;
371 break;
372 }
373 }
374 }
375 }
376
377 // IRDL does not support defining custom parsers or printers.
378 auto parser = [](OpAsmParser &parser, OperationState &result) {
379 return failure();
380 };
381 auto printer = [](Operation *op, OpAsmPrinter &printer, StringRef) {
382 printer.printGenericOp(op);
383 };
384
385 auto verifier =
386 [constraints{std::move(constraints)},
387 regionConstraints{std::move(regionConstraints)},
388 operandConstraints{std::move(operandConstraints)},
389 operandVariadicity{std::move(operandVariadicity)},
390 resultConstraints{std::move(resultConstraints)},
391 resultVariadicity{std::move(resultVariadicity)},
392 attributesContraints{std::move(attributesContraints)}](Operation *op) {
393 ConstraintVerifier verifier(constraints);
394 const LogicalResult opVerifierResult = irdlOpVerifier(
395 op, verifier, operandConstraints, operandVariadicity,
396 resultConstraints, resultVariadicity, attributesContraints);
397 const LogicalResult opRegionVerifierResult =
398 irdlRegionVerifier(op, verifier, regionsConstraints: regionConstraints);
399 return LogicalResult::success(isSuccess: opVerifierResult.succeeded() &&
400 opRegionVerifierResult.succeeded());
401 };
402
403 // IRDL supports only checking number of blocks and argument contraints
404 // It is done in the main verifier to reuse `ConstraintVerifier` context
405 auto regionVerifier = [](Operation *op) { return LogicalResult::success(); };
406
407 auto opDef = DynamicOpDefinition::get(
408 op.getName(), dialect, std::move(verifier), std::move(regionVerifier),
409 std::move(parser), std::move(printer));
410 dialect->registerDynamicOp(type: std::move(opDef));
411
412 return WalkResult::advance();
413}
414
415/// Get the verifier of a type or attribute definition.
416/// Return nullptr if the definition is invalid.
417static DynamicAttrDefinition::VerifierFn getAttrOrTypeVerifier(
418 Operation *attrOrTypeDef, ExtensibleDialect *dialect,
419 DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
420 DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
421 assert((isa<AttributeOp>(attrOrTypeDef) || isa<TypeOp>(attrOrTypeDef)) &&
422 "Expected an attribute or type definition");
423
424 // Resolve SSA values to verifier constraint slots
425 SmallVector<Value> constrToValue;
426 for (Operation &op : attrOrTypeDef->getRegion(index: 0).getOps()) {
427 if (isa<VerifyConstraintInterface>(op)) {
428 assert(op.getNumResults() == 1 &&
429 "IRDL constraint operations must have exactly one result");
430 constrToValue.push_back(Elt: op.getResult(idx: 0));
431 }
432 }
433
434 // Build the verifiers for each constraint slot
435 SmallVector<std::unique_ptr<Constraint>> constraints;
436 for (Value v : constrToValue) {
437 VerifyConstraintInterface op =
438 cast<VerifyConstraintInterface>(v.getDefiningOp());
439 std::unique_ptr<Constraint> verifier =
440 op.getVerifier(constrToValue, types, attrs);
441 if (!verifier)
442 return {};
443 constraints.push_back(Elt: std::move(verifier));
444 }
445
446 // Get the parameter definitions.
447 std::optional<ParametersOp> params;
448 if (auto attr = dyn_cast<AttributeOp>(attrOrTypeDef))
449 params = attr.getOp<ParametersOp>();
450 else if (auto type = dyn_cast<TypeOp>(attrOrTypeDef))
451 params = type.getOp<ParametersOp>();
452
453 // Gather which constraint slots correspond to parameter constraints
454 SmallVector<size_t> paramConstraints;
455 if (params.has_value()) {
456 paramConstraints.reserve(N: params->getArgs().size());
457 for (Value param : params->getArgs()) {
458 for (auto [i, constr] : enumerate(constrToValue)) {
459 if (constr == param) {
460 paramConstraints.push_back(i);
461 break;
462 }
463 }
464 }
465 }
466
467 auto verifier = [paramConstraints{std::move(paramConstraints)},
468 constraints{std::move(constraints)}](
469 function_ref<InFlightDiagnostic()> emitError,
470 ArrayRef<Attribute> params) {
471 return irdlAttrOrTypeVerifier(emitError, params, constraints,
472 paramConstraints);
473 };
474
475 // While the `std::move` is not required, not adding it triggers a bug in
476 // clang-10.
477 return std::move(verifier);
478}
479
480/// Get the possible bases of a constraint. Return `true` if all bases can
481/// potentially be matched.
482/// A base is a type or an attribute definition. For instance, the base of
483/// `irdl.parametric "!builtin.complex"(...)` is `builtin.complex`.
484/// This function returns the following information through arguments:
485/// - `paramIds`: the set of type or attribute IDs that are used as bases.
486/// - `paramIrdlOps`: the set of IRDL operations that are used as bases.
487/// - `isIds`: the set of type or attribute IDs that are used in `irdl.is`
488/// constraints.
489static bool getBases(Operation *op, SmallPtrSet<TypeID, 4> &paramIds,
490 SmallPtrSet<Operation *, 4> &paramIrdlOps,
491 SmallPtrSet<TypeID, 4> &isIds) {
492 // For `irdl.any_of`, we get the bases from all its arguments.
493 if (auto anyOf = dyn_cast<AnyOfOp>(op)) {
494 bool hasAny = false;
495 for (Value arg : anyOf.getArgs())
496 hasAny &= getBases(arg.getDefiningOp(), paramIds, paramIrdlOps, isIds);
497 return hasAny;
498 }
499
500 // For `irdl.all_of`, we get the bases from the first argument.
501 // This is restrictive, but we can relax it later if needed.
502 if (auto allOf = dyn_cast<AllOfOp>(op))
503 return getBases(allOf.getArgs()[0].getDefiningOp(), paramIds, paramIrdlOps,
504 isIds);
505
506 // For `irdl.parametric`, we get directly the base from the operation.
507 if (auto params = dyn_cast<ParametricOp>(op)) {
508 SymbolRefAttr symRef = params.getBaseType();
509 Operation *defOp = SymbolTable::lookupNearestSymbolFrom(op, symRef);
510 assert(defOp && "symbol reference should refer to an existing operation");
511 paramIrdlOps.insert(Ptr: defOp);
512 return false;
513 }
514
515 // For `irdl.is`, we get the base TypeID directly.
516 if (auto is = dyn_cast<IsOp>(op)) {
517 Attribute expected = is.getExpected();
518 isIds.insert(Ptr: expected.getTypeID());
519 return false;
520 }
521
522 // For `irdl.any`, we return `false` since we can match any type or attribute
523 // base.
524 if (auto isA = dyn_cast<AnyOp>(op))
525 return true;
526
527 llvm_unreachable("unknown IRDL constraint");
528}
529
530/// Check that an any_of is in the subset IRDL can handle.
531/// IRDL uses a greedy algorithm to match constraints. This means that if we
532/// encounter an `any_of` with multiple constraints, we will match the first
533/// constraint that is satisfied. Thus, the order of constraints matter in
534/// `any_of` with our current algorithm.
535/// In order to make the order of constraints irrelevant, we require that
536/// all `any_of` constraint parameters are disjoint. For this, we check that
537/// the base parameters are all disjoints between `parametric` operations, and
538/// that they are disjoint between `parametric` and `is` operations.
539/// This restriction will be relaxed in the future, when we will change our
540/// algorithm to be non-greedy.
541static LogicalResult checkCorrectAnyOf(AnyOfOp anyOf) {
542 SmallPtrSet<TypeID, 4> paramIds;
543 SmallPtrSet<Operation *, 4> paramIrdlOps;
544 SmallPtrSet<TypeID, 4> isIds;
545
546 for (Value arg : anyOf.getArgs()) {
547 Operation *argOp = arg.getDefiningOp();
548 SmallPtrSet<TypeID, 4> argParamIds;
549 SmallPtrSet<Operation *, 4> argParamIrdlOps;
550 SmallPtrSet<TypeID, 4> argIsIds;
551
552 // Get the bases of this argument. If it can match any type or attribute,
553 // then our `any_of` should not be allowed.
554 if (getBases(argOp, argParamIds, argParamIrdlOps, argIsIds))
555 return failure();
556
557 // We check that the base parameters are all disjoints between `parametric`
558 // operations, and that they are disjoint between `parametric` and `is`
559 // operations.
560 for (TypeID id : argParamIds) {
561 if (isIds.count(id))
562 return failure();
563 bool inserted = paramIds.insert(id).second;
564 if (!inserted)
565 return failure();
566 }
567
568 // We check that the base parameters are all disjoints with `irdl.is`
569 // operations.
570 for (TypeID id : isIds) {
571 if (paramIds.count(id))
572 return failure();
573 isIds.insert(id);
574 }
575
576 // We check that all `parametric` operations are disjoint. We do not
577 // need to check that they are disjoint with `is` operations, since
578 // `is` operations cannot refer to attributes defined with `irdl.parametric`
579 // operations.
580 for (Operation *op : argParamIrdlOps) {
581 bool inserted = paramIrdlOps.insert(op).second;
582 if (!inserted)
583 return failure();
584 }
585 }
586
587 return success();
588}
589
590/// Load all dialects in the given module, without loading any operation, type
591/// or attribute definitions.
592static DenseMap<DialectOp, ExtensibleDialect *> loadEmptyDialects(ModuleOp op) {
593 DenseMap<DialectOp, ExtensibleDialect *> dialects;
594 op.walk([&](DialectOp dialectOp) {
595 MLIRContext *ctx = dialectOp.getContext();
596 StringRef dialectName = dialectOp.getName();
597
598 DynamicDialect *dialect = ctx->getOrLoadDynamicDialect(
599 dialectNamespace: dialectName, ctor: [](DynamicDialect *dialect) {});
600
601 dialects.insert({dialectOp, dialect});
602 });
603 return dialects;
604}
605
606/// Preallocate type definitions objects with empty verifiers.
607/// This in particular allocates a TypeID for each type definition.
608static DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>>
609preallocateTypeDefs(ModuleOp op,
610 DenseMap<DialectOp, ExtensibleDialect *> dialects) {
611 DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> typeDefs;
612 op.walk([&](TypeOp typeOp) {
613 ExtensibleDialect *dialect = dialects[typeOp.getParentOp()];
614 auto typeDef = DynamicTypeDefinition::get(
615 typeOp.getName(), dialect,
616 [](function_ref<InFlightDiagnostic()>, ArrayRef<Attribute>) {
617 return success();
618 });
619 typeDefs.try_emplace(typeOp, std::move(typeDef));
620 });
621 return typeDefs;
622}
623
624/// Preallocate attribute definitions objects with empty verifiers.
625/// This in particular allocates a TypeID for each attribute definition.
626static DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>>
627preallocateAttrDefs(ModuleOp op,
628 DenseMap<DialectOp, ExtensibleDialect *> dialects) {
629 DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> attrDefs;
630 op.walk([&](AttributeOp attrOp) {
631 ExtensibleDialect *dialect = dialects[attrOp.getParentOp()];
632 auto attrDef = DynamicAttrDefinition::get(
633 attrOp.getName(), dialect,
634 [](function_ref<InFlightDiagnostic()>, ArrayRef<Attribute>) {
635 return success();
636 });
637 attrDefs.try_emplace(attrOp, std::move(attrDef));
638 });
639 return attrDefs;
640}
641
642LogicalResult mlir::irdl::loadDialects(ModuleOp op) {
643 // First, check that all any_of constraints are in a correct form.
644 // This is to ensure we can do the verification correctly.
645 WalkResult anyOfCorrects = op.walk(
646 [](AnyOfOp anyOf) { return (WalkResult)checkCorrectAnyOf(anyOf); });
647 if (anyOfCorrects.wasInterrupted())
648 return op.emitError("any_of constraints are not in the correct form");
649
650 // Preallocate all dialects, and type and attribute definitions.
651 // In particular, this allocates TypeIDs so type and attributes can have
652 // verifiers that refer to each other.
653 DenseMap<DialectOp, ExtensibleDialect *> dialects = loadEmptyDialects(op);
654 DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> types =
655 preallocateTypeDefs(op, dialects);
656 DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> attrs =
657 preallocateAttrDefs(op, dialects);
658
659 // Set the verifier for types.
660 WalkResult res = op.walk([&](TypeOp typeOp) {
661 DynamicAttrDefinition::VerifierFn verifier = getAttrOrTypeVerifier(
662 typeOp, dialects[typeOp.getParentOp()], types, attrs);
663 if (!verifier)
664 return WalkResult::interrupt();
665 types[typeOp]->setVerifyFn(std::move(verifier));
666 return WalkResult::advance();
667 });
668 if (res.wasInterrupted())
669 return failure();
670
671 // Set the verifier for attributes.
672 res = op.walk([&](AttributeOp attrOp) {
673 DynamicAttrDefinition::VerifierFn verifier = getAttrOrTypeVerifier(
674 attrOp, dialects[attrOp.getParentOp()], types, attrs);
675 if (!verifier)
676 return WalkResult::interrupt();
677 attrs[attrOp]->setVerifyFn(std::move(verifier));
678 return WalkResult::advance();
679 });
680 if (res.wasInterrupted())
681 return failure();
682
683 // Define and load all operations.
684 res = op.walk([&](OperationOp opOp) {
685 return loadOperation(opOp, dialects[opOp.getParentOp()], types, attrs);
686 });
687 if (res.wasInterrupted())
688 return failure();
689
690 // Load all types in their dialects.
691 for (auto &pair : types) {
692 ExtensibleDialect *dialect = dialects[pair.first.getParentOp()];
693 dialect->registerDynamicType(type: std::move(pair.second));
694 }
695
696 // Load all attributes in their dialects.
697 for (auto &pair : attrs) {
698 ExtensibleDialect *dialect = dialects[pair.first.getParentOp()];
699 dialect->registerDynamicAttr(attr: std::move(pair.second));
700 }
701
702 return success();
703}
704

source code of mlir/lib/Dialect/IRDL/IRDLLoading.cpp