1//===- IRDL.cpp - IRDL dialect ----------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/IRDL/IR/IRDL.h"
10#include "mlir/Dialect/IRDL/IRDLSymbols.h"
11#include "mlir/IR/Builders.h"
12#include "mlir/IR/BuiltinAttributes.h"
13#include "mlir/IR/Diagnostics.h"
14#include "mlir/IR/DialectImplementation.h"
15#include "mlir/IR/ExtensibleDialect.h"
16#include "mlir/IR/OpDefinition.h"
17#include "mlir/IR/OpImplementation.h"
18#include "mlir/IR/Operation.h"
19#include "mlir/Support/LLVM.h"
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/ADT/SetOperations.h"
22#include "llvm/ADT/StringExtras.h"
23#include "llvm/ADT/TypeSwitch.h"
24#include "llvm/Support/Casting.h"
25
26using namespace mlir;
27using namespace mlir::irdl;
28
29//===----------------------------------------------------------------------===//
30// IRDL dialect.
31//===----------------------------------------------------------------------===//
32
33#include "mlir/Dialect/IRDL/IR/IRDL.cpp.inc"
34
35#include "mlir/Dialect/IRDL/IR/IRDLDialect.cpp.inc"
36
37void IRDLDialect::initialize() {
38 addOperations<
39#define GET_OP_LIST
40#include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"
41 >();
42 addTypes<
43#define GET_TYPEDEF_LIST
44#include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
45 >();
46 addAttributes<
47#define GET_ATTRDEF_LIST
48#include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
49 >();
50}
51
52//===----------------------------------------------------------------------===//
53// Parsing/Printing/Verifying
54//===----------------------------------------------------------------------===//
55
56/// Parse a region, and add a single block if the region is empty.
57/// If no region is parsed, create a new region with a single empty block.
58static ParseResult parseSingleBlockRegion(OpAsmParser &p, Region &region) {
59 auto regionParseRes = p.parseOptionalRegion(region);
60 if (regionParseRes.has_value() && failed(Result: regionParseRes.value()))
61 return failure();
62
63 // If the region is empty, add a single empty block.
64 if (region.empty())
65 region.push_back(block: new Block());
66
67 return success();
68}
69
70static void printSingleBlockRegion(OpAsmPrinter &p, Operation *op,
71 Region &region) {
72 if (!region.getBlocks().front().empty())
73 p.printRegion(blocks&: region);
74}
75static llvm::LogicalResult isValidName(llvm::StringRef in, mlir::Operation *loc,
76 const Twine &label) {
77 if (in.empty())
78 return loc->emitError(message: "name of ") << label << " is empty";
79
80 bool allowUnderscore = false;
81 for (auto &elem : in) {
82 if (elem == '_') {
83 if (!allowUnderscore)
84 return loc->emitError(message: "name of ")
85 << label << " should not contain leading or double underscores";
86 } else {
87 if (!isalnum(elem))
88 return loc->emitError(message: "name of ")
89 << label
90 << " must contain only lowercase letters, digits and "
91 "underscores";
92
93 if (llvm::isUpper(C: elem))
94 return loc->emitError(message: "name of ")
95 << label << " should not contain uppercase letters";
96 }
97
98 allowUnderscore = elem != '_';
99 }
100
101 return success();
102}
103
104LogicalResult DialectOp::verify() {
105 if (!Dialect::isValidNamespace(str: getName()))
106 return emitOpError(message: "invalid dialect name");
107 if (failed(Result: isValidName(in: getSymName(), loc: getOperation(), label: "dialect")))
108 return failure();
109
110 return success();
111}
112
113LogicalResult OperationOp::verify() {
114 return isValidName(in: getSymName(), loc: getOperation(), label: "operation");
115}
116
117LogicalResult TypeOp::verify() {
118 auto symName = getSymName();
119 if (symName.front() == '!')
120 symName = symName.substr(Start: 1);
121 return isValidName(in: symName, loc: getOperation(), label: "type");
122}
123
124LogicalResult AttributeOp::verify() {
125 auto symName = getSymName();
126 if (symName.front() == '#')
127 symName = symName.substr(Start: 1);
128 return isValidName(in: symName, loc: getOperation(), label: "attribute");
129}
130
131LogicalResult OperationOp::verifyRegions() {
132 // Stores pairs of value kinds and the list of names of values of this kind in
133 // the operation.
134 SmallVector<std::tuple<StringRef, llvm::SmallDenseSet<StringRef>>> valueNames;
135
136 auto insertNames = [&](StringRef kind, ArrayAttr names) {
137 llvm::SmallDenseSet<StringRef> nameSet;
138 nameSet.reserve(Size: names.size());
139 for (auto name : names)
140 nameSet.insert(V: llvm::cast<StringAttr>(Val&: name).getValue());
141 valueNames.emplace_back(Args&: kind, Args: std::move(nameSet));
142 };
143
144 for (Operation &op : getBody().getOps()) {
145 TypeSwitch<Operation *>(&op)
146 .Case<OperandsOp>(
147 caseFn: [&](OperandsOp op) { insertNames("operands", op.getNames()); })
148 .Case<ResultsOp>(
149 caseFn: [&](ResultsOp op) { insertNames("results", op.getNames()); })
150 .Case<RegionsOp>(
151 caseFn: [&](RegionsOp op) { insertNames("regions", op.getNames()); });
152 }
153
154 // Verify that no two operand, result or region share the same name.
155 // The absence of duplicates within each value kind is checked by the
156 // associated operation's verifier.
157 for (size_t i : llvm::seq(Size: valueNames.size())) {
158 for (size_t j : llvm::seq(Begin: i + 1, End: valueNames.size())) {
159 auto [lhs, lhsSet] = valueNames[i];
160 auto &[rhs, rhsSet] = valueNames[j];
161 llvm::set_intersect(S1&: lhsSet, S2: rhsSet);
162 if (!lhsSet.empty())
163 return emitOpError(message: "contains a value named '")
164 << *lhsSet.begin() << "' for both its " << lhs << " and " << rhs;
165 }
166 }
167
168 return success();
169}
170
171static LogicalResult verifyNames(Operation *op, StringRef kindName,
172 ArrayAttr names, size_t numOperands) {
173 if (numOperands != names.size())
174 return op->emitOpError()
175 << "the number of " << kindName
176 << "s and their names must be "
177 "the same, but got "
178 << numOperands << " and " << names.size() << " respectively";
179
180 DenseMap<StringRef, size_t> nameMap;
181 for (auto [i, name] : llvm::enumerate(First&: names)) {
182 StringRef nameRef = llvm::cast<StringAttr>(Val: name).getValue();
183
184 if (failed(Result: isValidName(in: nameRef, loc: op, label: Twine(kindName) + " #" + Twine(i))))
185 return failure();
186
187 if (nameMap.contains(Val: nameRef))
188 return op->emitOpError() << "name of " << kindName << " #" << i
189 << " is a duplicate of the name of " << kindName
190 << " #" << nameMap[nameRef];
191 nameMap.insert(KV: {nameRef, i});
192 }
193
194 return success();
195}
196
197LogicalResult ParametersOp::verify() {
198 return verifyNames(op: *this, kindName: "parameter", names: getNames(), numOperands: getNumOperands());
199}
200
201template <typename ValueListOp>
202static LogicalResult verifyOperandsResultsCommon(ValueListOp op,
203 StringRef kindName) {
204 size_t numVariadicities = op.getVariadicity().size();
205 size_t numOperands = op.getNumOperands();
206
207 if (numOperands != numVariadicities)
208 return op.emitOpError()
209 << "the number of " << kindName
210 << "s and their variadicities must be "
211 "the same, but got "
212 << numOperands << " and " << numVariadicities << " respectively";
213
214 return verifyNames(op, kindName, op.getNames(), numOperands);
215}
216
217LogicalResult OperandsOp::verify() {
218 return verifyOperandsResultsCommon(op: *this, kindName: "operand");
219}
220
221LogicalResult ResultsOp::verify() {
222 return verifyOperandsResultsCommon(op: *this, kindName: "result");
223}
224
225LogicalResult AttributesOp::verify() {
226 size_t namesSize = getAttributeValueNames().size();
227 size_t valuesSize = getAttributeValues().size();
228
229 if (namesSize != valuesSize)
230 return emitOpError()
231 << "the number of attribute names and their constraints must be "
232 "the same but got "
233 << namesSize << " and " << valuesSize << " respectively";
234
235 return success();
236}
237
238LogicalResult BaseOp::verify() {
239 std::optional<StringRef> baseName = getBaseName();
240 std::optional<SymbolRefAttr> baseRef = getBaseRef();
241 if (baseName.has_value() == baseRef.has_value())
242 return emitOpError() << "the base type or attribute should be specified by "
243 "either a name or a reference";
244
245 if (baseName &&
246 (baseName->empty() || ((*baseName)[0] != '!' && (*baseName)[0] != '#')))
247 return emitOpError() << "the base type or attribute name should start with "
248 "'!' or '#'";
249
250 return success();
251}
252
253/// Finds whether the provided symbol is an IRDL type or attribute definition.
254/// The source operation must be within a DialectOp.
255static LogicalResult
256checkSymbolIsTypeOrAttribute(SymbolTableCollection &symbolTable,
257 Operation *source, SymbolRefAttr symbol) {
258 Operation *targetOp =
259 irdl::lookupSymbolNearDialect(symbolTable, source, symbol);
260
261 if (!targetOp)
262 return source->emitOpError() << "symbol '" << symbol << "' not found";
263
264 if (!isa<TypeOp, AttributeOp>(Val: targetOp))
265 return source->emitOpError() << "symbol '" << symbol
266 << "' does not refer to a type or attribute "
267 "definition (refers to '"
268 << targetOp->getName() << "')";
269
270 return success();
271}
272
273LogicalResult BaseOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
274 std::optional<SymbolRefAttr> baseRef = getBaseRef();
275 if (!baseRef)
276 return success();
277
278 return checkSymbolIsTypeOrAttribute(symbolTable, source: *this, symbol: *baseRef);
279}
280
281LogicalResult
282ParametricOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
283 std::optional<SymbolRefAttr> baseRef = getBaseType();
284 if (!baseRef)
285 return success();
286
287 return checkSymbolIsTypeOrAttribute(symbolTable, source: *this, symbol: *baseRef);
288}
289
290/// Parse a value with its variadicity first. By default, the variadicity is
291/// single.
292///
293/// value-with-variadicity ::= ("single" | "optional" | "variadic")? ssa-value
294static ParseResult
295parseValueWithVariadicity(OpAsmParser &p,
296 OpAsmParser::UnresolvedOperand &operand,
297 VariadicityAttr &variadicityAttr) {
298 MLIRContext *ctx = p.getBuilder().getContext();
299
300 // Parse the variadicity, if present
301 if (p.parseOptionalKeyword(keyword: "single").succeeded()) {
302 variadicityAttr = VariadicityAttr::get(context: ctx, value: Variadicity::single);
303 } else if (p.parseOptionalKeyword(keyword: "optional").succeeded()) {
304 variadicityAttr = VariadicityAttr::get(context: ctx, value: Variadicity::optional);
305 } else if (p.parseOptionalKeyword(keyword: "variadic").succeeded()) {
306 variadicityAttr = VariadicityAttr::get(context: ctx, value: Variadicity::variadic);
307 } else {
308 variadicityAttr = VariadicityAttr::get(context: ctx, value: Variadicity::single);
309 }
310
311 // Parse the value
312 if (p.parseOperand(result&: operand))
313 return failure();
314 return success();
315}
316
317static ParseResult parseNamedValueListImpl(
318 OpAsmParser &p, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
319 ArrayAttr &valueNamesAttr, VariadicityArrayAttr *variadicityAttr) {
320 Builder &builder = p.getBuilder();
321 MLIRContext *ctx = builder.getContext();
322 SmallVector<Attribute> valueNames;
323 SmallVector<VariadicityAttr> variadicities;
324
325 // Parse a single value with its variadicity
326 auto parseOne = [&] {
327 StringRef name;
328 OpAsmParser::UnresolvedOperand operand;
329 VariadicityAttr variadicity;
330 if (p.parseKeyword(keyword: &name) || p.parseColon())
331 return failure();
332
333 if (variadicityAttr) {
334 if (parseValueWithVariadicity(p, operand, variadicityAttr&: variadicity))
335 return failure();
336 variadicities.push_back(Elt: variadicity);
337 } else {
338 if (p.parseOperand(result&: operand))
339 return failure();
340 }
341
342 valueNames.push_back(Elt: StringAttr::get(context: ctx, bytes: name));
343 operands.push_back(Elt: operand);
344 return success();
345 };
346
347 if (p.parseCommaSeparatedList(delimiter: OpAsmParser::Delimiter::Paren, parseElementFn: parseOne))
348 return failure();
349 valueNamesAttr = ArrayAttr::get(context: ctx, value: valueNames);
350 if (variadicityAttr)
351 *variadicityAttr = VariadicityArrayAttr::get(context: ctx, value: variadicities);
352 return success();
353}
354
355/// Parse a list of named values.
356///
357/// values ::=
358/// `(` (named-value (`,` named-value)*)? `)`
359/// named-value := bare-id `:` ssa-value
360static ParseResult
361parseNamedValueList(OpAsmParser &p,
362 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
363 ArrayAttr &valueNamesAttr) {
364 return parseNamedValueListImpl(p, operands, valueNamesAttr, variadicityAttr: nullptr);
365}
366
367/// Parse a list of named values with their variadicities first. By default, the
368/// variadicity is single.
369///
370/// values-with-variadicity ::=
371/// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)`
372/// value-with-variadicity
373/// ::= bare-id `:` ("single" | "optional" | "variadic")? ssa-value
374static ParseResult parseNamedValueListWithVariadicity(
375 OpAsmParser &p, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
376 ArrayAttr &valueNamesAttr, VariadicityArrayAttr &variadicityAttr) {
377 return parseNamedValueListImpl(p, operands, valueNamesAttr, variadicityAttr: &variadicityAttr);
378}
379
380static void printNamedValueListImpl(OpAsmPrinter &p, Operation *op,
381 OperandRange operands,
382 ArrayAttr valueNamesAttr,
383 VariadicityArrayAttr variadicityAttr) {
384 p << "(";
385 interleaveComma(c: llvm::seq<int>(Begin: 0, End: operands.size()), os&: p, each_fn: [&](int i) {
386 p << llvm::cast<StringAttr>(Val: valueNamesAttr[i]).getValue() << ": ";
387 if (variadicityAttr) {
388 Variadicity variadicity = variadicityAttr[i].getValue();
389 if (variadicity != Variadicity::single) {
390 p << stringifyVariadicity(variadicity) << " ";
391 }
392 }
393 p << operands[i];
394 });
395 p << ")";
396}
397
398/// Print a list of named values.
399///
400/// values ::=
401/// `(` (named-value (`,` named-value)*)? `)`
402/// named-value := bare-id `:` ssa-value
403static void printNamedValueList(OpAsmPrinter &p, Operation *op,
404 OperandRange operands,
405 ArrayAttr valueNamesAttr) {
406 printNamedValueListImpl(p, op, operands, valueNamesAttr, variadicityAttr: nullptr);
407}
408
409/// Print a list of named values with their variadicities first. By default, the
410/// variadicity is single.
411///
412/// values-with-variadicity ::=
413/// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)`
414/// value-with-variadicity ::=
415/// bare-id `:` ("single" | "optional" | "variadic")? ssa-value
416static void printNamedValueListWithVariadicity(
417 OpAsmPrinter &p, Operation *op, OperandRange operands,
418 ArrayAttr valueNamesAttr, VariadicityArrayAttr variadicityAttr) {
419 printNamedValueListImpl(p, op, operands, valueNamesAttr, variadicityAttr);
420}
421
422static ParseResult
423parseAttributesOp(OpAsmParser &p,
424 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
425 ArrayAttr &attrNamesAttr) {
426 Builder &builder = p.getBuilder();
427 SmallVector<Attribute> attrNames;
428 if (succeeded(Result: p.parseOptionalLBrace())) {
429 auto parseOperands = [&]() {
430 if (p.parseAttribute(result&: attrNames.emplace_back()) || p.parseEqual() ||
431 p.parseOperand(result&: attrOperands.emplace_back()))
432 return failure();
433 return success();
434 };
435 if (p.parseCommaSeparatedList(parseElementFn: parseOperands) || p.parseRBrace())
436 return failure();
437 }
438 attrNamesAttr = builder.getArrayAttr(value: attrNames);
439 return success();
440}
441
442static void printAttributesOp(OpAsmPrinter &p, AttributesOp op,
443 OperandRange attrArgs, ArrayAttr attrNames) {
444 if (attrNames.empty())
445 return;
446 p << "{";
447 interleaveComma(c: llvm::seq<int>(Begin: 0, End: attrNames.size()), os&: p,
448 each_fn: [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
449 p << '}';
450}
451
452LogicalResult RegionOp::verify() {
453 if (IntegerAttr numberOfBlocks = getNumberOfBlocksAttr())
454 if (int64_t number = numberOfBlocks.getInt(); number <= 0) {
455 return emitOpError(message: "the number of blocks is expected to be >= 1 but got ")
456 << number;
457 }
458 return success();
459}
460
461LogicalResult RegionsOp::verify() {
462 return verifyNames(op: *this, kindName: "region", names: getNames(), numOperands: getNumOperands());
463}
464
465#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.cpp.inc"
466
467#define GET_TYPEDEF_CLASSES
468#include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
469
470#include "mlir/Dialect/IRDL/IR/IRDLEnums.cpp.inc"
471
472#define GET_ATTRDEF_CLASSES
473#include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
474
475#define GET_OP_CLASSES
476#include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"
477

source code of mlir/lib/Dialect/IRDL/IR/IRDL.cpp