1 | //===- OpDefinitionsGen.cpp - IRDL op definitions generator ---------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // OpDefinitionsGen uses the description of operations to generate IRDL |
10 | // definitions for ops. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/IRDL/IR/IRDL.h" |
15 | #include "mlir/IR/Attributes.h" |
16 | #include "mlir/IR/Builders.h" |
17 | #include "mlir/IR/BuiltinOps.h" |
18 | #include "mlir/IR/Diagnostics.h" |
19 | #include "mlir/IR/Dialect.h" |
20 | #include "mlir/IR/MLIRContext.h" |
21 | #include "mlir/TableGen/AttrOrTypeDef.h" |
22 | #include "mlir/TableGen/GenInfo.h" |
23 | #include "mlir/TableGen/GenNameParser.h" |
24 | #include "mlir/TableGen/Interfaces.h" |
25 | #include "mlir/TableGen/Operator.h" |
26 | #include "llvm/ADT/StringExtras.h" |
27 | #include "llvm/Support/CommandLine.h" |
28 | #include "llvm/Support/InitLLVM.h" |
29 | #include "llvm/Support/raw_ostream.h" |
30 | #include "llvm/TableGen/Main.h" |
31 | #include "llvm/TableGen/Record.h" |
32 | #include "llvm/TableGen/TableGenBackend.h" |
33 | |
34 | using namespace llvm; |
35 | using namespace mlir; |
36 | using tblgen::NamedTypeConstraint; |
37 | |
38 | static llvm::cl::OptionCategory dialectGenCat("Options for -gen-irdl-dialect" ); |
39 | static llvm::cl::opt<std::string> |
40 | selectedDialect("dialect" , llvm::cl::desc("The dialect to gen for" ), |
41 | llvm::cl::cat(dialectGenCat), llvm::cl::Required); |
42 | |
43 | Value createPredicate(OpBuilder &builder, tblgen::Pred pred) { |
44 | MLIRContext *ctx = builder.getContext(); |
45 | |
46 | if (pred.isCombined()) { |
47 | auto combiner = pred.getDef().getValueAsDef(FieldName: "kind" )->getName(); |
48 | if (combiner == "PredCombinerAnd" || combiner == "PredCombinerOr" ) { |
49 | std::vector<Value> constraints; |
50 | for (auto *child : pred.getDef().getValueAsListOfDefs(FieldName: "children" )) { |
51 | constraints.push_back(x: createPredicate(builder, pred: tblgen::Pred(child))); |
52 | } |
53 | if (combiner == "PredCombinerAnd" ) { |
54 | auto op = |
55 | builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints); |
56 | return op.getOutput(); |
57 | } |
58 | auto op = |
59 | builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints); |
60 | return op.getOutput(); |
61 | } |
62 | } |
63 | |
64 | std::string condition = pred.getCondition(); |
65 | // Build a CPredOp to match the C constraint built. |
66 | irdl::CPredOp op = builder.create<irdl::CPredOp>( |
67 | UnknownLoc::get(ctx), StringAttr::get(ctx, condition)); |
68 | return op; |
69 | } |
70 | |
71 | Value typeToConstraint(OpBuilder &builder, Type type) { |
72 | MLIRContext *ctx = builder.getContext(); |
73 | auto op = |
74 | builder.create<irdl::IsOp>(UnknownLoc::get(ctx), TypeAttr::get(type)); |
75 | return op.getOutput(); |
76 | } |
77 | |
78 | Value baseToConstraint(OpBuilder &builder, StringRef baseClass) { |
79 | MLIRContext *ctx = builder.getContext(); |
80 | auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), |
81 | StringAttr::get(ctx, baseClass)); |
82 | return op.getOutput(); |
83 | } |
84 | |
85 | std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) { |
86 | if (predRec.isSubClassOf(Name: "I" )) { |
87 | auto width = predRec.getValueAsInt(FieldName: "bitwidth" ); |
88 | return IntegerType::get(ctx, width, IntegerType::Signless); |
89 | } |
90 | |
91 | if (predRec.isSubClassOf(Name: "SI" )) { |
92 | auto width = predRec.getValueAsInt(FieldName: "bitwidth" ); |
93 | return IntegerType::get(ctx, width, IntegerType::Signed); |
94 | } |
95 | |
96 | if (predRec.isSubClassOf(Name: "UI" )) { |
97 | auto width = predRec.getValueAsInt(FieldName: "bitwidth" ); |
98 | return IntegerType::get(ctx, width, IntegerType::Unsigned); |
99 | } |
100 | |
101 | // Index type |
102 | if (predRec.getName() == "Index" ) { |
103 | return IndexType::get(ctx); |
104 | } |
105 | |
106 | // Float types |
107 | if (predRec.isSubClassOf(Name: "F" )) { |
108 | auto width = predRec.getValueAsInt(FieldName: "bitwidth" ); |
109 | switch (width) { |
110 | case 16: |
111 | return Float16Type::get(ctx); |
112 | case 32: |
113 | return Float32Type::get(ctx); |
114 | case 64: |
115 | return Float64Type::get(ctx); |
116 | case 80: |
117 | return Float80Type::get(ctx); |
118 | case 128: |
119 | return Float128Type::get(ctx); |
120 | } |
121 | } |
122 | |
123 | if (predRec.getName() == "NoneType" ) { |
124 | return NoneType::get(ctx); |
125 | } |
126 | |
127 | if (predRec.getName() == "BF16" ) { |
128 | return BFloat16Type::get(ctx); |
129 | } |
130 | |
131 | if (predRec.getName() == "TF32" ) { |
132 | return FloatTF32Type::get(ctx); |
133 | } |
134 | |
135 | if (predRec.getName() == "F8E4M3FN" ) { |
136 | return Float8E4M3FNType::get(ctx); |
137 | } |
138 | |
139 | if (predRec.getName() == "F8E5M2" ) { |
140 | return Float8E5M2Type::get(ctx); |
141 | } |
142 | |
143 | if (predRec.getName() == "F8E4M3" ) { |
144 | return Float8E4M3Type::get(ctx); |
145 | } |
146 | |
147 | if (predRec.getName() == "F8E4M3FNUZ" ) { |
148 | return Float8E4M3FNUZType::get(ctx); |
149 | } |
150 | |
151 | if (predRec.getName() == "F8E4M3B11FNUZ" ) { |
152 | return Float8E4M3B11FNUZType::get(ctx); |
153 | } |
154 | |
155 | if (predRec.getName() == "F8E5M2FNUZ" ) { |
156 | return Float8E5M2FNUZType::get(ctx); |
157 | } |
158 | |
159 | if (predRec.getName() == "F8E3M4" ) { |
160 | return Float8E3M4Type::get(ctx); |
161 | } |
162 | |
163 | if (predRec.isSubClassOf(Name: "Complex" )) { |
164 | const Record *elementRec = predRec.getValueAsDef(FieldName: "elementType" ); |
165 | auto elementType = recordToType(ctx, predRec: *elementRec); |
166 | if (elementType.has_value()) { |
167 | return ComplexType::get(elementType.value()); |
168 | } |
169 | } |
170 | |
171 | return std::nullopt; |
172 | } |
173 | |
174 | Value createTypeConstraint(OpBuilder &builder, tblgen::Constraint constraint) { |
175 | MLIRContext *ctx = builder.getContext(); |
176 | const Record &predRec = constraint.getDef(); |
177 | |
178 | if (predRec.isSubClassOf(Name: "Variadic" ) || predRec.isSubClassOf(Name: "Optional" )) |
179 | return createTypeConstraint(builder, constraint: predRec.getValueAsDef(FieldName: "baseType" )); |
180 | |
181 | if (predRec.getName() == "AnyType" ) { |
182 | auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx)); |
183 | return op.getOutput(); |
184 | } |
185 | |
186 | if (predRec.isSubClassOf(Name: "TypeDef" )) { |
187 | auto dialect = predRec.getValueAsDef(FieldName: "dialect" )->getValueAsString(FieldName: "name" ); |
188 | if (dialect == selectedDialect) { |
189 | std::string combined = ("!" + predRec.getValueAsString(FieldName: "mnemonic" )).str(); |
190 | SmallVector<FlatSymbolRefAttr> nested = { |
191 | SymbolRefAttr::get(ctx, combined)}; |
192 | auto typeSymbol = SymbolRefAttr::get(ctx, dialect, nested); |
193 | auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), typeSymbol); |
194 | return op.getOutput(); |
195 | } |
196 | std::string typeName = ("!" + predRec.getValueAsString(FieldName: "typeName" )).str(); |
197 | auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), |
198 | StringAttr::get(ctx, typeName)); |
199 | return op.getOutput(); |
200 | } |
201 | |
202 | if (predRec.isSubClassOf(Name: "AnyTypeOf" )) { |
203 | std::vector<Value> constraints; |
204 | for (const Record *child : predRec.getValueAsListOfDefs(FieldName: "allowedTypes" )) { |
205 | constraints.push_back( |
206 | x: createTypeConstraint(builder, constraint: tblgen::Constraint(child))); |
207 | } |
208 | auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints); |
209 | return op.getOutput(); |
210 | } |
211 | |
212 | if (predRec.isSubClassOf(Name: "AllOfType" )) { |
213 | std::vector<Value> constraints; |
214 | for (const Record *child : predRec.getValueAsListOfDefs(FieldName: "allowedTypes" )) { |
215 | constraints.push_back( |
216 | x: createTypeConstraint(builder, constraint: tblgen::Constraint(child))); |
217 | } |
218 | auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints); |
219 | return op.getOutput(); |
220 | } |
221 | |
222 | // Integer types |
223 | if (predRec.getName() == "AnyInteger" ) { |
224 | auto op = builder.create<irdl::BaseOp>( |
225 | UnknownLoc::get(ctx), StringAttr::get(ctx, "!builtin.integer" )); |
226 | return op.getOutput(); |
227 | } |
228 | |
229 | if (predRec.isSubClassOf(Name: "AnyI" )) { |
230 | auto width = predRec.getValueAsInt(FieldName: "bitwidth" ); |
231 | std::vector<Value> types = { |
232 | typeToConstraint(builder, |
233 | IntegerType::get(ctx, width, IntegerType::Signless)), |
234 | typeToConstraint(builder, |
235 | IntegerType::get(ctx, width, IntegerType::Signed)), |
236 | typeToConstraint(builder, |
237 | IntegerType::get(ctx, width, IntegerType::Unsigned))}; |
238 | auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), types); |
239 | return op.getOutput(); |
240 | } |
241 | |
242 | auto type = recordToType(ctx, predRec); |
243 | |
244 | if (type.has_value()) { |
245 | return typeToConstraint(builder, type: type.value()); |
246 | } |
247 | |
248 | // Confined type |
249 | if (predRec.isSubClassOf(Name: "ConfinedType" )) { |
250 | std::vector<Value> constraints; |
251 | constraints.push_back(x: createTypeConstraint( |
252 | builder, constraint: tblgen::Constraint(predRec.getValueAsDef(FieldName: "baseType" )))); |
253 | for (const Record *child : predRec.getValueAsListOfDefs(FieldName: "predicateList" )) { |
254 | constraints.push_back(x: createPredicate(builder, pred: tblgen::Pred(child))); |
255 | } |
256 | auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints); |
257 | return op.getOutput(); |
258 | } |
259 | |
260 | return createPredicate(builder, pred: constraint.getPredicate()); |
261 | } |
262 | |
263 | Value createAttrConstraint(OpBuilder &builder, tblgen::Constraint constraint) { |
264 | MLIRContext *ctx = builder.getContext(); |
265 | const Record &predRec = constraint.getDef(); |
266 | |
267 | if (predRec.isSubClassOf(Name: "DefaultValuedAttr" ) || |
268 | predRec.isSubClassOf(Name: "DefaultValuedOptionalAttr" ) || |
269 | predRec.isSubClassOf(Name: "OptionalAttr" )) { |
270 | return createAttrConstraint(builder, constraint: predRec.getValueAsDef(FieldName: "baseAttr" )); |
271 | } |
272 | |
273 | if (predRec.isSubClassOf(Name: "ConfinedAttr" )) { |
274 | std::vector<Value> constraints; |
275 | constraints.push_back(x: createAttrConstraint( |
276 | builder, constraint: tblgen::Constraint(predRec.getValueAsDef(FieldName: "baseAttr" )))); |
277 | for (const Record *child : |
278 | predRec.getValueAsListOfDefs(FieldName: "attrConstraints" )) { |
279 | constraints.push_back(x: createPredicate( |
280 | builder, pred: tblgen::Pred(child->getValueAsDef(FieldName: "predicate" )))); |
281 | } |
282 | auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints); |
283 | return op.getOutput(); |
284 | } |
285 | |
286 | if (predRec.isSubClassOf(Name: "AnyAttrOf" )) { |
287 | std::vector<Value> constraints; |
288 | for (const Record *child : |
289 | predRec.getValueAsListOfDefs(FieldName: "allowedAttributes" )) { |
290 | constraints.push_back( |
291 | x: createAttrConstraint(builder, constraint: tblgen::Constraint(child))); |
292 | } |
293 | auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints); |
294 | return op.getOutput(); |
295 | } |
296 | |
297 | if (predRec.getName() == "AnyAttr" ) { |
298 | auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx)); |
299 | return op.getOutput(); |
300 | } |
301 | |
302 | if (predRec.isSubClassOf(Name: "AnyIntegerAttrBase" ) || |
303 | predRec.isSubClassOf(Name: "SignlessIntegerAttrBase" ) || |
304 | predRec.isSubClassOf(Name: "SignedIntegerAttrBase" ) || |
305 | predRec.isSubClassOf(Name: "UnsignedIntegerAttrBase" ) || |
306 | predRec.isSubClassOf(Name: "BoolAttr" )) { |
307 | return baseToConstraint(builder, baseClass: "!builtin.integer" ); |
308 | } |
309 | |
310 | if (predRec.isSubClassOf(Name: "FloatAttrBase" )) { |
311 | return baseToConstraint(builder, baseClass: "!builtin.float" ); |
312 | } |
313 | |
314 | if (predRec.isSubClassOf(Name: "StringBasedAttr" )) { |
315 | return baseToConstraint(builder, baseClass: "!builtin.string" ); |
316 | } |
317 | |
318 | if (predRec.getName() == "UnitAttr" ) { |
319 | auto op = |
320 | builder.create<irdl::IsOp>(UnknownLoc::get(ctx), UnitAttr::get(ctx)); |
321 | return op.getOutput(); |
322 | } |
323 | |
324 | if (predRec.isSubClassOf(Name: "AttrDef" )) { |
325 | auto dialect = predRec.getValueAsDef(FieldName: "dialect" )->getValueAsString(FieldName: "name" ); |
326 | if (dialect == selectedDialect) { |
327 | std::string combined = ("#" + predRec.getValueAsString(FieldName: "mnemonic" )).str(); |
328 | SmallVector<FlatSymbolRefAttr> nested = {SymbolRefAttr::get(ctx, combined) |
329 | |
330 | }; |
331 | auto typeSymbol = SymbolRefAttr::get(ctx, dialect, nested); |
332 | auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), typeSymbol); |
333 | return op.getOutput(); |
334 | } |
335 | std::string typeName = ("#" + predRec.getValueAsString(FieldName: "attrName" )).str(); |
336 | auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), |
337 | StringAttr::get(ctx, typeName)); |
338 | return op.getOutput(); |
339 | } |
340 | |
341 | return createPredicate(builder, pred: constraint.getPredicate()); |
342 | } |
343 | |
344 | Value createRegionConstraint(OpBuilder &builder, tblgen::Region constraint) { |
345 | MLIRContext *ctx = builder.getContext(); |
346 | const Record &predRec = constraint.getDef(); |
347 | |
348 | if (predRec.getName() == "AnyRegion" ) { |
349 | ValueRange entryBlockArgs = {}; |
350 | auto op = |
351 | builder.create<irdl::RegionOp>(UnknownLoc::get(ctx), entryBlockArgs); |
352 | return op.getResult(); |
353 | } |
354 | |
355 | if (predRec.isSubClassOf(Name: "SizedRegion" )) { |
356 | ValueRange entryBlockArgs = {}; |
357 | auto ty = IntegerType::get(ctx, 32); |
358 | auto op = builder.create<irdl::RegionOp>( |
359 | UnknownLoc::get(ctx), entryBlockArgs, |
360 | IntegerAttr::get(ty, predRec.getValueAsInt("blocks" ))); |
361 | return op.getResult(); |
362 | } |
363 | |
364 | return createPredicate(builder, pred: constraint.getPredicate()); |
365 | } |
366 | |
367 | /// Returns the name of the operation without the dialect prefix. |
368 | static StringRef getOperatorName(tblgen::Operator &tblgenOp) { |
369 | StringRef opName = tblgenOp.getDef().getValueAsString(FieldName: "opName" ); |
370 | return opName; |
371 | } |
372 | |
373 | /// Returns the name of the type without the dialect prefix. |
374 | static StringRef getTypeName(tblgen::TypeDef &tblgenType) { |
375 | StringRef opName = tblgenType.getDef()->getValueAsString(FieldName: "mnemonic" ); |
376 | return opName; |
377 | } |
378 | |
379 | /// Returns the name of the attr without the dialect prefix. |
380 | static StringRef getAttrName(tblgen::AttrDef &tblgenType) { |
381 | StringRef opName = tblgenType.getDef()->getValueAsString(FieldName: "mnemonic" ); |
382 | return opName; |
383 | } |
384 | |
385 | /// Extract an operation to IRDL. |
386 | irdl::OperationOp createIRDLOperation(OpBuilder &builder, |
387 | tblgen::Operator &tblgenOp) { |
388 | MLIRContext *ctx = builder.getContext(); |
389 | StringRef opName = getOperatorName(tblgenOp); |
390 | |
391 | irdl::OperationOp op = builder.create<irdl::OperationOp>( |
392 | UnknownLoc::get(ctx), StringAttr::get(ctx, opName)); |
393 | |
394 | // Add the block in the region. |
395 | Block &opBlock = op.getBody().emplaceBlock(); |
396 | OpBuilder consBuilder = OpBuilder::atBlockBegin(block: &opBlock); |
397 | |
398 | SmallDenseSet<StringRef> usedNames; |
399 | for (auto &namedCons : tblgenOp.getOperands()) |
400 | usedNames.insert(V: namedCons.name); |
401 | for (auto &namedCons : tblgenOp.getResults()) |
402 | usedNames.insert(V: namedCons.name); |
403 | for (auto &namedReg : tblgenOp.getRegions()) |
404 | usedNames.insert(V: namedReg.name); |
405 | |
406 | size_t generateCounter = 0; |
407 | auto generateName = [&](StringRef prefix) -> StringAttr { |
408 | SmallString<16> candidate; |
409 | do { |
410 | candidate.clear(); |
411 | raw_svector_ostream candidateStream(candidate); |
412 | candidateStream << prefix << generateCounter; |
413 | generateCounter++; |
414 | } while (usedNames.contains(V: candidate)); |
415 | return StringAttr::get(ctx, candidate); |
416 | }; |
417 | auto normalizeName = [&](StringRef name) -> StringAttr { |
418 | if (name == "" ) |
419 | return generateName("unnamed" ); |
420 | return StringAttr::get(ctx, name); |
421 | }; |
422 | |
423 | auto getValues = [&](tblgen::Operator::const_value_range namedCons) { |
424 | SmallVector<Value> operands; |
425 | SmallVector<Attribute> names; |
426 | SmallVector<irdl::VariadicityAttr> variadicity; |
427 | |
428 | for (const NamedTypeConstraint &namedCons : namedCons) { |
429 | auto operand = createTypeConstraint(builder&: consBuilder, constraint: namedCons.constraint); |
430 | operands.push_back(Elt: operand); |
431 | |
432 | names.push_back(Elt: normalizeName(namedCons.name)); |
433 | |
434 | irdl::VariadicityAttr var; |
435 | if (namedCons.isOptional()) |
436 | var = consBuilder.getAttr<irdl::VariadicityAttr>( |
437 | irdl::Variadicity::optional); |
438 | else if (namedCons.isVariadic()) |
439 | var = consBuilder.getAttr<irdl::VariadicityAttr>( |
440 | irdl::Variadicity::variadic); |
441 | else |
442 | var = consBuilder.getAttr<irdl::VariadicityAttr>( |
443 | irdl::Variadicity::single); |
444 | |
445 | variadicity.push_back(var); |
446 | } |
447 | return std::make_tuple(operands, names, variadicity); |
448 | }; |
449 | |
450 | auto [operands, operandNames, operandVariadicity] = |
451 | getValues(tblgenOp.getOperands()); |
452 | auto [results, resultNames, resultVariadicity] = |
453 | getValues(tblgenOp.getResults()); |
454 | |
455 | SmallVector<Value> attributes; |
456 | SmallVector<Attribute> attrNames; |
457 | for (auto namedAttr : tblgenOp.getAttributes()) { |
458 | if (namedAttr.attr.isOptional()) |
459 | continue; |
460 | attributes.push_back(Elt: createAttrConstraint(builder&: consBuilder, constraint: namedAttr.attr)); |
461 | attrNames.push_back(StringAttr::get(ctx, namedAttr.name)); |
462 | } |
463 | |
464 | SmallVector<Value> regions; |
465 | SmallVector<Attribute> regionNames; |
466 | for (auto namedRegion : tblgenOp.getRegions()) { |
467 | regions.push_back( |
468 | Elt: createRegionConstraint(builder&: consBuilder, constraint: namedRegion.constraint)); |
469 | regionNames.push_back(Elt: normalizeName(namedRegion.name)); |
470 | } |
471 | |
472 | // Create the operands and results operations. |
473 | if (!operands.empty()) |
474 | consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands, |
475 | ArrayAttr::get(ctx, operandNames), |
476 | operandVariadicity); |
477 | if (!results.empty()) |
478 | consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results, |
479 | ArrayAttr::get(ctx, resultNames), |
480 | resultVariadicity); |
481 | if (!attributes.empty()) |
482 | consBuilder.create<irdl::AttributesOp>(UnknownLoc::get(ctx), attributes, |
483 | ArrayAttr::get(ctx, attrNames)); |
484 | if (!regions.empty()) |
485 | consBuilder.create<irdl::RegionsOp>(UnknownLoc::get(ctx), regions, |
486 | ArrayAttr::get(ctx, regionNames)); |
487 | |
488 | return op; |
489 | } |
490 | |
491 | irdl::TypeOp createIRDLType(OpBuilder &builder, tblgen::TypeDef &tblgenType) { |
492 | MLIRContext *ctx = builder.getContext(); |
493 | StringRef typeName = getTypeName(tblgenType); |
494 | std::string combined = ("!" + typeName).str(); |
495 | |
496 | irdl::TypeOp op = builder.create<irdl::TypeOp>( |
497 | UnknownLoc::get(ctx), StringAttr::get(ctx, combined)); |
498 | |
499 | op.getBody().emplaceBlock(); |
500 | |
501 | return op; |
502 | } |
503 | |
504 | irdl::AttributeOp createIRDLAttr(OpBuilder &builder, |
505 | tblgen::AttrDef &tblgenAttr) { |
506 | MLIRContext *ctx = builder.getContext(); |
507 | StringRef attrName = getAttrName(tblgenType&: tblgenAttr); |
508 | std::string combined = ("#" + attrName).str(); |
509 | |
510 | irdl::AttributeOp op = builder.create<irdl::AttributeOp>( |
511 | UnknownLoc::get(ctx), StringAttr::get(ctx, combined)); |
512 | |
513 | op.getBody().emplaceBlock(); |
514 | |
515 | return op; |
516 | } |
517 | |
518 | static irdl::DialectOp createIRDLDialect(OpBuilder &builder) { |
519 | MLIRContext *ctx = builder.getContext(); |
520 | return builder.create<irdl::DialectOp>(UnknownLoc::get(ctx), |
521 | StringAttr::get(ctx, selectedDialect)); |
522 | } |
523 | |
524 | static bool emitDialectIRDLDefs(const RecordKeeper &records, raw_ostream &os) { |
525 | // Initialize. |
526 | MLIRContext ctx; |
527 | ctx.getOrLoadDialect<irdl::IRDLDialect>(); |
528 | OpBuilder builder(&ctx); |
529 | |
530 | // Create a module op and set it as the insertion point. |
531 | OwningOpRef<ModuleOp> module = |
532 | builder.create<ModuleOp>(UnknownLoc::get(&ctx)); |
533 | builder = builder.atBlockBegin(block: module->getBody()); |
534 | // Create the dialect and insert it. |
535 | irdl::DialectOp dialect = createIRDLDialect(builder); |
536 | // Set insertion point to start of DialectOp. |
537 | builder = builder.atBlockBegin(block: &dialect.getBody().emplaceBlock()); |
538 | |
539 | for (const Record *type : |
540 | records.getAllDerivedDefinitionsIfDefined(ClassName: "TypeDef" )) { |
541 | tblgen::TypeDef tblgenType(type); |
542 | if (tblgenType.getDialect().getName() != selectedDialect) |
543 | continue; |
544 | createIRDLType(builder, tblgenType); |
545 | } |
546 | |
547 | for (const Record *attr : |
548 | records.getAllDerivedDefinitionsIfDefined(ClassName: "AttrDef" )) { |
549 | tblgen::AttrDef tblgenAttr(attr); |
550 | if (tblgenAttr.getDialect().getName() != selectedDialect) |
551 | continue; |
552 | createIRDLAttr(builder, tblgenAttr); |
553 | } |
554 | |
555 | for (const Record *def : records.getAllDerivedDefinitionsIfDefined(ClassName: "Op" )) { |
556 | tblgen::Operator tblgenOp(def); |
557 | if (tblgenOp.getDialectName() != selectedDialect) |
558 | continue; |
559 | |
560 | createIRDLOperation(builder, tblgenOp); |
561 | } |
562 | |
563 | // Print the module. |
564 | module->print(os); |
565 | |
566 | return false; |
567 | } |
568 | |
569 | static mlir::GenRegistration |
570 | genOpDefs("gen-dialect-irdl-defs" , "Generate IRDL dialect definitions" , |
571 | [](const RecordKeeper &records, raw_ostream &os) { |
572 | return emitDialectIRDLDefs(records, os); |
573 | }); |
574 | |