1//===-- MathToROCDL.cpp - conversion from Math to rocdl calls -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
10#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
11#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
12#include "mlir/Dialect/Func/IR/FuncOps.h"
13#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
15#include "mlir/Dialect/Math/IR/Math.h"
16#include "mlir/Dialect/Vector/IR/VectorOps.h"
17#include "mlir/IR/BuiltinDialect.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/Pass/Pass.h"
20#include "mlir/Transforms/DialectConversion.h"
21
22#include "../GPUCommon/GPUOpsLowering.h"
23#include "../GPUCommon/OpToFuncCallLowering.h"
24#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
25
26namespace mlir {
27#define GEN_PASS_DEF_CONVERTMATHTOROCDL
28#include "mlir/Conversion/Passes.h.inc"
29} // namespace mlir
30
31using namespace mlir;
32
33#define DEBUG_TYPE "math-to-rocdl"
34#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
35
36template <typename OpTy>
37static void populateOpPatterns(const LLVMTypeConverter &converter,
38 RewritePatternSet &patterns, StringRef f32Func,
39 StringRef f64Func, StringRef f16Func,
40 StringRef f32ApproxFunc = "") {
41 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
42 patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
43 f32ApproxFunc, f16Func);
44}
45
46void mlir::populateMathToROCDLConversionPatterns(
47 const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
48 // Handled by mathToLLVM: math::AbsIOp
49 // Handled by mathToLLVM: math::AbsFOp
50 // Handled by mathToLLVM: math::CopySignOp
51 // Handled by mathToLLVM: math::CountLeadingZerosOp
52 // Handled by mathToLLVM: math::CountTrailingZerosOp
53 // Handled by mathToLLVM: math::CgPopOp
54 // Handled by mathToLLVM: math::ExpOp (32-bit only)
55 // Handled by mathToLLVM: math::FmaOp
56 // Handled by mathToLLVM: math::LogOp (32-bit only)
57 // FIXME: math::IPowIOp
58 // Handled by mathToLLVM: math::RoundEvenOp
59 // Handled by mathToLLVM: math::RoundOp
60 // Handled by mathToLLVM: math::SqrtOp
61 // Handled by mathToLLVM: math::TruncOp
62 populateOpPatterns<math::AcosOp>(converter, patterns, f32Func: "__ocml_acos_f32",
63 f64Func: "__ocml_acos_f64", f16Func: "__ocml_acos_f16");
64 populateOpPatterns<math::AcoshOp>(converter, patterns, f32Func: "__ocml_acosh_f32",
65 f64Func: "__ocml_acosh_f64", f16Func: "__ocml_acosh_f16");
66 populateOpPatterns<math::AsinOp>(converter, patterns, f32Func: "__ocml_asin_f32",
67 f64Func: "__ocml_asin_f64", f16Func: "__ocml_asin_f16");
68 populateOpPatterns<math::AsinhOp>(converter, patterns, f32Func: "__ocml_asinh_f32",
69 f64Func: "__ocml_asinh_f64", f16Func: "__ocml_asinh_f16");
70 populateOpPatterns<math::AtanOp>(converter, patterns, f32Func: "__ocml_atan_f32",
71 f64Func: "__ocml_atan_f64", f16Func: "__ocml_atan_f16");
72 populateOpPatterns<math::AtanhOp>(converter, patterns, f32Func: "__ocml_atanh_f32",
73 f64Func: "__ocml_atanh_f64", f16Func: "__ocml_atanh_f16");
74 populateOpPatterns<math::Atan2Op>(converter, patterns, f32Func: "__ocml_atan2_f32",
75 f64Func: "__ocml_atan2_f64", f16Func: "__ocml_atan2_f16");
76 populateOpPatterns<math::CbrtOp>(converter, patterns, f32Func: "__ocml_cbrt_f32",
77 f64Func: "__ocml_cbrt_f64", f16Func: "__ocml_cbrt_f16");
78 populateOpPatterns<math::CeilOp>(converter, patterns, f32Func: "__ocml_ceil_f32",
79 f64Func: "__ocml_ceil_f64", f16Func: "__ocml_ceil_f16");
80 populateOpPatterns<math::CosOp>(converter, patterns, f32Func: "__ocml_cos_f32",
81 f64Func: "__ocml_cos_f64", f16Func: "__ocml_cos_f16");
82 populateOpPatterns<math::CoshOp>(converter, patterns, f32Func: "__ocml_cosh_f32",
83 f64Func: "__ocml_cosh_f64", f16Func: "__ocml_cosh_f16");
84 populateOpPatterns<math::SinhOp>(converter, patterns, f32Func: "__ocml_sinh_f32",
85 f64Func: "__ocml_sinh_f64", f16Func: "__ocml_sinh_f16");
86 populateOpPatterns<math::ExpOp>(converter, patterns, f32Func: "", f64Func: "__ocml_exp_f64",
87 f16Func: "__ocml_exp_f16");
88 populateOpPatterns<math::Exp2Op>(converter, patterns, f32Func: "__ocml_exp2_f32",
89 f64Func: "__ocml_exp2_f64", f16Func: "__ocml_exp2_f16");
90 populateOpPatterns<math::ExpM1Op>(converter, patterns, f32Func: "__ocml_expm1_f32",
91 f64Func: "__ocml_expm1_f64", f16Func: "__ocml_expm1_f16");
92 populateOpPatterns<math::FloorOp>(converter, patterns, f32Func: "__ocml_floor_f32",
93 f64Func: "__ocml_floor_f64", f16Func: "__ocml_floor_f16");
94 populateOpPatterns<math::LogOp>(converter, patterns, f32Func: "", f64Func: "__ocml_log_f64",
95 f16Func: "__ocml_log_f16");
96 populateOpPatterns<math::Log10Op>(converter, patterns, f32Func: "__ocml_log10_f32",
97 f64Func: "__ocml_log10_f64", f16Func: "__ocml_log10_f16");
98 populateOpPatterns<math::Log1pOp>(converter, patterns, f32Func: "__ocml_log1p_f32",
99 f64Func: "__ocml_log1p_f64", f16Func: "__ocml_log1p_f16");
100 populateOpPatterns<math::Log2Op>(converter, patterns, f32Func: "__ocml_log2_f32",
101 f64Func: "__ocml_log2_f64", f16Func: "__ocml_log2_f16");
102 populateOpPatterns<math::PowFOp>(converter, patterns, f32Func: "__ocml_pow_f32",
103 f64Func: "__ocml_pow_f64", f16Func: "__ocml_pow_f16");
104 populateOpPatterns<math::RsqrtOp>(converter, patterns, f32Func: "__ocml_rsqrt_f32",
105 f64Func: "__ocml_rsqrt_f64", f16Func: "__ocml_rsqrt_f16");
106 populateOpPatterns<math::SinOp>(converter, patterns, f32Func: "__ocml_sin_f32",
107 f64Func: "__ocml_sin_f64", f16Func: "__ocml_sin_f16");
108 populateOpPatterns<math::TanhOp>(converter, patterns, f32Func: "__ocml_tanh_f32",
109 f64Func: "__ocml_tanh_f64", f16Func: "__ocml_tanh_f16");
110 populateOpPatterns<math::TanOp>(converter, patterns, f32Func: "__ocml_tan_f32",
111 f64Func: "__ocml_tan_f64", f16Func: "__ocml_tan_f16");
112 populateOpPatterns<math::ErfOp>(converter, patterns, f32Func: "__ocml_erf_f32",
113 f64Func: "__ocml_erf_f64", f16Func: "__ocml_erf_f16");
114 populateOpPatterns<math::ErfcOp>(converter, patterns, f32Func: "__ocml_erfc_f32",
115 f64Func: "__ocml_erfc_f64", f16Func: "__ocml_erfc_f16");
116 populateOpPatterns<math::FPowIOp>(converter, patterns, f32Func: "__ocml_pown_f32",
117 f64Func: "__ocml_pown_f64", f16Func: "__ocml_pown_f16");
118 // Single arith pattern that needs a ROCDL call, probably not
119 // worth creating a separate pass for it.
120 populateOpPatterns<arith::RemFOp>(converter, patterns, f32Func: "__ocml_fmod_f32",
121 f64Func: "__ocml_fmod_f64", f16Func: "__ocml_fmod_f16");
122}
123
124namespace {
125struct ConvertMathToROCDLPass
126 : public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
127 ConvertMathToROCDLPass() = default;
128 void runOnOperation() override;
129};
130} // namespace
131
132void ConvertMathToROCDLPass::runOnOperation() {
133 auto m = getOperation();
134 MLIRContext *ctx = m.getContext();
135
136 RewritePatternSet patterns(&getContext());
137 LowerToLLVMOptions options(ctx, DataLayout(m));
138 LLVMTypeConverter converter(ctx, options);
139 populateMathToROCDLConversionPatterns(converter, patterns);
140 ConversionTarget target(getContext());
141 target.addLegalDialect<BuiltinDialect, func::FuncDialect,
142 vector::VectorDialect, LLVM::LLVMDialect>();
143 target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
144 LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
145 LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
146 LLVM::SqrtOp>();
147 if (failed(Result: applyPartialConversion(op: m, target, patterns: std::move(patterns))))
148 signalPassFailure();
149}
150

source code of mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp