1//===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
10
11#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12#include "mlir/Conversion/LLVMCommon/Pattern.h"
13#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
15#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
16#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
18#include "mlir/IR/BuiltinTypes.h"
19#include "mlir/IR/TypeUtilities.h"
20#include "mlir/Pass/Pass.h"
21
22#include "../LLVMCommon/MemRefDescriptor.h"
23
24#include "llvm/ADT/STLExtras.h"
25#include "llvm/ADT/TypeSwitch.h"
26#include "llvm/Support/Casting.h"
27#include <optional>
28
29namespace mlir {
30#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
31#include "mlir/Conversion/Passes.h.inc"
32} // namespace mlir
33
34using namespace mlir;
35using namespace mlir::amdgpu;
36
37// Define commonly used chipsets versions for convenience.
38constexpr Chipset kGfx908 = Chipset(9, 0, 8);
39constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
40constexpr Chipset kGfx942 = Chipset(9, 4, 2);
41constexpr Chipset kGfx950 = Chipset(9, 5, 0);
42
43/// Convert an unsigned number `val` to i32.
44static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
45 Location loc, Value val) {
46 IntegerType i32 = rewriter.getI32Type();
47 // Force check that `val` is of int type.
48 auto valTy = cast<IntegerType>(val.getType());
49 if (i32 == valTy)
50 return val;
51 return valTy.getWidth() > 32
52 ? Value(rewriter.create<LLVM::TruncOp>(loc, i32, val))
53 : Value(rewriter.create<LLVM::ZExtOp>(loc, i32, val));
54}
55
56static Value createI32Constant(ConversionPatternRewriter &rewriter,
57 Location loc, int32_t value) {
58 Type i32 = rewriter.getI32Type();
59 return rewriter.create<LLVM::ConstantOp>(loc, i32, value);
60}
61
62static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
63 bool value) {
64 Type llvmI1 = rewriter.getI1Type();
65 return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value);
66}
67
68/// Returns the linear index used to access an element in the memref.
69static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
70 Location loc, MemRefDescriptor &memRefDescriptor,
71 ValueRange indices, ArrayRef<int64_t> strides) {
72 IntegerType i32 = rewriter.getI32Type();
73 Value index;
74 for (auto [i, increment, stride] : llvm::enumerate(First&: indices, Rest&: strides)) {
75 if (stride != 1) { // Skip if stride is 1.
76 Value strideValue =
77 ShapedType::isDynamic(stride)
78 ? convertUnsignedToI32(rewriter, loc,
79 memRefDescriptor.stride(rewriter, loc, i))
80 : rewriter.create<LLVM::ConstantOp>(loc, i32, stride);
81 increment = rewriter.create<LLVM::MulOp>(loc, increment, strideValue);
82 }
83 index =
84 index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
85 }
86 return index ? index : createI32Constant(rewriter, loc, value: 0);
87}
88
89/// Compute the contents of the `num_records` field for a given memref
90/// descriptor - that is, the number of bytes that's one element past the
91/// greatest possible valid index into the memref.
92static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
93 MemRefType memrefType,
94 MemRefDescriptor &memrefDescriptor,
95 ArrayRef<int64_t> strides,
96 uint32_t elementByteWidth) {
97 if (memrefType.hasStaticShape() &&
98 !llvm::any_of(strides, ShapedType::isDynamic)) {
99 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
100 ArrayRef<int64_t> shape = memrefType.getShape();
101 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
102 size = std::max(a: shape[i] * strides[i], b: size);
103 size = size * elementByteWidth;
104 assert(size < std::numeric_limits<uint32_t>::max() &&
105 "the memref buffer is too large");
106 return createI32Constant(rewriter, loc, value: static_cast<int32_t>(size));
107 }
108 Value maxIndex;
109 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
110 Value size = memrefDescriptor.size(builder&: rewriter, loc, pos: i);
111 Value stride = memrefDescriptor.stride(builder&: rewriter, loc, pos: i);
112 Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
113 maxIndex = maxIndex
114 ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
115 : maxThisDim;
116 }
117 Value maxIndexI32 = convertUnsignedToI32(rewriter, loc, val: maxIndex);
118 Value byteWidthConst = createI32Constant(rewriter, loc, value: elementByteWidth);
119 return rewriter.create<LLVM::MulOp>(loc, maxIndexI32, byteWidthConst);
120}
121
122static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
123 Value basePointer, Value numRecords,
124 bool boundsCheck, amdgpu::Chipset chipset,
125 Value cacheSwizzleStride = nullptr,
126 unsigned addressSpace = 8) {
127 // The stride value is generally 0. However, on MI-300 and onward, you can
128 // enable a cache swizzling mode by setting bit 14 of the stride field
129 // and setting that stride to a cache stride.
130 Type i16 = rewriter.getI16Type();
131 Value stride;
132 if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
133 Value cacheStrideZext =
134 rewriter.create<LLVM::ZExtOp>(loc, i16, cacheSwizzleStride);
135 Value swizzleBit = rewriter.create<LLVM::ConstantOp>(
136 loc, i16, rewriter.getI16IntegerAttr(1 << 14));
137 stride = rewriter.create<LLVM::OrOp>(loc, cacheStrideZext, swizzleBit,
138 /*isDisjoint=*/true);
139 } else {
140 stride = rewriter.create<LLVM::ConstantOp>(loc, i16,
141 rewriter.getI16IntegerAttr(0));
142 }
143 // Get the number of elements.
144 // Flag word:
145 // bits 0-11: dst sel, ignored by these intrinsics
146 // bits 12-14: data format (ignored, must be nonzero, 7=float)
147 // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
148 // bit 19: In nested heap (0 here)
149 // bit 20: Behavior on unmap (0 means "return 0 / ignore")
150 // bits 21-22: Index stride for swizzles (N/A)
151 // bit 23: Add thread ID (0)
152 // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
153 // bits 25-26: Reserved (0)
154 // bit 27: Buffer is non-volatile (CDNA only)
155 // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
156 // none, 3 = either swizzles or testing against offset field) RDNA only
157 // bits 30-31: Type (must be 0)
158 uint32_t flags = (7 << 12) | (4 << 15);
159 if (chipset.majorVersion >= 10) {
160 flags |= (1 << 24);
161 uint32_t oob = boundsCheck ? 3 : 2;
162 flags |= (oob << 28);
163 }
164 Value flagsConst = createI32Constant(rewriter, loc, value: flags);
165 Type rsrcType =
166 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
167 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
168 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
169 return resource;
170}
171
172namespace {
173struct FatRawBufferCastLowering
174 : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
175 FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
176 : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
177 chipset(chipset) {}
178
179 Chipset chipset;
180
181 LogicalResult
182 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
183 ConversionPatternRewriter &rewriter) const override {
184 Location loc = op.getLoc();
185 Value memRef = adaptor.getSource();
186 Value unconvertedMemref = op.getSource();
187 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
188 MemRefDescriptor descriptor(memRef);
189
190 DataLayout dataLayout = DataLayout::closest(op: op);
191 int64_t elementByteWidth =
192 dataLayout.getTypeSizeInBits(t: memrefType.getElementType()) / 8;
193
194 int64_t unusedOffset = 0;
195 SmallVector<int64_t, 5> strideVals;
196 if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
197 return op.emitOpError("Can't lower non-stride-offset memrefs");
198
199 Value numRecords = adaptor.getValidBytes();
200 if (!numRecords)
201 numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,
202 strideVals, elementByteWidth);
203
204 Value basePointer =
205 adaptor.getResetOffset()
206 ? descriptor.bufferPtr(builder&: rewriter, loc, converter: *getTypeConverter(),
207 type: memrefType)
208 : descriptor.alignedPtr(builder&: rewriter, loc);
209
210 Value offset = adaptor.getResetOffset()
211 ? rewriter.create<LLVM::ConstantOp>(
212 loc, getIndexType(), rewriter.getIndexAttr(0))
213 : descriptor.offset(rewriter, loc);
214
215 bool hasSizes = memrefType.getRank() > 0;
216 // No need to unpack() and pack() all the individual sizes and strides,
217 // so we'll just extract the arrays.
218 Value sizes = hasSizes ? rewriter.create<LLVM::ExtractValueOp>(
219 loc, descriptor, kSizePosInMemRefDescriptor)
220 : Value{};
221 Value strides = hasSizes
222 ? rewriter.create<LLVM::ExtractValueOp>(
223 loc, descriptor, kStridePosInMemRefDescriptor)
224 : Value{};
225
226 Value fatPtr = makeBufferRsrc(
227 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
228 chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7);
229
230 Value result = MemRefDescriptor::poison(
231 builder&: rewriter, loc,
232 descriptorType: getTypeConverter()->convertType(op.getResult().getType()));
233 result = rewriter.create<LLVM::InsertValueOp>(
234 loc, result, fatPtr, kAllocatedPtrPosInMemRefDescriptor);
235 result = rewriter.create<LLVM::InsertValueOp>(
236 loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor);
237 result = rewriter.create<LLVM::InsertValueOp>(loc, result, offset,
238 kOffsetPosInMemRefDescriptor);
239 if (hasSizes) {
240 result = rewriter.create<LLVM::InsertValueOp>(loc, result, sizes,
241 kSizePosInMemRefDescriptor);
242 result = rewriter.create<LLVM::InsertValueOp>(
243 loc, result, strides, kStridePosInMemRefDescriptor);
244 }
245 rewriter.replaceOp(op, result);
246 return success();
247 }
248};
249
250/// Define lowering patterns for raw buffer ops
251template <typename GpuOp, typename Intrinsic>
252struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
253 RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
254 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
255
256 Chipset chipset;
257 static constexpr uint32_t maxVectorOpWidth = 128;
258
259 LogicalResult
260 matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
261 ConversionPatternRewriter &rewriter) const override {
262 Location loc = gpuOp.getLoc();
263 Value memref = adaptor.getMemref();
264 Value unconvertedMemref = gpuOp.getMemref();
265 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
266
267 if (chipset.majorVersion < 9)
268 return gpuOp.emitOpError("raw buffer ops require GCN or higher");
269
270 Value storeData = adaptor.getODSOperands(0)[0];
271 if (storeData == memref) // no write component to this op
272 storeData = Value();
273 Type wantedDataType;
274 if (storeData)
275 wantedDataType = storeData.getType();
276 else
277 wantedDataType = gpuOp.getODSResults(0)[0].getType();
278
279 Value atomicCmpData = Value();
280 // Operand index 1 of a load is the indices, trying to read them can crash.
281 if (storeData) {
282 Value maybeCmpData = adaptor.getODSOperands(1)[0];
283 if (maybeCmpData != memref)
284 atomicCmpData = maybeCmpData;
285 }
286
287 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
288
289 Type i32 = rewriter.getI32Type();
290
291 // Get the type size in bytes.
292 DataLayout dataLayout = DataLayout::closest(op: gpuOp);
293 int64_t elementByteWidth =
294 dataLayout.getTypeSizeInBits(t: memrefType.getElementType()) / 8;
295 Value byteWidthConst = createI32Constant(rewriter, loc, value: elementByteWidth);
296
297 // If we want to load a vector<NxT> with total size <= 32
298 // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
299 // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
300 // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
301 // so bitcast any floats to integers.
302 Type llvmBufferValType = llvmWantedDataType;
303 if (atomicCmpData) {
304 if (auto floatType = dyn_cast<FloatType>(wantedDataType))
305 llvmBufferValType = this->getTypeConverter()->convertType(
306 rewriter.getIntegerType(floatType.getWidth()));
307 }
308 if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
309 uint32_t vecLen = dataVector.getNumElements();
310 uint32_t elemBits =
311 dataLayout.getTypeSizeInBits(t: dataVector.getElementType());
312 uint32_t totalBits = elemBits * vecLen;
313 bool usePackedFp16 =
314 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
315 if (totalBits > maxVectorOpWidth)
316 return gpuOp.emitOpError(
317 "Total width of loads or stores must be no more than " +
318 Twine(maxVectorOpWidth) + " bits, but we call for " +
319 Twine(totalBits) +
320 " bits. This should've been caught in validation");
321 if (!usePackedFp16 && elemBits < 32) {
322 if (totalBits > 32) {
323 if (totalBits % 32 != 0)
324 return gpuOp.emitOpError("Load or store of more than 32-bits that "
325 "doesn't fit into words. Can't happen\n");
326 llvmBufferValType = this->typeConverter->convertType(
327 VectorType::get(totalBits / 32, i32));
328 } else {
329 llvmBufferValType = this->typeConverter->convertType(
330 rewriter.getIntegerType(totalBits));
331 }
332 }
333 }
334 if (auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
335 // Buffer intrinsics doesn't support 1-element vectors, cast them to
336 // scalars.
337 if (vecType.getNumElements() == 1)
338 llvmBufferValType = vecType.getElementType();
339 }
340
341 SmallVector<Value, 6> args;
342 if (storeData) {
343 if (llvmBufferValType != llvmWantedDataType) {
344 Value castForStore =
345 rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
346 args.push_back(Elt: castForStore);
347 } else {
348 args.push_back(Elt: storeData);
349 }
350 }
351
352 if (atomicCmpData) {
353 if (llvmBufferValType != llvmWantedDataType) {
354 Value castForCmp = rewriter.create<LLVM::BitcastOp>(
355 loc, llvmBufferValType, atomicCmpData);
356 args.push_back(Elt: castForCmp);
357 } else {
358 args.push_back(Elt: atomicCmpData);
359 }
360 }
361
362 // Construct buffer descriptor from memref, attributes
363 int64_t offset = 0;
364 SmallVector<int64_t, 5> strides;
365 if (failed(memrefType.getStridesAndOffset(strides, offset)))
366 return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
367
368 MemRefDescriptor memrefDescriptor(memref);
369
370 Value ptr = memrefDescriptor.bufferPtr(
371 builder&: rewriter, loc, converter: *this->getTypeConverter(), type: memrefType);
372 Value numRecords = getNumRecords(
373 rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
374 Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords,
375 adaptor.getBoundsCheck(), chipset);
376 args.push_back(Elt: resource);
377
378 // Indexing (voffset)
379 Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
380 adaptor.getIndices(), strides);
381 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
382 indexOffset && *indexOffset > 0) {
383 Value extraOffsetConst = createI32Constant(rewriter, loc, value: *indexOffset);
384 voffset =
385 voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
386 : extraOffsetConst;
387 }
388 voffset = rewriter.create<LLVM::MulOp>(loc, voffset, byteWidthConst);
389 args.push_back(Elt: voffset);
390
391 // SGPR offset.
392 Value sgprOffset = adaptor.getSgprOffset();
393 if (!sgprOffset)
394 sgprOffset = createI32Constant(rewriter, loc, value: 0);
395 sgprOffset = rewriter.create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst);
396 args.push_back(Elt: sgprOffset);
397
398 // bit 0: GLC = 0 (atomics drop value, less coherency)
399 // bits 1-2: SLC, DLC = 0 (similarly)
400 // bit 3: swizzled (0 for raw)
401 args.push_back(Elt: createI32Constant(rewriter, loc, value: 0));
402
403 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
404 llvmBufferValType);
405 Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args,
406 ArrayRef<NamedAttribute>());
407 if (lowered->getNumResults() == 1) {
408 Value replacement = lowered->getResult(idx: 0);
409 if (llvmBufferValType != llvmWantedDataType) {
410 replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType,
411 replacement);
412 }
413 rewriter.replaceOp(gpuOp, replacement);
414 } else {
415 rewriter.eraseOp(op: gpuOp);
416 }
417 return success();
418 }
419};
420
421struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
422 LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
423 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
424
425 Chipset chipset;
426
427 LogicalResult
428 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
429 ConversionPatternRewriter &rewriter) const override {
430 bool requiresInlineAsm = chipset < kGfx90a || chipset.majorVersion == 11;
431
432 if (requiresInlineAsm) {
433 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
434 LLVM::AsmDialect::AD_ATT);
435 const char *asmStr =
436 ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
437 const char *constraints = "";
438 rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>(
439 op,
440 /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
441 /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
442 /*is_align_stack=*/false, LLVM::TailCallKind::None,
443 /*asm_dialect=*/asmDialectAttr,
444 /*operand_attrs=*/ArrayAttr());
445 return success();
446 }
447 if (chipset.majorVersion < 12) {
448 constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
449 constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8);
450 // Left in place in case someone disables the inline ASM path or future
451 // chipsets use the same bit pattern.
452 constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4);
453
454 int32_t ldsOnlyBits;
455 if (chipset.majorVersion == 11)
456 ldsOnlyBits = ldsOnlyBitsGfx11;
457 else if (chipset.majorVersion == 10)
458 ldsOnlyBits = ldsOnlyBitsGfx10;
459 else if (chipset.majorVersion <= 9)
460 ldsOnlyBits = ldsOnlyBitsGfx6789;
461 else
462 return op.emitOpError(
463 "don't know how to lower this for chipset major version")
464 << chipset.majorVersion;
465
466 Location loc = op->getLoc();
467 rewriter.create<ROCDL::SWaitcntOp>(loc, ldsOnlyBits);
468 rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op);
469 } else {
470 Location loc = op->getLoc();
471 rewriter.create<ROCDL::WaitDscntOp>(loc, 0);
472 rewriter.create<ROCDL::BarrierSignalOp>(loc, -1);
473 rewriter.replaceOpWithNewOp<ROCDL::BarrierWaitOp>(op, -1);
474 }
475
476 return success();
477 }
478};
479
480struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
481 SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
482 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
483
484 Chipset chipset;
485
486 LogicalResult
487 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
488 ConversionPatternRewriter &rewriter) const override {
489 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
490 (uint32_t)op.getOpts());
491 return success();
492 }
493};
494
495} // namespace
496
497/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
498/// and LLVM AMDGPU intrinsics convention.
499///
500/// Specifically:
501/// 1. If the element type is bfloat16, bitcast it to i16.
502/// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
503/// instead, which is what the f8f6f4 intrinsics use.
504/// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
505/// integer.
506///
507/// Note that the type of `input` has already been LLVM type converted:
508/// therefore 8-bit and smaller floats are represented as their corresponding
509/// `iN` integers.
510static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
511 Location loc, Value input) {
512 Type inputType = input.getType();
513 if (auto vectorType = dyn_cast<VectorType>(inputType)) {
514 if (vectorType.getElementType().isBF16())
515 return rewriter.create<LLVM::BitcastOp>(
516 loc, vectorType.clone(rewriter.getI16Type()), input);
517 if (vectorType.getElementType().isInteger(8) &&
518 vectorType.getNumElements() <= 8)
519 return rewriter.create<LLVM::BitcastOp>(
520 loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
521 if (isa<IntegerType>(vectorType.getElementType()) &&
522 vectorType.getElementTypeBitWidth() <= 8) {
523 int64_t numWords = llvm::divideCeil(
524 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
525 32);
526 return rewriter.create<LLVM::BitcastOp>(
527 loc, VectorType::get(numWords, rewriter.getI32Type()), input);
528 }
529 }
530 return input;
531}
532
533/// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
534/// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
535///
536/// Specifically:
537/// 1. If `input` is a i8 value, zero extend it to i32
538/// 2. If `input` is a vector of length 4 and type i8, cast it to i32
539///
540/// Note that the type of `input` has already been LLVM type converted:
541/// therefore 8-bit and smaller floats are represented as their corresponding
542/// `iN` integers.
543static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
544 Location loc, Value input) {
545 Type inputType = input.getType();
546 Type outputType = rewriter.getI32Type();
547 if (auto intType = dyn_cast<IntegerType>(inputType))
548 return rewriter.create<LLVM::ZExtOp>(loc, outputType, input);
549 return rewriter.create<LLVM::BitcastOp>(loc, outputType, input);
550}
551
552/// Push an input operand. If it is a float type, nothing to do. If it is
553/// an integer type, then we need to also push its signdness (1 for signed, 0
554/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
555/// vector (or the 8xi8 vector into a 2xi32 one for gfx12+).
556/// We also need to convert bfloat inputs to i16 to account for the bfloat
557/// intrinsics having been defined before the AMD backend supported bfloat. We
558/// similarly need to pack 8-bit float types into integers as if they were i8
559/// (which they are for the backend's purposes).
560static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
561 Location loc,
562 const TypeConverter *typeConverter,
563 bool isUnsigned, Value llvmInput,
564 Value mlirInput,
565 SmallVector<Value, 4> &operands) {
566 Type inputType = llvmInput.getType();
567 auto vectorType = dyn_cast<VectorType>(inputType);
568 if (!vectorType) {
569 operands.push_back(Elt: llvmInput);
570 return;
571 }
572 Type elemType = vectorType.getElementType();
573
574 if (elemType.isBF16())
575 llvmInput = rewriter.create<LLVM::BitcastOp>(
576 loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
577 if (elemType.getIntOrFloatBitWidth() > 8) {
578 operands.push_back(Elt: llvmInput);
579 return;
580 }
581
582 // We need to check the type of the input before conversion to properly test
583 // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
584 // fp8/int8 information is lost during the conversion process.
585 auto mlirInputType = cast<VectorType>(mlirInput.getType());
586 bool isInputInteger = mlirInputType.getElementType().isInteger();
587 if (isInputInteger) {
588 // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
589 bool localIsUnsigned = isUnsigned;
590 if (elemType.isUnsignedInteger()) {
591 localIsUnsigned = true;
592 } else if (elemType.isSignedInteger()) {
593 localIsUnsigned = false;
594 }
595 Value sign = createI1Constant(rewriter, loc, value: !localIsUnsigned);
596 operands.push_back(Elt: sign);
597 }
598
599 int64_t numBits =
600 vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
601 Type i32 = rewriter.getI32Type();
602 Type intrinsicInType = numBits <= 32
603 ? (Type)rewriter.getIntegerType(numBits)
604 : (Type)VectorType::get(numBits / 32, i32);
605 auto llvmIntrinsicInType = typeConverter->convertType(t: intrinsicInType);
606 Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
607 loc, llvmIntrinsicInType, llvmInput);
608 // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need
609 // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
610 // Add in the zeros here.
611 if (numBits < 32)
612 castInput = rewriter.create<LLVM::ZExtOp>(loc, i32, castInput);
613 operands.push_back(Elt: castInput);
614}
615
616/// Push the output operand. For many cases this is only pushing the output in
617/// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
618/// since the same numbers of VGPRs is used, we need to decide if to store the
619/// result in the upper 16 bits of the VGPRs or in the lower part. To store the
620/// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
621/// be stored it in the upper part. The subwordOffset must not be set for gfx12,
622/// as the instructions have been changed to return fewer registers instead.
623static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
624 Location loc,
625 const TypeConverter *typeConverter,
626 Value output, int32_t subwordOffset,
627 bool clamp, SmallVector<Value, 4> &operands) {
628 Type inputType = output.getType();
629 auto vectorType = dyn_cast<VectorType>(inputType);
630 Type elemType = vectorType.getElementType();
631 if (elemType.isBF16())
632 output = rewriter.create<LLVM::BitcastOp>(
633 loc, vectorType.clone(rewriter.getI16Type()), output);
634 operands.push_back(Elt: output);
635 if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(width: 16)) {
636 operands.push_back(Elt: createI1Constant(rewriter, loc, value: subwordOffset));
637 } else if (elemType.isInteger(width: 32)) {
638 operands.push_back(Elt: createI1Constant(rewriter, loc, value: clamp));
639 }
640}
641
642/// Return true if `type` is the E5M2 variant of an 8-bit float that is
643/// supported by the `_bf8` instructions on the given `chipset`.
644static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) {
645 return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
646 (hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
647}
648
649/// Return true if `type` is the E4M3FN variant of an 8-bit float that is
650/// supported by the `_fp8` instructions on the given `chipset`.
651static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) {
652 return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
653 (hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
654}
655
656/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
657/// if one exists. This includes checking to ensure the intrinsic is supported
658/// on the architecture you are compiling for.
659static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
660 Chipset chipset) {
661 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
662 b = mfma.getBlocks();
663 Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType());
664 Type destElem = getElementTypeOrSelf(mfma.getDestC().getType());
665
666 if (sourceElem.isF32() && destElem.isF32()) {
667 if (mfma.getReducePrecision() && chipset >= kGfx942) {
668 if (m == 32 && n == 32 && k == 4 && b == 1)
669 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
670 if (m == 16 && n == 16 && k == 8 && b == 1)
671 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
672 }
673 if (m == 32 && n == 32 && k == 1 && b == 2)
674 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
675 if (m == 16 && n == 16 && k == 1 && b == 4)
676 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
677 if (m == 4 && n == 4 && k == 1 && b == 16)
678 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
679 if (m == 32 && n == 32 && k == 2 && b == 1)
680 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
681 if (m == 16 && n == 16 && k == 4 && b == 1)
682 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
683 }
684
685 if (sourceElem.isF16() && destElem.isF32()) {
686 if (chipset >= kGfx950) {
687 if (m == 32 && n == 32 && k == 16 && b == 1)
688 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
689 if (m == 16 && n == 16 && k == 32 && b == 1)
690 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
691 }
692 if (m == 32 && n == 32 && k == 4 && b == 2)
693 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
694 if (m == 16 && n == 16 && k == 4 && b == 4)
695 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
696 if (m == 4 && n == 4 && k == 4 && b == 16)
697 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
698 if (m == 32 && n == 32 && k == 8 && b == 1)
699 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
700 if (m == 16 && n == 16 && k == 16 && b == 1)
701 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
702 }
703
704 if (sourceElem.isBF16() && destElem.isF32()) {
705 if (chipset >= kGfx950) {
706 if (m == 32 && n == 32 && k == 16 && b == 1)
707 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
708 if (m == 16 && n == 16 && k == 32 && b == 1)
709 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
710 }
711 if (chipset >= kGfx90a) {
712 if (m == 32 && n == 32 && k == 4 && b == 2)
713 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
714 if (m == 16 && n == 16 && k == 4 && b == 4)
715 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
716 if (m == 4 && n == 4 && k == 4 && b == 16)
717 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
718 if (m == 32 && n == 32 && k == 8 && b == 1)
719 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
720 if (m == 16 && n == 16 && k == 16 && b == 1)
721 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
722 }
723 if (m == 32 && n == 32 && k == 2 && b == 2)
724 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
725 if (m == 16 && n == 16 && k == 2 && b == 4)
726 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
727 if (m == 4 && n == 4 && k == 2 && b == 16)
728 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
729 if (m == 32 && n == 32 && k == 4 && b == 1)
730 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
731 if (m == 16 && n == 16 && k == 8 && b == 1)
732 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
733 }
734
735 if (sourceElem.isInteger(width: 8) && destElem.isInteger(width: 32)) {
736 if (chipset >= kGfx950) {
737 if (m == 32 && n == 32 && k == 32 && b == 1)
738 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
739 if (m == 16 && n == 16 && k == 64 && b == 1)
740 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
741 }
742 if (m == 32 && n == 32 && k == 4 && b == 2)
743 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
744 if (m == 16 && n == 16 && k == 4 && b == 4)
745 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
746 if (m == 4 && n == 4 && k == 4 && b == 16)
747 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
748 if (m == 32 && n == 32 && k == 8 && b == 1)
749 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
750 if (m == 16 && n == 16 && k == 16 && b == 1)
751 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
752 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942)
753 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
754 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942)
755 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
756 }
757
758 if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
759 if (m == 16 && n == 16 && k == 4 && b == 1)
760 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
761 if (m == 4 && n == 4 && k == 4 && b == 4)
762 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
763 }
764
765 if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, type: sourceElem)) {
766 // Known to be correct because there are no scalar f8 instructions and
767 // because a length mismatch will have been caught by the verifier.
768 Type sourceBElem =
769 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
770 if (m == 16 && n == 16 && k == 32 && b == 1) {
771 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
772 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
773 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
774 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
775 }
776 if (m == 32 && n == 32 && k == 16 && b == 1) {
777 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
778 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
779 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
780 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
781 }
782 }
783
784 if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, type: sourceElem)) {
785 Type sourceBElem =
786 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
787 if (m == 16 && n == 16 && k == 32 && b == 1) {
788 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
789 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
790 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
791 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
792 }
793 if (m == 32 && n == 32 && k == 16 && b == 1) {
794 if (typeIsExpectedBf8ForChipset(chipset, sourceBElem))
795 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
796 if (typeIsExpectedFp8ForChipset(chipset, sourceBElem))
797 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
798 }
799 }
800
801 return std::nullopt;
802}
803
804static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
805 return llvm::TypeSwitch<Type, std::optional<uint32_t>>(mlirElemType)
806 .Case(caseFn: [](Float8E4M3FNType) { return 0u; })
807 .Case(caseFn: [](Float8E5M2Type) { return 1u; })
808 .Case(caseFn: [](Float6E2M3FNType) { return 2u; })
809 .Case(caseFn: [](Float6E3M2FNType) { return 3u; })
810 .Case(caseFn: [](Float4E2M1FNType) { return 4u; })
811 .Default(defaultFn: [](Type) { return std::nullopt; });
812}
813
814/// If there is a scaled MFMA instruction for the input element types `aType`
815/// and `bType`, output type `destType`, problem size M, N, K, and B (number of
816/// blocks) on the given `chipset`, return a tuple consisting of the
817/// OperationName of the intrinsic and the type codes that need to be passed to
818/// that intrinsic. Note that this is also used to implement some un-scaled
819/// MFMAs, since the compiler represents the ordinary instruction as a "scaled"
820/// MFMA with a scale of 0.
821static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
822mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
823 uint32_t n, uint32_t k, uint32_t b, Chipset chipset) {
824 aType = getElementTypeOrSelf(type: aType);
825 bType = getElementTypeOrSelf(type: bType);
826 destType = getElementTypeOrSelf(type: destType);
827
828 if (chipset < kGfx950)
829 return std::nullopt;
830 if (!isa<Float32Type>(destType))
831 return std::nullopt;
832
833 std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(mlirElemType: aType);
834 std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(mlirElemType: bType);
835 if (!aTypeCode || !bTypeCode)
836 return std::nullopt;
837
838 if (m == 32 && n == 32 && k == 64 && b == 1)
839 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
840 *aTypeCode, *bTypeCode};
841 if (m == 16 && n == 16 && k == 128 && b == 1)
842 return std::tuple{
843 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
844 *bTypeCode};
845
846 return std::nullopt;
847}
848
849static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
850mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
851 return mfmaOpToScaledIntrinsic(
852 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
853 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
854 mfma.getBlocks(), chipset);
855}
856
857static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
858mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
859 return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
860 smfma.getSourceB().getType(),
861 smfma.getDestC().getType(), smfma.getM(),
862 smfma.getN(), smfma.getK(), 1u, chipset);
863}
864
865/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
866/// if one exists. This includes checking to ensure the intrinsic is supported
867/// on the architecture you are compiling for.
868static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
869 Chipset chipset) {
870 auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
871 auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
872 auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
873 auto elemSourceType = sourceVectorType.getElementType();
874 auto elemBSourceType = sourceBVectorType.getElementType();
875 auto elemDestType = destVectorType.getElementType();
876
877 if (elemSourceType.isF16() && elemDestType.isF32())
878 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
879 if (elemSourceType.isBF16() && elemDestType.isF32())
880 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
881 if (elemSourceType.isF16() && elemDestType.isF16())
882 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
883 if (elemSourceType.isBF16() && elemDestType.isBF16())
884 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
885 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
886 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
887 if (chipset.majorVersion == 11) {
888 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
889 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
890 }
891 if (chipset.majorVersion >= 12) {
892 if (isa<Float8E4M3FNType>(elemSourceType) &&
893 isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
894 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
895 if (isa<Float8E4M3FNType>(elemSourceType) &&
896 isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
897 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
898 if (isa<Float8E5M2Type>(elemSourceType) &&
899 isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
900 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
901 if (isa<Float8E5M2Type>(elemSourceType) &&
902 isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
903 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
904 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
905 bool isWave64 = destVectorType.getNumElements() == 4;
906 // This is the ambiguous case. 8 inputs to the wave64 version means that
907 // we want the 16x16x32 version, but for wave32 they mean the short form.
908 bool has8Inputs = sourceVectorType.getNumElements() == 8;
909 if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
910 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
911 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
912 }
913 }
914 return std::nullopt;
915}
916
917namespace {
918struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
919 MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
920 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
921
922 Chipset chipset;
923
924 LogicalResult
925 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
926 ConversionPatternRewriter &rewriter) const override {
927 Location loc = op.getLoc();
928 Type outType = typeConverter->convertType(op.getDestD().getType());
929 Type intrinsicOutType = outType;
930 if (auto outVecType = dyn_cast<VectorType>(outType))
931 if (outVecType.getElementType().isBF16())
932 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
933
934 if (chipset.majorVersion != 9 || chipset < kGfx908)
935 return op->emitOpError("MFMA only supported on gfx908+");
936 uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
937 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
938 if (chipset < kGfx942)
939 return op.emitOpError("negation unsupported on older than gfx942");
940 getBlgpField |=
941 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
942 }
943 std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
944 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
945 maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
946 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
947 return op.emitOpError("no intrinsic matching MFMA size on given chipset");
948
949 bool isScaled =
950 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
951 if (isScaled &&
952 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
953 return op.emitOpError(
954 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
955 "be scaled as those fields are used for type information");
956 }
957
958 StringRef intrinsicName =
959 isScaled ? std::get<0>(t&: *maybeScaledIntrinsic) : *maybeIntrinsic;
960 OperationState loweredOp(loc, intrinsicName);
961 loweredOp.addTypes(newTypes: intrinsicOutType);
962 loweredOp.addOperands(
963 newOperands: {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
964 convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
965 adaptor.getDestC()});
966 if (isScaled) {
967 Value zero = createI32Constant(rewriter, loc, value: 0);
968 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
969 loweredOp.addOperands(newOperands: {createI32Constant(rewriter, loc, aTypeCode),
970 createI32Constant(rewriter, loc, bTypeCode),
971 /*scale A byte=*/zero, /*scale A=*/zero,
972 /*scale B byte=*/zero, /*scale B=*/zero});
973 } else {
974 loweredOp.addOperands(newOperands: {createI32Constant(rewriter, loc, op.getCbsz()),
975 createI32Constant(rewriter, loc, op.getAbid()),
976 createI32Constant(rewriter, loc, value: getBlgpField)});
977 };
978 Value lowered = rewriter.create(state: loweredOp)->getResult(idx: 0);
979 if (outType != intrinsicOutType)
980 lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
981 rewriter.replaceOp(op, lowered);
982 return success();
983 }
984};
985
986struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
987 ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
988 : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
989
990 Chipset chipset;
991
992 LogicalResult
993 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
994 ConversionPatternRewriter &rewriter) const override {
995 Location loc = op.getLoc();
996 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
997
998 if (chipset.majorVersion != 9 || chipset < kGfx950)
999 return op->emitOpError("scaled MFMA only supported on gfx908+");
1000 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1001 maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1002 if (!maybeScaledIntrinsic.has_value())
1003 return op.emitOpError(
1004 "no intrinsic matching scaled MFMA size on given chipset");
1005
1006 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1007 OperationState loweredOp(loc, intrinsicName);
1008 loweredOp.addTypes(newTypes: intrinsicOutType);
1009 loweredOp.addOperands(
1010 newOperands: {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
1011 convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
1012 adaptor.getDestC()});
1013 Value scalesIdxA =
1014 createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
1015 Value scalesIdxB =
1016 createI32Constant(rewriter, loc, adaptor.getScalesIdxB());
1017 loweredOp.addOperands(
1018 newOperands: {createI32Constant(rewriter, loc, aTypeCode),
1019 createI32Constant(rewriter, loc, bTypeCode),
1020 /*scales idx A=*/scalesIdxA,
1021 /*scales A*/
1022 castMFMAScaleOperand(rewriter, loc, adaptor.getScalesA()),
1023 /*scales idx B=*/scalesIdxB,
1024 /*scales B*/
1025 castMFMAScaleOperand(rewriter, loc, adaptor.getScalesB())});
1026 Value lowered = rewriter.create(state: loweredOp)->getResult(idx: 0);
1027 rewriter.replaceOp(op, lowered);
1028 return success();
1029 }
1030};
1031
1032struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
1033 WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1034 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1035
1036 Chipset chipset;
1037
1038 LogicalResult
1039 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1040 ConversionPatternRewriter &rewriter) const override {
1041 Location loc = op.getLoc();
1042 auto outType =
1043 typeConverter->convertType<VectorType>(op.getDestD().getType());
1044 if (!outType)
1045 return rewriter.notifyMatchFailure(op, "type conversion failed");
1046
1047 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1048 return op->emitOpError("WMMA only supported on gfx11 and gfx12");
1049
1050 // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
1051 // need to bitcast bfloats to i16 and then bitcast them back.
1052 VectorType rawOutType = outType;
1053 if (outType.getElementType().isBF16())
1054 rawOutType = outType.clone(rewriter.getI16Type());
1055
1056 std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
1057
1058 if (!maybeIntrinsic.has_value())
1059 return op.emitOpError("no intrinsic matching WMMA on the given chipset");
1060
1061 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1062 return op.emitOpError("subwordOffset not supported on gfx12+");
1063
1064 OperationState loweredOp(loc, *maybeIntrinsic);
1065 loweredOp.addTypes(rawOutType);
1066
1067 SmallVector<Value, 4> operands;
1068 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
1069 adaptor.getSourceA(), op.getSourceA(), operands);
1070 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
1071 adaptor.getSourceB(), op.getSourceB(), operands);
1072 wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
1073 op.getSubwordOffset(), op.getClamp(), operands);
1074
1075 loweredOp.addOperands(newOperands: operands);
1076 Operation *lowered = rewriter.create(state: loweredOp);
1077
1078 Operation *maybeCastBack = lowered;
1079 if (rawOutType != outType)
1080 maybeCastBack =
1081 rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0));
1082 rewriter.replaceOp(op, maybeCastBack->getResults());
1083
1084 return success();
1085 }
1086};
1087
1088struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
1089 GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1090 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
1091
1092 Chipset chipset;
1093
1094 LogicalResult
1095 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1096 ConversionPatternRewriter &rewriter) const override {
1097 if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
1098 return op.emitOpError("pre-gfx9 and post-gfx10 not supported");
1099
1100 Location loc = op.getLoc();
1101
1102 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1103 auto dstMemRefType = cast<MemRefType>(op.getSrc().getType());
1104
1105 // TODO: instead of only transfering one element per thread, we could
1106 // augment it to transfer multiple elements per thread by issuing multiple
1107 // `global_load_lds` instructions.
1108 Type transferType = op.getTransferType();
1109 size_t loadWidth = [&]() -> size_t {
1110 if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1111 return transferVectorType.getNumElements() *
1112 (transferVectorType.getElementTypeBitWidth() / 8);
1113 }
1114 return transferType.getIntOrFloatBitWidth() / 8;
1115 }();
1116
1117 // Currently only 1, 2, and 4 byte loads are supported.
1118 if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
1119 return op.emitOpError("chipset unsupported element size");
1120
1121 Value srcPtr =
1122 getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
1123 (adaptor.getSrcIndices()));
1124 Value dstPtr =
1125 getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(),
1126 (adaptor.getDstIndices()));
1127
1128 rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
1129 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
1130 /*offset=*/rewriter.getI32IntegerAttr(0),
1131 /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{},
1132 ArrayAttr{});
1133
1134 return success();
1135 }
1136};
1137
1138namespace {
1139struct ExtPackedFp8OpLowering final
1140 : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
1141 ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1142 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
1143 chipset(chipset) {}
1144 Chipset chipset;
1145
1146 LogicalResult
1147 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1148 ConversionPatternRewriter &rewriter) const override;
1149};
1150
1151struct PackedTrunc2xFp8OpLowering final
1152 : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
1153 PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
1154 Chipset chipset)
1155 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
1156 chipset(chipset) {}
1157 Chipset chipset;
1158
1159 LogicalResult
1160 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1161 ConversionPatternRewriter &rewriter) const override;
1162};
1163
1164struct PackedStochRoundFp8OpLowering final
1165 : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
1166 PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
1167 Chipset chipset)
1168 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
1169 chipset(chipset) {}
1170 Chipset chipset;
1171
1172 LogicalResult
1173 matchAndRewrite(PackedStochRoundFp8Op op,
1174 PackedStochRoundFp8OpAdaptor adaptor,
1175 ConversionPatternRewriter &rewriter) const override;
1176};
1177} // end namespace
1178
1179LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1180 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1181 ConversionPatternRewriter &rewriter) const {
1182 Location loc = op.getLoc();
1183 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1184 return rewriter.notifyMatchFailure(
1185 arg&: loc, msg: "Fp8 conversion instructions are not available on target "
1186 "architecture and their emulation is not implemented");
1187 Type v4i8 =
1188 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
1189 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1190 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1191
1192 Value source = adaptor.getSource();
1193 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1194 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
1195 Type sourceElemType = getElementTypeOrSelf(op.getSource());
1196 // Extend to a v4i8
1197 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1198 Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
1199 if (!sourceVecType) {
1200 longVec = rewriter.create<LLVM::InsertElementOp>(
1201 loc, longVec, source, createI32Constant(rewriter, loc, 0));
1202 } else {
1203 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1204 Value idx = createI32Constant(rewriter, loc, value: i);
1205 Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
1206 longVec =
1207 rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
1208 }
1209 }
1210 source = longVec;
1211 }
1212 Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
1213 if (resultVecType) {
1214 if (typeIsExpectedBf8ForChipset(chipset, type: sourceElemType)) {
1215 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
1216 op.getIndex());
1217 } else if (typeIsExpectedFp8ForChipset(chipset, type: sourceElemType)) {
1218 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
1219 op.getIndex());
1220 }
1221 } else {
1222 if (typeIsExpectedBf8ForChipset(chipset, type: sourceElemType)) {
1223 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
1224 op.getIndex());
1225 } else if (typeIsExpectedFp8ForChipset(chipset, type: sourceElemType)) {
1226 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
1227 op.getIndex());
1228 }
1229 }
1230 return success();
1231}
1232
1233LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
1234 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1235 ConversionPatternRewriter &rewriter) const {
1236 Location loc = op.getLoc();
1237 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1238 return rewriter.notifyMatchFailure(
1239 arg&: loc, msg: "Fp8 conversion instructions are not available on target "
1240 "architecture and their emulation is not implemented");
1241 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1242
1243 Type resultType = op.getResult().getType();
1244 Type resultElemType = getElementTypeOrSelf(type: resultType);
1245
1246 Value sourceA = adaptor.getSourceA();
1247 Value sourceB = adaptor.getSourceB();
1248 if (!sourceB)
1249 sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType());
1250 Value existing = adaptor.getExisting();
1251 if (existing)
1252 existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
1253 else
1254 existing = rewriter.create<LLVM::UndefOp>(loc, i32);
1255
1256 Value result;
1257 if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
1258 result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
1259 existing, op.getWordIndex());
1260 else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
1261 result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
1262 existing, op.getWordIndex());
1263
1264 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1265 op, getTypeConverter()->convertType(resultType), result);
1266 return success();
1267}
1268
1269LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
1270 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
1271 ConversionPatternRewriter &rewriter) const {
1272 Location loc = op.getLoc();
1273 if (!(chipset == kGfx942 || hasOcpFp8(chipset)))
1274 return rewriter.notifyMatchFailure(
1275 arg&: loc, msg: "Fp8 conversion instructions are not available on target "
1276 "architecture and their emulation is not implemented");
1277 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1278
1279 Type resultType = op.getResult().getType();
1280 Type resultElemType = getElementTypeOrSelf(type: resultType);
1281
1282 Value source = adaptor.getSource();
1283 Value stoch = adaptor.getStochiasticParam();
1284 Value existing = adaptor.getExisting();
1285 if (existing)
1286 existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
1287 else
1288 existing = rewriter.create<LLVM::UndefOp>(loc, i32);
1289
1290 Value result;
1291 if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
1292 result = rewriter.create<ROCDL::CvtSrBf8F32Op>(
1293 loc, i32, source, stoch, existing, op.getStoreIndex());
1294 else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
1295 result = rewriter.create<ROCDL::CvtSrFp8F32Op>(
1296 loc, i32, source, stoch, existing, op.getStoreIndex());
1297
1298 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1299 op, getTypeConverter()->convertType(resultType), result);
1300 return success();
1301}
1302
1303// Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
1304// operation into the corresponding ROCDL instructions.
1305struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
1306 AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
1307 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
1308 Chipset chipset;
1309
1310 LogicalResult
1311 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
1312 ConversionPatternRewriter &rewriter) const override {
1313
1314 // Convert the source operand to the corresponding LLVM type
1315 Location loc = DppOp.getLoc();
1316 Value src = adaptor.getSrc();
1317 Value old = adaptor.getOld();
1318 Type srcType = src.getType();
1319 Type oldType = old.getType();
1320 Type llvmType = nullptr;
1321 if (srcType.getIntOrFloatBitWidth() < 32) {
1322 llvmType = rewriter.getI32Type();
1323 } else if (isa<FloatType>(Val: srcType)) {
1324 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
1325 ? rewriter.getF32Type()
1326 : rewriter.getF64Type();
1327 } else if (isa<IntegerType>(Val: srcType)) {
1328 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
1329 ? rewriter.getI32Type()
1330 : rewriter.getI64Type();
1331 }
1332 auto llvmSrcIntType = typeConverter->convertType(
1333 rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
1334
1335 // If the source type is less of 32, use bitcast to convert it to i32.
1336 auto convertOperand = [&](Value operand, Type operandType) {
1337 if (operandType.getIntOrFloatBitWidth() <= 16) {
1338 if (llvm::isa<FloatType>(Val: operandType)) {
1339 operand =
1340 rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
1341 }
1342 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
1343 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
1344 Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
1345 operand = rewriter.create<LLVM::InsertElementOp>(
1346 loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
1347 operand = rewriter.create<LLVM::BitcastOp>(loc, llvmType, operand);
1348 }
1349 return operand;
1350 };
1351
1352 src = convertOperand(src, srcType);
1353 old = convertOperand(old, oldType);
1354
1355 // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
1356 enum DppCtrl : unsigned {
1357 ROW_SHL0 = 0x100,
1358 ROW_SHR0 = 0x110,
1359 ROW_ROR0 = 0x120,
1360 WAVE_SHL1 = 0x130,
1361 WAVE_ROL1 = 0x134,
1362 WAVE_SHR1 = 0x138,
1363 WAVE_ROR1 = 0x13C,
1364 ROW_MIRROR = 0x140,
1365 ROW_HALF_MIRROR = 0x141,
1366 BCAST15 = 0x142,
1367 BCAST31 = 0x143,
1368 };
1369
1370 auto kind = DppOp.getKind();
1371 auto permArgument = DppOp.getPermArgument();
1372 uint32_t DppCtrl = 0;
1373
1374 switch (kind) {
1375
1376 case DPPPerm::quad_perm:
1377 if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
1378 int32_t i = 0;
1379 for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
1380 uint32_t num = elem.getInt();
1381 DppCtrl |= num << (i * 2);
1382 i++;
1383 }
1384 }
1385 break;
1386 case DPPPerm::row_shl:
1387 if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
1388 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
1389 }
1390 break;
1391 case DPPPerm::row_shr:
1392 if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
1393 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
1394 }
1395 break;
1396 case DPPPerm::row_ror:
1397 if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
1398 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
1399 }
1400 break;
1401 case DPPPerm::wave_shl:
1402 DppCtrl = DppCtrl::WAVE_SHL1;
1403 break;
1404 case DPPPerm::wave_shr:
1405 DppCtrl = DppCtrl::WAVE_SHR1;
1406 break;
1407 case DPPPerm::wave_rol:
1408 DppCtrl = DppCtrl::WAVE_ROL1;
1409 break;
1410 case DPPPerm::wave_ror:
1411 DppCtrl = DppCtrl::WAVE_ROR1;
1412 break;
1413 case DPPPerm::row_mirror:
1414 DppCtrl = DppCtrl::ROW_MIRROR;
1415 break;
1416 case DPPPerm::row_half_mirror:
1417 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
1418 break;
1419 case DPPPerm::row_bcast_15:
1420 DppCtrl = DppCtrl::BCAST15;
1421 break;
1422 case DPPPerm::row_bcast_31:
1423 DppCtrl = DppCtrl::BCAST31;
1424 break;
1425 }
1426
1427 // Check for row_mask, bank_mask, bound_ctrl if they exist and create
1428 // constants
1429 auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
1430 auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
1431 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
1432
1433 // create a ROCDL_DPPMovOp instruction with the appropriate attributes
1434 auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
1435 loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
1436
1437 Value result = dppMovOp.getRes();
1438 if (srcType.getIntOrFloatBitWidth() < 32) {
1439 result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
1440 if (!llvm::isa<IntegerType>(Val: srcType)) {
1441 result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
1442 }
1443 }
1444
1445 // We are replacing the AMDGPU_DPPOp instruction with the new
1446 // ROCDL_DPPMovOp instruction
1447 rewriter.replaceOp(DppOp, ValueRange(result));
1448 return success();
1449 }
1450};
1451
1452struct AMDGPUSwizzleBitModeLowering
1453 : public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
1454 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1455
1456 LogicalResult
1457 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
1458 ConversionPatternRewriter &rewriter) const override {
1459 Location loc = op.getLoc();
1460 Type i32 = rewriter.getI32Type();
1461 Value src = adaptor.getSrc();
1462 SmallVector<Value> decomposed =
1463 LLVM::decomposeValue(builder&: rewriter, loc, src, dstType: i32);
1464 unsigned andMask = op.getAndMask();
1465 unsigned orMask = op.getOrMask();
1466 unsigned xorMask = op.getXorMask();
1467
1468 // bit 15 is 0 for the BitMode swizzle.
1469 // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
1470 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
1471 Value maskValue = createI32Constant(rewriter, loc, value: mask);
1472 SmallVector<Value> swizzled;
1473 for (Value v : decomposed) {
1474 Value res =
1475 rewriter.create<ROCDL::DsSwizzleOp>(loc, v.getType(), v, maskValue);
1476 swizzled.emplace_back(res);
1477 }
1478
1479 Value result = LLVM::composeValue(builder&: rewriter, loc, src: swizzled, dstType: src.getType());
1480 rewriter.replaceOp(op, result);
1481 return success();
1482 }
1483};
1484
1485struct ConvertAMDGPUToROCDLPass
1486 : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
1487 using Base::Base;
1488
1489 void runOnOperation() override {
1490 MLIRContext *ctx = &getContext();
1491 FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
1492 if (failed(Result: maybeChipset)) {
1493 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
1494 return signalPassFailure();
1495 }
1496
1497 RewritePatternSet patterns(ctx);
1498 LLVMTypeConverter converter(ctx);
1499 populateAMDGPUToROCDLConversionPatterns(converter, patterns, chipset: *maybeChipset);
1500 LLVMConversionTarget target(getContext());
1501 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
1502 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
1503 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
1504 if (failed(applyPartialConversion(getOperation(), target,
1505 std::move(patterns))))
1506 signalPassFailure();
1507 }
1508};
1509} // namespace
1510
1511void mlir::populateAMDGPUMemorySpaceAttributeConversions(
1512 TypeConverter &typeConverter) {
1513 typeConverter.addTypeAttributeConversion(
1514 callback: [](BaseMemRefType type, amdgpu::AddressSpaceAttr as)
1515 -> TypeConverter::AttributeConversionResult {
1516 MLIRContext *ctx = as.getContext();
1517 Type i64 = IntegerType::get(ctx, 64);
1518 switch (as.getValue()) {
1519 case amdgpu::AddressSpace::FatRawBuffer:
1520 return IntegerAttr::get(i64, 7);
1521 case amdgpu::AddressSpace::BufferRsrc:
1522 return IntegerAttr::get(i64, 8);
1523 case amdgpu::AddressSpace::FatStructuredBuffer:
1524 return IntegerAttr::get(i64, 9);
1525 }
1526 return TypeConverter::AttributeConversionResult::abort();
1527 });
1528}
1529
1530void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
1531 RewritePatternSet &patterns,
1532 Chipset chipset) {
1533 populateAMDGPUMemorySpaceAttributeConversions(typeConverter&: converter);
1534 patterns
1535 .add<FatRawBufferCastLowering,
1536 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
1537 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
1538 RawBufferOpLowering<RawBufferAtomicFaddOp,
1539 ROCDL::RawPtrBufferAtomicFaddOp>,
1540 RawBufferOpLowering<RawBufferAtomicFmaxOp,
1541 ROCDL::RawPtrBufferAtomicFmaxOp>,
1542 RawBufferOpLowering<RawBufferAtomicSmaxOp,
1543 ROCDL::RawPtrBufferAtomicSmaxOp>,
1544 RawBufferOpLowering<RawBufferAtomicUminOp,
1545 ROCDL::RawPtrBufferAtomicUminOp>,
1546 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
1547 ROCDL::RawPtrBufferAtomicCmpSwap>,
1548 AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
1549 MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
1550 ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
1551 PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
1552 chipset);
1553 patterns.add<AMDGPUSwizzleBitModeLowering>(arg&: converter);
1554}
1555

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp