1 | //===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h" |
10 | |
11 | #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" |
12 | #include "mlir/Dialect/Arith/IR/Arith.h" |
13 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
14 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
15 | #include "mlir/IR/BuiltinTypes.h" |
16 | #include "mlir/IR/PatternMatch.h" |
17 | #include "mlir/IR/TypeUtilities.h" |
18 | #include "mlir/Pass/Pass.h" |
19 | #include "mlir/Support/LogicalResult.h" |
20 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
21 | |
22 | namespace mlir { |
23 | #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS |
24 | #include "mlir/Conversion/Passes.h.inc" |
25 | } // namespace mlir |
26 | |
27 | using namespace mlir; |
28 | |
29 | namespace { |
30 | struct ArithToAMDGPUConversionPass final |
31 | : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> { |
32 | using impl::ArithToAMDGPUConversionPassBase< |
33 | ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase; |
34 | |
35 | void runOnOperation() override; |
36 | }; |
37 | |
38 | struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> { |
39 | using OpRewritePattern::OpRewritePattern; |
40 | |
41 | LogicalResult match(arith::ExtFOp op) const override; |
42 | void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override; |
43 | }; |
44 | |
45 | struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> { |
46 | bool saturateFP8 = false; |
47 | TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8) |
48 | : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8) {} |
49 | |
50 | LogicalResult match(arith::TruncFOp op) const override; |
51 | void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override; |
52 | }; |
53 | } // end namespace |
54 | |
55 | static Value castF32To(Type elementType, Value f32, Location loc, |
56 | PatternRewriter &rewriter) { |
57 | if (elementType.isF32()) |
58 | return f32; |
59 | if (elementType.getIntOrFloatBitWidth() < 32) |
60 | return rewriter.create<arith::TruncFOp>(loc, elementType, f32); |
61 | if (elementType.getIntOrFloatBitWidth() > 32) |
62 | return rewriter.create<arith::ExtFOp>(loc, elementType, f32); |
63 | llvm_unreachable("The only 32-bit float type is f32" ); |
64 | } |
65 | |
66 | LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const { |
67 | Type inType = op.getIn().getType(); |
68 | if (auto inVecType = dyn_cast<VectorType>(inType)) { |
69 | if (inVecType.isScalable()) |
70 | return failure(); |
71 | if (inVecType.getShape().size() > 1) |
72 | // Multi-dimensional vectors are currently unsupported. |
73 | return failure(); |
74 | inType = inVecType.getElementType(); |
75 | } |
76 | return success(isSuccess: inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ()); |
77 | } |
78 | |
79 | void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op, |
80 | PatternRewriter &rewriter) const { |
81 | Location loc = op.getLoc(); |
82 | Value in = op.getIn(); |
83 | Type outElemType = getElementTypeOrSelf(op.getOut().getType()); |
84 | if (!isa<VectorType>(Val: in.getType())) { |
85 | Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>( |
86 | loc, rewriter.getF32Type(), in, 0); |
87 | Value result = castF32To(elementType: outElemType, f32: asFloat, loc, rewriter); |
88 | return rewriter.replaceOp(op, result); |
89 | } |
90 | VectorType inType = cast<VectorType>(in.getType()); |
91 | int64_t numElements = inType.getNumElements(); |
92 | Value zero = rewriter.create<arith::ConstantOp>( |
93 | loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); |
94 | Value result = |
95 | rewriter.createOrFold<vector::SplatOp>(loc, op.getOut().getType(), zero); |
96 | if (inType.getShape().empty()) { |
97 | Value scalarIn = |
98 | rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{}); |
99 | // Recurse to send the 0-D vector case to the 1-D vector case |
100 | Value scalarExt = |
101 | rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn); |
102 | result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero, |
103 | ArrayRef<int64_t>{}); |
104 | return rewriter.replaceOp(op, result); |
105 | } |
106 | for (int64_t i = 0; i < numElements; i += 4) { |
107 | int64_t elemsThisOp = std::min(a: numElements, b: i + 4) - i; |
108 | Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>( |
109 | loc, in, i, elemsThisOp, 1); |
110 | for (int64_t j = 0; j < elemsThisOp; ++j) { |
111 | Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>( |
112 | loc, rewriter.getF32Type(), inSlice, j); |
113 | Value asType = castF32To(elementType: outElemType, f32: asFloat, loc, rewriter); |
114 | result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j); |
115 | } |
116 | } |
117 | rewriter.replaceOp(op, result); |
118 | } |
119 | |
120 | static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) { |
121 | Type type = value.getType(); |
122 | if (type.isF32()) |
123 | return value; |
124 | if (type.getIntOrFloatBitWidth() < 32) |
125 | return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value); |
126 | if (type.getIntOrFloatBitWidth() > 32) |
127 | return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value); |
128 | llvm_unreachable("The only 32-bit float type is f32" ); |
129 | } |
130 | |
131 | // If `in` is a finite value, clamp it between the maximum and minimum values |
132 | // of `outElemType` so that subsequent conversion instructions don't |
133 | // overflow those out-of-range values to NaN. These semantics are commonly |
134 | // used in machine-learning contexts where failure to clamp would lead to |
135 | // excessive NaN production. |
136 | static Value clampInput(PatternRewriter &rewriter, Location loc, |
137 | Type outElemType, Value source) { |
138 | Type sourceType = source.getType(); |
139 | const llvm::fltSemantics &sourceSem = |
140 | cast<FloatType>(Val: getElementTypeOrSelf(type: sourceType)).getFloatSemantics(); |
141 | const llvm::fltSemantics &targetSem = |
142 | cast<FloatType>(Val&: outElemType).getFloatSemantics(); |
143 | |
144 | APFloat min = APFloat::getLargest(Sem: targetSem, /*Negative=*/true); |
145 | APFloat max = APFloat::getLargest(Sem: targetSem, /*Negative=*/false); |
146 | bool ignoredLosesInfo = false; |
147 | // We can ignore conversion failures here because this conversion promotes |
148 | // from a smaller type to a larger one - ex. there can be no loss of precision |
149 | // when casting fp8 to f16. |
150 | (void)min.convert(ToSemantics: sourceSem, RM: APFloat::rmNearestTiesToEven, losesInfo: &ignoredLosesInfo); |
151 | (void)max.convert(ToSemantics: sourceSem, RM: APFloat::rmNearestTiesToEven, losesInfo: &ignoredLosesInfo); |
152 | |
153 | Value minCst = createScalarOrSplatConstant(builder&: rewriter, loc, type: sourceType, value: min); |
154 | Value maxCst = createScalarOrSplatConstant(builder&: rewriter, loc, type: sourceType, value: max); |
155 | |
156 | Value inf = createScalarOrSplatConstant( |
157 | builder&: rewriter, loc, type: sourceType, |
158 | value: APFloat::getInf(Sem: sourceSem, /*Negative=*/false)); |
159 | Value negInf = createScalarOrSplatConstant( |
160 | builder&: rewriter, loc, type: sourceType, value: APFloat::getInf(Sem: sourceSem, /*Negative=*/true)); |
161 | Value isInf = rewriter.createOrFold<arith::CmpFOp>( |
162 | loc, arith::CmpFPredicate::OEQ, source, inf); |
163 | Value isNegInf = rewriter.createOrFold<arith::CmpFOp>( |
164 | loc, arith::CmpFPredicate::OEQ, source, negInf); |
165 | Value isNan = rewriter.createOrFold<arith::CmpFOp>( |
166 | loc, arith::CmpFPredicate::UNO, source, source); |
167 | Value isNonFinite = rewriter.create<arith::OrIOp>( |
168 | loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan); |
169 | |
170 | Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst); |
171 | Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst); |
172 | Value res = |
173 | rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped); |
174 | return res; |
175 | } |
176 | |
177 | LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const { |
178 | // Only supporting default rounding mode as of now. |
179 | if (op.getRoundingmodeAttr()) |
180 | return failure(); |
181 | Type outType = op.getOut().getType(); |
182 | if (auto outVecType = dyn_cast<VectorType>(outType)) { |
183 | if (outVecType.isScalable()) |
184 | return failure(); |
185 | if (outVecType.getShape().size() > 1) |
186 | // Multi-dimensional vectors are currently unsupported. |
187 | return failure(); |
188 | outType = outVecType.getElementType(); |
189 | } |
190 | auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType())); |
191 | if (inType && inType.getWidth() <= 8 && saturateFP8) |
192 | // Conversion between 8-bit floats is not supported with truncation enabled. |
193 | return failure(); |
194 | return success(isSuccess: outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ()); |
195 | } |
196 | |
197 | void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op, |
198 | PatternRewriter &rewriter) const { |
199 | Location loc = op.getLoc(); |
200 | Value in = op.getIn(); |
201 | Type outElemType = getElementTypeOrSelf(op.getOut().getType()); |
202 | if (saturateFP8) |
203 | in = clampInput(rewriter, loc, outElemType, source: in); |
204 | VectorType truncResType = VectorType::get(4, outElemType); |
205 | if (!isa<VectorType>(Val: in.getType())) { |
206 | Value asFloat = castToF32(value: in, loc, rewriter); |
207 | Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>( |
208 | loc, truncResType, asFloat, /*sourceB=*/nullptr, 0, |
209 | /*existing=*/nullptr); |
210 | Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0); |
211 | return rewriter.replaceOp(op, result); |
212 | } |
213 | VectorType outType = cast<VectorType>(op.getOut().getType()); |
214 | int64_t numElements = outType.getNumElements(); |
215 | Value zero = rewriter.create<arith::ConstantOp>( |
216 | loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); |
217 | Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero); |
218 | if (outType.getShape().empty()) { |
219 | Value scalarIn = |
220 | rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{}); |
221 | // Recurse to send the 0-D vector case to the 1-D vector case |
222 | Value scalarTrunc = |
223 | rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn); |
224 | result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero, |
225 | ArrayRef<int64_t>{}); |
226 | return rewriter.replaceOp(op, result); |
227 | } |
228 | |
229 | for (int64_t i = 0; i < numElements; i += 4) { |
230 | int64_t elemsThisOp = std::min(a: numElements, b: i + 4) - i; |
231 | Value thisResult = nullptr; |
232 | for (int64_t j = 0; j < elemsThisOp; j += 2) { |
233 | Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j); |
234 | Value asFloatA = castToF32(value: elemA, loc, rewriter); |
235 | Value asFloatB = nullptr; |
236 | if (j + 1 < elemsThisOp) { |
237 | Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1); |
238 | asFloatB = castToF32(value: elemB, loc, rewriter); |
239 | } |
240 | thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>( |
241 | loc, truncResType, asFloatA, asFloatB, j / 2, thisResult); |
242 | } |
243 | if (elemsThisOp < 4) |
244 | thisResult = rewriter.create<vector::ExtractStridedSliceOp>( |
245 | loc, thisResult, 0, elemsThisOp, 1); |
246 | result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult, |
247 | result, i, 1); |
248 | } |
249 | rewriter.replaceOp(op, result); |
250 | } |
251 | |
252 | void mlir::arith::populateArithToAMDGPUConversionPatterns( |
253 | RewritePatternSet &patterns, bool saturateFP8TruncF) { |
254 | patterns.add<ExtFOnFloat8RewritePattern>(arg: patterns.getContext()); |
255 | patterns.add<TruncFToFloat8RewritePattern>(arg: patterns.getContext(), |
256 | args&: saturateFP8TruncF); |
257 | } |
258 | |
259 | void ArithToAMDGPUConversionPass::runOnOperation() { |
260 | Operation *op = getOperation(); |
261 | RewritePatternSet patterns(op->getContext()); |
262 | arith::populateArithToAMDGPUConversionPatterns(patterns, saturateFP8Truncf); |
263 | if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) |
264 | return signalPassFailure(); |
265 | } |
266 | |