1 | //===- OpDefinitionsGen.cpp - IRDL op definitions generator ---------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // OpDefinitionsGen uses the description of operations to generate IRDL |
10 | // definitions for ops. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/IRDL/IR/IRDL.h" |
15 | #include "mlir/IR/Attributes.h" |
16 | #include "mlir/IR/Builders.h" |
17 | #include "mlir/IR/BuiltinOps.h" |
18 | #include "mlir/IR/Diagnostics.h" |
19 | #include "mlir/IR/Dialect.h" |
20 | #include "mlir/IR/MLIRContext.h" |
21 | #include "mlir/TableGen/AttrOrTypeDef.h" |
22 | #include "mlir/TableGen/GenInfo.h" |
23 | #include "mlir/TableGen/GenNameParser.h" |
24 | #include "mlir/TableGen/Interfaces.h" |
25 | #include "mlir/TableGen/Operator.h" |
26 | #include "llvm/Support/CommandLine.h" |
27 | #include "llvm/Support/InitLLVM.h" |
28 | #include "llvm/Support/raw_ostream.h" |
29 | #include "llvm/TableGen/Main.h" |
30 | #include "llvm/TableGen/Record.h" |
31 | #include "llvm/TableGen/TableGenBackend.h" |
32 | |
33 | using namespace llvm; |
34 | using namespace mlir; |
35 | using tblgen::NamedTypeConstraint; |
36 | |
37 | static llvm::cl::OptionCategory dialectGenCat("Options for -gen-irdl-dialect" ); |
38 | llvm::cl::opt<std::string> |
39 | selectedDialect("dialect" , llvm::cl::desc("The dialect to gen for" ), |
40 | llvm::cl::cat(dialectGenCat), llvm::cl::Required); |
41 | |
42 | Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) { |
43 | MLIRContext *ctx = builder.getContext(); |
44 | const Record &predRec = constraint.getDef(); |
45 | |
46 | if (predRec.isSubClassOf(Name: "Variadic" ) || predRec.isSubClassOf(Name: "Optional" )) |
47 | return createConstraint(builder, constraint: predRec.getValueAsDef(FieldName: "baseType" )); |
48 | |
49 | if (predRec.getName() == "AnyType" ) { |
50 | auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx)); |
51 | return op.getOutput(); |
52 | } |
53 | |
54 | if (predRec.isSubClassOf(Name: "TypeDef" )) { |
55 | std::string typeName = ("!" + predRec.getValueAsString(FieldName: "typeName" )).str(); |
56 | auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx), |
57 | StringAttr::get(ctx, typeName)); |
58 | return op.getOutput(); |
59 | } |
60 | |
61 | if (predRec.isSubClassOf(Name: "AnyTypeOf" )) { |
62 | std::vector<Value> constraints; |
63 | for (Record *child : predRec.getValueAsListOfDefs(FieldName: "allowedTypes" )) { |
64 | constraints.push_back( |
65 | x: createConstraint(builder, constraint: tblgen::Constraint(child))); |
66 | } |
67 | auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints); |
68 | return op.getOutput(); |
69 | } |
70 | |
71 | if (predRec.isSubClassOf(Name: "AllOfType" )) { |
72 | std::vector<Value> constraints; |
73 | for (Record *child : predRec.getValueAsListOfDefs(FieldName: "allowedTypes" )) { |
74 | constraints.push_back( |
75 | x: createConstraint(builder, constraint: tblgen::Constraint(child))); |
76 | } |
77 | auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints); |
78 | return op.getOutput(); |
79 | } |
80 | |
81 | std::string condition = constraint.getPredicate().getCondition(); |
82 | // Build a CPredOp to match the C constraint built. |
83 | irdl::CPredOp op = builder.create<irdl::CPredOp>( |
84 | UnknownLoc::get(ctx), StringAttr::get(ctx, condition)); |
85 | return op; |
86 | } |
87 | |
88 | /// Returns the name of the operation without the dialect prefix. |
89 | static StringRef getOperatorName(tblgen::Operator &tblgenOp) { |
90 | StringRef opName = tblgenOp.getDef().getValueAsString(FieldName: "opName" ); |
91 | return opName; |
92 | } |
93 | |
94 | /// Extract an operation to IRDL. |
95 | irdl::OperationOp createIRDLOperation(OpBuilder &builder, |
96 | tblgen::Operator &tblgenOp) { |
97 | MLIRContext *ctx = builder.getContext(); |
98 | StringRef opName = getOperatorName(tblgenOp); |
99 | |
100 | irdl::OperationOp op = builder.create<irdl::OperationOp>( |
101 | UnknownLoc::get(ctx), StringAttr::get(ctx, opName)); |
102 | |
103 | // Add the block in the region. |
104 | Block &opBlock = op.getBody().emplaceBlock(); |
105 | OpBuilder consBuilder = OpBuilder::atBlockBegin(block: &opBlock); |
106 | |
107 | auto getValues = [&](tblgen::Operator::const_value_range namedCons) { |
108 | SmallVector<Value> operands; |
109 | SmallVector<irdl::VariadicityAttr> variadicity; |
110 | for (const NamedTypeConstraint &namedCons : namedCons) { |
111 | auto operand = createConstraint(builder&: consBuilder, constraint: namedCons.constraint); |
112 | operands.push_back(Elt: operand); |
113 | |
114 | irdl::VariadicityAttr var; |
115 | if (namedCons.isOptional()) |
116 | var = consBuilder.getAttr<irdl::VariadicityAttr>( |
117 | irdl::Variadicity::optional); |
118 | else if (namedCons.isVariadic()) |
119 | var = consBuilder.getAttr<irdl::VariadicityAttr>( |
120 | irdl::Variadicity::variadic); |
121 | else |
122 | var = consBuilder.getAttr<irdl::VariadicityAttr>( |
123 | irdl::Variadicity::single); |
124 | |
125 | variadicity.push_back(var); |
126 | } |
127 | return std::make_tuple(operands, variadicity); |
128 | }; |
129 | |
130 | auto [operands, operandVariadicity] = getValues(tblgenOp.getOperands()); |
131 | auto [results, resultVariadicity] = getValues(tblgenOp.getResults()); |
132 | |
133 | // Create the operands and results operations. |
134 | consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands, |
135 | operandVariadicity); |
136 | consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results, |
137 | resultVariadicity); |
138 | |
139 | return op; |
140 | } |
141 | |
142 | static irdl::DialectOp createIRDLDialect(OpBuilder &builder) { |
143 | MLIRContext *ctx = builder.getContext(); |
144 | return builder.create<irdl::DialectOp>(UnknownLoc::get(ctx), |
145 | StringAttr::get(ctx, selectedDialect)); |
146 | } |
147 | |
148 | static std::vector<llvm::Record *> |
149 | getOpDefinitions(const RecordKeeper &recordKeeper) { |
150 | if (!recordKeeper.getClass(Name: "Op" )) |
151 | return {}; |
152 | return recordKeeper.getAllDerivedDefinitions(ClassName: "Op" ); |
153 | } |
154 | |
155 | static bool emitDialectIRDLDefs(const RecordKeeper &recordKeeper, |
156 | raw_ostream &os) { |
157 | // Initialize. |
158 | MLIRContext ctx; |
159 | ctx.getOrLoadDialect<irdl::IRDLDialect>(); |
160 | OpBuilder builder(&ctx); |
161 | |
162 | // Create a module op and set it as the insertion point. |
163 | OwningOpRef<ModuleOp> module = |
164 | builder.create<ModuleOp>(UnknownLoc::get(&ctx)); |
165 | builder = builder.atBlockBegin(block: module->getBody()); |
166 | // Create the dialect and insert it. |
167 | irdl::DialectOp dialect = createIRDLDialect(builder); |
168 | // Set insertion point to start of DialectOp. |
169 | builder = builder.atBlockBegin(block: &dialect.getBody().emplaceBlock()); |
170 | |
171 | std::vector<Record *> defs = getOpDefinitions(recordKeeper); |
172 | for (auto *def : defs) { |
173 | tblgen::Operator tblgenOp(def); |
174 | if (tblgenOp.getDialectName() != selectedDialect) |
175 | continue; |
176 | |
177 | createIRDLOperation(builder, tblgenOp); |
178 | } |
179 | |
180 | // Print the module. |
181 | module->print(os); |
182 | |
183 | return false; |
184 | } |
185 | |
186 | static mlir::GenRegistration |
187 | genOpDefs("gen-dialect-irdl-defs" , "Generate IRDL dialect definitions" , |
188 | [](const RecordKeeper &records, raw_ostream &os) { |
189 | return emitDialectIRDLDefs(recordKeeper: records, os); |
190 | }); |
191 | |