1//===- PrintCallHelper.cpp - Helper to emit runtime print calls -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
10#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
11#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
12#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13#include "mlir/IR/Builders.h"
14#include "mlir/IR/BuiltinOps.h"
15#include "llvm/ADT/ArrayRef.h"
16
17using namespace mlir;
18using namespace llvm;
19
20/// Check if a given symbol name is already in use within the module operation.
21/// If no symbol with such name is present, then the same identifier is
22/// returned. Otherwise, a unique and yet unused identifier is computed starting
23/// from the requested one.
24static std::string
25ensureSymbolNameIsUnique(ModuleOp moduleOp, StringRef symbolName,
26 SymbolTableCollection *symbolTables = nullptr) {
27 if (symbolTables) {
28 SymbolTable &symbolTable = symbolTables->getSymbolTable(op: moduleOp);
29 unsigned counter = 0;
30 SmallString<128> uniqueName = symbolTable.generateSymbolName<128>(
31 name: symbolName,
32 uniqueChecker: [&](const SmallString<128> &tentativeName) {
33 return symbolTable.lookupSymbolIn(op: moduleOp, symbol: tentativeName) != nullptr;
34 },
35 uniquingCounter&: counter);
36
37 return static_cast<std::string>(uniqueName);
38 }
39
40 static int counter = 0;
41 std::string uniqueName = std::string(symbolName);
42 while (moduleOp.lookupSymbol(name: uniqueName)) {
43 uniqueName = std::string(symbolName) + "_" + std::to_string(val: counter++);
44 }
45 return uniqueName;
46}
47
48LogicalResult mlir::LLVM::createPrintStrCall(
49 OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
50 StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline,
51 std::optional<StringRef> runtimeFunctionName,
52 SymbolTableCollection *symbolTables) {
53 auto ip = builder.saveInsertionPoint();
54 builder.setInsertionPointToStart(moduleOp.getBody());
55 MLIRContext *ctx = builder.getContext();
56
57 // Create a zero-terminated byte representation and allocate global symbol.
58 SmallVector<uint8_t> elementVals;
59 elementVals.append(in_start: string.begin(), in_end: string.end());
60 if (addNewline)
61 elementVals.push_back(Elt: '\n');
62 elementVals.push_back(Elt: '\0');
63 auto dataAttrType = RankedTensorType::get(
64 shape: {static_cast<int64_t>(elementVals.size())}, elementType: builder.getI8Type());
65 auto dataAttr =
66 DenseElementsAttr::get(type: dataAttrType, values: llvm::ArrayRef(elementVals));
67 auto arrayTy =
68 LLVM::LLVMArrayType::get(elementType: IntegerType::get(context: ctx, width: 8), numElements: elementVals.size());
69 auto globalOp = builder.create<LLVM::GlobalOp>(
70 location: loc, args&: arrayTy, /*constant=*/args: true, args: LLVM::Linkage::Private,
71 args: ensureSymbolNameIsUnique(moduleOp, symbolName, symbolTables), args&: dataAttr);
72
73 auto ptrTy = LLVM::LLVMPointerType::get(context: builder.getContext());
74 // Emit call to `printStr` in runtime library.
75 builder.restoreInsertionPoint(ip);
76 auto msgAddr =
77 builder.create<LLVM::AddressOfOp>(location: loc, args&: ptrTy, args: globalOp.getName());
78 SmallVector<LLVM::GEPArg> indices(1, 0);
79 Value gep =
80 builder.create<LLVM::GEPOp>(location: loc, args&: ptrTy, args&: arrayTy, args&: msgAddr, args&: indices);
81 FailureOr<LLVM::LLVMFuncOp> printer =
82 LLVM::lookupOrCreatePrintStringFn(b&: builder, moduleOp, runtimeFunctionName);
83 if (failed(Result: printer))
84 return failure();
85 builder.create<LLVM::CallOp>(location: loc, args: TypeRange(),
86 args: SymbolRefAttr::get(symbol: printer.value()), args&: gep);
87 return success();
88}
89

source code of mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp