1 | //===- IRDL.cpp - IRDL dialect ----------------------------------*- C++ -*-===// |
---|---|
2 | // |
3 | // This file is licensed under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/IRDL/IR/IRDL.h" |
10 | #include "mlir/Dialect/IRDL/IRDLSymbols.h" |
11 | #include "mlir/IR/Builders.h" |
12 | #include "mlir/IR/BuiltinAttributes.h" |
13 | #include "mlir/IR/Diagnostics.h" |
14 | #include "mlir/IR/DialectImplementation.h" |
15 | #include "mlir/IR/ExtensibleDialect.h" |
16 | #include "mlir/IR/OpDefinition.h" |
17 | #include "mlir/IR/OpImplementation.h" |
18 | #include "mlir/IR/Operation.h" |
19 | #include "mlir/Support/LLVM.h" |
20 | #include "llvm/ADT/STLExtras.h" |
21 | #include "llvm/ADT/SetOperations.h" |
22 | #include "llvm/ADT/SmallString.h" |
23 | #include "llvm/ADT/StringExtras.h" |
24 | #include "llvm/ADT/TypeSwitch.h" |
25 | #include "llvm/IR/Metadata.h" |
26 | #include "llvm/Support/Casting.h" |
27 | |
28 | using namespace mlir; |
29 | using namespace mlir::irdl; |
30 | |
31 | //===----------------------------------------------------------------------===// |
32 | // IRDL dialect. |
33 | //===----------------------------------------------------------------------===// |
34 | |
35 | #include "mlir/Dialect/IRDL/IR/IRDL.cpp.inc" |
36 | |
37 | #include "mlir/Dialect/IRDL/IR/IRDLDialect.cpp.inc" |
38 | |
39 | void IRDLDialect::initialize() { |
40 | addOperations< |
41 | #define GET_OP_LIST |
42 | #include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc" |
43 | >(); |
44 | addTypes< |
45 | #define GET_TYPEDEF_LIST |
46 | #include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc" |
47 | >(); |
48 | addAttributes< |
49 | #define GET_ATTRDEF_LIST |
50 | #include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc" |
51 | >(); |
52 | } |
53 | |
54 | //===----------------------------------------------------------------------===// |
55 | // Parsing/Printing/Verifying |
56 | //===----------------------------------------------------------------------===// |
57 | |
58 | /// Parse a region, and add a single block if the region is empty. |
59 | /// If no region is parsed, create a new region with a single empty block. |
60 | static ParseResult parseSingleBlockRegion(OpAsmParser &p, Region ®ion) { |
61 | auto regionParseRes = p.parseOptionalRegion(region); |
62 | if (regionParseRes.has_value() && failed(Result: regionParseRes.value())) |
63 | return failure(); |
64 | |
65 | // If the region is empty, add a single empty block. |
66 | if (region.empty()) |
67 | region.push_back(block: new Block()); |
68 | |
69 | return success(); |
70 | } |
71 | |
72 | static void printSingleBlockRegion(OpAsmPrinter &p, Operation *op, |
73 | Region ®ion) { |
74 | if (!region.getBlocks().front().empty()) |
75 | p.printRegion(blocks&: region); |
76 | } |
77 | static llvm::LogicalResult isValidName(llvm::StringRef in, mlir::Operation *loc, |
78 | const Twine &label) { |
79 | if (in.empty()) |
80 | return loc->emitError(message: "name of ") << label << " is empty"; |
81 | |
82 | bool allowUnderscore = false; |
83 | for (auto &elem : in) { |
84 | if (elem == '_') { |
85 | if (!allowUnderscore) |
86 | return loc->emitError(message: "name of ") |
87 | << label << " should not contain leading or double underscores"; |
88 | } else { |
89 | if (!isalnum(elem)) |
90 | return loc->emitError(message: "name of ") |
91 | << label |
92 | << " must contain only lowercase letters, digits and " |
93 | "underscores"; |
94 | |
95 | if (llvm::isUpper(C: elem)) |
96 | return loc->emitError(message: "name of ") |
97 | << label << " should not contain uppercase letters"; |
98 | } |
99 | |
100 | allowUnderscore = elem != '_'; |
101 | } |
102 | |
103 | return success(); |
104 | } |
105 | |
106 | LogicalResult DialectOp::verify() { |
107 | if (!Dialect::isValidNamespace(getName())) |
108 | return emitOpError("invalid dialect name"); |
109 | if (failed(isValidName(getSymName(), getOperation(), "dialect"))) |
110 | return failure(); |
111 | |
112 | return success(); |
113 | } |
114 | |
115 | LogicalResult OperationOp::verify() { |
116 | return isValidName(getSymName(), getOperation(), "operation"); |
117 | } |
118 | |
119 | LogicalResult TypeOp::verify() { |
120 | auto symName = getSymName(); |
121 | if (symName.front() == '!') |
122 | symName = symName.substr(1); |
123 | return isValidName(symName, getOperation(), "type"); |
124 | } |
125 | |
126 | LogicalResult AttributeOp::verify() { |
127 | auto symName = getSymName(); |
128 | if (symName.front() == '#') |
129 | symName = symName.substr(1); |
130 | return isValidName(symName, getOperation(), "attribute"); |
131 | } |
132 | |
133 | LogicalResult OperationOp::verifyRegions() { |
134 | // Stores pairs of value kinds and the list of names of values of this kind in |
135 | // the operation. |
136 | SmallVector<std::tuple<StringRef, llvm::SmallDenseSet<StringRef>>> valueNames; |
137 | |
138 | auto insertNames = [&](StringRef kind, ArrayAttr names) { |
139 | llvm::SmallDenseSet<StringRef> nameSet; |
140 | nameSet.reserve(names.size()); |
141 | for (auto name : names) |
142 | nameSet.insert(llvm::cast<StringAttr>(name).getValue()); |
143 | valueNames.emplace_back(kind, std::move(nameSet)); |
144 | }; |
145 | |
146 | for (Operation &op : getBody().getOps()) { |
147 | TypeSwitch<Operation *>(&op) |
148 | .Case<OperandsOp>( |
149 | [&](OperandsOp op) { insertNames("operands", op.getNames()); }) |
150 | .Case<ResultsOp>( |
151 | [&](ResultsOp op) { insertNames("results", op.getNames()); }) |
152 | .Case<RegionsOp>( |
153 | [&](RegionsOp op) { insertNames("regions", op.getNames()); }); |
154 | } |
155 | |
156 | // Verify that no two operand, result or region share the same name. |
157 | // The absence of duplicates within each value kind is checked by the |
158 | // associated operation's verifier. |
159 | for (size_t i : llvm::seq(valueNames.size())) { |
160 | for (size_t j : llvm::seq(i + 1, valueNames.size())) { |
161 | auto [lhs, lhsSet] = valueNames[i]; |
162 | auto &[rhs, rhsSet] = valueNames[j]; |
163 | llvm::set_intersect(lhsSet, rhsSet); |
164 | if (!lhsSet.empty()) |
165 | return emitOpError("contains a value named '") |
166 | << *lhsSet.begin() << "' for both its "<< lhs << " and "<< rhs; |
167 | } |
168 | } |
169 | |
170 | return success(); |
171 | } |
172 | |
173 | static LogicalResult verifyNames(Operation *op, StringRef kindName, |
174 | ArrayAttr names, size_t numOperands) { |
175 | if (numOperands != names.size()) |
176 | return op->emitOpError() |
177 | << "the number of "<< kindName |
178 | << "s and their names must be " |
179 | "the same, but got " |
180 | << numOperands << " and "<< names.size() << " respectively"; |
181 | |
182 | DenseMap<StringRef, size_t> nameMap; |
183 | for (auto [i, name] : llvm::enumerate(names)) { |
184 | StringRef nameRef = llvm::cast<StringAttr>(name).getValue(); |
185 | |
186 | if (failed(isValidName(nameRef, op, Twine(kindName) + " #"+ Twine(i)))) |
187 | return failure(); |
188 | |
189 | if (nameMap.contains(nameRef)) |
190 | return op->emitOpError() << "name of "<< kindName << " #"<< i |
191 | << " is a duplicate of the name of "<< kindName |
192 | << " #"<< nameMap[nameRef]; |
193 | nameMap.insert({nameRef, i}); |
194 | } |
195 | |
196 | return success(); |
197 | } |
198 | |
199 | LogicalResult ParametersOp::verify() { |
200 | return verifyNames(*this, "parameter", getNames(), getNumOperands()); |
201 | } |
202 | |
203 | template <typename ValueListOp> |
204 | static LogicalResult verifyOperandsResultsCommon(ValueListOp op, |
205 | StringRef kindName) { |
206 | size_t numVariadicities = op.getVariadicity().size(); |
207 | size_t numOperands = op.getNumOperands(); |
208 | |
209 | if (numOperands != numVariadicities) |
210 | return op.emitOpError() |
211 | << "the number of "<< kindName |
212 | << "s and their variadicities must be " |
213 | "the same, but got " |
214 | << numOperands << " and "<< numVariadicities << " respectively"; |
215 | |
216 | return verifyNames(op, kindName, op.getNames(), numOperands); |
217 | } |
218 | |
219 | LogicalResult OperandsOp::verify() { |
220 | return verifyOperandsResultsCommon(*this, "operand"); |
221 | } |
222 | |
223 | LogicalResult ResultsOp::verify() { |
224 | return verifyOperandsResultsCommon(*this, "result"); |
225 | } |
226 | |
227 | LogicalResult AttributesOp::verify() { |
228 | size_t namesSize = getAttributeValueNames().size(); |
229 | size_t valuesSize = getAttributeValues().size(); |
230 | |
231 | if (namesSize != valuesSize) |
232 | return emitOpError() |
233 | << "the number of attribute names and their constraints must be " |
234 | "the same but got " |
235 | << namesSize << " and "<< valuesSize << " respectively"; |
236 | |
237 | return success(); |
238 | } |
239 | |
240 | LogicalResult BaseOp::verify() { |
241 | std::optional<StringRef> baseName = getBaseName(); |
242 | std::optional<SymbolRefAttr> baseRef = getBaseRef(); |
243 | if (baseName.has_value() == baseRef.has_value()) |
244 | return emitOpError() << "the base type or attribute should be specified by " |
245 | "either a name or a reference"; |
246 | |
247 | if (baseName && |
248 | (baseName->empty() || ((*baseName)[0] != '!' && (*baseName)[0] != '#'))) |
249 | return emitOpError() << "the base type or attribute name should start with " |
250 | "'!' or '#'"; |
251 | |
252 | return success(); |
253 | } |
254 | |
255 | /// Finds whether the provided symbol is an IRDL type or attribute definition. |
256 | /// The source operation must be within a DialectOp. |
257 | static LogicalResult |
258 | checkSymbolIsTypeOrAttribute(SymbolTableCollection &symbolTable, |
259 | Operation *source, SymbolRefAttr symbol) { |
260 | Operation *targetOp = |
261 | irdl::lookupSymbolNearDialect(symbolTable, source, symbol); |
262 | |
263 | if (!targetOp) |
264 | return source->emitOpError() << "symbol '"<< symbol << "' not found"; |
265 | |
266 | if (!isa<TypeOp, AttributeOp>(Val: targetOp)) |
267 | return source->emitOpError() << "symbol '"<< symbol |
268 | << "' does not refer to a type or attribute " |
269 | "definition (refers to '" |
270 | << targetOp->getName() << "')"; |
271 | |
272 | return success(); |
273 | } |
274 | |
275 | LogicalResult BaseOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
276 | std::optional<SymbolRefAttr> baseRef = getBaseRef(); |
277 | if (!baseRef) |
278 | return success(); |
279 | |
280 | return checkSymbolIsTypeOrAttribute(symbolTable, *this, *baseRef); |
281 | } |
282 | |
283 | LogicalResult |
284 | ParametricOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
285 | std::optional<SymbolRefAttr> baseRef = getBaseType(); |
286 | if (!baseRef) |
287 | return success(); |
288 | |
289 | return checkSymbolIsTypeOrAttribute(symbolTable, *this, *baseRef); |
290 | } |
291 | |
292 | /// Parse a value with its variadicity first. By default, the variadicity is |
293 | /// single. |
294 | /// |
295 | /// value-with-variadicity ::= ("single" | "optional" | "variadic")? ssa-value |
296 | static ParseResult |
297 | parseValueWithVariadicity(OpAsmParser &p, |
298 | OpAsmParser::UnresolvedOperand &operand, |
299 | VariadicityAttr &variadicityAttr) { |
300 | MLIRContext *ctx = p.getBuilder().getContext(); |
301 | |
302 | // Parse the variadicity, if present |
303 | if (p.parseOptionalKeyword(keyword: "single").succeeded()) { |
304 | variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single); |
305 | } else if (p.parseOptionalKeyword(keyword: "optional").succeeded()) { |
306 | variadicityAttr = VariadicityAttr::get(ctx, Variadicity::optional); |
307 | } else if (p.parseOptionalKeyword(keyword: "variadic").succeeded()) { |
308 | variadicityAttr = VariadicityAttr::get(ctx, Variadicity::variadic); |
309 | } else { |
310 | variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single); |
311 | } |
312 | |
313 | // Parse the value |
314 | if (p.parseOperand(result&: operand)) |
315 | return failure(); |
316 | return success(); |
317 | } |
318 | |
319 | static ParseResult parseNamedValueListImpl( |
320 | OpAsmParser &p, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands, |
321 | ArrayAttr &valueNamesAttr, VariadicityArrayAttr *variadicityAttr) { |
322 | Builder &builder = p.getBuilder(); |
323 | MLIRContext *ctx = builder.getContext(); |
324 | SmallVector<Attribute> valueNames; |
325 | SmallVector<VariadicityAttr> variadicities; |
326 | |
327 | // Parse a single value with its variadicity |
328 | auto parseOne = [&] { |
329 | StringRef name; |
330 | OpAsmParser::UnresolvedOperand operand; |
331 | VariadicityAttr variadicity; |
332 | if (p.parseKeyword(keyword: &name) || p.parseColon()) |
333 | return failure(); |
334 | |
335 | if (variadicityAttr) { |
336 | if (parseValueWithVariadicity(p, operand, variadicity)) |
337 | return failure(); |
338 | variadicities.push_back(variadicity); |
339 | } else { |
340 | if (p.parseOperand(result&: operand)) |
341 | return failure(); |
342 | } |
343 | |
344 | valueNames.push_back(StringAttr::get(ctx, name)); |
345 | operands.push_back(Elt: operand); |
346 | return success(); |
347 | }; |
348 | |
349 | if (p.parseCommaSeparatedList(delimiter: OpAsmParser::Delimiter::Paren, parseElementFn: parseOne)) |
350 | return failure(); |
351 | valueNamesAttr = ArrayAttr::get(ctx, valueNames); |
352 | if (variadicityAttr) |
353 | *variadicityAttr = VariadicityArrayAttr::get(ctx, variadicities); |
354 | return success(); |
355 | } |
356 | |
357 | /// Parse a list of named values. |
358 | /// |
359 | /// values ::= |
360 | /// `(` (named-value (`,` named-value)*)? `)` |
361 | /// named-value := bare-id `:` ssa-value |
362 | static ParseResult |
363 | parseNamedValueList(OpAsmParser &p, |
364 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands, |
365 | ArrayAttr &valueNamesAttr) { |
366 | return parseNamedValueListImpl(p, operands, valueNamesAttr, nullptr); |
367 | } |
368 | |
369 | /// Parse a list of named values with their variadicities first. By default, the |
370 | /// variadicity is single. |
371 | /// |
372 | /// values-with-variadicity ::= |
373 | /// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)` |
374 | /// value-with-variadicity |
375 | /// ::= bare-id `:` ("single" | "optional" | "variadic")? ssa-value |
376 | static ParseResult parseNamedValueListWithVariadicity( |
377 | OpAsmParser &p, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands, |
378 | ArrayAttr &valueNamesAttr, VariadicityArrayAttr &variadicityAttr) { |
379 | return parseNamedValueListImpl(p, operands, valueNamesAttr, &variadicityAttr); |
380 | } |
381 | |
382 | static void printNamedValueListImpl(OpAsmPrinter &p, Operation *op, |
383 | OperandRange operands, |
384 | ArrayAttr valueNamesAttr, |
385 | VariadicityArrayAttr variadicityAttr) { |
386 | p << "("; |
387 | interleaveComma(c: llvm::seq<int>(Begin: 0, End: operands.size()), os&: p, each_fn: [&](int i) { |
388 | p << llvm::cast<StringAttr>(valueNamesAttr[i]).getValue() << ": "; |
389 | if (variadicityAttr) { |
390 | Variadicity variadicity = variadicityAttr[i].getValue(); |
391 | if (variadicity != Variadicity::single) { |
392 | p << stringifyVariadicity(variadicity) << " "; |
393 | } |
394 | } |
395 | p << operands[i]; |
396 | }); |
397 | p << ")"; |
398 | } |
399 | |
400 | /// Print a list of named values. |
401 | /// |
402 | /// values ::= |
403 | /// `(` (named-value (`,` named-value)*)? `)` |
404 | /// named-value := bare-id `:` ssa-value |
405 | static void printNamedValueList(OpAsmPrinter &p, Operation *op, |
406 | OperandRange operands, |
407 | ArrayAttr valueNamesAttr) { |
408 | printNamedValueListImpl(p, op, operands, valueNamesAttr, nullptr); |
409 | } |
410 | |
411 | /// Print a list of named values with their variadicities first. By default, the |
412 | /// variadicity is single. |
413 | /// |
414 | /// values-with-variadicity ::= |
415 | /// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)` |
416 | /// value-with-variadicity ::= |
417 | /// bare-id `:` ("single" | "optional" | "variadic")? ssa-value |
418 | static void printNamedValueListWithVariadicity( |
419 | OpAsmPrinter &p, Operation *op, OperandRange operands, |
420 | ArrayAttr valueNamesAttr, VariadicityArrayAttr variadicityAttr) { |
421 | printNamedValueListImpl(p, op, operands, valueNamesAttr, variadicityAttr); |
422 | } |
423 | |
424 | static ParseResult |
425 | parseAttributesOp(OpAsmParser &p, |
426 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands, |
427 | ArrayAttr &attrNamesAttr) { |
428 | Builder &builder = p.getBuilder(); |
429 | SmallVector<Attribute> attrNames; |
430 | if (succeeded(Result: p.parseOptionalLBrace())) { |
431 | auto parseOperands = [&]() { |
432 | if (p.parseAttribute(result&: attrNames.emplace_back()) || p.parseEqual() || |
433 | p.parseOperand(result&: attrOperands.emplace_back())) |
434 | return failure(); |
435 | return success(); |
436 | }; |
437 | if (p.parseCommaSeparatedList(parseElementFn: parseOperands) || p.parseRBrace()) |
438 | return failure(); |
439 | } |
440 | attrNamesAttr = builder.getArrayAttr(attrNames); |
441 | return success(); |
442 | } |
443 | |
444 | static void printAttributesOp(OpAsmPrinter &p, AttributesOp op, |
445 | OperandRange attrArgs, ArrayAttr attrNames) { |
446 | if (attrNames.empty()) |
447 | return; |
448 | p << "{"; |
449 | interleaveComma(llvm::seq<int>(0, attrNames.size()), p, |
450 | [&](int i) { p << attrNames[i] << " = "<< attrArgs[i]; }); |
451 | p << '}'; |
452 | } |
453 | |
454 | LogicalResult RegionOp::verify() { |
455 | if (IntegerAttr numberOfBlocks = getNumberOfBlocksAttr()) |
456 | if (int64_t number = numberOfBlocks.getInt(); number <= 0) { |
457 | return emitOpError("the number of blocks is expected to be >= 1 but got ") |
458 | << number; |
459 | } |
460 | return success(); |
461 | } |
462 | |
463 | LogicalResult RegionsOp::verify() { |
464 | return verifyNames(*this, "region", getNames(), getNumOperands()); |
465 | } |
466 | |
467 | #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.cpp.inc" |
468 | |
469 | #define GET_TYPEDEF_CLASSES |
470 | #include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc" |
471 | |
472 | #include "mlir/Dialect/IRDL/IR/IRDLEnums.cpp.inc" |
473 | |
474 | #define GET_ATTRDEF_CLASSES |
475 | #include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc" |
476 | |
477 | #define GET_OP_CLASSES |
478 | #include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc" |
479 |
Definitions
- parseSingleBlockRegion
- printSingleBlockRegion
- isValidName
- verifyNames
- verifyOperandsResultsCommon
- checkSymbolIsTypeOrAttribute
- parseValueWithVariadicity
- parseNamedValueListImpl
- parseNamedValueList
- parseNamedValueListWithVariadicity
- printNamedValueListImpl
- printNamedValueList
- printNamedValueListWithVariadicity
- parseAttributesOp
Improve your Profiling and Debugging skills
Find out more