1//===- IRDLVerifiers.cpp - IRDL verifiers ------------------------- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Verifiers for objects declared by IRDL.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/IRDL/IRDLVerifiers.h"
14#include "mlir/IR/Attributes.h"
15#include "mlir/IR/Block.h"
16#include "mlir/IR/BuiltinAttributes.h"
17#include "mlir/IR/Diagnostics.h"
18#include "mlir/IR/ExtensibleDialect.h"
19#include "mlir/IR/Location.h"
20#include "mlir/IR/Region.h"
21#include "mlir/IR/Value.h"
22#include "mlir/Support/LogicalResult.h"
23#include "llvm/Support/FormatVariadic.h"
24
25using namespace mlir;
26using namespace mlir::irdl;
27
28ConstraintVerifier::ConstraintVerifier(
29 ArrayRef<std::unique_ptr<Constraint>> constraints)
30 : constraints(constraints), assigned() {
31 assigned.resize(N: this->constraints.size());
32}
33
34LogicalResult
35ConstraintVerifier::verify(function_ref<InFlightDiagnostic()> emitError,
36 Attribute attr, unsigned variable) {
37
38 assert(variable < constraints.size() && "invalid constraint variable");
39
40 // If the variable is already assigned, check that the attribute is the same.
41 if (assigned[variable].has_value()) {
42 if (attr == assigned[variable].value()) {
43 return success();
44 }
45 if (emitError)
46 return emitError() << "expected '" << assigned[variable].value()
47 << "' but got '" << attr << "'";
48 return failure();
49 }
50
51 // Otherwise, check the constraint and assign the attribute to the variable.
52 LogicalResult result = constraints[variable]->verify(emitError, attr, context&: *this);
53 if (succeeded(result))
54 assigned[variable] = attr;
55
56 return result;
57}
58
59LogicalResult IsConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
60 Attribute attr,
61 ConstraintVerifier &context) const {
62 if (attr == expectedAttribute)
63 return success();
64
65 if (emitError)
66 return emitError() << "expected '" << expectedAttribute << "' but got '"
67 << attr << "'";
68 return failure();
69}
70
71LogicalResult
72BaseAttrConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
73 Attribute attr, ConstraintVerifier &context) const {
74 if (attr.getTypeID() == baseTypeID)
75 return success();
76
77 if (emitError)
78 return emitError() << "expected base attribute '" << baseName
79 << "' but got '" << attr.getAbstractAttribute().getName()
80 << "'";
81 return failure();
82}
83
84LogicalResult
85BaseTypeConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
86 Attribute attr, ConstraintVerifier &context) const {
87 auto typeAttr = dyn_cast<TypeAttr>(attr);
88 if (!typeAttr) {
89 if (emitError)
90 return emitError() << "expected type, got attribute '" << attr;
91 return failure();
92 }
93
94 Type type = typeAttr.getValue();
95 if (type.getTypeID() == baseTypeID)
96 return success();
97
98 if (emitError)
99 return emitError() << "expected base type '" << baseName << "' but got '"
100 << type.getAbstractType().getName() << "'";
101 return failure();
102}
103
104LogicalResult DynParametricAttrConstraint::verify(
105 function_ref<InFlightDiagnostic()> emitError, Attribute attr,
106 ConstraintVerifier &context) const {
107
108 // Check that the base is the expected one.
109 auto dynAttr = dyn_cast<DynamicAttr>(Val&: attr);
110 if (!dynAttr || dynAttr.getAttrDef() != attrDef) {
111 if (emitError) {
112 StringRef dialectName = attrDef->getDialect()->getNamespace();
113 StringRef attrName = attrDef->getName();
114 return emitError() << "expected base attribute '" << attrName << '.'
115 << dialectName << "' but got '" << attr << "'";
116 }
117 return failure();
118 }
119
120 // Check that the parameters satisfy the constraints.
121 ArrayRef<Attribute> params = dynAttr.getParams();
122 if (params.size() != constraints.size()) {
123 if (emitError) {
124 StringRef dialectName = attrDef->getDialect()->getNamespace();
125 StringRef attrName = attrDef->getName();
126 emitError() << "attribute '" << dialectName << "." << attrName
127 << "' expects " << params.size() << " parameters but got "
128 << constraints.size();
129 }
130 return failure();
131 }
132
133 for (size_t i = 0, s = params.size(); i < s; i++)
134 if (failed(result: context.verify(emitError, attr: params[i], variable: constraints[i])))
135 return failure();
136
137 return success();
138}
139
140LogicalResult DynParametricTypeConstraint::verify(
141 function_ref<InFlightDiagnostic()> emitError, Attribute attr,
142 ConstraintVerifier &context) const {
143 // Check that the base is a TypeAttr.
144 auto typeAttr = dyn_cast<TypeAttr>(attr);
145 if (!typeAttr) {
146 if (emitError)
147 return emitError() << "expected type, got attribute '" << attr;
148 return failure();
149 }
150
151 // Check that the type base is the expected one.
152 auto dynType = dyn_cast<DynamicType>(typeAttr.getValue());
153 if (!dynType || dynType.getTypeDef() != typeDef) {
154 if (emitError) {
155 StringRef dialectName = typeDef->getDialect()->getNamespace();
156 StringRef attrName = typeDef->getName();
157 return emitError() << "expected base type '" << dialectName << '.'
158 << attrName << "' but got '" << attr << "'";
159 }
160 return failure();
161 }
162
163 // Check that the parameters satisfy the constraints.
164 ArrayRef<Attribute> params = dynType.getParams();
165 if (params.size() != constraints.size()) {
166 if (emitError) {
167 StringRef dialectName = typeDef->getDialect()->getNamespace();
168 StringRef attrName = typeDef->getName();
169 emitError() << "attribute '" << dialectName << "." << attrName
170 << "' expects " << params.size() << " parameters but got "
171 << constraints.size();
172 }
173 return failure();
174 }
175
176 for (size_t i = 0, s = params.size(); i < s; i++)
177 if (failed(result: context.verify(emitError, attr: params[i], variable: constraints[i])))
178 return failure();
179
180 return success();
181}
182
183LogicalResult
184AnyOfConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
185 Attribute attr, ConstraintVerifier &context) const {
186 for (unsigned constr : constraints) {
187 // We do not pass the `emitError` here, since we want to emit an error
188 // only if none of the constraints are satisfied.
189 if (succeeded(result: context.verify(emitError: {}, attr, variable: constr))) {
190 return success();
191 }
192 }
193
194 if (emitError)
195 return emitError() << "'" << attr << "' does not satisfy the constraint";
196 return failure();
197}
198
199LogicalResult
200AllOfConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
201 Attribute attr, ConstraintVerifier &context) const {
202 for (unsigned constr : constraints) {
203 if (failed(result: context.verify(emitError, attr, variable: constr))) {
204 return failure();
205 }
206 }
207
208 return success();
209}
210
211LogicalResult
212AnyAttributeConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
213 Attribute attr,
214 ConstraintVerifier &context) const {
215 return success();
216}
217
218LogicalResult RegionConstraint::verify(mlir::Region &region,
219 ConstraintVerifier &constraintContext) {
220 const auto emitError = [parentOp = region.getParentOp()](mlir::Location loc) {
221 return [loc, parentOp] {
222 InFlightDiagnostic diag = mlir::emitError(loc);
223 // If we already have been given location of the parent operation, which
224 // might happen when the region location is passed, we do not want to
225 // produce the note on the same location
226 if (loc != parentOp->getLoc())
227 diag.attachNote(noteLoc: parentOp->getLoc()).append(arg: "see the operation");
228 return diag;
229 };
230 };
231
232 if (blockCount.has_value() && *blockCount != region.getBlocks().size()) {
233 return emitError(region.getLoc())()
234 << "expected region " << region.getRegionNumber() << " to have "
235 << *blockCount << " block(s) but got " << region.getBlocks().size();
236 }
237
238 if (argumentConstraints.has_value()) {
239 auto actualArgs = region.getArguments();
240 if (actualArgs.size() != argumentConstraints->size()) {
241 const mlir::Location firstArgLoc =
242 actualArgs.empty() ? region.getLoc() : actualArgs.front().getLoc();
243 return emitError(firstArgLoc)()
244 << "expected region " << region.getRegionNumber() << " to have "
245 << argumentConstraints->size() << " arguments but got "
246 << actualArgs.size();
247 }
248
249 for (auto [arg, constraint] : llvm::zip(t&: actualArgs, u&: *argumentConstraints)) {
250 mlir::Attribute type = TypeAttr::get(arg.getType());
251 if (failed(result: constraintContext.verify(emitError: emitError(arg.getLoc()), attr: type,
252 variable: constraint))) {
253 return failure();
254 }
255 }
256 }
257 return success();
258}
259

source code of mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp