1//===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
10
11#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/Arith/Utils/Utils.h"
14#include "mlir/Dialect/Vector/IR/VectorOps.h"
15#include "mlir/IR/BuiltinTypes.h"
16#include "mlir/IR/PatternMatch.h"
17#include "mlir/IR/TypeUtilities.h"
18#include "mlir/Pass/Pass.h"
19#include "mlir/Support/LogicalResult.h"
20#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21
22namespace mlir {
23#define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
24#include "mlir/Conversion/Passes.h.inc"
25} // namespace mlir
26
27using namespace mlir;
28
29namespace {
30struct ArithToAMDGPUConversionPass final
31 : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
32 using impl::ArithToAMDGPUConversionPassBase<
33 ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase;
34
35 void runOnOperation() override;
36};
37
38struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
39 using OpRewritePattern::OpRewritePattern;
40
41 LogicalResult match(arith::ExtFOp op) const override;
42 void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
43};
44
45struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
46 bool saturateFP8 = false;
47 TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8)
48 : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8) {}
49
50 LogicalResult match(arith::TruncFOp op) const override;
51 void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
52};
53} // end namespace
54
55static Value castF32To(Type elementType, Value f32, Location loc,
56 PatternRewriter &rewriter) {
57 if (elementType.isF32())
58 return f32;
59 if (elementType.getIntOrFloatBitWidth() < 32)
60 return rewriter.create<arith::TruncFOp>(loc, elementType, f32);
61 if (elementType.getIntOrFloatBitWidth() > 32)
62 return rewriter.create<arith::ExtFOp>(loc, elementType, f32);
63 llvm_unreachable("The only 32-bit float type is f32");
64}
65
66LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
67 Type inType = op.getIn().getType();
68 if (auto inVecType = dyn_cast<VectorType>(inType)) {
69 if (inVecType.isScalable())
70 return failure();
71 if (inVecType.getShape().size() > 1)
72 // Multi-dimensional vectors are currently unsupported.
73 return failure();
74 inType = inVecType.getElementType();
75 }
76 return success(isSuccess: inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
77}
78
79void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
80 PatternRewriter &rewriter) const {
81 Location loc = op.getLoc();
82 Value in = op.getIn();
83 Type outElemType = getElementTypeOrSelf(op.getOut().getType());
84 if (!isa<VectorType>(Val: in.getType())) {
85 Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
86 loc, rewriter.getF32Type(), in, 0);
87 Value result = castF32To(elementType: outElemType, f32: asFloat, loc, rewriter);
88 return rewriter.replaceOp(op, result);
89 }
90 VectorType inType = cast<VectorType>(in.getType());
91 int64_t numElements = inType.getNumElements();
92 Value zero = rewriter.create<arith::ConstantOp>(
93 loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
94 Value result =
95 rewriter.createOrFold<vector::SplatOp>(loc, op.getOut().getType(), zero);
96 if (inType.getShape().empty()) {
97 Value scalarIn =
98 rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
99 // Recurse to send the 0-D vector case to the 1-D vector case
100 Value scalarExt =
101 rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
102 result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
103 ArrayRef<int64_t>{});
104 return rewriter.replaceOp(op, result);
105 }
106 for (int64_t i = 0; i < numElements; i += 4) {
107 int64_t elemsThisOp = std::min(a: numElements, b: i + 4) - i;
108 Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
109 loc, in, i, elemsThisOp, 1);
110 for (int64_t j = 0; j < elemsThisOp; ++j) {
111 Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
112 loc, rewriter.getF32Type(), inSlice, j);
113 Value asType = castF32To(elementType: outElemType, f32: asFloat, loc, rewriter);
114 result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
115 }
116 }
117 rewriter.replaceOp(op, result);
118}
119
120static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
121 Type type = value.getType();
122 if (type.isF32())
123 return value;
124 if (type.getIntOrFloatBitWidth() < 32)
125 return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value);
126 if (type.getIntOrFloatBitWidth() > 32)
127 return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value);
128 llvm_unreachable("The only 32-bit float type is f32");
129}
130
131// If `in` is a finite value, clamp it between the maximum and minimum values
132// of `outElemType` so that subsequent conversion instructions don't
133// overflow those out-of-range values to NaN. These semantics are commonly
134// used in machine-learning contexts where failure to clamp would lead to
135// excessive NaN production.
136static Value clampInput(PatternRewriter &rewriter, Location loc,
137 Type outElemType, Value source) {
138 Type sourceType = source.getType();
139 const llvm::fltSemantics &sourceSem =
140 cast<FloatType>(Val: getElementTypeOrSelf(type: sourceType)).getFloatSemantics();
141 const llvm::fltSemantics &targetSem =
142 cast<FloatType>(Val&: outElemType).getFloatSemantics();
143
144 APFloat min = APFloat::getLargest(Sem: targetSem, /*Negative=*/true);
145 APFloat max = APFloat::getLargest(Sem: targetSem, /*Negative=*/false);
146 bool ignoredLosesInfo = false;
147 // We can ignore conversion failures here because this conversion promotes
148 // from a smaller type to a larger one - ex. there can be no loss of precision
149 // when casting fp8 to f16.
150 (void)min.convert(ToSemantics: sourceSem, RM: APFloat::rmNearestTiesToEven, losesInfo: &ignoredLosesInfo);
151 (void)max.convert(ToSemantics: sourceSem, RM: APFloat::rmNearestTiesToEven, losesInfo: &ignoredLosesInfo);
152
153 Value minCst = createScalarOrSplatConstant(builder&: rewriter, loc, type: sourceType, value: min);
154 Value maxCst = createScalarOrSplatConstant(builder&: rewriter, loc, type: sourceType, value: max);
155
156 Value inf = createScalarOrSplatConstant(
157 builder&: rewriter, loc, type: sourceType,
158 value: APFloat::getInf(Sem: sourceSem, /*Negative=*/false));
159 Value negInf = createScalarOrSplatConstant(
160 builder&: rewriter, loc, type: sourceType, value: APFloat::getInf(Sem: sourceSem, /*Negative=*/true));
161 Value isInf = rewriter.createOrFold<arith::CmpFOp>(
162 loc, arith::CmpFPredicate::OEQ, source, inf);
163 Value isNegInf = rewriter.createOrFold<arith::CmpFOp>(
164 loc, arith::CmpFPredicate::OEQ, source, negInf);
165 Value isNan = rewriter.createOrFold<arith::CmpFOp>(
166 loc, arith::CmpFPredicate::UNO, source, source);
167 Value isNonFinite = rewriter.create<arith::OrIOp>(
168 loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan);
169
170 Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst);
171 Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
172 Value res =
173 rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped);
174 return res;
175}
176
177LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
178 // Only supporting default rounding mode as of now.
179 if (op.getRoundingmodeAttr())
180 return failure();
181 Type outType = op.getOut().getType();
182 if (auto outVecType = dyn_cast<VectorType>(outType)) {
183 if (outVecType.isScalable())
184 return failure();
185 if (outVecType.getShape().size() > 1)
186 // Multi-dimensional vectors are currently unsupported.
187 return failure();
188 outType = outVecType.getElementType();
189 }
190 auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType()));
191 if (inType && inType.getWidth() <= 8 && saturateFP8)
192 // Conversion between 8-bit floats is not supported with truncation enabled.
193 return failure();
194 return success(isSuccess: outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
195}
196
197void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
198 PatternRewriter &rewriter) const {
199 Location loc = op.getLoc();
200 Value in = op.getIn();
201 Type outElemType = getElementTypeOrSelf(op.getOut().getType());
202 if (saturateFP8)
203 in = clampInput(rewriter, loc, outElemType, source: in);
204 VectorType truncResType = VectorType::get(4, outElemType);
205 if (!isa<VectorType>(Val: in.getType())) {
206 Value asFloat = castToF32(value: in, loc, rewriter);
207 Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
208 loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
209 /*existing=*/nullptr);
210 Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
211 return rewriter.replaceOp(op, result);
212 }
213 VectorType outType = cast<VectorType>(op.getOut().getType());
214 int64_t numElements = outType.getNumElements();
215 Value zero = rewriter.create<arith::ConstantOp>(
216 loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
217 Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
218 if (outType.getShape().empty()) {
219 Value scalarIn =
220 rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
221 // Recurse to send the 0-D vector case to the 1-D vector case
222 Value scalarTrunc =
223 rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
224 result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
225 ArrayRef<int64_t>{});
226 return rewriter.replaceOp(op, result);
227 }
228
229 for (int64_t i = 0; i < numElements; i += 4) {
230 int64_t elemsThisOp = std::min(a: numElements, b: i + 4) - i;
231 Value thisResult = nullptr;
232 for (int64_t j = 0; j < elemsThisOp; j += 2) {
233 Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j);
234 Value asFloatA = castToF32(value: elemA, loc, rewriter);
235 Value asFloatB = nullptr;
236 if (j + 1 < elemsThisOp) {
237 Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1);
238 asFloatB = castToF32(value: elemB, loc, rewriter);
239 }
240 thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
241 loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
242 }
243 if (elemsThisOp < 4)
244 thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
245 loc, thisResult, 0, elemsThisOp, 1);
246 result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
247 result, i, 1);
248 }
249 rewriter.replaceOp(op, result);
250}
251
252void mlir::arith::populateArithToAMDGPUConversionPatterns(
253 RewritePatternSet &patterns, bool saturateFP8TruncF) {
254 patterns.add<ExtFOnFloat8RewritePattern>(arg: patterns.getContext());
255 patterns.add<TruncFToFloat8RewritePattern>(arg: patterns.getContext(),
256 args&: saturateFP8TruncF);
257}
258
259void ArithToAMDGPUConversionPass::runOnOperation() {
260 Operation *op = getOperation();
261 RewritePatternSet patterns(op->getContext());
262 arith::populateArithToAMDGPUConversionPatterns(patterns, saturateFP8Truncf);
263 if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
264 return signalPassFailure();
265}
266

source code of mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp