1//===- IRDL.cpp - IRDL dialect ----------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/IRDL/IR/IRDL.h"
10#include "mlir/IR/Builders.h"
11#include "mlir/IR/BuiltinAttributes.h"
12#include "mlir/IR/Diagnostics.h"
13#include "mlir/IR/DialectImplementation.h"
14#include "mlir/IR/ExtensibleDialect.h"
15#include "mlir/IR/OpDefinition.h"
16#include "mlir/IR/OpImplementation.h"
17#include "mlir/IR/Operation.h"
18#include "mlir/Support/LLVM.h"
19#include "mlir/Support/LogicalResult.h"
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/ADT/TypeSwitch.h"
22#include "llvm/IR/Metadata.h"
23#include "llvm/Support/Casting.h"
24
25using namespace mlir;
26using namespace mlir::irdl;
27
28//===----------------------------------------------------------------------===//
29// IRDL dialect.
30//===----------------------------------------------------------------------===//
31
32#include "mlir/Dialect/IRDL/IR/IRDL.cpp.inc"
33
34#include "mlir/Dialect/IRDL/IR/IRDLDialect.cpp.inc"
35
36void IRDLDialect::initialize() {
37 addOperations<
38#define GET_OP_LIST
39#include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"
40 >();
41 addTypes<
42#define GET_TYPEDEF_LIST
43#include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
44 >();
45 addAttributes<
46#define GET_ATTRDEF_LIST
47#include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
48 >();
49}
50
51//===----------------------------------------------------------------------===//
52// Parsing/Printing
53//===----------------------------------------------------------------------===//
54
55/// Parse a region, and add a single block if the region is empty.
56/// If no region is parsed, create a new region with a single empty block.
57static ParseResult parseSingleBlockRegion(OpAsmParser &p, Region &region) {
58 auto regionParseRes = p.parseOptionalRegion(region);
59 if (regionParseRes.has_value() && failed(result: regionParseRes.value()))
60 return failure();
61
62 // If the region is empty, add a single empty block.
63 if (region.empty())
64 region.push_back(block: new Block());
65
66 return success();
67}
68
69static void printSingleBlockRegion(OpAsmPrinter &p, Operation *op,
70 Region &region) {
71 if (!region.getBlocks().front().empty())
72 p.printRegion(blocks&: region);
73}
74
75LogicalResult DialectOp::verify() {
76 if (!Dialect::isValidNamespace(getName()))
77 return emitOpError("invalid dialect name");
78 return success();
79}
80
81LogicalResult OperandsOp::verify() {
82 size_t numVariadicities = getVariadicity().size();
83 size_t numOperands = getNumOperands();
84
85 if (numOperands != numVariadicities)
86 return emitOpError()
87 << "the number of operands and their variadicities must be "
88 "the same, but got "
89 << numOperands << " and " << numVariadicities << " respectively";
90
91 return success();
92}
93
94LogicalResult ResultsOp::verify() {
95 size_t numVariadicities = getVariadicity().size();
96 size_t numOperands = this->getNumOperands();
97
98 if (numOperands != numVariadicities)
99 return emitOpError()
100 << "the number of operands and their variadicities must be "
101 "the same, but got "
102 << numOperands << " and " << numVariadicities << " respectively";
103
104 return success();
105}
106
107LogicalResult AttributesOp::verify() {
108 size_t namesSize = getAttributeValueNames().size();
109 size_t valuesSize = getAttributeValues().size();
110
111 if (namesSize != valuesSize)
112 return emitOpError()
113 << "the number of attribute names and their constraints must be "
114 "the same but got "
115 << namesSize << " and " << valuesSize << " respectively";
116
117 return success();
118}
119
120LogicalResult BaseOp::verify() {
121 std::optional<StringRef> baseName = getBaseName();
122 std::optional<SymbolRefAttr> baseRef = getBaseRef();
123 if (baseName.has_value() == baseRef.has_value())
124 return emitOpError() << "the base type or attribute should be specified by "
125 "either a name or a reference";
126
127 if (baseName &&
128 (baseName->empty() || ((*baseName)[0] != '!' && (*baseName)[0] != '#')))
129 return emitOpError() << "the base type or attribute name should start with "
130 "'!' or '#'";
131
132 return success();
133}
134
135LogicalResult BaseOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
136 std::optional<SymbolRefAttr> baseRef = getBaseRef();
137 if (!baseRef)
138 return success();
139
140 TypeOp typeOp = symbolTable.lookupNearestSymbolFrom<TypeOp>(*this, *baseRef);
141 if (typeOp)
142 return success();
143
144 AttributeOp attrOp =
145 symbolTable.lookupNearestSymbolFrom<AttributeOp>(*this, *baseRef);
146 if (attrOp)
147 return success();
148
149 return emitOpError() << "'" << *baseRef
150 << "' does not refer to a type or attribute definition";
151}
152
153/// Parse a value with its variadicity first. By default, the variadicity is
154/// single.
155///
156/// value-with-variadicity ::= ("single" | "optional" | "variadic")? ssa-value
157static ParseResult
158parseValueWithVariadicity(OpAsmParser &p,
159 OpAsmParser::UnresolvedOperand &operand,
160 VariadicityAttr &variadicityAttr) {
161 MLIRContext *ctx = p.getBuilder().getContext();
162
163 // Parse the variadicity, if present
164 if (p.parseOptionalKeyword(keyword: "single").succeeded()) {
165 variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single);
166 } else if (p.parseOptionalKeyword(keyword: "optional").succeeded()) {
167 variadicityAttr = VariadicityAttr::get(ctx, Variadicity::optional);
168 } else if (p.parseOptionalKeyword(keyword: "variadic").succeeded()) {
169 variadicityAttr = VariadicityAttr::get(ctx, Variadicity::variadic);
170 } else {
171 variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single);
172 }
173
174 // Parse the value
175 if (p.parseOperand(result&: operand))
176 return failure();
177 return success();
178}
179
180/// Parse a list of values with their variadicities first. By default, the
181/// variadicity is single.
182///
183/// values-with-variadicity ::=
184/// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)`
185/// value-with-variadicity ::= ("single" | "optional" | "variadic")? ssa-value
186static ParseResult parseValuesWithVariadicity(
187 OpAsmParser &p, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
188 VariadicityArrayAttr &variadicityAttr) {
189 Builder &builder = p.getBuilder();
190 MLIRContext *ctx = builder.getContext();
191 SmallVector<VariadicityAttr> variadicities;
192
193 // Parse a single value with its variadicity
194 auto parseOne = [&] {
195 OpAsmParser::UnresolvedOperand operand;
196 VariadicityAttr variadicity;
197 if (parseValueWithVariadicity(p, operand, variadicity))
198 return failure();
199 operands.push_back(Elt: operand);
200 variadicities.push_back(variadicity);
201 return success();
202 };
203
204 if (p.parseCommaSeparatedList(delimiter: OpAsmParser::Delimiter::Paren, parseElementFn: parseOne))
205 return failure();
206 variadicityAttr = VariadicityArrayAttr::get(ctx, variadicities);
207 return success();
208}
209
210/// Print a list of values with their variadicities first. By default, the
211/// variadicity is single.
212///
213/// values-with-variadicity ::=
214/// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)`
215/// value-with-variadicity ::= ("single" | "optional" | "variadic")? ssa-value
216static void printValuesWithVariadicity(OpAsmPrinter &p, Operation *op,
217 OperandRange operands,
218 VariadicityArrayAttr variadicityAttr) {
219 p << "(";
220 interleaveComma(c: llvm::seq<int>(Begin: 0, End: operands.size()), os&: p, each_fn: [&](int i) {
221 Variadicity variadicity = variadicityAttr[i].getValue();
222 if (variadicity != Variadicity::single) {
223 p << stringifyVariadicity(variadicity) << " ";
224 }
225 p << operands[i];
226 });
227 p << ")";
228}
229
230static ParseResult
231parseAttributesOp(OpAsmParser &p,
232 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
233 ArrayAttr &attrNamesAttr) {
234 Builder &builder = p.getBuilder();
235 SmallVector<Attribute> attrNames;
236 if (succeeded(result: p.parseOptionalLBrace())) {
237 auto parseOperands = [&]() {
238 if (p.parseAttribute(result&: attrNames.emplace_back()) || p.parseEqual() ||
239 p.parseOperand(result&: attrOperands.emplace_back()))
240 return failure();
241 return success();
242 };
243 if (p.parseCommaSeparatedList(parseElementFn: parseOperands) || p.parseRBrace())
244 return failure();
245 }
246 attrNamesAttr = builder.getArrayAttr(attrNames);
247 return success();
248}
249
250static void printAttributesOp(OpAsmPrinter &p, AttributesOp op,
251 OperandRange attrArgs, ArrayAttr attrNames) {
252 if (attrNames.empty())
253 return;
254 p << "{";
255 interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
256 [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
257 p << '}';
258}
259
260LogicalResult RegionOp::verify() {
261 if (IntegerAttr numberOfBlocks = getNumberOfBlocksAttr())
262 if (int64_t number = numberOfBlocks.getInt(); number <= 0) {
263 return emitOpError("the number of blocks is expected to be >= 1 but got ")
264 << number;
265 }
266 return success();
267}
268
269#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.cpp.inc"
270
271#define GET_TYPEDEF_CLASSES
272#include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
273
274#include "mlir/Dialect/IRDL/IR/IRDLEnums.cpp.inc"
275
276#define GET_ATTRDEF_CLASSES
277#include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
278
279#define GET_OP_CLASSES
280#include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"
281

source code of mlir/lib/Dialect/IRDL/IR/IRDL.cpp