1 | //===- IRDL.cpp - IRDL dialect ----------------------------------*- C++ -*-===// |
2 | // |
3 | // This file is licensed under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/IRDL/IR/IRDL.h" |
10 | #include "mlir/IR/Builders.h" |
11 | #include "mlir/IR/BuiltinAttributes.h" |
12 | #include "mlir/IR/Diagnostics.h" |
13 | #include "mlir/IR/DialectImplementation.h" |
14 | #include "mlir/IR/ExtensibleDialect.h" |
15 | #include "mlir/IR/OpDefinition.h" |
16 | #include "mlir/IR/OpImplementation.h" |
17 | #include "mlir/IR/Operation.h" |
18 | #include "mlir/Support/LLVM.h" |
19 | #include "mlir/Support/LogicalResult.h" |
20 | #include "llvm/ADT/STLExtras.h" |
21 | #include "llvm/ADT/TypeSwitch.h" |
22 | #include "llvm/IR/Metadata.h" |
23 | #include "llvm/Support/Casting.h" |
24 | |
25 | using namespace mlir; |
26 | using namespace mlir::irdl; |
27 | |
28 | //===----------------------------------------------------------------------===// |
29 | // IRDL dialect. |
30 | //===----------------------------------------------------------------------===// |
31 | |
32 | #include "mlir/Dialect/IRDL/IR/IRDL.cpp.inc" |
33 | |
34 | #include "mlir/Dialect/IRDL/IR/IRDLDialect.cpp.inc" |
35 | |
36 | void IRDLDialect::initialize() { |
37 | addOperations< |
38 | #define GET_OP_LIST |
39 | #include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc" |
40 | >(); |
41 | addTypes< |
42 | #define GET_TYPEDEF_LIST |
43 | #include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc" |
44 | >(); |
45 | addAttributes< |
46 | #define GET_ATTRDEF_LIST |
47 | #include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc" |
48 | >(); |
49 | } |
50 | |
51 | //===----------------------------------------------------------------------===// |
52 | // Parsing/Printing |
53 | //===----------------------------------------------------------------------===// |
54 | |
55 | /// Parse a region, and add a single block if the region is empty. |
56 | /// If no region is parsed, create a new region with a single empty block. |
57 | static ParseResult parseSingleBlockRegion(OpAsmParser &p, Region ®ion) { |
58 | auto regionParseRes = p.parseOptionalRegion(region); |
59 | if (regionParseRes.has_value() && failed(result: regionParseRes.value())) |
60 | return failure(); |
61 | |
62 | // If the region is empty, add a single empty block. |
63 | if (region.empty()) |
64 | region.push_back(block: new Block()); |
65 | |
66 | return success(); |
67 | } |
68 | |
69 | static void printSingleBlockRegion(OpAsmPrinter &p, Operation *op, |
70 | Region ®ion) { |
71 | if (!region.getBlocks().front().empty()) |
72 | p.printRegion(blocks&: region); |
73 | } |
74 | |
75 | LogicalResult DialectOp::verify() { |
76 | if (!Dialect::isValidNamespace(getName())) |
77 | return emitOpError("invalid dialect name" ); |
78 | return success(); |
79 | } |
80 | |
81 | LogicalResult OperandsOp::verify() { |
82 | size_t numVariadicities = getVariadicity().size(); |
83 | size_t numOperands = getNumOperands(); |
84 | |
85 | if (numOperands != numVariadicities) |
86 | return emitOpError() |
87 | << "the number of operands and their variadicities must be " |
88 | "the same, but got " |
89 | << numOperands << " and " << numVariadicities << " respectively" ; |
90 | |
91 | return success(); |
92 | } |
93 | |
94 | LogicalResult ResultsOp::verify() { |
95 | size_t numVariadicities = getVariadicity().size(); |
96 | size_t numOperands = this->getNumOperands(); |
97 | |
98 | if (numOperands != numVariadicities) |
99 | return emitOpError() |
100 | << "the number of operands and their variadicities must be " |
101 | "the same, but got " |
102 | << numOperands << " and " << numVariadicities << " respectively" ; |
103 | |
104 | return success(); |
105 | } |
106 | |
107 | LogicalResult AttributesOp::verify() { |
108 | size_t namesSize = getAttributeValueNames().size(); |
109 | size_t valuesSize = getAttributeValues().size(); |
110 | |
111 | if (namesSize != valuesSize) |
112 | return emitOpError() |
113 | << "the number of attribute names and their constraints must be " |
114 | "the same but got " |
115 | << namesSize << " and " << valuesSize << " respectively" ; |
116 | |
117 | return success(); |
118 | } |
119 | |
120 | LogicalResult BaseOp::verify() { |
121 | std::optional<StringRef> baseName = getBaseName(); |
122 | std::optional<SymbolRefAttr> baseRef = getBaseRef(); |
123 | if (baseName.has_value() == baseRef.has_value()) |
124 | return emitOpError() << "the base type or attribute should be specified by " |
125 | "either a name or a reference" ; |
126 | |
127 | if (baseName && |
128 | (baseName->empty() || ((*baseName)[0] != '!' && (*baseName)[0] != '#'))) |
129 | return emitOpError() << "the base type or attribute name should start with " |
130 | "'!' or '#'" ; |
131 | |
132 | return success(); |
133 | } |
134 | |
135 | LogicalResult BaseOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
136 | std::optional<SymbolRefAttr> baseRef = getBaseRef(); |
137 | if (!baseRef) |
138 | return success(); |
139 | |
140 | TypeOp typeOp = symbolTable.lookupNearestSymbolFrom<TypeOp>(*this, *baseRef); |
141 | if (typeOp) |
142 | return success(); |
143 | |
144 | AttributeOp attrOp = |
145 | symbolTable.lookupNearestSymbolFrom<AttributeOp>(*this, *baseRef); |
146 | if (attrOp) |
147 | return success(); |
148 | |
149 | return emitOpError() << "'" << *baseRef |
150 | << "' does not refer to a type or attribute definition" ; |
151 | } |
152 | |
153 | /// Parse a value with its variadicity first. By default, the variadicity is |
154 | /// single. |
155 | /// |
156 | /// value-with-variadicity ::= ("single" | "optional" | "variadic")? ssa-value |
157 | static ParseResult |
158 | parseValueWithVariadicity(OpAsmParser &p, |
159 | OpAsmParser::UnresolvedOperand &operand, |
160 | VariadicityAttr &variadicityAttr) { |
161 | MLIRContext *ctx = p.getBuilder().getContext(); |
162 | |
163 | // Parse the variadicity, if present |
164 | if (p.parseOptionalKeyword(keyword: "single" ).succeeded()) { |
165 | variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single); |
166 | } else if (p.parseOptionalKeyword(keyword: "optional" ).succeeded()) { |
167 | variadicityAttr = VariadicityAttr::get(ctx, Variadicity::optional); |
168 | } else if (p.parseOptionalKeyword(keyword: "variadic" ).succeeded()) { |
169 | variadicityAttr = VariadicityAttr::get(ctx, Variadicity::variadic); |
170 | } else { |
171 | variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single); |
172 | } |
173 | |
174 | // Parse the value |
175 | if (p.parseOperand(result&: operand)) |
176 | return failure(); |
177 | return success(); |
178 | } |
179 | |
180 | /// Parse a list of values with their variadicities first. By default, the |
181 | /// variadicity is single. |
182 | /// |
183 | /// values-with-variadicity ::= |
184 | /// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)` |
185 | /// value-with-variadicity ::= ("single" | "optional" | "variadic")? ssa-value |
186 | static ParseResult parseValuesWithVariadicity( |
187 | OpAsmParser &p, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands, |
188 | VariadicityArrayAttr &variadicityAttr) { |
189 | Builder &builder = p.getBuilder(); |
190 | MLIRContext *ctx = builder.getContext(); |
191 | SmallVector<VariadicityAttr> variadicities; |
192 | |
193 | // Parse a single value with its variadicity |
194 | auto parseOne = [&] { |
195 | OpAsmParser::UnresolvedOperand operand; |
196 | VariadicityAttr variadicity; |
197 | if (parseValueWithVariadicity(p, operand, variadicity)) |
198 | return failure(); |
199 | operands.push_back(Elt: operand); |
200 | variadicities.push_back(variadicity); |
201 | return success(); |
202 | }; |
203 | |
204 | if (p.parseCommaSeparatedList(delimiter: OpAsmParser::Delimiter::Paren, parseElementFn: parseOne)) |
205 | return failure(); |
206 | variadicityAttr = VariadicityArrayAttr::get(ctx, variadicities); |
207 | return success(); |
208 | } |
209 | |
210 | /// Print a list of values with their variadicities first. By default, the |
211 | /// variadicity is single. |
212 | /// |
213 | /// values-with-variadicity ::= |
214 | /// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)` |
215 | /// value-with-variadicity ::= ("single" | "optional" | "variadic")? ssa-value |
216 | static void printValuesWithVariadicity(OpAsmPrinter &p, Operation *op, |
217 | OperandRange operands, |
218 | VariadicityArrayAttr variadicityAttr) { |
219 | p << "(" ; |
220 | interleaveComma(c: llvm::seq<int>(Begin: 0, End: operands.size()), os&: p, each_fn: [&](int i) { |
221 | Variadicity variadicity = variadicityAttr[i].getValue(); |
222 | if (variadicity != Variadicity::single) { |
223 | p << stringifyVariadicity(variadicity) << " " ; |
224 | } |
225 | p << operands[i]; |
226 | }); |
227 | p << ")" ; |
228 | } |
229 | |
230 | static ParseResult |
231 | parseAttributesOp(OpAsmParser &p, |
232 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands, |
233 | ArrayAttr &attrNamesAttr) { |
234 | Builder &builder = p.getBuilder(); |
235 | SmallVector<Attribute> attrNames; |
236 | if (succeeded(result: p.parseOptionalLBrace())) { |
237 | auto parseOperands = [&]() { |
238 | if (p.parseAttribute(result&: attrNames.emplace_back()) || p.parseEqual() || |
239 | p.parseOperand(result&: attrOperands.emplace_back())) |
240 | return failure(); |
241 | return success(); |
242 | }; |
243 | if (p.parseCommaSeparatedList(parseElementFn: parseOperands) || p.parseRBrace()) |
244 | return failure(); |
245 | } |
246 | attrNamesAttr = builder.getArrayAttr(attrNames); |
247 | return success(); |
248 | } |
249 | |
250 | static void printAttributesOp(OpAsmPrinter &p, AttributesOp op, |
251 | OperandRange attrArgs, ArrayAttr attrNames) { |
252 | if (attrNames.empty()) |
253 | return; |
254 | p << "{" ; |
255 | interleaveComma(llvm::seq<int>(0, attrNames.size()), p, |
256 | [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; }); |
257 | p << '}'; |
258 | } |
259 | |
260 | LogicalResult RegionOp::verify() { |
261 | if (IntegerAttr numberOfBlocks = getNumberOfBlocksAttr()) |
262 | if (int64_t number = numberOfBlocks.getInt(); number <= 0) { |
263 | return emitOpError("the number of blocks is expected to be >= 1 but got " ) |
264 | << number; |
265 | } |
266 | return success(); |
267 | } |
268 | |
269 | #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.cpp.inc" |
270 | |
271 | #define GET_TYPEDEF_CLASSES |
272 | #include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc" |
273 | |
274 | #include "mlir/Dialect/IRDL/IR/IRDLEnums.cpp.inc" |
275 | |
276 | #define GET_ATTRDEF_CLASSES |
277 | #include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc" |
278 | |
279 | #define GET_OP_CLASSES |
280 | #include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc" |
281 | |