1 | //===- MLIRGen.cpp --------------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Tools/PDLL/CodeGen/MLIRGen.h" |
10 | #include "mlir/AsmParser/AsmParser.h" |
11 | #include "mlir/Dialect/PDL/IR/PDL.h" |
12 | #include "mlir/Dialect/PDL/IR/PDLOps.h" |
13 | #include "mlir/Dialect/PDL/IR/PDLTypes.h" |
14 | #include "mlir/IR/Builders.h" |
15 | #include "mlir/IR/BuiltinOps.h" |
16 | #include "mlir/IR/Verifier.h" |
17 | #include "mlir/Tools/PDLL/AST/Context.h" |
18 | #include "mlir/Tools/PDLL/AST/Nodes.h" |
19 | #include "mlir/Tools/PDLL/AST/Types.h" |
20 | #include "mlir/Tools/PDLL/ODS/Context.h" |
21 | #include "mlir/Tools/PDLL/ODS/Operation.h" |
22 | #include "llvm/ADT/ScopedHashTable.h" |
23 | #include "llvm/ADT/StringExtras.h" |
24 | #include "llvm/ADT/TypeSwitch.h" |
25 | #include <optional> |
26 | |
27 | using namespace mlir; |
28 | using namespace mlir::pdll; |
29 | |
30 | //===----------------------------------------------------------------------===// |
31 | // CodeGen |
32 | //===----------------------------------------------------------------------===// |
33 | |
34 | namespace { |
35 | class CodeGen { |
36 | public: |
37 | CodeGen(MLIRContext *mlirContext, const ast::Context &context, |
38 | const llvm::SourceMgr &sourceMgr) |
39 | : builder(mlirContext), odsContext(context.getODSContext()), |
40 | sourceMgr(sourceMgr) { |
41 | // Make sure that the PDL dialect is loaded. |
42 | mlirContext->loadDialect<pdl::PDLDialect>(); |
43 | } |
44 | |
45 | OwningOpRef<ModuleOp> generate(const ast::Module &module); |
46 | |
47 | private: |
48 | /// Generate an MLIR location from the given source location. |
49 | Location genLoc(llvm::SMLoc loc); |
50 | Location genLoc(llvm::SMRange loc) { return genLoc(loc: loc.Start); } |
51 | |
52 | /// Generate an MLIR type from the given source type. |
53 | Type genType(ast::Type type); |
54 | |
55 | /// Generate MLIR for the given AST node. |
56 | void gen(const ast::Node *node); |
57 | |
58 | //===--------------------------------------------------------------------===// |
59 | // Statements |
60 | //===--------------------------------------------------------------------===// |
61 | |
62 | void genImpl(const ast::CompoundStmt *stmt); |
63 | void genImpl(const ast::EraseStmt *stmt); |
64 | void genImpl(const ast::LetStmt *stmt); |
65 | void genImpl(const ast::ReplaceStmt *stmt); |
66 | void genImpl(const ast::RewriteStmt *stmt); |
67 | void genImpl(const ast::ReturnStmt *stmt); |
68 | |
69 | //===--------------------------------------------------------------------===// |
70 | // Decls |
71 | //===--------------------------------------------------------------------===// |
72 | |
73 | void genImpl(const ast::UserConstraintDecl *decl); |
74 | void genImpl(const ast::UserRewriteDecl *decl); |
75 | void genImpl(const ast::PatternDecl *decl); |
76 | |
77 | /// Generate the set of MLIR values defined for the given variable decl, and |
78 | /// apply any attached constraints. |
79 | SmallVector<Value> genVar(const ast::VariableDecl *varDecl); |
80 | |
81 | /// Generate the value for a variable that does not have an initializer |
82 | /// expression, i.e. create the PDL value based on the type/constraints of the |
83 | /// variable. |
84 | Value genNonInitializerVar(const ast::VariableDecl *varDecl, Location loc); |
85 | |
86 | /// Apply the constraints of the given variable to `values`, which correspond |
87 | /// to the MLIR values of the variable. |
88 | void applyVarConstraints(const ast::VariableDecl *varDecl, ValueRange values); |
89 | |
90 | //===--------------------------------------------------------------------===// |
91 | // Expressions |
92 | //===--------------------------------------------------------------------===// |
93 | |
94 | Value genSingleExpr(const ast::Expr *expr); |
95 | SmallVector<Value> genExpr(const ast::Expr *expr); |
96 | Value genExprImpl(const ast::AttributeExpr *expr); |
97 | SmallVector<Value> genExprImpl(const ast::CallExpr *expr); |
98 | SmallVector<Value> genExprImpl(const ast::DeclRefExpr *expr); |
99 | Value genExprImpl(const ast::MemberAccessExpr *expr); |
100 | Value genExprImpl(const ast::OperationExpr *expr); |
101 | Value genExprImpl(const ast::RangeExpr *expr); |
102 | SmallVector<Value> genExprImpl(const ast::TupleExpr *expr); |
103 | Value genExprImpl(const ast::TypeExpr *expr); |
104 | |
105 | SmallVector<Value> genConstraintCall(const ast::UserConstraintDecl *decl, |
106 | Location loc, ValueRange inputs, |
107 | bool isNegated = false); |
108 | SmallVector<Value> genRewriteCall(const ast::UserRewriteDecl *decl, |
109 | Location loc, ValueRange inputs); |
110 | template <typename PDLOpT, typename T> |
111 | SmallVector<Value> genConstraintOrRewriteCall(const T *decl, Location loc, |
112 | ValueRange inputs, |
113 | bool isNegated = false); |
114 | |
115 | //===--------------------------------------------------------------------===// |
116 | // Fields |
117 | //===--------------------------------------------------------------------===// |
118 | |
119 | /// The MLIR builder used for building the resultant IR. |
120 | OpBuilder builder; |
121 | |
122 | /// A map from variable declarations to the MLIR equivalent. |
123 | using VariableMapTy = |
124 | llvm::ScopedHashTable<const ast::VariableDecl *, SmallVector<Value>>; |
125 | VariableMapTy variables; |
126 | |
127 | /// A reference to the ODS context. |
128 | const ods::Context &odsContext; |
129 | |
130 | /// The source manager of the PDLL ast. |
131 | const llvm::SourceMgr &sourceMgr; |
132 | }; |
133 | } // namespace |
134 | |
135 | OwningOpRef<ModuleOp> CodeGen::generate(const ast::Module &module) { |
136 | OwningOpRef<ModuleOp> mlirModule = |
137 | builder.create<ModuleOp>(genLoc(loc: module.getLoc())); |
138 | builder.setInsertionPointToStart(mlirModule->getBody()); |
139 | |
140 | // Generate code for each of the decls within the module. |
141 | for (const ast::Decl *decl : module.getChildren()) |
142 | gen(node: decl); |
143 | |
144 | return mlirModule; |
145 | } |
146 | |
147 | Location CodeGen::genLoc(llvm::SMLoc loc) { |
148 | unsigned fileID = sourceMgr.FindBufferContainingLoc(Loc: loc); |
149 | |
150 | // TODO: Fix performance issues in SourceMgr::getLineAndColumn so that we can |
151 | // use it here. |
152 | auto &bufferInfo = sourceMgr.getBufferInfo(i: fileID); |
153 | unsigned lineNo = bufferInfo.getLineNumber(Ptr: loc.getPointer()); |
154 | unsigned column = |
155 | (loc.getPointer() - bufferInfo.getPointerForLineNumber(LineNo: lineNo)) + 1; |
156 | auto *buffer = sourceMgr.getMemoryBuffer(i: fileID); |
157 | |
158 | return FileLineColLoc::get(context: builder.getContext(), |
159 | fileName: buffer->getBufferIdentifier(), line: lineNo, column); |
160 | } |
161 | |
162 | Type CodeGen::genType(ast::Type type) { |
163 | return TypeSwitch<ast::Type, Type>(type) |
164 | .Case(caseFn: [&](ast::AttributeType astType) -> Type { |
165 | return builder.getType<pdl::AttributeType>(); |
166 | }) |
167 | .Case(caseFn: [&](ast::OperationType astType) -> Type { |
168 | return builder.getType<pdl::OperationType>(); |
169 | }) |
170 | .Case(caseFn: [&](ast::TypeType astType) -> Type { |
171 | return builder.getType<pdl::TypeType>(); |
172 | }) |
173 | .Case(caseFn: [&](ast::ValueType astType) -> Type { |
174 | return builder.getType<pdl::ValueType>(); |
175 | }) |
176 | .Case(caseFn: [&](ast::RangeType astType) -> Type { |
177 | return pdl::RangeType::get(genType(astType.getElementType())); |
178 | }); |
179 | } |
180 | |
181 | void CodeGen::gen(const ast::Node *node) { |
182 | TypeSwitch<const ast::Node *>(node) |
183 | .Case<const ast::CompoundStmt, const ast::EraseStmt, const ast::LetStmt, |
184 | const ast::ReplaceStmt, const ast::RewriteStmt, |
185 | const ast::ReturnStmt, const ast::UserConstraintDecl, |
186 | const ast::UserRewriteDecl, const ast::PatternDecl>( |
187 | caseFn: [&](auto derivedNode) { this->genImpl(derivedNode); }) |
188 | .Case(caseFn: [&](const ast::Expr *expr) { genExpr(expr); }); |
189 | } |
190 | |
191 | //===----------------------------------------------------------------------===// |
192 | // CodeGen: Statements |
193 | //===----------------------------------------------------------------------===// |
194 | |
195 | void CodeGen::genImpl(const ast::CompoundStmt *stmt) { |
196 | VariableMapTy::ScopeTy varScope(variables); |
197 | for (const ast::Stmt *childStmt : stmt->getChildren()) |
198 | gen(node: childStmt); |
199 | } |
200 | |
201 | /// If the given builder is nested under a PDL PatternOp, build a rewrite |
202 | /// operation and update the builder to nest under it. This is necessary for |
203 | /// PDLL operation rewrite statements that are directly nested within a Pattern. |
204 | static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr, |
205 | Location loc) { |
206 | if (isa<pdl::PatternOp>(builder.getInsertionBlock()->getParentOp())) { |
207 | pdl::RewriteOp rewrite = |
208 | builder.create<pdl::RewriteOp>(loc, rootExpr, /*name=*/StringAttr(), |
209 | /*externalArgs=*/ValueRange()); |
210 | builder.createBlock(&rewrite.getBodyRegion()); |
211 | } |
212 | } |
213 | |
214 | void CodeGen::genImpl(const ast::EraseStmt *stmt) { |
215 | OpBuilder::InsertionGuard insertGuard(builder); |
216 | Value rootExpr = genSingleExpr(expr: stmt->getRootOpExpr()); |
217 | Location loc = genLoc(loc: stmt->getLoc()); |
218 | |
219 | // Make sure we are nested in a RewriteOp. |
220 | OpBuilder::InsertionGuard guard(builder); |
221 | checkAndNestUnderRewriteOp(builder, rootExpr, loc); |
222 | builder.create<pdl::EraseOp>(loc, rootExpr); |
223 | } |
224 | |
225 | void CodeGen::genImpl(const ast::LetStmt *stmt) { genVar(varDecl: stmt->getVarDecl()); } |
226 | |
227 | void CodeGen::genImpl(const ast::ReplaceStmt *stmt) { |
228 | OpBuilder::InsertionGuard insertGuard(builder); |
229 | Value rootExpr = genSingleExpr(expr: stmt->getRootOpExpr()); |
230 | Location loc = genLoc(loc: stmt->getLoc()); |
231 | |
232 | // Make sure we are nested in a RewriteOp. |
233 | OpBuilder::InsertionGuard guard(builder); |
234 | checkAndNestUnderRewriteOp(builder, rootExpr, loc); |
235 | |
236 | SmallVector<Value> replValues; |
237 | for (ast::Expr *replExpr : stmt->getReplExprs()) |
238 | replValues.push_back(Elt: genSingleExpr(expr: replExpr)); |
239 | |
240 | // Check to see if the statement has a replacement operation, or a range of |
241 | // replacement values. |
242 | bool usesReplOperation = |
243 | replValues.size() == 1 && |
244 | isa<pdl::OperationType>(replValues.front().getType()); |
245 | builder.create<pdl::ReplaceOp>( |
246 | loc, rootExpr, usesReplOperation ? replValues[0] : Value(), |
247 | usesReplOperation ? ValueRange() : ValueRange(replValues)); |
248 | } |
249 | |
250 | void CodeGen::genImpl(const ast::RewriteStmt *stmt) { |
251 | OpBuilder::InsertionGuard insertGuard(builder); |
252 | Value rootExpr = genSingleExpr(expr: stmt->getRootOpExpr()); |
253 | |
254 | // Make sure we are nested in a RewriteOp. |
255 | OpBuilder::InsertionGuard guard(builder); |
256 | checkAndNestUnderRewriteOp(builder, rootExpr, loc: genLoc(loc: stmt->getLoc())); |
257 | gen(node: stmt->getRewriteBody()); |
258 | } |
259 | |
260 | void CodeGen::genImpl(const ast::ReturnStmt *stmt) { |
261 | // ReturnStmt generation is handled by the respective constraint or rewrite |
262 | // parent node. |
263 | } |
264 | |
265 | //===----------------------------------------------------------------------===// |
266 | // CodeGen: Decls |
267 | //===----------------------------------------------------------------------===// |
268 | |
269 | void CodeGen::genImpl(const ast::UserConstraintDecl *decl) { |
270 | // All PDLL constraints get inlined when called, and the main native |
271 | // constraint declarations doesn't require any MLIR to be generated, only uses |
272 | // of it do. |
273 | } |
274 | |
275 | void CodeGen::genImpl(const ast::UserRewriteDecl *decl) { |
276 | // All PDLL rewrites get inlined when called, and the main native |
277 | // rewrite declarations doesn't require any MLIR to be generated, only uses |
278 | // of it do. |
279 | } |
280 | |
281 | void CodeGen::genImpl(const ast::PatternDecl *decl) { |
282 | const ast::Name *name = decl->getName(); |
283 | |
284 | // FIXME: Properly model HasBoundedRecursion in PDL so that we don't drop it |
285 | // here. |
286 | pdl::PatternOp pattern = builder.create<pdl::PatternOp>( |
287 | genLoc(decl->getLoc()), decl->getBenefit(), |
288 | name ? std::optional<StringRef>(name->getName()) |
289 | : std::optional<StringRef>()); |
290 | |
291 | OpBuilder::InsertionGuard savedInsertPoint(builder); |
292 | builder.setInsertionPointToStart(pattern.getBody()); |
293 | gen(node: decl->getBody()); |
294 | } |
295 | |
296 | SmallVector<Value> CodeGen::genVar(const ast::VariableDecl *varDecl) { |
297 | auto it = variables.begin(Key: varDecl); |
298 | if (it != variables.end()) |
299 | return *it; |
300 | |
301 | // If the variable has an initial value, use that as the base value. |
302 | // Otherwise, generate a value using the constraint list. |
303 | SmallVector<Value> values; |
304 | if (const ast::Expr *initExpr = varDecl->getInitExpr()) |
305 | values = genExpr(expr: initExpr); |
306 | else |
307 | values.push_back(Elt: genNonInitializerVar(varDecl, loc: genLoc(loc: varDecl->getLoc()))); |
308 | |
309 | // Apply the constraints of the values of the variable. |
310 | applyVarConstraints(varDecl, values); |
311 | |
312 | variables.insert(Key: varDecl, Val: values); |
313 | return values; |
314 | } |
315 | |
316 | Value CodeGen::genNonInitializerVar(const ast::VariableDecl *varDecl, |
317 | Location loc) { |
318 | // A functor used to generate expressions nested |
319 | auto getTypeConstraint = [&]() -> Value { |
320 | for (const ast::ConstraintRef &constraint : varDecl->getConstraints()) { |
321 | Value typeValue = |
322 | TypeSwitch<const ast::Node *, Value>(constraint.constraint) |
323 | .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl, |
324 | ast::ValueRangeConstraintDecl>( |
325 | caseFn: [&, this](auto *cst) -> Value { |
326 | if (auto *typeConstraintExpr = cst->getTypeExpr()) |
327 | return this->genSingleExpr(expr: typeConstraintExpr); |
328 | return Value(); |
329 | }) |
330 | .Default(defaultResult: Value()); |
331 | if (typeValue) |
332 | return typeValue; |
333 | } |
334 | return Value(); |
335 | }; |
336 | |
337 | // Generate a value based on the type of the variable. |
338 | ast::Type type = varDecl->getType(); |
339 | Type mlirType = genType(type); |
340 | if (isa<ast::ValueType>(type)) |
341 | return builder.create<pdl::OperandOp>(loc, mlirType, getTypeConstraint()); |
342 | if (isa<ast::TypeType>(type)) |
343 | return builder.create<pdl::TypeOp>(loc, mlirType, /*type=*/TypeAttr()); |
344 | if (isa<ast::AttributeType>(type)) |
345 | return builder.create<pdl::AttributeOp>(loc, getTypeConstraint()); |
346 | if (ast::OperationType opType = dyn_cast<ast::OperationType>(Val&: type)) { |
347 | Value operands = builder.create<pdl::OperandsOp>( |
348 | loc, pdl::RangeType::get(builder.getType<pdl::ValueType>()), |
349 | /*type=*/Value()); |
350 | Value results = builder.create<pdl::TypesOp>( |
351 | loc, pdl::RangeType::get(builder.getType<pdl::TypeType>()), |
352 | /*types=*/ArrayAttr()); |
353 | return builder.create<pdl::OperationOp>( |
354 | loc, opType.getName(), operands, std::nullopt, ValueRange(), results); |
355 | } |
356 | |
357 | if (ast::RangeType rangeTy = dyn_cast<ast::RangeType>(Val&: type)) { |
358 | ast::Type eleTy = rangeTy.getElementType(); |
359 | if (isa<ast::ValueType>(eleTy)) |
360 | return builder.create<pdl::OperandsOp>(loc, mlirType, |
361 | getTypeConstraint()); |
362 | if (isa<ast::TypeType>(eleTy)) |
363 | return builder.create<pdl::TypesOp>(loc, mlirType, /*types=*/ArrayAttr()); |
364 | } |
365 | |
366 | llvm_unreachable("invalid non-initialized variable type" ); |
367 | } |
368 | |
369 | void CodeGen::applyVarConstraints(const ast::VariableDecl *varDecl, |
370 | ValueRange values) { |
371 | // Generate calls to any user constraints that were attached via the |
372 | // constraint list. |
373 | for (const ast::ConstraintRef &ref : varDecl->getConstraints()) |
374 | if (const auto *userCst = dyn_cast<ast::UserConstraintDecl>(Val: ref.constraint)) |
375 | genConstraintCall(decl: userCst, loc: genLoc(loc: ref.referenceLoc), inputs: values); |
376 | } |
377 | |
378 | //===----------------------------------------------------------------------===// |
379 | // CodeGen: Expressions |
380 | //===----------------------------------------------------------------------===// |
381 | |
382 | Value CodeGen::genSingleExpr(const ast::Expr *expr) { |
383 | return TypeSwitch<const ast::Expr *, Value>(expr) |
384 | .Case<const ast::AttributeExpr, const ast::MemberAccessExpr, |
385 | const ast::OperationExpr, const ast::RangeExpr, |
386 | const ast::TypeExpr>( |
387 | caseFn: [&](auto derivedNode) { return this->genExprImpl(derivedNode); }) |
388 | .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>( |
389 | caseFn: [&](auto derivedNode) { |
390 | return llvm::getSingleElement(this->genExprImpl(derivedNode)); |
391 | }); |
392 | } |
393 | |
394 | SmallVector<Value> CodeGen::genExpr(const ast::Expr *expr) { |
395 | return TypeSwitch<const ast::Expr *, SmallVector<Value>>(expr) |
396 | .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>( |
397 | caseFn: [&](auto derivedNode) { return this->genExprImpl(derivedNode); }) |
398 | .Default(defaultFn: [&](const ast::Expr *expr) -> SmallVector<Value> { |
399 | return {genSingleExpr(expr)}; |
400 | }); |
401 | } |
402 | |
403 | Value CodeGen::genExprImpl(const ast::AttributeExpr *expr) { |
404 | Attribute attr = parseAttribute(attrStr: expr->getValue(), context: builder.getContext()); |
405 | assert(attr && "invalid MLIR attribute data" ); |
406 | return builder.create<pdl::AttributeOp>(genLoc(expr->getLoc()), attr); |
407 | } |
408 | |
409 | SmallVector<Value> CodeGen::genExprImpl(const ast::CallExpr *expr) { |
410 | Location loc = genLoc(loc: expr->getLoc()); |
411 | SmallVector<Value> arguments; |
412 | for (const ast::Expr *arg : expr->getArguments()) |
413 | arguments.push_back(Elt: genSingleExpr(expr: arg)); |
414 | |
415 | // Resolve the callable expression of this call. |
416 | auto *callableExpr = dyn_cast<ast::DeclRefExpr>(Val: expr->getCallableExpr()); |
417 | assert(callableExpr && "unhandled CallExpr callable" ); |
418 | |
419 | // Generate the PDL based on the type of callable. |
420 | const ast::Decl *callable = callableExpr->getDecl(); |
421 | if (const auto *decl = dyn_cast<ast::UserConstraintDecl>(Val: callable)) |
422 | return genConstraintCall(decl, loc, inputs: arguments, isNegated: expr->getIsNegated()); |
423 | if (const auto *decl = dyn_cast<ast::UserRewriteDecl>(Val: callable)) |
424 | return genRewriteCall(decl, loc, inputs: arguments); |
425 | llvm_unreachable("unhandled CallExpr callable" ); |
426 | } |
427 | |
428 | SmallVector<Value> CodeGen::genExprImpl(const ast::DeclRefExpr *expr) { |
429 | if (const auto *varDecl = dyn_cast<ast::VariableDecl>(Val: expr->getDecl())) |
430 | return genVar(varDecl); |
431 | llvm_unreachable("unknown decl reference expression" ); |
432 | } |
433 | |
434 | Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) { |
435 | Location loc = genLoc(loc: expr->getLoc()); |
436 | StringRef name = expr->getMemberName(); |
437 | SmallVector<Value> parentExprs = genExpr(expr: expr->getParentExpr()); |
438 | ast::Type parentType = expr->getParentExpr()->getType(); |
439 | |
440 | // Handle operation based member access. |
441 | if (ast::OperationType opType = dyn_cast<ast::OperationType>(Val&: parentType)) { |
442 | if (isa<ast::AllResultsMemberAccessExpr>(Val: expr)) { |
443 | Type mlirType = genType(type: expr->getType()); |
444 | if (isa<pdl::ValueType>(mlirType)) |
445 | return builder.create<pdl::ResultOp>(loc, mlirType, parentExprs[0], |
446 | builder.getI32IntegerAttr(0)); |
447 | return builder.create<pdl::ResultsOp>(loc, mlirType, parentExprs[0]); |
448 | } |
449 | |
450 | const ods::Operation *odsOp = opType.getODSOperation(); |
451 | if (!odsOp) { |
452 | assert(llvm::isDigit(name[0]) && |
453 | "unregistered op only allows numeric indexing" ); |
454 | unsigned resultIndex; |
455 | name.getAsInteger(/*Radix=*/10, Result&: resultIndex); |
456 | IntegerAttr index = builder.getI32IntegerAttr(resultIndex); |
457 | return builder.create<pdl::ResultOp>(loc, genType(expr->getType()), |
458 | parentExprs[0], index); |
459 | } |
460 | |
461 | // Find the result with the member name or by index. |
462 | ArrayRef<ods::OperandOrResult> results = odsOp->getResults(); |
463 | unsigned resultIndex = results.size(); |
464 | if (llvm::isDigit(C: name[0])) { |
465 | name.getAsInteger(/*Radix=*/10, Result&: resultIndex); |
466 | } else { |
467 | auto findFn = [&](const ods::OperandOrResult &result) { |
468 | return result.getName() == name; |
469 | }; |
470 | resultIndex = llvm::find_if(Range&: results, P: findFn) - results.begin(); |
471 | } |
472 | assert(resultIndex < results.size() && "invalid result index" ); |
473 | |
474 | // Generate the result access. |
475 | IntegerAttr index = builder.getI32IntegerAttr(resultIndex); |
476 | return builder.create<pdl::ResultsOp>(loc, genType(expr->getType()), |
477 | parentExprs[0], index); |
478 | } |
479 | |
480 | // Handle tuple based member access. |
481 | if (auto tupleType = dyn_cast<ast::TupleType>(Val&: parentType)) { |
482 | auto elementNames = tupleType.getElementNames(); |
483 | |
484 | // The index is either a numeric index, or a name. |
485 | unsigned index = 0; |
486 | if (llvm::isDigit(C: name[0])) |
487 | name.getAsInteger(/*Radix=*/10, Result&: index); |
488 | else |
489 | index = llvm::find(Range&: elementNames, Val: name) - elementNames.begin(); |
490 | |
491 | assert(index < parentExprs.size() && "invalid result index" ); |
492 | return parentExprs[index]; |
493 | } |
494 | |
495 | llvm_unreachable("unhandled member access expression" ); |
496 | } |
497 | |
498 | Value CodeGen::genExprImpl(const ast::OperationExpr *expr) { |
499 | Location loc = genLoc(loc: expr->getLoc()); |
500 | std::optional<StringRef> opName = expr->getName(); |
501 | |
502 | // Operands. |
503 | SmallVector<Value> operands; |
504 | for (const ast::Expr *operand : expr->getOperands()) |
505 | operands.push_back(Elt: genSingleExpr(expr: operand)); |
506 | |
507 | // Attributes. |
508 | SmallVector<StringRef> attrNames; |
509 | SmallVector<Value> attrValues; |
510 | for (const ast::NamedAttributeDecl *attr : expr->getAttributes()) { |
511 | attrNames.push_back(Elt: attr->getName().getName()); |
512 | attrValues.push_back(Elt: genSingleExpr(expr: attr->getValue())); |
513 | } |
514 | |
515 | // Results. |
516 | SmallVector<Value> results; |
517 | for (const ast::Expr *result : expr->getResultTypes()) |
518 | results.push_back(Elt: genSingleExpr(expr: result)); |
519 | |
520 | return builder.create<pdl::OperationOp>(loc, opName, operands, attrNames, |
521 | attrValues, results); |
522 | } |
523 | |
524 | Value CodeGen::genExprImpl(const ast::RangeExpr *expr) { |
525 | SmallVector<Value> elements; |
526 | for (const ast::Expr *element : expr->getElements()) |
527 | llvm::append_range(C&: elements, R: genExpr(expr: element)); |
528 | |
529 | return builder.create<pdl::RangeOp>(genLoc(expr->getLoc()), |
530 | genType(expr->getType()), elements); |
531 | } |
532 | |
533 | SmallVector<Value> CodeGen::genExprImpl(const ast::TupleExpr *expr) { |
534 | SmallVector<Value> elements; |
535 | for (const ast::Expr *element : expr->getElements()) |
536 | elements.push_back(Elt: genSingleExpr(expr: element)); |
537 | return elements; |
538 | } |
539 | |
540 | Value CodeGen::genExprImpl(const ast::TypeExpr *expr) { |
541 | Type type = parseType(typeStr: expr->getValue(), context: builder.getContext()); |
542 | assert(type && "invalid MLIR type data" ); |
543 | return builder.create<pdl::TypeOp>(genLoc(expr->getLoc()), |
544 | builder.getType<pdl::TypeType>(), |
545 | TypeAttr::get(type)); |
546 | } |
547 | |
548 | SmallVector<Value> |
549 | CodeGen::genConstraintCall(const ast::UserConstraintDecl *decl, Location loc, |
550 | ValueRange inputs, bool isNegated) { |
551 | // Apply any constraints defined on the arguments to the input values. |
552 | for (auto it : llvm::zip(t: decl->getInputs(), u&: inputs)) |
553 | applyVarConstraints(varDecl: std::get<0>(t&: it), values: std::get<1>(t&: it)); |
554 | |
555 | // Generate the constraint call. |
556 | SmallVector<Value> results = |
557 | genConstraintOrRewriteCall<pdl::ApplyNativeConstraintOp>( |
558 | decl, loc, inputs, isNegated); |
559 | |
560 | // Apply any constraints defined on the results of the constraint. |
561 | for (auto it : llvm::zip(t: decl->getResults(), u&: results)) |
562 | applyVarConstraints(varDecl: std::get<0>(t&: it), values: std::get<1>(t&: it)); |
563 | return results; |
564 | } |
565 | |
566 | SmallVector<Value> CodeGen::genRewriteCall(const ast::UserRewriteDecl *decl, |
567 | Location loc, ValueRange inputs) { |
568 | return genConstraintOrRewriteCall<pdl::ApplyNativeRewriteOp>(decl, loc, |
569 | inputs); |
570 | } |
571 | |
572 | template <typename PDLOpT, typename T> |
573 | SmallVector<Value> |
574 | CodeGen::genConstraintOrRewriteCall(const T *decl, Location loc, |
575 | ValueRange inputs, bool isNegated) { |
576 | const ast::CompoundStmt *cstBody = decl->getBody(); |
577 | |
578 | // If the decl doesn't have a statement body, it is a native decl. |
579 | if (!cstBody) { |
580 | ast::Type declResultType = decl->getResultType(); |
581 | SmallVector<Type> resultTypes; |
582 | if (ast::TupleType tupleType = dyn_cast<ast::TupleType>(Val&: declResultType)) { |
583 | for (ast::Type type : tupleType.getElementTypes()) |
584 | resultTypes.push_back(Elt: genType(type)); |
585 | } else { |
586 | resultTypes.push_back(Elt: genType(type: declResultType)); |
587 | } |
588 | PDLOpT pdlOp = builder.create<PDLOpT>(loc, resultTypes, |
589 | decl->getName().getName(), inputs); |
590 | if (isNegated && std::is_same_v<PDLOpT, pdl::ApplyNativeConstraintOp>) |
591 | cast<pdl::ApplyNativeConstraintOp>(pdlOp).setIsNegated(true); |
592 | return pdlOp->getResults(); |
593 | } |
594 | |
595 | // Otherwise, this is a PDLL decl. |
596 | VariableMapTy::ScopeTy varScope(variables); |
597 | |
598 | // Map the inputs of the call to the decl arguments. |
599 | // Note: This is only valid because we do not support recursion, meaning |
600 | // we don't need to worry about conflicting mappings here. |
601 | for (auto it : llvm::zip(inputs, decl->getInputs())) |
602 | variables.insert(Key: std::get<1>(it), Val: {std::get<0>(it)}); |
603 | |
604 | // Visit the body of the call as normal. |
605 | gen(node: cstBody); |
606 | |
607 | // If the decl has no results, there is nothing to do. |
608 | if (cstBody->getChildren().empty()) |
609 | return SmallVector<Value>(); |
610 | auto *returnStmt = dyn_cast<ast::ReturnStmt>(Val: cstBody->getChildren().back()); |
611 | if (!returnStmt) |
612 | return SmallVector<Value>(); |
613 | |
614 | // Otherwise, grab the results from the return statement. |
615 | return genExpr(expr: returnStmt->getResultExpr()); |
616 | } |
617 | |
618 | //===----------------------------------------------------------------------===// |
619 | // MLIRGen |
620 | //===----------------------------------------------------------------------===// |
621 | |
622 | OwningOpRef<ModuleOp> mlir::pdll::codegenPDLLToMLIR( |
623 | MLIRContext *mlirContext, const ast::Context &context, |
624 | const llvm::SourceMgr &sourceMgr, const ast::Module &module) { |
625 | CodeGen codegen(mlirContext, context, sourceMgr); |
626 | OwningOpRef<ModuleOp> mlirModule = codegen.generate(module); |
627 | if (failed(verify(*mlirModule))) |
628 | return nullptr; |
629 | return mlirModule; |
630 | } |
631 | |