1 | //===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h" |
10 | |
11 | #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" |
12 | #include "mlir/Dialect/AMDGPU/Utils/Chipset.h" |
13 | #include "mlir/Dialect/Arith/IR/Arith.h" |
14 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
15 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
16 | #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" |
17 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
18 | #include "mlir/IR/BuiltinTypes.h" |
19 | #include "mlir/IR/PatternMatch.h" |
20 | #include "mlir/IR/TypeUtilities.h" |
21 | #include "mlir/Pass/Pass.h" |
22 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
23 | |
24 | namespace mlir { |
25 | #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS |
26 | #include "mlir/Conversion/Passes.h.inc" |
27 | } // namespace mlir |
28 | |
29 | using namespace mlir; |
30 | using namespace mlir::amdgpu; |
31 | |
32 | namespace { |
33 | // Define commonly used chipsets versions for convenience. |
34 | constexpr Chipset kGfx942 = Chipset(9, 4, 2); |
35 | |
36 | struct ArithToAMDGPUConversionPass final |
37 | : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> { |
38 | using impl::ArithToAMDGPUConversionPassBase< |
39 | ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase; |
40 | |
41 | void runOnOperation() override; |
42 | }; |
43 | |
44 | struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> { |
45 | using OpRewritePattern::OpRewritePattern; |
46 | |
47 | Chipset chipset; |
48 | ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset) |
49 | : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {} |
50 | |
51 | LogicalResult matchAndRewrite(arith::ExtFOp op, |
52 | PatternRewriter &rewriter) const override; |
53 | }; |
54 | |
55 | struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> { |
56 | bool saturateFP8 = false; |
57 | TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8, |
58 | Chipset chipset) |
59 | : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8), |
60 | chipset(chipset) {} |
61 | Chipset chipset; |
62 | |
63 | LogicalResult matchAndRewrite(arith::TruncFOp op, |
64 | PatternRewriter &rewriter) const override; |
65 | }; |
66 | |
67 | struct TruncfToFloat16RewritePattern final |
68 | : public OpRewritePattern<arith::TruncFOp> { |
69 | |
70 | using OpRewritePattern::OpRewritePattern; |
71 | |
72 | LogicalResult matchAndRewrite(arith::TruncFOp op, |
73 | PatternRewriter &rewriter) const override; |
74 | }; |
75 | |
76 | } // end namespace |
77 | |
78 | static bool isSupportedF8(Type elementType, Chipset chipset) { |
79 | if (chipset == kGfx942) |
80 | return isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType); |
81 | if (hasOcpFp8(chipset)) |
82 | return isa<Float8E4M3FNType, Float8E5M2Type>(elementType); |
83 | return false; |
84 | } |
85 | |
86 | static Value castF32To(Type desType, Value f32, Location loc, |
87 | PatternRewriter &rewriter) { |
88 | Type elementType = getElementTypeOrSelf(type: desType); |
89 | if (elementType.isF32()) |
90 | return f32; |
91 | if (elementType.getIntOrFloatBitWidth() < 32) |
92 | return rewriter.create<arith::TruncFOp>(loc, desType, f32); |
93 | if (elementType.getIntOrFloatBitWidth() > 32) |
94 | return rewriter.create<arith::ExtFOp>(loc, desType, f32); |
95 | llvm_unreachable("The only 32-bit float type is f32" ); |
96 | } |
97 | |
98 | LogicalResult |
99 | ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op, |
100 | PatternRewriter &rewriter) const { |
101 | Type inType = op.getIn().getType(); |
102 | auto inVecType = dyn_cast<VectorType>(inType); |
103 | if (inVecType) { |
104 | if (inVecType.isScalable()) |
105 | return failure(); |
106 | inType = inVecType.getElementType(); |
107 | } |
108 | if (!isSupportedF8(elementType: inType, chipset)) |
109 | return failure(); |
110 | |
111 | Location loc = op.getLoc(); |
112 | Value in = op.getIn(); |
113 | Type outElemType = getElementTypeOrSelf(op.getOut().getType()); |
114 | VectorType extResType = VectorType::get(2, rewriter.getF32Type()); |
115 | if (!inVecType) { |
116 | Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>( |
117 | loc, rewriter.getF32Type(), in, 0); |
118 | Value result = castF32To(desType: outElemType, f32: asFloat, loc, rewriter); |
119 | rewriter.replaceOp(op, result); |
120 | return success(); |
121 | } |
122 | int64_t numElements = inVecType.getNumElements(); |
123 | |
124 | Value zero = rewriter.create<arith::ConstantOp>( |
125 | loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); |
126 | VectorType outType = cast<VectorType>(op.getOut().getType()); |
127 | |
128 | if (inVecType.getShape().empty()) { |
129 | Value zerodSplat = |
130 | rewriter.createOrFold<vector::SplatOp>(loc, outType, zero); |
131 | Value scalarIn = |
132 | rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{}); |
133 | Value scalarExt = |
134 | rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn); |
135 | Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zerodSplat, |
136 | ArrayRef<int64_t>{}); |
137 | rewriter.replaceOp(op, result); |
138 | return success(); |
139 | } |
140 | |
141 | VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements}, |
142 | outType.getElementType()); |
143 | Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero); |
144 | |
145 | if (inVecType.getRank() > 1) { |
146 | inVecType = VectorType::get(SmallVector<int64_t>{numElements}, |
147 | inVecType.getElementType()); |
148 | in = rewriter.create<vector::ShapeCastOp>(loc, inVecType, in); |
149 | } |
150 | |
151 | for (int64_t i = 0; i < numElements; i += 4) { |
152 | int64_t elemsThisOp = std::min(a: numElements, b: i + 4) - i; |
153 | Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>( |
154 | loc, in, i, elemsThisOp, 1); |
155 | for (int64_t j = 0; j < elemsThisOp; j += 2) { |
156 | if (i + j + 1 < numElements) { // Convert two 8-bit elements |
157 | Value asFloats = rewriter.create<amdgpu::ExtPackedFp8Op>( |
158 | loc, extResType, inSlice, j / 2); |
159 | Type desType = VectorType::get(2, outElemType); |
160 | Value asType = castF32To(desType, f32: asFloats, loc, rewriter); |
161 | result = rewriter.create<vector::InsertStridedSliceOp>( |
162 | loc, asType, result, i + j, 1); |
163 | } else { // Convert a 8-bit element |
164 | Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>( |
165 | loc, rewriter.getF32Type(), inSlice, j / 2 * 2); |
166 | Value asType = castF32To(desType: outElemType, f32: asFloat, loc, rewriter); |
167 | result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j); |
168 | } |
169 | } |
170 | } |
171 | |
172 | if (inVecType.getRank() != outType.getRank()) { |
173 | result = rewriter.create<vector::ShapeCastOp>(loc, outType, result); |
174 | } |
175 | |
176 | rewriter.replaceOp(op, result); |
177 | return success(); |
178 | } |
179 | |
180 | static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) { |
181 | Type type = value.getType(); |
182 | if (type.isF32()) |
183 | return value; |
184 | if (type.getIntOrFloatBitWidth() < 32) |
185 | return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value); |
186 | if (type.getIntOrFloatBitWidth() > 32) |
187 | return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value); |
188 | llvm_unreachable("The only 32-bit float type is f32" ); |
189 | } |
190 | |
191 | // If `in` is a finite value, clamp it between the maximum and minimum values |
192 | // of `outElemType` so that subsequent conversion instructions don't |
193 | // overflow those out-of-range values to NaN. These semantics are commonly |
194 | // used in machine-learning contexts where failure to clamp would lead to |
195 | // excessive NaN production. |
196 | static Value clampInput(PatternRewriter &rewriter, Location loc, |
197 | Type outElemType, Value source) { |
198 | Type sourceType = source.getType(); |
199 | const llvm::fltSemantics &sourceSem = |
200 | cast<FloatType>(getElementTypeOrSelf(type: sourceType)).getFloatSemantics(); |
201 | const llvm::fltSemantics &targetSem = |
202 | cast<FloatType>(outElemType).getFloatSemantics(); |
203 | |
204 | APFloat min = APFloat::getLargest(Sem: targetSem, /*Negative=*/true); |
205 | APFloat max = APFloat::getLargest(Sem: targetSem, /*Negative=*/false); |
206 | bool ignoredLosesInfo = false; |
207 | // We can ignore conversion failures here because this conversion promotes |
208 | // from a smaller type to a larger one - ex. there can be no loss of precision |
209 | // when casting fp8 to f16. |
210 | (void)min.convert(ToSemantics: sourceSem, RM: APFloat::rmNearestTiesToEven, losesInfo: &ignoredLosesInfo); |
211 | (void)max.convert(ToSemantics: sourceSem, RM: APFloat::rmNearestTiesToEven, losesInfo: &ignoredLosesInfo); |
212 | |
213 | Value minCst = createScalarOrSplatConstant(builder&: rewriter, loc, type: sourceType, value: min); |
214 | Value maxCst = createScalarOrSplatConstant(builder&: rewriter, loc, type: sourceType, value: max); |
215 | |
216 | Value inf = createScalarOrSplatConstant( |
217 | builder&: rewriter, loc, type: sourceType, |
218 | value: APFloat::getInf(Sem: sourceSem, /*Negative=*/false)); |
219 | Value negInf = createScalarOrSplatConstant( |
220 | builder&: rewriter, loc, type: sourceType, value: APFloat::getInf(Sem: sourceSem, /*Negative=*/true)); |
221 | Value isInf = rewriter.createOrFold<arith::CmpFOp>( |
222 | loc, arith::CmpFPredicate::OEQ, source, inf); |
223 | Value isNegInf = rewriter.createOrFold<arith::CmpFOp>( |
224 | loc, arith::CmpFPredicate::OEQ, source, negInf); |
225 | Value isNan = rewriter.createOrFold<arith::CmpFOp>( |
226 | loc, arith::CmpFPredicate::UNO, source, source); |
227 | Value isNonFinite = rewriter.create<arith::OrIOp>( |
228 | loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan); |
229 | |
230 | Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst); |
231 | Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst); |
232 | Value res = |
233 | rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped); |
234 | return res; |
235 | } |
236 | |
237 | LogicalResult |
238 | TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op, |
239 | PatternRewriter &rewriter) const { |
240 | // Only supporting default rounding mode as of now. |
241 | if (op.getRoundingmodeAttr()) |
242 | return failure(); |
243 | Type outType = op.getOut().getType(); |
244 | auto outVecType = dyn_cast<VectorType>(outType); |
245 | if (outVecType) { |
246 | if (outVecType.isScalable()) |
247 | return failure(); |
248 | outType = outVecType.getElementType(); |
249 | } |
250 | auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType())); |
251 | if (inType && inType.getWidth() <= 8 && saturateFP8) |
252 | // Conversion between 8-bit floats is not supported with truncation enabled. |
253 | return failure(); |
254 | |
255 | if (!isSupportedF8(elementType: outType, chipset)) |
256 | return failure(); |
257 | |
258 | Location loc = op.getLoc(); |
259 | Value in = op.getIn(); |
260 | Type outElemType = getElementTypeOrSelf(op.getOut().getType()); |
261 | if (saturateFP8) |
262 | in = clampInput(rewriter, loc, outElemType, source: in); |
263 | auto inVectorTy = dyn_cast<VectorType>(in.getType()); |
264 | VectorType truncResType = VectorType::get(4, outElemType); |
265 | if (!inVectorTy) { |
266 | Value asFloat = castToF32(value: in, loc, rewriter); |
267 | Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>( |
268 | loc, truncResType, asFloat, /*sourceB=*/nullptr, 0, |
269 | /*existing=*/nullptr); |
270 | Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0); |
271 | rewriter.replaceOp(op, result); |
272 | return success(); |
273 | } |
274 | |
275 | int64_t numElements = outVecType.getNumElements(); |
276 | Value zero = rewriter.create<arith::ConstantOp>( |
277 | loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); |
278 | if (outVecType.getShape().empty()) { |
279 | Value scalarIn = |
280 | rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{}); |
281 | // Recurse to send the 0-D vector case to the 1-D vector case |
282 | Value scalarTrunc = |
283 | rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn); |
284 | Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero, |
285 | ArrayRef<int64_t>{}); |
286 | rewriter.replaceOp(op, result); |
287 | return success(); |
288 | } |
289 | |
290 | VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements}, |
291 | outVecType.getElementType()); |
292 | Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero); |
293 | |
294 | if (inVectorTy.getRank() > 1) { |
295 | inVectorTy = VectorType::get(SmallVector<int64_t>{numElements}, |
296 | inVectorTy.getElementType()); |
297 | in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in); |
298 | } |
299 | |
300 | for (int64_t i = 0; i < numElements; i += 4) { |
301 | int64_t elemsThisOp = std::min(a: numElements, b: i + 4) - i; |
302 | Value thisResult = nullptr; |
303 | for (int64_t j = 0; j < elemsThisOp; j += 2) { |
304 | Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j); |
305 | Value asFloatA = castToF32(value: elemA, loc, rewriter); |
306 | Value asFloatB = nullptr; |
307 | if (j + 1 < elemsThisOp) { |
308 | Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1); |
309 | asFloatB = castToF32(value: elemB, loc, rewriter); |
310 | } |
311 | thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>( |
312 | loc, truncResType, asFloatA, asFloatB, j / 2, thisResult); |
313 | } |
314 | if (elemsThisOp < 4) |
315 | thisResult = rewriter.create<vector::ExtractStridedSliceOp>( |
316 | loc, thisResult, 0, elemsThisOp, 1); |
317 | result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult, |
318 | result, i, 1); |
319 | } |
320 | |
321 | if (inVectorTy.getRank() != outVecType.getRank()) { |
322 | result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result); |
323 | } |
324 | |
325 | rewriter.replaceOp(op, result); |
326 | return success(); |
327 | } |
328 | |
329 | LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite( |
330 | arith::TruncFOp op, PatternRewriter &rewriter) const { |
331 | Type outType = op.getOut().getType(); |
332 | Type inputType = getElementTypeOrSelf(op.getIn()); |
333 | auto outVecType = dyn_cast<VectorType>(outType); |
334 | if (outVecType) { |
335 | if (outVecType.isScalable()) |
336 | return failure(); |
337 | outType = outVecType.getElementType(); |
338 | } |
339 | if (!(outType.isF16() && inputType.isF32())) |
340 | return failure(); |
341 | |
342 | Location loc = op.getLoc(); |
343 | Value in = op.getIn(); |
344 | Type outElemType = getElementTypeOrSelf(op.getOut().getType()); |
345 | VectorType truncResType = VectorType::get(2, outElemType); |
346 | auto inVectorTy = dyn_cast<VectorType>(in.getType()); |
347 | |
348 | // Handle the case where input type is not a vector type |
349 | if (!inVectorTy) { |
350 | auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type()); |
351 | Value asF16s = |
352 | rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB); |
353 | Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0); |
354 | rewriter.replaceOp(op, result); |
355 | return success(); |
356 | } |
357 | int64_t numElements = outVecType.getNumElements(); |
358 | Value zero = rewriter.createOrFold<arith::ConstantOp>( |
359 | loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); |
360 | Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero); |
361 | |
362 | if (inVectorTy.getRank() > 1) { |
363 | inVectorTy = VectorType::get(SmallVector<int64_t>{numElements}, |
364 | inVectorTy.getElementType()); |
365 | in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in); |
366 | } |
367 | |
368 | // Handle the vector case. We also handle the (uncommon) case where the vector |
369 | // length is odd |
370 | for (int64_t i = 0; i < numElements; i += 2) { |
371 | int64_t elemsThisOp = std::min(a: numElements, b: i + 2) - i; |
372 | Value thisResult = nullptr; |
373 | Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i); |
374 | Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type()); |
375 | |
376 | if (elemsThisOp == 2) { |
377 | elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1); |
378 | } |
379 | |
380 | thisResult = |
381 | rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB); |
382 | // Place back the truncated result into the possibly larger vector. If we |
383 | // are operating on a size 2 vector, these operations should be folded away |
384 | thisResult = rewriter.create<vector::ExtractStridedSliceOp>( |
385 | loc, thisResult, 0, elemsThisOp, 1); |
386 | result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult, |
387 | result, i, 1); |
388 | } |
389 | |
390 | if (inVectorTy.getRank() != outVecType.getRank()) { |
391 | result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result); |
392 | } |
393 | |
394 | rewriter.replaceOp(op, result); |
395 | return success(); |
396 | } |
397 | |
398 | void mlir::arith::populateArithToAMDGPUConversionPatterns( |
399 | RewritePatternSet &patterns, bool convertFP8Arithmetic, |
400 | bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) { |
401 | |
402 | if (convertFP8Arithmetic) { |
403 | patterns.add<ExtFOnFloat8RewritePattern>(arg: patterns.getContext(), args&: chipset); |
404 | patterns.add<TruncFToFloat8RewritePattern>(arg: patterns.getContext(), |
405 | args&: saturateFP8Truncf, args&: chipset); |
406 | } |
407 | if (allowPackedF16Rtz) |
408 | patterns.add<TruncfToFloat16RewritePattern>(arg: patterns.getContext()); |
409 | } |
410 | |
411 | void ArithToAMDGPUConversionPass::runOnOperation() { |
412 | Operation *op = getOperation(); |
413 | MLIRContext *ctx = &getContext(); |
414 | RewritePatternSet patterns(op->getContext()); |
415 | FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset); |
416 | if (failed(Result: maybeChipset)) { |
417 | emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset); |
418 | return signalPassFailure(); |
419 | } |
420 | |
421 | bool convertFP8Arithmetic = |
422 | *maybeChipset == kGfx942 || hasOcpFp8(chipset: *maybeChipset); |
423 | arith::populateArithToAMDGPUConversionPatterns( |
424 | patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz, |
425 | *maybeChipset); |
426 | if (failed(applyPatternsGreedily(op, std::move(patterns)))) |
427 | return signalPassFailure(); |
428 | } |
429 | |