1//===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
10#include "mlir/Dialect/EmitC/IR/EmitC.h"
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/IR/BuiltinOps.h"
13#include "mlir/IR/BuiltinTypes.h"
14#include "mlir/IR/Dialect.h"
15#include "mlir/IR/Operation.h"
16#include "mlir/IR/SymbolTable.h"
17#include "mlir/Support/IndentedOstream.h"
18#include "mlir/Support/LLVM.h"
19#include "mlir/Target/Cpp/CppEmitter.h"
20#include "llvm/ADT/DenseMap.h"
21#include "llvm/ADT/ScopedHashTable.h"
22#include "llvm/ADT/StringExtras.h"
23#include "llvm/ADT/StringMap.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/Support/Debug.h"
26#include "llvm/Support/FormatVariadic.h"
27#include <stack>
28#include <utility>
29
30#define DEBUG_TYPE "translate-to-cpp"
31
32using namespace mlir;
33using namespace mlir::emitc;
34using llvm::formatv;
35
36/// Convenience functions to produce interleaved output with functions returning
37/// a LogicalResult. This is different than those in STLExtras as functions used
38/// on each element doesn't return a string.
39template <typename ForwardIterator, typename UnaryFunctor,
40 typename NullaryFunctor>
41inline LogicalResult
42interleaveWithError(ForwardIterator begin, ForwardIterator end,
43 UnaryFunctor eachFn, NullaryFunctor betweenFn) {
44 if (begin == end)
45 return success();
46 if (failed(eachFn(*begin)))
47 return failure();
48 ++begin;
49 for (; begin != end; ++begin) {
50 betweenFn();
51 if (failed(eachFn(*begin)))
52 return failure();
53 }
54 return success();
55}
56
57template <typename Container, typename UnaryFunctor, typename NullaryFunctor>
58inline LogicalResult interleaveWithError(const Container &c,
59 UnaryFunctor eachFn,
60 NullaryFunctor betweenFn) {
61 return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn);
62}
63
64template <typename Container, typename UnaryFunctor>
65inline LogicalResult interleaveCommaWithError(const Container &c,
66 raw_ostream &os,
67 UnaryFunctor eachFn) {
68 return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
69}
70
71/// Return the precedence of a operator as an integer, higher values
72/// imply higher precedence.
73static FailureOr<int> getOperatorPrecedence(Operation *operation) {
74 return llvm::TypeSwitch<Operation *, FailureOr<int>>(operation)
75 .Case<emitc::AddOp>([&](auto op) { return 12; })
76 .Case<emitc::ApplyOp>([&](auto op) { return 15; })
77 .Case<emitc::BitwiseAndOp>([&](auto op) { return 7; })
78 .Case<emitc::BitwiseLeftShiftOp>([&](auto op) { return 11; })
79 .Case<emitc::BitwiseNotOp>([&](auto op) { return 15; })
80 .Case<emitc::BitwiseOrOp>([&](auto op) { return 5; })
81 .Case<emitc::BitwiseRightShiftOp>([&](auto op) { return 11; })
82 .Case<emitc::BitwiseXorOp>([&](auto op) { return 6; })
83 .Case<emitc::CallOp>([&](auto op) { return 16; })
84 .Case<emitc::CallOpaqueOp>([&](auto op) { return 16; })
85 .Case<emitc::CastOp>([&](auto op) { return 15; })
86 .Case<emitc::CmpOp>([&](auto op) -> FailureOr<int> {
87 switch (op.getPredicate()) {
88 case emitc::CmpPredicate::eq:
89 case emitc::CmpPredicate::ne:
90 return 8;
91 case emitc::CmpPredicate::lt:
92 case emitc::CmpPredicate::le:
93 case emitc::CmpPredicate::gt:
94 case emitc::CmpPredicate::ge:
95 return 9;
96 case emitc::CmpPredicate::three_way:
97 return 10;
98 }
99 return op->emitError("unsupported cmp predicate");
100 })
101 .Case<emitc::ConditionalOp>([&](auto op) { return 2; })
102 .Case<emitc::DivOp>([&](auto op) { return 13; })
103 .Case<emitc::LogicalAndOp>([&](auto op) { return 4; })
104 .Case<emitc::LogicalNotOp>([&](auto op) { return 15; })
105 .Case<emitc::LogicalOrOp>([&](auto op) { return 3; })
106 .Case<emitc::MulOp>([&](auto op) { return 13; })
107 .Case<emitc::RemOp>([&](auto op) { return 13; })
108 .Case<emitc::SubOp>([&](auto op) { return 12; })
109 .Case<emitc::UnaryMinusOp>([&](auto op) { return 15; })
110 .Case<emitc::UnaryPlusOp>([&](auto op) { return 15; })
111 .Default([](auto op) { return op->emitError("unsupported operation"); });
112}
113
114namespace {
115/// Emitter that uses dialect specific emitters to emit C++ code.
116struct CppEmitter {
117 explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop);
118
119 /// Emits attribute or returns failure.
120 LogicalResult emitAttribute(Location loc, Attribute attr);
121
122 /// Emits operation 'op' with/without training semicolon or returns failure.
123 LogicalResult emitOperation(Operation &op, bool trailingSemicolon);
124
125 /// Emits type 'type' or returns failure.
126 LogicalResult emitType(Location loc, Type type);
127
128 /// Emits array of types as a std::tuple of the emitted types.
129 /// - emits void for an empty array;
130 /// - emits the type of the only element for arrays of size one;
131 /// - emits a std::tuple otherwise;
132 LogicalResult emitTypes(Location loc, ArrayRef<Type> types);
133
134 /// Emits array of types as a std::tuple of the emitted types independently of
135 /// the array size.
136 LogicalResult emitTupleType(Location loc, ArrayRef<Type> types);
137
138 /// Emits an assignment for a variable which has been declared previously.
139 LogicalResult emitVariableAssignment(OpResult result);
140
141 /// Emits a variable declaration for a result of an operation.
142 LogicalResult emitVariableDeclaration(OpResult result,
143 bool trailingSemicolon);
144
145 /// Emits a declaration of a variable with the given type and name.
146 LogicalResult emitVariableDeclaration(Location loc, Type type,
147 StringRef name);
148
149 /// Emits the variable declaration and assignment prefix for 'op'.
150 /// - emits separate variable followed by std::tie for multi-valued operation;
151 /// - emits single type followed by variable for single result;
152 /// - emits nothing if no value produced by op;
153 /// Emits final '=' operator where a type is produced. Returns failure if
154 /// any result type could not be converted.
155 LogicalResult emitAssignPrefix(Operation &op);
156
157 /// Emits a global variable declaration or definition.
158 LogicalResult emitGlobalVariable(GlobalOp op);
159
160 /// Emits a label for the block.
161 LogicalResult emitLabel(Block &block);
162
163 /// Emits the operands and atttributes of the operation. All operands are
164 /// emitted first and then all attributes in alphabetical order.
165 LogicalResult emitOperandsAndAttributes(Operation &op,
166 ArrayRef<StringRef> exclude = {});
167
168 /// Emits the operands of the operation. All operands are emitted in order.
169 LogicalResult emitOperands(Operation &op);
170
171 /// Emits value as an operands of an operation
172 LogicalResult emitOperand(Value value);
173
174 /// Emit an expression as a C expression.
175 LogicalResult emitExpression(ExpressionOp expressionOp);
176
177 /// Return the existing or a new name for a Value.
178 StringRef getOrCreateName(Value val);
179
180 // Returns the textual representation of a subscript operation.
181 std::string getSubscriptName(emitc::SubscriptOp op);
182
183 /// Return the existing or a new label of a Block.
184 StringRef getOrCreateName(Block &block);
185
186 /// Whether to map an mlir integer to a unsigned integer in C++.
187 bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
188
189 /// RAII helper function to manage entering/exiting C++ scopes.
190 struct Scope {
191 Scope(CppEmitter &emitter)
192 : valueMapperScope(emitter.valueMapper),
193 blockMapperScope(emitter.blockMapper), emitter(emitter) {
194 emitter.valueInScopeCount.push(x: emitter.valueInScopeCount.top());
195 emitter.labelInScopeCount.push(x: emitter.labelInScopeCount.top());
196 }
197 ~Scope() {
198 emitter.valueInScopeCount.pop();
199 emitter.labelInScopeCount.pop();
200 }
201
202 private:
203 llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
204 llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
205 CppEmitter &emitter;
206 };
207
208 /// Returns wether the Value is assigned to a C++ variable in the scope.
209 bool hasValueInScope(Value val);
210
211 // Returns whether a label is assigned to the block.
212 bool hasBlockLabel(Block &block);
213
214 /// Returns the output stream.
215 raw_indented_ostream &ostream() { return os; };
216
217 /// Returns if all variables for op results and basic block arguments need to
218 /// be declared at the beginning of a function.
219 bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
220
221 /// Get expression currently being emitted.
222 ExpressionOp getEmittedExpression() { return emittedExpression; }
223
224 /// Determine whether given value is part of the expression potentially being
225 /// emitted.
226 bool isPartOfCurrentExpression(Value value) {
227 if (!emittedExpression)
228 return false;
229 Operation *def = value.getDefiningOp();
230 if (!def)
231 return false;
232 auto operandExpression = dyn_cast<ExpressionOp>(def->getParentOp());
233 return operandExpression == emittedExpression;
234 };
235
236private:
237 using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
238 using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
239
240 /// Output stream to emit to.
241 raw_indented_ostream os;
242
243 /// Boolean to enforce that all variables for op results and block
244 /// arguments are declared at the beginning of the function. This also
245 /// includes results from ops located in nested regions.
246 bool declareVariablesAtTop;
247
248 /// Map from value to name of C++ variable that contain the name.
249 ValueMapper valueMapper;
250
251 /// Map from block to name of C++ label.
252 BlockMapper blockMapper;
253
254 /// The number of values in the current scope. This is used to declare the
255 /// names of values in a scope.
256 std::stack<int64_t> valueInScopeCount;
257 std::stack<int64_t> labelInScopeCount;
258
259 /// State of the current expression being emitted.
260 ExpressionOp emittedExpression;
261 SmallVector<int> emittedExpressionPrecedence;
262
263 void pushExpressionPrecedence(int precedence) {
264 emittedExpressionPrecedence.push_back(Elt: precedence);
265 }
266 void popExpressionPrecedence() { emittedExpressionPrecedence.pop_back(); }
267 static int lowestPrecedence() { return 0; }
268 int getExpressionPrecedence() {
269 if (emittedExpressionPrecedence.empty())
270 return lowestPrecedence();
271 return emittedExpressionPrecedence.back();
272 }
273};
274} // namespace
275
276/// Determine whether expression \p expressionOp should be emitted inline, i.e.
277/// as part of its user. This function recommends inlining of any expressions
278/// that can be inlined unless it is used by another expression, under the
279/// assumption that any expression fusion/re-materialization was taken care of
280/// by transformations run by the backend.
281static bool shouldBeInlined(ExpressionOp expressionOp) {
282 // Do not inline if expression is marked as such.
283 if (expressionOp.getDoNotInline())
284 return false;
285
286 // Do not inline expressions with side effects to prevent side-effect
287 // reordering.
288 if (expressionOp.hasSideEffects())
289 return false;
290
291 // Do not inline expressions with multiple uses.
292 Value result = expressionOp.getResult();
293 if (!result.hasOneUse())
294 return false;
295
296 // Do not inline expressions used by other expressions, as any desired
297 // expression folding was taken care of by transformations.
298 Operation *user = *result.getUsers().begin();
299 return !user->getParentOfType<ExpressionOp>();
300}
301
302static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
303 Attribute value) {
304 OpResult result = operation->getResult(idx: 0);
305
306 // Only emit an assignment as the variable was already declared when printing
307 // the FuncOp.
308 if (emitter.shouldDeclareVariablesAtTop()) {
309 // Skip the assignment if the emitc.constant has no value.
310 if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
311 if (oAttr.getValue().empty())
312 return success();
313 }
314
315 if (failed(result: emitter.emitVariableAssignment(result)))
316 return failure();
317 return emitter.emitAttribute(loc: operation->getLoc(), attr: value);
318 }
319
320 // Emit a variable declaration for an emitc.constant op without value.
321 if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
322 if (oAttr.getValue().empty())
323 // The semicolon gets printed by the emitOperation function.
324 return emitter.emitVariableDeclaration(result,
325 /*trailingSemicolon=*/false);
326 }
327
328 // Emit a variable declaration.
329 if (failed(result: emitter.emitAssignPrefix(op&: *operation)))
330 return failure();
331 return emitter.emitAttribute(loc: operation->getLoc(), attr: value);
332}
333
334static LogicalResult printOperation(CppEmitter &emitter,
335 emitc::ConstantOp constantOp) {
336 Operation *operation = constantOp.getOperation();
337 Attribute value = constantOp.getValue();
338
339 return printConstantOp(emitter, operation, value);
340}
341
342static LogicalResult printOperation(CppEmitter &emitter,
343 emitc::VariableOp variableOp) {
344 Operation *operation = variableOp.getOperation();
345 Attribute value = variableOp.getValue();
346
347 return printConstantOp(emitter, operation, value);
348}
349
350static LogicalResult printOperation(CppEmitter &emitter,
351 emitc::GlobalOp globalOp) {
352
353 return emitter.emitGlobalVariable(globalOp);
354}
355
356static LogicalResult printOperation(CppEmitter &emitter,
357 emitc::AssignOp assignOp) {
358 OpResult result = assignOp.getVar().getDefiningOp()->getResult(0);
359
360 if (failed(result: emitter.emitVariableAssignment(result)))
361 return failure();
362
363 return emitter.emitOperand(value: assignOp.getValue());
364}
365
366static LogicalResult printOperation(CppEmitter &emitter,
367 emitc::GetGlobalOp op) {
368 // Add name to cache so that `hasValueInScope` works.
369 emitter.getOrCreateName(op.getResult());
370 return success();
371}
372
373static LogicalResult printOperation(CppEmitter &emitter,
374 emitc::SubscriptOp subscriptOp) {
375 // Add name to cache so that `hasValueInScope` works.
376 emitter.getOrCreateName(subscriptOp.getResult());
377 return success();
378}
379
380static LogicalResult printBinaryOperation(CppEmitter &emitter,
381 Operation *operation,
382 StringRef binaryOperator) {
383 raw_ostream &os = emitter.ostream();
384
385 if (failed(result: emitter.emitAssignPrefix(op&: *operation)))
386 return failure();
387
388 if (failed(result: emitter.emitOperand(value: operation->getOperand(idx: 0))))
389 return failure();
390
391 os << " " << binaryOperator << " ";
392
393 if (failed(result: emitter.emitOperand(value: operation->getOperand(idx: 1))))
394 return failure();
395
396 return success();
397}
398
399static LogicalResult printUnaryOperation(CppEmitter &emitter,
400 Operation *operation,
401 StringRef unaryOperator) {
402 raw_ostream &os = emitter.ostream();
403
404 if (failed(result: emitter.emitAssignPrefix(op&: *operation)))
405 return failure();
406
407 os << unaryOperator;
408
409 if (failed(result: emitter.emitOperand(value: operation->getOperand(idx: 0))))
410 return failure();
411
412 return success();
413}
414
415static LogicalResult printOperation(CppEmitter &emitter, emitc::AddOp addOp) {
416 Operation *operation = addOp.getOperation();
417
418 return printBinaryOperation(emitter, operation, binaryOperator: "+");
419}
420
421static LogicalResult printOperation(CppEmitter &emitter, emitc::DivOp divOp) {
422 Operation *operation = divOp.getOperation();
423
424 return printBinaryOperation(emitter, operation, binaryOperator: "/");
425}
426
427static LogicalResult printOperation(CppEmitter &emitter, emitc::MulOp mulOp) {
428 Operation *operation = mulOp.getOperation();
429
430 return printBinaryOperation(emitter, operation, binaryOperator: "*");
431}
432
433static LogicalResult printOperation(CppEmitter &emitter, emitc::RemOp remOp) {
434 Operation *operation = remOp.getOperation();
435
436 return printBinaryOperation(emitter, operation, binaryOperator: "%");
437}
438
439static LogicalResult printOperation(CppEmitter &emitter, emitc::SubOp subOp) {
440 Operation *operation = subOp.getOperation();
441
442 return printBinaryOperation(emitter, operation, binaryOperator: "-");
443}
444
445static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
446 Operation *operation = cmpOp.getOperation();
447
448 StringRef binaryOperator;
449
450 switch (cmpOp.getPredicate()) {
451 case emitc::CmpPredicate::eq:
452 binaryOperator = "==";
453 break;
454 case emitc::CmpPredicate::ne:
455 binaryOperator = "!=";
456 break;
457 case emitc::CmpPredicate::lt:
458 binaryOperator = "<";
459 break;
460 case emitc::CmpPredicate::le:
461 binaryOperator = "<=";
462 break;
463 case emitc::CmpPredicate::gt:
464 binaryOperator = ">";
465 break;
466 case emitc::CmpPredicate::ge:
467 binaryOperator = ">=";
468 break;
469 case emitc::CmpPredicate::three_way:
470 binaryOperator = "<=>";
471 break;
472 }
473
474 return printBinaryOperation(emitter, operation, binaryOperator);
475}
476
477static LogicalResult printOperation(CppEmitter &emitter,
478 emitc::ConditionalOp conditionalOp) {
479 raw_ostream &os = emitter.ostream();
480
481 if (failed(emitter.emitAssignPrefix(op&: *conditionalOp)))
482 return failure();
483
484 if (failed(emitter.emitOperand(value: conditionalOp.getCondition())))
485 return failure();
486
487 os << " ? ";
488
489 if (failed(emitter.emitOperand(value: conditionalOp.getTrueValue())))
490 return failure();
491
492 os << " : ";
493
494 if (failed(emitter.emitOperand(value: conditionalOp.getFalseValue())))
495 return failure();
496
497 return success();
498}
499
500static LogicalResult printOperation(CppEmitter &emitter,
501 emitc::VerbatimOp verbatimOp) {
502 raw_ostream &os = emitter.ostream();
503
504 os << verbatimOp.getValue();
505
506 return success();
507}
508
509static LogicalResult printOperation(CppEmitter &emitter,
510 cf::BranchOp branchOp) {
511 raw_ostream &os = emitter.ostream();
512 Block &successor = *branchOp.getSuccessor();
513
514 for (auto pair :
515 llvm::zip(branchOp.getOperands(), successor.getArguments())) {
516 Value &operand = std::get<0>(pair);
517 BlockArgument &argument = std::get<1>(pair);
518 os << emitter.getOrCreateName(argument) << " = "
519 << emitter.getOrCreateName(operand) << ";\n";
520 }
521
522 os << "goto ";
523 if (!(emitter.hasBlockLabel(block&: successor)))
524 return branchOp.emitOpError("unable to find label for successor block");
525 os << emitter.getOrCreateName(block&: successor);
526 return success();
527}
528
529static LogicalResult printOperation(CppEmitter &emitter,
530 cf::CondBranchOp condBranchOp) {
531 raw_indented_ostream &os = emitter.ostream();
532 Block &trueSuccessor = *condBranchOp.getTrueDest();
533 Block &falseSuccessor = *condBranchOp.getFalseDest();
534
535 os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition())
536 << ") {\n";
537
538 os.indent();
539
540 // If condition is true.
541 for (auto pair : llvm::zip(condBranchOp.getTrueOperands(),
542 trueSuccessor.getArguments())) {
543 Value &operand = std::get<0>(pair);
544 BlockArgument &argument = std::get<1>(pair);
545 os << emitter.getOrCreateName(argument) << " = "
546 << emitter.getOrCreateName(operand) << ";\n";
547 }
548
549 os << "goto ";
550 if (!(emitter.hasBlockLabel(block&: trueSuccessor))) {
551 return condBranchOp.emitOpError("unable to find label for successor block");
552 }
553 os << emitter.getOrCreateName(block&: trueSuccessor) << ";\n";
554 os.unindent() << "} else {\n";
555 os.indent();
556 // If condition is false.
557 for (auto pair : llvm::zip(condBranchOp.getFalseOperands(),
558 falseSuccessor.getArguments())) {
559 Value &operand = std::get<0>(pair);
560 BlockArgument &argument = std::get<1>(pair);
561 os << emitter.getOrCreateName(argument) << " = "
562 << emitter.getOrCreateName(operand) << ";\n";
563 }
564
565 os << "goto ";
566 if (!(emitter.hasBlockLabel(block&: falseSuccessor))) {
567 return condBranchOp.emitOpError()
568 << "unable to find label for successor block";
569 }
570 os << emitter.getOrCreateName(block&: falseSuccessor) << ";\n";
571 os.unindent() << "}";
572 return success();
573}
574
575static LogicalResult printCallOperation(CppEmitter &emitter, Operation *callOp,
576 StringRef callee) {
577 if (failed(result: emitter.emitAssignPrefix(op&: *callOp)))
578 return failure();
579
580 raw_ostream &os = emitter.ostream();
581 os << callee << "(";
582 if (failed(result: emitter.emitOperands(op&: *callOp)))
583 return failure();
584 os << ")";
585 return success();
586}
587
588static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) {
589 Operation *operation = callOp.getOperation();
590 StringRef callee = callOp.getCallee();
591
592 return printCallOperation(emitter, callOp: operation, callee);
593}
594
595static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
596 Operation *operation = callOp.getOperation();
597 StringRef callee = callOp.getCallee();
598
599 return printCallOperation(emitter, callOp: operation, callee);
600}
601
602static LogicalResult printOperation(CppEmitter &emitter,
603 emitc::CallOpaqueOp callOpaqueOp) {
604 raw_ostream &os = emitter.ostream();
605 Operation &op = *callOpaqueOp.getOperation();
606
607 if (failed(result: emitter.emitAssignPrefix(op)))
608 return failure();
609 os << callOpaqueOp.getCallee();
610
611 auto emitArgs = [&](Attribute attr) -> LogicalResult {
612 if (auto t = dyn_cast<IntegerAttr>(attr)) {
613 // Index attributes are treated specially as operand index.
614 if (t.getType().isIndex()) {
615 int64_t idx = t.getInt();
616 Value operand = op.getOperand(idx);
617 auto literalDef =
618 dyn_cast_if_present<LiteralOp>(operand.getDefiningOp());
619 if (!literalDef && !emitter.hasValueInScope(val: operand))
620 return op.emitOpError(message: "operand ")
621 << idx << "'s value not defined in scope";
622 os << emitter.getOrCreateName(val: operand);
623 return success();
624 }
625 }
626 if (failed(result: emitter.emitAttribute(loc: op.getLoc(), attr)))
627 return failure();
628
629 return success();
630 };
631
632 if (callOpaqueOp.getTemplateArgs()) {
633 os << "<";
634 if (failed(interleaveCommaWithError(*callOpaqueOp.getTemplateArgs(), os,
635 emitArgs)))
636 return failure();
637 os << ">";
638 }
639
640 os << "(";
641
642 LogicalResult emittedArgs =
643 callOpaqueOp.getArgs()
644 ? interleaveCommaWithError(*callOpaqueOp.getArgs(), os, emitArgs)
645 : emitter.emitOperands(op);
646 if (failed(result: emittedArgs))
647 return failure();
648 os << ")";
649 return success();
650}
651
652static LogicalResult printOperation(CppEmitter &emitter,
653 emitc::ApplyOp applyOp) {
654 raw_ostream &os = emitter.ostream();
655 Operation &op = *applyOp.getOperation();
656
657 if (failed(result: emitter.emitAssignPrefix(op)))
658 return failure();
659 os << applyOp.getApplicableOperator();
660 os << emitter.getOrCreateName(applyOp.getOperand());
661
662 return success();
663}
664
665static LogicalResult printOperation(CppEmitter &emitter,
666 emitc::BitwiseAndOp bitwiseAndOp) {
667 Operation *operation = bitwiseAndOp.getOperation();
668 return printBinaryOperation(emitter, operation, binaryOperator: "&");
669}
670
671static LogicalResult
672printOperation(CppEmitter &emitter,
673 emitc::BitwiseLeftShiftOp bitwiseLeftShiftOp) {
674 Operation *operation = bitwiseLeftShiftOp.getOperation();
675 return printBinaryOperation(emitter, operation, binaryOperator: "<<");
676}
677
678static LogicalResult printOperation(CppEmitter &emitter,
679 emitc::BitwiseNotOp bitwiseNotOp) {
680 Operation *operation = bitwiseNotOp.getOperation();
681 return printUnaryOperation(emitter, operation, unaryOperator: "~");
682}
683
684static LogicalResult printOperation(CppEmitter &emitter,
685 emitc::BitwiseOrOp bitwiseOrOp) {
686 Operation *operation = bitwiseOrOp.getOperation();
687 return printBinaryOperation(emitter, operation, binaryOperator: "|");
688}
689
690static LogicalResult
691printOperation(CppEmitter &emitter,
692 emitc::BitwiseRightShiftOp bitwiseRightShiftOp) {
693 Operation *operation = bitwiseRightShiftOp.getOperation();
694 return printBinaryOperation(emitter, operation, binaryOperator: ">>");
695}
696
697static LogicalResult printOperation(CppEmitter &emitter,
698 emitc::BitwiseXorOp bitwiseXorOp) {
699 Operation *operation = bitwiseXorOp.getOperation();
700 return printBinaryOperation(emitter, operation, binaryOperator: "^");
701}
702
703static LogicalResult printOperation(CppEmitter &emitter,
704 emitc::UnaryPlusOp unaryPlusOp) {
705 Operation *operation = unaryPlusOp.getOperation();
706 return printUnaryOperation(emitter, operation, unaryOperator: "+");
707}
708
709static LogicalResult printOperation(CppEmitter &emitter,
710 emitc::UnaryMinusOp unaryMinusOp) {
711 Operation *operation = unaryMinusOp.getOperation();
712 return printUnaryOperation(emitter, operation, unaryOperator: "-");
713}
714
715static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {
716 raw_ostream &os = emitter.ostream();
717 Operation &op = *castOp.getOperation();
718
719 if (failed(result: emitter.emitAssignPrefix(op)))
720 return failure();
721 os << "(";
722 if (failed(result: emitter.emitType(loc: op.getLoc(), type: op.getResult(idx: 0).getType())))
723 return failure();
724 os << ") ";
725 return emitter.emitOperand(value: castOp.getOperand());
726}
727
728static LogicalResult printOperation(CppEmitter &emitter,
729 emitc::ExpressionOp expressionOp) {
730 if (shouldBeInlined(expressionOp))
731 return success();
732
733 Operation &op = *expressionOp.getOperation();
734
735 if (failed(result: emitter.emitAssignPrefix(op)))
736 return failure();
737
738 return emitter.emitExpression(expressionOp);
739}
740
741static LogicalResult printOperation(CppEmitter &emitter,
742 emitc::IncludeOp includeOp) {
743 raw_ostream &os = emitter.ostream();
744
745 os << "#include ";
746 if (includeOp.getIsStandardInclude())
747 os << "<" << includeOp.getInclude() << ">";
748 else
749 os << "\"" << includeOp.getInclude() << "\"";
750
751 return success();
752}
753
754static LogicalResult printOperation(CppEmitter &emitter,
755 emitc::LogicalAndOp logicalAndOp) {
756 Operation *operation = logicalAndOp.getOperation();
757 return printBinaryOperation(emitter, operation, binaryOperator: "&&");
758}
759
760static LogicalResult printOperation(CppEmitter &emitter,
761 emitc::LogicalNotOp logicalNotOp) {
762 Operation *operation = logicalNotOp.getOperation();
763 return printUnaryOperation(emitter, operation, unaryOperator: "!");
764}
765
766static LogicalResult printOperation(CppEmitter &emitter,
767 emitc::LogicalOrOp logicalOrOp) {
768 Operation *operation = logicalOrOp.getOperation();
769 return printBinaryOperation(emitter, operation, binaryOperator: "||");
770}
771
772static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
773
774 raw_indented_ostream &os = emitter.ostream();
775
776 // Utility function to determine whether a value is an expression that will be
777 // inlined, and as such should be wrapped in parentheses in order to guarantee
778 // its precedence and associativity.
779 auto requiresParentheses = [&](Value value) {
780 auto expressionOp =
781 dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
782 if (!expressionOp)
783 return false;
784 return shouldBeInlined(expressionOp);
785 };
786
787 os << "for (";
788 if (failed(
789 emitter.emitType(loc: forOp.getLoc(), type: forOp.getInductionVar().getType())))
790 return failure();
791 os << " ";
792 os << emitter.getOrCreateName(forOp.getInductionVar());
793 os << " = ";
794 if (failed(emitter.emitOperand(value: forOp.getLowerBound())))
795 return failure();
796 os << "; ";
797 os << emitter.getOrCreateName(forOp.getInductionVar());
798 os << " < ";
799 Value upperBound = forOp.getUpperBound();
800 bool upperBoundRequiresParentheses = requiresParentheses(upperBound);
801 if (upperBoundRequiresParentheses)
802 os << "(";
803 if (failed(result: emitter.emitOperand(value: upperBound)))
804 return failure();
805 if (upperBoundRequiresParentheses)
806 os << ")";
807 os << "; ";
808 os << emitter.getOrCreateName(forOp.getInductionVar());
809 os << " += ";
810 if (failed(emitter.emitOperand(value: forOp.getStep())))
811 return failure();
812 os << ") {\n";
813 os.indent();
814
815 Region &forRegion = forOp.getRegion();
816 auto regionOps = forRegion.getOps();
817
818 // We skip the trailing yield op.
819 for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
820 if (failed(emitter.emitOperation(op&: *it, /*trailingSemicolon=*/true)))
821 return failure();
822 }
823
824 os.unindent() << "}";
825
826 return success();
827}
828
829static LogicalResult printOperation(CppEmitter &emitter, emitc::IfOp ifOp) {
830 raw_indented_ostream &os = emitter.ostream();
831
832 // Helper function to emit all ops except the last one, expected to be
833 // emitc::yield.
834 auto emitAllExceptLast = [&emitter](Region &region) {
835 Region::OpIterator it = region.op_begin(), end = region.op_end();
836 for (; std::next(x: it) != end; ++it) {
837 if (failed(result: emitter.emitOperation(op&: *it, /*trailingSemicolon=*/true)))
838 return failure();
839 }
840 assert(isa<emitc::YieldOp>(*it) &&
841 "Expected last operation in the region to be emitc::yield");
842 return success();
843 };
844
845 os << "if (";
846 if (failed(emitter.emitOperand(value: ifOp.getCondition())))
847 return failure();
848 os << ") {\n";
849 os.indent();
850 if (failed(emitAllExceptLast(ifOp.getThenRegion())))
851 return failure();
852 os.unindent() << "}";
853
854 Region &elseRegion = ifOp.getElseRegion();
855 if (!elseRegion.empty()) {
856 os << " else {\n";
857 os.indent();
858 if (failed(result: emitAllExceptLast(elseRegion)))
859 return failure();
860 os.unindent() << "}";
861 }
862
863 return success();
864}
865
866static LogicalResult printOperation(CppEmitter &emitter,
867 func::ReturnOp returnOp) {
868 raw_ostream &os = emitter.ostream();
869 os << "return";
870 switch (returnOp.getNumOperands()) {
871 case 0:
872 return success();
873 case 1:
874 os << " ";
875 if (failed(emitter.emitOperand(value: returnOp.getOperand(0))))
876 return failure();
877 return success();
878 default:
879 os << " std::make_tuple(";
880 if (failed(emitter.emitOperandsAndAttributes(op&: *returnOp.getOperation())))
881 return failure();
882 os << ")";
883 return success();
884 }
885}
886
887static LogicalResult printOperation(CppEmitter &emitter,
888 emitc::ReturnOp returnOp) {
889 raw_ostream &os = emitter.ostream();
890 os << "return";
891 if (returnOp.getNumOperands() == 0)
892 return success();
893
894 os << " ";
895 if (failed(emitter.emitOperand(value: returnOp.getOperand())))
896 return failure();
897 return success();
898}
899
900static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
901 CppEmitter::Scope scope(emitter);
902
903 for (Operation &op : moduleOp) {
904 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
905 return failure();
906 }
907 return success();
908}
909
910static LogicalResult printFunctionArgs(CppEmitter &emitter,
911 Operation *functionOp,
912 ArrayRef<Type> arguments) {
913 raw_indented_ostream &os = emitter.ostream();
914
915 return (
916 interleaveCommaWithError(c: arguments, os, eachFn: [&](Type arg) -> LogicalResult {
917 return emitter.emitType(loc: functionOp->getLoc(), type: arg);
918 }));
919}
920
921static LogicalResult printFunctionArgs(CppEmitter &emitter,
922 Operation *functionOp,
923 Region::BlockArgListType arguments) {
924 raw_indented_ostream &os = emitter.ostream();
925
926 return (interleaveCommaWithError(
927 c: arguments, os, eachFn: [&](BlockArgument arg) -> LogicalResult {
928 return emitter.emitVariableDeclaration(
929 loc: functionOp->getLoc(), type: arg.getType(), name: emitter.getOrCreateName(val: arg));
930 }));
931}
932
933static LogicalResult printFunctionBody(CppEmitter &emitter,
934 Operation *functionOp,
935 Region::BlockListType &blocks) {
936 raw_indented_ostream &os = emitter.ostream();
937 os.indent();
938
939 if (emitter.shouldDeclareVariablesAtTop()) {
940 // Declare all variables that hold op results including those from nested
941 // regions.
942 WalkResult result =
943 functionOp->walk<WalkOrder::PreOrder>(callback: [&](Operation *op) -> WalkResult {
944 if (isa<emitc::LiteralOp>(op) ||
945 isa<emitc::ExpressionOp>(op->getParentOp()) ||
946 (isa<emitc::ExpressionOp>(op) &&
947 shouldBeInlined(cast<emitc::ExpressionOp>(op))))
948 return WalkResult::skip();
949 for (OpResult result : op->getResults()) {
950 if (failed(result: emitter.emitVariableDeclaration(
951 result, /*trailingSemicolon=*/true))) {
952 return WalkResult(
953 op->emitError(message: "unable to declare result variable for op"));
954 }
955 }
956 return WalkResult::advance();
957 });
958 if (result.wasInterrupted())
959 return failure();
960 }
961
962 // Create label names for basic blocks.
963 for (Block &block : blocks) {
964 emitter.getOrCreateName(block);
965 }
966
967 // Declare variables for basic block arguments.
968 for (Block &block : llvm::drop_begin(RangeOrContainer&: blocks)) {
969 for (BlockArgument &arg : block.getArguments()) {
970 if (emitter.hasValueInScope(val: arg))
971 return functionOp->emitOpError(message: " block argument #")
972 << arg.getArgNumber() << " is out of scope";
973 if (isa<ArrayType>(arg.getType()))
974 return functionOp->emitOpError(message: "cannot emit block argument #")
975 << arg.getArgNumber() << " with array type";
976 if (failed(
977 result: emitter.emitType(loc: block.getParentOp()->getLoc(), type: arg.getType()))) {
978 return failure();
979 }
980 os << " " << emitter.getOrCreateName(val: arg) << ";\n";
981 }
982 }
983
984 for (Block &block : blocks) {
985 // Only print a label if the block has predecessors.
986 if (!block.hasNoPredecessors()) {
987 if (failed(result: emitter.emitLabel(block)))
988 return failure();
989 }
990 for (Operation &op : block.getOperations()) {
991 // When generating code for an emitc.if or cf.cond_br op no semicolon
992 // needs to be printed after the closing brace.
993 // When generating code for an emitc.for and emitc.verbatim op, printing a
994 // trailing semicolon is handled within the printOperation function.
995 bool trailingSemicolon =
996 !isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::ForOp,
997 emitc::IfOp, emitc::LiteralOp, emitc::VerbatimOp>(op);
998
999 if (failed(result: emitter.emitOperation(
1000 op, /*trailingSemicolon=*/trailingSemicolon)))
1001 return failure();
1002 }
1003 }
1004
1005 os.unindent();
1006
1007 return success();
1008}
1009
1010static LogicalResult printOperation(CppEmitter &emitter,
1011 func::FuncOp functionOp) {
1012 // We need to declare variables at top if the function has multiple blocks.
1013 if (!emitter.shouldDeclareVariablesAtTop() &&
1014 functionOp.getBlocks().size() > 1) {
1015 return functionOp.emitOpError(
1016 "with multiple blocks needs variables declared at top");
1017 }
1018
1019 if (llvm::any_of(functionOp.getResultTypes(), llvm::IsaPred<ArrayType>)) {
1020 return functionOp.emitOpError() << "cannot emit array type as result type";
1021 }
1022
1023 CppEmitter::Scope scope(emitter);
1024 raw_indented_ostream &os = emitter.ostream();
1025 if (failed(emitter.emitTypes(loc: functionOp.getLoc(),
1026 types: functionOp.getFunctionType().getResults())))
1027 return failure();
1028 os << " " << functionOp.getName();
1029
1030 os << "(";
1031 Operation *operation = functionOp.getOperation();
1032 if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
1033 return failure();
1034 os << ") {\n";
1035 if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
1036 return failure();
1037 os << "}\n";
1038
1039 return success();
1040}
1041
1042static LogicalResult printOperation(CppEmitter &emitter,
1043 emitc::FuncOp functionOp) {
1044 // We need to declare variables at top if the function has multiple blocks.
1045 if (!emitter.shouldDeclareVariablesAtTop() &&
1046 functionOp.getBlocks().size() > 1) {
1047 return functionOp.emitOpError(
1048 "with multiple blocks needs variables declared at top");
1049 }
1050
1051 CppEmitter::Scope scope(emitter);
1052 raw_indented_ostream &os = emitter.ostream();
1053 if (functionOp.getSpecifiers()) {
1054 for (Attribute specifier : functionOp.getSpecifiersAttr()) {
1055 os << cast<StringAttr>(specifier).str() << " ";
1056 }
1057 }
1058
1059 if (failed(emitter.emitTypes(loc: functionOp.getLoc(),
1060 types: functionOp.getFunctionType().getResults())))
1061 return failure();
1062 os << " " << functionOp.getName();
1063
1064 os << "(";
1065 Operation *operation = functionOp.getOperation();
1066 if (functionOp.isExternal()) {
1067 if (failed(printFunctionArgs(emitter, operation,
1068 functionOp.getArgumentTypes())))
1069 return failure();
1070 os << ");";
1071 return success();
1072 }
1073 if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
1074 return failure();
1075 os << ") {\n";
1076 if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
1077 return failure();
1078 os << "}\n";
1079
1080 return success();
1081}
1082
1083static LogicalResult printOperation(CppEmitter &emitter,
1084 DeclareFuncOp declareFuncOp) {
1085 CppEmitter::Scope scope(emitter);
1086 raw_indented_ostream &os = emitter.ostream();
1087
1088 auto functionOp = SymbolTable::lookupNearestSymbolFrom<emitc::FuncOp>(
1089 declareFuncOp, declareFuncOp.getSymNameAttr());
1090
1091 if (!functionOp)
1092 return failure();
1093
1094 if (functionOp.getSpecifiers()) {
1095 for (Attribute specifier : functionOp.getSpecifiersAttr()) {
1096 os << cast<StringAttr>(specifier).str() << " ";
1097 }
1098 }
1099
1100 if (failed(emitter.emitTypes(loc: functionOp.getLoc(),
1101 types: functionOp.getFunctionType().getResults())))
1102 return failure();
1103 os << " " << functionOp.getName();
1104
1105 os << "(";
1106 Operation *operation = functionOp.getOperation();
1107 if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
1108 return failure();
1109 os << ");";
1110
1111 return success();
1112}
1113
1114CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
1115 : os(os), declareVariablesAtTop(declareVariablesAtTop) {
1116 valueInScopeCount.push(x: 0);
1117 labelInScopeCount.push(x: 0);
1118}
1119
1120std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
1121 std::string out;
1122 llvm::raw_string_ostream ss(out);
1123 ss << getOrCreateName(op.getValue());
1124 for (auto index : op.getIndices()) {
1125 ss << "[" << getOrCreateName(index) << "]";
1126 }
1127 return out;
1128}
1129
1130/// Return the existing or a new name for a Value.
1131StringRef CppEmitter::getOrCreateName(Value val) {
1132 if (auto literal = dyn_cast_if_present<emitc::LiteralOp>(val.getDefiningOp()))
1133 return literal.getValue();
1134 if (!valueMapper.count(Key: val)) {
1135 if (auto subscript =
1136 dyn_cast_if_present<emitc::SubscriptOp>(val.getDefiningOp())) {
1137 valueMapper.insert(val, getSubscriptName(subscript));
1138 } else if (auto getGlobal = dyn_cast_if_present<emitc::GetGlobalOp>(
1139 val.getDefiningOp())) {
1140 valueMapper.insert(Key: val, Val: getGlobal.getName().str());
1141 } else {
1142 valueMapper.insert(Key: val, Val: formatv(Fmt: "v{0}", Vals&: ++valueInScopeCount.top()));
1143 }
1144 }
1145 return *valueMapper.begin(Key: val);
1146}
1147
1148/// Return the existing or a new label for a Block.
1149StringRef CppEmitter::getOrCreateName(Block &block) {
1150 if (!blockMapper.count(Key: &block))
1151 blockMapper.insert(Key: &block, Val: formatv(Fmt: "label{0}", Vals&: ++labelInScopeCount.top()));
1152 return *blockMapper.begin(Key: &block);
1153}
1154
1155bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) {
1156 switch (val) {
1157 case IntegerType::Signless:
1158 return false;
1159 case IntegerType::Signed:
1160 return false;
1161 case IntegerType::Unsigned:
1162 return true;
1163 }
1164 llvm_unreachable("Unexpected IntegerType::SignednessSemantics");
1165}
1166
1167bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(Key: val); }
1168
1169bool CppEmitter::hasBlockLabel(Block &block) {
1170 return blockMapper.count(Key: &block);
1171}
1172
1173LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
1174 auto printInt = [&](const APInt &val, bool isUnsigned) {
1175 if (val.getBitWidth() == 1) {
1176 if (val.getBoolValue())
1177 os << "true";
1178 else
1179 os << "false";
1180 } else {
1181 SmallString<128> strValue;
1182 val.toString(Str&: strValue, Radix: 10, Signed: !isUnsigned, formatAsCLiteral: false);
1183 os << strValue;
1184 }
1185 };
1186
1187 auto printFloat = [&](const APFloat &val) {
1188 if (val.isFinite()) {
1189 SmallString<128> strValue;
1190 // Use default values of toString except don't truncate zeros.
1191 val.toString(Str&: strValue, FormatPrecision: 0, FormatMaxPadding: 0, TruncateZero: false);
1192 os << strValue;
1193 switch (llvm::APFloatBase::SemanticsToEnum(Sem: val.getSemantics())) {
1194 case llvm::APFloatBase::S_IEEEsingle:
1195 os << "f";
1196 break;
1197 case llvm::APFloatBase::S_IEEEdouble:
1198 break;
1199 default:
1200 llvm_unreachable("unsupported floating point type");
1201 };
1202 } else if (val.isNaN()) {
1203 os << "NAN";
1204 } else if (val.isInfinity()) {
1205 if (val.isNegative())
1206 os << "-";
1207 os << "INFINITY";
1208 }
1209 };
1210
1211 // Print floating point attributes.
1212 if (auto fAttr = dyn_cast<FloatAttr>(attr)) {
1213 if (!isa<Float32Type, Float64Type>(fAttr.getType())) {
1214 return emitError(loc,
1215 message: "expected floating point attribute to be f32 or f64");
1216 }
1217 printFloat(fAttr.getValue());
1218 return success();
1219 }
1220 if (auto dense = dyn_cast<DenseFPElementsAttr>(attr)) {
1221 if (!isa<Float32Type, Float64Type>(dense.getElementType())) {
1222 return emitError(loc,
1223 message: "expected floating point attribute to be f32 or f64");
1224 }
1225 os << '{';
1226 interleaveComma(c: dense, os, each_fn: [&](const APFloat &val) { printFloat(val); });
1227 os << '}';
1228 return success();
1229 }
1230
1231 // Print integer attributes.
1232 if (auto iAttr = dyn_cast<IntegerAttr>(attr)) {
1233 if (auto iType = dyn_cast<IntegerType>(iAttr.getType())) {
1234 printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness()));
1235 return success();
1236 }
1237 if (auto iType = dyn_cast<IndexType>(iAttr.getType())) {
1238 printInt(iAttr.getValue(), false);
1239 return success();
1240 }
1241 }
1242 if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
1243 if (auto iType = dyn_cast<IntegerType>(
1244 cast<TensorType>(dense.getType()).getElementType())) {
1245 os << '{';
1246 interleaveComma(c: dense, os, each_fn: [&](const APInt &val) {
1247 printInt(val, shouldMapToUnsigned(iType.getSignedness()));
1248 });
1249 os << '}';
1250 return success();
1251 }
1252 if (auto iType = dyn_cast<IndexType>(
1253 cast<TensorType>(dense.getType()).getElementType())) {
1254 os << '{';
1255 interleaveComma(c: dense, os,
1256 each_fn: [&](const APInt &val) { printInt(val, false); });
1257 os << '}';
1258 return success();
1259 }
1260 }
1261
1262 // Print opaque attributes.
1263 if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(attr)) {
1264 os << oAttr.getValue();
1265 return success();
1266 }
1267
1268 // Print symbolic reference attributes.
1269 if (auto sAttr = dyn_cast<SymbolRefAttr>(attr)) {
1270 if (sAttr.getNestedReferences().size() > 1)
1271 return emitError(loc, message: "attribute has more than 1 nested reference");
1272 os << sAttr.getRootReference().getValue();
1273 return success();
1274 }
1275
1276 // Print type attributes.
1277 if (auto type = dyn_cast<TypeAttr>(attr))
1278 return emitType(loc, type: type.getValue());
1279
1280 return emitError(loc, message: "cannot emit attribute: ") << attr;
1281}
1282
1283LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) {
1284 assert(emittedExpressionPrecedence.empty() &&
1285 "Expected precedence stack to be empty");
1286 Operation *rootOp = expressionOp.getRootOp();
1287
1288 emittedExpression = expressionOp;
1289 FailureOr<int> precedence = getOperatorPrecedence(operation: rootOp);
1290 if (failed(result: precedence))
1291 return failure();
1292 pushExpressionPrecedence(precedence: precedence.value());
1293
1294 if (failed(result: emitOperation(op&: *rootOp, /*trailingSemicolon=*/false)))
1295 return failure();
1296
1297 popExpressionPrecedence();
1298 assert(emittedExpressionPrecedence.empty() &&
1299 "Expected precedence stack to be empty");
1300 emittedExpression = nullptr;
1301
1302 return success();
1303}
1304
1305LogicalResult CppEmitter::emitOperand(Value value) {
1306 if (isPartOfCurrentExpression(value)) {
1307 Operation *def = value.getDefiningOp();
1308 assert(def && "Expected operand to be defined by an operation");
1309 FailureOr<int> precedence = getOperatorPrecedence(operation: def);
1310 if (failed(result: precedence))
1311 return failure();
1312 bool encloseInParenthesis = precedence.value() < getExpressionPrecedence();
1313 if (encloseInParenthesis) {
1314 os << "(";
1315 pushExpressionPrecedence(precedence: lowestPrecedence());
1316 } else
1317 pushExpressionPrecedence(precedence: precedence.value());
1318
1319 if (failed(result: emitOperation(op&: *def, /*trailingSemicolon=*/false)))
1320 return failure();
1321
1322 if (encloseInParenthesis)
1323 os << ")";
1324
1325 popExpressionPrecedence();
1326 return success();
1327 }
1328
1329 auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
1330 if (expressionOp && shouldBeInlined(expressionOp))
1331 return emitExpression(expressionOp);
1332
1333 auto literalOp = dyn_cast_if_present<LiteralOp>(value.getDefiningOp());
1334 if (!literalOp && !hasValueInScope(val: value))
1335 return failure();
1336 os << getOrCreateName(val: value);
1337 return success();
1338}
1339
1340LogicalResult CppEmitter::emitOperands(Operation &op) {
1341 return interleaveCommaWithError(c: op.getOperands(), os, eachFn: [&](Value operand) {
1342 // If an expression is being emitted, push lowest precedence as these
1343 // operands are either wrapped by parenthesis.
1344 if (getEmittedExpression())
1345 pushExpressionPrecedence(precedence: lowestPrecedence());
1346 if (failed(result: emitOperand(value: operand)))
1347 return failure();
1348 if (getEmittedExpression())
1349 popExpressionPrecedence();
1350 return success();
1351 });
1352}
1353
1354LogicalResult
1355CppEmitter::emitOperandsAndAttributes(Operation &op,
1356 ArrayRef<StringRef> exclude) {
1357 if (failed(result: emitOperands(op)))
1358 return failure();
1359 // Insert comma in between operands and non-filtered attributes if needed.
1360 if (op.getNumOperands() > 0) {
1361 for (NamedAttribute attr : op.getAttrs()) {
1362 if (!llvm::is_contained(exclude, attr.getName().strref())) {
1363 os << ", ";
1364 break;
1365 }
1366 }
1367 }
1368 // Emit attributes.
1369 auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
1370 if (llvm::is_contained(exclude, attr.getName().strref()))
1371 return success();
1372 os << "/* " << attr.getName().getValue() << " */";
1373 if (failed(result: emitAttribute(loc: op.getLoc(), attr: attr.getValue())))
1374 return failure();
1375 return success();
1376 };
1377 return interleaveCommaWithError(c: op.getAttrs(), os, eachFn: emitNamedAttribute);
1378}
1379
1380LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
1381 if (!hasValueInScope(val: result)) {
1382 return result.getDefiningOp()->emitOpError(
1383 message: "result variable for the operation has not been declared");
1384 }
1385 os << getOrCreateName(val: result) << " = ";
1386 return success();
1387}
1388
1389LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
1390 bool trailingSemicolon) {
1391 if (isa<emitc::SubscriptOp>(result.getDefiningOp()))
1392 return success();
1393 if (hasValueInScope(val: result)) {
1394 return result.getDefiningOp()->emitError(
1395 message: "result variable for the operation already declared");
1396 }
1397 if (failed(result: emitVariableDeclaration(loc: result.getOwner()->getLoc(),
1398 type: result.getType(),
1399 name: getOrCreateName(val: result))))
1400 return failure();
1401 if (trailingSemicolon)
1402 os << ";\n";
1403 return success();
1404}
1405
1406LogicalResult CppEmitter::emitGlobalVariable(GlobalOp op) {
1407 if (op.getExternSpecifier())
1408 os << "extern ";
1409 else if (op.getStaticSpecifier())
1410 os << "static ";
1411 if (op.getConstSpecifier())
1412 os << "const ";
1413
1414 if (failed(emitVariableDeclaration(op->getLoc(), op.getType(),
1415 op.getSymName()))) {
1416 return failure();
1417 }
1418
1419 std::optional<Attribute> initialValue = op.getInitialValue();
1420 if (initialValue) {
1421 os << " = ";
1422 if (failed(emitAttribute(loc: op->getLoc(), attr: *initialValue)))
1423 return failure();
1424 }
1425
1426 os << ";";
1427 return success();
1428}
1429
1430LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
1431 // If op is being emitted as part of an expression, bail out.
1432 if (getEmittedExpression())
1433 return success();
1434
1435 switch (op.getNumResults()) {
1436 case 0:
1437 break;
1438 case 1: {
1439 OpResult result = op.getResult(idx: 0);
1440 if (shouldDeclareVariablesAtTop()) {
1441 if (failed(result: emitVariableAssignment(result)))
1442 return failure();
1443 } else {
1444 if (failed(result: emitVariableDeclaration(result, /*trailingSemicolon=*/false)))
1445 return failure();
1446 os << " = ";
1447 }
1448 break;
1449 }
1450 default:
1451 if (!shouldDeclareVariablesAtTop()) {
1452 for (OpResult result : op.getResults()) {
1453 if (failed(result: emitVariableDeclaration(result, /*trailingSemicolon=*/true)))
1454 return failure();
1455 }
1456 }
1457 os << "std::tie(";
1458 interleaveComma(c: op.getResults(), os,
1459 each_fn: [&](Value result) { os << getOrCreateName(val: result); });
1460 os << ") = ";
1461 }
1462 return success();
1463}
1464
1465LogicalResult CppEmitter::emitLabel(Block &block) {
1466 if (!hasBlockLabel(block))
1467 return block.getParentOp()->emitError(message: "label for block not found");
1468 // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block
1469 // label instead of using `getOStream`.
1470 os.getOStream() << getOrCreateName(block) << ":\n";
1471 return success();
1472}
1473
1474LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
1475 LogicalResult status =
1476 llvm::TypeSwitch<Operation *, LogicalResult>(&op)
1477 // Builtin ops.
1478 .Case<ModuleOp>([&](auto op) { return printOperation(*this, op); })
1479 // CF ops.
1480 .Case<cf::BranchOp, cf::CondBranchOp>(
1481 [&](auto op) { return printOperation(*this, op); })
1482 // EmitC ops.
1483 .Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp,
1484 emitc::BitwiseAndOp, emitc::BitwiseLeftShiftOp,
1485 emitc::BitwiseNotOp, emitc::BitwiseOrOp,
1486 emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, emitc::CallOp,
1487 emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
1488 emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
1489 emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
1490 emitc::GlobalOp, emitc::GetGlobalOp, emitc::IfOp,
1491 emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp,
1492 emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
1493 emitc::SubOp, emitc::SubscriptOp, emitc::UnaryMinusOp,
1494 emitc::UnaryPlusOp, emitc::VariableOp, emitc::VerbatimOp>(
1495 [&](auto op) { return printOperation(*this, op); })
1496 // Func ops.
1497 .Case<func::CallOp, func::FuncOp, func::ReturnOp>(
1498 [&](auto op) { return printOperation(*this, op); })
1499 .Case<emitc::LiteralOp>([&](auto op) { return success(); })
1500 .Default([&](Operation *) {
1501 return op.emitOpError("unable to find printer for op");
1502 });
1503
1504 if (failed(result: status))
1505 return failure();
1506
1507 if (isa<emitc::LiteralOp, emitc::SubscriptOp, emitc::GetGlobalOp>(op))
1508 return success();
1509
1510 if (getEmittedExpression() ||
1511 (isa<emitc::ExpressionOp>(op) &&
1512 shouldBeInlined(cast<emitc::ExpressionOp>(op))))
1513 return success();
1514
1515 os << (trailingSemicolon ? ";\n" : "\n");
1516
1517 return success();
1518}
1519
1520LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type,
1521 StringRef name) {
1522 if (auto arrType = dyn_cast<emitc::ArrayType>(type)) {
1523 if (failed(emitType(loc, type: arrType.getElementType())))
1524 return failure();
1525 os << " " << name;
1526 for (auto dim : arrType.getShape()) {
1527 os << "[" << dim << "]";
1528 }
1529 return success();
1530 }
1531 if (failed(result: emitType(loc, type)))
1532 return failure();
1533 os << " " << name;
1534 return success();
1535}
1536
1537LogicalResult CppEmitter::emitType(Location loc, Type type) {
1538 if (auto iType = dyn_cast<IntegerType>(type)) {
1539 switch (iType.getWidth()) {
1540 case 1:
1541 return (os << "bool"), success();
1542 case 8:
1543 case 16:
1544 case 32:
1545 case 64:
1546 if (shouldMapToUnsigned(iType.getSignedness()))
1547 return (os << "uint" << iType.getWidth() << "_t"), success();
1548 else
1549 return (os << "int" << iType.getWidth() << "_t"), success();
1550 default:
1551 return emitError(loc, message: "cannot emit integer type ") << type;
1552 }
1553 }
1554 if (auto fType = dyn_cast<FloatType>(Val&: type)) {
1555 switch (fType.getWidth()) {
1556 case 32:
1557 return (os << "float"), success();
1558 case 64:
1559 return (os << "double"), success();
1560 default:
1561 return emitError(loc, message: "cannot emit float type ") << type;
1562 }
1563 }
1564 if (auto iType = dyn_cast<IndexType>(type))
1565 return (os << "size_t"), success();
1566 if (auto tType = dyn_cast<TensorType>(Val&: type)) {
1567 if (!tType.hasRank())
1568 return emitError(loc, message: "cannot emit unranked tensor type");
1569 if (!tType.hasStaticShape())
1570 return emitError(loc, message: "cannot emit tensor type with non static shape");
1571 os << "Tensor<";
1572 if (isa<ArrayType>(tType.getElementType()))
1573 return emitError(loc, message: "cannot emit tensor of array type ") << type;
1574 if (failed(result: emitType(loc, type: tType.getElementType())))
1575 return failure();
1576 auto shape = tType.getShape();
1577 for (auto dimSize : shape) {
1578 os << ", ";
1579 os << dimSize;
1580 }
1581 os << ">";
1582 return success();
1583 }
1584 if (auto tType = dyn_cast<TupleType>(type))
1585 return emitTupleType(loc, types: tType.getTypes());
1586 if (auto oType = dyn_cast<emitc::OpaqueType>(type)) {
1587 os << oType.getValue();
1588 return success();
1589 }
1590 if (auto aType = dyn_cast<emitc::ArrayType>(type)) {
1591 if (failed(emitType(loc, type: aType.getElementType())))
1592 return failure();
1593 for (auto dim : aType.getShape())
1594 os << "[" << dim << "]";
1595 return success();
1596 }
1597 if (auto pType = dyn_cast<emitc::PointerType>(type)) {
1598 if (isa<ArrayType>(pType.getPointee()))
1599 return emitError(loc, message: "cannot emit pointer to array type ") << type;
1600 if (failed(emitType(loc, type: pType.getPointee())))
1601 return failure();
1602 os << "*";
1603 return success();
1604 }
1605 return emitError(loc, message: "cannot emit type ") << type;
1606}
1607
1608LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
1609 switch (types.size()) {
1610 case 0:
1611 os << "void";
1612 return success();
1613 case 1:
1614 return emitType(loc, type: types.front());
1615 default:
1616 return emitTupleType(loc, types);
1617 }
1618}
1619
1620LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
1621 if (llvm::any_of(types, llvm::IsaPred<ArrayType>)) {
1622 return emitError(loc, message: "cannot emit tuple of array type");
1623 }
1624 os << "std::tuple<";
1625 if (failed(result: interleaveCommaWithError(
1626 c: types, os, eachFn: [&](Type type) { return emitType(loc, type); })))
1627 return failure();
1628 os << ">";
1629 return success();
1630}
1631
1632LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
1633 bool declareVariablesAtTop) {
1634 CppEmitter emitter(os, declareVariablesAtTop);
1635 return emitter.emitOperation(op&: *op, /*trailingSemicolon=*/false);
1636}
1637

source code of mlir/lib/Target/Cpp/TranslateToCpp.cpp