| 1 | //===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h" |
| 10 | |
| 11 | #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" |
| 12 | #include "mlir/Dialect/AMDGPU/Utils/Chipset.h" |
| 13 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 14 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 15 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 16 | #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" |
| 17 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 18 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 19 | #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" |
| 20 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
| 21 | #include "mlir/IR/BuiltinTypes.h" |
| 22 | #include "mlir/IR/PatternMatch.h" |
| 23 | #include "mlir/IR/TypeUtilities.h" |
| 24 | #include "mlir/Pass/Pass.h" |
| 25 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 26 | |
| 27 | namespace mlir { |
| 28 | #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS |
| 29 | #include "mlir/Conversion/Passes.h.inc" |
| 30 | } // namespace mlir |
| 31 | |
| 32 | using namespace mlir; |
| 33 | using namespace mlir::amdgpu; |
| 34 | |
| 35 | namespace { |
| 36 | // Define commonly used chipsets versions for convenience. |
| 37 | constexpr Chipset kGfx942 = Chipset(9, 4, 2); |
| 38 | constexpr Chipset kGfx950 = Chipset(9, 5, 0); |
| 39 | |
| 40 | struct ArithToAMDGPUConversionPass final |
| 41 | : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> { |
| 42 | using impl::ArithToAMDGPUConversionPassBase< |
| 43 | ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase; |
| 44 | |
| 45 | void runOnOperation() override; |
| 46 | }; |
| 47 | |
| 48 | struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> { |
| 49 | using OpRewritePattern::OpRewritePattern; |
| 50 | |
| 51 | Chipset chipset; |
| 52 | ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset) |
| 53 | : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {} |
| 54 | |
| 55 | LogicalResult matchAndRewrite(arith::ExtFOp op, |
| 56 | PatternRewriter &rewriter) const override; |
| 57 | }; |
| 58 | |
| 59 | struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> { |
| 60 | bool saturateFP8 = false; |
| 61 | TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8, |
| 62 | Chipset chipset) |
| 63 | : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8), |
| 64 | chipset(chipset) {} |
| 65 | Chipset chipset; |
| 66 | |
| 67 | LogicalResult matchAndRewrite(arith::TruncFOp op, |
| 68 | PatternRewriter &rewriter) const override; |
| 69 | }; |
| 70 | |
| 71 | struct TruncfToFloat16RewritePattern final |
| 72 | : public OpRewritePattern<arith::TruncFOp> { |
| 73 | |
| 74 | using OpRewritePattern::OpRewritePattern; |
| 75 | |
| 76 | LogicalResult matchAndRewrite(arith::TruncFOp op, |
| 77 | PatternRewriter &rewriter) const override; |
| 78 | }; |
| 79 | |
| 80 | struct ScalingExtFRewritePattern final |
| 81 | : OpRewritePattern<arith::ScalingExtFOp> { |
| 82 | using OpRewritePattern::OpRewritePattern; |
| 83 | |
| 84 | ScalingExtFRewritePattern(MLIRContext *ctx) |
| 85 | : OpRewritePattern::OpRewritePattern(ctx) {} |
| 86 | |
| 87 | LogicalResult matchAndRewrite(arith::ScalingExtFOp op, |
| 88 | PatternRewriter &rewriter) const override; |
| 89 | }; |
| 90 | |
| 91 | struct ScalingTruncFRewritePattern final |
| 92 | : OpRewritePattern<arith::ScalingTruncFOp> { |
| 93 | using OpRewritePattern::OpRewritePattern; |
| 94 | |
| 95 | ScalingTruncFRewritePattern(MLIRContext *ctx) |
| 96 | : OpRewritePattern::OpRewritePattern(ctx) {} |
| 97 | |
| 98 | LogicalResult matchAndRewrite(arith::ScalingTruncFOp op, |
| 99 | PatternRewriter &rewriter) const override; |
| 100 | }; |
| 101 | |
| 102 | } // end namespace |
| 103 | |
| 104 | static bool isSupportedF8(Type elementType, Chipset chipset) { |
| 105 | if (chipset == kGfx942) |
| 106 | return isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(Val: elementType); |
| 107 | if (hasOcpFp8(chipset)) |
| 108 | return isa<Float8E4M3FNType, Float8E5M2Type>(Val: elementType); |
| 109 | return false; |
| 110 | } |
| 111 | |
| 112 | static Value castF32To(Type desType, Value f32, Location loc, |
| 113 | PatternRewriter &rewriter) { |
| 114 | Type elementType = getElementTypeOrSelf(type: desType); |
| 115 | if (elementType.isF32()) |
| 116 | return f32; |
| 117 | if (elementType.getIntOrFloatBitWidth() < 32) |
| 118 | return rewriter.create<arith::TruncFOp>(location: loc, args&: desType, args&: f32); |
| 119 | if (elementType.getIntOrFloatBitWidth() > 32) |
| 120 | return rewriter.create<arith::ExtFOp>(location: loc, args&: desType, args&: f32); |
| 121 | llvm_unreachable("The only 32-bit float type is f32" ); |
| 122 | } |
| 123 | |
| 124 | LogicalResult |
| 125 | ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op, |
| 126 | PatternRewriter &rewriter) const { |
| 127 | Type inType = op.getIn().getType(); |
| 128 | auto inVecType = dyn_cast<VectorType>(Val&: inType); |
| 129 | if (inVecType) { |
| 130 | if (inVecType.isScalable()) |
| 131 | return failure(); |
| 132 | inType = inVecType.getElementType(); |
| 133 | } |
| 134 | if (!isSupportedF8(elementType: inType, chipset)) |
| 135 | return failure(); |
| 136 | |
| 137 | Location loc = op.getLoc(); |
| 138 | Value in = op.getIn(); |
| 139 | Type outElemType = getElementTypeOrSelf(type: op.getOut().getType()); |
| 140 | VectorType extResType = VectorType::get(shape: 2, elementType: rewriter.getF32Type()); |
| 141 | if (!inVecType) { |
| 142 | Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>( |
| 143 | location: loc, args: rewriter.getF32Type(), args&: in, args: 0); |
| 144 | Value result = castF32To(desType: outElemType, f32: asFloat, loc, rewriter); |
| 145 | rewriter.replaceOp(op, newValues: result); |
| 146 | return success(); |
| 147 | } |
| 148 | int64_t numElements = inVecType.getNumElements(); |
| 149 | |
| 150 | Value zero = rewriter.create<arith::ConstantOp>( |
| 151 | location: loc, args&: outElemType, args: rewriter.getFloatAttr(type: outElemType, value: 0.0)); |
| 152 | VectorType outType = cast<VectorType>(Val: op.getOut().getType()); |
| 153 | |
| 154 | if (inVecType.getShape().empty()) { |
| 155 | Value zerodSplat = |
| 156 | rewriter.createOrFold<vector::SplatOp>(location: loc, args&: outType, args&: zero); |
| 157 | Value scalarIn = |
| 158 | rewriter.create<vector::ExtractOp>(location: loc, args&: in, args: ArrayRef<int64_t>{}); |
| 159 | Value scalarExt = |
| 160 | rewriter.create<arith::ExtFOp>(location: loc, args&: outElemType, args&: scalarIn); |
| 161 | Value result = rewriter.create<vector::InsertOp>(location: loc, args&: scalarExt, args&: zerodSplat, |
| 162 | args: ArrayRef<int64_t>{}); |
| 163 | rewriter.replaceOp(op, newValues: result); |
| 164 | return success(); |
| 165 | } |
| 166 | |
| 167 | VectorType flatTy = VectorType::get(shape: SmallVector<int64_t>{numElements}, |
| 168 | elementType: outType.getElementType()); |
| 169 | Value result = rewriter.createOrFold<vector::SplatOp>(location: loc, args&: flatTy, args&: zero); |
| 170 | |
| 171 | if (inVecType.getRank() > 1) { |
| 172 | inVecType = VectorType::get(shape: SmallVector<int64_t>{numElements}, |
| 173 | elementType: inVecType.getElementType()); |
| 174 | in = rewriter.create<vector::ShapeCastOp>(location: loc, args&: inVecType, args&: in); |
| 175 | } |
| 176 | |
| 177 | for (int64_t i = 0; i < numElements; i += 4) { |
| 178 | int64_t elemsThisOp = std::min(a: numElements, b: i + 4) - i; |
| 179 | Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>( |
| 180 | location: loc, args&: in, args&: i, args&: elemsThisOp, args: 1); |
| 181 | for (int64_t j = 0; j < elemsThisOp; j += 2) { |
| 182 | if (i + j + 1 < numElements) { // Convert two 8-bit elements |
| 183 | Value asFloats = rewriter.create<amdgpu::ExtPackedFp8Op>( |
| 184 | location: loc, args&: extResType, args&: inSlice, args: j / 2); |
| 185 | Type desType = VectorType::get(shape: 2, elementType: outElemType); |
| 186 | Value asType = castF32To(desType, f32: asFloats, loc, rewriter); |
| 187 | result = rewriter.create<vector::InsertStridedSliceOp>( |
| 188 | location: loc, args&: asType, args&: result, args: i + j, args: 1); |
| 189 | } else { // Convert a 8-bit element |
| 190 | Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>( |
| 191 | location: loc, args: rewriter.getF32Type(), args&: inSlice, args: j / 2 * 2); |
| 192 | Value asType = castF32To(desType: outElemType, f32: asFloat, loc, rewriter); |
| 193 | result = rewriter.create<vector::InsertOp>(location: loc, args&: asType, args&: result, args: i + j); |
| 194 | } |
| 195 | } |
| 196 | } |
| 197 | |
| 198 | if (inVecType.getRank() != outType.getRank()) { |
| 199 | result = rewriter.create<vector::ShapeCastOp>(location: loc, args&: outType, args&: result); |
| 200 | } |
| 201 | |
| 202 | rewriter.replaceOp(op, newValues: result); |
| 203 | return success(); |
| 204 | } |
| 205 | |
| 206 | static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) { |
| 207 | Type type = value.getType(); |
| 208 | if (type.isF32()) |
| 209 | return value; |
| 210 | if (type.getIntOrFloatBitWidth() < 32) |
| 211 | return rewriter.create<arith::ExtFOp>(location: loc, args: rewriter.getF32Type(), args&: value); |
| 212 | if (type.getIntOrFloatBitWidth() > 32) |
| 213 | return rewriter.create<arith::TruncFOp>(location: loc, args: rewriter.getF32Type(), args&: value); |
| 214 | llvm_unreachable("The only 32-bit float type is f32" ); |
| 215 | } |
| 216 | |
| 217 | // If `in` is a finite value, clamp it between the maximum and minimum values |
| 218 | // of `outElemType` so that subsequent conversion instructions don't |
| 219 | // overflow those out-of-range values to NaN. These semantics are commonly |
| 220 | // used in machine-learning contexts where failure to clamp would lead to |
| 221 | // excessive NaN production. |
| 222 | static Value clampInput(PatternRewriter &rewriter, Location loc, |
| 223 | Type outElemType, Value source) { |
| 224 | Type sourceType = source.getType(); |
| 225 | const llvm::fltSemantics &sourceSem = |
| 226 | cast<FloatType>(Val: getElementTypeOrSelf(type: sourceType)).getFloatSemantics(); |
| 227 | const llvm::fltSemantics &targetSem = |
| 228 | cast<FloatType>(Val&: outElemType).getFloatSemantics(); |
| 229 | |
| 230 | APFloat min = APFloat::getLargest(Sem: targetSem, /*Negative=*/true); |
| 231 | APFloat max = APFloat::getLargest(Sem: targetSem, /*Negative=*/false); |
| 232 | bool ignoredLosesInfo = false; |
| 233 | // We can ignore conversion failures here because this conversion promotes |
| 234 | // from a smaller type to a larger one - ex. there can be no loss of precision |
| 235 | // when casting fp8 to f16. |
| 236 | (void)min.convert(ToSemantics: sourceSem, RM: APFloat::rmNearestTiesToEven, losesInfo: &ignoredLosesInfo); |
| 237 | (void)max.convert(ToSemantics: sourceSem, RM: APFloat::rmNearestTiesToEven, losesInfo: &ignoredLosesInfo); |
| 238 | |
| 239 | Value minCst = createScalarOrSplatConstant(builder&: rewriter, loc, type: sourceType, value: min); |
| 240 | Value maxCst = createScalarOrSplatConstant(builder&: rewriter, loc, type: sourceType, value: max); |
| 241 | |
| 242 | Value inf = createScalarOrSplatConstant( |
| 243 | builder&: rewriter, loc, type: sourceType, |
| 244 | value: APFloat::getInf(Sem: sourceSem, /*Negative=*/false)); |
| 245 | Value negInf = createScalarOrSplatConstant( |
| 246 | builder&: rewriter, loc, type: sourceType, value: APFloat::getInf(Sem: sourceSem, /*Negative=*/true)); |
| 247 | Value isInf = rewriter.createOrFold<arith::CmpFOp>( |
| 248 | location: loc, args: arith::CmpFPredicate::OEQ, args&: source, args&: inf); |
| 249 | Value isNegInf = rewriter.createOrFold<arith::CmpFOp>( |
| 250 | location: loc, args: arith::CmpFPredicate::OEQ, args&: source, args&: negInf); |
| 251 | Value isNan = rewriter.createOrFold<arith::CmpFOp>( |
| 252 | location: loc, args: arith::CmpFPredicate::UNO, args&: source, args&: source); |
| 253 | Value isNonFinite = rewriter.create<arith::OrIOp>( |
| 254 | location: loc, args: rewriter.create<arith::OrIOp>(location: loc, args&: isInf, args&: isNegInf), args&: isNan); |
| 255 | |
| 256 | Value clampedBelow = rewriter.create<arith::MaximumFOp>(location: loc, args&: source, args&: minCst); |
| 257 | Value clamped = rewriter.create<arith::MinimumFOp>(location: loc, args&: clampedBelow, args&: maxCst); |
| 258 | Value res = |
| 259 | rewriter.create<arith::SelectOp>(location: loc, args&: isNonFinite, args&: source, args&: clamped); |
| 260 | return res; |
| 261 | } |
| 262 | |
| 263 | LogicalResult |
| 264 | TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op, |
| 265 | PatternRewriter &rewriter) const { |
| 266 | // Only supporting default rounding mode as of now. |
| 267 | if (op.getRoundingmodeAttr()) |
| 268 | return failure(); |
| 269 | Type outType = op.getOut().getType(); |
| 270 | auto outVecType = dyn_cast<VectorType>(Val&: outType); |
| 271 | if (outVecType) { |
| 272 | if (outVecType.isScalable()) |
| 273 | return failure(); |
| 274 | outType = outVecType.getElementType(); |
| 275 | } |
| 276 | auto inType = dyn_cast<FloatType>(Val: getElementTypeOrSelf(type: op.getIn().getType())); |
| 277 | if (inType && inType.getWidth() <= 8 && saturateFP8) |
| 278 | // Conversion between 8-bit floats is not supported with truncation enabled. |
| 279 | return failure(); |
| 280 | |
| 281 | if (!isSupportedF8(elementType: outType, chipset)) |
| 282 | return failure(); |
| 283 | |
| 284 | Location loc = op.getLoc(); |
| 285 | Value in = op.getIn(); |
| 286 | Type outElemType = getElementTypeOrSelf(type: op.getOut().getType()); |
| 287 | if (saturateFP8) |
| 288 | in = clampInput(rewriter, loc, outElemType, source: in); |
| 289 | auto inVectorTy = dyn_cast<VectorType>(Val: in.getType()); |
| 290 | VectorType truncResType = VectorType::get(shape: 4, elementType: outElemType); |
| 291 | if (!inVectorTy) { |
| 292 | Value asFloat = castToF32(value: in, loc, rewriter); |
| 293 | Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>( |
| 294 | location: loc, args&: truncResType, args&: asFloat, /*sourceB=*/args: nullptr, args: 0, |
| 295 | /*existing=*/args: nullptr); |
| 296 | Value result = rewriter.create<vector::ExtractOp>(location: loc, args&: asF8s, args: 0); |
| 297 | rewriter.replaceOp(op, newValues: result); |
| 298 | return success(); |
| 299 | } |
| 300 | |
| 301 | int64_t numElements = outVecType.getNumElements(); |
| 302 | Value zero = rewriter.create<arith::ConstantOp>( |
| 303 | location: loc, args&: outElemType, args: rewriter.getFloatAttr(type: outElemType, value: 0.0)); |
| 304 | if (outVecType.getShape().empty()) { |
| 305 | Value scalarIn = |
| 306 | rewriter.create<vector::ExtractOp>(location: loc, args&: in, args: ArrayRef<int64_t>{}); |
| 307 | // Recurse to send the 0-D vector case to the 1-D vector case |
| 308 | Value scalarTrunc = |
| 309 | rewriter.create<arith::TruncFOp>(location: loc, args&: outElemType, args&: scalarIn); |
| 310 | Value result = rewriter.create<vector::InsertOp>(location: loc, args&: scalarTrunc, args&: zero, |
| 311 | args: ArrayRef<int64_t>{}); |
| 312 | rewriter.replaceOp(op, newValues: result); |
| 313 | return success(); |
| 314 | } |
| 315 | |
| 316 | VectorType flatTy = VectorType::get(shape: SmallVector<int64_t>{numElements}, |
| 317 | elementType: outVecType.getElementType()); |
| 318 | Value result = rewriter.createOrFold<vector::SplatOp>(location: loc, args&: flatTy, args&: zero); |
| 319 | |
| 320 | if (inVectorTy.getRank() > 1) { |
| 321 | inVectorTy = VectorType::get(shape: SmallVector<int64_t>{numElements}, |
| 322 | elementType: inVectorTy.getElementType()); |
| 323 | in = rewriter.create<vector::ShapeCastOp>(location: loc, args&: inVectorTy, args&: in); |
| 324 | } |
| 325 | |
| 326 | for (int64_t i = 0; i < numElements; i += 4) { |
| 327 | int64_t elemsThisOp = std::min(a: numElements, b: i + 4) - i; |
| 328 | Value thisResult = nullptr; |
| 329 | for (int64_t j = 0; j < elemsThisOp; j += 2) { |
| 330 | Value elemA = rewriter.create<vector::ExtractOp>(location: loc, args&: in, args: i + j); |
| 331 | Value asFloatA = castToF32(value: elemA, loc, rewriter); |
| 332 | Value asFloatB = nullptr; |
| 333 | if (j + 1 < elemsThisOp) { |
| 334 | Value elemB = rewriter.create<vector::ExtractOp>(location: loc, args&: in, args: i + j + 1); |
| 335 | asFloatB = castToF32(value: elemB, loc, rewriter); |
| 336 | } |
| 337 | thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>( |
| 338 | location: loc, args&: truncResType, args&: asFloatA, args&: asFloatB, args: j / 2, args&: thisResult); |
| 339 | } |
| 340 | if (elemsThisOp < 4) |
| 341 | thisResult = rewriter.create<vector::ExtractStridedSliceOp>( |
| 342 | location: loc, args&: thisResult, args: 0, args&: elemsThisOp, args: 1); |
| 343 | result = rewriter.create<vector::InsertStridedSliceOp>(location: loc, args&: thisResult, |
| 344 | args&: result, args&: i, args: 1); |
| 345 | } |
| 346 | |
| 347 | if (inVectorTy.getRank() != outVecType.getRank()) { |
| 348 | result = rewriter.create<vector::ShapeCastOp>(location: loc, args&: outVecType, args&: result); |
| 349 | } |
| 350 | |
| 351 | rewriter.replaceOp(op, newValues: result); |
| 352 | return success(); |
| 353 | } |
| 354 | |
| 355 | LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite( |
| 356 | arith::TruncFOp op, PatternRewriter &rewriter) const { |
| 357 | Type outType = op.getOut().getType(); |
| 358 | Type inputType = getElementTypeOrSelf(val: op.getIn()); |
| 359 | auto outVecType = dyn_cast<VectorType>(Val&: outType); |
| 360 | if (outVecType) { |
| 361 | if (outVecType.isScalable()) |
| 362 | return failure(); |
| 363 | outType = outVecType.getElementType(); |
| 364 | } |
| 365 | if (!(outType.isF16() && inputType.isF32())) |
| 366 | return failure(); |
| 367 | |
| 368 | Location loc = op.getLoc(); |
| 369 | Value in = op.getIn(); |
| 370 | Type outElemType = getElementTypeOrSelf(type: op.getOut().getType()); |
| 371 | VectorType truncResType = VectorType::get(shape: 2, elementType: outElemType); |
| 372 | auto inVectorTy = dyn_cast<VectorType>(Val: in.getType()); |
| 373 | |
| 374 | // Handle the case where input type is not a vector type |
| 375 | if (!inVectorTy) { |
| 376 | auto sourceB = rewriter.create<LLVM::PoisonOp>(location: loc, args: rewriter.getF32Type()); |
| 377 | Value asF16s = |
| 378 | rewriter.create<ROCDL::CvtPkRtz>(location: loc, args&: truncResType, args&: in, args&: sourceB); |
| 379 | Value result = rewriter.create<vector::ExtractOp>(location: loc, args&: asF16s, args: 0); |
| 380 | rewriter.replaceOp(op, newValues: result); |
| 381 | return success(); |
| 382 | } |
| 383 | int64_t numElements = outVecType.getNumElements(); |
| 384 | Value zero = rewriter.createOrFold<arith::ConstantOp>( |
| 385 | location: loc, args&: outElemType, args: rewriter.getFloatAttr(type: outElemType, value: 0.0)); |
| 386 | Value result = rewriter.createOrFold<vector::SplatOp>(location: loc, args&: outVecType, args&: zero); |
| 387 | |
| 388 | if (inVectorTy.getRank() > 1) { |
| 389 | inVectorTy = VectorType::get(shape: SmallVector<int64_t>{numElements}, |
| 390 | elementType: inVectorTy.getElementType()); |
| 391 | in = rewriter.create<vector::ShapeCastOp>(location: loc, args&: inVectorTy, args&: in); |
| 392 | } |
| 393 | |
| 394 | // Handle the vector case. We also handle the (uncommon) case where the vector |
| 395 | // length is odd |
| 396 | for (int64_t i = 0; i < numElements; i += 2) { |
| 397 | int64_t elemsThisOp = std::min(a: numElements, b: i + 2) - i; |
| 398 | Value thisResult = nullptr; |
| 399 | Value elemA = rewriter.create<vector::ExtractOp>(location: loc, args&: in, args&: i); |
| 400 | Value elemB = rewriter.create<LLVM::PoisonOp>(location: loc, args: rewriter.getF32Type()); |
| 401 | |
| 402 | if (elemsThisOp == 2) { |
| 403 | elemB = rewriter.create<vector::ExtractOp>(location: loc, args&: in, args: i + 1); |
| 404 | } |
| 405 | |
| 406 | thisResult = |
| 407 | rewriter.create<ROCDL::CvtPkRtz>(location: loc, args&: truncResType, args&: elemA, args&: elemB); |
| 408 | // Place back the truncated result into the possibly larger vector. If we |
| 409 | // are operating on a size 2 vector, these operations should be folded away |
| 410 | thisResult = rewriter.create<vector::ExtractStridedSliceOp>( |
| 411 | location: loc, args&: thisResult, args: 0, args&: elemsThisOp, args: 1); |
| 412 | result = rewriter.create<vector::InsertStridedSliceOp>(location: loc, args&: thisResult, |
| 413 | args&: result, args&: i, args: 1); |
| 414 | } |
| 415 | |
| 416 | if (inVectorTy.getRank() != outVecType.getRank()) { |
| 417 | result = rewriter.create<vector::ShapeCastOp>(location: loc, args&: outVecType, args&: result); |
| 418 | } |
| 419 | |
| 420 | rewriter.replaceOp(op, newValues: result); |
| 421 | return success(); |
| 422 | } |
| 423 | |
| 424 | /// Get the broadcasted / splatted value for a chain of ops. |
| 425 | static Value getOriginalVectorValue(Value value) { |
| 426 | Value current = value; |
| 427 | while (Operation *definingOp = current.getDefiningOp()) { |
| 428 | bool skipOp = llvm::TypeSwitch<Operation *, bool>(definingOp) |
| 429 | .Case<vector::ShapeCastOp>(caseFn: [¤t](auto op) { |
| 430 | current = op.getSource(); |
| 431 | return true; |
| 432 | }) |
| 433 | .Case<vector::BroadcastOp>(caseFn: [¤t](auto op) { |
| 434 | current = op.getSource(); |
| 435 | return false; |
| 436 | }) |
| 437 | .Case<vector::SplatOp>(caseFn: [¤t](auto op) { |
| 438 | current = op.getInput(); |
| 439 | return false; |
| 440 | }) |
| 441 | .Default(defaultFn: [](Operation *) { return false; }); |
| 442 | |
| 443 | if (!skipOp) { |
| 444 | break; |
| 445 | } |
| 446 | } |
| 447 | return current; |
| 448 | } |
| 449 | |
| 450 | LogicalResult |
| 451 | ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, |
| 452 | PatternRewriter &rewriter) const { |
| 453 | Location loc = op.getLoc(); |
| 454 | constexpr int64_t opWidth = 2; |
| 455 | |
| 456 | Value in = op.getIn(); |
| 457 | Value scale = op.getScale(); |
| 458 | Value out = op.getOut(); |
| 459 | |
| 460 | Type f32 = rewriter.getF32Type(); |
| 461 | Type inType = getElementTypeOrSelf(val: in); |
| 462 | Type scaleType = getElementTypeOrSelf(val: scale); |
| 463 | Type outType = getElementTypeOrSelf(val: out); |
| 464 | |
| 465 | VectorType outVecType = dyn_cast<VectorType>(Val: out.getType()); |
| 466 | VectorType scaleVecType = dyn_cast<VectorType>(Val: scale.getType()); |
| 467 | |
| 468 | if (outVecType && outVecType.isScalable()) |
| 469 | return failure(); |
| 470 | |
| 471 | Type scaleF32Type = |
| 472 | scaleVecType ? VectorType::get(shape: scaleVecType.getShape(), elementType: f32) : f32; |
| 473 | if (scaleType.getIntOrFloatBitWidth() < 32) |
| 474 | scale = rewriter.create<arith::ExtFOp>(location: loc, args&: scaleF32Type, args&: scale); |
| 475 | else if (scaleType.getIntOrFloatBitWidth() > 32) |
| 476 | scale = rewriter.create<arith::TruncFOp>(location: loc, args&: scaleF32Type, args&: scale); |
| 477 | |
| 478 | VectorType extScaleResultType = VectorType::get(shape: opWidth, elementType: outType); |
| 479 | |
| 480 | if (!outVecType) { |
| 481 | Value inCast = |
| 482 | rewriter.create<vector::SplatOp>(location: loc, args: VectorType::get(shape: 1, elementType: inType), args&: in); |
| 483 | // TODO: replace this with non-packed ScaledExtOp |
| 484 | Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>( |
| 485 | location: loc, args&: extScaleResultType, args&: inCast, args&: scale, args: 0); |
| 486 | scaleExt = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, args&: scaleExt, args: 0); |
| 487 | return success(); |
| 488 | } |
| 489 | |
| 490 | VectorType inVecType = cast<VectorType>(Val: in.getType()); |
| 491 | Value origScale = getOriginalVectorValue(value: op.getScale()); |
| 492 | |
| 493 | ArrayRef<int64_t> inShape = inVecType.getShape(); |
| 494 | SmallVector<int64_t> originalScaleShape; |
| 495 | if (auto origScaleVecType = dyn_cast<VectorType>(Val: origScale.getType())) |
| 496 | llvm::append_range(C&: originalScaleShape, R: origScaleVecType.getShape()); |
| 497 | |
| 498 | originalScaleShape.insert(I: originalScaleShape.end(), |
| 499 | NumToInsert: inShape.size() - originalScaleShape.size(), Elt: 1); |
| 500 | |
| 501 | auto maybeRatio = computeShapeRatio(shape: inShape, subShape: originalScaleShape); |
| 502 | assert(maybeRatio && |
| 503 | "failed to derive block size from broadcast or splat operation" ); |
| 504 | |
| 505 | SmallVector<int64_t> ratio = |
| 506 | maybeRatio.value_or(u: SmallVector<int64_t>(inShape.size(), 1)); |
| 507 | |
| 508 | int64_t blockSize = computeProduct(basis: ratio); |
| 509 | |
| 510 | Value zero = rewriter.create<arith::ConstantOp>( |
| 511 | location: loc, args&: outType, args: rewriter.getFloatAttr(type: outType, value: 0.0)); |
| 512 | Value result = rewriter.createOrFold<vector::SplatOp>(location: loc, args&: outVecType, args&: zero); |
| 513 | |
| 514 | for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) { |
| 515 | SmallVector<int64_t> strides(offsets.size(), 1); |
| 516 | Value block = rewriter.create<vector::ExtractStridedSliceOp>( |
| 517 | location: loc, args&: in, args&: offsets, args&: ratio, args&: strides); |
| 518 | VectorType block1DType = VectorType::get(shape: blockSize, elementType: inType); |
| 519 | Value block1D = |
| 520 | rewriter.create<vector::ShapeCastOp>(location: loc, args&: block1DType, args&: block); |
| 521 | Value uniformScale = |
| 522 | rewriter.create<vector::ExtractOp>(location: loc, args&: scale, args&: offsets); |
| 523 | |
| 524 | VectorType blockResultType = VectorType::get(shape: blockSize, elementType: outType); |
| 525 | Value blockResult = |
| 526 | rewriter.createOrFold<vector::SplatOp>(location: loc, args&: blockResultType, args&: zero); |
| 527 | |
| 528 | for (int64_t i = 0, sliceWidth = std::min(a: opWidth, b: blockSize - i); |
| 529 | i < blockSize; |
| 530 | i += sliceWidth, sliceWidth = std::min(a: opWidth, b: blockSize - i)) { |
| 531 | Value slice = rewriter.create<vector::ExtractStridedSliceOp>( |
| 532 | location: loc, args&: block1D, args&: i, args&: sliceWidth, args: 1); |
| 533 | // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1 |
| 534 | Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>( |
| 535 | location: loc, args&: extScaleResultType, args&: slice, args&: uniformScale, args: 0); |
| 536 | if (sliceWidth != opWidth) |
| 537 | scaleExt = rewriter.create<vector::ExtractStridedSliceOp>( |
| 538 | location: loc, args&: scaleExt, args: 0, args&: sliceWidth, args: 1); |
| 539 | blockResult = rewriter.create<vector::InsertStridedSliceOp>( |
| 540 | location: loc, args&: scaleExt, args&: blockResult, args&: i, args: 1); |
| 541 | } |
| 542 | |
| 543 | VectorType resultType = VectorType::get(shape: ratio, elementType: outType); |
| 544 | Value cast = |
| 545 | rewriter.create<vector::ShapeCastOp>(location: loc, args&: resultType, args&: blockResult); |
| 546 | result = rewriter.create<vector::InsertStridedSliceOp>(location: loc, args&: cast, args&: result, |
| 547 | args&: offsets, args&: strides); |
| 548 | } |
| 549 | |
| 550 | rewriter.replaceOp(op, newValues: result); |
| 551 | |
| 552 | return success(); |
| 553 | } |
| 554 | |
| 555 | LogicalResult |
| 556 | ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, |
| 557 | PatternRewriter &rewriter) const { |
| 558 | Location loc = op.getLoc(); |
| 559 | constexpr int64_t opWidth = 2; |
| 560 | |
| 561 | Value in = op.getIn(); |
| 562 | Value scale = op.getScale(); |
| 563 | Value out = op.getOut(); |
| 564 | |
| 565 | Type f32 = rewriter.getF32Type(); |
| 566 | Type inType = getElementTypeOrSelf(val: in); |
| 567 | Type scaleType = getElementTypeOrSelf(val: scale); |
| 568 | Type outType = getElementTypeOrSelf(val: out); |
| 569 | |
| 570 | VectorType outVecType = dyn_cast<VectorType>(Val: out.getType()); |
| 571 | VectorType scaleVecType = dyn_cast<VectorType>(Val: scale.getType()); |
| 572 | |
| 573 | if (outVecType && outVecType.isScalable()) |
| 574 | return failure(); |
| 575 | |
| 576 | Type scaleF32Type = |
| 577 | scaleVecType ? VectorType::get(shape: scaleVecType.getShape(), elementType: f32) : f32; |
| 578 | if (scaleType.getIntOrFloatBitWidth() < 32) |
| 579 | scale = rewriter.create<arith::ExtFOp>(location: loc, args&: scaleF32Type, args&: scale); |
| 580 | else if (scaleType.getIntOrFloatBitWidth() > 32) |
| 581 | scale = rewriter.create<arith::TruncFOp>(location: loc, args&: scaleF32Type, args&: scale); |
| 582 | |
| 583 | Value zero = rewriter.create<arith::ConstantOp>( |
| 584 | location: loc, args&: outType, args: rewriter.getFloatAttr(type: outType, value: 0.0)); |
| 585 | unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth(); |
| 586 | VectorType truncScaleResultType = VectorType::get(shape: numPackedElem, elementType: outType); |
| 587 | |
| 588 | if (!outVecType) { |
| 589 | Type inVecType = VectorType::get(shape: 1, elementType: inType); |
| 590 | Value inCast = rewriter.create<vector::SplatOp>(location: loc, args&: inVecType, args&: in); |
| 591 | // TODO: replace this with non-packed ScaledTruncOp |
| 592 | Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>( |
| 593 | location: loc, args&: truncScaleResultType, args&: inCast, args&: scale, args: 0, /*existing=*/args: nullptr); |
| 594 | scaleTrunc = |
| 595 | rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, args&: scaleTrunc, args: 0); |
| 596 | return success(); |
| 597 | } |
| 598 | |
| 599 | VectorType inVecType = cast<VectorType>(Val: in.getType()); |
| 600 | Value origScale = getOriginalVectorValue(value: op.getScale()); |
| 601 | |
| 602 | ArrayRef<int64_t> inShape = inVecType.getShape(); |
| 603 | SmallVector<int64_t> originalScaleShape; |
| 604 | if (auto origScaleVecType = dyn_cast<VectorType>(Val: origScale.getType())) |
| 605 | llvm::append_range(C&: originalScaleShape, R: origScaleVecType.getShape()); |
| 606 | |
| 607 | originalScaleShape.insert(I: originalScaleShape.end(), |
| 608 | NumToInsert: inShape.size() - originalScaleShape.size(), Elt: 1); |
| 609 | |
| 610 | auto maybeRatio = computeShapeRatio(shape: inShape, subShape: originalScaleShape); |
| 611 | assert(maybeRatio && |
| 612 | "failed to derive block size from broadcast or splat operation" ); |
| 613 | |
| 614 | SmallVector<int64_t> ratio = |
| 615 | maybeRatio.value_or(u: SmallVector<int64_t>(inShape.size(), 1)); |
| 616 | |
| 617 | int64_t blockSize = computeProduct(basis: ratio); |
| 618 | |
| 619 | Value result = rewriter.createOrFold<vector::SplatOp>(location: loc, args&: outVecType, args&: zero); |
| 620 | |
| 621 | for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) { |
| 622 | SmallVector<int64_t> strides(offsets.size(), 1); |
| 623 | Value block = rewriter.create<vector::ExtractStridedSliceOp>( |
| 624 | location: loc, args&: in, args&: offsets, args&: ratio, args&: strides); |
| 625 | VectorType block1DType = VectorType::get(shape: blockSize, elementType: inType); |
| 626 | Value block1D = |
| 627 | rewriter.create<vector::ShapeCastOp>(location: loc, args&: block1DType, args&: block); |
| 628 | Value uniformScale = |
| 629 | rewriter.create<vector::ExtractOp>(location: loc, args&: scale, args&: offsets); |
| 630 | |
| 631 | VectorType blockResultType = VectorType::get(shape: blockSize, elementType: outType); |
| 632 | Value blockResult = |
| 633 | rewriter.createOrFold<vector::SplatOp>(location: loc, args&: blockResultType, args&: zero); |
| 634 | |
| 635 | for (int64_t i = 0, sliceWidth = std::min(a: opWidth, b: blockSize - i); |
| 636 | i < blockSize; |
| 637 | i += sliceWidth, sliceWidth = std::min(a: opWidth, b: blockSize - i)) { |
| 638 | Value slice = rewriter.create<vector::ExtractStridedSliceOp>( |
| 639 | location: loc, args&: block1D, args&: i, args&: sliceWidth, args: 1); |
| 640 | // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1 |
| 641 | Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>( |
| 642 | location: loc, args&: truncScaleResultType, args&: slice, args&: uniformScale, args: 0, |
| 643 | /*existing=*/args: nullptr); |
| 644 | int64_t packedWidth = |
| 645 | cast<VectorType>(Val: scaleTrunc.getType()).getNumElements(); |
| 646 | if (packedWidth != opWidth) |
| 647 | scaleTrunc = rewriter.create<vector::ExtractStridedSliceOp>( |
| 648 | location: loc, args&: scaleTrunc, args: 0, args&: sliceWidth, args: 1); |
| 649 | blockResult = rewriter.create<vector::InsertStridedSliceOp>( |
| 650 | location: loc, args&: scaleTrunc, args&: blockResult, args&: i, args: 1); |
| 651 | } |
| 652 | |
| 653 | VectorType resultType = VectorType::get(shape: ratio, elementType: outType); |
| 654 | Value cast = |
| 655 | rewriter.create<vector::ShapeCastOp>(location: loc, args&: resultType, args&: blockResult); |
| 656 | result = rewriter.create<vector::InsertStridedSliceOp>(location: loc, args&: cast, args&: result, |
| 657 | args&: offsets, args&: strides); |
| 658 | } |
| 659 | |
| 660 | rewriter.replaceOp(op, newValues: result); |
| 661 | |
| 662 | return success(); |
| 663 | } |
| 664 | |
| 665 | void mlir::arith::populateArithToAMDGPUConversionPatterns( |
| 666 | RewritePatternSet &patterns, bool convertFP8Arithmetic, |
| 667 | bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) { |
| 668 | |
| 669 | if (convertFP8Arithmetic) { |
| 670 | patterns.add<ExtFOnFloat8RewritePattern>(arg: patterns.getContext(), args&: chipset); |
| 671 | patterns.add<TruncFToFloat8RewritePattern>(arg: patterns.getContext(), |
| 672 | args&: saturateFP8Truncf, args&: chipset); |
| 673 | } |
| 674 | if (allowPackedF16Rtz) |
| 675 | patterns.add<TruncfToFloat16RewritePattern>(arg: patterns.getContext()); |
| 676 | |
| 677 | if (chipset >= kGfx950) { |
| 678 | patterns.add<ScalingExtFRewritePattern>(arg: patterns.getContext()); |
| 679 | patterns.add<ScalingTruncFRewritePattern>(arg: patterns.getContext()); |
| 680 | } |
| 681 | } |
| 682 | |
| 683 | void ArithToAMDGPUConversionPass::runOnOperation() { |
| 684 | Operation *op = getOperation(); |
| 685 | MLIRContext *ctx = &getContext(); |
| 686 | RewritePatternSet patterns(op->getContext()); |
| 687 | FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(name: chipset); |
| 688 | if (failed(Result: maybeChipset)) { |
| 689 | emitError(loc: UnknownLoc::get(context: ctx), message: "Invalid chipset name: " + chipset); |
| 690 | return signalPassFailure(); |
| 691 | } |
| 692 | |
| 693 | bool convertFP8Arithmetic = |
| 694 | *maybeChipset == kGfx942 || hasOcpFp8(chipset: *maybeChipset); |
| 695 | arith::populateArithToAMDGPUConversionPatterns( |
| 696 | patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz, |
| 697 | chipset: *maybeChipset); |
| 698 | if (failed(Result: applyPatternsGreedily(op, patterns: std::move(patterns)))) |
| 699 | return signalPassFailure(); |
| 700 | } |
| 701 | |