1//===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
10
11#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/Arith/Utils/Utils.h"
15#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
17#include "mlir/Dialect/Vector/IR/VectorOps.h"
18#include "mlir/IR/BuiltinTypes.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/IR/TypeUtilities.h"
21#include "mlir/Pass/Pass.h"
22#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23
24namespace mlir {
25#define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
26#include "mlir/Conversion/Passes.h.inc"
27} // namespace mlir
28
29using namespace mlir;
30using namespace mlir::amdgpu;
31
32namespace {
33// Define commonly used chipsets versions for convenience.
34constexpr Chipset kGfx942 = Chipset(9, 4, 2);
35
36struct ArithToAMDGPUConversionPass final
37 : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
38 using impl::ArithToAMDGPUConversionPassBase<
39 ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase;
40
41 void runOnOperation() override;
42};
43
44struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
45 using OpRewritePattern::OpRewritePattern;
46
47 Chipset chipset;
48 ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
49 : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
50
51 LogicalResult matchAndRewrite(arith::ExtFOp op,
52 PatternRewriter &rewriter) const override;
53};
54
55struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
56 bool saturateFP8 = false;
57 TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
58 Chipset chipset)
59 : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
60 chipset(chipset) {}
61 Chipset chipset;
62
63 LogicalResult matchAndRewrite(arith::TruncFOp op,
64 PatternRewriter &rewriter) const override;
65};
66
67struct TruncfToFloat16RewritePattern final
68 : public OpRewritePattern<arith::TruncFOp> {
69
70 using OpRewritePattern::OpRewritePattern;
71
72 LogicalResult matchAndRewrite(arith::TruncFOp op,
73 PatternRewriter &rewriter) const override;
74};
75
76} // end namespace
77
78static bool isSupportedF8(Type elementType, Chipset chipset) {
79 if (chipset == kGfx942)
80 return isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType);
81 if (hasOcpFp8(chipset))
82 return isa<Float8E4M3FNType, Float8E5M2Type>(elementType);
83 return false;
84}
85
86static Value castF32To(Type desType, Value f32, Location loc,
87 PatternRewriter &rewriter) {
88 Type elementType = getElementTypeOrSelf(type: desType);
89 if (elementType.isF32())
90 return f32;
91 if (elementType.getIntOrFloatBitWidth() < 32)
92 return rewriter.create<arith::TruncFOp>(loc, desType, f32);
93 if (elementType.getIntOrFloatBitWidth() > 32)
94 return rewriter.create<arith::ExtFOp>(loc, desType, f32);
95 llvm_unreachable("The only 32-bit float type is f32");
96}
97
98LogicalResult
99ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
100 PatternRewriter &rewriter) const {
101 Type inType = op.getIn().getType();
102 auto inVecType = dyn_cast<VectorType>(inType);
103 if (inVecType) {
104 if (inVecType.isScalable())
105 return failure();
106 inType = inVecType.getElementType();
107 }
108 if (!isSupportedF8(elementType: inType, chipset))
109 return failure();
110
111 Location loc = op.getLoc();
112 Value in = op.getIn();
113 Type outElemType = getElementTypeOrSelf(op.getOut().getType());
114 VectorType extResType = VectorType::get(2, rewriter.getF32Type());
115 if (!inVecType) {
116 Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
117 loc, rewriter.getF32Type(), in, 0);
118 Value result = castF32To(desType: outElemType, f32: asFloat, loc, rewriter);
119 rewriter.replaceOp(op, result);
120 return success();
121 }
122 int64_t numElements = inVecType.getNumElements();
123
124 Value zero = rewriter.create<arith::ConstantOp>(
125 loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
126 VectorType outType = cast<VectorType>(op.getOut().getType());
127
128 if (inVecType.getShape().empty()) {
129 Value zerodSplat =
130 rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
131 Value scalarIn =
132 rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
133 Value scalarExt =
134 rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
135 Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zerodSplat,
136 ArrayRef<int64_t>{});
137 rewriter.replaceOp(op, result);
138 return success();
139 }
140
141 VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
142 outType.getElementType());
143 Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
144
145 if (inVecType.getRank() > 1) {
146 inVecType = VectorType::get(SmallVector<int64_t>{numElements},
147 inVecType.getElementType());
148 in = rewriter.create<vector::ShapeCastOp>(loc, inVecType, in);
149 }
150
151 for (int64_t i = 0; i < numElements; i += 4) {
152 int64_t elemsThisOp = std::min(a: numElements, b: i + 4) - i;
153 Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
154 loc, in, i, elemsThisOp, 1);
155 for (int64_t j = 0; j < elemsThisOp; j += 2) {
156 if (i + j + 1 < numElements) { // Convert two 8-bit elements
157 Value asFloats = rewriter.create<amdgpu::ExtPackedFp8Op>(
158 loc, extResType, inSlice, j / 2);
159 Type desType = VectorType::get(2, outElemType);
160 Value asType = castF32To(desType, f32: asFloats, loc, rewriter);
161 result = rewriter.create<vector::InsertStridedSliceOp>(
162 loc, asType, result, i + j, 1);
163 } else { // Convert a 8-bit element
164 Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
165 loc, rewriter.getF32Type(), inSlice, j / 2 * 2);
166 Value asType = castF32To(desType: outElemType, f32: asFloat, loc, rewriter);
167 result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
168 }
169 }
170 }
171
172 if (inVecType.getRank() != outType.getRank()) {
173 result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
174 }
175
176 rewriter.replaceOp(op, result);
177 return success();
178}
179
180static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
181 Type type = value.getType();
182 if (type.isF32())
183 return value;
184 if (type.getIntOrFloatBitWidth() < 32)
185 return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value);
186 if (type.getIntOrFloatBitWidth() > 32)
187 return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value);
188 llvm_unreachable("The only 32-bit float type is f32");
189}
190
191// If `in` is a finite value, clamp it between the maximum and minimum values
192// of `outElemType` so that subsequent conversion instructions don't
193// overflow those out-of-range values to NaN. These semantics are commonly
194// used in machine-learning contexts where failure to clamp would lead to
195// excessive NaN production.
196static Value clampInput(PatternRewriter &rewriter, Location loc,
197 Type outElemType, Value source) {
198 Type sourceType = source.getType();
199 const llvm::fltSemantics &sourceSem =
200 cast<FloatType>(getElementTypeOrSelf(type: sourceType)).getFloatSemantics();
201 const llvm::fltSemantics &targetSem =
202 cast<FloatType>(outElemType).getFloatSemantics();
203
204 APFloat min = APFloat::getLargest(Sem: targetSem, /*Negative=*/true);
205 APFloat max = APFloat::getLargest(Sem: targetSem, /*Negative=*/false);
206 bool ignoredLosesInfo = false;
207 // We can ignore conversion failures here because this conversion promotes
208 // from a smaller type to a larger one - ex. there can be no loss of precision
209 // when casting fp8 to f16.
210 (void)min.convert(ToSemantics: sourceSem, RM: APFloat::rmNearestTiesToEven, losesInfo: &ignoredLosesInfo);
211 (void)max.convert(ToSemantics: sourceSem, RM: APFloat::rmNearestTiesToEven, losesInfo: &ignoredLosesInfo);
212
213 Value minCst = createScalarOrSplatConstant(builder&: rewriter, loc, type: sourceType, value: min);
214 Value maxCst = createScalarOrSplatConstant(builder&: rewriter, loc, type: sourceType, value: max);
215
216 Value inf = createScalarOrSplatConstant(
217 builder&: rewriter, loc, type: sourceType,
218 value: APFloat::getInf(Sem: sourceSem, /*Negative=*/false));
219 Value negInf = createScalarOrSplatConstant(
220 builder&: rewriter, loc, type: sourceType, value: APFloat::getInf(Sem: sourceSem, /*Negative=*/true));
221 Value isInf = rewriter.createOrFold<arith::CmpFOp>(
222 loc, arith::CmpFPredicate::OEQ, source, inf);
223 Value isNegInf = rewriter.createOrFold<arith::CmpFOp>(
224 loc, arith::CmpFPredicate::OEQ, source, negInf);
225 Value isNan = rewriter.createOrFold<arith::CmpFOp>(
226 loc, arith::CmpFPredicate::UNO, source, source);
227 Value isNonFinite = rewriter.create<arith::OrIOp>(
228 loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan);
229
230 Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst);
231 Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
232 Value res =
233 rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped);
234 return res;
235}
236
237LogicalResult
238TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
239 PatternRewriter &rewriter) const {
240 // Only supporting default rounding mode as of now.
241 if (op.getRoundingmodeAttr())
242 return failure();
243 Type outType = op.getOut().getType();
244 auto outVecType = dyn_cast<VectorType>(outType);
245 if (outVecType) {
246 if (outVecType.isScalable())
247 return failure();
248 outType = outVecType.getElementType();
249 }
250 auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType()));
251 if (inType && inType.getWidth() <= 8 && saturateFP8)
252 // Conversion between 8-bit floats is not supported with truncation enabled.
253 return failure();
254
255 if (!isSupportedF8(elementType: outType, chipset))
256 return failure();
257
258 Location loc = op.getLoc();
259 Value in = op.getIn();
260 Type outElemType = getElementTypeOrSelf(op.getOut().getType());
261 if (saturateFP8)
262 in = clampInput(rewriter, loc, outElemType, source: in);
263 auto inVectorTy = dyn_cast<VectorType>(in.getType());
264 VectorType truncResType = VectorType::get(4, outElemType);
265 if (!inVectorTy) {
266 Value asFloat = castToF32(value: in, loc, rewriter);
267 Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
268 loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
269 /*existing=*/nullptr);
270 Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
271 rewriter.replaceOp(op, result);
272 return success();
273 }
274
275 int64_t numElements = outVecType.getNumElements();
276 Value zero = rewriter.create<arith::ConstantOp>(
277 loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
278 if (outVecType.getShape().empty()) {
279 Value scalarIn =
280 rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
281 // Recurse to send the 0-D vector case to the 1-D vector case
282 Value scalarTrunc =
283 rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
284 Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
285 ArrayRef<int64_t>{});
286 rewriter.replaceOp(op, result);
287 return success();
288 }
289
290 VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
291 outVecType.getElementType());
292 Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
293
294 if (inVectorTy.getRank() > 1) {
295 inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
296 inVectorTy.getElementType());
297 in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
298 }
299
300 for (int64_t i = 0; i < numElements; i += 4) {
301 int64_t elemsThisOp = std::min(a: numElements, b: i + 4) - i;
302 Value thisResult = nullptr;
303 for (int64_t j = 0; j < elemsThisOp; j += 2) {
304 Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j);
305 Value asFloatA = castToF32(value: elemA, loc, rewriter);
306 Value asFloatB = nullptr;
307 if (j + 1 < elemsThisOp) {
308 Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1);
309 asFloatB = castToF32(value: elemB, loc, rewriter);
310 }
311 thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
312 loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
313 }
314 if (elemsThisOp < 4)
315 thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
316 loc, thisResult, 0, elemsThisOp, 1);
317 result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
318 result, i, 1);
319 }
320
321 if (inVectorTy.getRank() != outVecType.getRank()) {
322 result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result);
323 }
324
325 rewriter.replaceOp(op, result);
326 return success();
327}
328
329LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
330 arith::TruncFOp op, PatternRewriter &rewriter) const {
331 Type outType = op.getOut().getType();
332 Type inputType = getElementTypeOrSelf(op.getIn());
333 auto outVecType = dyn_cast<VectorType>(outType);
334 if (outVecType) {
335 if (outVecType.isScalable())
336 return failure();
337 outType = outVecType.getElementType();
338 }
339 if (!(outType.isF16() && inputType.isF32()))
340 return failure();
341
342 Location loc = op.getLoc();
343 Value in = op.getIn();
344 Type outElemType = getElementTypeOrSelf(op.getOut().getType());
345 VectorType truncResType = VectorType::get(2, outElemType);
346 auto inVectorTy = dyn_cast<VectorType>(in.getType());
347
348 // Handle the case where input type is not a vector type
349 if (!inVectorTy) {
350 auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
351 Value asF16s =
352 rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
353 Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0);
354 rewriter.replaceOp(op, result);
355 return success();
356 }
357 int64_t numElements = outVecType.getNumElements();
358 Value zero = rewriter.createOrFold<arith::ConstantOp>(
359 loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
360 Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
361
362 if (inVectorTy.getRank() > 1) {
363 inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
364 inVectorTy.getElementType());
365 in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
366 }
367
368 // Handle the vector case. We also handle the (uncommon) case where the vector
369 // length is odd
370 for (int64_t i = 0; i < numElements; i += 2) {
371 int64_t elemsThisOp = std::min(a: numElements, b: i + 2) - i;
372 Value thisResult = nullptr;
373 Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i);
374 Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
375
376 if (elemsThisOp == 2) {
377 elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1);
378 }
379
380 thisResult =
381 rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB);
382 // Place back the truncated result into the possibly larger vector. If we
383 // are operating on a size 2 vector, these operations should be folded away
384 thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
385 loc, thisResult, 0, elemsThisOp, 1);
386 result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
387 result, i, 1);
388 }
389
390 if (inVectorTy.getRank() != outVecType.getRank()) {
391 result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result);
392 }
393
394 rewriter.replaceOp(op, result);
395 return success();
396}
397
398void mlir::arith::populateArithToAMDGPUConversionPatterns(
399 RewritePatternSet &patterns, bool convertFP8Arithmetic,
400 bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
401
402 if (convertFP8Arithmetic) {
403 patterns.add<ExtFOnFloat8RewritePattern>(arg: patterns.getContext(), args&: chipset);
404 patterns.add<TruncFToFloat8RewritePattern>(arg: patterns.getContext(),
405 args&: saturateFP8Truncf, args&: chipset);
406 }
407 if (allowPackedF16Rtz)
408 patterns.add<TruncfToFloat16RewritePattern>(arg: patterns.getContext());
409}
410
411void ArithToAMDGPUConversionPass::runOnOperation() {
412 Operation *op = getOperation();
413 MLIRContext *ctx = &getContext();
414 RewritePatternSet patterns(op->getContext());
415 FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
416 if (failed(Result: maybeChipset)) {
417 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
418 return signalPassFailure();
419 }
420
421 bool convertFP8Arithmetic =
422 *maybeChipset == kGfx942 || hasOcpFp8(chipset: *maybeChipset);
423 arith::populateArithToAMDGPUConversionPatterns(
424 patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
425 *maybeChipset);
426 if (failed(applyPatternsGreedily(op, std::move(patterns))))
427 return signalPassFailure();
428}
429

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp