1//===- OpDefinitionsGen.cpp - IRDL op definitions generator ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// OpDefinitionsGen uses the description of operations to generate IRDL
10// definitions for ops.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/IRDL/IR/IRDL.h"
15#include "mlir/IR/Attributes.h"
16#include "mlir/IR/Builders.h"
17#include "mlir/IR/BuiltinOps.h"
18#include "mlir/IR/Diagnostics.h"
19#include "mlir/IR/Dialect.h"
20#include "mlir/IR/MLIRContext.h"
21#include "mlir/TableGen/AttrOrTypeDef.h"
22#include "mlir/TableGen/GenInfo.h"
23#include "mlir/TableGen/GenNameParser.h"
24#include "mlir/TableGen/Interfaces.h"
25#include "mlir/TableGen/Operator.h"
26#include "llvm/ADT/StringExtras.h"
27#include "llvm/Support/CommandLine.h"
28#include "llvm/Support/InitLLVM.h"
29#include "llvm/Support/raw_ostream.h"
30#include "llvm/TableGen/Main.h"
31#include "llvm/TableGen/Record.h"
32#include "llvm/TableGen/TableGenBackend.h"
33
34using namespace llvm;
35using namespace mlir;
36using tblgen::NamedTypeConstraint;
37
38static llvm::cl::OptionCategory dialectGenCat("Options for -gen-irdl-dialect");
39static llvm::cl::opt<std::string>
40 selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
41 llvm::cl::cat(dialectGenCat), llvm::cl::Required);
42
43Value createPredicate(OpBuilder &builder, tblgen::Pred pred) {
44 MLIRContext *ctx = builder.getContext();
45
46 if (pred.isCombined()) {
47 auto combiner = pred.getDef().getValueAsDef(FieldName: "kind")->getName();
48 if (combiner == "PredCombinerAnd" || combiner == "PredCombinerOr") {
49 std::vector<Value> constraints;
50 for (auto *child : pred.getDef().getValueAsListOfDefs(FieldName: "children")) {
51 constraints.push_back(x: createPredicate(builder, pred: tblgen::Pred(child)));
52 }
53 if (combiner == "PredCombinerAnd") {
54 auto op =
55 builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
56 return op.getOutput();
57 }
58 auto op =
59 builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
60 return op.getOutput();
61 }
62 }
63
64 std::string condition = pred.getCondition();
65 // Build a CPredOp to match the C constraint built.
66 irdl::CPredOp op = builder.create<irdl::CPredOp>(
67 UnknownLoc::get(ctx), StringAttr::get(ctx, condition));
68 return op;
69}
70
71Value typeToConstraint(OpBuilder &builder, Type type) {
72 MLIRContext *ctx = builder.getContext();
73 auto op =
74 builder.create<irdl::IsOp>(UnknownLoc::get(ctx), TypeAttr::get(type));
75 return op.getOutput();
76}
77
78Value baseToConstraint(OpBuilder &builder, StringRef baseClass) {
79 MLIRContext *ctx = builder.getContext();
80 auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
81 StringAttr::get(ctx, baseClass));
82 return op.getOutput();
83}
84
85std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
86 if (predRec.isSubClassOf(Name: "I")) {
87 auto width = predRec.getValueAsInt(FieldName: "bitwidth");
88 return IntegerType::get(ctx, width, IntegerType::Signless);
89 }
90
91 if (predRec.isSubClassOf(Name: "SI")) {
92 auto width = predRec.getValueAsInt(FieldName: "bitwidth");
93 return IntegerType::get(ctx, width, IntegerType::Signed);
94 }
95
96 if (predRec.isSubClassOf(Name: "UI")) {
97 auto width = predRec.getValueAsInt(FieldName: "bitwidth");
98 return IntegerType::get(ctx, width, IntegerType::Unsigned);
99 }
100
101 // Index type
102 if (predRec.getName() == "Index") {
103 return IndexType::get(ctx);
104 }
105
106 // Float types
107 if (predRec.isSubClassOf(Name: "F")) {
108 auto width = predRec.getValueAsInt(FieldName: "bitwidth");
109 switch (width) {
110 case 16:
111 return Float16Type::get(ctx);
112 case 32:
113 return Float32Type::get(ctx);
114 case 64:
115 return Float64Type::get(ctx);
116 case 80:
117 return Float80Type::get(ctx);
118 case 128:
119 return Float128Type::get(ctx);
120 }
121 }
122
123 if (predRec.getName() == "NoneType") {
124 return NoneType::get(ctx);
125 }
126
127 if (predRec.getName() == "BF16") {
128 return BFloat16Type::get(ctx);
129 }
130
131 if (predRec.getName() == "TF32") {
132 return FloatTF32Type::get(ctx);
133 }
134
135 if (predRec.getName() == "F8E4M3FN") {
136 return Float8E4M3FNType::get(ctx);
137 }
138
139 if (predRec.getName() == "F8E5M2") {
140 return Float8E5M2Type::get(ctx);
141 }
142
143 if (predRec.getName() == "F8E4M3") {
144 return Float8E4M3Type::get(ctx);
145 }
146
147 if (predRec.getName() == "F8E4M3FNUZ") {
148 return Float8E4M3FNUZType::get(ctx);
149 }
150
151 if (predRec.getName() == "F8E4M3B11FNUZ") {
152 return Float8E4M3B11FNUZType::get(ctx);
153 }
154
155 if (predRec.getName() == "F8E5M2FNUZ") {
156 return Float8E5M2FNUZType::get(ctx);
157 }
158
159 if (predRec.getName() == "F8E3M4") {
160 return Float8E3M4Type::get(ctx);
161 }
162
163 if (predRec.isSubClassOf(Name: "Complex")) {
164 const Record *elementRec = predRec.getValueAsDef(FieldName: "elementType");
165 auto elementType = recordToType(ctx, predRec: *elementRec);
166 if (elementType.has_value()) {
167 return ComplexType::get(elementType.value());
168 }
169 }
170
171 return std::nullopt;
172}
173
174Value createTypeConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
175 MLIRContext *ctx = builder.getContext();
176 const Record &predRec = constraint.getDef();
177
178 if (predRec.isSubClassOf(Name: "Variadic") || predRec.isSubClassOf(Name: "Optional"))
179 return createTypeConstraint(builder, constraint: predRec.getValueAsDef(FieldName: "baseType"));
180
181 if (predRec.getName() == "AnyType") {
182 auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
183 return op.getOutput();
184 }
185
186 if (predRec.isSubClassOf(Name: "TypeDef")) {
187 auto dialect = predRec.getValueAsDef(FieldName: "dialect")->getValueAsString(FieldName: "name");
188 if (dialect == selectedDialect) {
189 std::string combined = ("!" + predRec.getValueAsString(FieldName: "mnemonic")).str();
190 SmallVector<FlatSymbolRefAttr> nested = {
191 SymbolRefAttr::get(ctx, combined)};
192 auto typeSymbol = SymbolRefAttr::get(ctx, dialect, nested);
193 auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), typeSymbol);
194 return op.getOutput();
195 }
196 std::string typeName = ("!" + predRec.getValueAsString(FieldName: "typeName")).str();
197 auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
198 StringAttr::get(ctx, typeName));
199 return op.getOutput();
200 }
201
202 if (predRec.isSubClassOf(Name: "AnyTypeOf")) {
203 std::vector<Value> constraints;
204 for (const Record *child : predRec.getValueAsListOfDefs(FieldName: "allowedTypes")) {
205 constraints.push_back(
206 x: createTypeConstraint(builder, constraint: tblgen::Constraint(child)));
207 }
208 auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
209 return op.getOutput();
210 }
211
212 if (predRec.isSubClassOf(Name: "AllOfType")) {
213 std::vector<Value> constraints;
214 for (const Record *child : predRec.getValueAsListOfDefs(FieldName: "allowedTypes")) {
215 constraints.push_back(
216 x: createTypeConstraint(builder, constraint: tblgen::Constraint(child)));
217 }
218 auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
219 return op.getOutput();
220 }
221
222 // Integer types
223 if (predRec.getName() == "AnyInteger") {
224 auto op = builder.create<irdl::BaseOp>(
225 UnknownLoc::get(ctx), StringAttr::get(ctx, "!builtin.integer"));
226 return op.getOutput();
227 }
228
229 if (predRec.isSubClassOf(Name: "AnyI")) {
230 auto width = predRec.getValueAsInt(FieldName: "bitwidth");
231 std::vector<Value> types = {
232 typeToConstraint(builder,
233 IntegerType::get(ctx, width, IntegerType::Signless)),
234 typeToConstraint(builder,
235 IntegerType::get(ctx, width, IntegerType::Signed)),
236 typeToConstraint(builder,
237 IntegerType::get(ctx, width, IntegerType::Unsigned))};
238 auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), types);
239 return op.getOutput();
240 }
241
242 auto type = recordToType(ctx, predRec);
243
244 if (type.has_value()) {
245 return typeToConstraint(builder, type: type.value());
246 }
247
248 // Confined type
249 if (predRec.isSubClassOf(Name: "ConfinedType")) {
250 std::vector<Value> constraints;
251 constraints.push_back(x: createTypeConstraint(
252 builder, constraint: tblgen::Constraint(predRec.getValueAsDef(FieldName: "baseType"))));
253 for (const Record *child : predRec.getValueAsListOfDefs(FieldName: "predicateList")) {
254 constraints.push_back(x: createPredicate(builder, pred: tblgen::Pred(child)));
255 }
256 auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
257 return op.getOutput();
258 }
259
260 return createPredicate(builder, pred: constraint.getPredicate());
261}
262
263Value createAttrConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
264 MLIRContext *ctx = builder.getContext();
265 const Record &predRec = constraint.getDef();
266
267 if (predRec.isSubClassOf(Name: "DefaultValuedAttr") ||
268 predRec.isSubClassOf(Name: "DefaultValuedOptionalAttr") ||
269 predRec.isSubClassOf(Name: "OptionalAttr")) {
270 return createAttrConstraint(builder, constraint: predRec.getValueAsDef(FieldName: "baseAttr"));
271 }
272
273 if (predRec.isSubClassOf(Name: "ConfinedAttr")) {
274 std::vector<Value> constraints;
275 constraints.push_back(x: createAttrConstraint(
276 builder, constraint: tblgen::Constraint(predRec.getValueAsDef(FieldName: "baseAttr"))));
277 for (const Record *child :
278 predRec.getValueAsListOfDefs(FieldName: "attrConstraints")) {
279 constraints.push_back(x: createPredicate(
280 builder, pred: tblgen::Pred(child->getValueAsDef(FieldName: "predicate"))));
281 }
282 auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
283 return op.getOutput();
284 }
285
286 if (predRec.isSubClassOf(Name: "AnyAttrOf")) {
287 std::vector<Value> constraints;
288 for (const Record *child :
289 predRec.getValueAsListOfDefs(FieldName: "allowedAttributes")) {
290 constraints.push_back(
291 x: createAttrConstraint(builder, constraint: tblgen::Constraint(child)));
292 }
293 auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
294 return op.getOutput();
295 }
296
297 if (predRec.getName() == "AnyAttr") {
298 auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
299 return op.getOutput();
300 }
301
302 if (predRec.isSubClassOf(Name: "AnyIntegerAttrBase") ||
303 predRec.isSubClassOf(Name: "SignlessIntegerAttrBase") ||
304 predRec.isSubClassOf(Name: "SignedIntegerAttrBase") ||
305 predRec.isSubClassOf(Name: "UnsignedIntegerAttrBase") ||
306 predRec.isSubClassOf(Name: "BoolAttr")) {
307 return baseToConstraint(builder, baseClass: "!builtin.integer");
308 }
309
310 if (predRec.isSubClassOf(Name: "FloatAttrBase")) {
311 return baseToConstraint(builder, baseClass: "!builtin.float");
312 }
313
314 if (predRec.isSubClassOf(Name: "StringBasedAttr")) {
315 return baseToConstraint(builder, baseClass: "!builtin.string");
316 }
317
318 if (predRec.getName() == "UnitAttr") {
319 auto op =
320 builder.create<irdl::IsOp>(UnknownLoc::get(ctx), UnitAttr::get(ctx));
321 return op.getOutput();
322 }
323
324 if (predRec.isSubClassOf(Name: "AttrDef")) {
325 auto dialect = predRec.getValueAsDef(FieldName: "dialect")->getValueAsString(FieldName: "name");
326 if (dialect == selectedDialect) {
327 std::string combined = ("#" + predRec.getValueAsString(FieldName: "mnemonic")).str();
328 SmallVector<FlatSymbolRefAttr> nested = {SymbolRefAttr::get(ctx, combined)
329
330 };
331 auto typeSymbol = SymbolRefAttr::get(ctx, dialect, nested);
332 auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), typeSymbol);
333 return op.getOutput();
334 }
335 std::string typeName = ("#" + predRec.getValueAsString(FieldName: "attrName")).str();
336 auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
337 StringAttr::get(ctx, typeName));
338 return op.getOutput();
339 }
340
341 return createPredicate(builder, pred: constraint.getPredicate());
342}
343
344Value createRegionConstraint(OpBuilder &builder, tblgen::Region constraint) {
345 MLIRContext *ctx = builder.getContext();
346 const Record &predRec = constraint.getDef();
347
348 if (predRec.getName() == "AnyRegion") {
349 ValueRange entryBlockArgs = {};
350 auto op =
351 builder.create<irdl::RegionOp>(UnknownLoc::get(ctx), entryBlockArgs);
352 return op.getResult();
353 }
354
355 if (predRec.isSubClassOf(Name: "SizedRegion")) {
356 ValueRange entryBlockArgs = {};
357 auto ty = IntegerType::get(ctx, 32);
358 auto op = builder.create<irdl::RegionOp>(
359 UnknownLoc::get(ctx), entryBlockArgs,
360 IntegerAttr::get(ty, predRec.getValueAsInt("blocks")));
361 return op.getResult();
362 }
363
364 return createPredicate(builder, pred: constraint.getPredicate());
365}
366
367/// Returns the name of the operation without the dialect prefix.
368static StringRef getOperatorName(tblgen::Operator &tblgenOp) {
369 StringRef opName = tblgenOp.getDef().getValueAsString(FieldName: "opName");
370 return opName;
371}
372
373/// Returns the name of the type without the dialect prefix.
374static StringRef getTypeName(tblgen::TypeDef &tblgenType) {
375 StringRef opName = tblgenType.getDef()->getValueAsString(FieldName: "mnemonic");
376 return opName;
377}
378
379/// Returns the name of the attr without the dialect prefix.
380static StringRef getAttrName(tblgen::AttrDef &tblgenType) {
381 StringRef opName = tblgenType.getDef()->getValueAsString(FieldName: "mnemonic");
382 return opName;
383}
384
385/// Extract an operation to IRDL.
386irdl::OperationOp createIRDLOperation(OpBuilder &builder,
387 tblgen::Operator &tblgenOp) {
388 MLIRContext *ctx = builder.getContext();
389 StringRef opName = getOperatorName(tblgenOp);
390
391 irdl::OperationOp op = builder.create<irdl::OperationOp>(
392 UnknownLoc::get(ctx), StringAttr::get(ctx, opName));
393
394 // Add the block in the region.
395 Block &opBlock = op.getBody().emplaceBlock();
396 OpBuilder consBuilder = OpBuilder::atBlockBegin(block: &opBlock);
397
398 SmallDenseSet<StringRef> usedNames;
399 for (auto &namedCons : tblgenOp.getOperands())
400 usedNames.insert(V: namedCons.name);
401 for (auto &namedCons : tblgenOp.getResults())
402 usedNames.insert(V: namedCons.name);
403 for (auto &namedReg : tblgenOp.getRegions())
404 usedNames.insert(V: namedReg.name);
405
406 size_t generateCounter = 0;
407 auto generateName = [&](StringRef prefix) -> StringAttr {
408 SmallString<16> candidate;
409 do {
410 candidate.clear();
411 raw_svector_ostream candidateStream(candidate);
412 candidateStream << prefix << generateCounter;
413 generateCounter++;
414 } while (usedNames.contains(V: candidate));
415 return StringAttr::get(ctx, candidate);
416 };
417 auto normalizeName = [&](StringRef name) -> StringAttr {
418 if (name == "")
419 return generateName("unnamed");
420 return StringAttr::get(ctx, name);
421 };
422
423 auto getValues = [&](tblgen::Operator::const_value_range namedCons) {
424 SmallVector<Value> operands;
425 SmallVector<Attribute> names;
426 SmallVector<irdl::VariadicityAttr> variadicity;
427
428 for (const NamedTypeConstraint &namedCons : namedCons) {
429 auto operand = createTypeConstraint(builder&: consBuilder, constraint: namedCons.constraint);
430 operands.push_back(Elt: operand);
431
432 names.push_back(Elt: normalizeName(namedCons.name));
433
434 irdl::VariadicityAttr var;
435 if (namedCons.isOptional())
436 var = consBuilder.getAttr<irdl::VariadicityAttr>(
437 irdl::Variadicity::optional);
438 else if (namedCons.isVariadic())
439 var = consBuilder.getAttr<irdl::VariadicityAttr>(
440 irdl::Variadicity::variadic);
441 else
442 var = consBuilder.getAttr<irdl::VariadicityAttr>(
443 irdl::Variadicity::single);
444
445 variadicity.push_back(var);
446 }
447 return std::make_tuple(operands, names, variadicity);
448 };
449
450 auto [operands, operandNames, operandVariadicity] =
451 getValues(tblgenOp.getOperands());
452 auto [results, resultNames, resultVariadicity] =
453 getValues(tblgenOp.getResults());
454
455 SmallVector<Value> attributes;
456 SmallVector<Attribute> attrNames;
457 for (auto namedAttr : tblgenOp.getAttributes()) {
458 if (namedAttr.attr.isOptional())
459 continue;
460 attributes.push_back(Elt: createAttrConstraint(builder&: consBuilder, constraint: namedAttr.attr));
461 attrNames.push_back(StringAttr::get(ctx, namedAttr.name));
462 }
463
464 SmallVector<Value> regions;
465 SmallVector<Attribute> regionNames;
466 for (auto namedRegion : tblgenOp.getRegions()) {
467 regions.push_back(
468 Elt: createRegionConstraint(builder&: consBuilder, constraint: namedRegion.constraint));
469 regionNames.push_back(Elt: normalizeName(namedRegion.name));
470 }
471
472 // Create the operands and results operations.
473 if (!operands.empty())
474 consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
475 ArrayAttr::get(ctx, operandNames),
476 operandVariadicity);
477 if (!results.empty())
478 consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
479 ArrayAttr::get(ctx, resultNames),
480 resultVariadicity);
481 if (!attributes.empty())
482 consBuilder.create<irdl::AttributesOp>(UnknownLoc::get(ctx), attributes,
483 ArrayAttr::get(ctx, attrNames));
484 if (!regions.empty())
485 consBuilder.create<irdl::RegionsOp>(UnknownLoc::get(ctx), regions,
486 ArrayAttr::get(ctx, regionNames));
487
488 return op;
489}
490
491irdl::TypeOp createIRDLType(OpBuilder &builder, tblgen::TypeDef &tblgenType) {
492 MLIRContext *ctx = builder.getContext();
493 StringRef typeName = getTypeName(tblgenType);
494 std::string combined = ("!" + typeName).str();
495
496 irdl::TypeOp op = builder.create<irdl::TypeOp>(
497 UnknownLoc::get(ctx), StringAttr::get(ctx, combined));
498
499 op.getBody().emplaceBlock();
500
501 return op;
502}
503
504irdl::AttributeOp createIRDLAttr(OpBuilder &builder,
505 tblgen::AttrDef &tblgenAttr) {
506 MLIRContext *ctx = builder.getContext();
507 StringRef attrName = getAttrName(tblgenType&: tblgenAttr);
508 std::string combined = ("#" + attrName).str();
509
510 irdl::AttributeOp op = builder.create<irdl::AttributeOp>(
511 UnknownLoc::get(ctx), StringAttr::get(ctx, combined));
512
513 op.getBody().emplaceBlock();
514
515 return op;
516}
517
518static irdl::DialectOp createIRDLDialect(OpBuilder &builder) {
519 MLIRContext *ctx = builder.getContext();
520 return builder.create<irdl::DialectOp>(UnknownLoc::get(ctx),
521 StringAttr::get(ctx, selectedDialect));
522}
523
524static bool emitDialectIRDLDefs(const RecordKeeper &records, raw_ostream &os) {
525 // Initialize.
526 MLIRContext ctx;
527 ctx.getOrLoadDialect<irdl::IRDLDialect>();
528 OpBuilder builder(&ctx);
529
530 // Create a module op and set it as the insertion point.
531 OwningOpRef<ModuleOp> module =
532 builder.create<ModuleOp>(UnknownLoc::get(&ctx));
533 builder = builder.atBlockBegin(block: module->getBody());
534 // Create the dialect and insert it.
535 irdl::DialectOp dialect = createIRDLDialect(builder);
536 // Set insertion point to start of DialectOp.
537 builder = builder.atBlockBegin(block: &dialect.getBody().emplaceBlock());
538
539 for (const Record *type :
540 records.getAllDerivedDefinitionsIfDefined(ClassName: "TypeDef")) {
541 tblgen::TypeDef tblgenType(type);
542 if (tblgenType.getDialect().getName() != selectedDialect)
543 continue;
544 createIRDLType(builder, tblgenType);
545 }
546
547 for (const Record *attr :
548 records.getAllDerivedDefinitionsIfDefined(ClassName: "AttrDef")) {
549 tblgen::AttrDef tblgenAttr(attr);
550 if (tblgenAttr.getDialect().getName() != selectedDialect)
551 continue;
552 createIRDLAttr(builder, tblgenAttr);
553 }
554
555 for (const Record *def : records.getAllDerivedDefinitionsIfDefined(ClassName: "Op")) {
556 tblgen::Operator tblgenOp(def);
557 if (tblgenOp.getDialectName() != selectedDialect)
558 continue;
559
560 createIRDLOperation(builder, tblgenOp);
561 }
562
563 // Print the module.
564 module->print(os);
565
566 return false;
567}
568
569static mlir::GenRegistration
570 genOpDefs("gen-dialect-irdl-defs", "Generate IRDL dialect definitions",
571 [](const RecordKeeper &records, raw_ostream &os) {
572 return emitDialectIRDLDefs(records, os);
573 });
574

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp