1//===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8#ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
9#define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
10
11#include "mlir/Conversion/LLVMCommon/Pattern.h"
12#include "mlir/Dialect/GPU/IR/GPUDialect.h"
13#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14#include "mlir/IR/Builders.h"
15
16namespace mlir {
17
18/// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func`
19/// depending on the element type that Op operates upon. The function
20/// declaration is added in case it was not added before.
21///
22/// If the input values are of f16 type, the value is first casted to f32, the
23/// function called and then the result casted back.
24///
25/// Example with NVVM:
26/// %exp_f32 = math.exp %arg_f32 : f32
27///
28/// will be transformed into
29/// llvm.call @__nv_expf(%arg_f32) : (f32) -> f32
30template <typename SourceOp>
31struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
32public:
33 explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func,
34 StringRef f64Func)
35 : ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
36 f64Func(f64Func) {}
37
38 LogicalResult
39 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
40 ConversionPatternRewriter &rewriter) const override {
41 using LLVM::LLVMFuncOp;
42
43 static_assert(
44 std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
45 "expected single result op");
46
47 static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
48 SourceOp>::value,
49 "expected op with same operand and result types");
50
51 SmallVector<Value, 1> castedOperands;
52 for (Value operand : adaptor.getOperands())
53 castedOperands.push_back(Elt: maybeCast(operand, rewriter));
54
55 Type resultType = castedOperands.front().getType();
56 Type funcType = getFunctionType(resultType, operands: castedOperands);
57 StringRef funcName =
58 getFunctionName(type: cast<LLVM::LLVMFunctionType>(funcType).getReturnType());
59 if (funcName.empty())
60 return failure();
61
62 LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
63 auto callOp =
64 rewriter.create<LLVM::CallOp>(op->getLoc(), funcOp, castedOperands);
65
66 if (resultType == adaptor.getOperands().front().getType()) {
67 rewriter.replaceOp(op, {callOp.getResult()});
68 return success();
69 }
70
71 Value truncated = rewriter.create<LLVM::FPTruncOp>(
72 op->getLoc(), adaptor.getOperands().front().getType(),
73 callOp.getResult());
74 rewriter.replaceOp(op, {truncated});
75 return success();
76 }
77
78private:
79 Value maybeCast(Value operand, PatternRewriter &rewriter) const {
80 Type type = operand.getType();
81 if (!isa<Float16Type>(type))
82 return operand;
83
84 return rewriter.create<LLVM::FPExtOp>(
85 operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
86 }
87
88 Type getFunctionType(Type resultType, ValueRange operands) const {
89 SmallVector<Type> operandTypes(operands.getTypes());
90 return LLVM::LLVMFunctionType::get(resultType, operandTypes);
91 }
92
93 StringRef getFunctionName(Type type) const {
94 if (isa<Float32Type>(type))
95 return f32Func;
96 if (isa<Float64Type>(type))
97 return f64Func;
98 return "";
99 }
100
101 LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType,
102 Operation *op) const {
103 using LLVM::LLVMFuncOp;
104
105 auto funcAttr = StringAttr::get(op->getContext(), funcName);
106 Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr);
107 if (funcOp)
108 return cast<LLVMFuncOp>(*funcOp);
109
110 mlir::OpBuilder b(op->getParentOfType<FunctionOpInterface>());
111 return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
112 }
113
114 const std::string f32Func;
115 const std::string f64Func;
116};
117
118} // namespace mlir
119
120#endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
121

source code of mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h