1//===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
10
11#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12#include "mlir/Conversion/LLVMCommon/Pattern.h"
13#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
15#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
17#include "mlir/IR/BuiltinTypes.h"
18#include "mlir/IR/TypeUtilities.h"
19#include "mlir/Pass/Pass.h"
20
21#include "llvm/ADT/STLExtras.h"
22#include <optional>
23
24namespace mlir {
25#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDL
26#include "mlir/Conversion/Passes.h.inc"
27} // namespace mlir
28
29using namespace mlir;
30using namespace mlir::amdgpu;
31
32static Value createI32Constant(ConversionPatternRewriter &rewriter,
33 Location loc, int32_t value) {
34 Type llvmI32 = rewriter.getI32Type();
35 return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, value);
36}
37
38static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
39 bool value) {
40 Type llvmI1 = rewriter.getI1Type();
41 return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value);
42}
43
44namespace {
45/// Define lowering patterns for raw buffer ops
46template <typename GpuOp, typename Intrinsic>
47struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
48 RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
49 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
50
51 Chipset chipset;
52 static constexpr uint32_t maxVectorOpWidth = 128;
53
54 LogicalResult
55 matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
56 ConversionPatternRewriter &rewriter) const override {
57 Location loc = gpuOp.getLoc();
58 Value memref = adaptor.getMemref();
59 Value unconvertedMemref = gpuOp.getMemref();
60 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
61
62 if (chipset.majorVersion < 9)
63 return gpuOp.emitOpError("raw buffer ops require GCN or higher");
64
65 Value storeData = adaptor.getODSOperands(0)[0];
66 if (storeData == memref) // no write component to this op
67 storeData = Value();
68 Type wantedDataType;
69 if (storeData)
70 wantedDataType = storeData.getType();
71 else
72 wantedDataType = gpuOp.getODSResults(0)[0].getType();
73
74 Value atomicCmpData = Value();
75 // Operand index 1 of a load is the indices, trying to read them can crash.
76 if (storeData) {
77 Value maybeCmpData = adaptor.getODSOperands(1)[0];
78 if (maybeCmpData != memref)
79 atomicCmpData = maybeCmpData;
80 }
81
82 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
83
84 Type i32 = rewriter.getI32Type();
85 Type llvmI32 = this->typeConverter->convertType(i32);
86 Type llvmI16 = this->typeConverter->convertType(rewriter.getI16Type());
87
88 int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8;
89 Value byteWidthConst = createI32Constant(rewriter, loc, value: elementByteWidth);
90
91 // If we want to load a vector<NxT> with total size <= 32
92 // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
93 // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
94 // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
95 // so bitcast any floats to integers. On top of all this, cast bfloat
96 // (vectors) to i16 since the backend doesn't currently support bfloat on
97 // these operations.
98 Type llvmBufferValType = llvmWantedDataType;
99 if (wantedDataType.isBF16())
100 llvmBufferValType = rewriter.getI16Type();
101 if (auto wantedVecType = dyn_cast<VectorType>(wantedDataType))
102 if (wantedVecType.getElementType().isBF16())
103 llvmBufferValType = wantedVecType.clone(rewriter.getI16Type());
104 if (atomicCmpData) {
105 if (isa<VectorType>(Val: wantedDataType))
106 return gpuOp.emitOpError("vector compare-and-swap does not exist");
107 if (auto floatType = dyn_cast<FloatType>(Val&: wantedDataType))
108 llvmBufferValType = this->getTypeConverter()->convertType(
109 rewriter.getIntegerType(floatType.getWidth()));
110 }
111 if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
112 uint32_t elemBits = dataVector.getElementTypeBitWidth();
113 uint32_t totalBits = elemBits * dataVector.getNumElements();
114 if (totalBits > maxVectorOpWidth)
115 return gpuOp.emitOpError(
116 "Total width of loads or stores must be no more than " +
117 Twine(maxVectorOpWidth) + " bits, but we call for " +
118 Twine(totalBits) +
119 " bits. This should've been caught in validation");
120 if (elemBits < 32) {
121 if (totalBits > 32) {
122 if (totalBits % 32 != 0)
123 return gpuOp.emitOpError("Load or store of more than 32-bits that "
124 "doesn't fit into words. Can't happen\n");
125 llvmBufferValType = this->typeConverter->convertType(
126 VectorType::get(totalBits / 32, i32));
127 } else {
128 llvmBufferValType = this->typeConverter->convertType(
129 rewriter.getIntegerType(totalBits));
130 }
131 }
132 }
133
134 SmallVector<Value, 6> args;
135 if (storeData) {
136 if (llvmBufferValType != llvmWantedDataType) {
137 Value castForStore =
138 rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
139 args.push_back(Elt: castForStore);
140 } else {
141 args.push_back(Elt: storeData);
142 }
143 }
144
145 if (atomicCmpData) {
146 if (llvmBufferValType != llvmWantedDataType) {
147 Value castForCmp = rewriter.create<LLVM::BitcastOp>(
148 loc, llvmBufferValType, atomicCmpData);
149 args.push_back(Elt: castForCmp);
150 } else {
151 args.push_back(Elt: atomicCmpData);
152 }
153 }
154
155 // Construct buffer descriptor from memref, attributes
156 int64_t offset = 0;
157 SmallVector<int64_t, 5> strides;
158 if (failed(getStridesAndOffset(memrefType, strides, offset)))
159 return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
160
161 MemRefDescriptor memrefDescriptor(memref);
162
163 Value ptr = memrefDescriptor.alignedPtr(builder&: rewriter, loc);
164 // The stride value is always 0 for raw buffers. This also disables
165 // swizling.
166 Value stride = rewriter.create<LLVM::ConstantOp>(
167 loc, llvmI16, rewriter.getI16IntegerAttr(0));
168 Value numRecords;
169 if (memrefType.hasStaticShape()) {
170 numRecords = createI32Constant(
171 rewriter, loc,
172 value: static_cast<int32_t>(memrefType.getNumElements() * elementByteWidth));
173 } else {
174 Value maxIndex;
175 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
176 Value size = memrefDescriptor.size(builder&: rewriter, loc, pos: i);
177 Value stride = memrefDescriptor.stride(builder&: rewriter, loc, pos: i);
178 stride = rewriter.create<LLVM::MulOp>(loc, stride, byteWidthConst);
179 Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
180 maxIndex = maxIndex ? rewriter.create<LLVM::MaximumOp>(loc, maxIndex,
181 maxThisDim)
182 : maxThisDim;
183 }
184 numRecords = rewriter.create<LLVM::TruncOp>(loc, llvmI32, maxIndex);
185 }
186
187 // Flag word:
188 // bits 0-11: dst sel, ignored by these intrinsics
189 // bits 12-14: data format (ignored, must be nonzero, 7=float)
190 // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
191 // bit 19: In nested heap (0 here)
192 // bit 20: Behavior on unmap (0 means "return 0 / ignore")
193 // bits 21-22: Index stride for swizzles (N/A)
194 // bit 23: Add thread ID (0)
195 // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
196 // bits 25-26: Reserved (0)
197 // bit 27: Buffer is non-volatile (CDNA only)
198 // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
199 // none, 3 = either swizzles or testing against offset field) RDNA only
200 // bits 30-31: Type (must be 0)
201 uint32_t flags = (7 << 12) | (4 << 15);
202 if (chipset.majorVersion >= 10) {
203 flags |= (1 << 24);
204 uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2;
205 flags |= (oob << 28);
206 }
207 Value flagsConst = createI32Constant(rewriter, loc, value: flags);
208 Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8);
209 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
210 loc, rsrcType, ptr, stride, numRecords, flagsConst);
211 args.push_back(Elt: resource);
212
213 // Indexing (voffset)
214 Value voffset = createI32Constant(rewriter, loc, value: 0);
215 for (auto pair : llvm::enumerate(adaptor.getIndices())) {
216 size_t i = pair.index();
217 Value index = pair.value();
218 Value strideOp;
219 if (ShapedType::isDynamic(strides[i])) {
220 strideOp = rewriter.create<LLVM::MulOp>(
221 loc, memrefDescriptor.stride(rewriter, loc, i), byteWidthConst);
222 } else {
223 strideOp =
224 createI32Constant(rewriter, loc, value: strides[i] * elementByteWidth);
225 }
226 index = rewriter.create<LLVM::MulOp>(loc, index, strideOp);
227 voffset = rewriter.create<LLVM::AddOp>(loc, voffset, index);
228 }
229 if (adaptor.getIndexOffset()) {
230 int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth;
231 Value extraOffsetConst = createI32Constant(rewriter, loc, value: indexOffset);
232 voffset =
233 voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
234 : extraOffsetConst;
235 }
236 args.push_back(Elt: voffset);
237
238 Value sgprOffset = adaptor.getSgprOffset();
239 if (!sgprOffset)
240 sgprOffset = createI32Constant(rewriter, loc, value: 0);
241 if (ShapedType::isDynamic(offset))
242 sgprOffset = rewriter.create<LLVM::AddOp>(
243 loc, memrefDescriptor.offset(rewriter, loc), sgprOffset);
244 else if (offset > 0)
245 sgprOffset = rewriter.create<LLVM::AddOp>(
246 loc, sgprOffset, createI32Constant(rewriter, loc, offset));
247 args.push_back(Elt: sgprOffset);
248
249 // bit 0: GLC = 0 (atomics drop value, less coherency)
250 // bits 1-2: SLC, DLC = 0 (similarly)
251 // bit 3: swizzled (0 for raw)
252 args.push_back(Elt: createI32Constant(rewriter, loc, value: 0));
253
254 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
255 llvmBufferValType);
256 Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args,
257 ArrayRef<NamedAttribute>());
258 if (lowered->getNumResults() == 1) {
259 Value replacement = lowered->getResult(idx: 0);
260 if (llvmBufferValType != llvmWantedDataType) {
261 replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType,
262 replacement);
263 }
264 rewriter.replaceOp(gpuOp, replacement);
265 } else {
266 rewriter.eraseOp(op: gpuOp);
267 }
268 return success();
269 }
270};
271
272struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
273 LDSBarrierOpLowering(LLVMTypeConverter &converter, Chipset chipset)
274 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
275
276 Chipset chipset;
277
278 LogicalResult
279 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
280 ConversionPatternRewriter &rewriter) const override {
281 bool requiresInlineAsm =
282 chipset.majorVersion < 9 ||
283 (chipset.majorVersion == 9 && chipset.minorVersion < 0x0a) ||
284 (chipset.majorVersion == 11);
285
286 if (requiresInlineAsm) {
287 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
288 LLVM::AsmDialect::AD_ATT);
289 const char *asmStr =
290 ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
291 const char *constraints = "";
292 rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>(
293 op,
294 /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
295 /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
296 /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
297 /*operand_attrs=*/ArrayAttr());
298 return success();
299 }
300 constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
301 constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8);
302 // Left in place in case someone disables the inline ASM path or future
303 // chipsets use the same bit pattern.
304 constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4);
305
306 int32_t ldsOnlyBits;
307 if (chipset.majorVersion == 11)
308 ldsOnlyBits = ldsOnlyBitsGfx11;
309 else if (chipset.majorVersion == 10)
310 ldsOnlyBits = ldsOnlyBitsGfx10;
311 else if (chipset.majorVersion <= 9)
312 ldsOnlyBits = ldsOnlyBitsGfx6789;
313 else
314 return op.emitOpError(
315 "don't know how to lower this for chipset major version")
316 << chipset.majorVersion;
317
318 Location loc = op->getLoc();
319 rewriter.create<ROCDL::WaitcntOp>(loc, ldsOnlyBits);
320 rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op);
321 return success();
322 }
323};
324} // namespace
325
326/// If `input` is a vector of bytes, concatentate those bytes in little-endian
327/// order to form a single integer of size 8 * [vector length]. This works
328/// around a wart in the AMDGPU intrinsics where operations that logically take
329/// vectors of bytes instead integers. Since we do not want to expose this
330/// implementation detail to MLIR, we correct for it here.
331///
332/// In addition, convert vectors of LLVM bfloats to vectors of i16, since AMDGPU
333/// MFMA intrinsics pre-date the bfloat type.
334static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter,
335 Location loc, Value input) {
336 Type inputType = input.getType();
337 if (auto vectorType = dyn_cast<VectorType>(inputType)) {
338 if (vectorType.getElementType().isBF16())
339 return rewriter.create<LLVM::BitcastOp>(
340 loc, vectorType.clone(rewriter.getI16Type()), input);
341
342 if (!vectorType.getElementType().isInteger(8))
343 return input;
344 int64_t numBytes = vectorType.getNumElements();
345 Type destType = rewriter.getIntegerType(numBytes * 8);
346 Value result = rewriter.create<LLVM::ConstantOp>(
347 loc, destType, rewriter.getIntegerAttr(destType, 0));
348 for (int64_t i = 0; i < numBytes; ++i) {
349 Value idxConst = createI32Constant(rewriter, loc, value: i);
350 Value element =
351 rewriter.create<LLVM::ExtractElementOp>(loc, input, idxConst);
352 Value extended = rewriter.create<LLVM::ZExtOp>(loc, destType, element);
353 Value shiftConst = rewriter.create<LLVM::ConstantOp>(
354 loc, destType, rewriter.getIntegerAttr(destType, i * 8));
355 Value shifted = rewriter.create<LLVM::ShlOp>(loc, extended, shiftConst);
356 result = rewriter.create<LLVM::OrOp>(loc, result, shifted);
357 }
358 return result;
359 }
360 return input;
361}
362
363/// Push an input operand. If it is a float type, nothing to do. If it is
364/// an integer type, then we need to also push its signdness (1 for signed, 0
365/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
366/// vector. We also need to convert bfloat inputs to i16 to account for the lack
367/// of bfloat support in the WMMA intrinsics themselves.
368static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
369 Location loc,
370 const TypeConverter *typeConverter,
371 bool isUnsigned, Value llvmInput,
372 SmallVector<Value, 4> &operands) {
373 Type inputType = llvmInput.getType();
374 auto vectorType = dyn_cast<VectorType>(inputType);
375 Type elemType = vectorType.getElementType();
376
377 if (elemType.isBF16())
378 llvmInput = rewriter.create<LLVM::BitcastOp>(
379 loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
380 if (!elemType.isInteger(width: 8)) {
381 operands.push_back(Elt: llvmInput);
382 return;
383 }
384
385 int64_t numBytes = vectorType.getNumElements();
386 Type i32 = rewriter.getI32Type();
387 VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32);
388 auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits);
389
390 Value result = rewriter.createOrFold<LLVM::BitcastOp>(
391 loc, llvmVectorType32bits, llvmInput);
392
393 // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
394 bool localIsUnsigned = isUnsigned;
395 if (elemType.isUnsignedInteger(width: 8)) {
396 localIsUnsigned = true;
397 } else if (elemType.isSignedInteger(width: 8)) {
398 localIsUnsigned = false;
399 }
400 Value sign = createI1Constant(rewriter, loc, value: !localIsUnsigned);
401 operands.push_back(Elt: sign);
402 operands.push_back(Elt: result);
403}
404
405/// Push the output operand. For many cases this is only pushing the output in
406/// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
407/// since the same numbers of VGPRs is used, we need to decide if to store the
408/// result in the upper 16 bits of the VGPRs or in the lower part. To store the
409/// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
410/// be stored it in the upper part
411static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
412 Location loc,
413 const TypeConverter *typeConverter,
414 Value output, int32_t subwordOffset,
415 bool clamp, SmallVector<Value, 4> &operands) {
416 Type inputType = output.getType();
417 auto vectorType = dyn_cast<VectorType>(inputType);
418 Type elemType = vectorType.getElementType();
419 if (elemType.isBF16())
420 output = rewriter.create<LLVM::BitcastOp>(
421 loc, vectorType.clone(rewriter.getI16Type()), output);
422 operands.push_back(Elt: output);
423 if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(width: 16)) {
424 operands.push_back(Elt: createI1Constant(rewriter, loc, value: subwordOffset));
425 } else if (elemType.isInteger(width: 32)) {
426 operands.push_back(Elt: createI1Constant(rewriter, loc, value: clamp));
427 }
428}
429
430/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
431/// if one exists. This includes checking to ensure the intrinsic is supported
432/// on the architecture you are compiling for.
433static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
434 Chipset chipset) {
435 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
436 b = mfma.getBlocks();
437 Type sourceElem = mfma.getSourceA().getType();
438 if (auto sourceType = dyn_cast<VectorType>(sourceElem))
439 sourceElem = sourceType.getElementType();
440 Type destElem = mfma.getDestC().getType();
441 if (auto destType = dyn_cast<VectorType>(destElem))
442 destElem = destType.getElementType();
443
444 if (sourceElem.isF32() && destElem.isF32()) {
445 if (mfma.getReducePrecision() && chipset.minorVersion >= 0x40) {
446 if (m == 32 && n == 32 && k == 4 && b == 1)
447 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
448 if (m == 16 && n == 16 && k == 8 && b == 1)
449 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
450 }
451 if (m == 32 && n == 32 && k == 1 && b == 2)
452 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
453 if (m == 16 && n == 16 && k == 1 && b == 4)
454 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
455 if (m == 4 && n == 4 && k == 1 && b == 16)
456 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
457 if (m == 32 && n == 32 && k == 2 && b == 1)
458 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
459 if (m == 16 && n == 16 && k == 4 && b == 1)
460 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
461 }
462
463 if (sourceElem.isF16() && destElem.isF32()) {
464 if (m == 32 && n == 32 && k == 4 && b == 2)
465 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
466 if (m == 16 && n == 16 && k == 4 && b == 4)
467 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
468 if (m == 4 && n == 4 && k == 4 && b == 16)
469 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
470 if (m == 32 && n == 32 && k == 8 && b == 1)
471 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
472 if (m == 16 && n == 16 && k == 16 && b == 1)
473 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
474 }
475
476 if (sourceElem.isBF16() && destElem.isF32() && chipset.minorVersion >= 0x0a) {
477 if (m == 32 && n == 32 && k == 4 && b == 2)
478 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
479 if (m == 16 && n == 16 && k == 4 && b == 4)
480 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
481 if (m == 4 && n == 4 && k == 4 && b == 16)
482 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
483 if (m == 32 && n == 32 && k == 8 && b == 1)
484 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
485 if (m == 16 && n == 16 && k == 16 && b == 1)
486 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
487 }
488
489 if (sourceElem.isBF16() && destElem.isF32()) {
490 if (m == 32 && n == 32 && k == 2 && b == 2)
491 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
492 if (m == 16 && n == 16 && k == 2 && b == 4)
493 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
494 if (m == 4 && n == 4 && k == 2 && b == 16)
495 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
496 if (m == 32 && n == 32 && k == 4 && b == 1)
497 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
498 if (m == 16 && n == 16 && k == 8 && b == 1)
499 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
500 }
501
502 if (isa<IntegerType>(Val: sourceElem) && destElem.isInteger(width: 32)) {
503 if (m == 32 && n == 32 && k == 4 && b == 2)
504 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
505 if (m == 16 && n == 16 && k == 4 && b == 4)
506 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
507 if (m == 4 && n == 4 && k == 4 && b == 16)
508 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
509 if (m == 32 && n == 32 && k == 8 && b == 1)
510 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
511 if (m == 16 && n == 16 && k == 16 && b == 1)
512 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
513 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset.minorVersion >= 0x40)
514 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
515 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset.minorVersion >= 0x40)
516 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
517 }
518
519 if (sourceElem.isF64() && destElem.isF64() && chipset.minorVersion >= 0x0a) {
520 if (m == 16 && n == 16 && k == 4 && b == 1)
521 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
522 if (m == 4 && n == 4 && k == 4 && b == 4)
523 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
524 }
525
526 if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() &&
527 chipset.minorVersion >= 0x40) {
528 // Known to be correct because there are no scalar f8 instructions and
529 // because a length mismatch will have been caught by the verifier.
530 Type sourceBElem =
531 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
532 if (m == 16 && n == 16 && k == 32 && b == 1) {
533 if (sourceBElem.isFloat8E5M2FNUZ())
534 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
535 if (sourceBElem.isFloat8E4M3FNUZ())
536 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
537 }
538 if (m == 32 && n == 32 && k == 16 && b == 1) {
539 if (sourceBElem.isFloat8E5M2FNUZ())
540 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
541 if (sourceBElem.isFloat8E4M3FNUZ())
542 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
543 }
544 }
545
546 if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() &&
547 chipset.minorVersion >= 0x40) {
548 Type sourceBElem =
549 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
550 if (m == 16 && n == 16 && k == 32 && b == 1) {
551 if (sourceBElem.isFloat8E5M2FNUZ())
552 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
553 if (sourceBElem.isFloat8E4M3FNUZ())
554 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
555 }
556 if (m == 32 && n == 32 && k == 16 && b == 1) {
557 if (sourceBElem.isFloat8E5M2FNUZ())
558 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
559 if (sourceBElem.isFloat8E4M3FNUZ())
560 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
561 }
562 }
563
564 return std::nullopt;
565}
566
567/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
568/// if one exists. This includes checking to ensure the intrinsic is supported
569/// on the architecture you are compiling for.
570static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
571 Chipset chipset) {
572 auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
573 auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
574 auto elemSourceType = sourceVectorType.getElementType();
575 auto elemDestType = destVectorType.getElementType();
576
577 if (elemSourceType.isF16() && elemDestType.isF32()) {
578 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
579 }
580 if (elemSourceType.isBF16() && elemDestType.isF32()) {
581 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
582 } else if (elemSourceType.isF16() && elemDestType.isF16()) {
583 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
584 } else if (elemSourceType.isBF16() && elemDestType.isBF16()) {
585 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
586 } else if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) {
587 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
588 }
589 return std::nullopt;
590}
591
592namespace {
593struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
594 MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
595 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
596
597 Chipset chipset;
598
599 LogicalResult
600 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
601 ConversionPatternRewriter &rewriter) const override {
602 Location loc = op.getLoc();
603 Type outType = typeConverter->convertType(op.getDestD().getType());
604 Type intrinsicOutType = outType;
605 if (auto outVecType = dyn_cast<VectorType>(outType))
606 if (outVecType.getElementType().isBF16())
607 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
608
609 if (chipset.majorVersion != 9 || chipset.minorVersion < 0x08)
610 return op->emitOpError("MFMA only supported on gfx908+");
611 uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
612 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
613 if (chipset.minorVersion < 0x40)
614 return op.emitOpError("negation unsupported on older than gfx840");
615 getBlgpField |=
616 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
617 }
618 std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
619 if (!maybeIntrinsic.has_value())
620 return op.emitOpError("no intrinsic matching MFMA size on given chipset");
621 OperationState loweredOp(loc, *maybeIntrinsic);
622 loweredOp.addTypes(newTypes: intrinsicOutType);
623 loweredOp.addOperands(
624 newOperands: {mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceA()),
625 mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceB()),
626 adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()),
627 createI32Constant(rewriter, loc, op.getAbid()),
628 createI32Constant(rewriter, loc, value: getBlgpField)});
629 Value lowered = rewriter.create(state: loweredOp)->getResult(idx: 0);
630 if (outType != intrinsicOutType)
631 lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
632 rewriter.replaceOp(op, lowered);
633 return success();
634 }
635};
636
637struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
638 WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
639 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
640
641 Chipset chipset;
642
643 LogicalResult
644 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
645 ConversionPatternRewriter &rewriter) const override {
646 Location loc = op.getLoc();
647 Type outType = typeConverter->convertType(op.getDestD().getType());
648
649 if (chipset.majorVersion != 11)
650 return op->emitOpError("WMMA only supported on gfx11");
651
652 std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
653
654 if (!maybeIntrinsic.has_value())
655 return op.emitOpError("no intrinsic matching WMMA on the given chipset");
656
657 OperationState loweredOp(loc, *maybeIntrinsic);
658 loweredOp.addTypes(newTypes: outType);
659
660 SmallVector<Value, 4> operands;
661 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
662 adaptor.getSourceA(), operands);
663 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
664 adaptor.getSourceB(), operands);
665 wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
666 op.getSubwordOffset(), op.getClamp(), operands);
667
668 loweredOp.addOperands(newOperands: operands);
669 Operation *lowered = rewriter.create(state: loweredOp);
670 rewriter.replaceOp(op, lowered->getResults());
671
672 return success();
673 }
674};
675
676namespace {
677struct ExtPackedFp8OpLowering final
678 : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
679 ExtPackedFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset)
680 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
681 chipset(chipset) {}
682 Chipset chipset;
683
684 LogicalResult
685 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
686 ConversionPatternRewriter &rewriter) const override;
687};
688
689struct PackedTrunc2xFp8OpLowering final
690 : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
691 PackedTrunc2xFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset)
692 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
693 chipset(chipset) {}
694 Chipset chipset;
695
696 LogicalResult
697 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
698 ConversionPatternRewriter &rewriter) const override;
699};
700
701struct PackedStochRoundFp8OpLowering final
702 : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
703 PackedStochRoundFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset)
704 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
705 chipset(chipset) {}
706 Chipset chipset;
707
708 LogicalResult
709 matchAndRewrite(PackedStochRoundFp8Op op,
710 PackedStochRoundFp8OpAdaptor adaptor,
711 ConversionPatternRewriter &rewriter) const override;
712};
713} // end namespace
714
715LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
716 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
717 ConversionPatternRewriter &rewriter) const {
718 Location loc = op.getLoc();
719 if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
720 return rewriter.notifyMatchFailure(
721 arg&: loc, msg: "Fp8 conversion instructions are not available on target "
722 "architecture and their emulation is not implemented");
723 Type v4i8 =
724 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
725 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
726 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
727
728 Value source = adaptor.getSource();
729 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
730 Type sourceElemType = getElementTypeOrSelf(op.getSource());
731 // Extend to a v4i8
732 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
733 Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
734 if (!sourceVecType) {
735 longVec = rewriter.create<LLVM::InsertElementOp>(
736 loc, longVec, source, createI32Constant(rewriter, loc, 0));
737 } else {
738 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
739 Value idx = createI32Constant(rewriter, loc, value: i);
740 Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
741 longVec =
742 rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
743 }
744 }
745 source = longVec;
746 }
747 Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
748 Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
749 if (sourceElemType.isFloat8E5M2FNUZ()) {
750 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
751 wordSel);
752 } else if (sourceElemType.isFloat8E4M3FNUZ()) {
753 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
754 wordSel);
755 }
756 return success();
757}
758
759LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
760 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
761 ConversionPatternRewriter &rewriter) const {
762 Location loc = op.getLoc();
763 if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
764 return rewriter.notifyMatchFailure(
765 arg&: loc, msg: "Fp8 conversion instructions are not available on target "
766 "architecture and their emulation is not implemented");
767 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
768
769 Type resultType = op.getResult().getType();
770 Type resultElemType = getElementTypeOrSelf(type: resultType);
771
772 Value sourceA = adaptor.getSourceA();
773 Value sourceB = adaptor.getSourceB();
774 if (!sourceB)
775 sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType());
776 Value existing = adaptor.getExisting();
777 if (existing)
778 existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
779 else
780 existing = rewriter.create<LLVM::UndefOp>(loc, i32);
781 Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
782
783 Value result;
784 if (resultElemType.isFloat8E5M2FNUZ())
785 result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
786 existing, wordSel);
787 else if (resultElemType.isFloat8E4M3FNUZ())
788 result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
789 existing, wordSel);
790
791 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
792 op, getTypeConverter()->convertType(resultType), result);
793 return success();
794}
795
796LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
797 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
798 ConversionPatternRewriter &rewriter) const {
799 Location loc = op.getLoc();
800 if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40)
801 return rewriter.notifyMatchFailure(
802 arg&: loc, msg: "Fp8 conversion instructions are not available on target "
803 "architecture and their emulation is not implemented");
804 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
805
806 Type resultType = op.getResult().getType();
807 Type resultElemType = getElementTypeOrSelf(type: resultType);
808
809 Value source = adaptor.getSource();
810 Value stoch = adaptor.getStochiasticParam();
811 Value existing = adaptor.getExisting();
812 if (existing)
813 existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
814 else
815 existing = rewriter.create<LLVM::UndefOp>(loc, i32);
816 Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
817
818 Value result;
819 if (resultElemType.isFloat8E5M2FNUZ())
820 result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
821 existing, byteSel);
822 else if (resultElemType.isFloat8E4M3FNUZ())
823 result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
824 existing, byteSel);
825
826 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
827 op, getTypeConverter()->convertType(resultType), result);
828 return success();
829}
830
831struct ConvertAMDGPUToROCDLPass
832 : public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
833 ConvertAMDGPUToROCDLPass() = default;
834
835 void runOnOperation() override {
836 MLIRContext *ctx = &getContext();
837 FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
838 if (failed(result: maybeChipset)) {
839 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
840 return signalPassFailure();
841 }
842
843 RewritePatternSet patterns(ctx);
844 LLVMTypeConverter converter(ctx);
845 populateAMDGPUToROCDLConversionPatterns(converter, patterns, chipset: *maybeChipset);
846 LLVMConversionTarget target(getContext());
847 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
848 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
849 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
850 if (failed(applyPartialConversion(getOperation(), target,
851 std::move(patterns))))
852 signalPassFailure();
853 }
854};
855} // namespace
856
857void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
858 RewritePatternSet &patterns,
859 Chipset chipset) {
860 converter.addConversion(callback: [](BFloat16Type t) -> Type {
861 return IntegerType::get(t.getContext(), 16);
862 });
863 converter.addConversion(callback: [&converter](VectorType t) -> std::optional<Type> {
864 if (!t.getElementType().isBF16())
865 return std::nullopt;
866 return converter.convertType(t.clone(IntegerType::get(t.getContext(), 16)));
867 });
868
869 patterns
870 .add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
871 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
872 RawBufferOpLowering<RawBufferAtomicFaddOp,
873 ROCDL::RawPtrBufferAtomicFaddOp>,
874 RawBufferOpLowering<RawBufferAtomicFmaxOp,
875 ROCDL::RawPtrBufferAtomicFmaxOp>,
876 RawBufferOpLowering<RawBufferAtomicSmaxOp,
877 ROCDL::RawPtrBufferAtomicSmaxOp>,
878 RawBufferOpLowering<RawBufferAtomicUminOp,
879 ROCDL::RawPtrBufferAtomicUminOp>,
880 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
881 ROCDL::RawPtrBufferAtomicCmpSwap>,
882 LDSBarrierOpLowering, MFMAOpLowering, WMMAOpLowering,
883 ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
884 PackedStochRoundFp8OpLowering>(converter, chipset);
885}
886
887std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() {
888 return std::make_unique<ConvertAMDGPUToROCDLPass>();
889}
890

source code of mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp