| 1 | //===- IRDLLoading.cpp - IRDL dialect loading --------------------- C++ -*-===// |
| 2 | // |
| 3 | // This file is licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Manages the loading of MLIR objects from IRDL operations. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/IRDL/IRDLLoading.h" |
| 14 | #include "mlir/Dialect/IRDL/IR/IRDL.h" |
| 15 | #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h" |
| 16 | #include "mlir/Dialect/IRDL/IRDLSymbols.h" |
| 17 | #include "mlir/Dialect/IRDL/IRDLVerifiers.h" |
| 18 | #include "mlir/IR/Attributes.h" |
| 19 | #include "mlir/IR/BuiltinOps.h" |
| 20 | #include "mlir/IR/ExtensibleDialect.h" |
| 21 | #include "mlir/IR/OperationSupport.h" |
| 22 | #include "llvm/ADT/STLExtras.h" |
| 23 | #include "llvm/ADT/SmallPtrSet.h" |
| 24 | #include "llvm/Support/SMLoc.h" |
| 25 | #include <numeric> |
| 26 | |
| 27 | using namespace mlir; |
| 28 | using namespace mlir::irdl; |
| 29 | |
| 30 | /// Verify that the given list of parameters satisfy the given constraints. |
| 31 | /// This encodes the logic of the verification method for attributes and types |
| 32 | /// defined with IRDL. |
| 33 | static LogicalResult |
| 34 | irdlAttrOrTypeVerifier(function_ref<InFlightDiagnostic()> emitError, |
| 35 | ArrayRef<Attribute> params, |
| 36 | ArrayRef<std::unique_ptr<Constraint>> constraints, |
| 37 | ArrayRef<size_t> paramConstraints) { |
| 38 | if (params.size() != paramConstraints.size()) { |
| 39 | emitError() << "expected " << paramConstraints.size() |
| 40 | << " type arguments, but had " << params.size(); |
| 41 | return failure(); |
| 42 | } |
| 43 | |
| 44 | ConstraintVerifier verifier(constraints); |
| 45 | |
| 46 | // Check that each parameter satisfies its constraint. |
| 47 | for (auto [i, param] : enumerate(First&: params)) |
| 48 | if (failed(Result: verifier.verify(emitError, attr: param, variable: paramConstraints[i]))) |
| 49 | return failure(); |
| 50 | |
| 51 | return success(); |
| 52 | } |
| 53 | |
| 54 | /// Get the operand segment sizes from the attribute dictionary. |
| 55 | LogicalResult getSegmentSizesFromAttr(Operation *op, StringRef elemName, |
| 56 | StringRef attrName, unsigned numElements, |
| 57 | ArrayRef<Variadicity> variadicities, |
| 58 | SmallVectorImpl<int> &segmentSizes) { |
| 59 | // Get the segment sizes attribute, and check that it is of the right type. |
| 60 | Attribute segmentSizesAttr = op->getAttr(name: attrName); |
| 61 | if (!segmentSizesAttr) { |
| 62 | return op->emitError() << "'" << attrName |
| 63 | << "' attribute is expected but not provided" ; |
| 64 | } |
| 65 | |
| 66 | auto denseSegmentSizes = dyn_cast<DenseI32ArrayAttr>(segmentSizesAttr); |
| 67 | if (!denseSegmentSizes) { |
| 68 | return op->emitError() << "'" << attrName |
| 69 | << "' attribute is expected to be a dense i32 array" ; |
| 70 | } |
| 71 | |
| 72 | if (denseSegmentSizes.size() != (int64_t)variadicities.size()) { |
| 73 | return op->emitError() << "'" << attrName << "' attribute for specifying " |
| 74 | << elemName << " segments must have " |
| 75 | << variadicities.size() << " elements, but got " |
| 76 | << denseSegmentSizes.size(); |
| 77 | } |
| 78 | |
| 79 | // Check that the segment sizes are corresponding to the given variadicities, |
| 80 | for (auto [i, segmentSize, variadicity] : |
| 81 | enumerate(denseSegmentSizes.asArrayRef(), variadicities)) { |
| 82 | if (segmentSize < 0) |
| 83 | return op->emitError() |
| 84 | << "'" << attrName << "' attribute for specifying " << elemName |
| 85 | << " segments must have non-negative values" ; |
| 86 | if (variadicity == Variadicity::single && segmentSize != 1) |
| 87 | return op->emitError() << "element " << i << " in '" << attrName |
| 88 | << "' attribute must be equal to 1" ; |
| 89 | |
| 90 | if (variadicity == Variadicity::optional && segmentSize > 1) |
| 91 | return op->emitError() << "element " << i << " in '" << attrName |
| 92 | << "' attribute must be equal to 0 or 1" ; |
| 93 | |
| 94 | segmentSizes.push_back(segmentSize); |
| 95 | } |
| 96 | |
| 97 | // Check that the sum of the segment sizes is equal to the number of elements. |
| 98 | int32_t sum = 0; |
| 99 | for (int32_t segmentSize : denseSegmentSizes.asArrayRef()) |
| 100 | sum += segmentSize; |
| 101 | if (sum != static_cast<int32_t>(numElements)) |
| 102 | return op->emitError() << "sum of elements in '" << attrName |
| 103 | << "' attribute must be equal to the number of " |
| 104 | << elemName << "s" ; |
| 105 | |
| 106 | return success(); |
| 107 | } |
| 108 | |
| 109 | /// Compute the segment sizes of the given element (operands, results). |
| 110 | /// If the operation has more than two non-single elements (optional or |
| 111 | /// variadic), then get the segment sizes from the attribute dictionary. |
| 112 | /// Otherwise, compute the segment sizes from the number of elements. |
| 113 | /// `elemName` should be either `"operand"` or `"result"`. |
| 114 | LogicalResult getSegmentSizes(Operation *op, StringRef elemName, |
| 115 | StringRef attrName, unsigned numElements, |
| 116 | ArrayRef<Variadicity> variadicities, |
| 117 | SmallVectorImpl<int> &segmentSizes) { |
| 118 | // If we have more than one non-single variadicity, we need to get the |
| 119 | // segment sizes from the attribute dictionary. |
| 120 | int numberNonSingle = count_if( |
| 121 | variadicities, [](Variadicity v) { return v != Variadicity::single; }); |
| 122 | if (numberNonSingle > 1) |
| 123 | return getSegmentSizesFromAttr(op, elemName, attrName, numElements, |
| 124 | variadicities, segmentSizes); |
| 125 | |
| 126 | // If we only have single variadicities, the segments sizes are all 1. |
| 127 | if (numberNonSingle == 0) { |
| 128 | if (numElements != variadicities.size()) { |
| 129 | return op->emitError() << "op expects exactly " << variadicities.size() |
| 130 | << " " << elemName << "s, but got " << numElements; |
| 131 | } |
| 132 | for (size_t i = 0, e = variadicities.size(); i < e; ++i) |
| 133 | segmentSizes.push_back(Elt: 1); |
| 134 | return success(); |
| 135 | } |
| 136 | |
| 137 | assert(numberNonSingle == 1); |
| 138 | |
| 139 | // There is exactly one non-single element, so we can |
| 140 | // compute its size and check that it is valid. |
| 141 | int nonSingleSegmentSize = static_cast<int>(numElements) - |
| 142 | static_cast<int>(variadicities.size()) + 1; |
| 143 | |
| 144 | if (nonSingleSegmentSize < 0) { |
| 145 | return op->emitError() << "op expects at least " << variadicities.size() - 1 |
| 146 | << " " << elemName << "s, but got " << numElements; |
| 147 | } |
| 148 | |
| 149 | // Add the segment sizes. |
| 150 | for (Variadicity variadicity : variadicities) { |
| 151 | if (variadicity == Variadicity::single) { |
| 152 | segmentSizes.push_back(1); |
| 153 | continue; |
| 154 | } |
| 155 | |
| 156 | // If we have an optional element, we should check that it represents |
| 157 | // zero or one elements. |
| 158 | if (nonSingleSegmentSize > 1 && variadicity == Variadicity::optional) |
| 159 | return op->emitError() << "op expects at most " << variadicities.size() |
| 160 | << " " << elemName << "s, but got " << numElements; |
| 161 | |
| 162 | segmentSizes.push_back(nonSingleSegmentSize); |
| 163 | } |
| 164 | |
| 165 | return success(); |
| 166 | } |
| 167 | |
| 168 | /// Compute the segment sizes of the given operands. |
| 169 | /// If the operation has more than two non-single operands (optional or |
| 170 | /// variadic), then get the segment sizes from the attribute dictionary. |
| 171 | /// Otherwise, compute the segment sizes from the number of operands. |
| 172 | LogicalResult getOperandSegmentSizes(Operation *op, |
| 173 | ArrayRef<Variadicity> variadicities, |
| 174 | SmallVectorImpl<int> &segmentSizes) { |
| 175 | return getSegmentSizes(op, "operand" , "operand_segment_sizes" , |
| 176 | op->getNumOperands(), variadicities, segmentSizes); |
| 177 | } |
| 178 | |
| 179 | /// Compute the segment sizes of the given results. |
| 180 | /// If the operation has more than two non-single results (optional or |
| 181 | /// variadic), then get the segment sizes from the attribute dictionary. |
| 182 | /// Otherwise, compute the segment sizes from the number of results. |
| 183 | LogicalResult getResultSegmentSizes(Operation *op, |
| 184 | ArrayRef<Variadicity> variadicities, |
| 185 | SmallVectorImpl<int> &segmentSizes) { |
| 186 | return getSegmentSizes(op, "result" , "result_segment_sizes" , |
| 187 | op->getNumResults(), variadicities, segmentSizes); |
| 188 | } |
| 189 | |
| 190 | /// Verify that the given operation satisfies the given constraints. |
| 191 | /// This encodes the logic of the verification method for operations defined |
| 192 | /// with IRDL. |
| 193 | static LogicalResult irdlOpVerifier( |
| 194 | Operation *op, ConstraintVerifier &verifier, |
| 195 | ArrayRef<size_t> operandConstrs, ArrayRef<Variadicity> operandVariadicity, |
| 196 | ArrayRef<size_t> resultConstrs, ArrayRef<Variadicity> resultVariadicity, |
| 197 | const DenseMap<StringAttr, size_t> &attributeConstrs) { |
| 198 | // Get the segment sizes for the operands. |
| 199 | // This will check that the number of operands is correct. |
| 200 | SmallVector<int> operandSegmentSizes; |
| 201 | if (failed( |
| 202 | getOperandSegmentSizes(op, operandVariadicity, operandSegmentSizes))) |
| 203 | return failure(); |
| 204 | |
| 205 | // Get the segment sizes for the results. |
| 206 | // This will check that the number of results is correct. |
| 207 | SmallVector<int> resultSegmentSizes; |
| 208 | if (failed(getResultSegmentSizes(op, resultVariadicity, resultSegmentSizes))) |
| 209 | return failure(); |
| 210 | |
| 211 | auto emitError = [op] { return op->emitError(); }; |
| 212 | |
| 213 | /// Сheck that we have all needed attributes passed |
| 214 | /// and they satisfy the constraints. |
| 215 | DictionaryAttr actualAttrs = op->getAttrDictionary(); |
| 216 | |
| 217 | for (auto [name, constraint] : attributeConstrs) { |
| 218 | /// First, check if the attribute actually passed. |
| 219 | std::optional<NamedAttribute> actual = actualAttrs.getNamed(name); |
| 220 | if (!actual.has_value()) |
| 221 | return op->emitOpError() |
| 222 | << "attribute " << name << " is expected but not provided" ; |
| 223 | |
| 224 | /// Then, check if the attribute value satisfies the constraint. |
| 225 | if (failed(verifier.verify(emitError: {emitError}, attr: actual->getValue(), variable: constraint))) |
| 226 | return failure(); |
| 227 | } |
| 228 | |
| 229 | // Check that all operands satisfy the constraints |
| 230 | int operandIdx = 0; |
| 231 | for (auto [defIndex, segmentSize] : enumerate(First&: operandSegmentSizes)) { |
| 232 | for (int i = 0; i < segmentSize; i++) { |
| 233 | if (failed(verifier.verify( |
| 234 | {emitError}, TypeAttr::get(op->getOperandTypes()[operandIdx]), |
| 235 | operandConstrs[defIndex]))) |
| 236 | return failure(); |
| 237 | ++operandIdx; |
| 238 | } |
| 239 | } |
| 240 | |
| 241 | // Check that all results satisfy the constraints |
| 242 | int resultIdx = 0; |
| 243 | for (auto [defIndex, segmentSize] : enumerate(First&: resultSegmentSizes)) { |
| 244 | for (int i = 0; i < segmentSize; i++) { |
| 245 | if (failed(verifier.verify({emitError}, |
| 246 | TypeAttr::get(op->getResultTypes()[resultIdx]), |
| 247 | resultConstrs[defIndex]))) |
| 248 | return failure(); |
| 249 | ++resultIdx; |
| 250 | } |
| 251 | } |
| 252 | |
| 253 | return success(); |
| 254 | } |
| 255 | |
| 256 | static LogicalResult irdlRegionVerifier( |
| 257 | Operation *op, ConstraintVerifier &verifier, |
| 258 | ArrayRef<std::unique_ptr<RegionConstraint>> regionsConstraints) { |
| 259 | if (op->getNumRegions() != regionsConstraints.size()) { |
| 260 | return op->emitOpError() |
| 261 | << "unexpected number of regions: expected " |
| 262 | << regionsConstraints.size() << " but got " << op->getNumRegions(); |
| 263 | } |
| 264 | |
| 265 | for (auto [constraint, region] : |
| 266 | llvm::zip(t&: regionsConstraints, u: op->getRegions())) |
| 267 | if (failed(Result: constraint->verify(region, constraintContext&: verifier))) |
| 268 | return failure(); |
| 269 | |
| 270 | return success(); |
| 271 | } |
| 272 | |
| 273 | llvm::unique_function<LogicalResult(Operation *) const> |
| 274 | mlir::irdl::createVerifier( |
| 275 | OperationOp op, |
| 276 | const DenseMap<irdl::TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types, |
| 277 | const DenseMap<irdl::AttributeOp, std::unique_ptr<DynamicAttrDefinition>> |
| 278 | &attrs) { |
| 279 | // Resolve SSA values to verifier constraint slots |
| 280 | SmallVector<Value> constrToValue; |
| 281 | SmallVector<Value> regionToValue; |
| 282 | for (Operation &op : op->getRegion(0).getOps()) { |
| 283 | if (isa<VerifyConstraintInterface>(op)) { |
| 284 | if (op.getNumResults() != 1) { |
| 285 | op.emitError() |
| 286 | << "IRDL constraint operations must have exactly one result" ; |
| 287 | return nullptr; |
| 288 | } |
| 289 | constrToValue.push_back(op.getResult(0)); |
| 290 | } |
| 291 | if (isa<VerifyRegionInterface>(op)) { |
| 292 | if (op.getNumResults() != 1) { |
| 293 | op.emitError() |
| 294 | << "IRDL constraint operations must have exactly one result" ; |
| 295 | return nullptr; |
| 296 | } |
| 297 | regionToValue.push_back(op.getResult(0)); |
| 298 | } |
| 299 | } |
| 300 | |
| 301 | // Build the verifiers for each constraint slot |
| 302 | SmallVector<std::unique_ptr<Constraint>> constraints; |
| 303 | for (Value v : constrToValue) { |
| 304 | VerifyConstraintInterface op = |
| 305 | cast<VerifyConstraintInterface>(v.getDefiningOp()); |
| 306 | std::unique_ptr<Constraint> verifier = |
| 307 | op.getVerifier(constrToValue, types, attrs); |
| 308 | if (!verifier) |
| 309 | return nullptr; |
| 310 | constraints.push_back(Elt: std::move(verifier)); |
| 311 | } |
| 312 | |
| 313 | // Build region constraints |
| 314 | SmallVector<std::unique_ptr<RegionConstraint>> regionConstraints; |
| 315 | for (Value v : regionToValue) { |
| 316 | VerifyRegionInterface op = cast<VerifyRegionInterface>(v.getDefiningOp()); |
| 317 | std::unique_ptr<RegionConstraint> verifier = |
| 318 | op.getVerifier(constrToValue, types, attrs); |
| 319 | regionConstraints.push_back(Elt: std::move(verifier)); |
| 320 | } |
| 321 | |
| 322 | SmallVector<size_t> operandConstraints; |
| 323 | SmallVector<Variadicity> operandVariadicity; |
| 324 | |
| 325 | // Gather which constraint slots correspond to operand constraints |
| 326 | auto operandsOp = op.getOp<OperandsOp>(); |
| 327 | if (operandsOp.has_value()) { |
| 328 | operandConstraints.reserve(N: operandsOp->getArgs().size()); |
| 329 | for (Value operand : operandsOp->getArgs()) { |
| 330 | for (auto [i, constr] : enumerate(constrToValue)) { |
| 331 | if (constr == operand) { |
| 332 | operandConstraints.push_back(i); |
| 333 | break; |
| 334 | } |
| 335 | } |
| 336 | } |
| 337 | |
| 338 | // Gather the variadicities of each operand |
| 339 | for (VariadicityAttr attr : operandsOp->getVariadicity()) |
| 340 | operandVariadicity.push_back(attr.getValue()); |
| 341 | } |
| 342 | |
| 343 | SmallVector<size_t> resultConstraints; |
| 344 | SmallVector<Variadicity> resultVariadicity; |
| 345 | |
| 346 | // Gather which constraint slots correspond to result constraints |
| 347 | auto resultsOp = op.getOp<ResultsOp>(); |
| 348 | if (resultsOp.has_value()) { |
| 349 | resultConstraints.reserve(N: resultsOp->getArgs().size()); |
| 350 | for (Value result : resultsOp->getArgs()) { |
| 351 | for (auto [i, constr] : enumerate(constrToValue)) { |
| 352 | if (constr == result) { |
| 353 | resultConstraints.push_back(i); |
| 354 | break; |
| 355 | } |
| 356 | } |
| 357 | } |
| 358 | |
| 359 | // Gather the variadicities of each result |
| 360 | for (Attribute attr : resultsOp->getVariadicity()) |
| 361 | resultVariadicity.push_back(cast<VariadicityAttr>(attr).getValue()); |
| 362 | } |
| 363 | |
| 364 | // Gather which constraint slots correspond to attributes constraints |
| 365 | DenseMap<StringAttr, size_t> attributeConstraints; |
| 366 | auto attributesOp = op.getOp<AttributesOp>(); |
| 367 | if (attributesOp.has_value()) { |
| 368 | const Operation::operand_range values = attributesOp->getAttributeValues(); |
| 369 | const ArrayAttr names = attributesOp->getAttributeValueNames(); |
| 370 | |
| 371 | for (const auto &[name, value] : llvm::zip(names, values)) { |
| 372 | for (auto [i, constr] : enumerate(constrToValue)) { |
| 373 | if (constr == value) { |
| 374 | attributeConstraints[cast<StringAttr>(name)] = i; |
| 375 | break; |
| 376 | } |
| 377 | } |
| 378 | } |
| 379 | } |
| 380 | |
| 381 | return |
| 382 | [constraints{std::move(constraints)}, |
| 383 | regionConstraints{std::move(regionConstraints)}, |
| 384 | operandConstraints{std::move(operandConstraints)}, |
| 385 | operandVariadicity{std::move(operandVariadicity)}, |
| 386 | resultConstraints{std::move(resultConstraints)}, |
| 387 | resultVariadicity{std::move(resultVariadicity)}, |
| 388 | attributeConstraints{std::move(attributeConstraints)}](Operation *op) { |
| 389 | ConstraintVerifier verifier(constraints); |
| 390 | const LogicalResult opVerifierResult = irdlOpVerifier( |
| 391 | op, verifier, operandConstraints, operandVariadicity, |
| 392 | resultConstraints, resultVariadicity, attributeConstraints); |
| 393 | const LogicalResult opRegionVerifierResult = |
| 394 | irdlRegionVerifier(op, verifier, regionsConstraints: regionConstraints); |
| 395 | return LogicalResult::success(IsSuccess: opVerifierResult.succeeded() && |
| 396 | opRegionVerifierResult.succeeded()); |
| 397 | }; |
| 398 | } |
| 399 | |
| 400 | /// Define and load an operation represented by a `irdl.operation` |
| 401 | /// operation. |
| 402 | static WalkResult loadOperation( |
| 403 | OperationOp op, ExtensibleDialect *dialect, |
| 404 | const DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types, |
| 405 | const DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> |
| 406 | &attrs) { |
| 407 | |
| 408 | // IRDL does not support defining custom parsers or printers. |
| 409 | auto parser = [](OpAsmParser &parser, OperationState &result) { |
| 410 | return failure(); |
| 411 | }; |
| 412 | auto printer = [](Operation *op, OpAsmPrinter &printer, StringRef) { |
| 413 | printer.printGenericOp(op); |
| 414 | }; |
| 415 | |
| 416 | auto verifier = createVerifier(op, types, attrs); |
| 417 | if (!verifier) |
| 418 | return WalkResult::interrupt(); |
| 419 | |
| 420 | // IRDL supports only checking number of blocks and argument constraints |
| 421 | // It is done in the main verifier to reuse `ConstraintVerifier` context |
| 422 | auto regionVerifier = [](Operation *op) { return LogicalResult::success(); }; |
| 423 | |
| 424 | auto opDef = DynamicOpDefinition::get( |
| 425 | op.getName(), dialect, std::move(verifier), std::move(regionVerifier), |
| 426 | std::move(parser), std::move(printer)); |
| 427 | dialect->registerDynamicOp(type: std::move(opDef)); |
| 428 | |
| 429 | return WalkResult::advance(); |
| 430 | } |
| 431 | |
| 432 | /// Get the verifier of a type or attribute definition. |
| 433 | /// Return nullptr if the definition is invalid. |
| 434 | static DynamicAttrDefinition::VerifierFn getAttrOrTypeVerifier( |
| 435 | Operation *attrOrTypeDef, ExtensibleDialect *dialect, |
| 436 | DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types, |
| 437 | DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) { |
| 438 | assert((isa<AttributeOp>(attrOrTypeDef) || isa<TypeOp>(attrOrTypeDef)) && |
| 439 | "Expected an attribute or type definition" ); |
| 440 | |
| 441 | // Resolve SSA values to verifier constraint slots |
| 442 | SmallVector<Value> constrToValue; |
| 443 | for (Operation &op : attrOrTypeDef->getRegion(index: 0).getOps()) { |
| 444 | if (isa<VerifyConstraintInterface>(op)) { |
| 445 | assert(op.getNumResults() == 1 && |
| 446 | "IRDL constraint operations must have exactly one result" ); |
| 447 | constrToValue.push_back(Elt: op.getResult(idx: 0)); |
| 448 | } |
| 449 | } |
| 450 | |
| 451 | // Build the verifiers for each constraint slot |
| 452 | SmallVector<std::unique_ptr<Constraint>> constraints; |
| 453 | for (Value v : constrToValue) { |
| 454 | VerifyConstraintInterface op = |
| 455 | cast<VerifyConstraintInterface>(v.getDefiningOp()); |
| 456 | std::unique_ptr<Constraint> verifier = |
| 457 | op.getVerifier(constrToValue, types, attrs); |
| 458 | if (!verifier) |
| 459 | return {}; |
| 460 | constraints.push_back(Elt: std::move(verifier)); |
| 461 | } |
| 462 | |
| 463 | // Get the parameter definitions. |
| 464 | std::optional<ParametersOp> params; |
| 465 | if (auto attr = dyn_cast<AttributeOp>(attrOrTypeDef)) |
| 466 | params = attr.getOp<ParametersOp>(); |
| 467 | else if (auto type = dyn_cast<TypeOp>(attrOrTypeDef)) |
| 468 | params = type.getOp<ParametersOp>(); |
| 469 | |
| 470 | // Gather which constraint slots correspond to parameter constraints |
| 471 | SmallVector<size_t> paramConstraints; |
| 472 | if (params.has_value()) { |
| 473 | paramConstraints.reserve(N: params->getArgs().size()); |
| 474 | for (Value param : params->getArgs()) { |
| 475 | for (auto [i, constr] : enumerate(constrToValue)) { |
| 476 | if (constr == param) { |
| 477 | paramConstraints.push_back(i); |
| 478 | break; |
| 479 | } |
| 480 | } |
| 481 | } |
| 482 | } |
| 483 | |
| 484 | auto verifier = [paramConstraints{std::move(paramConstraints)}, |
| 485 | constraints{std::move(constraints)}]( |
| 486 | function_ref<InFlightDiagnostic()> emitError, |
| 487 | ArrayRef<Attribute> params) { |
| 488 | return irdlAttrOrTypeVerifier(emitError, params, constraints, |
| 489 | paramConstraints); |
| 490 | }; |
| 491 | |
| 492 | // While the `std::move` is not required, not adding it triggers a bug in |
| 493 | // clang-10. |
| 494 | return std::move(verifier); |
| 495 | } |
| 496 | |
| 497 | /// Get the possible bases of a constraint. Return `true` if all bases can |
| 498 | /// potentially be matched. |
| 499 | /// A base is a type or an attribute definition. For instance, the base of |
| 500 | /// `irdl.parametric "!builtin.complex"(...)` is `builtin.complex`. |
| 501 | /// This function returns the following information through arguments: |
| 502 | /// - `paramIds`: the set of type or attribute IDs that are used as bases. |
| 503 | /// - `paramIrdlOps`: the set of IRDL operations that are used as bases. |
| 504 | /// - `isIds`: the set of type or attribute IDs that are used in `irdl.is` |
| 505 | /// constraints. |
| 506 | static bool getBases(Operation *op, SmallPtrSet<TypeID, 4> ¶mIds, |
| 507 | SmallPtrSet<Operation *, 4> ¶mIrdlOps, |
| 508 | SmallPtrSet<TypeID, 4> &isIds) { |
| 509 | // For `irdl.any_of`, we get the bases from all its arguments. |
| 510 | if (auto anyOf = dyn_cast<AnyOfOp>(op)) { |
| 511 | bool hasAny = false; |
| 512 | for (Value arg : anyOf.getArgs()) |
| 513 | hasAny &= getBases(arg.getDefiningOp(), paramIds, paramIrdlOps, isIds); |
| 514 | return hasAny; |
| 515 | } |
| 516 | |
| 517 | // For `irdl.all_of`, we get the bases from the first argument. |
| 518 | // This is restrictive, but we can relax it later if needed. |
| 519 | if (auto allOf = dyn_cast<AllOfOp>(op)) |
| 520 | return getBases(allOf.getArgs()[0].getDefiningOp(), paramIds, paramIrdlOps, |
| 521 | isIds); |
| 522 | |
| 523 | // For `irdl.parametric`, we get directly the base from the operation. |
| 524 | if (auto params = dyn_cast<ParametricOp>(op)) { |
| 525 | SymbolRefAttr symRef = params.getBaseType(); |
| 526 | Operation *defOp = irdl::lookupSymbolNearDialect(op, symRef); |
| 527 | assert(defOp && "symbol reference should refer to an existing operation" ); |
| 528 | paramIrdlOps.insert(Ptr: defOp); |
| 529 | return false; |
| 530 | } |
| 531 | |
| 532 | // For `irdl.is`, we get the base TypeID directly. |
| 533 | if (auto is = dyn_cast<IsOp>(op)) { |
| 534 | Attribute expected = is.getExpected(); |
| 535 | isIds.insert(Ptr: expected.getTypeID()); |
| 536 | return false; |
| 537 | } |
| 538 | |
| 539 | // For `irdl.any`, we return `false` since we can match any type or attribute |
| 540 | // base. |
| 541 | if (auto isA = dyn_cast<AnyOp>(op)) |
| 542 | return true; |
| 543 | |
| 544 | llvm_unreachable("unknown IRDL constraint" ); |
| 545 | } |
| 546 | |
| 547 | /// Check that an any_of is in the subset IRDL can handle. |
| 548 | /// IRDL uses a greedy algorithm to match constraints. This means that if we |
| 549 | /// encounter an `any_of` with multiple constraints, we will match the first |
| 550 | /// constraint that is satisfied. Thus, the order of constraints matter in |
| 551 | /// `any_of` with our current algorithm. |
| 552 | /// In order to make the order of constraints irrelevant, we require that |
| 553 | /// all `any_of` constraint parameters are disjoint. For this, we check that |
| 554 | /// the base parameters are all disjoints between `parametric` operations, and |
| 555 | /// that they are disjoint between `parametric` and `is` operations. |
| 556 | /// This restriction will be relaxed in the future, when we will change our |
| 557 | /// algorithm to be non-greedy. |
| 558 | static LogicalResult checkCorrectAnyOf(AnyOfOp anyOf) { |
| 559 | SmallPtrSet<TypeID, 4> paramIds; |
| 560 | SmallPtrSet<Operation *, 4> paramIrdlOps; |
| 561 | SmallPtrSet<TypeID, 4> isIds; |
| 562 | |
| 563 | for (Value arg : anyOf.getArgs()) { |
| 564 | Operation *argOp = arg.getDefiningOp(); |
| 565 | SmallPtrSet<TypeID, 4> argParamIds; |
| 566 | SmallPtrSet<Operation *, 4> argParamIrdlOps; |
| 567 | SmallPtrSet<TypeID, 4> argIsIds; |
| 568 | |
| 569 | // Get the bases of this argument. If it can match any type or attribute, |
| 570 | // then our `any_of` should not be allowed. |
| 571 | if (getBases(argOp, argParamIds, argParamIrdlOps, argIsIds)) |
| 572 | return failure(); |
| 573 | |
| 574 | // We check that the base parameters are all disjoints between `parametric` |
| 575 | // operations, and that they are disjoint between `parametric` and `is` |
| 576 | // operations. |
| 577 | for (TypeID id : argParamIds) { |
| 578 | if (isIds.count(id)) |
| 579 | return failure(); |
| 580 | bool inserted = paramIds.insert(id).second; |
| 581 | if (!inserted) |
| 582 | return failure(); |
| 583 | } |
| 584 | |
| 585 | // We check that the base parameters are all disjoints with `irdl.is` |
| 586 | // operations. |
| 587 | for (TypeID id : isIds) { |
| 588 | if (paramIds.count(id)) |
| 589 | return failure(); |
| 590 | isIds.insert(id); |
| 591 | } |
| 592 | |
| 593 | // We check that all `parametric` operations are disjoint. We do not |
| 594 | // need to check that they are disjoint with `is` operations, since |
| 595 | // `is` operations cannot refer to attributes defined with `irdl.parametric` |
| 596 | // operations. |
| 597 | for (Operation *op : argParamIrdlOps) { |
| 598 | bool inserted = paramIrdlOps.insert(op).second; |
| 599 | if (!inserted) |
| 600 | return failure(); |
| 601 | } |
| 602 | } |
| 603 | |
| 604 | return success(); |
| 605 | } |
| 606 | |
| 607 | /// Load all dialects in the given module, without loading any operation, type |
| 608 | /// or attribute definitions. |
| 609 | static DenseMap<DialectOp, ExtensibleDialect *> loadEmptyDialects(ModuleOp op) { |
| 610 | DenseMap<DialectOp, ExtensibleDialect *> dialects; |
| 611 | op.walk([&](DialectOp dialectOp) { |
| 612 | MLIRContext *ctx = dialectOp.getContext(); |
| 613 | StringRef dialectName = dialectOp.getName(); |
| 614 | |
| 615 | DynamicDialect *dialect = ctx->getOrLoadDynamicDialect( |
| 616 | dialectNamespace: dialectName, ctor: [](DynamicDialect *dialect) {}); |
| 617 | |
| 618 | dialects.insert({dialectOp, dialect}); |
| 619 | }); |
| 620 | return dialects; |
| 621 | } |
| 622 | |
| 623 | /// Preallocate type definitions objects with empty verifiers. |
| 624 | /// This in particular allocates a TypeID for each type definition. |
| 625 | static DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> |
| 626 | preallocateTypeDefs(ModuleOp op, |
| 627 | DenseMap<DialectOp, ExtensibleDialect *> dialects) { |
| 628 | DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> typeDefs; |
| 629 | op.walk([&](TypeOp typeOp) { |
| 630 | ExtensibleDialect *dialect = dialects[typeOp.getParentOp()]; |
| 631 | auto typeDef = DynamicTypeDefinition::get( |
| 632 | typeOp.getName(), dialect, |
| 633 | [](function_ref<InFlightDiagnostic()>, ArrayRef<Attribute>) { |
| 634 | return success(); |
| 635 | }); |
| 636 | typeDefs.try_emplace(typeOp, std::move(typeDef)); |
| 637 | }); |
| 638 | return typeDefs; |
| 639 | } |
| 640 | |
| 641 | /// Preallocate attribute definitions objects with empty verifiers. |
| 642 | /// This in particular allocates a TypeID for each attribute definition. |
| 643 | static DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> |
| 644 | preallocateAttrDefs(ModuleOp op, |
| 645 | DenseMap<DialectOp, ExtensibleDialect *> dialects) { |
| 646 | DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> attrDefs; |
| 647 | op.walk([&](AttributeOp attrOp) { |
| 648 | ExtensibleDialect *dialect = dialects[attrOp.getParentOp()]; |
| 649 | auto attrDef = DynamicAttrDefinition::get( |
| 650 | attrOp.getName(), dialect, |
| 651 | [](function_ref<InFlightDiagnostic()>, ArrayRef<Attribute>) { |
| 652 | return success(); |
| 653 | }); |
| 654 | attrDefs.try_emplace(attrOp, std::move(attrDef)); |
| 655 | }); |
| 656 | return attrDefs; |
| 657 | } |
| 658 | |
| 659 | LogicalResult mlir::irdl::loadDialects(ModuleOp op) { |
| 660 | // First, check that all any_of constraints are in a correct form. |
| 661 | // This is to ensure we can do the verification correctly. |
| 662 | WalkResult anyOfCorrects = op.walk( |
| 663 | [](AnyOfOp anyOf) { return (WalkResult)checkCorrectAnyOf(anyOf); }); |
| 664 | if (anyOfCorrects.wasInterrupted()) |
| 665 | return op.emitError("any_of constraints are not in the correct form" ); |
| 666 | |
| 667 | // Preallocate all dialects, and type and attribute definitions. |
| 668 | // In particular, this allocates TypeIDs so type and attributes can have |
| 669 | // verifiers that refer to each other. |
| 670 | DenseMap<DialectOp, ExtensibleDialect *> dialects = loadEmptyDialects(op); |
| 671 | DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> types = |
| 672 | preallocateTypeDefs(op, dialects); |
| 673 | DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> attrs = |
| 674 | preallocateAttrDefs(op, dialects); |
| 675 | |
| 676 | // Set the verifier for types. |
| 677 | WalkResult res = op.walk([&](TypeOp typeOp) { |
| 678 | DynamicAttrDefinition::VerifierFn verifier = getAttrOrTypeVerifier( |
| 679 | typeOp, dialects[typeOp.getParentOp()], types, attrs); |
| 680 | if (!verifier) |
| 681 | return WalkResult::interrupt(); |
| 682 | types[typeOp]->setVerifyFn(std::move(verifier)); |
| 683 | return WalkResult::advance(); |
| 684 | }); |
| 685 | if (res.wasInterrupted()) |
| 686 | return failure(); |
| 687 | |
| 688 | // Set the verifier for attributes. |
| 689 | res = op.walk([&](AttributeOp attrOp) { |
| 690 | DynamicAttrDefinition::VerifierFn verifier = getAttrOrTypeVerifier( |
| 691 | attrOp, dialects[attrOp.getParentOp()], types, attrs); |
| 692 | if (!verifier) |
| 693 | return WalkResult::interrupt(); |
| 694 | attrs[attrOp]->setVerifyFn(std::move(verifier)); |
| 695 | return WalkResult::advance(); |
| 696 | }); |
| 697 | if (res.wasInterrupted()) |
| 698 | return failure(); |
| 699 | |
| 700 | // Define and load all operations. |
| 701 | res = op.walk([&](OperationOp opOp) { |
| 702 | return loadOperation(opOp, dialects[opOp.getParentOp()], types, attrs); |
| 703 | }); |
| 704 | if (res.wasInterrupted()) |
| 705 | return failure(); |
| 706 | |
| 707 | // Load all types in their dialects. |
| 708 | for (auto &pair : types) { |
| 709 | ExtensibleDialect *dialect = dialects[pair.first.getParentOp()]; |
| 710 | dialect->registerDynamicType(type: std::move(pair.second)); |
| 711 | } |
| 712 | |
| 713 | // Load all attributes in their dialects. |
| 714 | for (auto &pair : attrs) { |
| 715 | ExtensibleDialect *dialect = dialects[pair.first.getParentOp()]; |
| 716 | dialect->registerDynamicAttr(attr: std::move(pair.second)); |
| 717 | } |
| 718 | |
| 719 | return success(); |
| 720 | } |
| 721 | |