1 | //===- IRDLVerifiers.cpp - IRDL verifiers ------------------------- C++ -*-===// |
2 | // |
3 | // This file is licensed under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Verifiers for objects declared by IRDL. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/IRDL/IRDLVerifiers.h" |
14 | #include "mlir/IR/Attributes.h" |
15 | #include "mlir/IR/Block.h" |
16 | #include "mlir/IR/BuiltinAttributes.h" |
17 | #include "mlir/IR/Diagnostics.h" |
18 | #include "mlir/IR/ExtensibleDialect.h" |
19 | #include "mlir/IR/Location.h" |
20 | #include "mlir/IR/Region.h" |
21 | #include "mlir/IR/Value.h" |
22 | #include "mlir/Support/LogicalResult.h" |
23 | #include "llvm/Support/FormatVariadic.h" |
24 | |
25 | using namespace mlir; |
26 | using namespace mlir::irdl; |
27 | |
28 | ConstraintVerifier::ConstraintVerifier( |
29 | ArrayRef<std::unique_ptr<Constraint>> constraints) |
30 | : constraints(constraints), assigned() { |
31 | assigned.resize(N: this->constraints.size()); |
32 | } |
33 | |
34 | LogicalResult |
35 | ConstraintVerifier::verify(function_ref<InFlightDiagnostic()> emitError, |
36 | Attribute attr, unsigned variable) { |
37 | |
38 | assert(variable < constraints.size() && "invalid constraint variable" ); |
39 | |
40 | // If the variable is already assigned, check that the attribute is the same. |
41 | if (assigned[variable].has_value()) { |
42 | if (attr == assigned[variable].value()) { |
43 | return success(); |
44 | } |
45 | if (emitError) |
46 | return emitError() << "expected '" << assigned[variable].value() |
47 | << "' but got '" << attr << "'" ; |
48 | return failure(); |
49 | } |
50 | |
51 | // Otherwise, check the constraint and assign the attribute to the variable. |
52 | LogicalResult result = constraints[variable]->verify(emitError, attr, context&: *this); |
53 | if (succeeded(result)) |
54 | assigned[variable] = attr; |
55 | |
56 | return result; |
57 | } |
58 | |
59 | LogicalResult IsConstraint::verify(function_ref<InFlightDiagnostic()> emitError, |
60 | Attribute attr, |
61 | ConstraintVerifier &context) const { |
62 | if (attr == expectedAttribute) |
63 | return success(); |
64 | |
65 | if (emitError) |
66 | return emitError() << "expected '" << expectedAttribute << "' but got '" |
67 | << attr << "'" ; |
68 | return failure(); |
69 | } |
70 | |
71 | LogicalResult |
72 | BaseAttrConstraint::verify(function_ref<InFlightDiagnostic()> emitError, |
73 | Attribute attr, ConstraintVerifier &context) const { |
74 | if (attr.getTypeID() == baseTypeID) |
75 | return success(); |
76 | |
77 | if (emitError) |
78 | return emitError() << "expected base attribute '" << baseName |
79 | << "' but got '" << attr.getAbstractAttribute().getName() |
80 | << "'" ; |
81 | return failure(); |
82 | } |
83 | |
84 | LogicalResult |
85 | BaseTypeConstraint::verify(function_ref<InFlightDiagnostic()> emitError, |
86 | Attribute attr, ConstraintVerifier &context) const { |
87 | auto typeAttr = dyn_cast<TypeAttr>(attr); |
88 | if (!typeAttr) { |
89 | if (emitError) |
90 | return emitError() << "expected type, got attribute '" << attr; |
91 | return failure(); |
92 | } |
93 | |
94 | Type type = typeAttr.getValue(); |
95 | if (type.getTypeID() == baseTypeID) |
96 | return success(); |
97 | |
98 | if (emitError) |
99 | return emitError() << "expected base type '" << baseName << "' but got '" |
100 | << type.getAbstractType().getName() << "'" ; |
101 | return failure(); |
102 | } |
103 | |
104 | LogicalResult DynParametricAttrConstraint::verify( |
105 | function_ref<InFlightDiagnostic()> emitError, Attribute attr, |
106 | ConstraintVerifier &context) const { |
107 | |
108 | // Check that the base is the expected one. |
109 | auto dynAttr = dyn_cast<DynamicAttr>(Val&: attr); |
110 | if (!dynAttr || dynAttr.getAttrDef() != attrDef) { |
111 | if (emitError) { |
112 | StringRef dialectName = attrDef->getDialect()->getNamespace(); |
113 | StringRef attrName = attrDef->getName(); |
114 | return emitError() << "expected base attribute '" << attrName << '.' |
115 | << dialectName << "' but got '" << attr << "'" ; |
116 | } |
117 | return failure(); |
118 | } |
119 | |
120 | // Check that the parameters satisfy the constraints. |
121 | ArrayRef<Attribute> params = dynAttr.getParams(); |
122 | if (params.size() != constraints.size()) { |
123 | if (emitError) { |
124 | StringRef dialectName = attrDef->getDialect()->getNamespace(); |
125 | StringRef attrName = attrDef->getName(); |
126 | emitError() << "attribute '" << dialectName << "." << attrName |
127 | << "' expects " << params.size() << " parameters but got " |
128 | << constraints.size(); |
129 | } |
130 | return failure(); |
131 | } |
132 | |
133 | for (size_t i = 0, s = params.size(); i < s; i++) |
134 | if (failed(result: context.verify(emitError, attr: params[i], variable: constraints[i]))) |
135 | return failure(); |
136 | |
137 | return success(); |
138 | } |
139 | |
140 | LogicalResult DynParametricTypeConstraint::verify( |
141 | function_ref<InFlightDiagnostic()> emitError, Attribute attr, |
142 | ConstraintVerifier &context) const { |
143 | // Check that the base is a TypeAttr. |
144 | auto typeAttr = dyn_cast<TypeAttr>(attr); |
145 | if (!typeAttr) { |
146 | if (emitError) |
147 | return emitError() << "expected type, got attribute '" << attr; |
148 | return failure(); |
149 | } |
150 | |
151 | // Check that the type base is the expected one. |
152 | auto dynType = dyn_cast<DynamicType>(typeAttr.getValue()); |
153 | if (!dynType || dynType.getTypeDef() != typeDef) { |
154 | if (emitError) { |
155 | StringRef dialectName = typeDef->getDialect()->getNamespace(); |
156 | StringRef attrName = typeDef->getName(); |
157 | return emitError() << "expected base type '" << dialectName << '.' |
158 | << attrName << "' but got '" << attr << "'" ; |
159 | } |
160 | return failure(); |
161 | } |
162 | |
163 | // Check that the parameters satisfy the constraints. |
164 | ArrayRef<Attribute> params = dynType.getParams(); |
165 | if (params.size() != constraints.size()) { |
166 | if (emitError) { |
167 | StringRef dialectName = typeDef->getDialect()->getNamespace(); |
168 | StringRef attrName = typeDef->getName(); |
169 | emitError() << "attribute '" << dialectName << "." << attrName |
170 | << "' expects " << params.size() << " parameters but got " |
171 | << constraints.size(); |
172 | } |
173 | return failure(); |
174 | } |
175 | |
176 | for (size_t i = 0, s = params.size(); i < s; i++) |
177 | if (failed(result: context.verify(emitError, attr: params[i], variable: constraints[i]))) |
178 | return failure(); |
179 | |
180 | return success(); |
181 | } |
182 | |
183 | LogicalResult |
184 | AnyOfConstraint::verify(function_ref<InFlightDiagnostic()> emitError, |
185 | Attribute attr, ConstraintVerifier &context) const { |
186 | for (unsigned constr : constraints) { |
187 | // We do not pass the `emitError` here, since we want to emit an error |
188 | // only if none of the constraints are satisfied. |
189 | if (succeeded(result: context.verify(emitError: {}, attr, variable: constr))) { |
190 | return success(); |
191 | } |
192 | } |
193 | |
194 | if (emitError) |
195 | return emitError() << "'" << attr << "' does not satisfy the constraint" ; |
196 | return failure(); |
197 | } |
198 | |
199 | LogicalResult |
200 | AllOfConstraint::verify(function_ref<InFlightDiagnostic()> emitError, |
201 | Attribute attr, ConstraintVerifier &context) const { |
202 | for (unsigned constr : constraints) { |
203 | if (failed(result: context.verify(emitError, attr, variable: constr))) { |
204 | return failure(); |
205 | } |
206 | } |
207 | |
208 | return success(); |
209 | } |
210 | |
211 | LogicalResult |
212 | AnyAttributeConstraint::verify(function_ref<InFlightDiagnostic()> emitError, |
213 | Attribute attr, |
214 | ConstraintVerifier &context) const { |
215 | return success(); |
216 | } |
217 | |
218 | LogicalResult RegionConstraint::verify(mlir::Region ®ion, |
219 | ConstraintVerifier &constraintContext) { |
220 | const auto emitError = [parentOp = region.getParentOp()](mlir::Location loc) { |
221 | return [loc, parentOp] { |
222 | InFlightDiagnostic diag = mlir::emitError(loc); |
223 | // If we already have been given location of the parent operation, which |
224 | // might happen when the region location is passed, we do not want to |
225 | // produce the note on the same location |
226 | if (loc != parentOp->getLoc()) |
227 | diag.attachNote(noteLoc: parentOp->getLoc()).append(arg: "see the operation" ); |
228 | return diag; |
229 | }; |
230 | }; |
231 | |
232 | if (blockCount.has_value() && *blockCount != region.getBlocks().size()) { |
233 | return emitError(region.getLoc())() |
234 | << "expected region " << region.getRegionNumber() << " to have " |
235 | << *blockCount << " block(s) but got " << region.getBlocks().size(); |
236 | } |
237 | |
238 | if (argumentConstraints.has_value()) { |
239 | auto actualArgs = region.getArguments(); |
240 | if (actualArgs.size() != argumentConstraints->size()) { |
241 | const mlir::Location firstArgLoc = |
242 | actualArgs.empty() ? region.getLoc() : actualArgs.front().getLoc(); |
243 | return emitError(firstArgLoc)() |
244 | << "expected region " << region.getRegionNumber() << " to have " |
245 | << argumentConstraints->size() << " arguments but got " |
246 | << actualArgs.size(); |
247 | } |
248 | |
249 | for (auto [arg, constraint] : llvm::zip(t&: actualArgs, u&: *argumentConstraints)) { |
250 | mlir::Attribute type = TypeAttr::get(arg.getType()); |
251 | if (failed(result: constraintContext.verify(emitError: emitError(arg.getLoc()), attr: type, |
252 | variable: constraint))) { |
253 | return failure(); |
254 | } |
255 | } |
256 | } |
257 | return success(); |
258 | } |
259 | |