1 | //===- OptimizeForNVVM.cpp - Optimize LLVM IR for NVVM ---------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h" |
10 | |
11 | #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
12 | #include "mlir/IR/Builders.h" |
13 | #include "mlir/IR/PatternMatch.h" |
14 | #include "mlir/Pass/Pass.h" |
15 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
16 | |
17 | namespace mlir { |
18 | namespace NVVM { |
19 | #define GEN_PASS_DEF_NVVMOPTIMIZEFORTARGET |
20 | #include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc" |
21 | } // namespace NVVM |
22 | } // namespace mlir |
23 | |
24 | using namespace mlir; |
25 | |
26 | namespace { |
27 | // Replaces fdiv on fp16 with fp32 multiplication with reciprocal plus one |
28 | // (conditional) Newton iteration. |
29 | // |
30 | // This as accurate as promoting the division to fp32 in the NVPTX backend, but |
31 | // faster because it performs less Newton iterations, avoids the slow path |
32 | // for e.g. denormals, and allows reuse of the reciprocal for multiple divisions |
33 | // by the same divisor. |
34 | struct ExpandDivF16 : public OpRewritePattern<LLVM::FDivOp> { |
35 | using OpRewritePattern<LLVM::FDivOp>::OpRewritePattern; |
36 | |
37 | private: |
38 | LogicalResult matchAndRewrite(LLVM::FDivOp op, |
39 | PatternRewriter &rewriter) const override; |
40 | }; |
41 | |
42 | struct NVVMOptimizeForTarget |
43 | : public NVVM::impl::NVVMOptimizeForTargetBase<NVVMOptimizeForTarget> { |
44 | void runOnOperation() override; |
45 | |
46 | void getDependentDialects(DialectRegistry ®istry) const override { |
47 | registry.insert<NVVM::NVVMDialect>(); |
48 | } |
49 | }; |
50 | } // namespace |
51 | |
52 | LogicalResult ExpandDivF16::matchAndRewrite(LLVM::FDivOp op, |
53 | PatternRewriter &rewriter) const { |
54 | if (!op.getType().isF16()) |
55 | return rewriter.notifyMatchFailure(op, "not f16" ); |
56 | Location loc = op.getLoc(); |
57 | |
58 | Type f32Type = rewriter.getF32Type(); |
59 | Type i32Type = rewriter.getI32Type(); |
60 | |
61 | // Extend lhs and rhs to fp32. |
62 | Value lhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, op.getLhs()); |
63 | Value rhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, op.getRhs()); |
64 | |
65 | // float rcp = rcp.approx.ftz.f32(rhs), approx = lhs * rcp. |
66 | Value rcp = rewriter.create<NVVM::RcpApproxFtzF32Op>(loc, f32Type, rhs); |
67 | Value approx = rewriter.create<LLVM::FMulOp>(loc, lhs, rcp); |
68 | |
69 | // Refine the approximation with one Newton iteration: |
70 | // float refined = approx + (lhs - approx * rhs) * rcp; |
71 | Value err = rewriter.create<LLVM::FMAOp>( |
72 | loc, approx, rewriter.create<LLVM::FNegOp>(loc, rhs), lhs); |
73 | Value refined = rewriter.create<LLVM::FMAOp>(loc, err, rcp, approx); |
74 | |
75 | // Use refined value if approx is normal (exponent neither all 0 or all 1). |
76 | Value mask = rewriter.create<LLVM::ConstantOp>( |
77 | loc, i32Type, rewriter.getUI32IntegerAttr(0x7f800000)); |
78 | Value cast = rewriter.create<LLVM::BitcastOp>(loc, i32Type, approx); |
79 | Value exp = rewriter.create<LLVM::AndOp>(loc, i32Type, cast, mask); |
80 | Value zero = rewriter.create<LLVM::ConstantOp>( |
81 | loc, i32Type, rewriter.getUI32IntegerAttr(0)); |
82 | Value pred = rewriter.create<LLVM::OrOp>( |
83 | loc, |
84 | rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, zero), |
85 | rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, mask)); |
86 | Value result = |
87 | rewriter.create<LLVM::SelectOp>(loc, f32Type, pred, approx, refined); |
88 | |
89 | // Replace with trucation back to fp16. |
90 | rewriter.replaceOpWithNewOp<LLVM::FPTruncOp>(op, op.getType(), result); |
91 | |
92 | return success(); |
93 | } |
94 | |
95 | void NVVMOptimizeForTarget::runOnOperation() { |
96 | MLIRContext *ctx = getOperation()->getContext(); |
97 | RewritePatternSet patterns(ctx); |
98 | patterns.add<ExpandDivF16>(arg&: ctx); |
99 | if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) |
100 | return signalPassFailure(); |
101 | } |
102 | |
103 | std::unique_ptr<Pass> NVVM::createOptimizeForTargetPass() { |
104 | return std::make_unique<NVVMOptimizeForTarget>(); |
105 | } |
106 | |