1//===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
10
11#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12#include "mlir/Conversion/LLVMCommon/Pattern.h"
13#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
15#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
16#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
18#include "mlir/IR/BuiltinTypes.h"
19#include "mlir/IR/TypeUtilities.h"
20#include "mlir/Pass/Pass.h"
21
22#include "../LLVMCommon/MemRefDescriptor.h"
23
24#include "llvm/ADT/STLExtras.h"
25#include "llvm/ADT/TypeSwitch.h"
26#include "llvm/Support/Casting.h"
27#include "llvm/Support/ErrorHandling.h"
28#include <optional>
29
30namespace mlir {
31#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
32#include "mlir/Conversion/Passes.h.inc"
33} // namespace mlir
34
35using namespace mlir;
36using namespace mlir::amdgpu;
37
38// Define commonly used chipsets versions for convenience.
39constexpr Chipset kGfx908 = Chipset(9, 0, 8);
40constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
41constexpr Chipset kGfx942 = Chipset(9, 4, 2);
42constexpr Chipset kGfx950 = Chipset(9, 5, 0);
43
44/// Convert an unsigned number `val` to i32.
45static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
46 Location loc, Value val) {
47 IntegerType i32 = rewriter.getI32Type();
48 // Force check that `val` is of int type.
49 auto valTy = cast<IntegerType>(Val: val.getType());
50 if (i32 == valTy)
51 return val;
52 return valTy.getWidth() > 32
53 ? Value(rewriter.create<LLVM::TruncOp>(location: loc, args&: i32, args&: val))
54 : Value(rewriter.create<LLVM::ZExtOp>(location: loc, args&: i32, args&: val));
55}
56
57static Value createI32Constant(ConversionPatternRewriter &rewriter,
58 Location loc, int32_t value) {
59 Type i32 = rewriter.getI32Type();
60 return rewriter.create<LLVM::ConstantOp>(location: loc, args&: i32, args&: value);
61}
62
63static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
64 bool value) {
65 Type llvmI1 = rewriter.getI1Type();
66 return rewriter.create<LLVM::ConstantOp>(location: loc, args&: llvmI1, args&: value);
67}
68
69/// Returns the linear index used to access an element in the memref.
70static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
71 Location loc, MemRefDescriptor &memRefDescriptor,
72 ValueRange indices, ArrayRef<int64_t> strides) {
73 IntegerType i32 = rewriter.getI32Type();
74 Value index;
75 for (auto [i, increment, stride] : llvm::enumerate(First&: indices, Rest&: strides)) {
76 if (stride != 1) { // Skip if stride is 1.
77 Value strideValue =
78 ShapedType::isDynamic(dValue: stride)
79 ? convertUnsignedToI32(rewriter, loc,
80 val: memRefDescriptor.stride(builder&: rewriter, loc, pos: i))
81 : rewriter.create<LLVM::ConstantOp>(location: loc, args&: i32, args: stride);
82 increment = rewriter.create<LLVM::MulOp>(location: loc, args&: increment, args&: strideValue);
83 }
84 index =
85 index ? rewriter.create<LLVM::AddOp>(location: loc, args&: index, args&: increment) : increment;
86 }
87 return index ? index : createI32Constant(rewriter, loc, value: 0);
88}
89
90/// Compute the contents of the `num_records` field for a given memref
91/// descriptor - that is, the number of bytes that's one element past the
92/// greatest possible valid index into the memref.
93static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
94 MemRefType memrefType,
95 MemRefDescriptor &memrefDescriptor,
96 ArrayRef<int64_t> strides,
97 uint32_t elementByteWidth) {
98 if (memrefType.hasStaticShape() &&
99 !llvm::any_of(Range&: strides, P: ShapedType::isDynamic)) {
100 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
101 ArrayRef<int64_t> shape = memrefType.getShape();
102 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
103 size = std::max(a: shape[i] * strides[i], b: size);
104 size = size * elementByteWidth;
105 assert(size < std::numeric_limits<uint32_t>::max() &&
106 "the memref buffer is too large");
107 return createI32Constant(rewriter, loc, value: static_cast<int32_t>(size));
108 }
109 Value maxIndex;
110 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
111 Value size = memrefDescriptor.size(builder&: rewriter, loc, pos: i);
112 Value stride = memrefDescriptor.stride(builder&: rewriter, loc, pos: i);
113 Value maxThisDim = rewriter.create<LLVM::MulOp>(location: loc, args&: size, args&: stride);
114 maxIndex = maxIndex
115 ? rewriter.create<LLVM::UMaxOp>(location: loc, args&: maxIndex, args&: maxThisDim)
116 : maxThisDim;
117 }
118 Value maxIndexI32 = convertUnsignedToI32(rewriter, loc, val: maxIndex);
119 Value byteWidthConst = createI32Constant(rewriter, loc, value: elementByteWidth);
120 return rewriter.create<LLVM::MulOp>(location: loc, args&: maxIndexI32, args&: byteWidthConst);
121}
122
123static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
124 Value basePointer, Value numRecords,
125 bool boundsCheck, amdgpu::Chipset chipset,
126 Value cacheSwizzleStride = nullptr,
127 unsigned addressSpace = 8) {
128 // The stride value is generally 0. However, on MI-300 and onward, you can
129 // enable a cache swizzling mode by setting bit 14 of the stride field
130 // and setting that stride to a cache stride.
131 Type i16 = rewriter.getI16Type();
132 Value stride;
133 if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
134 Value cacheStrideZext =
135 rewriter.create<LLVM::ZExtOp>(location: loc, args&: i16, args&: cacheSwizzleStride);
136 Value swizzleBit = rewriter.create<LLVM::ConstantOp>(
137 location: loc, args&: i16, args: rewriter.getI16IntegerAttr(value: 1 << 14));
138 stride = rewriter.create<LLVM::OrOp>(location: loc, args&: cacheStrideZext, args&: swizzleBit,
139 /*isDisjoint=*/args: true);
140 } else {
141 stride = rewriter.create<LLVM::ConstantOp>(location: loc, args&: i16,
142 args: rewriter.getI16IntegerAttr(value: 0));
143 }
144 // Get the number of elements.
145 // Flag word:
146 // bits 0-11: dst sel, ignored by these intrinsics
147 // bits 12-14: data format (ignored, must be nonzero, 7=float)
148 // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
149 // bit 19: In nested heap (0 here)
150 // bit 20: Behavior on unmap (0 means "return 0 / ignore")
151 // bits 21-22: Index stride for swizzles (N/A)
152 // bit 23: Add thread ID (0)
153 // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
154 // bits 25-26: Reserved (0)
155 // bit 27: Buffer is non-volatile (CDNA only)
156 // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
157 // none, 3 = either swizzles or testing against offset field) RDNA only
158 // bits 30-31: Type (must be 0)
159 uint32_t flags = (7 << 12) | (4 << 15);
160 if (chipset.majorVersion >= 10) {
161 flags |= (1 << 24);
162 uint32_t oob = boundsCheck ? 3 : 2;
163 flags |= (oob << 28);
164 }
165 Value flagsConst = createI32Constant(rewriter, loc, value: flags);
166 Type rsrcType =
167 LLVM::LLVMPointerType::get(context: rewriter.getContext(), addressSpace);
168 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
169 location: loc, args&: rsrcType, args&: basePointer, args&: stride, args&: numRecords, args&: flagsConst);
170 return resource;
171}
172
173namespace {
174struct FatRawBufferCastLowering
175 : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
176 FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
177 : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
178 chipset(chipset) {}
179
180 Chipset chipset;
181
182 LogicalResult
183 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
184 ConversionPatternRewriter &rewriter) const override {
185 Location loc = op.getLoc();
186 Value memRef = adaptor.getSource();
187 Value unconvertedMemref = op.getSource();
188 MemRefType memrefType = cast<MemRefType>(Val: unconvertedMemref.getType());
189 MemRefDescriptor descriptor(memRef);
190
191 DataLayout dataLayout = DataLayout::closest(op);
192 int64_t elementByteWidth =
193 dataLayout.getTypeSizeInBits(t: memrefType.getElementType()) / 8;
194
195 int64_t unusedOffset = 0;
196 SmallVector<int64_t, 5> strideVals;
197 if (failed(Result: memrefType.getStridesAndOffset(strides&: strideVals, offset&: unusedOffset)))
198 return op.emitOpError(message: "Can't lower non-stride-offset memrefs");
199
200 Value numRecords = adaptor.getValidBytes();
201 if (!numRecords)
202 numRecords = getNumRecords(rewriter, loc, memrefType, memrefDescriptor&: descriptor,
203 strides: strideVals, elementByteWidth);
204
205 Value basePointer =
206 adaptor.getResetOffset()
207 ? descriptor.bufferPtr(builder&: rewriter, loc, converter: *getTypeConverter(),
208 type: memrefType)
209 : descriptor.alignedPtr(builder&: rewriter, loc);
210
211 Value offset = adaptor.getResetOffset()
212 ? rewriter.create<LLVM::ConstantOp>(
213 location: loc, args: getIndexType(), args: rewriter.getIndexAttr(value: 0))
214 : descriptor.offset(builder&: rewriter, loc);
215
216 bool hasSizes = memrefType.getRank() > 0;
217 // No need to unpack() and pack() all the individual sizes and strides,
218 // so we'll just extract the arrays.
219 Value sizes = hasSizes ? rewriter.create<LLVM::ExtractValueOp>(
220 location: loc, args&: descriptor, args: kSizePosInMemRefDescriptor)
221 : Value{};
222 Value strides = hasSizes
223 ? rewriter.create<LLVM::ExtractValueOp>(
224 location: loc, args&: descriptor, args: kStridePosInMemRefDescriptor)
225 : Value{};
226
227 Value fatPtr = makeBufferRsrc(
228 rewriter, loc, basePointer, numRecords, boundsCheck: adaptor.getBoundsCheck(),
229 chipset, cacheSwizzleStride: adaptor.getCacheSwizzleStride(), /*addressSpace=*/7);
230
231 Value result = MemRefDescriptor::poison(
232 builder&: rewriter, loc,
233 descriptorType: getTypeConverter()->convertType(t: op.getResult().getType()));
234 result = rewriter.create<LLVM::InsertValueOp>(
235 location: loc, args&: result, args&: fatPtr, args: kAllocatedPtrPosInMemRefDescriptor);
236 result = rewriter.create<LLVM::InsertValueOp>(
237 location: loc, args&: result, args&: fatPtr, args: kAlignedPtrPosInMemRefDescriptor);
238 result = rewriter.create<LLVM::InsertValueOp>(location: loc, args&: result, args&: offset,
239 args: kOffsetPosInMemRefDescriptor);
240 if (hasSizes) {
241 result = rewriter.create<LLVM::InsertValueOp>(location: loc, args&: result, args&: sizes,
242 args: kSizePosInMemRefDescriptor);
243 result = rewriter.create<LLVM::InsertValueOp>(
244 location: loc, args&: result, args&: strides, args: kStridePosInMemRefDescriptor);
245 }
246 rewriter.replaceOp(op, newValues: result);
247 return success();
248 }
249};
250
251/// Define lowering patterns for raw buffer ops
252template <typename GpuOp, typename Intrinsic>
253struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
254 RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
255 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
256
257 Chipset chipset;
258 static constexpr uint32_t maxVectorOpWidth = 128;
259
260 LogicalResult
261 matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
262 ConversionPatternRewriter &rewriter) const override {
263 Location loc = gpuOp.getLoc();
264 Value memref = adaptor.getMemref();
265 Value unconvertedMemref = gpuOp.getMemref();
266 MemRefType memrefType = cast<MemRefType>(Val: unconvertedMemref.getType());
267
268 if (chipset.majorVersion < 9)
269 return gpuOp.emitOpError("raw buffer ops require GCN or higher");
270
271 Value storeData = adaptor.getODSOperands(0)[0];
272 if (storeData == memref) // no write component to this op
273 storeData = Value();
274 Type wantedDataType;
275 if (storeData)
276 wantedDataType = storeData.getType();
277 else
278 wantedDataType = gpuOp.getODSResults(0)[0].getType();
279
280 Value atomicCmpData = Value();
281 // Operand index 1 of a load is the indices, trying to read them can crash.
282 if (storeData) {
283 Value maybeCmpData = adaptor.getODSOperands(1)[0];
284 if (maybeCmpData != memref)
285 atomicCmpData = maybeCmpData;
286 }
287
288 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
289
290 Type i32 = rewriter.getI32Type();
291
292 // Get the type size in bytes.
293 DataLayout dataLayout = DataLayout::closest(op: gpuOp);
294 int64_t elementByteWidth =
295 dataLayout.getTypeSizeInBits(t: memrefType.getElementType()) / 8;
296 Value byteWidthConst = createI32Constant(rewriter, loc, value: elementByteWidth);
297
298 // If we want to load a vector<NxT> with total size <= 32
299 // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
300 // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
301 // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
302 // so bitcast any floats to integers.
303 Type llvmBufferValType = llvmWantedDataType;
304 if (atomicCmpData) {
305 if (auto floatType = dyn_cast<FloatType>(Val&: wantedDataType))
306 llvmBufferValType = this->getTypeConverter()->convertType(
307 rewriter.getIntegerType(width: floatType.getWidth()));
308 }
309 if (auto dataVector = dyn_cast<VectorType>(Val&: wantedDataType)) {
310 uint32_t vecLen = dataVector.getNumElements();
311 uint32_t elemBits =
312 dataLayout.getTypeSizeInBits(t: dataVector.getElementType());
313 uint32_t totalBits = elemBits * vecLen;
314 bool usePackedFp16 =
315 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
316 if (totalBits > maxVectorOpWidth)
317 return gpuOp.emitOpError(
318 "Total width of loads or stores must be no more than " +
319 Twine(maxVectorOpWidth) + " bits, but we call for " +
320 Twine(totalBits) +
321 " bits. This should've been caught in validation");
322 if (!usePackedFp16 && elemBits < 32) {
323 if (totalBits > 32) {
324 if (totalBits % 32 != 0)
325 return gpuOp.emitOpError("Load or store of more than 32-bits that "
326 "doesn't fit into words. Can't happen\n");
327 llvmBufferValType = this->typeConverter->convertType(
328 VectorType::get(shape: totalBits / 32, elementType: i32));
329 } else {
330 llvmBufferValType = this->typeConverter->convertType(
331 rewriter.getIntegerType(width: totalBits));
332 }
333 }
334 }
335 if (auto vecType = dyn_cast<VectorType>(Val&: llvmBufferValType)) {
336 // Buffer intrinsics doesn't support 1-element vectors, cast them to
337 // scalars.
338 if (vecType.getNumElements() == 1)
339 llvmBufferValType = vecType.getElementType();
340 }
341
342 SmallVector<Value, 6> args;
343 if (storeData) {
344 if (llvmBufferValType != llvmWantedDataType) {
345 Value castForStore =
346 rewriter.create<LLVM::BitcastOp>(location: loc, args&: llvmBufferValType, args&: storeData);
347 args.push_back(Elt: castForStore);
348 } else {
349 args.push_back(Elt: storeData);
350 }
351 }
352
353 if (atomicCmpData) {
354 if (llvmBufferValType != llvmWantedDataType) {
355 Value castForCmp = rewriter.create<LLVM::BitcastOp>(
356 location: loc, args&: llvmBufferValType, args&: atomicCmpData);
357 args.push_back(Elt: castForCmp);
358 } else {
359 args.push_back(Elt: atomicCmpData);
360 }
361 }
362
363 // Construct buffer descriptor from memref, attributes
364 int64_t offset = 0;
365 SmallVector<int64_t, 5> strides;
366 if (failed(Result: memrefType.getStridesAndOffset(strides, offset)))
367 return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
368
369 MemRefDescriptor memrefDescriptor(memref);
370
371 Value ptr = memrefDescriptor.bufferPtr(
372 builder&: rewriter, loc, converter: *this->getTypeConverter(), type: memrefType);
373 Value numRecords = getNumRecords(
374 rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
375 Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
376 adaptor.getBoundsCheck(), chipset);
377 args.push_back(Elt: resource);
378
379 // Indexing (voffset)
380 Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
381 adaptor.getIndices(), strides);
382 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
383 indexOffset && *indexOffset > 0) {
384 Value extraOffsetConst = createI32Constant(rewriter, loc, value: *indexOffset);
385 voffset =
386 voffset ? rewriter.create<LLVM::AddOp>(location: loc, args&: voffset, args&: extraOffsetConst)
387 : extraOffsetConst;
388 }
389 voffset = rewriter.create<LLVM::MulOp>(location: loc, args&: voffset, args&: byteWidthConst);
390 args.push_back(Elt: voffset);
391
392 // SGPR offset.
393 Value sgprOffset = adaptor.getSgprOffset();
394 if (!sgprOffset)
395 sgprOffset = createI32Constant(rewriter, loc, value: 0);
396 sgprOffset = rewriter.create<LLVM::MulOp>(location: loc, args&: sgprOffset, args&: byteWidthConst);
397 args.push_back(Elt: sgprOffset);
398
399 // bit 0: GLC = 0 (atomics drop value, less coherency)
400 // bits 1-2: SLC, DLC = 0 (similarly)
401 // bit 3: swizzled (0 for raw)
402 args.push_back(Elt: createI32Constant(rewriter, loc, value: 0));
403
404 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
405 llvmBufferValType);
406 Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args,
407 ArrayRef<NamedAttribute>());
408 if (lowered->getNumResults() == 1) {
409 Value replacement = lowered->getResult(idx: 0);
410 if (llvmBufferValType != llvmWantedDataType) {
411 replacement = rewriter.create<LLVM::BitcastOp>(location: loc, args&: llvmWantedDataType,
412 args&: replacement);
413 }
414 rewriter.replaceOp(gpuOp, replacement);
415 } else {
416 rewriter.eraseOp(op: gpuOp);
417 }
418 return success();
419 }
420};
421
422struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
423 LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
424 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
425
426 Chipset chipset;
427
428 LogicalResult
429 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
430 ConversionPatternRewriter &rewriter) const override {
431 bool requiresInlineAsm = chipset < kGfx90a || chipset.majorVersion == 11;
432
433 if (requiresInlineAsm) {
434 auto asmDialectAttr = LLVM::AsmDialectAttr::get(context: rewriter.getContext(),
435 val: LLVM::AsmDialect::AD_ATT);
436 const char *asmStr =
437 ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
438 const char *constraints = "";
439 rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>(
440 op,
441 /*resultTypes=*/args: TypeRange(), /*operands=*/args: ValueRange(),
442 /*asm_string=*/args&: asmStr, args&: constraints, /*has_side_effects=*/args: true,
443 /*is_align_stack=*/args: false, args: LLVM::TailCallKind::None,
444 /*asm_dialect=*/args&: asmDialectAttr,
445 /*operand_attrs=*/args: ArrayAttr());
446 return success();
447 }
448 if (chipset.majorVersion < 12) {
449 constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
450 constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8);
451 // Left in place in case someone disables the inline ASM path or future
452 // chipsets use the same bit pattern.
453 constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4);
454
455 int32_t ldsOnlyBits;
456 if (chipset.majorVersion == 11)
457 ldsOnlyBits = ldsOnlyBitsGfx11;
458 else if (chipset.majorVersion == 10)
459 ldsOnlyBits = ldsOnlyBitsGfx10;
460 else if (chipset.majorVersion <= 9)
461 ldsOnlyBits = ldsOnlyBitsGfx6789;
462 else
463 return op.emitOpError(
464 message: "don't know how to lower this for chipset major version")
465 << chipset.majorVersion;
466
467 Location loc = op->getLoc();
468 rewriter.create<ROCDL::SWaitcntOp>(location: loc, args&: ldsOnlyBits);
469 rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op);
470 } else {
471 Location loc = op->getLoc();
472 rewriter.create<ROCDL::WaitDscntOp>(location: loc, args: 0);
473 rewriter.create<ROCDL::BarrierSignalOp>(location: loc, args: -1);
474 rewriter.replaceOpWithNewOp<ROCDL::BarrierWaitOp>(op, args: -1);
475 }
476
477 return success();
478 }
479};
480
481struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
482 SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
483 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
484
485 Chipset chipset;
486
487 LogicalResult
488 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
489 ConversionPatternRewriter &rewriter) const override {
490 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
491 args: (uint32_t)op.getOpts());
492 return success();
493 }
494};
495
496} // namespace
497
498/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
499/// and LLVM AMDGPU intrinsics convention.
500///
501/// Specifically:
502/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
503/// allows bf16. Newer MFMAs support bf16 types on operand, check
504/// IntrinsicsAMDGPU.td file for reference.
505/// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
506/// instead, which is what the f8f6f4 intrinsics use.
507/// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
508/// integer.
509///
510/// Note that the type of `input` has already been LLVM type converted:
511/// therefore 8-bit and smaller floats are represented as their corresponding
512/// `iN` integers.
513static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
514 Location loc, Value input,
515 bool allowBf16 = true) {
516 Type inputType = input.getType();
517 if (auto vectorType = dyn_cast<VectorType>(Val&: inputType)) {
518 if (vectorType.getElementType().isBF16() && !allowBf16)
519 return rewriter.create<LLVM::BitcastOp>(
520 location: loc, args: vectorType.clone(elementType: rewriter.getI16Type()), args&: input);
521 if (vectorType.getElementType().isInteger(width: 8) &&
522 vectorType.getNumElements() <= 8)
523 return rewriter.create<LLVM::BitcastOp>(
524 location: loc, args: rewriter.getIntegerType(width: vectorType.getNumElements() * 8), args&: input);
525 if (isa<IntegerType>(Val: vectorType.getElementType()) &&
526 vectorType.getElementTypeBitWidth() <= 8) {
527 int64_t numWords = llvm::divideCeil(
528 Numerator: vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
529 Denominator: 32);
530 return rewriter.create<LLVM::BitcastOp>(
531 location: loc, args: VectorType::get(shape: numWords, elementType: rewriter.getI32Type()), args&: input);
532 }
533 }
534 return input;
535}
536
537/// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
538/// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
539///
540/// Specifically:
541/// 1. If `input` is a i8 value, zero extend it to i32
542/// 2. If `input` is a vector of length 4 and type i8, cast it to i32
543///
544/// Note that the type of `input` has already been LLVM type converted:
545/// therefore 8-bit and smaller floats are represented as their corresponding
546/// `iN` integers.
547static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
548 Location loc, Value input) {
549 Type inputType = input.getType();
550 Type outputType = rewriter.getI32Type();
551 if (auto intType = dyn_cast<IntegerType>(Val&: inputType))
552 return rewriter.create<LLVM::ZExtOp>(location: loc, args&: outputType, args&: input);
553 return rewriter.create<LLVM::BitcastOp>(location: loc, args&: outputType, args&: input);
554}
555
556/// Push an input operand. If it is a float type, nothing to do. If it is
557/// an integer type, then we need to also push its signdness (1 for signed, 0
558/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
559/// vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
560/// We also need to convert bfloat inputs to i16 to account for the bfloat
561/// intrinsics having been defined before the AMD backend supported bfloat. We
562/// similarly need to pack 8-bit float types into integers as if they were i8
563/// (which they are for the backend's purposes).
564static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
565 Location loc,
566 const TypeConverter *typeConverter,
567 bool isUnsigned, Value llvmInput,
568 Value mlirInput,
569 SmallVector<Value, 4> &operands) {
570 Type inputType = llvmInput.getType();
571 auto vectorType = dyn_cast<VectorType>(Val&: inputType);
572 if (!vectorType) {
573 operands.push_back(Elt: llvmInput);
574 return;
575 }
576 Type elemType = vectorType.getElementType();
577
578 if (elemType.isBF16())
579 llvmInput = rewriter.create<LLVM::BitcastOp>(
580 location: loc, args: vectorType.clone(elementType: rewriter.getI16Type()), args&: llvmInput);
581 if (elemType.getIntOrFloatBitWidth() > 8) {
582 operands.push_back(Elt: llvmInput);
583 return;
584 }
585
586 // We need to check the type of the input before conversion to properly test
587 // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
588 // fp8/int8 information is lost during the conversion process.
589 auto mlirInputType = cast<VectorType>(Val: mlirInput.getType());
590 bool isInputInteger = mlirInputType.getElementType().isInteger();
591 if (isInputInteger) {
592 // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
593 bool localIsUnsigned = isUnsigned;
594 if (elemType.isUnsignedInteger()) {
595 localIsUnsigned = true;
596 } else if (elemType.isSignedInteger()) {
597 localIsUnsigned = false;
598 }
599 Value sign = createI1Constant(rewriter, loc, value: !localIsUnsigned);
600 operands.push_back(Elt: sign);
601 }
602
603 int64_t numBits =
604 vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
605 Type i32 = rewriter.getI32Type();
606 Type intrinsicInType = numBits <= 32
607 ? (Type)rewriter.getIntegerType(width: numBits)
608 : (Type)VectorType::get(shape: numBits / 32, elementType: i32);
609 auto llvmIntrinsicInType = typeConverter->convertType(t: intrinsicInType);
610 Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
611 location: loc, args&: llvmIntrinsicInType, args&: llvmInput);
612 // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
613 // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
614 // Add in the zeros here.
615 if (numBits < 32)
616 castInput = rewriter.create<LLVM::ZExtOp>(location: loc, args&: i32, args&: castInput);
617 operands.push_back(Elt: castInput);
618}
619
620/// Push the output operand. For many cases this is only pushing the output in
621/// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
622/// since the same numbers of VGPRs is used, we need to decide if to store the
623/// result in the upper 16 bits of the VGPRs or in the lower part. To store the
624/// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
625/// be stored it in the upper part. The subwordOffset must not be set for gfx12,
626/// as the instructions have been changed to return fewer registers instead.
627static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
628 Location loc,
629 const TypeConverter *typeConverter,
630 Value output, int32_t subwordOffset,
631 bool clamp, SmallVector<Value, 4> &operands) {
632 Type inputType = output.getType();
633 auto vectorType = dyn_cast<VectorType>(Val&: inputType);
634 Type elemType = vectorType.getElementType();
635 if (elemType.isBF16())
636 output = rewriter.create<LLVM::BitcastOp>(
637 location: loc, args: vectorType.clone(elementType: rewriter.getI16Type()), args&: output);
638 operands.push_back(Elt: output);
639 if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(width: 16)) {
640 operands.push_back(Elt: createI1Constant(rewriter, loc, value: subwordOffset));
641 } else if (elemType.isInteger(width: 32)) {
642 operands.push_back(Elt: createI1Constant(rewriter, loc, value: clamp));
643 }
644}
645
646/// Return true if `type` is the E5M2 variant of an 8-bit float that is
647/// supported by the `_bf8` instructions on the given `chipset`.
648static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
649 return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(Val: type)) ||
650 (hasOcpFp8(chipset) && isa<Float8E5M2Type>(Val: type));
651}
652
653/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
654/// supported by the `_fp8` instructions on the given `chipset`.
655static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
656 return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(Val: type)) ||
657 (hasOcpFp8(chipset) && isa<Float8E4M3FNType>(Val: type));
658}
659
660/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
661/// if one exists. This includes checking to ensure the intrinsic is supported
662/// on the architecture you are compiling for.
663static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
664 Chipset chipset) {
665 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
666 b = mfma.getBlocks();
667 Type sourceElem = getElementTypeOrSelf(type: mfma.getSourceA().getType());
668 Type destElem = getElementTypeOrSelf(type: mfma.getDestC().getType());
669
670 if (sourceElem.isF32() && destElem.isF32()) {
671 if (mfma.getReducePrecision() && chipset >= kGfx942) {
672 if (m == 32 && n == 32 && k == 4 && b == 1)
673 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
674 if (m == 16 && n == 16 && k == 8 && b == 1)
675 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
676 }
677 if (m == 32 && n == 32 && k == 1 && b == 2)
678 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
679 if (m == 16 && n == 16 && k == 1 && b == 4)
680 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
681 if (m == 4 && n == 4 && k == 1 && b == 16)
682 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
683 if (m == 32 && n == 32 && k == 2 && b == 1)
684 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
685 if (m == 16 && n == 16 && k == 4 && b == 1)
686 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
687 }
688
689 if (sourceElem.isF16() && destElem.isF32()) {
690 if (chipset >= kGfx950) {
691 if (m == 32 && n == 32 && k == 16 && b == 1)
692 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
693 if (m == 16 && n == 16 && k == 32 && b == 1)
694 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
695 }
696 if (m == 32 && n == 32 && k == 4 && b == 2)
697 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
698 if (m == 16 && n == 16 && k == 4 && b == 4)
699 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
700 if (m == 4 && n == 4 && k == 4 && b == 16)
701 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
702 if (m == 32 && n == 32 && k == 8 && b == 1)
703 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
704 if (m == 16 && n == 16 && k == 16 && b == 1)
705 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
706 }
707
708 if (sourceElem.isBF16() && destElem.isF32()) {
709 if (chipset >= kGfx950) {
710 if (m == 32 && n == 32 && k == 16 && b == 1)
711 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
712 if (m == 16 && n == 16 && k == 32 && b == 1)
713 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
714 }
715 if (chipset >= kGfx90a) {
716 if (m == 32 && n == 32 && k == 4 && b == 2)
717 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
718 if (m == 16 && n == 16 && k == 4 && b == 4)
719 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
720 if (m == 4 && n == 4 && k == 4 && b == 16)
721 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
722 if (m == 32 && n == 32 && k == 8 && b == 1)
723 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
724 if (m == 16 && n == 16 && k == 16 && b == 1)
725 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
726 }
727 if (m == 32 && n == 32 && k == 2 && b == 2)
728 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
729 if (m == 16 && n == 16 && k == 2 && b == 4)
730 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
731 if (m == 4 && n == 4 && k == 2 && b == 16)
732 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
733 if (m == 32 && n == 32 && k == 4 && b == 1)
734 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
735 if (m == 16 && n == 16 && k == 8 && b == 1)
736 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
737 }
738
739 if (sourceElem.isInteger(width: 8) && destElem.isInteger(width: 32)) {
740 if (chipset >= kGfx950) {
741 if (m == 32 && n == 32 && k == 32 && b == 1)
742 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
743 if (m == 16 && n == 16 && k == 64 && b == 1)
744 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
745 }
746 if (m == 32 && n == 32 && k == 4 && b == 2)
747 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
748 if (m == 16 && n == 16 && k == 4 && b == 4)
749 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
750 if (m == 4 && n == 4 && k == 4 && b == 16)
751 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
752 if (m == 32 && n == 32 && k == 8 && b == 1)
753 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
754 if (m == 16 && n == 16 && k == 16 && b == 1)
755 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
756 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942)
757 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
758 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942)
759 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
760 }
761
762 if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
763 if (m == 16 && n == 16 && k == 4 && b == 1)
764 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
765 if (m == 4 && n == 4 && k == 4 && b == 4)
766 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
767 }
768
769 if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, type: sourceElem)) {
770 // Known to be correct because there are no scalar f8 instructions and
771 // because a length mismatch will have been caught by the verifier.
772 Type sourceBElem =
773 cast<VectorType>(Val: mfma.getSourceB().getType()).getElementType();
774 if (m == 16 && n == 16 && k == 32 && b == 1) {
775 if (typeIsExpectedBf8ForChipset(chipset, type: sourceBElem))
776 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
777 if (typeIsExpectedFp8ForChipset(chipset, type: sourceBElem))
778 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
779 }
780 if (m == 32 && n == 32 && k == 16 && b == 1) {
781 if (typeIsExpectedBf8ForChipset(chipset, type: sourceBElem))
782 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
783 if (typeIsExpectedFp8ForChipset(chipset, type: sourceBElem))
784 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
785 }
786 }
787
788 if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, type: sourceElem)) {
789 Type sourceBElem =
790 cast<VectorType>(Val: mfma.getSourceB().getType()).getElementType();
791 if (m == 16 && n == 16 && k == 32 && b == 1) {
792 if (typeIsExpectedBf8ForChipset(chipset, type: sourceBElem))
793 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
794 if (typeIsExpectedFp8ForChipset(chipset, type: sourceBElem))
795 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
796 }
797 if (m == 32 && n == 32 && k == 16 && b == 1) {
798 if (typeIsExpectedBf8ForChipset(chipset, type: sourceBElem))
799 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
800 if (typeIsExpectedFp8ForChipset(chipset, type: sourceBElem))
801 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
802 }
803 }
804
805 return std::nullopt;
806}
807
808static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
809 return llvm::TypeSwitch<Type, std::optional<uint32_t>>(mlirElemType)
810 .Case(caseFn: [](Float8E4M3FNType) { return 0u; })
811 .Case(caseFn: [](Float8E5M2Type) { return 1u; })
812 .Case(caseFn: [](Float6E2M3FNType) { return 2u; })
813 .Case(caseFn: [](Float6E3M2FNType) { return 3u; })
814 .Case(caseFn: [](Float4E2M1FNType) { return 4u; })
815 .Default(defaultFn: [](Type) { return std::nullopt; });
816}
817
818/// If there is a scaled MFMA instruction for the input element types `aType`
819/// and `bType`, output type `destType`, problem size M, N, K, and B (number of
820/// blocks) on the given `chipset`, return a tuple consisting of the
821/// OperationName of the intrinsic and the type codes that need to be passed to
822/// that intrinsic. Note that this is also used to implement some un-scaled
823/// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
824/// MFMA with a scale of 0.
825static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
826mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
827 uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
828 aType = getElementTypeOrSelf(type: aType);
829 bType = getElementTypeOrSelf(type: bType);
830 destType = getElementTypeOrSelf(type: destType);
831
832 if (chipset < kGfx950)
833 return std::nullopt;
834 if (!isa<Float32Type>(Val: destType))
835 return std::nullopt;
836
837 std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(mlirElemType: aType);
838 std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(mlirElemType: bType);
839 if (!aTypeCode || !bTypeCode)
840 return std::nullopt;
841
842 if (m == 32 && n == 32 && k == 64 && b == 1)
843 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
844 *aTypeCode, *bTypeCode};
845 if (m == 16 && n == 16 && k == 128 && b == 1)
846 return std::tuple{
847 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
848 *bTypeCode};
849
850 return std::nullopt;
851}
852
853static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
854mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
855 return mfmaOpToScaledIntrinsic(
856 aType: mfma.getSourceA().getType(), bType: mfma.getSourceB().getType(),
857 destType: mfma.getDestC().getType(), m: mfma.getM(), n: mfma.getN(), k: mfma.getK(),
858 b: mfma.getBlocks(), chipset);
859}
860
861static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
862mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
863 return mfmaOpToScaledIntrinsic(aType: smfma.getSourceA().getType(),
864 bType: smfma.getSourceB().getType(),
865 destType: smfma.getDestC().getType(), m: smfma.getM(),
866 n: smfma.getN(), k: smfma.getK(), b: 1u, chipset);
867}
868
869/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
870/// if one exists. This includes checking to ensure the intrinsic is supported
871/// on the architecture you are compiling for.
872static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
873 Chipset chipset) {
874 auto sourceVectorType = dyn_cast<VectorType>(Val: wmma.getSourceA().getType());
875 auto sourceBVectorType = dyn_cast<VectorType>(Val: wmma.getSourceB().getType());
876 auto destVectorType = dyn_cast<VectorType>(Val: wmma.getDestC().getType());
877 auto elemSourceType = sourceVectorType.getElementType();
878 auto elemBSourceType = sourceBVectorType.getElementType();
879 auto elemDestType = destVectorType.getElementType();
880
881 if (elemSourceType.isF16() && elemDestType.isF32())
882 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
883 if (elemSourceType.isBF16() && elemDestType.isF32())
884 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
885 if (elemSourceType.isF16() && elemDestType.isF16())
886 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
887 if (elemSourceType.isBF16() && elemDestType.isBF16())
888 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
889 if (elemSourceType.isInteger(width: 8) && elemDestType.isInteger(width: 32))
890 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
891 if (chipset.majorVersion == 11) {
892 if (elemSourceType.isInteger(width: 4) && elemDestType.isInteger(width: 32))
893 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
894 }
895 if (chipset.majorVersion >= 12) {
896 if (isa<Float8E4M3FNType>(Val: elemSourceType) &&
897 isa<Float8E4M3FNType>(Val: elemBSourceType) && elemDestType.isF32())
898 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
899 if (isa<Float8E4M3FNType>(Val: elemSourceType) &&
900 isa<Float8E5M2Type>(Val: elemBSourceType) && elemDestType.isF32())
901 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
902 if (isa<Float8E5M2Type>(Val: elemSourceType) &&
903 isa<Float8E5M2Type>(Val: elemBSourceType) && elemDestType.isF32())
904 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
905 if (isa<Float8E5M2Type>(Val: elemSourceType) &&
906 isa<Float8E4M3FNType>(Val: elemBSourceType) && elemDestType.isF32())
907 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
908 if (elemSourceType.isInteger(width: 4) && elemDestType.isInteger(width: 32)) {
909 bool isWave64 = destVectorType.getNumElements() == 4;
910 // This is the ambiguous case. 8 inputs to the wave64 version means that
911 // we want the 16x16x32 version, but for wave32 they mean the short form.
912 bool has8Inputs = sourceVectorType.getNumElements() == 8;
913 if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
914 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
915 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
916 }
917 }
918 return std::nullopt;
919}
920
921namespace {
922struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
923 MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
924 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
925
926 Chipset chipset;
927
928 LogicalResult
929 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
930 ConversionPatternRewriter &rewriter) const override {
931 Location loc = op.getLoc();
932 Type outType = typeConverter->convertType(t: op.getDestD().getType());
933 Type intrinsicOutType = outType;
934 if (auto outVecType = dyn_cast<VectorType>(Val&: outType))
935 if (outVecType.getElementType().isBF16())
936 intrinsicOutType = outVecType.clone(elementType: rewriter.getI16Type());
937
938 if (chipset.majorVersion != 9 || chipset < kGfx908)
939 return op->emitOpError(message: "MFMA only supported on gfx908+");
940 uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
941 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
942 if (chipset < kGfx942)
943 return op.emitOpError(message: "negation unsupported on older than gfx942");
944 getBlgpField |=
945 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
946 }
947 std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(mfma: op, chipset);
948 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
949 maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(mfma: op, chipset);
950 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
951 return op.emitOpError(message: "no intrinsic matching MFMA size on given chipset");
952
953 bool isScaled =
954 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
955 if (isScaled &&
956 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
957 return op.emitOpError(
958 message: "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
959 "be scaled as those fields are used for type information");
960 }
961
962 StringRef intrinsicName =
963 isScaled ? std::get<0>(t&: *maybeScaledIntrinsic) : *maybeIntrinsic;
964 // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
965 // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
966 bool allowBf16 = [&]() {
967 if (chipset < kGfx950)
968 return false;
969 if (isScaled)
970 return true;
971 return intrinsicName.contains(Other: "16x16x32.bf16") ||
972 intrinsicName.contains(Other: "32x32x16.bf16");
973 }();
974 OperationState loweredOp(loc, intrinsicName);
975 loweredOp.addTypes(newTypes: intrinsicOutType);
976 loweredOp.addOperands(newOperands: {convertMFMAVectorOperand(
977 rewriter, loc, input: adaptor.getSourceA(), allowBf16),
978 convertMFMAVectorOperand(
979 rewriter, loc, input: adaptor.getSourceB(), allowBf16),
980 adaptor.getDestC()});
981 if (isScaled) {
982 Value zero = createI32Constant(rewriter, loc, value: 0);
983 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
984 loweredOp.addOperands(newOperands: {createI32Constant(rewriter, loc, value: aTypeCode),
985 createI32Constant(rewriter, loc, value: bTypeCode),
986 /*scale A byte=*/zero, /*scale A=*/zero,
987 /*scale B byte=*/zero, /*scale B=*/zero});
988 } else {
989 loweredOp.addOperands(newOperands: {createI32Constant(rewriter, loc, value: op.getCbsz()),
990 createI32Constant(rewriter, loc, value: op.getAbid()),
991 createI32Constant(rewriter, loc, value: getBlgpField)});
992 };
993 Value lowered = rewriter.create(state: loweredOp)->getResult(idx: 0);
994 if (outType != intrinsicOutType)
995 lowered = rewriter.create<LLVM::BitcastOp>(location: loc, args&: outType, args&: lowered);
996 rewriter.replaceOp(op, newValues: lowered);
997 return success();
998 }
999};
1000
1001struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
1002 ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1003 : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1004
1005 Chipset chipset;
1006
1007 LogicalResult
1008 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1009 ConversionPatternRewriter &rewriter) const override {
1010 Location loc = op.getLoc();
1011 Type intrinsicOutType = typeConverter->convertType(t: op.getDestD().getType());
1012
1013 if (chipset.majorVersion != 9 || chipset < kGfx950)
1014 return op->emitOpError(message: "scaled MFMA only supported on gfx908+");
1015 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1016 maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(smfma: op, chipset);
1017 if (!maybeScaledIntrinsic.has_value())
1018 return op.emitOpError(
1019 message: "no intrinsic matching scaled MFMA size on given chipset");
1020
1021 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1022 OperationState loweredOp(loc, intrinsicName);
1023 loweredOp.addTypes(newTypes: intrinsicOutType);
1024 loweredOp.addOperands(
1025 newOperands: {convertMFMAVectorOperand(rewriter, loc, input: adaptor.getSourceA()),
1026 convertMFMAVectorOperand(rewriter, loc, input: adaptor.getSourceB()),
1027 adaptor.getDestC()});
1028 Value scalesIdxA =
1029 createI32Constant(rewriter, loc, value: adaptor.getScalesIdxA());
1030 Value scalesIdxB =
1031 createI32Constant(rewriter, loc, value: adaptor.getScalesIdxB());
1032 loweredOp.addOperands(
1033 newOperands: {createI32Constant(rewriter, loc, value: aTypeCode),
1034 createI32Constant(rewriter, loc, value: bTypeCode),
1035 /*scales idx A=*/scalesIdxA,
1036 /*scales A*/
1037 castMFMAScaleOperand(rewriter, loc, input: adaptor.getScalesA()),
1038 /*scales idx B=*/scalesIdxB,
1039 /*scales B*/
1040 castMFMAScaleOperand(rewriter, loc, input: adaptor.getScalesB())});
1041 Value lowered = rewriter.create(state: loweredOp)->getResult(idx: 0);
1042 rewriter.replaceOp(op, newValues: lowered);
1043 return success();
1044 }
1045};
1046
1047struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
1048 WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1049 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1050
1051 Chipset chipset;
1052
1053 LogicalResult
1054 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1055 ConversionPatternRewriter &rewriter) const override {
1056 Location loc = op.getLoc();
1057 auto outType =
1058 typeConverter->convertType<VectorType>(t: op.getDestD().getType());
1059 if (!outType)
1060 return rewriter.notifyMatchFailure(arg&: op, msg: "type conversion failed");
1061
1062 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1063 return op->emitOpError(message: "WMMA only supported on gfx11 and gfx12");
1064
1065 // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
1066 // need to bitcast bfloats to i16 and then bitcast them back.
1067 VectorType rawOutType = outType;
1068 if (outType.getElementType().isBF16())
1069 rawOutType = outType.clone(elementType: rewriter.getI16Type());
1070
1071 std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(wmma: op, chipset);
1072
1073 if (!maybeIntrinsic.has_value())
1074 return op.emitOpError(message: "no intrinsic matching WMMA on the given chipset");
1075
1076 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1077 return op.emitOpError(message: "subwordOffset not supported on gfx12+");
1078
1079 OperationState loweredOp(loc, *maybeIntrinsic);
1080 loweredOp.addTypes(newTypes: rawOutType);
1081
1082 SmallVector<Value, 4> operands;
1083 wmmaPushInputOperand(rewriter, loc, typeConverter, isUnsigned: op.getUnsignedA(),
1084 llvmInput: adaptor.getSourceA(), mlirInput: op.getSourceA(), operands);
1085 wmmaPushInputOperand(rewriter, loc, typeConverter, isUnsigned: op.getUnsignedB(),
1086 llvmInput: adaptor.getSourceB(), mlirInput: op.getSourceB(), operands);
1087 wmmaPushOutputOperand(rewriter, loc, typeConverter, output: adaptor.getDestC(),
1088 subwordOffset: op.getSubwordOffset(), clamp: op.getClamp(), operands);
1089
1090 loweredOp.addOperands(newOperands: operands);
1091 Operation *lowered = rewriter.create(state: loweredOp);
1092
1093 Operation *maybeCastBack = lowered;
1094 if (rawOutType != outType)
1095 maybeCastBack =
1096 rewriter.create<LLVM::BitcastOp>(location: loc, args&: outType, args: lowered->getResult(idx: 0));
1097 rewriter.replaceOp(op, newValues: maybeCastBack->getResults());
1098
1099 return success();
1100 }
1101};
1102
1103struct TransposeLoadOpLowering
1104 : public ConvertOpToLLVMPattern<TransposeLoadOp> {
1105 TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1106 : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
1107
1108 Chipset chipset;
1109
1110 LogicalResult
1111 matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1112 ConversionPatternRewriter &rewriter) const override {
1113 if (chipset != kGfx950)
1114 return op.emitOpError(message: "Non-gfx950 chipset not supported");
1115
1116 Location loc = op.getLoc();
1117 auto srcMemRefType = cast<MemRefType>(Val: op.getSrc().getType());
1118
1119 // Elements in subbyte memrefs are stored non-contiguously,
1120 // reject if source is sub-byte memref. Use emulated memrefs instead.
1121 size_t srcElementSize =
1122 srcMemRefType.getElementType().getIntOrFloatBitWidth();
1123 if (srcElementSize < 8)
1124 return op.emitOpError(message: "Expect source memref to have at least 8 bits "
1125 "element size, got ")
1126 << srcElementSize;
1127
1128 auto resultType = cast<VectorType>(Val: op.getResult().getType());
1129 Value srcPtr =
1130 getStridedElementPtr(rewriter, loc, type: srcMemRefType, memRefDesc: adaptor.getSrc(),
1131 indices: (adaptor.getSrcIndices()));
1132
1133 size_t numElements = resultType.getNumElements();
1134 size_t elementTypeSize =
1135 resultType.getElementType().getIntOrFloatBitWidth();
1136
1137 // ROCDL transpose load intrinsics return vectors of 32-bit integers, if
1138 // the element size is smaller than 16 bits.
1139 Type rocdlResultType = VectorType::get(shape: (numElements * elementTypeSize) / 32,
1140 elementType: rewriter.getIntegerType(width: 32));
1141 Type llvmResultType = typeConverter->convertType(t: resultType);
1142
1143 switch (elementTypeSize) {
1144 case 4: {
1145 assert(numElements == 16);
1146 auto rocdlOp =
1147 rewriter.create<ROCDL::ds_read_tr4_b64>(location: loc, args&: rocdlResultType, args&: srcPtr);
1148 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, args&: llvmResultType, args&: rocdlOp);
1149 break;
1150 }
1151 case 6: {
1152 assert(numElements == 16);
1153 auto rocdlOp =
1154 rewriter.create<ROCDL::ds_read_tr6_b96>(location: loc, args&: rocdlResultType, args&: srcPtr);
1155 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, args&: llvmResultType, args&: rocdlOp);
1156 break;
1157 }
1158 case 8: {
1159 assert(numElements == 8);
1160 auto rocdlOp =
1161 rewriter.create<ROCDL::ds_read_tr8_b64>(location: loc, args&: rocdlResultType, args&: srcPtr);
1162 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, args&: llvmResultType, args&: rocdlOp);
1163 break;
1164 }
1165 case 16: {
1166 assert(numElements == 4);
1167 rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, args&: llvmResultType,
1168 args&: srcPtr);
1169 break;
1170 }
1171 default:
1172 return op.emitOpError(message: "Unsupported element size for transpose load");
1173 }
1174 return success();
1175 }
1176};
1177
1178struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
1179 GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1180 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
1181
1182 Chipset chipset;
1183
1184 LogicalResult
1185 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1186 ConversionPatternRewriter &rewriter) const override {
1187 if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
1188 return op.emitOpError(message: "pre-gfx9 and post-gfx10 not supported");
1189
1190 Location loc = op.getLoc();
1191
1192 auto srcMemRefType = cast<MemRefType>(Val: op.getSrc().getType());
1193 auto dstMemRefType = cast<MemRefType>(Val: op.getDst().getType());
1194
1195 // TODO: instead of only transfering one element per thread, we could
1196 // augment it to transfer multiple elements per thread by issuing multiple
1197 // `global_load_lds` instructions.
1198 Type transferType = op.getTransferType();
1199 int loadWidth = [&]() -> int {
1200 if (auto transferVectorType = dyn_cast<VectorType>(Val&: transferType)) {
1201 return (transferVectorType.getNumElements() *
1202 transferVectorType.getElementTypeBitWidth()) /
1203 8;
1204 }
1205 return transferType.getIntOrFloatBitWidth() / 8;
1206 }();
1207
1208 // Currently only 1, 2, 4, 12 and 16 byte loads are supported.
1209 if (!llvm::is_contained(Set: {1, 2, 4, 12, 16}, Element: loadWidth))
1210 return op.emitOpError(message: "chipset unsupported element size");
1211
1212 if (chipset != kGfx950 && llvm::is_contained(Set: {12, 16}, Element: loadWidth))
1213 return op.emitOpError(message: "Gather to LDS instructions with 12-byte and "
1214 "16-byte load widths are only supported on gfx950");
1215
1216 Value srcPtr =
1217 getStridedElementPtr(rewriter, loc, type: srcMemRefType, memRefDesc: adaptor.getSrc(),
1218 indices: (adaptor.getSrcIndices()));
1219 Value dstPtr =
1220 getStridedElementPtr(rewriter, loc, type: dstMemRefType, memRefDesc: adaptor.getDst(),
1221 indices: (adaptor.getDstIndices()));
1222
1223 rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
1224 op, args&: srcPtr, args&: dstPtr, args: rewriter.getI32IntegerAttr(value: loadWidth),
1225 /*offset=*/args: rewriter.getI32IntegerAttr(value: 0),
1226 /*aux=*/args: rewriter.getI32IntegerAttr(value: 0), args: ArrayAttr{}, args: ArrayAttr{},
1227 args: ArrayAttr{});
1228
1229 return success();
1230 }
1231};
1232
1233namespace {
1234struct ExtPackedFp8OpLowering final
1235 : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
1236 ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1237 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
1238 chipset(chipset) {}
1239 Chipset chipset;
1240
1241 LogicalResult
1242 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1243 ConversionPatternRewriter &rewriter) const override;
1244};
1245
1246struct PackedTrunc2xFp8OpLowering final
1247 : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
1248 PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
1249 Chipset chipset)
1250 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
1251 chipset(chipset) {}
1252 Chipset chipset;
1253
1254 LogicalResult
1255 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1256 ConversionPatternRewriter &rewriter) const override;
1257};
1258
1259struct PackedStochRoundFp8OpLowering final
1260 : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
1261 PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
1262 Chipset chipset)
1263 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
1264 chipset(chipset) {}
1265 Chipset chipset;
1266
1267 LogicalResult
1268 matchAndRewrite(PackedStochRoundFp8Op op,
1269 PackedStochRoundFp8OpAdaptor adaptor,
1270 ConversionPatternRewriter &rewriter) const override;
1271};
1272
1273struct ScaledExtPackedOpLowering final
1274 : public ConvertOpToLLVMPattern<ScaledExtPackedOp> {
1275 ScaledExtPackedOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1276 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
1277 chipset(chipset) {}
1278 Chipset chipset;
1279
1280 LogicalResult
1281 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1282 ConversionPatternRewriter &rewriter) const override;
1283};
1284
1285struct PackedScaledTruncOpLowering final
1286 : public ConvertOpToLLVMPattern<PackedScaledTruncOp> {
1287 PackedScaledTruncOpLowering(const LLVMTypeConverter &converter,
1288 Chipset chipset)
1289 : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
1290 chipset(chipset) {}
1291 Chipset chipset;
1292
1293 LogicalResult
1294 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1295 ConversionPatternRewriter &rewriter) const override;
1296};
1297
1298} // end namespace
1299
1300LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1301 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1302 ConversionPatternRewriter &rewriter) const {
1303 Location loc = op.getLoc();
1304 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1305 return rewriter.notifyMatchFailure(
1306 arg&: loc, msg: "Fp8 conversion instructions are not available on target "
1307 "architecture and their emulation is not implemented");
1308 Type v4i8 =
1309 getTypeConverter()->convertType(t: VectorType::get(shape: 4, elementType: rewriter.getI8Type()));
1310 Type i32 = getTypeConverter()->convertType(t: rewriter.getI32Type());
1311 Type f32 = getTypeConverter()->convertType(t: op.getResult().getType());
1312
1313 Value source = adaptor.getSource();
1314 auto sourceVecType = dyn_cast<VectorType>(Val: op.getSource().getType());
1315 auto resultVecType = dyn_cast<VectorType>(Val: op.getResult().getType());
1316 Type sourceElemType = getElementTypeOrSelf(val: op.getSource());
1317 // Extend to a v4i8
1318 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1319 Value longVec = rewriter.create<LLVM::UndefOp>(location: loc, args&: v4i8);
1320 if (!sourceVecType) {
1321 longVec = rewriter.create<LLVM::InsertElementOp>(
1322 location: loc, args&: longVec, args&: source, args: createI32Constant(rewriter, loc, value: 0));
1323 } else {
1324 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1325 Value idx = createI32Constant(rewriter, loc, value: i);
1326 Value elem = rewriter.create<LLVM::ExtractElementOp>(location: loc, args&: source, args&: idx);
1327 longVec =
1328 rewriter.create<LLVM::InsertElementOp>(location: loc, args&: longVec, args&: elem, args&: idx);
1329 }
1330 }
1331 source = longVec;
1332 }
1333 Value i32Source = rewriter.create<LLVM::BitcastOp>(location: loc, args&: i32, args&: source);
1334 if (resultVecType) {
1335 if (typeIsExpectedBf8ForChipset(chipset, type: sourceElemType)) {
1336 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, args&: f32, args&: i32Source,
1337 args: op.getIndex());
1338 } else if (typeIsExpectedFp8ForChipset(chipset, type: sourceElemType)) {
1339 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, args&: f32, args&: i32Source,
1340 args: op.getIndex());
1341 }
1342 } else {
1343 if (typeIsExpectedBf8ForChipset(chipset, type: sourceElemType)) {
1344 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, args&: f32, args&: i32Source,
1345 args: op.getIndex());
1346 } else if (typeIsExpectedFp8ForChipset(chipset, type: sourceElemType)) {
1347 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, args&: f32, args&: i32Source,
1348 args: op.getIndex());
1349 }
1350 }
1351 return success();
1352}
1353
1354LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
1355 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1356 ConversionPatternRewriter &rewriter) const {
1357 Location loc = op.getLoc();
1358 if (chipset != kGfx950)
1359 return rewriter.notifyMatchFailure(
1360 arg&: loc, msg: "Scaled fp conversion instructions are not available on target "
1361 "architecture and their emulation is not implemented");
1362 Type i32 = getTypeConverter()->convertType(t: rewriter.getI32Type());
1363
1364 Value source = adaptor.getSource();
1365 Value scale = adaptor.getScale();
1366
1367 VectorType sourceVecType = cast<VectorType>(Val: op.getSource().getType());
1368 Type sourceElemType = sourceVecType.getElementType();
1369 VectorType destVecType = cast<VectorType>(Val: op.getResult().getType());
1370 Type destElemType = destVecType.getElementType();
1371
1372 VectorType packedVecType;
1373 if (isa<Float8E5M2Type, Float8E4M3FNType>(Val: sourceElemType)) {
1374 VectorType v4i8 = VectorType::get(shape: 4, elementType: rewriter.getI8Type());
1375 packedVecType = cast<VectorType>(Val: getTypeConverter()->convertType(t: v4i8));
1376 } else if (isa<Float4E2M1FNType>(Val: sourceElemType)) {
1377 VectorType v8i4 = VectorType::get(shape: 8, elementType: rewriter.getI4Type());
1378 packedVecType = cast<VectorType>(Val: getTypeConverter()->convertType(t: v8i4));
1379 } else {
1380 llvm_unreachable("invalid element type for scaled ext");
1381 }
1382
1383 // Extend to a packedVectorType
1384 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
1385 Value longVec = rewriter.create<LLVM::ZeroOp>(location: loc, args&: packedVecType);
1386 if (!sourceVecType) {
1387 longVec = rewriter.create<LLVM::InsertElementOp>(
1388 location: loc, args&: longVec, args&: source, args: createI32Constant(rewriter, loc, value: 0));
1389 } else {
1390 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1391 Value idx = createI32Constant(rewriter, loc, value: i);
1392 Value elem = rewriter.create<LLVM::ExtractElementOp>(location: loc, args&: source, args&: idx);
1393 longVec =
1394 rewriter.create<LLVM::InsertElementOp>(location: loc, args&: longVec, args&: elem, args&: idx);
1395 }
1396 }
1397 source = longVec;
1398 }
1399 Value i32Source = rewriter.create<LLVM::BitcastOp>(location: loc, args&: i32, args&: source);
1400
1401 if (isa<Float8E5M2Type>(Val: sourceElemType) && destElemType.isF32())
1402 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
1403 op, args&: destVecType, args&: i32Source, args&: scale, args: op.getIndex());
1404 else if (isa<Float8E5M2Type>(Val: sourceElemType) && destElemType.isF16())
1405 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
1406 op, args&: destVecType, args&: i32Source, args&: scale, args: op.getIndex());
1407 else if (isa<Float8E5M2Type>(Val: sourceElemType) && destElemType.isBF16())
1408 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
1409 op, args&: destVecType, args&: i32Source, args&: scale, args: op.getIndex());
1410 else if (isa<Float8E4M3FNType>(Val: sourceElemType) && destElemType.isF32())
1411 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
1412 op, args&: destVecType, args&: i32Source, args&: scale, args: op.getIndex());
1413 else if (isa<Float8E4M3FNType>(Val: sourceElemType) && destElemType.isF16())
1414 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
1415 op, args&: destVecType, args&: i32Source, args&: scale, args: op.getIndex());
1416 else if (isa<Float8E4M3FNType>(Val: sourceElemType) && destElemType.isBF16())
1417 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
1418 op, args&: destVecType, args&: i32Source, args&: scale, args: op.getIndex());
1419 else if (isa<Float4E2M1FNType>(Val: sourceElemType) && destElemType.isF32())
1420 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
1421 op, args&: destVecType, args&: i32Source, args&: scale, args: op.getIndex());
1422 else if (isa<Float4E2M1FNType>(Val: sourceElemType) && destElemType.isF16())
1423 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
1424 op, args&: destVecType, args&: i32Source, args&: scale, args: op.getIndex());
1425 else if (isa<Float4E2M1FNType>(Val: sourceElemType) && destElemType.isBF16())
1426 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
1427 op, args&: destVecType, args&: i32Source, args&: scale, args: op.getIndex());
1428 else
1429 return failure();
1430
1431 return success();
1432}
1433
1434LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
1435 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1436 ConversionPatternRewriter &rewriter) const {
1437 Location loc = op.getLoc();
1438 if (chipset != kGfx950)
1439 return rewriter.notifyMatchFailure(
1440 arg&: loc, msg: "Scaled fp conversion instructions are not available on target "
1441 "architecture and their emulation is not implemented");
1442 Type v2i16 = getTypeConverter()->convertType(
1443 t: VectorType::get(shape: 2, elementType: rewriter.getI16Type()));
1444 Type i32 = getTypeConverter()->convertType(t: rewriter.getI32Type());
1445
1446 Type resultType = op.getResult().getType();
1447 Type resultElemType = getElementTypeOrSelf(type: resultType);
1448 VectorType sourceVecType = cast<VectorType>(Val: op.getSource().getType());
1449 Type sourceElemType = sourceVecType.getElementType();
1450
1451 Type intResultType = isa<Float4E2M1FNType>(Val: resultElemType) ? i32 : v2i16;
1452
1453 Value source = adaptor.getSource();
1454 Value scale = adaptor.getScale();
1455 Value existing = adaptor.getExisting();
1456 if (existing)
1457 existing = rewriter.create<LLVM::BitcastOp>(location: loc, args&: intResultType, args&: existing);
1458 else
1459 existing = rewriter.create<LLVM::ZeroOp>(location: loc, args&: intResultType);
1460
1461 if (sourceVecType.getNumElements() < 2) {
1462 Value c0 = createI32Constant(rewriter, loc, value: 0);
1463 Value elem0 = rewriter.create<LLVM::ExtractElementOp>(location: loc, args&: source, args&: c0);
1464 VectorType v2 = VectorType::get(shape: 2, elementType: sourceElemType);
1465 source = rewriter.create<LLVM::ZeroOp>(location: loc, args&: v2);
1466 source = rewriter.create<LLVM::InsertElementOp>(location: loc, args&: source, args&: elem0, args&: c0);
1467 }
1468
1469 Value sourceA, sourceB;
1470 if (sourceElemType.isF32()) {
1471 Value c0 = createI32Constant(rewriter, loc, value: 0);
1472 Value c1 = createI32Constant(rewriter, loc, value: 1);
1473 sourceA = rewriter.create<LLVM::ExtractElementOp>(location: loc, args&: source, args&: c0);
1474 sourceB = rewriter.create<LLVM::ExtractElementOp>(location: loc, args&: source, args&: c1);
1475 }
1476
1477 Value result;
1478 if (sourceElemType.isF32() && isa<Float8E5M2Type>(Val: resultElemType))
1479 result = rewriter.create<ROCDL::CvtScaleF32PkBf8F32Op>(
1480 location: loc, args&: intResultType, args&: existing, args&: sourceA, args&: sourceB, args&: scale, args: op.getIndex());
1481 else if (sourceElemType.isF16() && isa<Float8E5M2Type>(Val: resultElemType))
1482 result = rewriter.create<ROCDL::CvtScaleF32PkBf8F16Op>(
1483 location: loc, args&: intResultType, args&: existing, args&: source, args&: scale, args: op.getIndex());
1484 else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(Val: resultElemType))
1485 result = rewriter.create<ROCDL::CvtScaleF32PkBf8Bf16Op>(
1486 location: loc, args&: intResultType, args&: existing, args&: source, args&: scale, args: op.getIndex());
1487 else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(Val: resultElemType))
1488 result = rewriter.create<ROCDL::CvtScaleF32PkFp8F32Op>(
1489 location: loc, args&: intResultType, args&: existing, args&: sourceA, args&: sourceB, args&: scale, args: op.getIndex());
1490 else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(Val: resultElemType))
1491 result = rewriter.create<ROCDL::CvtScaleF32PkFp8F16Op>(
1492 location: loc, args&: intResultType, args&: existing, args&: source, args&: scale, args: op.getIndex());
1493 else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(Val: resultElemType))
1494 result = rewriter.create<ROCDL::CvtScaleF32PkFp8Bf16Op>(
1495 location: loc, args&: intResultType, args&: existing, args&: source, args&: scale, args: op.getIndex());
1496 else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(Val: resultElemType))
1497 result = rewriter.create<ROCDL::CvtScaleF32PkFp4F32Op>(
1498 location: loc, args&: intResultType, args&: existing, args&: sourceA, args&: sourceB, args&: scale, args: op.getIndex());
1499 else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(Val: resultElemType))
1500 result = rewriter.create<ROCDL::CvtScaleF32PkFp4F16Op>(
1501 location: loc, args&: intResultType, args&: existing, args&: source, args&: scale, args: op.getIndex());
1502 else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(Val: resultElemType))
1503 result = rewriter.create<ROCDL::CvtScaleF32PkFp4Bf16Op>(
1504 location: loc, args&: intResultType, args&: existing, args&: source, args&: scale, args: op.getIndex());
1505 else
1506 return failure();
1507
1508 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1509 op, args: getTypeConverter()->convertType(t: resultType), args&: result);
1510 return success();
1511}
1512
1513LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
1514 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1515 ConversionPatternRewriter &rewriter) const {
1516 Location loc = op.getLoc();
1517 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1518 return rewriter.notifyMatchFailure(
1519 arg&: loc, msg: "Fp8 conversion instructions are not available on target "
1520 "architecture and their emulation is not implemented");
1521 Type i32 = getTypeConverter()->convertType(t: rewriter.getI32Type());
1522
1523 Type resultType = op.getResult().getType();
1524 Type resultElemType = getElementTypeOrSelf(type: resultType);
1525
1526 Value sourceA = adaptor.getSourceA();
1527 Value sourceB = adaptor.getSourceB();
1528 if (!sourceB)
1529 sourceB = rewriter.create<LLVM::UndefOp>(location: loc, args: sourceA.getType());
1530 Value existing = adaptor.getExisting();
1531 if (existing)
1532 existing = rewriter.create<LLVM::BitcastOp>(location: loc, args&: i32, args&: existing);
1533 else
1534 existing = rewriter.create<LLVM::UndefOp>(location: loc, args&: i32);
1535
1536 Value result;
1537 if (typeIsExpectedBf8ForChipset(chipset, type: resultElemType))
1538 result = rewriter.create<ROCDL::CvtPkBf8F32Op>(location: loc, args&: i32, args&: sourceA, args&: sourceB,
1539 args&: existing, args: op.getWordIndex());
1540 else if (typeIsExpectedFp8ForChipset(chipset, type: resultElemType))
1541 result = rewriter.create<ROCDL::CvtPkFp8F32Op>(location: loc, args&: i32, args&: sourceA, args&: sourceB,
1542 args&: existing, args: op.getWordIndex());
1543
1544 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1545 op, args: getTypeConverter()->convertType(t: resultType), args&: result);
1546 return success();
1547}
1548
1549LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
1550 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
1551 ConversionPatternRewriter &rewriter) const {
1552 Location loc = op.getLoc();
1553 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1554 return rewriter.notifyMatchFailure(
1555 arg&: loc, msg: "Fp8 conversion instructions are not available on target "
1556 "architecture and their emulation is not implemented");
1557 Type i32 = getTypeConverter()->convertType(t: rewriter.getI32Type());
1558
1559 Type resultType = op.getResult().getType();
1560 Type resultElemType = getElementTypeOrSelf(type: resultType);
1561
1562 Value source = adaptor.getSource();
1563 Value stoch = adaptor.getStochiasticParam();
1564 Value existing = adaptor.getExisting();
1565 if (existing)
1566 existing = rewriter.create<LLVM::BitcastOp>(location: loc, args&: i32, args&: existing);
1567 else
1568 existing = rewriter.create<LLVM::UndefOp>(location: loc, args&: i32);
1569
1570 Value result;
1571 if (typeIsExpectedBf8ForChipset(chipset, type: resultElemType))
1572 result = rewriter.create<ROCDL::CvtSrBf8F32Op>(
1573 location: loc, args&: i32, args&: source, args&: stoch, args&: existing, args: op.getStoreIndex());
1574 else if (typeIsExpectedFp8ForChipset(chipset, type: resultElemType))
1575 result = rewriter.create<ROCDL::CvtSrFp8F32Op>(
1576 location: loc, args&: i32, args&: source, args&: stoch, args&: existing, args: op.getStoreIndex());
1577
1578 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1579 op, args: getTypeConverter()->convertType(t: resultType), args&: result);
1580 return success();
1581}
1582
1583// Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
1584// operation into the corresponding ROCDL instructions.
1585struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
1586 AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
1587 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
1588 Chipset chipset;
1589
1590 LogicalResult
1591 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
1592 ConversionPatternRewriter &rewriter) const override {
1593
1594 // Convert the source operand to the corresponding LLVM type
1595 Location loc = DppOp.getLoc();
1596 Value src = adaptor.getSrc();
1597 Value old = adaptor.getOld();
1598 Type srcType = src.getType();
1599 Type oldType = old.getType();
1600 Type llvmType = nullptr;
1601 if (srcType.getIntOrFloatBitWidth() < 32) {
1602 llvmType = rewriter.getI32Type();
1603 } else if (isa<FloatType>(Val: srcType)) {
1604 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
1605 ? rewriter.getF32Type()
1606 : rewriter.getF64Type();
1607 } else if (isa<IntegerType>(Val: srcType)) {
1608 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
1609 ? rewriter.getI32Type()
1610 : rewriter.getI64Type();
1611 }
1612 auto llvmSrcIntType = typeConverter->convertType(
1613 t: rewriter.getIntegerType(width: srcType.getIntOrFloatBitWidth()));
1614
1615 // If the source type is less of 32, use bitcast to convert it to i32.
1616 auto convertOperand = [&](Value operand, Type operandType) {
1617 if (operandType.getIntOrFloatBitWidth() <= 16) {
1618 if (llvm::isa<FloatType>(Val: operandType)) {
1619 operand =
1620 rewriter.create<LLVM::BitcastOp>(location: loc, args&: llvmSrcIntType, args&: operand);
1621 }
1622 auto llvmVecType = typeConverter->convertType(t: mlir::VectorType::get(
1623 shape: 32 / operandType.getIntOrFloatBitWidth(), elementType: llvmSrcIntType));
1624 Value undefVec = rewriter.create<LLVM::UndefOp>(location: loc, args&: llvmVecType);
1625 operand = rewriter.create<LLVM::InsertElementOp>(
1626 location: loc, args&: undefVec, args&: operand, args: createI32Constant(rewriter, loc, value: 0));
1627 operand = rewriter.create<LLVM::BitcastOp>(location: loc, args&: llvmType, args&: operand);
1628 }
1629 return operand;
1630 };
1631
1632 src = convertOperand(src, srcType);
1633 old = convertOperand(old, oldType);
1634
1635 // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
1636 enum DppCtrl : unsigned {
1637 ROW_SHL0 = 0x100,
1638 ROW_SHR0 = 0x110,
1639 ROW_ROR0 = 0x120,
1640 WAVE_SHL1 = 0x130,
1641 WAVE_ROL1 = 0x134,
1642 WAVE_SHR1 = 0x138,
1643 WAVE_ROR1 = 0x13C,
1644 ROW_MIRROR = 0x140,
1645 ROW_HALF_MIRROR = 0x141,
1646 BCAST15 = 0x142,
1647 BCAST31 = 0x143,
1648 };
1649
1650 auto kind = DppOp.getKind();
1651 auto permArgument = DppOp.getPermArgument();
1652 uint32_t DppCtrl = 0;
1653
1654 switch (kind) {
1655
1656 case DPPPerm::quad_perm:
1657 if (auto quadPermAttr = cast<ArrayAttr>(Val&: *permArgument)) {
1658 int32_t i = 0;
1659 for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
1660 uint32_t num = elem.getInt();
1661 DppCtrl |= num << (i * 2);
1662 i++;
1663 }
1664 }
1665 break;
1666 case DPPPerm::row_shl:
1667 if (auto intAttr = cast<IntegerAttr>(Val&: *permArgument)) {
1668 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
1669 }
1670 break;
1671 case DPPPerm::row_shr:
1672 if (auto intAttr = cast<IntegerAttr>(Val&: *permArgument)) {
1673 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
1674 }
1675 break;
1676 case DPPPerm::row_ror:
1677 if (auto intAttr = cast<IntegerAttr>(Val&: *permArgument)) {
1678 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
1679 }
1680 break;
1681 case DPPPerm::wave_shl:
1682 DppCtrl = DppCtrl::WAVE_SHL1;
1683 break;
1684 case DPPPerm::wave_shr:
1685 DppCtrl = DppCtrl::WAVE_SHR1;
1686 break;
1687 case DPPPerm::wave_rol:
1688 DppCtrl = DppCtrl::WAVE_ROL1;
1689 break;
1690 case DPPPerm::wave_ror:
1691 DppCtrl = DppCtrl::WAVE_ROR1;
1692 break;
1693 case DPPPerm::row_mirror:
1694 DppCtrl = DppCtrl::ROW_MIRROR;
1695 break;
1696 case DPPPerm::row_half_mirror:
1697 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
1698 break;
1699 case DPPPerm::row_bcast_15:
1700 DppCtrl = DppCtrl::BCAST15;
1701 break;
1702 case DPPPerm::row_bcast_31:
1703 DppCtrl = DppCtrl::BCAST31;
1704 break;
1705 }
1706
1707 // Check for row_mask, bank_mask, bound_ctrl if they exist and create
1708 // constants
1709 auto rowMask = DppOp->getAttrOfType<IntegerAttr>(name: "row_mask").getInt();
1710 auto bankMask = DppOp->getAttrOfType<IntegerAttr>(name: "bank_mask").getInt();
1711 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>(name: "bound_ctrl").getValue();
1712
1713 // create a ROCDL_DPPMovOp instruction with the appropriate attributes
1714 auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
1715 location: loc, args&: llvmType, args&: old, args&: src, args&: DppCtrl, args&: rowMask, args&: bankMask, args&: boundCtrl);
1716
1717 Value result = dppMovOp.getRes();
1718 if (srcType.getIntOrFloatBitWidth() < 32) {
1719 result = rewriter.create<LLVM::TruncOp>(location: loc, args&: llvmSrcIntType, args&: result);
1720 if (!llvm::isa<IntegerType>(Val: srcType)) {
1721 result = rewriter.create<LLVM::BitcastOp>(location: loc, args&: srcType, args&: result);
1722 }
1723 }
1724
1725 // We are replacing the AMDGPU_DPPOp instruction with the new
1726 // ROCDL_DPPMovOp instruction
1727 rewriter.replaceOp(op: DppOp, newValues: ValueRange(result));
1728 return success();
1729 }
1730};
1731
1732struct AMDGPUSwizzleBitModeLowering
1733 : public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
1734 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1735
1736 LogicalResult
1737 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
1738 ConversionPatternRewriter &rewriter) const override {
1739 Location loc = op.getLoc();
1740 Type i32 = rewriter.getI32Type();
1741 Value src = adaptor.getSrc();
1742 SmallVector<Value> decomposed =
1743 LLVM::decomposeValue(builder&: rewriter, loc, src, dstType: i32);
1744 unsigned andMask = op.getAndMask();
1745 unsigned orMask = op.getOrMask();
1746 unsigned xorMask = op.getXorMask();
1747
1748 // bit 15 is 0 for the BitMode swizzle.
1749 // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
1750 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
1751 Value maskValue = createI32Constant(rewriter, loc, value: mask);
1752 SmallVector<Value> swizzled;
1753 for (Value v : decomposed) {
1754 Value res =
1755 rewriter.create<ROCDL::DsSwizzleOp>(location: loc, args: v.getType(), args&: v, args&: maskValue);
1756 swizzled.emplace_back(Args&: res);
1757 }
1758
1759 Value result = LLVM::composeValue(builder&: rewriter, loc, src: swizzled, dstType: src.getType());
1760 rewriter.replaceOp(op, newValues: result);
1761 return success();
1762 }
1763};
1764
1765struct ConvertAMDGPUToROCDLPass
1766 : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
1767 using Base::Base;
1768
1769 void runOnOperation() override {
1770 MLIRContext *ctx = &getContext();
1771 FailureOr<Chipset> maybeChipset = Chipset::parse(name: chipset);
1772 if (failed(Result: maybeChipset)) {
1773 emitError(loc: UnknownLoc::get(context: ctx), message: "Invalid chipset name: " + chipset);
1774 return signalPassFailure();
1775 }
1776
1777 RewritePatternSet patterns(ctx);
1778 LLVMTypeConverter converter(ctx);
1779 populateAMDGPUToROCDLConversionPatterns(converter, patterns, chipset: *maybeChipset);
1780 LLVMConversionTarget target(getContext());
1781 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
1782 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
1783 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
1784 if (failed(Result: applyPartialConversion(op: getOperation(), target,
1785 patterns: std::move(patterns))))
1786 signalPassFailure();
1787 }
1788};
1789} // namespace
1790
1791void mlir::populateAMDGPUMemorySpaceAttributeConversions(
1792 TypeConverter &typeConverter) {
1793 typeConverter.addTypeAttributeConversion(
1794 callback: [](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
1795 -> TypeConverter::AttributeConversionResult {
1796 MLIRContext *ctx = as.getContext();
1797 Type i64 = IntegerType::get(context: ctx, width: 64);
1798 switch (as.getValue()) {
1799 case amdgpu::AddressSpace::FatRawBuffer:
1800 return IntegerAttr::get(type: i64, value: 7);
1801 case amdgpu::AddressSpace::BufferRsrc:
1802 return IntegerAttr::get(type: i64, value: 8);
1803 case amdgpu::AddressSpace::FatStructuredBuffer:
1804 return IntegerAttr::get(type: i64, value: 9);
1805 }
1806 return TypeConverter::AttributeConversionResult::abort();
1807 });
1808}
1809
1810void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
1811 RewritePatternSet &patterns,
1812 Chipset chipset) {
1813 populateAMDGPUMemorySpaceAttributeConversions(typeConverter&: converter);
1814 patterns
1815 .add<FatRawBufferCastLowering,
1816 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
1817 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
1818 RawBufferOpLowering<RawBufferAtomicFaddOp,
1819 ROCDL::RawPtrBufferAtomicFaddOp>,
1820 RawBufferOpLowering<RawBufferAtomicFmaxOp,
1821 ROCDL::RawPtrBufferAtomicFmaxOp>,
1822 RawBufferOpLowering<RawBufferAtomicSmaxOp,
1823 ROCDL::RawPtrBufferAtomicSmaxOp>,
1824 RawBufferOpLowering<RawBufferAtomicUminOp,
1825 ROCDL::RawPtrBufferAtomicUminOp>,
1826 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
1827 ROCDL::RawPtrBufferAtomicCmpSwap>,
1828 AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
1829 MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
1830 ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
1831 PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
1832 PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
1833 TransposeLoadOpLowering>(arg&: converter, args&: chipset);
1834 patterns.add<AMDGPUSwizzleBitModeLowering>(arg&: converter);
1835}
1836

source code of mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp