1 | //===- MLIRGen.cpp --------------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Tools/PDLL/CodeGen/MLIRGen.h" |
10 | #include "mlir/AsmParser/AsmParser.h" |
11 | #include "mlir/Dialect/PDL/IR/PDL.h" |
12 | #include "mlir/Dialect/PDL/IR/PDLOps.h" |
13 | #include "mlir/Dialect/PDL/IR/PDLTypes.h" |
14 | #include "mlir/IR/Builders.h" |
15 | #include "mlir/IR/BuiltinOps.h" |
16 | #include "mlir/IR/Verifier.h" |
17 | #include "mlir/Tools/PDLL/AST/Context.h" |
18 | #include "mlir/Tools/PDLL/AST/Nodes.h" |
19 | #include "mlir/Tools/PDLL/AST/Types.h" |
20 | #include "mlir/Tools/PDLL/ODS/Context.h" |
21 | #include "mlir/Tools/PDLL/ODS/Operation.h" |
22 | #include "llvm/ADT/ScopedHashTable.h" |
23 | #include "llvm/ADT/StringExtras.h" |
24 | #include "llvm/ADT/TypeSwitch.h" |
25 | #include <optional> |
26 | |
27 | using namespace mlir; |
28 | using namespace mlir::pdll; |
29 | |
30 | //===----------------------------------------------------------------------===// |
31 | // CodeGen |
32 | //===----------------------------------------------------------------------===// |
33 | |
34 | namespace { |
35 | class CodeGen { |
36 | public: |
37 | CodeGen(MLIRContext *mlirContext, const ast::Context &context, |
38 | const llvm::SourceMgr &sourceMgr) |
39 | : builder(mlirContext), odsContext(context.getODSContext()), |
40 | sourceMgr(sourceMgr) { |
41 | // Make sure that the PDL dialect is loaded. |
42 | mlirContext->loadDialect<pdl::PDLDialect>(); |
43 | } |
44 | |
45 | OwningOpRef<ModuleOp> generate(const ast::Module &module); |
46 | |
47 | private: |
48 | /// Generate an MLIR location from the given source location. |
49 | Location genLoc(llvm::SMLoc loc); |
50 | Location genLoc(llvm::SMRange loc) { return genLoc(loc: loc.Start); } |
51 | |
52 | /// Generate an MLIR type from the given source type. |
53 | Type genType(ast::Type type); |
54 | |
55 | /// Generate MLIR for the given AST node. |
56 | void gen(const ast::Node *node); |
57 | |
58 | //===--------------------------------------------------------------------===// |
59 | // Statements |
60 | //===--------------------------------------------------------------------===// |
61 | |
62 | void genImpl(const ast::CompoundStmt *stmt); |
63 | void genImpl(const ast::EraseStmt *stmt); |
64 | void genImpl(const ast::LetStmt *stmt); |
65 | void genImpl(const ast::ReplaceStmt *stmt); |
66 | void genImpl(const ast::RewriteStmt *stmt); |
67 | void genImpl(const ast::ReturnStmt *stmt); |
68 | |
69 | //===--------------------------------------------------------------------===// |
70 | // Decls |
71 | //===--------------------------------------------------------------------===// |
72 | |
73 | void genImpl(const ast::UserConstraintDecl *decl); |
74 | void genImpl(const ast::UserRewriteDecl *decl); |
75 | void genImpl(const ast::PatternDecl *decl); |
76 | |
77 | /// Generate the set of MLIR values defined for the given variable decl, and |
78 | /// apply any attached constraints. |
79 | SmallVector<Value> genVar(const ast::VariableDecl *varDecl); |
80 | |
81 | /// Generate the value for a variable that does not have an initializer |
82 | /// expression, i.e. create the PDL value based on the type/constraints of the |
83 | /// variable. |
84 | Value genNonInitializerVar(const ast::VariableDecl *varDecl, Location loc); |
85 | |
86 | /// Apply the constraints of the given variable to `values`, which correspond |
87 | /// to the MLIR values of the variable. |
88 | void applyVarConstraints(const ast::VariableDecl *varDecl, ValueRange values); |
89 | |
90 | //===--------------------------------------------------------------------===// |
91 | // Expressions |
92 | //===--------------------------------------------------------------------===// |
93 | |
94 | Value genSingleExpr(const ast::Expr *expr); |
95 | SmallVector<Value> genExpr(const ast::Expr *expr); |
96 | Value genExprImpl(const ast::AttributeExpr *expr); |
97 | SmallVector<Value> genExprImpl(const ast::CallExpr *expr); |
98 | SmallVector<Value> genExprImpl(const ast::DeclRefExpr *expr); |
99 | Value genExprImpl(const ast::MemberAccessExpr *expr); |
100 | Value genExprImpl(const ast::OperationExpr *expr); |
101 | Value genExprImpl(const ast::RangeExpr *expr); |
102 | SmallVector<Value> genExprImpl(const ast::TupleExpr *expr); |
103 | Value genExprImpl(const ast::TypeExpr *expr); |
104 | |
105 | SmallVector<Value> genConstraintCall(const ast::UserConstraintDecl *decl, |
106 | Location loc, ValueRange inputs, |
107 | bool isNegated = false); |
108 | SmallVector<Value> genRewriteCall(const ast::UserRewriteDecl *decl, |
109 | Location loc, ValueRange inputs); |
110 | template <typename PDLOpT, typename T> |
111 | SmallVector<Value> genConstraintOrRewriteCall(const T *decl, Location loc, |
112 | ValueRange inputs, |
113 | bool isNegated = false); |
114 | |
115 | //===--------------------------------------------------------------------===// |
116 | // Fields |
117 | //===--------------------------------------------------------------------===// |
118 | |
119 | /// The MLIR builder used for building the resultant IR. |
120 | OpBuilder builder; |
121 | |
122 | /// A map from variable declarations to the MLIR equivalent. |
123 | using VariableMapTy = |
124 | llvm::ScopedHashTable<const ast::VariableDecl *, SmallVector<Value>>; |
125 | VariableMapTy variables; |
126 | |
127 | /// A reference to the ODS context. |
128 | const ods::Context &odsContext; |
129 | |
130 | /// The source manager of the PDLL ast. |
131 | const llvm::SourceMgr &sourceMgr; |
132 | }; |
133 | } // namespace |
134 | |
135 | OwningOpRef<ModuleOp> CodeGen::generate(const ast::Module &module) { |
136 | OwningOpRef<ModuleOp> mlirModule = |
137 | builder.create<ModuleOp>(genLoc(loc: module.getLoc())); |
138 | builder.setInsertionPointToStart(mlirModule->getBody()); |
139 | |
140 | // Generate code for each of the decls within the module. |
141 | for (const ast::Decl *decl : module.getChildren()) |
142 | gen(node: decl); |
143 | |
144 | return mlirModule; |
145 | } |
146 | |
147 | Location CodeGen::genLoc(llvm::SMLoc loc) { |
148 | unsigned fileID = sourceMgr.FindBufferContainingLoc(Loc: loc); |
149 | |
150 | // TODO: Fix performance issues in SourceMgr::getLineAndColumn so that we can |
151 | // use it here. |
152 | auto &bufferInfo = sourceMgr.getBufferInfo(i: fileID); |
153 | unsigned lineNo = bufferInfo.getLineNumber(Ptr: loc.getPointer()); |
154 | unsigned column = |
155 | (loc.getPointer() - bufferInfo.getPointerForLineNumber(LineNo: lineNo)) + 1; |
156 | auto *buffer = sourceMgr.getMemoryBuffer(i: fileID); |
157 | |
158 | return FileLineColLoc::get(builder.getContext(), |
159 | buffer->getBufferIdentifier(), lineNo, column); |
160 | } |
161 | |
162 | Type CodeGen::genType(ast::Type type) { |
163 | return TypeSwitch<ast::Type, Type>(type) |
164 | .Case(caseFn: [&](ast::AttributeType astType) -> Type { |
165 | return builder.getType<pdl::AttributeType>(); |
166 | }) |
167 | .Case(caseFn: [&](ast::OperationType astType) -> Type { |
168 | return builder.getType<pdl::OperationType>(); |
169 | }) |
170 | .Case(caseFn: [&](ast::TypeType astType) -> Type { |
171 | return builder.getType<pdl::TypeType>(); |
172 | }) |
173 | .Case(caseFn: [&](ast::ValueType astType) -> Type { |
174 | return builder.getType<pdl::ValueType>(); |
175 | }) |
176 | .Case(caseFn: [&](ast::RangeType astType) -> Type { |
177 | return pdl::RangeType::get(genType(astType.getElementType())); |
178 | }); |
179 | } |
180 | |
181 | void CodeGen::gen(const ast::Node *node) { |
182 | TypeSwitch<const ast::Node *>(node) |
183 | .Case<const ast::CompoundStmt, const ast::EraseStmt, const ast::LetStmt, |
184 | const ast::ReplaceStmt, const ast::RewriteStmt, |
185 | const ast::ReturnStmt, const ast::UserConstraintDecl, |
186 | const ast::UserRewriteDecl, const ast::PatternDecl>( |
187 | caseFn: [&](auto derivedNode) { this->genImpl(derivedNode); }) |
188 | .Case(caseFn: [&](const ast::Expr *expr) { genExpr(expr); }); |
189 | } |
190 | |
191 | //===----------------------------------------------------------------------===// |
192 | // CodeGen: Statements |
193 | //===----------------------------------------------------------------------===// |
194 | |
195 | void CodeGen::genImpl(const ast::CompoundStmt *stmt) { |
196 | VariableMapTy::ScopeTy varScope(variables); |
197 | for (const ast::Stmt *childStmt : stmt->getChildren()) |
198 | gen(node: childStmt); |
199 | } |
200 | |
201 | /// If the given builder is nested under a PDL PatternOp, build a rewrite |
202 | /// operation and update the builder to nest under it. This is necessary for |
203 | /// PDLL operation rewrite statements that are directly nested within a Pattern. |
204 | static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr, |
205 | Location loc) { |
206 | if (isa<pdl::PatternOp>(builder.getInsertionBlock()->getParentOp())) { |
207 | pdl::RewriteOp rewrite = |
208 | builder.create<pdl::RewriteOp>(loc, rootExpr, /*name=*/StringAttr(), |
209 | /*externalArgs=*/ValueRange()); |
210 | builder.createBlock(&rewrite.getBodyRegion()); |
211 | } |
212 | } |
213 | |
214 | void CodeGen::genImpl(const ast::EraseStmt *stmt) { |
215 | OpBuilder::InsertionGuard insertGuard(builder); |
216 | Value rootExpr = genSingleExpr(expr: stmt->getRootOpExpr()); |
217 | Location loc = genLoc(loc: stmt->getLoc()); |
218 | |
219 | // Make sure we are nested in a RewriteOp. |
220 | OpBuilder::InsertionGuard guard(builder); |
221 | checkAndNestUnderRewriteOp(builder, rootExpr, loc); |
222 | builder.create<pdl::EraseOp>(loc, rootExpr); |
223 | } |
224 | |
225 | void CodeGen::genImpl(const ast::LetStmt *stmt) { genVar(varDecl: stmt->getVarDecl()); } |
226 | |
227 | void CodeGen::genImpl(const ast::ReplaceStmt *stmt) { |
228 | OpBuilder::InsertionGuard insertGuard(builder); |
229 | Value rootExpr = genSingleExpr(expr: stmt->getRootOpExpr()); |
230 | Location loc = genLoc(loc: stmt->getLoc()); |
231 | |
232 | // Make sure we are nested in a RewriteOp. |
233 | OpBuilder::InsertionGuard guard(builder); |
234 | checkAndNestUnderRewriteOp(builder, rootExpr, loc); |
235 | |
236 | SmallVector<Value> replValues; |
237 | for (ast::Expr *replExpr : stmt->getReplExprs()) |
238 | replValues.push_back(Elt: genSingleExpr(expr: replExpr)); |
239 | |
240 | // Check to see if the statement has a replacement operation, or a range of |
241 | // replacement values. |
242 | bool usesReplOperation = |
243 | replValues.size() == 1 && |
244 | isa<pdl::OperationType>(replValues.front().getType()); |
245 | builder.create<pdl::ReplaceOp>( |
246 | loc, rootExpr, usesReplOperation ? replValues[0] : Value(), |
247 | usesReplOperation ? ValueRange() : ValueRange(replValues)); |
248 | } |
249 | |
250 | void CodeGen::genImpl(const ast::RewriteStmt *stmt) { |
251 | OpBuilder::InsertionGuard insertGuard(builder); |
252 | Value rootExpr = genSingleExpr(expr: stmt->getRootOpExpr()); |
253 | |
254 | // Make sure we are nested in a RewriteOp. |
255 | OpBuilder::InsertionGuard guard(builder); |
256 | checkAndNestUnderRewriteOp(builder, rootExpr, loc: genLoc(loc: stmt->getLoc())); |
257 | gen(node: stmt->getRewriteBody()); |
258 | } |
259 | |
260 | void CodeGen::genImpl(const ast::ReturnStmt *stmt) { |
261 | // ReturnStmt generation is handled by the respective constraint or rewrite |
262 | // parent node. |
263 | } |
264 | |
265 | //===----------------------------------------------------------------------===// |
266 | // CodeGen: Decls |
267 | //===----------------------------------------------------------------------===// |
268 | |
269 | void CodeGen::genImpl(const ast::UserConstraintDecl *decl) { |
270 | // All PDLL constraints get inlined when called, and the main native |
271 | // constraint declarations doesn't require any MLIR to be generated, only uses |
272 | // of it do. |
273 | } |
274 | |
275 | void CodeGen::genImpl(const ast::UserRewriteDecl *decl) { |
276 | // All PDLL rewrites get inlined when called, and the main native |
277 | // rewrite declarations doesn't require any MLIR to be generated, only uses |
278 | // of it do. |
279 | } |
280 | |
281 | void CodeGen::genImpl(const ast::PatternDecl *decl) { |
282 | const ast::Name *name = decl->getName(); |
283 | |
284 | // FIXME: Properly model HasBoundedRecursion in PDL so that we don't drop it |
285 | // here. |
286 | pdl::PatternOp pattern = builder.create<pdl::PatternOp>( |
287 | genLoc(decl->getLoc()), decl->getBenefit(), |
288 | name ? std::optional<StringRef>(name->getName()) |
289 | : std::optional<StringRef>()); |
290 | |
291 | OpBuilder::InsertionGuard savedInsertPoint(builder); |
292 | builder.setInsertionPointToStart(pattern.getBody()); |
293 | gen(node: decl->getBody()); |
294 | } |
295 | |
296 | SmallVector<Value> CodeGen::genVar(const ast::VariableDecl *varDecl) { |
297 | auto it = variables.begin(Key: varDecl); |
298 | if (it != variables.end()) |
299 | return *it; |
300 | |
301 | // If the variable has an initial value, use that as the base value. |
302 | // Otherwise, generate a value using the constraint list. |
303 | SmallVector<Value> values; |
304 | if (const ast::Expr *initExpr = varDecl->getInitExpr()) |
305 | values = genExpr(expr: initExpr); |
306 | else |
307 | values.push_back(Elt: genNonInitializerVar(varDecl, loc: genLoc(loc: varDecl->getLoc()))); |
308 | |
309 | // Apply the constraints of the values of the variable. |
310 | applyVarConstraints(varDecl, values); |
311 | |
312 | variables.insert(Key: varDecl, Val: values); |
313 | return values; |
314 | } |
315 | |
316 | Value CodeGen::genNonInitializerVar(const ast::VariableDecl *varDecl, |
317 | Location loc) { |
318 | // A functor used to generate expressions nested |
319 | auto getTypeConstraint = [&]() -> Value { |
320 | for (const ast::ConstraintRef &constraint : varDecl->getConstraints()) { |
321 | Value typeValue = |
322 | TypeSwitch<const ast::Node *, Value>(constraint.constraint) |
323 | .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl, |
324 | ast::ValueRangeConstraintDecl>( |
325 | caseFn: [&, this](auto *cst) -> Value { |
326 | if (auto *typeConstraintExpr = cst->getTypeExpr()) |
327 | return this->genSingleExpr(expr: typeConstraintExpr); |
328 | return Value(); |
329 | }) |
330 | .Default(defaultResult: Value()); |
331 | if (typeValue) |
332 | return typeValue; |
333 | } |
334 | return Value(); |
335 | }; |
336 | |
337 | // Generate a value based on the type of the variable. |
338 | ast::Type type = varDecl->getType(); |
339 | Type mlirType = genType(type); |
340 | if (type.isa<ast::ValueType>()) |
341 | return builder.create<pdl::OperandOp>(loc, mlirType, getTypeConstraint()); |
342 | if (type.isa<ast::TypeType>()) |
343 | return builder.create<pdl::TypeOp>(loc, mlirType, /*type=*/TypeAttr()); |
344 | if (type.isa<ast::AttributeType>()) |
345 | return builder.create<pdl::AttributeOp>(loc, getTypeConstraint()); |
346 | if (ast::OperationType opType = type.dyn_cast<ast::OperationType>()) { |
347 | Value operands = builder.create<pdl::OperandsOp>( |
348 | loc, pdl::RangeType::get(builder.getType<pdl::ValueType>()), |
349 | /*type=*/Value()); |
350 | Value results = builder.create<pdl::TypesOp>( |
351 | loc, pdl::RangeType::get(builder.getType<pdl::TypeType>()), |
352 | /*types=*/ArrayAttr()); |
353 | return builder.create<pdl::OperationOp>( |
354 | loc, opType.getName(), operands, std::nullopt, ValueRange(), results); |
355 | } |
356 | |
357 | if (ast::RangeType rangeTy = type.dyn_cast<ast::RangeType>()) { |
358 | ast::Type eleTy = rangeTy.getElementType(); |
359 | if (eleTy.isa<ast::ValueType>()) |
360 | return builder.create<pdl::OperandsOp>(loc, mlirType, |
361 | getTypeConstraint()); |
362 | if (eleTy.isa<ast::TypeType>()) |
363 | return builder.create<pdl::TypesOp>(loc, mlirType, /*types=*/ArrayAttr()); |
364 | } |
365 | |
366 | llvm_unreachable("invalid non-initialized variable type" ); |
367 | } |
368 | |
369 | void CodeGen::applyVarConstraints(const ast::VariableDecl *varDecl, |
370 | ValueRange values) { |
371 | // Generate calls to any user constraints that were attached via the |
372 | // constraint list. |
373 | for (const ast::ConstraintRef &ref : varDecl->getConstraints()) |
374 | if (const auto *userCst = dyn_cast<ast::UserConstraintDecl>(Val: ref.constraint)) |
375 | genConstraintCall(decl: userCst, loc: genLoc(loc: ref.referenceLoc), inputs: values); |
376 | } |
377 | |
378 | //===----------------------------------------------------------------------===// |
379 | // CodeGen: Expressions |
380 | //===----------------------------------------------------------------------===// |
381 | |
382 | Value CodeGen::genSingleExpr(const ast::Expr *expr) { |
383 | return TypeSwitch<const ast::Expr *, Value>(expr) |
384 | .Case<const ast::AttributeExpr, const ast::MemberAccessExpr, |
385 | const ast::OperationExpr, const ast::RangeExpr, |
386 | const ast::TypeExpr>( |
387 | caseFn: [&](auto derivedNode) { return this->genExprImpl(derivedNode); }) |
388 | .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>( |
389 | caseFn: [&](auto derivedNode) { |
390 | SmallVector<Value> results = this->genExprImpl(derivedNode); |
391 | assert(results.size() == 1 && "expected single expression result" ); |
392 | return results[0]; |
393 | }); |
394 | } |
395 | |
396 | SmallVector<Value> CodeGen::genExpr(const ast::Expr *expr) { |
397 | return TypeSwitch<const ast::Expr *, SmallVector<Value>>(expr) |
398 | .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>( |
399 | caseFn: [&](auto derivedNode) { return this->genExprImpl(derivedNode); }) |
400 | .Default(defaultFn: [&](const ast::Expr *expr) -> SmallVector<Value> { |
401 | return {genSingleExpr(expr)}; |
402 | }); |
403 | } |
404 | |
405 | Value CodeGen::genExprImpl(const ast::AttributeExpr *expr) { |
406 | Attribute attr = parseAttribute(attrStr: expr->getValue(), context: builder.getContext()); |
407 | assert(attr && "invalid MLIR attribute data" ); |
408 | return builder.create<pdl::AttributeOp>(genLoc(expr->getLoc()), attr); |
409 | } |
410 | |
411 | SmallVector<Value> CodeGen::genExprImpl(const ast::CallExpr *expr) { |
412 | Location loc = genLoc(loc: expr->getLoc()); |
413 | SmallVector<Value> arguments; |
414 | for (const ast::Expr *arg : expr->getArguments()) |
415 | arguments.push_back(Elt: genSingleExpr(expr: arg)); |
416 | |
417 | // Resolve the callable expression of this call. |
418 | auto *callableExpr = dyn_cast<ast::DeclRefExpr>(Val: expr->getCallableExpr()); |
419 | assert(callableExpr && "unhandled CallExpr callable" ); |
420 | |
421 | // Generate the PDL based on the type of callable. |
422 | const ast::Decl *callable = callableExpr->getDecl(); |
423 | if (const auto *decl = dyn_cast<ast::UserConstraintDecl>(Val: callable)) |
424 | return genConstraintCall(decl, loc, inputs: arguments, isNegated: expr->getIsNegated()); |
425 | if (const auto *decl = dyn_cast<ast::UserRewriteDecl>(Val: callable)) |
426 | return genRewriteCall(decl, loc, inputs: arguments); |
427 | llvm_unreachable("unhandled CallExpr callable" ); |
428 | } |
429 | |
430 | SmallVector<Value> CodeGen::genExprImpl(const ast::DeclRefExpr *expr) { |
431 | if (const auto *varDecl = dyn_cast<ast::VariableDecl>(Val: expr->getDecl())) |
432 | return genVar(varDecl); |
433 | llvm_unreachable("unknown decl reference expression" ); |
434 | } |
435 | |
436 | Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) { |
437 | Location loc = genLoc(loc: expr->getLoc()); |
438 | StringRef name = expr->getMemberName(); |
439 | SmallVector<Value> parentExprs = genExpr(expr: expr->getParentExpr()); |
440 | ast::Type parentType = expr->getParentExpr()->getType(); |
441 | |
442 | // Handle operation based member access. |
443 | if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) { |
444 | if (isa<ast::AllResultsMemberAccessExpr>(Val: expr)) { |
445 | Type mlirType = genType(type: expr->getType()); |
446 | if (isa<pdl::ValueType>(mlirType)) |
447 | return builder.create<pdl::ResultOp>(loc, mlirType, parentExprs[0], |
448 | builder.getI32IntegerAttr(0)); |
449 | return builder.create<pdl::ResultsOp>(loc, mlirType, parentExprs[0]); |
450 | } |
451 | |
452 | const ods::Operation *odsOp = opType.getODSOperation(); |
453 | if (!odsOp) { |
454 | assert(llvm::isDigit(name[0]) && |
455 | "unregistered op only allows numeric indexing" ); |
456 | unsigned resultIndex; |
457 | name.getAsInteger(/*Radix=*/10, Result&: resultIndex); |
458 | IntegerAttr index = builder.getI32IntegerAttr(resultIndex); |
459 | return builder.create<pdl::ResultOp>(loc, genType(expr->getType()), |
460 | parentExprs[0], index); |
461 | } |
462 | |
463 | // Find the result with the member name or by index. |
464 | ArrayRef<ods::OperandOrResult> results = odsOp->getResults(); |
465 | unsigned resultIndex = results.size(); |
466 | if (llvm::isDigit(C: name[0])) { |
467 | name.getAsInteger(/*Radix=*/10, Result&: resultIndex); |
468 | } else { |
469 | auto findFn = [&](const ods::OperandOrResult &result) { |
470 | return result.getName() == name; |
471 | }; |
472 | resultIndex = llvm::find_if(Range&: results, P: findFn) - results.begin(); |
473 | } |
474 | assert(resultIndex < results.size() && "invalid result index" ); |
475 | |
476 | // Generate the result access. |
477 | IntegerAttr index = builder.getI32IntegerAttr(resultIndex); |
478 | return builder.create<pdl::ResultsOp>(loc, genType(expr->getType()), |
479 | parentExprs[0], index); |
480 | } |
481 | |
482 | // Handle tuple based member access. |
483 | if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) { |
484 | auto elementNames = tupleType.getElementNames(); |
485 | |
486 | // The index is either a numeric index, or a name. |
487 | unsigned index = 0; |
488 | if (llvm::isDigit(C: name[0])) |
489 | name.getAsInteger(/*Radix=*/10, Result&: index); |
490 | else |
491 | index = llvm::find(Range&: elementNames, Val: name) - elementNames.begin(); |
492 | |
493 | assert(index < parentExprs.size() && "invalid result index" ); |
494 | return parentExprs[index]; |
495 | } |
496 | |
497 | llvm_unreachable("unhandled member access expression" ); |
498 | } |
499 | |
500 | Value CodeGen::genExprImpl(const ast::OperationExpr *expr) { |
501 | Location loc = genLoc(loc: expr->getLoc()); |
502 | std::optional<StringRef> opName = expr->getName(); |
503 | |
504 | // Operands. |
505 | SmallVector<Value> operands; |
506 | for (const ast::Expr *operand : expr->getOperands()) |
507 | operands.push_back(Elt: genSingleExpr(expr: operand)); |
508 | |
509 | // Attributes. |
510 | SmallVector<StringRef> attrNames; |
511 | SmallVector<Value> attrValues; |
512 | for (const ast::NamedAttributeDecl *attr : expr->getAttributes()) { |
513 | attrNames.push_back(Elt: attr->getName().getName()); |
514 | attrValues.push_back(Elt: genSingleExpr(expr: attr->getValue())); |
515 | } |
516 | |
517 | // Results. |
518 | SmallVector<Value> results; |
519 | for (const ast::Expr *result : expr->getResultTypes()) |
520 | results.push_back(Elt: genSingleExpr(expr: result)); |
521 | |
522 | return builder.create<pdl::OperationOp>(loc, opName, operands, attrNames, |
523 | attrValues, results); |
524 | } |
525 | |
526 | Value CodeGen::genExprImpl(const ast::RangeExpr *expr) { |
527 | SmallVector<Value> elements; |
528 | for (const ast::Expr *element : expr->getElements()) |
529 | llvm::append_range(C&: elements, R: genExpr(expr: element)); |
530 | |
531 | return builder.create<pdl::RangeOp>(genLoc(expr->getLoc()), |
532 | genType(expr->getType()), elements); |
533 | } |
534 | |
535 | SmallVector<Value> CodeGen::genExprImpl(const ast::TupleExpr *expr) { |
536 | SmallVector<Value> elements; |
537 | for (const ast::Expr *element : expr->getElements()) |
538 | elements.push_back(Elt: genSingleExpr(expr: element)); |
539 | return elements; |
540 | } |
541 | |
542 | Value CodeGen::genExprImpl(const ast::TypeExpr *expr) { |
543 | Type type = parseType(typeStr: expr->getValue(), context: builder.getContext()); |
544 | assert(type && "invalid MLIR type data" ); |
545 | return builder.create<pdl::TypeOp>(genLoc(expr->getLoc()), |
546 | builder.getType<pdl::TypeType>(), |
547 | TypeAttr::get(type)); |
548 | } |
549 | |
550 | SmallVector<Value> |
551 | CodeGen::genConstraintCall(const ast::UserConstraintDecl *decl, Location loc, |
552 | ValueRange inputs, bool isNegated) { |
553 | // Apply any constraints defined on the arguments to the input values. |
554 | for (auto it : llvm::zip(t: decl->getInputs(), u&: inputs)) |
555 | applyVarConstraints(varDecl: std::get<0>(t&: it), values: std::get<1>(t&: it)); |
556 | |
557 | // Generate the constraint call. |
558 | SmallVector<Value> results = |
559 | genConstraintOrRewriteCall<pdl::ApplyNativeConstraintOp>( |
560 | decl, loc, inputs, isNegated); |
561 | |
562 | // Apply any constraints defined on the results of the constraint. |
563 | for (auto it : llvm::zip(t: decl->getResults(), u&: results)) |
564 | applyVarConstraints(varDecl: std::get<0>(t&: it), values: std::get<1>(t&: it)); |
565 | return results; |
566 | } |
567 | |
568 | SmallVector<Value> CodeGen::genRewriteCall(const ast::UserRewriteDecl *decl, |
569 | Location loc, ValueRange inputs) { |
570 | return genConstraintOrRewriteCall<pdl::ApplyNativeRewriteOp>(decl, loc, |
571 | inputs); |
572 | } |
573 | |
574 | template <typename PDLOpT, typename T> |
575 | SmallVector<Value> |
576 | CodeGen::genConstraintOrRewriteCall(const T *decl, Location loc, |
577 | ValueRange inputs, bool isNegated) { |
578 | const ast::CompoundStmt *cstBody = decl->getBody(); |
579 | |
580 | // If the decl doesn't have a statement body, it is a native decl. |
581 | if (!cstBody) { |
582 | ast::Type declResultType = decl->getResultType(); |
583 | SmallVector<Type> resultTypes; |
584 | if (ast::TupleType tupleType = declResultType.dyn_cast<ast::TupleType>()) { |
585 | for (ast::Type type : tupleType.getElementTypes()) |
586 | resultTypes.push_back(Elt: genType(type)); |
587 | } else { |
588 | resultTypes.push_back(Elt: genType(type: declResultType)); |
589 | } |
590 | PDLOpT pdlOp = builder.create<PDLOpT>( |
591 | loc, resultTypes, decl->getName().getName(), inputs); |
592 | if (isNegated && std::is_same_v<PDLOpT, pdl::ApplyNativeConstraintOp>) |
593 | cast<pdl::ApplyNativeConstraintOp>(pdlOp).setIsNegated(true); |
594 | return pdlOp->getResults(); |
595 | } |
596 | |
597 | // Otherwise, this is a PDLL decl. |
598 | VariableMapTy::ScopeTy varScope(variables); |
599 | |
600 | // Map the inputs of the call to the decl arguments. |
601 | // Note: This is only valid because we do not support recursion, meaning |
602 | // we don't need to worry about conflicting mappings here. |
603 | for (auto it : llvm::zip(inputs, decl->getInputs())) |
604 | variables.insert(Key: std::get<1>(it), Val: {std::get<0>(it)}); |
605 | |
606 | // Visit the body of the call as normal. |
607 | gen(node: cstBody); |
608 | |
609 | // If the decl has no results, there is nothing to do. |
610 | if (cstBody->getChildren().empty()) |
611 | return SmallVector<Value>(); |
612 | auto *returnStmt = dyn_cast<ast::ReturnStmt>(Val: cstBody->getChildren().back()); |
613 | if (!returnStmt) |
614 | return SmallVector<Value>(); |
615 | |
616 | // Otherwise, grab the results from the return statement. |
617 | return genExpr(expr: returnStmt->getResultExpr()); |
618 | } |
619 | |
620 | //===----------------------------------------------------------------------===// |
621 | // MLIRGen |
622 | //===----------------------------------------------------------------------===// |
623 | |
624 | OwningOpRef<ModuleOp> mlir::pdll::codegenPDLLToMLIR( |
625 | MLIRContext *mlirContext, const ast::Context &context, |
626 | const llvm::SourceMgr &sourceMgr, const ast::Module &module) { |
627 | CodeGen codegen(mlirContext, context, sourceMgr); |
628 | OwningOpRef<ModuleOp> mlirModule = codegen.generate(module); |
629 | if (failed(verify(*mlirModule))) |
630 | return nullptr; |
631 | return mlirModule; |
632 | } |
633 | |