1//===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
10
11#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/Arith/Utils/Utils.h"
15#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
17#include "mlir/Dialect/Utils/IndexingUtils.h"
18#include "mlir/Dialect/Vector/IR/VectorOps.h"
19#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
20#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
21#include "mlir/IR/BuiltinTypes.h"
22#include "mlir/IR/PatternMatch.h"
23#include "mlir/IR/TypeUtilities.h"
24#include "mlir/Pass/Pass.h"
25#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26
27namespace mlir {
28#define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
29#include "mlir/Conversion/Passes.h.inc"
30} // namespace mlir
31
32using namespace mlir;
33using namespace mlir::amdgpu;
34
35namespace {
36// Define commonly used chipsets versions for convenience.
37constexpr Chipset kGfx942 = Chipset(9, 4, 2);
38constexpr Chipset kGfx950 = Chipset(9, 5, 0);
39
40struct ArithToAMDGPUConversionPass final
41 : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
42 using impl::ArithToAMDGPUConversionPassBase<
43 ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase;
44
45 void runOnOperation() override;
46};
47
48struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
49 using OpRewritePattern::OpRewritePattern;
50
51 Chipset chipset;
52 ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
53 : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
54
55 LogicalResult matchAndRewrite(arith::ExtFOp op,
56 PatternRewriter &rewriter) const override;
57};
58
59struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
60 bool saturateFP8 = false;
61 TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
62 Chipset chipset)
63 : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
64 chipset(chipset) {}
65 Chipset chipset;
66
67 LogicalResult matchAndRewrite(arith::TruncFOp op,
68 PatternRewriter &rewriter) const override;
69};
70
71struct TruncfToFloat16RewritePattern final
72 : public OpRewritePattern<arith::TruncFOp> {
73
74 using OpRewritePattern::OpRewritePattern;
75
76 LogicalResult matchAndRewrite(arith::TruncFOp op,
77 PatternRewriter &rewriter) const override;
78};
79
80struct ScalingExtFRewritePattern final
81 : OpRewritePattern<arith::ScalingExtFOp> {
82 using OpRewritePattern::OpRewritePattern;
83
84 ScalingExtFRewritePattern(MLIRContext *ctx)
85 : OpRewritePattern::OpRewritePattern(ctx) {}
86
87 LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
88 PatternRewriter &rewriter) const override;
89};
90
91struct ScalingTruncFRewritePattern final
92 : OpRewritePattern<arith::ScalingTruncFOp> {
93 using OpRewritePattern::OpRewritePattern;
94
95 ScalingTruncFRewritePattern(MLIRContext *ctx)
96 : OpRewritePattern::OpRewritePattern(ctx) {}
97
98 LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
99 PatternRewriter &rewriter) const override;
100};
101
102} // end namespace
103
104static bool isSupportedF8(Type elementType, Chipset chipset) {
105 if (chipset == kGfx942)
106 return isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(Val: elementType);
107 if (hasOcpFp8(chipset))
108 return isa<Float8E4M3FNType, Float8E5M2Type>(Val: elementType);
109 return false;
110}
111
112static Value castF32To(Type desType, Value f32, Location loc,
113 PatternRewriter &rewriter) {
114 Type elementType = getElementTypeOrSelf(type: desType);
115 if (elementType.isF32())
116 return f32;
117 if (elementType.getIntOrFloatBitWidth() < 32)
118 return rewriter.create<arith::TruncFOp>(location: loc, args&: desType, args&: f32);
119 if (elementType.getIntOrFloatBitWidth() > 32)
120 return rewriter.create<arith::ExtFOp>(location: loc, args&: desType, args&: f32);
121 llvm_unreachable("The only 32-bit float type is f32");
122}
123
124LogicalResult
125ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
126 PatternRewriter &rewriter) const {
127 Type inType = op.getIn().getType();
128 auto inVecType = dyn_cast<VectorType>(Val&: inType);
129 if (inVecType) {
130 if (inVecType.isScalable())
131 return failure();
132 inType = inVecType.getElementType();
133 }
134 if (!isSupportedF8(elementType: inType, chipset))
135 return failure();
136
137 Location loc = op.getLoc();
138 Value in = op.getIn();
139 Type outElemType = getElementTypeOrSelf(type: op.getOut().getType());
140 VectorType extResType = VectorType::get(shape: 2, elementType: rewriter.getF32Type());
141 if (!inVecType) {
142 Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
143 location: loc, args: rewriter.getF32Type(), args&: in, args: 0);
144 Value result = castF32To(desType: outElemType, f32: asFloat, loc, rewriter);
145 rewriter.replaceOp(op, newValues: result);
146 return success();
147 }
148 int64_t numElements = inVecType.getNumElements();
149
150 Value zero = rewriter.create<arith::ConstantOp>(
151 location: loc, args&: outElemType, args: rewriter.getFloatAttr(type: outElemType, value: 0.0));
152 VectorType outType = cast<VectorType>(Val: op.getOut().getType());
153
154 if (inVecType.getShape().empty()) {
155 Value zerodSplat =
156 rewriter.createOrFold<vector::SplatOp>(location: loc, args&: outType, args&: zero);
157 Value scalarIn =
158 rewriter.create<vector::ExtractOp>(location: loc, args&: in, args: ArrayRef<int64_t>{});
159 Value scalarExt =
160 rewriter.create<arith::ExtFOp>(location: loc, args&: outElemType, args&: scalarIn);
161 Value result = rewriter.create<vector::InsertOp>(location: loc, args&: scalarExt, args&: zerodSplat,
162 args: ArrayRef<int64_t>{});
163 rewriter.replaceOp(op, newValues: result);
164 return success();
165 }
166
167 VectorType flatTy = VectorType::get(shape: SmallVector<int64_t>{numElements},
168 elementType: outType.getElementType());
169 Value result = rewriter.createOrFold<vector::SplatOp>(location: loc, args&: flatTy, args&: zero);
170
171 if (inVecType.getRank() > 1) {
172 inVecType = VectorType::get(shape: SmallVector<int64_t>{numElements},
173 elementType: inVecType.getElementType());
174 in = rewriter.create<vector::ShapeCastOp>(location: loc, args&: inVecType, args&: in);
175 }
176
177 for (int64_t i = 0; i < numElements; i += 4) {
178 int64_t elemsThisOp = std::min(a: numElements, b: i + 4) - i;
179 Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
180 location: loc, args&: in, args&: i, args&: elemsThisOp, args: 1);
181 for (int64_t j = 0; j < elemsThisOp; j += 2) {
182 if (i + j + 1 < numElements) { // Convert two 8-bit elements
183 Value asFloats = rewriter.create<amdgpu::ExtPackedFp8Op>(
184 location: loc, args&: extResType, args&: inSlice, args: j / 2);
185 Type desType = VectorType::get(shape: 2, elementType: outElemType);
186 Value asType = castF32To(desType, f32: asFloats, loc, rewriter);
187 result = rewriter.create<vector::InsertStridedSliceOp>(
188 location: loc, args&: asType, args&: result, args: i + j, args: 1);
189 } else { // Convert a 8-bit element
190 Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
191 location: loc, args: rewriter.getF32Type(), args&: inSlice, args: j / 2 * 2);
192 Value asType = castF32To(desType: outElemType, f32: asFloat, loc, rewriter);
193 result = rewriter.create<vector::InsertOp>(location: loc, args&: asType, args&: result, args: i + j);
194 }
195 }
196 }
197
198 if (inVecType.getRank() != outType.getRank()) {
199 result = rewriter.create<vector::ShapeCastOp>(location: loc, args&: outType, args&: result);
200 }
201
202 rewriter.replaceOp(op, newValues: result);
203 return success();
204}
205
206static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
207 Type type = value.getType();
208 if (type.isF32())
209 return value;
210 if (type.getIntOrFloatBitWidth() < 32)
211 return rewriter.create<arith::ExtFOp>(location: loc, args: rewriter.getF32Type(), args&: value);
212 if (type.getIntOrFloatBitWidth() > 32)
213 return rewriter.create<arith::TruncFOp>(location: loc, args: rewriter.getF32Type(), args&: value);
214 llvm_unreachable("The only 32-bit float type is f32");
215}
216
217// If `in` is a finite value, clamp it between the maximum and minimum values
218// of `outElemType` so that subsequent conversion instructions don't
219// overflow those out-of-range values to NaN. These semantics are commonly
220// used in machine-learning contexts where failure to clamp would lead to
221// excessive NaN production.
222static Value clampInput(PatternRewriter &rewriter, Location loc,
223 Type outElemType, Value source) {
224 Type sourceType = source.getType();
225 const llvm::fltSemantics &sourceSem =
226 cast<FloatType>(Val: getElementTypeOrSelf(type: sourceType)).getFloatSemantics();
227 const llvm::fltSemantics &targetSem =
228 cast<FloatType>(Val&: outElemType).getFloatSemantics();
229
230 APFloat min = APFloat::getLargest(Sem: targetSem, /*Negative=*/true);
231 APFloat max = APFloat::getLargest(Sem: targetSem, /*Negative=*/false);
232 bool ignoredLosesInfo = false;
233 // We can ignore conversion failures here because this conversion promotes
234 // from a smaller type to a larger one - ex. there can be no loss of precision
235 // when casting fp8 to f16.
236 (void)min.convert(ToSemantics: sourceSem, RM: APFloat::rmNearestTiesToEven, losesInfo: &ignoredLosesInfo);
237 (void)max.convert(ToSemantics: sourceSem, RM: APFloat::rmNearestTiesToEven, losesInfo: &ignoredLosesInfo);
238
239 Value minCst = createScalarOrSplatConstant(builder&: rewriter, loc, type: sourceType, value: min);
240 Value maxCst = createScalarOrSplatConstant(builder&: rewriter, loc, type: sourceType, value: max);
241
242 Value inf = createScalarOrSplatConstant(
243 builder&: rewriter, loc, type: sourceType,
244 value: APFloat::getInf(Sem: sourceSem, /*Negative=*/false));
245 Value negInf = createScalarOrSplatConstant(
246 builder&: rewriter, loc, type: sourceType, value: APFloat::getInf(Sem: sourceSem, /*Negative=*/true));
247 Value isInf = rewriter.createOrFold<arith::CmpFOp>(
248 location: loc, args: arith::CmpFPredicate::OEQ, args&: source, args&: inf);
249 Value isNegInf = rewriter.createOrFold<arith::CmpFOp>(
250 location: loc, args: arith::CmpFPredicate::OEQ, args&: source, args&: negInf);
251 Value isNan = rewriter.createOrFold<arith::CmpFOp>(
252 location: loc, args: arith::CmpFPredicate::UNO, args&: source, args&: source);
253 Value isNonFinite = rewriter.create<arith::OrIOp>(
254 location: loc, args: rewriter.create<arith::OrIOp>(location: loc, args&: isInf, args&: isNegInf), args&: isNan);
255
256 Value clampedBelow = rewriter.create<arith::MaximumFOp>(location: loc, args&: source, args&: minCst);
257 Value clamped = rewriter.create<arith::MinimumFOp>(location: loc, args&: clampedBelow, args&: maxCst);
258 Value res =
259 rewriter.create<arith::SelectOp>(location: loc, args&: isNonFinite, args&: source, args&: clamped);
260 return res;
261}
262
263LogicalResult
264TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
265 PatternRewriter &rewriter) const {
266 // Only supporting default rounding mode as of now.
267 if (op.getRoundingmodeAttr())
268 return failure();
269 Type outType = op.getOut().getType();
270 auto outVecType = dyn_cast<VectorType>(Val&: outType);
271 if (outVecType) {
272 if (outVecType.isScalable())
273 return failure();
274 outType = outVecType.getElementType();
275 }
276 auto inType = dyn_cast<FloatType>(Val: getElementTypeOrSelf(type: op.getIn().getType()));
277 if (inType && inType.getWidth() <= 8 && saturateFP8)
278 // Conversion between 8-bit floats is not supported with truncation enabled.
279 return failure();
280
281 if (!isSupportedF8(elementType: outType, chipset))
282 return failure();
283
284 Location loc = op.getLoc();
285 Value in = op.getIn();
286 Type outElemType = getElementTypeOrSelf(type: op.getOut().getType());
287 if (saturateFP8)
288 in = clampInput(rewriter, loc, outElemType, source: in);
289 auto inVectorTy = dyn_cast<VectorType>(Val: in.getType());
290 VectorType truncResType = VectorType::get(shape: 4, elementType: outElemType);
291 if (!inVectorTy) {
292 Value asFloat = castToF32(value: in, loc, rewriter);
293 Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
294 location: loc, args&: truncResType, args&: asFloat, /*sourceB=*/args: nullptr, args: 0,
295 /*existing=*/args: nullptr);
296 Value result = rewriter.create<vector::ExtractOp>(location: loc, args&: asF8s, args: 0);
297 rewriter.replaceOp(op, newValues: result);
298 return success();
299 }
300
301 int64_t numElements = outVecType.getNumElements();
302 Value zero = rewriter.create<arith::ConstantOp>(
303 location: loc, args&: outElemType, args: rewriter.getFloatAttr(type: outElemType, value: 0.0));
304 if (outVecType.getShape().empty()) {
305 Value scalarIn =
306 rewriter.create<vector::ExtractOp>(location: loc, args&: in, args: ArrayRef<int64_t>{});
307 // Recurse to send the 0-D vector case to the 1-D vector case
308 Value scalarTrunc =
309 rewriter.create<arith::TruncFOp>(location: loc, args&: outElemType, args&: scalarIn);
310 Value result = rewriter.create<vector::InsertOp>(location: loc, args&: scalarTrunc, args&: zero,
311 args: ArrayRef<int64_t>{});
312 rewriter.replaceOp(op, newValues: result);
313 return success();
314 }
315
316 VectorType flatTy = VectorType::get(shape: SmallVector<int64_t>{numElements},
317 elementType: outVecType.getElementType());
318 Value result = rewriter.createOrFold<vector::SplatOp>(location: loc, args&: flatTy, args&: zero);
319
320 if (inVectorTy.getRank() > 1) {
321 inVectorTy = VectorType::get(shape: SmallVector<int64_t>{numElements},
322 elementType: inVectorTy.getElementType());
323 in = rewriter.create<vector::ShapeCastOp>(location: loc, args&: inVectorTy, args&: in);
324 }
325
326 for (int64_t i = 0; i < numElements; i += 4) {
327 int64_t elemsThisOp = std::min(a: numElements, b: i + 4) - i;
328 Value thisResult = nullptr;
329 for (int64_t j = 0; j < elemsThisOp; j += 2) {
330 Value elemA = rewriter.create<vector::ExtractOp>(location: loc, args&: in, args: i + j);
331 Value asFloatA = castToF32(value: elemA, loc, rewriter);
332 Value asFloatB = nullptr;
333 if (j + 1 < elemsThisOp) {
334 Value elemB = rewriter.create<vector::ExtractOp>(location: loc, args&: in, args: i + j + 1);
335 asFloatB = castToF32(value: elemB, loc, rewriter);
336 }
337 thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
338 location: loc, args&: truncResType, args&: asFloatA, args&: asFloatB, args: j / 2, args&: thisResult);
339 }
340 if (elemsThisOp < 4)
341 thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
342 location: loc, args&: thisResult, args: 0, args&: elemsThisOp, args: 1);
343 result = rewriter.create<vector::InsertStridedSliceOp>(location: loc, args&: thisResult,
344 args&: result, args&: i, args: 1);
345 }
346
347 if (inVectorTy.getRank() != outVecType.getRank()) {
348 result = rewriter.create<vector::ShapeCastOp>(location: loc, args&: outVecType, args&: result);
349 }
350
351 rewriter.replaceOp(op, newValues: result);
352 return success();
353}
354
355LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
356 arith::TruncFOp op, PatternRewriter &rewriter) const {
357 Type outType = op.getOut().getType();
358 Type inputType = getElementTypeOrSelf(val: op.getIn());
359 auto outVecType = dyn_cast<VectorType>(Val&: outType);
360 if (outVecType) {
361 if (outVecType.isScalable())
362 return failure();
363 outType = outVecType.getElementType();
364 }
365 if (!(outType.isF16() && inputType.isF32()))
366 return failure();
367
368 Location loc = op.getLoc();
369 Value in = op.getIn();
370 Type outElemType = getElementTypeOrSelf(type: op.getOut().getType());
371 VectorType truncResType = VectorType::get(shape: 2, elementType: outElemType);
372 auto inVectorTy = dyn_cast<VectorType>(Val: in.getType());
373
374 // Handle the case where input type is not a vector type
375 if (!inVectorTy) {
376 auto sourceB = rewriter.create<LLVM::PoisonOp>(location: loc, args: rewriter.getF32Type());
377 Value asF16s =
378 rewriter.create<ROCDL::CvtPkRtz>(location: loc, args&: truncResType, args&: in, args&: sourceB);
379 Value result = rewriter.create<vector::ExtractOp>(location: loc, args&: asF16s, args: 0);
380 rewriter.replaceOp(op, newValues: result);
381 return success();
382 }
383 int64_t numElements = outVecType.getNumElements();
384 Value zero = rewriter.createOrFold<arith::ConstantOp>(
385 location: loc, args&: outElemType, args: rewriter.getFloatAttr(type: outElemType, value: 0.0));
386 Value result = rewriter.createOrFold<vector::SplatOp>(location: loc, args&: outVecType, args&: zero);
387
388 if (inVectorTy.getRank() > 1) {
389 inVectorTy = VectorType::get(shape: SmallVector<int64_t>{numElements},
390 elementType: inVectorTy.getElementType());
391 in = rewriter.create<vector::ShapeCastOp>(location: loc, args&: inVectorTy, args&: in);
392 }
393
394 // Handle the vector case. We also handle the (uncommon) case where the vector
395 // length is odd
396 for (int64_t i = 0; i < numElements; i += 2) {
397 int64_t elemsThisOp = std::min(a: numElements, b: i + 2) - i;
398 Value thisResult = nullptr;
399 Value elemA = rewriter.create<vector::ExtractOp>(location: loc, args&: in, args&: i);
400 Value elemB = rewriter.create<LLVM::PoisonOp>(location: loc, args: rewriter.getF32Type());
401
402 if (elemsThisOp == 2) {
403 elemB = rewriter.create<vector::ExtractOp>(location: loc, args&: in, args: i + 1);
404 }
405
406 thisResult =
407 rewriter.create<ROCDL::CvtPkRtz>(location: loc, args&: truncResType, args&: elemA, args&: elemB);
408 // Place back the truncated result into the possibly larger vector. If we
409 // are operating on a size 2 vector, these operations should be folded away
410 thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
411 location: loc, args&: thisResult, args: 0, args&: elemsThisOp, args: 1);
412 result = rewriter.create<vector::InsertStridedSliceOp>(location: loc, args&: thisResult,
413 args&: result, args&: i, args: 1);
414 }
415
416 if (inVectorTy.getRank() != outVecType.getRank()) {
417 result = rewriter.create<vector::ShapeCastOp>(location: loc, args&: outVecType, args&: result);
418 }
419
420 rewriter.replaceOp(op, newValues: result);
421 return success();
422}
423
424/// Get the broadcasted / splatted value for a chain of ops.
425static Value getOriginalVectorValue(Value value) {
426 Value current = value;
427 while (Operation *definingOp = current.getDefiningOp()) {
428 bool skipOp = llvm::TypeSwitch<Operation *, bool>(definingOp)
429 .Case<vector::ShapeCastOp>(caseFn: [&current](auto op) {
430 current = op.getSource();
431 return true;
432 })
433 .Case<vector::BroadcastOp>(caseFn: [&current](auto op) {
434 current = op.getSource();
435 return false;
436 })
437 .Case<vector::SplatOp>(caseFn: [&current](auto op) {
438 current = op.getInput();
439 return false;
440 })
441 .Default(defaultFn: [](Operation *) { return false; });
442
443 if (!skipOp) {
444 break;
445 }
446 }
447 return current;
448}
449
450LogicalResult
451ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
452 PatternRewriter &rewriter) const {
453 Location loc = op.getLoc();
454 constexpr int64_t opWidth = 2;
455
456 Value in = op.getIn();
457 Value scale = op.getScale();
458 Value out = op.getOut();
459
460 Type f32 = rewriter.getF32Type();
461 Type inType = getElementTypeOrSelf(val: in);
462 Type scaleType = getElementTypeOrSelf(val: scale);
463 Type outType = getElementTypeOrSelf(val: out);
464
465 VectorType outVecType = dyn_cast<VectorType>(Val: out.getType());
466 VectorType scaleVecType = dyn_cast<VectorType>(Val: scale.getType());
467
468 if (outVecType && outVecType.isScalable())
469 return failure();
470
471 Type scaleF32Type =
472 scaleVecType ? VectorType::get(shape: scaleVecType.getShape(), elementType: f32) : f32;
473 if (scaleType.getIntOrFloatBitWidth() < 32)
474 scale = rewriter.create<arith::ExtFOp>(location: loc, args&: scaleF32Type, args&: scale);
475 else if (scaleType.getIntOrFloatBitWidth() > 32)
476 scale = rewriter.create<arith::TruncFOp>(location: loc, args&: scaleF32Type, args&: scale);
477
478 VectorType extScaleResultType = VectorType::get(shape: opWidth, elementType: outType);
479
480 if (!outVecType) {
481 Value inCast =
482 rewriter.create<vector::SplatOp>(location: loc, args: VectorType::get(shape: 1, elementType: inType), args&: in);
483 // TODO: replace this with non-packed ScaledExtOp
484 Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
485 location: loc, args&: extScaleResultType, args&: inCast, args&: scale, args: 0);
486 scaleExt = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, args&: scaleExt, args: 0);
487 return success();
488 }
489
490 VectorType inVecType = cast<VectorType>(Val: in.getType());
491 Value origScale = getOriginalVectorValue(value: op.getScale());
492
493 ArrayRef<int64_t> inShape = inVecType.getShape();
494 SmallVector<int64_t> originalScaleShape;
495 if (auto origScaleVecType = dyn_cast<VectorType>(Val: origScale.getType()))
496 llvm::append_range(C&: originalScaleShape, R: origScaleVecType.getShape());
497
498 originalScaleShape.insert(I: originalScaleShape.end(),
499 NumToInsert: inShape.size() - originalScaleShape.size(), Elt: 1);
500
501 auto maybeRatio = computeShapeRatio(shape: inShape, subShape: originalScaleShape);
502 assert(maybeRatio &&
503 "failed to derive block size from broadcast or splat operation");
504
505 SmallVector<int64_t> ratio =
506 maybeRatio.value_or(u: SmallVector<int64_t>(inShape.size(), 1));
507
508 int64_t blockSize = computeProduct(basis: ratio);
509
510 Value zero = rewriter.create<arith::ConstantOp>(
511 location: loc, args&: outType, args: rewriter.getFloatAttr(type: outType, value: 0.0));
512 Value result = rewriter.createOrFold<vector::SplatOp>(location: loc, args&: outVecType, args&: zero);
513
514 for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
515 SmallVector<int64_t> strides(offsets.size(), 1);
516 Value block = rewriter.create<vector::ExtractStridedSliceOp>(
517 location: loc, args&: in, args&: offsets, args&: ratio, args&: strides);
518 VectorType block1DType = VectorType::get(shape: blockSize, elementType: inType);
519 Value block1D =
520 rewriter.create<vector::ShapeCastOp>(location: loc, args&: block1DType, args&: block);
521 Value uniformScale =
522 rewriter.create<vector::ExtractOp>(location: loc, args&: scale, args&: offsets);
523
524 VectorType blockResultType = VectorType::get(shape: blockSize, elementType: outType);
525 Value blockResult =
526 rewriter.createOrFold<vector::SplatOp>(location: loc, args&: blockResultType, args&: zero);
527
528 for (int64_t i = 0, sliceWidth = std::min(a: opWidth, b: blockSize - i);
529 i < blockSize;
530 i += sliceWidth, sliceWidth = std::min(a: opWidth, b: blockSize - i)) {
531 Value slice = rewriter.create<vector::ExtractStridedSliceOp>(
532 location: loc, args&: block1D, args&: i, args&: sliceWidth, args: 1);
533 // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
534 Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
535 location: loc, args&: extScaleResultType, args&: slice, args&: uniformScale, args: 0);
536 if (sliceWidth != opWidth)
537 scaleExt = rewriter.create<vector::ExtractStridedSliceOp>(
538 location: loc, args&: scaleExt, args: 0, args&: sliceWidth, args: 1);
539 blockResult = rewriter.create<vector::InsertStridedSliceOp>(
540 location: loc, args&: scaleExt, args&: blockResult, args&: i, args: 1);
541 }
542
543 VectorType resultType = VectorType::get(shape: ratio, elementType: outType);
544 Value cast =
545 rewriter.create<vector::ShapeCastOp>(location: loc, args&: resultType, args&: blockResult);
546 result = rewriter.create<vector::InsertStridedSliceOp>(location: loc, args&: cast, args&: result,
547 args&: offsets, args&: strides);
548 }
549
550 rewriter.replaceOp(op, newValues: result);
551
552 return success();
553}
554
555LogicalResult
556ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
557 PatternRewriter &rewriter) const {
558 Location loc = op.getLoc();
559 constexpr int64_t opWidth = 2;
560
561 Value in = op.getIn();
562 Value scale = op.getScale();
563 Value out = op.getOut();
564
565 Type f32 = rewriter.getF32Type();
566 Type inType = getElementTypeOrSelf(val: in);
567 Type scaleType = getElementTypeOrSelf(val: scale);
568 Type outType = getElementTypeOrSelf(val: out);
569
570 VectorType outVecType = dyn_cast<VectorType>(Val: out.getType());
571 VectorType scaleVecType = dyn_cast<VectorType>(Val: scale.getType());
572
573 if (outVecType && outVecType.isScalable())
574 return failure();
575
576 Type scaleF32Type =
577 scaleVecType ? VectorType::get(shape: scaleVecType.getShape(), elementType: f32) : f32;
578 if (scaleType.getIntOrFloatBitWidth() < 32)
579 scale = rewriter.create<arith::ExtFOp>(location: loc, args&: scaleF32Type, args&: scale);
580 else if (scaleType.getIntOrFloatBitWidth() > 32)
581 scale = rewriter.create<arith::TruncFOp>(location: loc, args&: scaleF32Type, args&: scale);
582
583 Value zero = rewriter.create<arith::ConstantOp>(
584 location: loc, args&: outType, args: rewriter.getFloatAttr(type: outType, value: 0.0));
585 unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth();
586 VectorType truncScaleResultType = VectorType::get(shape: numPackedElem, elementType: outType);
587
588 if (!outVecType) {
589 Type inVecType = VectorType::get(shape: 1, elementType: inType);
590 Value inCast = rewriter.create<vector::SplatOp>(location: loc, args&: inVecType, args&: in);
591 // TODO: replace this with non-packed ScaledTruncOp
592 Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>(
593 location: loc, args&: truncScaleResultType, args&: inCast, args&: scale, args: 0, /*existing=*/args: nullptr);
594 scaleTrunc =
595 rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, args&: scaleTrunc, args: 0);
596 return success();
597 }
598
599 VectorType inVecType = cast<VectorType>(Val: in.getType());
600 Value origScale = getOriginalVectorValue(value: op.getScale());
601
602 ArrayRef<int64_t> inShape = inVecType.getShape();
603 SmallVector<int64_t> originalScaleShape;
604 if (auto origScaleVecType = dyn_cast<VectorType>(Val: origScale.getType()))
605 llvm::append_range(C&: originalScaleShape, R: origScaleVecType.getShape());
606
607 originalScaleShape.insert(I: originalScaleShape.end(),
608 NumToInsert: inShape.size() - originalScaleShape.size(), Elt: 1);
609
610 auto maybeRatio = computeShapeRatio(shape: inShape, subShape: originalScaleShape);
611 assert(maybeRatio &&
612 "failed to derive block size from broadcast or splat operation");
613
614 SmallVector<int64_t> ratio =
615 maybeRatio.value_or(u: SmallVector<int64_t>(inShape.size(), 1));
616
617 int64_t blockSize = computeProduct(basis: ratio);
618
619 Value result = rewriter.createOrFold<vector::SplatOp>(location: loc, args&: outVecType, args&: zero);
620
621 for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
622 SmallVector<int64_t> strides(offsets.size(), 1);
623 Value block = rewriter.create<vector::ExtractStridedSliceOp>(
624 location: loc, args&: in, args&: offsets, args&: ratio, args&: strides);
625 VectorType block1DType = VectorType::get(shape: blockSize, elementType: inType);
626 Value block1D =
627 rewriter.create<vector::ShapeCastOp>(location: loc, args&: block1DType, args&: block);
628 Value uniformScale =
629 rewriter.create<vector::ExtractOp>(location: loc, args&: scale, args&: offsets);
630
631 VectorType blockResultType = VectorType::get(shape: blockSize, elementType: outType);
632 Value blockResult =
633 rewriter.createOrFold<vector::SplatOp>(location: loc, args&: blockResultType, args&: zero);
634
635 for (int64_t i = 0, sliceWidth = std::min(a: opWidth, b: blockSize - i);
636 i < blockSize;
637 i += sliceWidth, sliceWidth = std::min(a: opWidth, b: blockSize - i)) {
638 Value slice = rewriter.create<vector::ExtractStridedSliceOp>(
639 location: loc, args&: block1D, args&: i, args&: sliceWidth, args: 1);
640 // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
641 Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>(
642 location: loc, args&: truncScaleResultType, args&: slice, args&: uniformScale, args: 0,
643 /*existing=*/args: nullptr);
644 int64_t packedWidth =
645 cast<VectorType>(Val: scaleTrunc.getType()).getNumElements();
646 if (packedWidth != opWidth)
647 scaleTrunc = rewriter.create<vector::ExtractStridedSliceOp>(
648 location: loc, args&: scaleTrunc, args: 0, args&: sliceWidth, args: 1);
649 blockResult = rewriter.create<vector::InsertStridedSliceOp>(
650 location: loc, args&: scaleTrunc, args&: blockResult, args&: i, args: 1);
651 }
652
653 VectorType resultType = VectorType::get(shape: ratio, elementType: outType);
654 Value cast =
655 rewriter.create<vector::ShapeCastOp>(location: loc, args&: resultType, args&: blockResult);
656 result = rewriter.create<vector::InsertStridedSliceOp>(location: loc, args&: cast, args&: result,
657 args&: offsets, args&: strides);
658 }
659
660 rewriter.replaceOp(op, newValues: result);
661
662 return success();
663}
664
665void mlir::arith::populateArithToAMDGPUConversionPatterns(
666 RewritePatternSet &patterns, bool convertFP8Arithmetic,
667 bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
668
669 if (convertFP8Arithmetic) {
670 patterns.add<ExtFOnFloat8RewritePattern>(arg: patterns.getContext(), args&: chipset);
671 patterns.add<TruncFToFloat8RewritePattern>(arg: patterns.getContext(),
672 args&: saturateFP8Truncf, args&: chipset);
673 }
674 if (allowPackedF16Rtz)
675 patterns.add<TruncfToFloat16RewritePattern>(arg: patterns.getContext());
676
677 if (chipset >= kGfx950) {
678 patterns.add<ScalingExtFRewritePattern>(arg: patterns.getContext());
679 patterns.add<ScalingTruncFRewritePattern>(arg: patterns.getContext());
680 }
681}
682
683void ArithToAMDGPUConversionPass::runOnOperation() {
684 Operation *op = getOperation();
685 MLIRContext *ctx = &getContext();
686 RewritePatternSet patterns(op->getContext());
687 FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(name: chipset);
688 if (failed(Result: maybeChipset)) {
689 emitError(loc: UnknownLoc::get(context: ctx), message: "Invalid chipset name: " + chipset);
690 return signalPassFailure();
691 }
692
693 bool convertFP8Arithmetic =
694 *maybeChipset == kGfx942 || hasOcpFp8(chipset: *maybeChipset);
695 arith::populateArithToAMDGPUConversionPatterns(
696 patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
697 chipset: *maybeChipset);
698 if (failed(Result: applyPatternsGreedily(op, patterns: std::move(patterns))))
699 return signalPassFailure();
700}
701

source code of mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp