1 | //===- IRDLLoading.cpp - IRDL dialect loading --------------------- C++ -*-===// |
2 | // |
3 | // This file is licensed under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Manages the loading of MLIR objects from IRDL operations. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/IRDL/IRDLLoading.h" |
14 | #include "mlir/Dialect/IRDL/IR/IRDL.h" |
15 | #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h" |
16 | #include "mlir/Dialect/IRDL/IRDLVerifiers.h" |
17 | #include "mlir/IR/Attributes.h" |
18 | #include "mlir/IR/BuiltinOps.h" |
19 | #include "mlir/IR/ExtensibleDialect.h" |
20 | #include "mlir/IR/OperationSupport.h" |
21 | #include "mlir/Support/LogicalResult.h" |
22 | #include "llvm/ADT/STLExtras.h" |
23 | #include "llvm/ADT/SmallPtrSet.h" |
24 | #include "llvm/Support/SMLoc.h" |
25 | #include <numeric> |
26 | |
27 | using namespace mlir; |
28 | using namespace mlir::irdl; |
29 | |
30 | /// Verify that the given list of parameters satisfy the given constraints. |
31 | /// This encodes the logic of the verification method for attributes and types |
32 | /// defined with IRDL. |
33 | static LogicalResult |
34 | irdlAttrOrTypeVerifier(function_ref<InFlightDiagnostic()> emitError, |
35 | ArrayRef<Attribute> params, |
36 | ArrayRef<std::unique_ptr<Constraint>> constraints, |
37 | ArrayRef<size_t> paramConstraints) { |
38 | if (params.size() != paramConstraints.size()) { |
39 | emitError() << "expected " << paramConstraints.size() |
40 | << " type arguments, but had " << params.size(); |
41 | return failure(); |
42 | } |
43 | |
44 | ConstraintVerifier verifier(constraints); |
45 | |
46 | // Check that each parameter satisfies its constraint. |
47 | for (auto [i, param] : enumerate(First&: params)) |
48 | if (failed(result: verifier.verify(emitError, attr: param, variable: paramConstraints[i]))) |
49 | return failure(); |
50 | |
51 | return success(); |
52 | } |
53 | |
54 | /// Get the operand segment sizes from the attribute dictionary. |
55 | LogicalResult getSegmentSizesFromAttr(Operation *op, StringRef elemName, |
56 | StringRef attrName, unsigned numElements, |
57 | ArrayRef<Variadicity> variadicities, |
58 | SmallVectorImpl<int> &segmentSizes) { |
59 | // Get the segment sizes attribute, and check that it is of the right type. |
60 | Attribute segmentSizesAttr = op->getAttr(name: attrName); |
61 | if (!segmentSizesAttr) { |
62 | return op->emitError() << "'" << attrName |
63 | << "' attribute is expected but not provided" ; |
64 | } |
65 | |
66 | auto denseSegmentSizes = dyn_cast<DenseI32ArrayAttr>(segmentSizesAttr); |
67 | if (!denseSegmentSizes) { |
68 | return op->emitError() << "'" << attrName |
69 | << "' attribute is expected to be a dense i32 array" ; |
70 | } |
71 | |
72 | if (denseSegmentSizes.size() != (int64_t)variadicities.size()) { |
73 | return op->emitError() << "'" << attrName << "' attribute for specifying " |
74 | << elemName << " segments must have " |
75 | << variadicities.size() << " elements, but got " |
76 | << denseSegmentSizes.size(); |
77 | } |
78 | |
79 | // Check that the segment sizes are corresponding to the given variadicities, |
80 | for (auto [i, segmentSize, variadicity] : |
81 | enumerate(denseSegmentSizes.asArrayRef(), variadicities)) { |
82 | if (segmentSize < 0) |
83 | return op->emitError() |
84 | << "'" << attrName << "' attribute for specifying " << elemName |
85 | << " segments must have non-negative values" ; |
86 | if (variadicity == Variadicity::single && segmentSize != 1) |
87 | return op->emitError() << "element " << i << " in '" << attrName |
88 | << "' attribute must be equal to 1" ; |
89 | |
90 | if (variadicity == Variadicity::optional && segmentSize > 1) |
91 | return op->emitError() << "element " << i << " in '" << attrName |
92 | << "' attribute must be equal to 0 or 1" ; |
93 | |
94 | segmentSizes.push_back(segmentSize); |
95 | } |
96 | |
97 | // Check that the sum of the segment sizes is equal to the number of elements. |
98 | int32_t sum = 0; |
99 | for (int32_t segmentSize : denseSegmentSizes.asArrayRef()) |
100 | sum += segmentSize; |
101 | if (sum != static_cast<int32_t>(numElements)) |
102 | return op->emitError() << "sum of elements in '" << attrName |
103 | << "' attribute must be equal to the number of " |
104 | << elemName << "s" ; |
105 | |
106 | return success(); |
107 | } |
108 | |
109 | /// Compute the segment sizes of the given element (operands, results). |
110 | /// If the operation has more than two non-single elements (optional or |
111 | /// variadic), then get the segment sizes from the attribute dictionary. |
112 | /// Otherwise, compute the segment sizes from the number of elements. |
113 | /// `elemName` should be either `"operand"` or `"result"`. |
114 | LogicalResult getSegmentSizes(Operation *op, StringRef elemName, |
115 | StringRef attrName, unsigned numElements, |
116 | ArrayRef<Variadicity> variadicities, |
117 | SmallVectorImpl<int> &segmentSizes) { |
118 | // If we have more than one non-single variadicity, we need to get the |
119 | // segment sizes from the attribute dictionary. |
120 | int numberNonSingle = count_if( |
121 | variadicities, [](Variadicity v) { return v != Variadicity::single; }); |
122 | if (numberNonSingle > 1) |
123 | return getSegmentSizesFromAttr(op, elemName, attrName, numElements, |
124 | variadicities, segmentSizes); |
125 | |
126 | // If we only have single variadicities, the segments sizes are all 1. |
127 | if (numberNonSingle == 0) { |
128 | if (numElements != variadicities.size()) { |
129 | return op->emitError() << "op expects exactly " << variadicities.size() |
130 | << " " << elemName << "s, but got " << numElements; |
131 | } |
132 | for (size_t i = 0, e = variadicities.size(); i < e; ++i) |
133 | segmentSizes.push_back(Elt: 1); |
134 | return success(); |
135 | } |
136 | |
137 | assert(numberNonSingle == 1); |
138 | |
139 | // There is exactly one non-single element, so we can |
140 | // compute its size and check that it is valid. |
141 | int nonSingleSegmentSize = static_cast<int>(numElements) - |
142 | static_cast<int>(variadicities.size()) + 1; |
143 | |
144 | if (nonSingleSegmentSize < 0) { |
145 | return op->emitError() << "op expects at least " << variadicities.size() - 1 |
146 | << " " << elemName << "s, but got " << numElements; |
147 | } |
148 | |
149 | // Add the segment sizes. |
150 | for (Variadicity variadicity : variadicities) { |
151 | if (variadicity == Variadicity::single) { |
152 | segmentSizes.push_back(1); |
153 | continue; |
154 | } |
155 | |
156 | // If we have an optional element, we should check that it represents |
157 | // zero or one elements. |
158 | if (nonSingleSegmentSize > 1 && variadicity == Variadicity::optional) |
159 | return op->emitError() << "op expects at most " << variadicities.size() |
160 | << " " << elemName << "s, but got " << numElements; |
161 | |
162 | segmentSizes.push_back(nonSingleSegmentSize); |
163 | } |
164 | |
165 | return success(); |
166 | } |
167 | |
168 | /// Compute the segment sizes of the given operands. |
169 | /// If the operation has more than two non-single operands (optional or |
170 | /// variadic), then get the segment sizes from the attribute dictionary. |
171 | /// Otherwise, compute the segment sizes from the number of operands. |
172 | LogicalResult getOperandSegmentSizes(Operation *op, |
173 | ArrayRef<Variadicity> variadicities, |
174 | SmallVectorImpl<int> &segmentSizes) { |
175 | return getSegmentSizes(op, "operand" , "operand_segment_sizes" , |
176 | op->getNumOperands(), variadicities, segmentSizes); |
177 | } |
178 | |
179 | /// Compute the segment sizes of the given results. |
180 | /// If the operation has more than two non-single results (optional or |
181 | /// variadic), then get the segment sizes from the attribute dictionary. |
182 | /// Otherwise, compute the segment sizes from the number of results. |
183 | LogicalResult getResultSegmentSizes(Operation *op, |
184 | ArrayRef<Variadicity> variadicities, |
185 | SmallVectorImpl<int> &segmentSizes) { |
186 | return getSegmentSizes(op, "result" , "result_segment_sizes" , |
187 | op->getNumResults(), variadicities, segmentSizes); |
188 | } |
189 | |
190 | /// Verify that the given operation satisfies the given constraints. |
191 | /// This encodes the logic of the verification method for operations defined |
192 | /// with IRDL. |
193 | static LogicalResult irdlOpVerifier( |
194 | Operation *op, ConstraintVerifier &verifier, |
195 | ArrayRef<size_t> operandConstrs, ArrayRef<Variadicity> operandVariadicity, |
196 | ArrayRef<size_t> resultConstrs, ArrayRef<Variadicity> resultVariadicity, |
197 | const DenseMap<StringAttr, size_t> &attributeConstrs) { |
198 | // Get the segment sizes for the operands. |
199 | // This will check that the number of operands is correct. |
200 | SmallVector<int> operandSegmentSizes; |
201 | if (failed( |
202 | getOperandSegmentSizes(op, operandVariadicity, operandSegmentSizes))) |
203 | return failure(); |
204 | |
205 | // Get the segment sizes for the results. |
206 | // This will check that the number of results is correct. |
207 | SmallVector<int> resultSegmentSizes; |
208 | if (failed(getResultSegmentSizes(op, resultVariadicity, resultSegmentSizes))) |
209 | return failure(); |
210 | |
211 | auto emitError = [op] { return op->emitError(); }; |
212 | |
213 | /// Сheck that we have all needed attributes passed |
214 | /// and they satisfy the constraints. |
215 | DictionaryAttr actualAttrs = op->getAttrDictionary(); |
216 | |
217 | for (auto [name, constraint] : attributeConstrs) { |
218 | /// First, check if the attribute actually passed. |
219 | std::optional<NamedAttribute> actual = actualAttrs.getNamed(name); |
220 | if (!actual.has_value()) |
221 | return op->emitOpError() |
222 | << "attribute " << name << " is expected but not provided" ; |
223 | |
224 | /// Then, check if the attribute value satisfies the constraint. |
225 | if (failed(verifier.verify(emitError: {emitError}, attr: actual->getValue(), variable: constraint))) |
226 | return failure(); |
227 | } |
228 | |
229 | // Check that all operands satisfy the constraints |
230 | int operandIdx = 0; |
231 | for (auto [defIndex, segmentSize] : enumerate(First&: operandSegmentSizes)) { |
232 | for (int i = 0; i < segmentSize; i++) { |
233 | if (failed(verifier.verify( |
234 | {emitError}, TypeAttr::get(op->getOperandTypes()[operandIdx]), |
235 | operandConstrs[defIndex]))) |
236 | return failure(); |
237 | ++operandIdx; |
238 | } |
239 | } |
240 | |
241 | // Check that all results satisfy the constraints |
242 | int resultIdx = 0; |
243 | for (auto [defIndex, segmentSize] : enumerate(First&: resultSegmentSizes)) { |
244 | for (int i = 0; i < segmentSize; i++) { |
245 | if (failed(verifier.verify({emitError}, |
246 | TypeAttr::get(op->getResultTypes()[resultIdx]), |
247 | resultConstrs[defIndex]))) |
248 | return failure(); |
249 | ++resultIdx; |
250 | } |
251 | } |
252 | |
253 | return success(); |
254 | } |
255 | |
256 | static LogicalResult irdlRegionVerifier( |
257 | Operation *op, ConstraintVerifier &verifier, |
258 | ArrayRef<std::unique_ptr<RegionConstraint>> regionsConstraints) { |
259 | if (op->getNumRegions() != regionsConstraints.size()) { |
260 | return op->emitOpError() |
261 | << "unexpected number of regions: expected " |
262 | << regionsConstraints.size() << " but got " << op->getNumRegions(); |
263 | } |
264 | |
265 | for (auto [constraint, region] : |
266 | llvm::zip(t&: regionsConstraints, u: op->getRegions())) |
267 | if (failed(result: constraint->verify(region, constraintContext&: verifier))) |
268 | return failure(); |
269 | |
270 | return success(); |
271 | } |
272 | |
273 | /// Define and load an operation represented by a `irdl.operation` |
274 | /// operation. |
275 | static WalkResult loadOperation( |
276 | OperationOp op, ExtensibleDialect *dialect, |
277 | DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types, |
278 | DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) { |
279 | // Resolve SSA values to verifier constraint slots |
280 | SmallVector<Value> constrToValue; |
281 | SmallVector<Value> regionToValue; |
282 | for (Operation &op : op->getRegion(0).getOps()) { |
283 | if (isa<VerifyConstraintInterface>(op)) { |
284 | if (op.getNumResults() != 1) |
285 | return op.emitError() |
286 | << "IRDL constraint operations must have exactly one result" ; |
287 | constrToValue.push_back(op.getResult(0)); |
288 | } |
289 | if (isa<VerifyRegionInterface>(op)) { |
290 | if (op.getNumResults() != 1) |
291 | return op.emitError() |
292 | << "IRDL constraint operations must have exactly one result" ; |
293 | regionToValue.push_back(op.getResult(0)); |
294 | } |
295 | } |
296 | |
297 | // Build the verifiers for each constraint slot |
298 | SmallVector<std::unique_ptr<Constraint>> constraints; |
299 | for (Value v : constrToValue) { |
300 | VerifyConstraintInterface op = |
301 | cast<VerifyConstraintInterface>(v.getDefiningOp()); |
302 | std::unique_ptr<Constraint> verifier = |
303 | op.getVerifier(constrToValue, types, attrs); |
304 | if (!verifier) |
305 | return WalkResult::interrupt(); |
306 | constraints.push_back(Elt: std::move(verifier)); |
307 | } |
308 | |
309 | // Build region constraints |
310 | SmallVector<std::unique_ptr<RegionConstraint>> regionConstraints; |
311 | for (Value v : regionToValue) { |
312 | VerifyRegionInterface op = cast<VerifyRegionInterface>(v.getDefiningOp()); |
313 | std::unique_ptr<RegionConstraint> verifier = |
314 | op.getVerifier(constrToValue, types, attrs); |
315 | regionConstraints.push_back(Elt: std::move(verifier)); |
316 | } |
317 | |
318 | SmallVector<size_t> operandConstraints; |
319 | SmallVector<Variadicity> operandVariadicity; |
320 | |
321 | // Gather which constraint slots correspond to operand constraints |
322 | auto operandsOp = op.getOp<OperandsOp>(); |
323 | if (operandsOp.has_value()) { |
324 | operandConstraints.reserve(N: operandsOp->getArgs().size()); |
325 | for (Value operand : operandsOp->getArgs()) { |
326 | for (auto [i, constr] : enumerate(constrToValue)) { |
327 | if (constr == operand) { |
328 | operandConstraints.push_back(i); |
329 | break; |
330 | } |
331 | } |
332 | } |
333 | |
334 | // Gather the variadicities of each operand |
335 | for (VariadicityAttr attr : operandsOp->getVariadicity()) |
336 | operandVariadicity.push_back(attr.getValue()); |
337 | } |
338 | |
339 | SmallVector<size_t> resultConstraints; |
340 | SmallVector<Variadicity> resultVariadicity; |
341 | |
342 | // Gather which constraint slots correspond to result constraints |
343 | auto resultsOp = op.getOp<ResultsOp>(); |
344 | if (resultsOp.has_value()) { |
345 | resultConstraints.reserve(N: resultsOp->getArgs().size()); |
346 | for (Value result : resultsOp->getArgs()) { |
347 | for (auto [i, constr] : enumerate(constrToValue)) { |
348 | if (constr == result) { |
349 | resultConstraints.push_back(i); |
350 | break; |
351 | } |
352 | } |
353 | } |
354 | |
355 | // Gather the variadicities of each result |
356 | for (Attribute attr : resultsOp->getVariadicity()) |
357 | resultVariadicity.push_back(cast<VariadicityAttr>(attr).getValue()); |
358 | } |
359 | |
360 | // Gather which constraint slots correspond to attributes constraints |
361 | DenseMap<StringAttr, size_t> attributesContraints; |
362 | auto attributesOp = op.getOp<AttributesOp>(); |
363 | if (attributesOp.has_value()) { |
364 | const Operation::operand_range values = attributesOp->getAttributeValues(); |
365 | const ArrayAttr names = attributesOp->getAttributeValueNames(); |
366 | |
367 | for (const auto &[name, value] : llvm::zip(names, values)) { |
368 | for (auto [i, constr] : enumerate(constrToValue)) { |
369 | if (constr == value) { |
370 | attributesContraints[cast<StringAttr>(name)] = i; |
371 | break; |
372 | } |
373 | } |
374 | } |
375 | } |
376 | |
377 | // IRDL does not support defining custom parsers or printers. |
378 | auto parser = [](OpAsmParser &parser, OperationState &result) { |
379 | return failure(); |
380 | }; |
381 | auto printer = [](Operation *op, OpAsmPrinter &printer, StringRef) { |
382 | printer.printGenericOp(op); |
383 | }; |
384 | |
385 | auto verifier = |
386 | [constraints{std::move(constraints)}, |
387 | regionConstraints{std::move(regionConstraints)}, |
388 | operandConstraints{std::move(operandConstraints)}, |
389 | operandVariadicity{std::move(operandVariadicity)}, |
390 | resultConstraints{std::move(resultConstraints)}, |
391 | resultVariadicity{std::move(resultVariadicity)}, |
392 | attributesContraints{std::move(attributesContraints)}](Operation *op) { |
393 | ConstraintVerifier verifier(constraints); |
394 | const LogicalResult opVerifierResult = irdlOpVerifier( |
395 | op, verifier, operandConstraints, operandVariadicity, |
396 | resultConstraints, resultVariadicity, attributesContraints); |
397 | const LogicalResult opRegionVerifierResult = |
398 | irdlRegionVerifier(op, verifier, regionsConstraints: regionConstraints); |
399 | return LogicalResult::success(isSuccess: opVerifierResult.succeeded() && |
400 | opRegionVerifierResult.succeeded()); |
401 | }; |
402 | |
403 | // IRDL supports only checking number of blocks and argument contraints |
404 | // It is done in the main verifier to reuse `ConstraintVerifier` context |
405 | auto regionVerifier = [](Operation *op) { return LogicalResult::success(); }; |
406 | |
407 | auto opDef = DynamicOpDefinition::get( |
408 | op.getName(), dialect, std::move(verifier), std::move(regionVerifier), |
409 | std::move(parser), std::move(printer)); |
410 | dialect->registerDynamicOp(type: std::move(opDef)); |
411 | |
412 | return WalkResult::advance(); |
413 | } |
414 | |
415 | /// Get the verifier of a type or attribute definition. |
416 | /// Return nullptr if the definition is invalid. |
417 | static DynamicAttrDefinition::VerifierFn getAttrOrTypeVerifier( |
418 | Operation *attrOrTypeDef, ExtensibleDialect *dialect, |
419 | DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types, |
420 | DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) { |
421 | assert((isa<AttributeOp>(attrOrTypeDef) || isa<TypeOp>(attrOrTypeDef)) && |
422 | "Expected an attribute or type definition" ); |
423 | |
424 | // Resolve SSA values to verifier constraint slots |
425 | SmallVector<Value> constrToValue; |
426 | for (Operation &op : attrOrTypeDef->getRegion(index: 0).getOps()) { |
427 | if (isa<VerifyConstraintInterface>(op)) { |
428 | assert(op.getNumResults() == 1 && |
429 | "IRDL constraint operations must have exactly one result" ); |
430 | constrToValue.push_back(Elt: op.getResult(idx: 0)); |
431 | } |
432 | } |
433 | |
434 | // Build the verifiers for each constraint slot |
435 | SmallVector<std::unique_ptr<Constraint>> constraints; |
436 | for (Value v : constrToValue) { |
437 | VerifyConstraintInterface op = |
438 | cast<VerifyConstraintInterface>(v.getDefiningOp()); |
439 | std::unique_ptr<Constraint> verifier = |
440 | op.getVerifier(constrToValue, types, attrs); |
441 | if (!verifier) |
442 | return {}; |
443 | constraints.push_back(Elt: std::move(verifier)); |
444 | } |
445 | |
446 | // Get the parameter definitions. |
447 | std::optional<ParametersOp> params; |
448 | if (auto attr = dyn_cast<AttributeOp>(attrOrTypeDef)) |
449 | params = attr.getOp<ParametersOp>(); |
450 | else if (auto type = dyn_cast<TypeOp>(attrOrTypeDef)) |
451 | params = type.getOp<ParametersOp>(); |
452 | |
453 | // Gather which constraint slots correspond to parameter constraints |
454 | SmallVector<size_t> paramConstraints; |
455 | if (params.has_value()) { |
456 | paramConstraints.reserve(N: params->getArgs().size()); |
457 | for (Value param : params->getArgs()) { |
458 | for (auto [i, constr] : enumerate(constrToValue)) { |
459 | if (constr == param) { |
460 | paramConstraints.push_back(i); |
461 | break; |
462 | } |
463 | } |
464 | } |
465 | } |
466 | |
467 | auto verifier = [paramConstraints{std::move(paramConstraints)}, |
468 | constraints{std::move(constraints)}]( |
469 | function_ref<InFlightDiagnostic()> emitError, |
470 | ArrayRef<Attribute> params) { |
471 | return irdlAttrOrTypeVerifier(emitError, params, constraints, |
472 | paramConstraints); |
473 | }; |
474 | |
475 | // While the `std::move` is not required, not adding it triggers a bug in |
476 | // clang-10. |
477 | return std::move(verifier); |
478 | } |
479 | |
480 | /// Get the possible bases of a constraint. Return `true` if all bases can |
481 | /// potentially be matched. |
482 | /// A base is a type or an attribute definition. For instance, the base of |
483 | /// `irdl.parametric "!builtin.complex"(...)` is `builtin.complex`. |
484 | /// This function returns the following information through arguments: |
485 | /// - `paramIds`: the set of type or attribute IDs that are used as bases. |
486 | /// - `paramIrdlOps`: the set of IRDL operations that are used as bases. |
487 | /// - `isIds`: the set of type or attribute IDs that are used in `irdl.is` |
488 | /// constraints. |
489 | static bool getBases(Operation *op, SmallPtrSet<TypeID, 4> ¶mIds, |
490 | SmallPtrSet<Operation *, 4> ¶mIrdlOps, |
491 | SmallPtrSet<TypeID, 4> &isIds) { |
492 | // For `irdl.any_of`, we get the bases from all its arguments. |
493 | if (auto anyOf = dyn_cast<AnyOfOp>(op)) { |
494 | bool hasAny = false; |
495 | for (Value arg : anyOf.getArgs()) |
496 | hasAny &= getBases(arg.getDefiningOp(), paramIds, paramIrdlOps, isIds); |
497 | return hasAny; |
498 | } |
499 | |
500 | // For `irdl.all_of`, we get the bases from the first argument. |
501 | // This is restrictive, but we can relax it later if needed. |
502 | if (auto allOf = dyn_cast<AllOfOp>(op)) |
503 | return getBases(allOf.getArgs()[0].getDefiningOp(), paramIds, paramIrdlOps, |
504 | isIds); |
505 | |
506 | // For `irdl.parametric`, we get directly the base from the operation. |
507 | if (auto params = dyn_cast<ParametricOp>(op)) { |
508 | SymbolRefAttr symRef = params.getBaseType(); |
509 | Operation *defOp = SymbolTable::lookupNearestSymbolFrom(op, symRef); |
510 | assert(defOp && "symbol reference should refer to an existing operation" ); |
511 | paramIrdlOps.insert(Ptr: defOp); |
512 | return false; |
513 | } |
514 | |
515 | // For `irdl.is`, we get the base TypeID directly. |
516 | if (auto is = dyn_cast<IsOp>(op)) { |
517 | Attribute expected = is.getExpected(); |
518 | isIds.insert(Ptr: expected.getTypeID()); |
519 | return false; |
520 | } |
521 | |
522 | // For `irdl.any`, we return `false` since we can match any type or attribute |
523 | // base. |
524 | if (auto isA = dyn_cast<AnyOp>(op)) |
525 | return true; |
526 | |
527 | llvm_unreachable("unknown IRDL constraint" ); |
528 | } |
529 | |
530 | /// Check that an any_of is in the subset IRDL can handle. |
531 | /// IRDL uses a greedy algorithm to match constraints. This means that if we |
532 | /// encounter an `any_of` with multiple constraints, we will match the first |
533 | /// constraint that is satisfied. Thus, the order of constraints matter in |
534 | /// `any_of` with our current algorithm. |
535 | /// In order to make the order of constraints irrelevant, we require that |
536 | /// all `any_of` constraint parameters are disjoint. For this, we check that |
537 | /// the base parameters are all disjoints between `parametric` operations, and |
538 | /// that they are disjoint between `parametric` and `is` operations. |
539 | /// This restriction will be relaxed in the future, when we will change our |
540 | /// algorithm to be non-greedy. |
541 | static LogicalResult checkCorrectAnyOf(AnyOfOp anyOf) { |
542 | SmallPtrSet<TypeID, 4> paramIds; |
543 | SmallPtrSet<Operation *, 4> paramIrdlOps; |
544 | SmallPtrSet<TypeID, 4> isIds; |
545 | |
546 | for (Value arg : anyOf.getArgs()) { |
547 | Operation *argOp = arg.getDefiningOp(); |
548 | SmallPtrSet<TypeID, 4> argParamIds; |
549 | SmallPtrSet<Operation *, 4> argParamIrdlOps; |
550 | SmallPtrSet<TypeID, 4> argIsIds; |
551 | |
552 | // Get the bases of this argument. If it can match any type or attribute, |
553 | // then our `any_of` should not be allowed. |
554 | if (getBases(argOp, argParamIds, argParamIrdlOps, argIsIds)) |
555 | return failure(); |
556 | |
557 | // We check that the base parameters are all disjoints between `parametric` |
558 | // operations, and that they are disjoint between `parametric` and `is` |
559 | // operations. |
560 | for (TypeID id : argParamIds) { |
561 | if (isIds.count(id)) |
562 | return failure(); |
563 | bool inserted = paramIds.insert(id).second; |
564 | if (!inserted) |
565 | return failure(); |
566 | } |
567 | |
568 | // We check that the base parameters are all disjoints with `irdl.is` |
569 | // operations. |
570 | for (TypeID id : isIds) { |
571 | if (paramIds.count(id)) |
572 | return failure(); |
573 | isIds.insert(id); |
574 | } |
575 | |
576 | // We check that all `parametric` operations are disjoint. We do not |
577 | // need to check that they are disjoint with `is` operations, since |
578 | // `is` operations cannot refer to attributes defined with `irdl.parametric` |
579 | // operations. |
580 | for (Operation *op : argParamIrdlOps) { |
581 | bool inserted = paramIrdlOps.insert(op).second; |
582 | if (!inserted) |
583 | return failure(); |
584 | } |
585 | } |
586 | |
587 | return success(); |
588 | } |
589 | |
590 | /// Load all dialects in the given module, without loading any operation, type |
591 | /// or attribute definitions. |
592 | static DenseMap<DialectOp, ExtensibleDialect *> loadEmptyDialects(ModuleOp op) { |
593 | DenseMap<DialectOp, ExtensibleDialect *> dialects; |
594 | op.walk([&](DialectOp dialectOp) { |
595 | MLIRContext *ctx = dialectOp.getContext(); |
596 | StringRef dialectName = dialectOp.getName(); |
597 | |
598 | DynamicDialect *dialect = ctx->getOrLoadDynamicDialect( |
599 | dialectNamespace: dialectName, ctor: [](DynamicDialect *dialect) {}); |
600 | |
601 | dialects.insert({dialectOp, dialect}); |
602 | }); |
603 | return dialects; |
604 | } |
605 | |
606 | /// Preallocate type definitions objects with empty verifiers. |
607 | /// This in particular allocates a TypeID for each type definition. |
608 | static DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> |
609 | preallocateTypeDefs(ModuleOp op, |
610 | DenseMap<DialectOp, ExtensibleDialect *> dialects) { |
611 | DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> typeDefs; |
612 | op.walk([&](TypeOp typeOp) { |
613 | ExtensibleDialect *dialect = dialects[typeOp.getParentOp()]; |
614 | auto typeDef = DynamicTypeDefinition::get( |
615 | typeOp.getName(), dialect, |
616 | [](function_ref<InFlightDiagnostic()>, ArrayRef<Attribute>) { |
617 | return success(); |
618 | }); |
619 | typeDefs.try_emplace(typeOp, std::move(typeDef)); |
620 | }); |
621 | return typeDefs; |
622 | } |
623 | |
624 | /// Preallocate attribute definitions objects with empty verifiers. |
625 | /// This in particular allocates a TypeID for each attribute definition. |
626 | static DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> |
627 | preallocateAttrDefs(ModuleOp op, |
628 | DenseMap<DialectOp, ExtensibleDialect *> dialects) { |
629 | DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> attrDefs; |
630 | op.walk([&](AttributeOp attrOp) { |
631 | ExtensibleDialect *dialect = dialects[attrOp.getParentOp()]; |
632 | auto attrDef = DynamicAttrDefinition::get( |
633 | attrOp.getName(), dialect, |
634 | [](function_ref<InFlightDiagnostic()>, ArrayRef<Attribute>) { |
635 | return success(); |
636 | }); |
637 | attrDefs.try_emplace(attrOp, std::move(attrDef)); |
638 | }); |
639 | return attrDefs; |
640 | } |
641 | |
642 | LogicalResult mlir::irdl::loadDialects(ModuleOp op) { |
643 | // First, check that all any_of constraints are in a correct form. |
644 | // This is to ensure we can do the verification correctly. |
645 | WalkResult anyOfCorrects = op.walk( |
646 | [](AnyOfOp anyOf) { return (WalkResult)checkCorrectAnyOf(anyOf); }); |
647 | if (anyOfCorrects.wasInterrupted()) |
648 | return op.emitError("any_of constraints are not in the correct form" ); |
649 | |
650 | // Preallocate all dialects, and type and attribute definitions. |
651 | // In particular, this allocates TypeIDs so type and attributes can have |
652 | // verifiers that refer to each other. |
653 | DenseMap<DialectOp, ExtensibleDialect *> dialects = loadEmptyDialects(op); |
654 | DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> types = |
655 | preallocateTypeDefs(op, dialects); |
656 | DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> attrs = |
657 | preallocateAttrDefs(op, dialects); |
658 | |
659 | // Set the verifier for types. |
660 | WalkResult res = op.walk([&](TypeOp typeOp) { |
661 | DynamicAttrDefinition::VerifierFn verifier = getAttrOrTypeVerifier( |
662 | typeOp, dialects[typeOp.getParentOp()], types, attrs); |
663 | if (!verifier) |
664 | return WalkResult::interrupt(); |
665 | types[typeOp]->setVerifyFn(std::move(verifier)); |
666 | return WalkResult::advance(); |
667 | }); |
668 | if (res.wasInterrupted()) |
669 | return failure(); |
670 | |
671 | // Set the verifier for attributes. |
672 | res = op.walk([&](AttributeOp attrOp) { |
673 | DynamicAttrDefinition::VerifierFn verifier = getAttrOrTypeVerifier( |
674 | attrOp, dialects[attrOp.getParentOp()], types, attrs); |
675 | if (!verifier) |
676 | return WalkResult::interrupt(); |
677 | attrs[attrOp]->setVerifyFn(std::move(verifier)); |
678 | return WalkResult::advance(); |
679 | }); |
680 | if (res.wasInterrupted()) |
681 | return failure(); |
682 | |
683 | // Define and load all operations. |
684 | res = op.walk([&](OperationOp opOp) { |
685 | return loadOperation(opOp, dialects[opOp.getParentOp()], types, attrs); |
686 | }); |
687 | if (res.wasInterrupted()) |
688 | return failure(); |
689 | |
690 | // Load all types in their dialects. |
691 | for (auto &pair : types) { |
692 | ExtensibleDialect *dialect = dialects[pair.first.getParentOp()]; |
693 | dialect->registerDynamicType(type: std::move(pair.second)); |
694 | } |
695 | |
696 | // Load all attributes in their dialects. |
697 | for (auto &pair : attrs) { |
698 | ExtensibleDialect *dialect = dialects[pair.first.getParentOp()]; |
699 | dialect->registerDynamicAttr(attr: std::move(pair.second)); |
700 | } |
701 | |
702 | return success(); |
703 | } |
704 | |