1//===- PrintCallHelper.cpp - Helper to emit runtime print calls -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
10#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
11#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
12#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13#include "mlir/IR/Builders.h"
14#include "mlir/IR/BuiltinOps.h"
15#include "llvm/ADT/ArrayRef.h"
16
17using namespace mlir;
18using namespace llvm;
19
20static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
21 StringRef symbolName) {
22 static int counter = 0;
23 std::string uniqueName = std::string(symbolName);
24 while (moduleOp.lookupSymbol(uniqueName)) {
25 uniqueName = std::string(symbolName) + "_" + std::to_string(val: counter++);
26 }
27 return uniqueName;
28}
29
30void mlir::LLVM::createPrintStrCall(
31 OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
32 StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline,
33 std::optional<StringRef> runtimeFunctionName) {
34 auto ip = builder.saveInsertionPoint();
35 builder.setInsertionPointToStart(moduleOp.getBody());
36 MLIRContext *ctx = builder.getContext();
37
38 // Create a zero-terminated byte representation and allocate global symbol.
39 SmallVector<uint8_t> elementVals;
40 elementVals.append(in_start: string.begin(), in_end: string.end());
41 if (addNewline)
42 elementVals.push_back(Elt: '\n');
43 elementVals.push_back(Elt: '\0');
44 auto dataAttrType = RankedTensorType::get(
45 {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
46 auto dataAttr =
47 DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
48 auto arrayTy =
49 LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
50 auto globalOp = builder.create<LLVM::GlobalOp>(
51 loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private,
52 ensureSymbolNameIsUnique(moduleOp, symbolName), dataAttr);
53
54 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
55 // Emit call to `printStr` in runtime library.
56 builder.restoreInsertionPoint(ip);
57 auto msgAddr =
58 builder.create<LLVM::AddressOfOp>(loc, ptrTy, globalOp.getName());
59 SmallVector<LLVM::GEPArg> indices(1, 0);
60 Value gep =
61 builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, msgAddr, indices);
62 Operation *printer =
63 LLVM::lookupOrCreatePrintStringFn(moduleOp: moduleOp, runtimeFunctionName);
64 builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
65 gep);
66}
67

source code of mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp