| 1 | //===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h" |
| 10 | |
| 11 | #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" |
| 12 | #include "mlir/Dialect/AMDGPU/Utils/Chipset.h" |
| 13 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 14 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 15 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 16 | #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" |
| 17 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 18 | #include "mlir/IR/BuiltinTypes.h" |
| 19 | #include "mlir/IR/PatternMatch.h" |
| 20 | #include "mlir/IR/TypeUtilities.h" |
| 21 | #include "mlir/Pass/Pass.h" |
| 22 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 23 | |
| 24 | namespace mlir { |
| 25 | #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS |
| 26 | #include "mlir/Conversion/Passes.h.inc" |
| 27 | } // namespace mlir |
| 28 | |
| 29 | using namespace mlir; |
| 30 | using namespace mlir::amdgpu; |
| 31 | |
| 32 | namespace { |
| 33 | // Define commonly used chipsets versions for convenience. |
| 34 | constexpr Chipset kGfx942 = Chipset(9, 4, 2); |
| 35 | |
| 36 | struct ArithToAMDGPUConversionPass final |
| 37 | : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> { |
| 38 | using impl::ArithToAMDGPUConversionPassBase< |
| 39 | ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase; |
| 40 | |
| 41 | void runOnOperation() override; |
| 42 | }; |
| 43 | |
| 44 | struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> { |
| 45 | using OpRewritePattern::OpRewritePattern; |
| 46 | |
| 47 | Chipset chipset; |
| 48 | ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset) |
| 49 | : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {} |
| 50 | |
| 51 | LogicalResult matchAndRewrite(arith::ExtFOp op, |
| 52 | PatternRewriter &rewriter) const override; |
| 53 | }; |
| 54 | |
| 55 | struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> { |
| 56 | bool saturateFP8 = false; |
| 57 | TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8, |
| 58 | Chipset chipset) |
| 59 | : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8), |
| 60 | chipset(chipset) {} |
| 61 | Chipset chipset; |
| 62 | |
| 63 | LogicalResult matchAndRewrite(arith::TruncFOp op, |
| 64 | PatternRewriter &rewriter) const override; |
| 65 | }; |
| 66 | |
| 67 | struct TruncfToFloat16RewritePattern final |
| 68 | : public OpRewritePattern<arith::TruncFOp> { |
| 69 | |
| 70 | using OpRewritePattern::OpRewritePattern; |
| 71 | |
| 72 | LogicalResult matchAndRewrite(arith::TruncFOp op, |
| 73 | PatternRewriter &rewriter) const override; |
| 74 | }; |
| 75 | |
| 76 | } // end namespace |
| 77 | |
| 78 | static bool isSupportedF8(Type elementType, Chipset chipset) { |
| 79 | if (chipset == kGfx942) |
| 80 | return isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType); |
| 81 | if (hasOcpFp8(chipset)) |
| 82 | return isa<Float8E4M3FNType, Float8E5M2Type>(elementType); |
| 83 | return false; |
| 84 | } |
| 85 | |
| 86 | static Value castF32To(Type desType, Value f32, Location loc, |
| 87 | PatternRewriter &rewriter) { |
| 88 | Type elementType = getElementTypeOrSelf(type: desType); |
| 89 | if (elementType.isF32()) |
| 90 | return f32; |
| 91 | if (elementType.getIntOrFloatBitWidth() < 32) |
| 92 | return rewriter.create<arith::TruncFOp>(loc, desType, f32); |
| 93 | if (elementType.getIntOrFloatBitWidth() > 32) |
| 94 | return rewriter.create<arith::ExtFOp>(loc, desType, f32); |
| 95 | llvm_unreachable("The only 32-bit float type is f32" ); |
| 96 | } |
| 97 | |
| 98 | LogicalResult |
| 99 | ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op, |
| 100 | PatternRewriter &rewriter) const { |
| 101 | Type inType = op.getIn().getType(); |
| 102 | auto inVecType = dyn_cast<VectorType>(inType); |
| 103 | if (inVecType) { |
| 104 | if (inVecType.isScalable()) |
| 105 | return failure(); |
| 106 | inType = inVecType.getElementType(); |
| 107 | } |
| 108 | if (!isSupportedF8(elementType: inType, chipset)) |
| 109 | return failure(); |
| 110 | |
| 111 | Location loc = op.getLoc(); |
| 112 | Value in = op.getIn(); |
| 113 | Type outElemType = getElementTypeOrSelf(op.getOut().getType()); |
| 114 | VectorType extResType = VectorType::get(2, rewriter.getF32Type()); |
| 115 | if (!inVecType) { |
| 116 | Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>( |
| 117 | loc, rewriter.getF32Type(), in, 0); |
| 118 | Value result = castF32To(desType: outElemType, f32: asFloat, loc, rewriter); |
| 119 | rewriter.replaceOp(op, result); |
| 120 | return success(); |
| 121 | } |
| 122 | int64_t numElements = inVecType.getNumElements(); |
| 123 | |
| 124 | Value zero = rewriter.create<arith::ConstantOp>( |
| 125 | loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); |
| 126 | VectorType outType = cast<VectorType>(op.getOut().getType()); |
| 127 | |
| 128 | if (inVecType.getShape().empty()) { |
| 129 | Value zerodSplat = |
| 130 | rewriter.createOrFold<vector::SplatOp>(loc, outType, zero); |
| 131 | Value scalarIn = |
| 132 | rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{}); |
| 133 | Value scalarExt = |
| 134 | rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn); |
| 135 | Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zerodSplat, |
| 136 | ArrayRef<int64_t>{}); |
| 137 | rewriter.replaceOp(op, result); |
| 138 | return success(); |
| 139 | } |
| 140 | |
| 141 | VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements}, |
| 142 | outType.getElementType()); |
| 143 | Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero); |
| 144 | |
| 145 | if (inVecType.getRank() > 1) { |
| 146 | inVecType = VectorType::get(SmallVector<int64_t>{numElements}, |
| 147 | inVecType.getElementType()); |
| 148 | in = rewriter.create<vector::ShapeCastOp>(loc, inVecType, in); |
| 149 | } |
| 150 | |
| 151 | for (int64_t i = 0; i < numElements; i += 4) { |
| 152 | int64_t elemsThisOp = std::min(a: numElements, b: i + 4) - i; |
| 153 | Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>( |
| 154 | loc, in, i, elemsThisOp, 1); |
| 155 | for (int64_t j = 0; j < elemsThisOp; j += 2) { |
| 156 | if (i + j + 1 < numElements) { // Convert two 8-bit elements |
| 157 | Value asFloats = rewriter.create<amdgpu::ExtPackedFp8Op>( |
| 158 | loc, extResType, inSlice, j / 2); |
| 159 | Type desType = VectorType::get(2, outElemType); |
| 160 | Value asType = castF32To(desType, f32: asFloats, loc, rewriter); |
| 161 | result = rewriter.create<vector::InsertStridedSliceOp>( |
| 162 | loc, asType, result, i + j, 1); |
| 163 | } else { // Convert a 8-bit element |
| 164 | Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>( |
| 165 | loc, rewriter.getF32Type(), inSlice, j / 2 * 2); |
| 166 | Value asType = castF32To(desType: outElemType, f32: asFloat, loc, rewriter); |
| 167 | result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j); |
| 168 | } |
| 169 | } |
| 170 | } |
| 171 | |
| 172 | if (inVecType.getRank() != outType.getRank()) { |
| 173 | result = rewriter.create<vector::ShapeCastOp>(loc, outType, result); |
| 174 | } |
| 175 | |
| 176 | rewriter.replaceOp(op, result); |
| 177 | return success(); |
| 178 | } |
| 179 | |
| 180 | static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) { |
| 181 | Type type = value.getType(); |
| 182 | if (type.isF32()) |
| 183 | return value; |
| 184 | if (type.getIntOrFloatBitWidth() < 32) |
| 185 | return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value); |
| 186 | if (type.getIntOrFloatBitWidth() > 32) |
| 187 | return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value); |
| 188 | llvm_unreachable("The only 32-bit float type is f32" ); |
| 189 | } |
| 190 | |
| 191 | // If `in` is a finite value, clamp it between the maximum and minimum values |
| 192 | // of `outElemType` so that subsequent conversion instructions don't |
| 193 | // overflow those out-of-range values to NaN. These semantics are commonly |
| 194 | // used in machine-learning contexts where failure to clamp would lead to |
| 195 | // excessive NaN production. |
| 196 | static Value clampInput(PatternRewriter &rewriter, Location loc, |
| 197 | Type outElemType, Value source) { |
| 198 | Type sourceType = source.getType(); |
| 199 | const llvm::fltSemantics &sourceSem = |
| 200 | cast<FloatType>(getElementTypeOrSelf(type: sourceType)).getFloatSemantics(); |
| 201 | const llvm::fltSemantics &targetSem = |
| 202 | cast<FloatType>(outElemType).getFloatSemantics(); |
| 203 | |
| 204 | APFloat min = APFloat::getLargest(Sem: targetSem, /*Negative=*/true); |
| 205 | APFloat max = APFloat::getLargest(Sem: targetSem, /*Negative=*/false); |
| 206 | bool ignoredLosesInfo = false; |
| 207 | // We can ignore conversion failures here because this conversion promotes |
| 208 | // from a smaller type to a larger one - ex. there can be no loss of precision |
| 209 | // when casting fp8 to f16. |
| 210 | (void)min.convert(ToSemantics: sourceSem, RM: APFloat::rmNearestTiesToEven, losesInfo: &ignoredLosesInfo); |
| 211 | (void)max.convert(ToSemantics: sourceSem, RM: APFloat::rmNearestTiesToEven, losesInfo: &ignoredLosesInfo); |
| 212 | |
| 213 | Value minCst = createScalarOrSplatConstant(builder&: rewriter, loc, type: sourceType, value: min); |
| 214 | Value maxCst = createScalarOrSplatConstant(builder&: rewriter, loc, type: sourceType, value: max); |
| 215 | |
| 216 | Value inf = createScalarOrSplatConstant( |
| 217 | builder&: rewriter, loc, type: sourceType, |
| 218 | value: APFloat::getInf(Sem: sourceSem, /*Negative=*/false)); |
| 219 | Value negInf = createScalarOrSplatConstant( |
| 220 | builder&: rewriter, loc, type: sourceType, value: APFloat::getInf(Sem: sourceSem, /*Negative=*/true)); |
| 221 | Value isInf = rewriter.createOrFold<arith::CmpFOp>( |
| 222 | loc, arith::CmpFPredicate::OEQ, source, inf); |
| 223 | Value isNegInf = rewriter.createOrFold<arith::CmpFOp>( |
| 224 | loc, arith::CmpFPredicate::OEQ, source, negInf); |
| 225 | Value isNan = rewriter.createOrFold<arith::CmpFOp>( |
| 226 | loc, arith::CmpFPredicate::UNO, source, source); |
| 227 | Value isNonFinite = rewriter.create<arith::OrIOp>( |
| 228 | loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan); |
| 229 | |
| 230 | Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst); |
| 231 | Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst); |
| 232 | Value res = |
| 233 | rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped); |
| 234 | return res; |
| 235 | } |
| 236 | |
| 237 | LogicalResult |
| 238 | TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op, |
| 239 | PatternRewriter &rewriter) const { |
| 240 | // Only supporting default rounding mode as of now. |
| 241 | if (op.getRoundingmodeAttr()) |
| 242 | return failure(); |
| 243 | Type outType = op.getOut().getType(); |
| 244 | auto outVecType = dyn_cast<VectorType>(outType); |
| 245 | if (outVecType) { |
| 246 | if (outVecType.isScalable()) |
| 247 | return failure(); |
| 248 | outType = outVecType.getElementType(); |
| 249 | } |
| 250 | auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType())); |
| 251 | if (inType && inType.getWidth() <= 8 && saturateFP8) |
| 252 | // Conversion between 8-bit floats is not supported with truncation enabled. |
| 253 | return failure(); |
| 254 | |
| 255 | if (!isSupportedF8(elementType: outType, chipset)) |
| 256 | return failure(); |
| 257 | |
| 258 | Location loc = op.getLoc(); |
| 259 | Value in = op.getIn(); |
| 260 | Type outElemType = getElementTypeOrSelf(op.getOut().getType()); |
| 261 | if (saturateFP8) |
| 262 | in = clampInput(rewriter, loc, outElemType, source: in); |
| 263 | auto inVectorTy = dyn_cast<VectorType>(in.getType()); |
| 264 | VectorType truncResType = VectorType::get(4, outElemType); |
| 265 | if (!inVectorTy) { |
| 266 | Value asFloat = castToF32(value: in, loc, rewriter); |
| 267 | Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>( |
| 268 | loc, truncResType, asFloat, /*sourceB=*/nullptr, 0, |
| 269 | /*existing=*/nullptr); |
| 270 | Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0); |
| 271 | rewriter.replaceOp(op, result); |
| 272 | return success(); |
| 273 | } |
| 274 | |
| 275 | int64_t numElements = outVecType.getNumElements(); |
| 276 | Value zero = rewriter.create<arith::ConstantOp>( |
| 277 | loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); |
| 278 | if (outVecType.getShape().empty()) { |
| 279 | Value scalarIn = |
| 280 | rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{}); |
| 281 | // Recurse to send the 0-D vector case to the 1-D vector case |
| 282 | Value scalarTrunc = |
| 283 | rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn); |
| 284 | Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero, |
| 285 | ArrayRef<int64_t>{}); |
| 286 | rewriter.replaceOp(op, result); |
| 287 | return success(); |
| 288 | } |
| 289 | |
| 290 | VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements}, |
| 291 | outVecType.getElementType()); |
| 292 | Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero); |
| 293 | |
| 294 | if (inVectorTy.getRank() > 1) { |
| 295 | inVectorTy = VectorType::get(SmallVector<int64_t>{numElements}, |
| 296 | inVectorTy.getElementType()); |
| 297 | in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in); |
| 298 | } |
| 299 | |
| 300 | for (int64_t i = 0; i < numElements; i += 4) { |
| 301 | int64_t elemsThisOp = std::min(a: numElements, b: i + 4) - i; |
| 302 | Value thisResult = nullptr; |
| 303 | for (int64_t j = 0; j < elemsThisOp; j += 2) { |
| 304 | Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j); |
| 305 | Value asFloatA = castToF32(value: elemA, loc, rewriter); |
| 306 | Value asFloatB = nullptr; |
| 307 | if (j + 1 < elemsThisOp) { |
| 308 | Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1); |
| 309 | asFloatB = castToF32(value: elemB, loc, rewriter); |
| 310 | } |
| 311 | thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>( |
| 312 | loc, truncResType, asFloatA, asFloatB, j / 2, thisResult); |
| 313 | } |
| 314 | if (elemsThisOp < 4) |
| 315 | thisResult = rewriter.create<vector::ExtractStridedSliceOp>( |
| 316 | loc, thisResult, 0, elemsThisOp, 1); |
| 317 | result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult, |
| 318 | result, i, 1); |
| 319 | } |
| 320 | |
| 321 | if (inVectorTy.getRank() != outVecType.getRank()) { |
| 322 | result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result); |
| 323 | } |
| 324 | |
| 325 | rewriter.replaceOp(op, result); |
| 326 | return success(); |
| 327 | } |
| 328 | |
| 329 | LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite( |
| 330 | arith::TruncFOp op, PatternRewriter &rewriter) const { |
| 331 | Type outType = op.getOut().getType(); |
| 332 | Type inputType = getElementTypeOrSelf(op.getIn()); |
| 333 | auto outVecType = dyn_cast<VectorType>(outType); |
| 334 | if (outVecType) { |
| 335 | if (outVecType.isScalable()) |
| 336 | return failure(); |
| 337 | outType = outVecType.getElementType(); |
| 338 | } |
| 339 | if (!(outType.isF16() && inputType.isF32())) |
| 340 | return failure(); |
| 341 | |
| 342 | Location loc = op.getLoc(); |
| 343 | Value in = op.getIn(); |
| 344 | Type outElemType = getElementTypeOrSelf(op.getOut().getType()); |
| 345 | VectorType truncResType = VectorType::get(2, outElemType); |
| 346 | auto inVectorTy = dyn_cast<VectorType>(in.getType()); |
| 347 | |
| 348 | // Handle the case where input type is not a vector type |
| 349 | if (!inVectorTy) { |
| 350 | auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type()); |
| 351 | Value asF16s = |
| 352 | rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB); |
| 353 | Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0); |
| 354 | rewriter.replaceOp(op, result); |
| 355 | return success(); |
| 356 | } |
| 357 | int64_t numElements = outVecType.getNumElements(); |
| 358 | Value zero = rewriter.createOrFold<arith::ConstantOp>( |
| 359 | loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); |
| 360 | Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero); |
| 361 | |
| 362 | if (inVectorTy.getRank() > 1) { |
| 363 | inVectorTy = VectorType::get(SmallVector<int64_t>{numElements}, |
| 364 | inVectorTy.getElementType()); |
| 365 | in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in); |
| 366 | } |
| 367 | |
| 368 | // Handle the vector case. We also handle the (uncommon) case where the vector |
| 369 | // length is odd |
| 370 | for (int64_t i = 0; i < numElements; i += 2) { |
| 371 | int64_t elemsThisOp = std::min(a: numElements, b: i + 2) - i; |
| 372 | Value thisResult = nullptr; |
| 373 | Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i); |
| 374 | Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type()); |
| 375 | |
| 376 | if (elemsThisOp == 2) { |
| 377 | elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1); |
| 378 | } |
| 379 | |
| 380 | thisResult = |
| 381 | rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB); |
| 382 | // Place back the truncated result into the possibly larger vector. If we |
| 383 | // are operating on a size 2 vector, these operations should be folded away |
| 384 | thisResult = rewriter.create<vector::ExtractStridedSliceOp>( |
| 385 | loc, thisResult, 0, elemsThisOp, 1); |
| 386 | result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult, |
| 387 | result, i, 1); |
| 388 | } |
| 389 | |
| 390 | if (inVectorTy.getRank() != outVecType.getRank()) { |
| 391 | result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result); |
| 392 | } |
| 393 | |
| 394 | rewriter.replaceOp(op, result); |
| 395 | return success(); |
| 396 | } |
| 397 | |
| 398 | void mlir::arith::populateArithToAMDGPUConversionPatterns( |
| 399 | RewritePatternSet &patterns, bool convertFP8Arithmetic, |
| 400 | bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) { |
| 401 | |
| 402 | if (convertFP8Arithmetic) { |
| 403 | patterns.add<ExtFOnFloat8RewritePattern>(arg: patterns.getContext(), args&: chipset); |
| 404 | patterns.add<TruncFToFloat8RewritePattern>(arg: patterns.getContext(), |
| 405 | args&: saturateFP8Truncf, args&: chipset); |
| 406 | } |
| 407 | if (allowPackedF16Rtz) |
| 408 | patterns.add<TruncfToFloat16RewritePattern>(arg: patterns.getContext()); |
| 409 | } |
| 410 | |
| 411 | void ArithToAMDGPUConversionPass::runOnOperation() { |
| 412 | Operation *op = getOperation(); |
| 413 | MLIRContext *ctx = &getContext(); |
| 414 | RewritePatternSet patterns(op->getContext()); |
| 415 | FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset); |
| 416 | if (failed(Result: maybeChipset)) { |
| 417 | emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset); |
| 418 | return signalPassFailure(); |
| 419 | } |
| 420 | |
| 421 | bool convertFP8Arithmetic = |
| 422 | *maybeChipset == kGfx942 || hasOcpFp8(chipset: *maybeChipset); |
| 423 | arith::populateArithToAMDGPUConversionPatterns( |
| 424 | patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz, |
| 425 | *maybeChipset); |
| 426 | if (failed(applyPatternsGreedily(op, std::move(patterns)))) |
| 427 | return signalPassFailure(); |
| 428 | } |
| 429 | |