1//===- IRDL.cpp - IRDL dialect ----------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/IRDL/IR/IRDL.h"
10#include "mlir/Dialect/IRDL/IRDLSymbols.h"
11#include "mlir/IR/Builders.h"
12#include "mlir/IR/BuiltinAttributes.h"
13#include "mlir/IR/Diagnostics.h"
14#include "mlir/IR/DialectImplementation.h"
15#include "mlir/IR/ExtensibleDialect.h"
16#include "mlir/IR/OpDefinition.h"
17#include "mlir/IR/OpImplementation.h"
18#include "mlir/IR/Operation.h"
19#include "mlir/Support/LLVM.h"
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/ADT/SetOperations.h"
22#include "llvm/ADT/SmallString.h"
23#include "llvm/ADT/StringExtras.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/IR/Metadata.h"
26#include "llvm/Support/Casting.h"
27
28using namespace mlir;
29using namespace mlir::irdl;
30
31//===----------------------------------------------------------------------===//
32// IRDL dialect.
33//===----------------------------------------------------------------------===//
34
35#include "mlir/Dialect/IRDL/IR/IRDL.cpp.inc"
36
37#include "mlir/Dialect/IRDL/IR/IRDLDialect.cpp.inc"
38
39void IRDLDialect::initialize() {
40 addOperations<
41#define GET_OP_LIST
42#include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"
43 >();
44 addTypes<
45#define GET_TYPEDEF_LIST
46#include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
47 >();
48 addAttributes<
49#define GET_ATTRDEF_LIST
50#include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
51 >();
52}
53
54//===----------------------------------------------------------------------===//
55// Parsing/Printing/Verifying
56//===----------------------------------------------------------------------===//
57
58/// Parse a region, and add a single block if the region is empty.
59/// If no region is parsed, create a new region with a single empty block.
60static ParseResult parseSingleBlockRegion(OpAsmParser &p, Region &region) {
61 auto regionParseRes = p.parseOptionalRegion(region);
62 if (regionParseRes.has_value() && failed(Result: regionParseRes.value()))
63 return failure();
64
65 // If the region is empty, add a single empty block.
66 if (region.empty())
67 region.push_back(block: new Block());
68
69 return success();
70}
71
72static void printSingleBlockRegion(OpAsmPrinter &p, Operation *op,
73 Region &region) {
74 if (!region.getBlocks().front().empty())
75 p.printRegion(blocks&: region);
76}
77static llvm::LogicalResult isValidName(llvm::StringRef in, mlir::Operation *loc,
78 const Twine &label) {
79 if (in.empty())
80 return loc->emitError(message: "name of ") << label << " is empty";
81
82 bool allowUnderscore = false;
83 for (auto &elem : in) {
84 if (elem == '_') {
85 if (!allowUnderscore)
86 return loc->emitError(message: "name of ")
87 << label << " should not contain leading or double underscores";
88 } else {
89 if (!isalnum(elem))
90 return loc->emitError(message: "name of ")
91 << label
92 << " must contain only lowercase letters, digits and "
93 "underscores";
94
95 if (llvm::isUpper(C: elem))
96 return loc->emitError(message: "name of ")
97 << label << " should not contain uppercase letters";
98 }
99
100 allowUnderscore = elem != '_';
101 }
102
103 return success();
104}
105
106LogicalResult DialectOp::verify() {
107 if (!Dialect::isValidNamespace(getName()))
108 return emitOpError("invalid dialect name");
109 if (failed(isValidName(getSymName(), getOperation(), "dialect")))
110 return failure();
111
112 return success();
113}
114
115LogicalResult OperationOp::verify() {
116 return isValidName(getSymName(), getOperation(), "operation");
117}
118
119LogicalResult TypeOp::verify() {
120 auto symName = getSymName();
121 if (symName.front() == '!')
122 symName = symName.substr(1);
123 return isValidName(symName, getOperation(), "type");
124}
125
126LogicalResult AttributeOp::verify() {
127 auto symName = getSymName();
128 if (symName.front() == '#')
129 symName = symName.substr(1);
130 return isValidName(symName, getOperation(), "attribute");
131}
132
133LogicalResult OperationOp::verifyRegions() {
134 // Stores pairs of value kinds and the list of names of values of this kind in
135 // the operation.
136 SmallVector<std::tuple<StringRef, llvm::SmallDenseSet<StringRef>>> valueNames;
137
138 auto insertNames = [&](StringRef kind, ArrayAttr names) {
139 llvm::SmallDenseSet<StringRef> nameSet;
140 nameSet.reserve(names.size());
141 for (auto name : names)
142 nameSet.insert(llvm::cast<StringAttr>(name).getValue());
143 valueNames.emplace_back(kind, std::move(nameSet));
144 };
145
146 for (Operation &op : getBody().getOps()) {
147 TypeSwitch<Operation *>(&op)
148 .Case<OperandsOp>(
149 [&](OperandsOp op) { insertNames("operands", op.getNames()); })
150 .Case<ResultsOp>(
151 [&](ResultsOp op) { insertNames("results", op.getNames()); })
152 .Case<RegionsOp>(
153 [&](RegionsOp op) { insertNames("regions", op.getNames()); });
154 }
155
156 // Verify that no two operand, result or region share the same name.
157 // The absence of duplicates within each value kind is checked by the
158 // associated operation's verifier.
159 for (size_t i : llvm::seq(valueNames.size())) {
160 for (size_t j : llvm::seq(i + 1, valueNames.size())) {
161 auto [lhs, lhsSet] = valueNames[i];
162 auto &[rhs, rhsSet] = valueNames[j];
163 llvm::set_intersect(lhsSet, rhsSet);
164 if (!lhsSet.empty())
165 return emitOpError("contains a value named '")
166 << *lhsSet.begin() << "' for both its " << lhs << " and " << rhs;
167 }
168 }
169
170 return success();
171}
172
173static LogicalResult verifyNames(Operation *op, StringRef kindName,
174 ArrayAttr names, size_t numOperands) {
175 if (numOperands != names.size())
176 return op->emitOpError()
177 << "the number of " << kindName
178 << "s and their names must be "
179 "the same, but got "
180 << numOperands << " and " << names.size() << " respectively";
181
182 DenseMap<StringRef, size_t> nameMap;
183 for (auto [i, name] : llvm::enumerate(names)) {
184 StringRef nameRef = llvm::cast<StringAttr>(name).getValue();
185
186 if (failed(isValidName(nameRef, op, Twine(kindName) + " #" + Twine(i))))
187 return failure();
188
189 if (nameMap.contains(nameRef))
190 return op->emitOpError() << "name of " << kindName << " #" << i
191 << " is a duplicate of the name of " << kindName
192 << " #" << nameMap[nameRef];
193 nameMap.insert({nameRef, i});
194 }
195
196 return success();
197}
198
199LogicalResult ParametersOp::verify() {
200 return verifyNames(*this, "parameter", getNames(), getNumOperands());
201}
202
203template <typename ValueListOp>
204static LogicalResult verifyOperandsResultsCommon(ValueListOp op,
205 StringRef kindName) {
206 size_t numVariadicities = op.getVariadicity().size();
207 size_t numOperands = op.getNumOperands();
208
209 if (numOperands != numVariadicities)
210 return op.emitOpError()
211 << "the number of " << kindName
212 << "s and their variadicities must be "
213 "the same, but got "
214 << numOperands << " and " << numVariadicities << " respectively";
215
216 return verifyNames(op, kindName, op.getNames(), numOperands);
217}
218
219LogicalResult OperandsOp::verify() {
220 return verifyOperandsResultsCommon(*this, "operand");
221}
222
223LogicalResult ResultsOp::verify() {
224 return verifyOperandsResultsCommon(*this, "result");
225}
226
227LogicalResult AttributesOp::verify() {
228 size_t namesSize = getAttributeValueNames().size();
229 size_t valuesSize = getAttributeValues().size();
230
231 if (namesSize != valuesSize)
232 return emitOpError()
233 << "the number of attribute names and their constraints must be "
234 "the same but got "
235 << namesSize << " and " << valuesSize << " respectively";
236
237 return success();
238}
239
240LogicalResult BaseOp::verify() {
241 std::optional<StringRef> baseName = getBaseName();
242 std::optional<SymbolRefAttr> baseRef = getBaseRef();
243 if (baseName.has_value() == baseRef.has_value())
244 return emitOpError() << "the base type or attribute should be specified by "
245 "either a name or a reference";
246
247 if (baseName &&
248 (baseName->empty() || ((*baseName)[0] != '!' && (*baseName)[0] != '#')))
249 return emitOpError() << "the base type or attribute name should start with "
250 "'!' or '#'";
251
252 return success();
253}
254
255/// Finds whether the provided symbol is an IRDL type or attribute definition.
256/// The source operation must be within a DialectOp.
257static LogicalResult
258checkSymbolIsTypeOrAttribute(SymbolTableCollection &symbolTable,
259 Operation *source, SymbolRefAttr symbol) {
260 Operation *targetOp =
261 irdl::lookupSymbolNearDialect(symbolTable, source, symbol);
262
263 if (!targetOp)
264 return source->emitOpError() << "symbol '" << symbol << "' not found";
265
266 if (!isa<TypeOp, AttributeOp>(Val: targetOp))
267 return source->emitOpError() << "symbol '" << symbol
268 << "' does not refer to a type or attribute "
269 "definition (refers to '"
270 << targetOp->getName() << "')";
271
272 return success();
273}
274
275LogicalResult BaseOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
276 std::optional<SymbolRefAttr> baseRef = getBaseRef();
277 if (!baseRef)
278 return success();
279
280 return checkSymbolIsTypeOrAttribute(symbolTable, *this, *baseRef);
281}
282
283LogicalResult
284ParametricOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
285 std::optional<SymbolRefAttr> baseRef = getBaseType();
286 if (!baseRef)
287 return success();
288
289 return checkSymbolIsTypeOrAttribute(symbolTable, *this, *baseRef);
290}
291
292/// Parse a value with its variadicity first. By default, the variadicity is
293/// single.
294///
295/// value-with-variadicity ::= ("single" | "optional" | "variadic")? ssa-value
296static ParseResult
297parseValueWithVariadicity(OpAsmParser &p,
298 OpAsmParser::UnresolvedOperand &operand,
299 VariadicityAttr &variadicityAttr) {
300 MLIRContext *ctx = p.getBuilder().getContext();
301
302 // Parse the variadicity, if present
303 if (p.parseOptionalKeyword(keyword: "single").succeeded()) {
304 variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single);
305 } else if (p.parseOptionalKeyword(keyword: "optional").succeeded()) {
306 variadicityAttr = VariadicityAttr::get(ctx, Variadicity::optional);
307 } else if (p.parseOptionalKeyword(keyword: "variadic").succeeded()) {
308 variadicityAttr = VariadicityAttr::get(ctx, Variadicity::variadic);
309 } else {
310 variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single);
311 }
312
313 // Parse the value
314 if (p.parseOperand(result&: operand))
315 return failure();
316 return success();
317}
318
319static ParseResult parseNamedValueListImpl(
320 OpAsmParser &p, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
321 ArrayAttr &valueNamesAttr, VariadicityArrayAttr *variadicityAttr) {
322 Builder &builder = p.getBuilder();
323 MLIRContext *ctx = builder.getContext();
324 SmallVector<Attribute> valueNames;
325 SmallVector<VariadicityAttr> variadicities;
326
327 // Parse a single value with its variadicity
328 auto parseOne = [&] {
329 StringRef name;
330 OpAsmParser::UnresolvedOperand operand;
331 VariadicityAttr variadicity;
332 if (p.parseKeyword(keyword: &name) || p.parseColon())
333 return failure();
334
335 if (variadicityAttr) {
336 if (parseValueWithVariadicity(p, operand, variadicity))
337 return failure();
338 variadicities.push_back(variadicity);
339 } else {
340 if (p.parseOperand(result&: operand))
341 return failure();
342 }
343
344 valueNames.push_back(StringAttr::get(ctx, name));
345 operands.push_back(Elt: operand);
346 return success();
347 };
348
349 if (p.parseCommaSeparatedList(delimiter: OpAsmParser::Delimiter::Paren, parseElementFn: parseOne))
350 return failure();
351 valueNamesAttr = ArrayAttr::get(ctx, valueNames);
352 if (variadicityAttr)
353 *variadicityAttr = VariadicityArrayAttr::get(ctx, variadicities);
354 return success();
355}
356
357/// Parse a list of named values.
358///
359/// values ::=
360/// `(` (named-value (`,` named-value)*)? `)`
361/// named-value := bare-id `:` ssa-value
362static ParseResult
363parseNamedValueList(OpAsmParser &p,
364 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
365 ArrayAttr &valueNamesAttr) {
366 return parseNamedValueListImpl(p, operands, valueNamesAttr, nullptr);
367}
368
369/// Parse a list of named values with their variadicities first. By default, the
370/// variadicity is single.
371///
372/// values-with-variadicity ::=
373/// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)`
374/// value-with-variadicity
375/// ::= bare-id `:` ("single" | "optional" | "variadic")? ssa-value
376static ParseResult parseNamedValueListWithVariadicity(
377 OpAsmParser &p, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
378 ArrayAttr &valueNamesAttr, VariadicityArrayAttr &variadicityAttr) {
379 return parseNamedValueListImpl(p, operands, valueNamesAttr, &variadicityAttr);
380}
381
382static void printNamedValueListImpl(OpAsmPrinter &p, Operation *op,
383 OperandRange operands,
384 ArrayAttr valueNamesAttr,
385 VariadicityArrayAttr variadicityAttr) {
386 p << "(";
387 interleaveComma(c: llvm::seq<int>(Begin: 0, End: operands.size()), os&: p, each_fn: [&](int i) {
388 p << llvm::cast<StringAttr>(valueNamesAttr[i]).getValue() << ": ";
389 if (variadicityAttr) {
390 Variadicity variadicity = variadicityAttr[i].getValue();
391 if (variadicity != Variadicity::single) {
392 p << stringifyVariadicity(variadicity) << " ";
393 }
394 }
395 p << operands[i];
396 });
397 p << ")";
398}
399
400/// Print a list of named values.
401///
402/// values ::=
403/// `(` (named-value (`,` named-value)*)? `)`
404/// named-value := bare-id `:` ssa-value
405static void printNamedValueList(OpAsmPrinter &p, Operation *op,
406 OperandRange operands,
407 ArrayAttr valueNamesAttr) {
408 printNamedValueListImpl(p, op, operands, valueNamesAttr, nullptr);
409}
410
411/// Print a list of named values with their variadicities first. By default, the
412/// variadicity is single.
413///
414/// values-with-variadicity ::=
415/// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)`
416/// value-with-variadicity ::=
417/// bare-id `:` ("single" | "optional" | "variadic")? ssa-value
418static void printNamedValueListWithVariadicity(
419 OpAsmPrinter &p, Operation *op, OperandRange operands,
420 ArrayAttr valueNamesAttr, VariadicityArrayAttr variadicityAttr) {
421 printNamedValueListImpl(p, op, operands, valueNamesAttr, variadicityAttr);
422}
423
424static ParseResult
425parseAttributesOp(OpAsmParser &p,
426 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
427 ArrayAttr &attrNamesAttr) {
428 Builder &builder = p.getBuilder();
429 SmallVector<Attribute> attrNames;
430 if (succeeded(Result: p.parseOptionalLBrace())) {
431 auto parseOperands = [&]() {
432 if (p.parseAttribute(result&: attrNames.emplace_back()) || p.parseEqual() ||
433 p.parseOperand(result&: attrOperands.emplace_back()))
434 return failure();
435 return success();
436 };
437 if (p.parseCommaSeparatedList(parseElementFn: parseOperands) || p.parseRBrace())
438 return failure();
439 }
440 attrNamesAttr = builder.getArrayAttr(attrNames);
441 return success();
442}
443
444static void printAttributesOp(OpAsmPrinter &p, AttributesOp op,
445 OperandRange attrArgs, ArrayAttr attrNames) {
446 if (attrNames.empty())
447 return;
448 p << "{";
449 interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
450 [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
451 p << '}';
452}
453
454LogicalResult RegionOp::verify() {
455 if (IntegerAttr numberOfBlocks = getNumberOfBlocksAttr())
456 if (int64_t number = numberOfBlocks.getInt(); number <= 0) {
457 return emitOpError("the number of blocks is expected to be >= 1 but got ")
458 << number;
459 }
460 return success();
461}
462
463LogicalResult RegionsOp::verify() {
464 return verifyNames(*this, "region", getNames(), getNumOperands());
465}
466
467#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.cpp.inc"
468
469#define GET_TYPEDEF_CLASSES
470#include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
471
472#include "mlir/Dialect/IRDL/IR/IRDLEnums.cpp.inc"
473
474#define GET_ATTRDEF_CLASSES
475#include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
476
477#define GET_OP_CLASSES
478#include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"
479

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/IRDL/IR/IRDL.cpp