1 | //===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" |
10 | |
11 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
12 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
13 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
14 | #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" |
15 | #include "mlir/Dialect/AMDGPU/Utils/Chipset.h" |
16 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
17 | #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" |
18 | #include "mlir/IR/BuiltinTypes.h" |
19 | #include "mlir/IR/TypeUtilities.h" |
20 | #include "mlir/Pass/Pass.h" |
21 | |
22 | #include "../LLVMCommon/MemRefDescriptor.h" |
23 | |
24 | #include "llvm/ADT/STLExtras.h" |
25 | #include "llvm/ADT/TypeSwitch.h" |
26 | #include "llvm/Support/Casting.h" |
27 | #include <optional> |
28 | |
29 | namespace mlir { |
30 | #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS |
31 | #include "mlir/Conversion/Passes.h.inc" |
32 | } // namespace mlir |
33 | |
34 | using namespace mlir; |
35 | using namespace mlir::amdgpu; |
36 | |
37 | // Define commonly used chipsets versions for convenience. |
38 | constexpr Chipset kGfx908 = Chipset(9, 0, 8); |
39 | constexpr Chipset kGfx90a = Chipset(9, 0, 0xa); |
40 | constexpr Chipset kGfx942 = Chipset(9, 4, 2); |
41 | constexpr Chipset kGfx950 = Chipset(9, 5, 0); |
42 | |
43 | /// Convert an unsigned number `val` to i32. |
44 | static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, |
45 | Location loc, Value val) { |
46 | IntegerType i32 = rewriter.getI32Type(); |
47 | // Force check that `val` is of int type. |
48 | auto valTy = cast<IntegerType>(val.getType()); |
49 | if (i32 == valTy) |
50 | return val; |
51 | return valTy.getWidth() > 32 |
52 | ? Value(rewriter.create<LLVM::TruncOp>(loc, i32, val)) |
53 | : Value(rewriter.create<LLVM::ZExtOp>(loc, i32, val)); |
54 | } |
55 | |
56 | static Value createI32Constant(ConversionPatternRewriter &rewriter, |
57 | Location loc, int32_t value) { |
58 | Type i32 = rewriter.getI32Type(); |
59 | return rewriter.create<LLVM::ConstantOp>(loc, i32, value); |
60 | } |
61 | |
62 | static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, |
63 | bool value) { |
64 | Type llvmI1 = rewriter.getI1Type(); |
65 | return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value); |
66 | } |
67 | |
68 | /// Returns the linear index used to access an element in the memref. |
69 | static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, |
70 | Location loc, MemRefDescriptor &memRefDescriptor, |
71 | ValueRange indices, ArrayRef<int64_t> strides) { |
72 | IntegerType i32 = rewriter.getI32Type(); |
73 | Value index; |
74 | for (auto [i, increment, stride] : llvm::enumerate(First&: indices, Rest&: strides)) { |
75 | if (stride != 1) { // Skip if stride is 1. |
76 | Value strideValue = |
77 | ShapedType::isDynamic(stride) |
78 | ? convertUnsignedToI32(rewriter, loc, |
79 | memRefDescriptor.stride(rewriter, loc, i)) |
80 | : rewriter.create<LLVM::ConstantOp>(loc, i32, stride); |
81 | increment = rewriter.create<LLVM::MulOp>(loc, increment, strideValue); |
82 | } |
83 | index = |
84 | index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment; |
85 | } |
86 | return index ? index : createI32Constant(rewriter, loc, value: 0); |
87 | } |
88 | |
89 | /// Compute the contents of the `num_records` field for a given memref |
90 | /// descriptor - that is, the number of bytes that's one element past the |
91 | /// greatest possible valid index into the memref. |
92 | static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, |
93 | MemRefType memrefType, |
94 | MemRefDescriptor &memrefDescriptor, |
95 | ArrayRef<int64_t> strides, |
96 | uint32_t elementByteWidth) { |
97 | if (memrefType.hasStaticShape() && |
98 | !llvm::any_of(strides, ShapedType::isDynamic)) { |
99 | int64_t size = memrefType.getRank() == 0 ? 1 : 0; |
100 | ArrayRef<int64_t> shape = memrefType.getShape(); |
101 | for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) |
102 | size = std::max(a: shape[i] * strides[i], b: size); |
103 | size = size * elementByteWidth; |
104 | assert(size < std::numeric_limits<uint32_t>::max() && |
105 | "the memref buffer is too large"); |
106 | return createI32Constant(rewriter, loc, value: static_cast<int32_t>(size)); |
107 | } |
108 | Value maxIndex; |
109 | for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) { |
110 | Value size = memrefDescriptor.size(builder&: rewriter, loc, pos: i); |
111 | Value stride = memrefDescriptor.stride(builder&: rewriter, loc, pos: i); |
112 | Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride); |
113 | maxIndex = maxIndex |
114 | ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim) |
115 | : maxThisDim; |
116 | } |
117 | Value maxIndexI32 = convertUnsignedToI32(rewriter, loc, val: maxIndex); |
118 | Value byteWidthConst = createI32Constant(rewriter, loc, value: elementByteWidth); |
119 | return rewriter.create<LLVM::MulOp>(loc, maxIndexI32, byteWidthConst); |
120 | } |
121 | |
122 | static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, |
123 | Value basePointer, Value numRecords, |
124 | bool boundsCheck, amdgpu::Chipset chipset, |
125 | Value cacheSwizzleStride = nullptr, |
126 | unsigned addressSpace = 8) { |
127 | // The stride value is generally 0. However, on MI-300 and onward, you can |
128 | // enable a cache swizzling mode by setting bit 14 of the stride field |
129 | // and setting that stride to a cache stride. |
130 | Type i16 = rewriter.getI16Type(); |
131 | Value stride; |
132 | if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) { |
133 | Value cacheStrideZext = |
134 | rewriter.create<LLVM::ZExtOp>(loc, i16, cacheSwizzleStride); |
135 | Value swizzleBit = rewriter.create<LLVM::ConstantOp>( |
136 | loc, i16, rewriter.getI16IntegerAttr(1 << 14)); |
137 | stride = rewriter.create<LLVM::OrOp>(loc, cacheStrideZext, swizzleBit, |
138 | /*isDisjoint=*/true); |
139 | } else { |
140 | stride = rewriter.create<LLVM::ConstantOp>(loc, i16, |
141 | rewriter.getI16IntegerAttr(0)); |
142 | } |
143 | // Get the number of elements. |
144 | // Flag word: |
145 | // bits 0-11: dst sel, ignored by these intrinsics |
146 | // bits 12-14: data format (ignored, must be nonzero, 7=float) |
147 | // bits 15-18: data format (ignored, must be nonzero, 4=32bit) |
148 | // bit 19: In nested heap (0 here) |
149 | // bit 20: Behavior on unmap (0 means "return 0 / ignore") |
150 | // bits 21-22: Index stride for swizzles (N/A) |
151 | // bit 23: Add thread ID (0) |
152 | // bit 24: Reserved to 1 (RDNA) or 0 (CDNA) |
153 | // bits 25-26: Reserved (0) |
154 | // bit 27: Buffer is non-volatile (CDNA only) |
155 | // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 = |
156 | // none, 3 = either swizzles or testing against offset field) RDNA only |
157 | // bits 30-31: Type (must be 0) |
158 | uint32_t flags = (7 << 12) | (4 << 15); |
159 | if (chipset.majorVersion >= 10) { |
160 | flags |= (1 << 24); |
161 | uint32_t oob = boundsCheck ? 3 : 2; |
162 | flags |= (oob << 28); |
163 | } |
164 | Value flagsConst = createI32Constant(rewriter, loc, value: flags); |
165 | Type rsrcType = |
166 | LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace); |
167 | Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>( |
168 | loc, rsrcType, basePointer, stride, numRecords, flagsConst); |
169 | return resource; |
170 | } |
171 | |
172 | namespace { |
173 | struct FatRawBufferCastLowering |
174 | : public ConvertOpToLLVMPattern<FatRawBufferCastOp> { |
175 | FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset) |
176 | : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter), |
177 | chipset(chipset) {} |
178 | |
179 | Chipset chipset; |
180 | |
181 | LogicalResult |
182 | matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor, |
183 | ConversionPatternRewriter &rewriter) const override { |
184 | Location loc = op.getLoc(); |
185 | Value memRef = adaptor.getSource(); |
186 | Value unconvertedMemref = op.getSource(); |
187 | MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType()); |
188 | MemRefDescriptor descriptor(memRef); |
189 | |
190 | DataLayout dataLayout = DataLayout::closest(op: op); |
191 | int64_t elementByteWidth = |
192 | dataLayout.getTypeSizeInBits(t: memrefType.getElementType()) / 8; |
193 | |
194 | int64_t unusedOffset = 0; |
195 | SmallVector<int64_t, 5> strideVals; |
196 | if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset))) |
197 | return op.emitOpError("Can't lower non-stride-offset memrefs"); |
198 | |
199 | Value numRecords = adaptor.getValidBytes(); |
200 | if (!numRecords) |
201 | numRecords = getNumRecords(rewriter, loc, memrefType, descriptor, |
202 | strideVals, elementByteWidth); |
203 | |
204 | Value basePointer = |
205 | adaptor.getResetOffset() |
206 | ? descriptor.bufferPtr(builder&: rewriter, loc, converter: *getTypeConverter(), |
207 | type: memrefType) |
208 | : descriptor.alignedPtr(builder&: rewriter, loc); |
209 | |
210 | Value offset = adaptor.getResetOffset() |
211 | ? rewriter.create<LLVM::ConstantOp>( |
212 | loc, getIndexType(), rewriter.getIndexAttr(0)) |
213 | : descriptor.offset(rewriter, loc); |
214 | |
215 | bool hasSizes = memrefType.getRank() > 0; |
216 | // No need to unpack() and pack() all the individual sizes and strides, |
217 | // so we'll just extract the arrays. |
218 | Value sizes = hasSizes ? rewriter.create<LLVM::ExtractValueOp>( |
219 | loc, descriptor, kSizePosInMemRefDescriptor) |
220 | : Value{}; |
221 | Value strides = hasSizes |
222 | ? rewriter.create<LLVM::ExtractValueOp>( |
223 | loc, descriptor, kStridePosInMemRefDescriptor) |
224 | : Value{}; |
225 | |
226 | Value fatPtr = makeBufferRsrc( |
227 | rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(), |
228 | chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7); |
229 | |
230 | Value result = MemRefDescriptor::poison( |
231 | builder&: rewriter, loc, |
232 | descriptorType: getTypeConverter()->convertType(op.getResult().getType())); |
233 | result = rewriter.create<LLVM::InsertValueOp>( |
234 | loc, result, fatPtr, kAllocatedPtrPosInMemRefDescriptor); |
235 | result = rewriter.create<LLVM::InsertValueOp>( |
236 | loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor); |
237 | result = rewriter.create<LLVM::InsertValueOp>(loc, result, offset, |
238 | kOffsetPosInMemRefDescriptor); |
239 | if (hasSizes) { |
240 | result = rewriter.create<LLVM::InsertValueOp>(loc, result, sizes, |
241 | kSizePosInMemRefDescriptor); |
242 | result = rewriter.create<LLVM::InsertValueOp>( |
243 | loc, result, strides, kStridePosInMemRefDescriptor); |
244 | } |
245 | rewriter.replaceOp(op, result); |
246 | return success(); |
247 | } |
248 | }; |
249 | |
250 | /// Define lowering patterns for raw buffer ops |
251 | template <typename GpuOp, typename Intrinsic> |
252 | struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> { |
253 | RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset) |
254 | : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {} |
255 | |
256 | Chipset chipset; |
257 | static constexpr uint32_t maxVectorOpWidth = 128; |
258 | |
259 | LogicalResult |
260 | matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor, |
261 | ConversionPatternRewriter &rewriter) const override { |
262 | Location loc = gpuOp.getLoc(); |
263 | Value memref = adaptor.getMemref(); |
264 | Value unconvertedMemref = gpuOp.getMemref(); |
265 | MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType()); |
266 | |
267 | if (chipset.majorVersion < 9) |
268 | return gpuOp.emitOpError("raw buffer ops require GCN or higher"); |
269 | |
270 | Value storeData = adaptor.getODSOperands(0)[0]; |
271 | if (storeData == memref) // no write component to this op |
272 | storeData = Value(); |
273 | Type wantedDataType; |
274 | if (storeData) |
275 | wantedDataType = storeData.getType(); |
276 | else |
277 | wantedDataType = gpuOp.getODSResults(0)[0].getType(); |
278 | |
279 | Value atomicCmpData = Value(); |
280 | // Operand index 1 of a load is the indices, trying to read them can crash. |
281 | if (storeData) { |
282 | Value maybeCmpData = adaptor.getODSOperands(1)[0]; |
283 | if (maybeCmpData != memref) |
284 | atomicCmpData = maybeCmpData; |
285 | } |
286 | |
287 | Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType); |
288 | |
289 | Type i32 = rewriter.getI32Type(); |
290 | |
291 | // Get the type size in bytes. |
292 | DataLayout dataLayout = DataLayout::closest(op: gpuOp); |
293 | int64_t elementByteWidth = |
294 | dataLayout.getTypeSizeInBits(t: memrefType.getElementType()) / 8; |
295 | Value byteWidthConst = createI32Constant(rewriter, loc, value: elementByteWidth); |
296 | |
297 | // If we want to load a vector<NxT> with total size <= 32 |
298 | // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32 |
299 | // and the total load size is >= 32, use a vector load of N / (bitsize(T) / |
300 | // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands, |
301 | // so bitcast any floats to integers. |
302 | Type llvmBufferValType = llvmWantedDataType; |
303 | if (atomicCmpData) { |
304 | if (auto floatType = dyn_cast<FloatType>(wantedDataType)) |
305 | llvmBufferValType = this->getTypeConverter()->convertType( |
306 | rewriter.getIntegerType(floatType.getWidth())); |
307 | } |
308 | if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) { |
309 | uint32_t vecLen = dataVector.getNumElements(); |
310 | uint32_t elemBits = |
311 | dataLayout.getTypeSizeInBits(t: dataVector.getElementType()); |
312 | uint32_t totalBits = elemBits * vecLen; |
313 | bool usePackedFp16 = |
314 | isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2; |
315 | if (totalBits > maxVectorOpWidth) |
316 | return gpuOp.emitOpError( |
317 | "Total width of loads or stores must be no more than "+ |
318 | Twine(maxVectorOpWidth) + " bits, but we call for "+ |
319 | Twine(totalBits) + |
320 | " bits. This should've been caught in validation"); |
321 | if (!usePackedFp16 && elemBits < 32) { |
322 | if (totalBits > 32) { |
323 | if (totalBits % 32 != 0) |
324 | return gpuOp.emitOpError("Load or store of more than 32-bits that " |
325 | "doesn't fit into words. Can't happen\n"); |
326 | llvmBufferValType = this->typeConverter->convertType( |
327 | VectorType::get(totalBits / 32, i32)); |
328 | } else { |
329 | llvmBufferValType = this->typeConverter->convertType( |
330 | rewriter.getIntegerType(totalBits)); |
331 | } |
332 | } |
333 | } |
334 | if (auto vecType = dyn_cast<VectorType>(llvmBufferValType)) { |
335 | // Buffer intrinsics doesn't support 1-element vectors, cast them to |
336 | // scalars. |
337 | if (vecType.getNumElements() == 1) |
338 | llvmBufferValType = vecType.getElementType(); |
339 | } |
340 | |
341 | SmallVector<Value, 6> args; |
342 | if (storeData) { |
343 | if (llvmBufferValType != llvmWantedDataType) { |
344 | Value castForStore = |
345 | rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData); |
346 | args.push_back(Elt: castForStore); |
347 | } else { |
348 | args.push_back(Elt: storeData); |
349 | } |
350 | } |
351 | |
352 | if (atomicCmpData) { |
353 | if (llvmBufferValType != llvmWantedDataType) { |
354 | Value castForCmp = rewriter.create<LLVM::BitcastOp>( |
355 | loc, llvmBufferValType, atomicCmpData); |
356 | args.push_back(Elt: castForCmp); |
357 | } else { |
358 | args.push_back(Elt: atomicCmpData); |
359 | } |
360 | } |
361 | |
362 | // Construct buffer descriptor from memref, attributes |
363 | int64_t offset = 0; |
364 | SmallVector<int64_t, 5> strides; |
365 | if (failed(memrefType.getStridesAndOffset(strides, offset))) |
366 | return gpuOp.emitOpError("Can't lower non-stride-offset memrefs"); |
367 | |
368 | MemRefDescriptor memrefDescriptor(memref); |
369 | |
370 | Value ptr = memrefDescriptor.bufferPtr( |
371 | builder&: rewriter, loc, converter: *this->getTypeConverter(), type: memrefType); |
372 | Value numRecords = getNumRecords( |
373 | rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth); |
374 | Value resource = makeBufferRsrc(rewriter, loc, ptr, numRecords, |
375 | adaptor.getBoundsCheck(), chipset); |
376 | args.push_back(Elt: resource); |
377 | |
378 | // Indexing (voffset) |
379 | Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor, |
380 | adaptor.getIndices(), strides); |
381 | if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset(); |
382 | indexOffset && *indexOffset > 0) { |
383 | Value extraOffsetConst = createI32Constant(rewriter, loc, value: *indexOffset); |
384 | voffset = |
385 | voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst) |
386 | : extraOffsetConst; |
387 | } |
388 | voffset = rewriter.create<LLVM::MulOp>(loc, voffset, byteWidthConst); |
389 | args.push_back(Elt: voffset); |
390 | |
391 | // SGPR offset. |
392 | Value sgprOffset = adaptor.getSgprOffset(); |
393 | if (!sgprOffset) |
394 | sgprOffset = createI32Constant(rewriter, loc, value: 0); |
395 | sgprOffset = rewriter.create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst); |
396 | args.push_back(Elt: sgprOffset); |
397 | |
398 | // bit 0: GLC = 0 (atomics drop value, less coherency) |
399 | // bits 1-2: SLC, DLC = 0 (similarly) |
400 | // bit 3: swizzled (0 for raw) |
401 | args.push_back(Elt: createI32Constant(rewriter, loc, value: 0)); |
402 | |
403 | llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(), |
404 | llvmBufferValType); |
405 | Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args, |
406 | ArrayRef<NamedAttribute>()); |
407 | if (lowered->getNumResults() == 1) { |
408 | Value replacement = lowered->getResult(idx: 0); |
409 | if (llvmBufferValType != llvmWantedDataType) { |
410 | replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType, |
411 | replacement); |
412 | } |
413 | rewriter.replaceOp(gpuOp, replacement); |
414 | } else { |
415 | rewriter.eraseOp(op: gpuOp); |
416 | } |
417 | return success(); |
418 | } |
419 | }; |
420 | |
421 | struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> { |
422 | LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset) |
423 | : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {} |
424 | |
425 | Chipset chipset; |
426 | |
427 | LogicalResult |
428 | matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor, |
429 | ConversionPatternRewriter &rewriter) const override { |
430 | bool requiresInlineAsm = chipset < kGfx90a || chipset.majorVersion == 11; |
431 | |
432 | if (requiresInlineAsm) { |
433 | auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), |
434 | LLVM::AsmDialect::AD_ATT); |
435 | const char *asmStr = |
436 | ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier"; |
437 | const char *constraints = ""; |
438 | rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>( |
439 | op, |
440 | /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(), |
441 | /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true, |
442 | /*is_align_stack=*/false, LLVM::TailCallKind::None, |
443 | /*asm_dialect=*/asmDialectAttr, |
444 | /*operand_attrs=*/ArrayAttr()); |
445 | return success(); |
446 | } |
447 | if (chipset.majorVersion < 12) { |
448 | constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8); |
449 | constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8); |
450 | // Left in place in case someone disables the inline ASM path or future |
451 | // chipsets use the same bit pattern. |
452 | constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4); |
453 | |
454 | int32_t ldsOnlyBits; |
455 | if (chipset.majorVersion == 11) |
456 | ldsOnlyBits = ldsOnlyBitsGfx11; |
457 | else if (chipset.majorVersion == 10) |
458 | ldsOnlyBits = ldsOnlyBitsGfx10; |
459 | else if (chipset.majorVersion <= 9) |
460 | ldsOnlyBits = ldsOnlyBitsGfx6789; |
461 | else |
462 | return op.emitOpError( |
463 | "don't know how to lower this for chipset major version") |
464 | << chipset.majorVersion; |
465 | |
466 | Location loc = op->getLoc(); |
467 | rewriter.create<ROCDL::SWaitcntOp>(loc, ldsOnlyBits); |
468 | rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op); |
469 | } else { |
470 | Location loc = op->getLoc(); |
471 | rewriter.create<ROCDL::WaitDscntOp>(loc, 0); |
472 | rewriter.create<ROCDL::BarrierSignalOp>(loc, -1); |
473 | rewriter.replaceOpWithNewOp<ROCDL::BarrierWaitOp>(op, -1); |
474 | } |
475 | |
476 | return success(); |
477 | } |
478 | }; |
479 | |
480 | struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> { |
481 | SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset) |
482 | : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {} |
483 | |
484 | Chipset chipset; |
485 | |
486 | LogicalResult |
487 | matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor, |
488 | ConversionPatternRewriter &rewriter) const override { |
489 | rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op, |
490 | (uint32_t)op.getOpts()); |
491 | return success(); |
492 | } |
493 | }; |
494 | |
495 | } // namespace |
496 | |
497 | /// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL |
498 | /// and LLVM AMDGPU intrinsics convention. |
499 | /// |
500 | /// Specifically: |
501 | /// 1. If the element type is bfloat16, bitcast it to i16. |
502 | /// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32> |
503 | /// instead, which is what the f8f6f4 intrinsics use. |
504 | /// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit |
505 | /// integer. |
506 | /// |
507 | /// Note that the type of `input` has already been LLVM type converted: |
508 | /// therefore 8-bit and smaller floats are represented as their corresponding |
509 | /// `iN` integers. |
510 | static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, |
511 | Location loc, Value input) { |
512 | Type inputType = input.getType(); |
513 | if (auto vectorType = dyn_cast<VectorType>(inputType)) { |
514 | if (vectorType.getElementType().isBF16()) |
515 | return rewriter.create<LLVM::BitcastOp>( |
516 | loc, vectorType.clone(rewriter.getI16Type()), input); |
517 | if (vectorType.getElementType().isInteger(8) && |
518 | vectorType.getNumElements() <= 8) |
519 | return rewriter.create<LLVM::BitcastOp>( |
520 | loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input); |
521 | if (isa<IntegerType>(vectorType.getElementType()) && |
522 | vectorType.getElementTypeBitWidth() <= 8) { |
523 | int64_t numWords = llvm::divideCeil( |
524 | vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), |
525 | 32); |
526 | return rewriter.create<LLVM::BitcastOp>( |
527 | loc, VectorType::get(numWords, rewriter.getI32Type()), input); |
528 | } |
529 | } |
530 | return input; |
531 | } |
532 | |
533 | /// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU |
534 | /// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention. |
535 | /// |
536 | /// Specifically: |
537 | /// 1. If `input` is a i8 value, zero extend it to i32 |
538 | /// 2. If `input` is a vector of length 4 and type i8, cast it to i32 |
539 | /// |
540 | /// Note that the type of `input` has already been LLVM type converted: |
541 | /// therefore 8-bit and smaller floats are represented as their corresponding |
542 | /// `iN` integers. |
543 | static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, |
544 | Location loc, Value input) { |
545 | Type inputType = input.getType(); |
546 | Type outputType = rewriter.getI32Type(); |
547 | if (auto intType = dyn_cast<IntegerType>(inputType)) |
548 | return rewriter.create<LLVM::ZExtOp>(loc, outputType, input); |
549 | return rewriter.create<LLVM::BitcastOp>(loc, outputType, input); |
550 | } |
551 | |
552 | /// Push an input operand. If it is a float type, nothing to do. If it is |
553 | /// an integer type, then we need to also push its signdness (1 for signed, 0 |
554 | /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32 |
555 | /// vector (or the 8xi8 vector into a 2xi32 one for gfx12+). |
556 | /// We also need to convert bfloat inputs to i16 to account for the bfloat |
557 | /// intrinsics having been defined before the AMD backend supported bfloat. We |
558 | /// similarly need to pack 8-bit float types into integers as if they were i8 |
559 | /// (which they are for the backend's purposes). |
560 | static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, |
561 | Location loc, |
562 | const TypeConverter *typeConverter, |
563 | bool isUnsigned, Value llvmInput, |
564 | Value mlirInput, |
565 | SmallVector<Value, 4> &operands) { |
566 | Type inputType = llvmInput.getType(); |
567 | auto vectorType = dyn_cast<VectorType>(inputType); |
568 | if (!vectorType) { |
569 | operands.push_back(Elt: llvmInput); |
570 | return; |
571 | } |
572 | Type elemType = vectorType.getElementType(); |
573 | |
574 | if (elemType.isBF16()) |
575 | llvmInput = rewriter.create<LLVM::BitcastOp>( |
576 | loc, vectorType.clone(rewriter.getI16Type()), llvmInput); |
577 | if (elemType.getIntOrFloatBitWidth() > 8) { |
578 | operands.push_back(Elt: llvmInput); |
579 | return; |
580 | } |
581 | |
582 | // We need to check the type of the input before conversion to properly test |
583 | // for int8. This is because, in LLVM, fp8 type is converted to int8, so the |
584 | // fp8/int8 information is lost during the conversion process. |
585 | auto mlirInputType = cast<VectorType>(mlirInput.getType()); |
586 | bool isInputInteger = mlirInputType.getElementType().isInteger(); |
587 | if (isInputInteger) { |
588 | // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag |
589 | bool localIsUnsigned = isUnsigned; |
590 | if (elemType.isUnsignedInteger()) { |
591 | localIsUnsigned = true; |
592 | } else if (elemType.isSignedInteger()) { |
593 | localIsUnsigned = false; |
594 | } |
595 | Value sign = createI1Constant(rewriter, loc, value: !localIsUnsigned); |
596 | operands.push_back(Elt: sign); |
597 | } |
598 | |
599 | int64_t numBits = |
600 | vectorType.getNumElements() * elemType.getIntOrFloatBitWidth(); |
601 | Type i32 = rewriter.getI32Type(); |
602 | Type intrinsicInType = numBits <= 32 |
603 | ? (Type)rewriter.getIntegerType(numBits) |
604 | : (Type)VectorType::get(numBits / 32, i32); |
605 | auto llvmIntrinsicInType = typeConverter->convertType(t: intrinsicInType); |
606 | Value castInput = rewriter.createOrFold<LLVM::BitcastOp>( |
607 | loc, llvmIntrinsicInType, llvmInput); |
608 | // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need |
609 | // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments. |
610 | // Add in the zeros here. |
611 | if (numBits < 32) |
612 | castInput = rewriter.create<LLVM::ZExtOp>(loc, i32, castInput); |
613 | operands.push_back(Elt: castInput); |
614 | } |
615 | |
616 | /// Push the output operand. For many cases this is only pushing the output in |
617 | /// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics, |
618 | /// since the same numbers of VGPRs is used, we need to decide if to store the |
619 | /// result in the upper 16 bits of the VGPRs or in the lower part. To store the |
620 | /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will |
621 | /// be stored it in the upper part. The subwordOffset must not be set for gfx12, |
622 | /// as the instructions have been changed to return fewer registers instead. |
623 | static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, |
624 | Location loc, |
625 | const TypeConverter *typeConverter, |
626 | Value output, int32_t subwordOffset, |
627 | bool clamp, SmallVector<Value, 4> &operands) { |
628 | Type inputType = output.getType(); |
629 | auto vectorType = dyn_cast<VectorType>(inputType); |
630 | Type elemType = vectorType.getElementType(); |
631 | if (elemType.isBF16()) |
632 | output = rewriter.create<LLVM::BitcastOp>( |
633 | loc, vectorType.clone(rewriter.getI16Type()), output); |
634 | operands.push_back(Elt: output); |
635 | if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(width: 16)) { |
636 | operands.push_back(Elt: createI1Constant(rewriter, loc, value: subwordOffset)); |
637 | } else if (elemType.isInteger(width: 32)) { |
638 | operands.push_back(Elt: createI1Constant(rewriter, loc, value: clamp)); |
639 | } |
640 | } |
641 | |
642 | /// Return true if `type` is the E5M2 variant of an 8-bit float that is |
643 | /// supported by the `_bf8` instructions on the given `chipset`. |
644 | static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type) { |
645 | return (chipset == kGfx942 && isa<Float8E5M2FNUZType>(type)) || |
646 | (hasOcpFp8(chipset) && isa<Float8E5M2Type>(type)); |
647 | } |
648 | |
649 | /// Return true if `type` is the E4M3FN variant of an 8-bit float that is |
650 | /// supported by the `_fp8` instructions on the given `chipset`. |
651 | static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type) { |
652 | return (chipset == kGfx942 && isa<Float8E4M3FNUZType>(type)) || |
653 | (hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type)); |
654 | } |
655 | |
656 | /// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma` |
657 | /// if one exists. This includes checking to ensure the intrinsic is supported |
658 | /// on the architecture you are compiling for. |
659 | static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma, |
660 | Chipset chipset) { |
661 | uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(), |
662 | b = mfma.getBlocks(); |
663 | Type sourceElem = getElementTypeOrSelf(mfma.getSourceA().getType()); |
664 | Type destElem = getElementTypeOrSelf(mfma.getDestC().getType()); |
665 | |
666 | if (sourceElem.isF32() && destElem.isF32()) { |
667 | if (mfma.getReducePrecision() && chipset >= kGfx942) { |
668 | if (m == 32 && n == 32 && k == 4 && b == 1) |
669 | return ROCDL::mfma_f32_32x32x4_xf32::getOperationName(); |
670 | if (m == 16 && n == 16 && k == 8 && b == 1) |
671 | return ROCDL::mfma_f32_16x16x8_xf32::getOperationName(); |
672 | } |
673 | if (m == 32 && n == 32 && k == 1 && b == 2) |
674 | return ROCDL::mfma_f32_32x32x1f32::getOperationName(); |
675 | if (m == 16 && n == 16 && k == 1 && b == 4) |
676 | return ROCDL::mfma_f32_16x16x1f32::getOperationName(); |
677 | if (m == 4 && n == 4 && k == 1 && b == 16) |
678 | return ROCDL::mfma_f32_4x4x1f32::getOperationName(); |
679 | if (m == 32 && n == 32 && k == 2 && b == 1) |
680 | return ROCDL::mfma_f32_32x32x2f32::getOperationName(); |
681 | if (m == 16 && n == 16 && k == 4 && b == 1) |
682 | return ROCDL::mfma_f32_16x16x4f32::getOperationName(); |
683 | } |
684 | |
685 | if (sourceElem.isF16() && destElem.isF32()) { |
686 | if (chipset >= kGfx950) { |
687 | if (m == 32 && n == 32 && k == 16 && b == 1) |
688 | return ROCDL::mfma_f32_32x32x16_f16::getOperationName(); |
689 | if (m == 16 && n == 16 && k == 32 && b == 1) |
690 | return ROCDL::mfma_f32_16x16x32_f16::getOperationName(); |
691 | } |
692 | if (m == 32 && n == 32 && k == 4 && b == 2) |
693 | return ROCDL::mfma_f32_32x32x4f16::getOperationName(); |
694 | if (m == 16 && n == 16 && k == 4 && b == 4) |
695 | return ROCDL::mfma_f32_16x16x4f16::getOperationName(); |
696 | if (m == 4 && n == 4 && k == 4 && b == 16) |
697 | return ROCDL::mfma_f32_4x4x4f16::getOperationName(); |
698 | if (m == 32 && n == 32 && k == 8 && b == 1) |
699 | return ROCDL::mfma_f32_32x32x8f16::getOperationName(); |
700 | if (m == 16 && n == 16 && k == 16 && b == 1) |
701 | return ROCDL::mfma_f32_16x16x16f16::getOperationName(); |
702 | } |
703 | |
704 | if (sourceElem.isBF16() && destElem.isF32()) { |
705 | if (chipset >= kGfx950) { |
706 | if (m == 32 && n == 32 && k == 16 && b == 1) |
707 | return ROCDL::mfma_f32_32x32x16_bf16::getOperationName(); |
708 | if (m == 16 && n == 16 && k == 32 && b == 1) |
709 | return ROCDL::mfma_f32_16x16x32_bf16::getOperationName(); |
710 | } |
711 | if (chipset >= kGfx90a) { |
712 | if (m == 32 && n == 32 && k == 4 && b == 2) |
713 | return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName(); |
714 | if (m == 16 && n == 16 && k == 4 && b == 4) |
715 | return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName(); |
716 | if (m == 4 && n == 4 && k == 4 && b == 16) |
717 | return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName(); |
718 | if (m == 32 && n == 32 && k == 8 && b == 1) |
719 | return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName(); |
720 | if (m == 16 && n == 16 && k == 16 && b == 1) |
721 | return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName(); |
722 | } |
723 | if (m == 32 && n == 32 && k == 2 && b == 2) |
724 | return ROCDL::mfma_f32_32x32x2bf16::getOperationName(); |
725 | if (m == 16 && n == 16 && k == 2 && b == 4) |
726 | return ROCDL::mfma_f32_16x16x2bf16::getOperationName(); |
727 | if (m == 4 && n == 4 && k == 2 && b == 16) |
728 | return ROCDL::mfma_f32_4x4x2bf16::getOperationName(); |
729 | if (m == 32 && n == 32 && k == 4 && b == 1) |
730 | return ROCDL::mfma_f32_32x32x4bf16::getOperationName(); |
731 | if (m == 16 && n == 16 && k == 8 && b == 1) |
732 | return ROCDL::mfma_f32_16x16x8bf16::getOperationName(); |
733 | } |
734 | |
735 | if (sourceElem.isInteger(width: 8) && destElem.isInteger(width: 32)) { |
736 | if (chipset >= kGfx950) { |
737 | if (m == 32 && n == 32 && k == 32 && b == 1) |
738 | return ROCDL::mfma_i32_32x32x32_i8::getOperationName(); |
739 | if (m == 16 && n == 16 && k == 64 && b == 1) |
740 | return ROCDL::mfma_i32_16x16x64_i8::getOperationName(); |
741 | } |
742 | if (m == 32 && n == 32 && k == 4 && b == 2) |
743 | return ROCDL::mfma_i32_32x32x4i8::getOperationName(); |
744 | if (m == 16 && n == 16 && k == 4 && b == 4) |
745 | return ROCDL::mfma_i32_16x16x4i8::getOperationName(); |
746 | if (m == 4 && n == 4 && k == 4 && b == 16) |
747 | return ROCDL::mfma_i32_4x4x4i8::getOperationName(); |
748 | if (m == 32 && n == 32 && k == 8 && b == 1) |
749 | return ROCDL::mfma_i32_32x32x8i8::getOperationName(); |
750 | if (m == 16 && n == 16 && k == 16 && b == 1) |
751 | return ROCDL::mfma_i32_16x16x16i8::getOperationName(); |
752 | if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx942) |
753 | return ROCDL::mfma_i32_32x32x16_i8::getOperationName(); |
754 | if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx942) |
755 | return ROCDL::mfma_i32_16x16x32_i8::getOperationName(); |
756 | } |
757 | |
758 | if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) { |
759 | if (m == 16 && n == 16 && k == 4 && b == 1) |
760 | return ROCDL::mfma_f64_16x16x4f64::getOperationName(); |
761 | if (m == 4 && n == 4 && k == 4 && b == 4) |
762 | return ROCDL::mfma_f64_4x4x4f64::getOperationName(); |
763 | } |
764 | |
765 | if (destElem.isF32() && typeIsExpectedBf8ForChipset(chipset, type: sourceElem)) { |
766 | // Known to be correct because there are no scalar f8 instructions and |
767 | // because a length mismatch will have been caught by the verifier. |
768 | Type sourceBElem = |
769 | cast<VectorType>(mfma.getSourceB().getType()).getElementType(); |
770 | if (m == 16 && n == 16 && k == 32 && b == 1) { |
771 | if (typeIsExpectedBf8ForChipset(chipset, sourceBElem)) |
772 | return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName(); |
773 | if (typeIsExpectedFp8ForChipset(chipset, sourceBElem)) |
774 | return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName(); |
775 | } |
776 | if (m == 32 && n == 32 && k == 16 && b == 1) { |
777 | if (typeIsExpectedBf8ForChipset(chipset, sourceBElem)) |
778 | return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName(); |
779 | if (typeIsExpectedFp8ForChipset(chipset, sourceBElem)) |
780 | return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName(); |
781 | } |
782 | } |
783 | |
784 | if (destElem.isF32() && typeIsExpectedFp8ForChipset(chipset, type: sourceElem)) { |
785 | Type sourceBElem = |
786 | cast<VectorType>(mfma.getSourceB().getType()).getElementType(); |
787 | if (m == 16 && n == 16 && k == 32 && b == 1) { |
788 | if (typeIsExpectedBf8ForChipset(chipset, sourceBElem)) |
789 | return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName(); |
790 | if (typeIsExpectedFp8ForChipset(chipset, sourceBElem)) |
791 | return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName(); |
792 | } |
793 | if (m == 32 && n == 32 && k == 16 && b == 1) { |
794 | if (typeIsExpectedBf8ForChipset(chipset, sourceBElem)) |
795 | return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName(); |
796 | if (typeIsExpectedFp8ForChipset(chipset, sourceBElem)) |
797 | return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName(); |
798 | } |
799 | } |
800 | |
801 | return std::nullopt; |
802 | } |
803 | |
804 | static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) { |
805 | return llvm::TypeSwitch<Type, std::optional<uint32_t>>(mlirElemType) |
806 | .Case(caseFn: [](Float8E4M3FNType) { return 0u; }) |
807 | .Case(caseFn: [](Float8E5M2Type) { return 1u; }) |
808 | .Case(caseFn: [](Float6E2M3FNType) { return 2u; }) |
809 | .Case(caseFn: [](Float6E3M2FNType) { return 3u; }) |
810 | .Case(caseFn: [](Float4E2M1FNType) { return 4u; }) |
811 | .Default(defaultFn: [](Type) { return std::nullopt; }); |
812 | } |
813 | |
814 | /// If there is a scaled MFMA instruction for the input element types `aType` |
815 | /// and `bType`, output type `destType`, problem size M, N, K, and B (number of |
816 | /// blocks) on the given `chipset`, return a tuple consisting of the |
817 | /// OperationName of the intrinsic and the type codes that need to be passed to |
818 | /// that intrinsic. Note that this is also used to implement some un-scaled |
819 | /// MFMAs, since the compiler represents the ordinary instruction as a "scaled" |
820 | /// MFMA with a scale of 0. |
821 | static std::optional<std::tuple<StringRef, uint32_t, uint32_t>> |
822 | mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, |
823 | uint32_t n, uint32_t k, uint32_t b, Chipset chipset) { |
824 | aType = getElementTypeOrSelf(type: aType); |
825 | bType = getElementTypeOrSelf(type: bType); |
826 | destType = getElementTypeOrSelf(type: destType); |
827 | |
828 | if (chipset < kGfx950) |
829 | return std::nullopt; |
830 | if (!isa<Float32Type>(destType)) |
831 | return std::nullopt; |
832 | |
833 | std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(mlirElemType: aType); |
834 | std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(mlirElemType: bType); |
835 | if (!aTypeCode || !bTypeCode) |
836 | return std::nullopt; |
837 | |
838 | if (m == 32 && n == 32 && k == 64 && b == 1) |
839 | return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(), |
840 | *aTypeCode, *bTypeCode}; |
841 | if (m == 16 && n == 16 && k == 128 && b == 1) |
842 | return std::tuple{ |
843 | ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode, |
844 | *bTypeCode}; |
845 | |
846 | return std::nullopt; |
847 | } |
848 | |
849 | static std::optional<std::tuple<StringRef, uint32_t, uint32_t>> |
850 | mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) { |
851 | return mfmaOpToScaledIntrinsic( |
852 | mfma.getSourceA().getType(), mfma.getSourceB().getType(), |
853 | mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(), |
854 | mfma.getBlocks(), chipset); |
855 | } |
856 | |
857 | static std::optional<std::tuple<StringRef, uint32_t, uint32_t>> |
858 | mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) { |
859 | return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(), |
860 | smfma.getSourceB().getType(), |
861 | smfma.getDestC().getType(), smfma.getM(), |
862 | smfma.getN(), smfma.getK(), 1u, chipset); |
863 | } |
864 | |
865 | /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma` |
866 | /// if one exists. This includes checking to ensure the intrinsic is supported |
867 | /// on the architecture you are compiling for. |
868 | static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, |
869 | Chipset chipset) { |
870 | auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType()); |
871 | auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType()); |
872 | auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType()); |
873 | auto elemSourceType = sourceVectorType.getElementType(); |
874 | auto elemBSourceType = sourceBVectorType.getElementType(); |
875 | auto elemDestType = destVectorType.getElementType(); |
876 | |
877 | if (elemSourceType.isF16() && elemDestType.isF32()) |
878 | return ROCDL::wmma_f32_16x16x16_f16::getOperationName(); |
879 | if (elemSourceType.isBF16() && elemDestType.isF32()) |
880 | return ROCDL::wmma_f32_16x16x16_bf16::getOperationName(); |
881 | if (elemSourceType.isF16() && elemDestType.isF16()) |
882 | return ROCDL::wmma_f16_16x16x16_f16::getOperationName(); |
883 | if (elemSourceType.isBF16() && elemDestType.isBF16()) |
884 | return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName(); |
885 | if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) |
886 | return ROCDL::wmma_i32_16x16x16_iu8::getOperationName(); |
887 | if (chipset.majorVersion == 11) { |
888 | if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) |
889 | return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); |
890 | } |
891 | if (chipset.majorVersion >= 12) { |
892 | if (isa<Float8E4M3FNType>(elemSourceType) && |
893 | isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32()) |
894 | return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName(); |
895 | if (isa<Float8E4M3FNType>(elemSourceType) && |
896 | isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32()) |
897 | return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName(); |
898 | if (isa<Float8E5M2Type>(elemSourceType) && |
899 | isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32()) |
900 | return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName(); |
901 | if (isa<Float8E5M2Type>(elemSourceType) && |
902 | isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32()) |
903 | return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName(); |
904 | if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) { |
905 | bool isWave64 = destVectorType.getNumElements() == 4; |
906 | // This is the ambiguous case. 8 inputs to the wave64 version means that |
907 | // we want the 16x16x32 version, but for wave32 they mean the short form. |
908 | bool has8Inputs = sourceVectorType.getNumElements() == 8; |
909 | if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs)) |
910 | return ROCDL::wmma_i32_16x16x32_iu4::getOperationName(); |
911 | return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); |
912 | } |
913 | } |
914 | return std::nullopt; |
915 | } |
916 | |
917 | namespace { |
918 | struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> { |
919 | MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) |
920 | : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {} |
921 | |
922 | Chipset chipset; |
923 | |
924 | LogicalResult |
925 | matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor, |
926 | ConversionPatternRewriter &rewriter) const override { |
927 | Location loc = op.getLoc(); |
928 | Type outType = typeConverter->convertType(op.getDestD().getType()); |
929 | Type intrinsicOutType = outType; |
930 | if (auto outVecType = dyn_cast<VectorType>(outType)) |
931 | if (outVecType.getElementType().isBF16()) |
932 | intrinsicOutType = outVecType.clone(rewriter.getI16Type()); |
933 | |
934 | if (chipset.majorVersion != 9 || chipset < kGfx908) |
935 | return op->emitOpError("MFMA only supported on gfx908+"); |
936 | uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp()); |
937 | if (op.getNegateA() || op.getNegateB() || op.getNegateC()) { |
938 | if (chipset < kGfx942) |
939 | return op.emitOpError("negation unsupported on older than gfx942"); |
940 | getBlgpField |= |
941 | op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2); |
942 | } |
943 | std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset); |
944 | std::optional<std::tuple<StringRef, uint32_t, uint32_t>> |
945 | maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset); |
946 | if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value()) |
947 | return op.emitOpError("no intrinsic matching MFMA size on given chipset"); |
948 | |
949 | bool isScaled = |
950 | !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value(); |
951 | if (isScaled && |
952 | (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) { |
953 | return op.emitOpError( |
954 | "non-default abid, blgp, and cbsz aren't supported on MFMAs that can " |
955 | "be scaled as those fields are used for type information"); |
956 | } |
957 | |
958 | StringRef intrinsicName = |
959 | isScaled ? std::get<0>(t&: *maybeScaledIntrinsic) : *maybeIntrinsic; |
960 | OperationState loweredOp(loc, intrinsicName); |
961 | loweredOp.addTypes(newTypes: intrinsicOutType); |
962 | loweredOp.addOperands( |
963 | newOperands: {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()), |
964 | convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()), |
965 | adaptor.getDestC()}); |
966 | if (isScaled) { |
967 | Value zero = createI32Constant(rewriter, loc, value: 0); |
968 | auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic; |
969 | loweredOp.addOperands(newOperands: {createI32Constant(rewriter, loc, aTypeCode), |
970 | createI32Constant(rewriter, loc, bTypeCode), |
971 | /*scale A byte=*/zero, /*scale A=*/zero, |
972 | /*scale B byte=*/zero, /*scale B=*/zero}); |
973 | } else { |
974 | loweredOp.addOperands(newOperands: {createI32Constant(rewriter, loc, op.getCbsz()), |
975 | createI32Constant(rewriter, loc, op.getAbid()), |
976 | createI32Constant(rewriter, loc, value: getBlgpField)}); |
977 | }; |
978 | Value lowered = rewriter.create(state: loweredOp)->getResult(idx: 0); |
979 | if (outType != intrinsicOutType) |
980 | lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered); |
981 | rewriter.replaceOp(op, lowered); |
982 | return success(); |
983 | } |
984 | }; |
985 | |
986 | struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> { |
987 | ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) |
988 | : ConvertOpToLLVMPattern(converter), chipset(chipset) {} |
989 | |
990 | Chipset chipset; |
991 | |
992 | LogicalResult |
993 | matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor, |
994 | ConversionPatternRewriter &rewriter) const override { |
995 | Location loc = op.getLoc(); |
996 | Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType()); |
997 | |
998 | if (chipset.majorVersion != 9 || chipset < kGfx950) |
999 | return op->emitOpError("scaled MFMA only supported on gfx908+"); |
1000 | std::optional<std::tuple<StringRef, uint32_t, uint32_t>> |
1001 | maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset); |
1002 | if (!maybeScaledIntrinsic.has_value()) |
1003 | return op.emitOpError( |
1004 | "no intrinsic matching scaled MFMA size on given chipset"); |
1005 | |
1006 | auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic; |
1007 | OperationState loweredOp(loc, intrinsicName); |
1008 | loweredOp.addTypes(newTypes: intrinsicOutType); |
1009 | loweredOp.addOperands( |
1010 | newOperands: {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()), |
1011 | convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()), |
1012 | adaptor.getDestC()}); |
1013 | Value scalesIdxA = |
1014 | createI32Constant(rewriter, loc, adaptor.getScalesIdxA()); |
1015 | Value scalesIdxB = |
1016 | createI32Constant(rewriter, loc, adaptor.getScalesIdxB()); |
1017 | loweredOp.addOperands( |
1018 | newOperands: {createI32Constant(rewriter, loc, aTypeCode), |
1019 | createI32Constant(rewriter, loc, bTypeCode), |
1020 | /*scales idx A=*/scalesIdxA, |
1021 | /*scales A*/ |
1022 | castMFMAScaleOperand(rewriter, loc, adaptor.getScalesA()), |
1023 | /*scales idx B=*/scalesIdxB, |
1024 | /*scales B*/ |
1025 | castMFMAScaleOperand(rewriter, loc, adaptor.getScalesB())}); |
1026 | Value lowered = rewriter.create(state: loweredOp)->getResult(idx: 0); |
1027 | rewriter.replaceOp(op, lowered); |
1028 | return success(); |
1029 | } |
1030 | }; |
1031 | |
1032 | struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> { |
1033 | WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) |
1034 | : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {} |
1035 | |
1036 | Chipset chipset; |
1037 | |
1038 | LogicalResult |
1039 | matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor, |
1040 | ConversionPatternRewriter &rewriter) const override { |
1041 | Location loc = op.getLoc(); |
1042 | auto outType = |
1043 | typeConverter->convertType<VectorType>(op.getDestD().getType()); |
1044 | if (!outType) |
1045 | return rewriter.notifyMatchFailure(op, "type conversion failed"); |
1046 | |
1047 | if (chipset.majorVersion != 11 && chipset.majorVersion != 12) |
1048 | return op->emitOpError("WMMA only supported on gfx11 and gfx12"); |
1049 | |
1050 | // The WMMA operations represent vectors of bf16s as vectors of i16s, so we |
1051 | // need to bitcast bfloats to i16 and then bitcast them back. |
1052 | VectorType rawOutType = outType; |
1053 | if (outType.getElementType().isBF16()) |
1054 | rawOutType = outType.clone(rewriter.getI16Type()); |
1055 | |
1056 | std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset); |
1057 | |
1058 | if (!maybeIntrinsic.has_value()) |
1059 | return op.emitOpError("no intrinsic matching WMMA on the given chipset"); |
1060 | |
1061 | if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0) |
1062 | return op.emitOpError("subwordOffset not supported on gfx12+"); |
1063 | |
1064 | OperationState loweredOp(loc, *maybeIntrinsic); |
1065 | loweredOp.addTypes(rawOutType); |
1066 | |
1067 | SmallVector<Value, 4> operands; |
1068 | wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), |
1069 | adaptor.getSourceA(), op.getSourceA(), operands); |
1070 | wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), |
1071 | adaptor.getSourceB(), op.getSourceB(), operands); |
1072 | wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(), |
1073 | op.getSubwordOffset(), op.getClamp(), operands); |
1074 | |
1075 | loweredOp.addOperands(newOperands: operands); |
1076 | Operation *lowered = rewriter.create(state: loweredOp); |
1077 | |
1078 | Operation *maybeCastBack = lowered; |
1079 | if (rawOutType != outType) |
1080 | maybeCastBack = |
1081 | rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0)); |
1082 | rewriter.replaceOp(op, maybeCastBack->getResults()); |
1083 | |
1084 | return success(); |
1085 | } |
1086 | }; |
1087 | |
1088 | struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> { |
1089 | GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset) |
1090 | : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {} |
1091 | |
1092 | Chipset chipset; |
1093 | |
1094 | LogicalResult |
1095 | matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor, |
1096 | ConversionPatternRewriter &rewriter) const override { |
1097 | if (chipset.majorVersion < 9 || chipset.majorVersion > 10) |
1098 | return op.emitOpError("pre-gfx9 and post-gfx10 not supported"); |
1099 | |
1100 | Location loc = op.getLoc(); |
1101 | |
1102 | auto srcMemRefType = cast<MemRefType>(op.getSrc().getType()); |
1103 | auto dstMemRefType = cast<MemRefType>(op.getSrc().getType()); |
1104 | |
1105 | // TODO: instead of only transfering one element per thread, we could |
1106 | // augment it to transfer multiple elements per thread by issuing multiple |
1107 | // `global_load_lds` instructions. |
1108 | Type transferType = op.getTransferType(); |
1109 | size_t loadWidth = [&]() -> size_t { |
1110 | if (auto transferVectorType = dyn_cast<VectorType>(transferType)) { |
1111 | return transferVectorType.getNumElements() * |
1112 | (transferVectorType.getElementTypeBitWidth() / 8); |
1113 | } |
1114 | return transferType.getIntOrFloatBitWidth() / 8; |
1115 | }(); |
1116 | |
1117 | // Currently only 1, 2, and 4 byte loads are supported. |
1118 | if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4) |
1119 | return op.emitOpError("chipset unsupported element size"); |
1120 | |
1121 | Value srcPtr = |
1122 | getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(), |
1123 | (adaptor.getSrcIndices())); |
1124 | Value dstPtr = |
1125 | getStridedElementPtr(rewriter, loc, dstMemRefType, adaptor.getDst(), |
1126 | (adaptor.getDstIndices())); |
1127 | |
1128 | rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>( |
1129 | op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth), |
1130 | /*offset=*/rewriter.getI32IntegerAttr(0), |
1131 | /*aux=*/rewriter.getI32IntegerAttr(0), ArrayAttr{}, ArrayAttr{}, |
1132 | ArrayAttr{}); |
1133 | |
1134 | return success(); |
1135 | } |
1136 | }; |
1137 | |
1138 | namespace { |
1139 | struct ExtPackedFp8OpLowering final |
1140 | : public ConvertOpToLLVMPattern<ExtPackedFp8Op> { |
1141 | ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset) |
1142 | : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter), |
1143 | chipset(chipset) {} |
1144 | Chipset chipset; |
1145 | |
1146 | LogicalResult |
1147 | matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor, |
1148 | ConversionPatternRewriter &rewriter) const override; |
1149 | }; |
1150 | |
1151 | struct PackedTrunc2xFp8OpLowering final |
1152 | : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> { |
1153 | PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter, |
1154 | Chipset chipset) |
1155 | : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter), |
1156 | chipset(chipset) {} |
1157 | Chipset chipset; |
1158 | |
1159 | LogicalResult |
1160 | matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor, |
1161 | ConversionPatternRewriter &rewriter) const override; |
1162 | }; |
1163 | |
1164 | struct PackedStochRoundFp8OpLowering final |
1165 | : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> { |
1166 | PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter, |
1167 | Chipset chipset) |
1168 | : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter), |
1169 | chipset(chipset) {} |
1170 | Chipset chipset; |
1171 | |
1172 | LogicalResult |
1173 | matchAndRewrite(PackedStochRoundFp8Op op, |
1174 | PackedStochRoundFp8OpAdaptor adaptor, |
1175 | ConversionPatternRewriter &rewriter) const override; |
1176 | }; |
1177 | } // end namespace |
1178 | |
1179 | LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( |
1180 | ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor, |
1181 | ConversionPatternRewriter &rewriter) const { |
1182 | Location loc = op.getLoc(); |
1183 | if (!(chipset == kGfx942 || hasOcpFp8(chipset))) |
1184 | return rewriter.notifyMatchFailure( |
1185 | arg&: loc, msg: "Fp8 conversion instructions are not available on target " |
1186 | "architecture and their emulation is not implemented"); |
1187 | Type v4i8 = |
1188 | getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type())); |
1189 | Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); |
1190 | Type f32 = getTypeConverter()->convertType(op.getResult().getType()); |
1191 | |
1192 | Value source = adaptor.getSource(); |
1193 | auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType()); |
1194 | auto resultVecType = dyn_cast<VectorType>(op.getResult().getType()); |
1195 | Type sourceElemType = getElementTypeOrSelf(op.getSource()); |
1196 | // Extend to a v4i8 |
1197 | if (!sourceVecType || sourceVecType.getNumElements() < 4) { |
1198 | Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8); |
1199 | if (!sourceVecType) { |
1200 | longVec = rewriter.create<LLVM::InsertElementOp>( |
1201 | loc, longVec, source, createI32Constant(rewriter, loc, 0)); |
1202 | } else { |
1203 | for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) { |
1204 | Value idx = createI32Constant(rewriter, loc, value: i); |
1205 | Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx); |
1206 | longVec = |
1207 | rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx); |
1208 | } |
1209 | } |
1210 | source = longVec; |
1211 | } |
1212 | Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source); |
1213 | if (resultVecType) { |
1214 | if (typeIsExpectedBf8ForChipset(chipset, type: sourceElemType)) { |
1215 | rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source, |
1216 | op.getIndex()); |
1217 | } else if (typeIsExpectedFp8ForChipset(chipset, type: sourceElemType)) { |
1218 | rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source, |
1219 | op.getIndex()); |
1220 | } |
1221 | } else { |
1222 | if (typeIsExpectedBf8ForChipset(chipset, type: sourceElemType)) { |
1223 | rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source, |
1224 | op.getIndex()); |
1225 | } else if (typeIsExpectedFp8ForChipset(chipset, type: sourceElemType)) { |
1226 | rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source, |
1227 | op.getIndex()); |
1228 | } |
1229 | } |
1230 | return success(); |
1231 | } |
1232 | |
1233 | LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite( |
1234 | PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor, |
1235 | ConversionPatternRewriter &rewriter) const { |
1236 | Location loc = op.getLoc(); |
1237 | if (!(chipset == kGfx942 || hasOcpFp8(chipset))) |
1238 | return rewriter.notifyMatchFailure( |
1239 | arg&: loc, msg: "Fp8 conversion instructions are not available on target " |
1240 | "architecture and their emulation is not implemented"); |
1241 | Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); |
1242 | |
1243 | Type resultType = op.getResult().getType(); |
1244 | Type resultElemType = getElementTypeOrSelf(type: resultType); |
1245 | |
1246 | Value sourceA = adaptor.getSourceA(); |
1247 | Value sourceB = adaptor.getSourceB(); |
1248 | if (!sourceB) |
1249 | sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType()); |
1250 | Value existing = adaptor.getExisting(); |
1251 | if (existing) |
1252 | existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing); |
1253 | else |
1254 | existing = rewriter.create<LLVM::UndefOp>(loc, i32); |
1255 | |
1256 | Value result; |
1257 | if (typeIsExpectedBf8ForChipset(chipset, resultElemType)) |
1258 | result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB, |
1259 | existing, op.getWordIndex()); |
1260 | else if (typeIsExpectedFp8ForChipset(chipset, resultElemType)) |
1261 | result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB, |
1262 | existing, op.getWordIndex()); |
1263 | |
1264 | result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>( |
1265 | op, getTypeConverter()->convertType(resultType), result); |
1266 | return success(); |
1267 | } |
1268 | |
1269 | LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite( |
1270 | PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor, |
1271 | ConversionPatternRewriter &rewriter) const { |
1272 | Location loc = op.getLoc(); |
1273 | if (!(chipset == kGfx942 || hasOcpFp8(chipset))) |
1274 | return rewriter.notifyMatchFailure( |
1275 | arg&: loc, msg: "Fp8 conversion instructions are not available on target " |
1276 | "architecture and their emulation is not implemented"); |
1277 | Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); |
1278 | |
1279 | Type resultType = op.getResult().getType(); |
1280 | Type resultElemType = getElementTypeOrSelf(type: resultType); |
1281 | |
1282 | Value source = adaptor.getSource(); |
1283 | Value stoch = adaptor.getStochiasticParam(); |
1284 | Value existing = adaptor.getExisting(); |
1285 | if (existing) |
1286 | existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing); |
1287 | else |
1288 | existing = rewriter.create<LLVM::UndefOp>(loc, i32); |
1289 | |
1290 | Value result; |
1291 | if (typeIsExpectedBf8ForChipset(chipset, resultElemType)) |
1292 | result = rewriter.create<ROCDL::CvtSrBf8F32Op>( |
1293 | loc, i32, source, stoch, existing, op.getStoreIndex()); |
1294 | else if (typeIsExpectedFp8ForChipset(chipset, resultElemType)) |
1295 | result = rewriter.create<ROCDL::CvtSrFp8F32Op>( |
1296 | loc, i32, source, stoch, existing, op.getStoreIndex()); |
1297 | |
1298 | result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>( |
1299 | op, getTypeConverter()->convertType(resultType), result); |
1300 | return success(); |
1301 | } |
1302 | |
1303 | // Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp |
1304 | // operation into the corresponding ROCDL instructions. |
1305 | struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> { |
1306 | AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset) |
1307 | : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {} |
1308 | Chipset chipset; |
1309 | |
1310 | LogicalResult |
1311 | matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor, |
1312 | ConversionPatternRewriter &rewriter) const override { |
1313 | |
1314 | // Convert the source operand to the corresponding LLVM type |
1315 | Location loc = DppOp.getLoc(); |
1316 | Value src = adaptor.getSrc(); |
1317 | Value old = adaptor.getOld(); |
1318 | Type srcType = src.getType(); |
1319 | Type oldType = old.getType(); |
1320 | Type llvmType = nullptr; |
1321 | if (srcType.getIntOrFloatBitWidth() < 32) { |
1322 | llvmType = rewriter.getI32Type(); |
1323 | } else if (isa<FloatType>(Val: srcType)) { |
1324 | llvmType = (srcType.getIntOrFloatBitWidth() == 32) |
1325 | ? rewriter.getF32Type() |
1326 | : rewriter.getF64Type(); |
1327 | } else if (isa<IntegerType>(Val: srcType)) { |
1328 | llvmType = (srcType.getIntOrFloatBitWidth() == 32) |
1329 | ? rewriter.getI32Type() |
1330 | : rewriter.getI64Type(); |
1331 | } |
1332 | auto llvmSrcIntType = typeConverter->convertType( |
1333 | rewriter.getIntegerType(srcType.getIntOrFloatBitWidth())); |
1334 | |
1335 | // If the source type is less of 32, use bitcast to convert it to i32. |
1336 | auto convertOperand = [&](Value operand, Type operandType) { |
1337 | if (operandType.getIntOrFloatBitWidth() <= 16) { |
1338 | if (llvm::isa<FloatType>(Val: operandType)) { |
1339 | operand = |
1340 | rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand); |
1341 | } |
1342 | auto llvmVecType = typeConverter->convertType(mlir::VectorType::get( |
1343 | 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType)); |
1344 | Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType); |
1345 | operand = rewriter.create<LLVM::InsertElementOp>( |
1346 | loc, undefVec, operand, createI32Constant(rewriter, loc, 0)); |
1347 | operand = rewriter.create<LLVM::BitcastOp>(loc, llvmType, operand); |
1348 | } |
1349 | return operand; |
1350 | }; |
1351 | |
1352 | src = convertOperand(src, srcType); |
1353 | old = convertOperand(old, oldType); |
1354 | |
1355 | // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h |
1356 | enum DppCtrl : unsigned { |
1357 | ROW_SHL0 = 0x100, |
1358 | ROW_SHR0 = 0x110, |
1359 | ROW_ROR0 = 0x120, |
1360 | WAVE_SHL1 = 0x130, |
1361 | WAVE_ROL1 = 0x134, |
1362 | WAVE_SHR1 = 0x138, |
1363 | WAVE_ROR1 = 0x13C, |
1364 | ROW_MIRROR = 0x140, |
1365 | ROW_HALF_MIRROR = 0x141, |
1366 | BCAST15 = 0x142, |
1367 | BCAST31 = 0x143, |
1368 | }; |
1369 | |
1370 | auto kind = DppOp.getKind(); |
1371 | auto permArgument = DppOp.getPermArgument(); |
1372 | uint32_t DppCtrl = 0; |
1373 | |
1374 | switch (kind) { |
1375 | |
1376 | case DPPPerm::quad_perm: |
1377 | if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) { |
1378 | int32_t i = 0; |
1379 | for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) { |
1380 | uint32_t num = elem.getInt(); |
1381 | DppCtrl |= num << (i * 2); |
1382 | i++; |
1383 | } |
1384 | } |
1385 | break; |
1386 | case DPPPerm::row_shl: |
1387 | if (auto intAttr = cast<IntegerAttr>(*permArgument)) { |
1388 | DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0; |
1389 | } |
1390 | break; |
1391 | case DPPPerm::row_shr: |
1392 | if (auto intAttr = cast<IntegerAttr>(*permArgument)) { |
1393 | DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0; |
1394 | } |
1395 | break; |
1396 | case DPPPerm::row_ror: |
1397 | if (auto intAttr = cast<IntegerAttr>(*permArgument)) { |
1398 | DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0; |
1399 | } |
1400 | break; |
1401 | case DPPPerm::wave_shl: |
1402 | DppCtrl = DppCtrl::WAVE_SHL1; |
1403 | break; |
1404 | case DPPPerm::wave_shr: |
1405 | DppCtrl = DppCtrl::WAVE_SHR1; |
1406 | break; |
1407 | case DPPPerm::wave_rol: |
1408 | DppCtrl = DppCtrl::WAVE_ROL1; |
1409 | break; |
1410 | case DPPPerm::wave_ror: |
1411 | DppCtrl = DppCtrl::WAVE_ROR1; |
1412 | break; |
1413 | case DPPPerm::row_mirror: |
1414 | DppCtrl = DppCtrl::ROW_MIRROR; |
1415 | break; |
1416 | case DPPPerm::row_half_mirror: |
1417 | DppCtrl = DppCtrl::ROW_HALF_MIRROR; |
1418 | break; |
1419 | case DPPPerm::row_bcast_15: |
1420 | DppCtrl = DppCtrl::BCAST15; |
1421 | break; |
1422 | case DPPPerm::row_bcast_31: |
1423 | DppCtrl = DppCtrl::BCAST31; |
1424 | break; |
1425 | } |
1426 | |
1427 | // Check for row_mask, bank_mask, bound_ctrl if they exist and create |
1428 | // constants |
1429 | auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt(); |
1430 | auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt(); |
1431 | bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue(); |
1432 | |
1433 | // create a ROCDL_DPPMovOp instruction with the appropriate attributes |
1434 | auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>( |
1435 | loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl); |
1436 | |
1437 | Value result = dppMovOp.getRes(); |
1438 | if (srcType.getIntOrFloatBitWidth() < 32) { |
1439 | result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result); |
1440 | if (!llvm::isa<IntegerType>(Val: srcType)) { |
1441 | result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result); |
1442 | } |
1443 | } |
1444 | |
1445 | // We are replacing the AMDGPU_DPPOp instruction with the new |
1446 | // ROCDL_DPPMovOp instruction |
1447 | rewriter.replaceOp(DppOp, ValueRange(result)); |
1448 | return success(); |
1449 | } |
1450 | }; |
1451 | |
1452 | struct AMDGPUSwizzleBitModeLowering |
1453 | : public ConvertOpToLLVMPattern<SwizzleBitModeOp> { |
1454 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
1455 | |
1456 | LogicalResult |
1457 | matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor, |
1458 | ConversionPatternRewriter &rewriter) const override { |
1459 | Location loc = op.getLoc(); |
1460 | Type i32 = rewriter.getI32Type(); |
1461 | Value src = adaptor.getSrc(); |
1462 | SmallVector<Value> decomposed = |
1463 | LLVM::decomposeValue(builder&: rewriter, loc, src, dstType: i32); |
1464 | unsigned andMask = op.getAndMask(); |
1465 | unsigned orMask = op.getOrMask(); |
1466 | unsigned xorMask = op.getXorMask(); |
1467 | |
1468 | // bit 15 is 0 for the BitMode swizzle. |
1469 | // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/ |
1470 | unsigned mask = andMask | (orMask << 5) | (xorMask << 10); |
1471 | Value maskValue = createI32Constant(rewriter, loc, value: mask); |
1472 | SmallVector<Value> swizzled; |
1473 | for (Value v : decomposed) { |
1474 | Value res = |
1475 | rewriter.create<ROCDL::DsSwizzleOp>(loc, v.getType(), v, maskValue); |
1476 | swizzled.emplace_back(res); |
1477 | } |
1478 | |
1479 | Value result = LLVM::composeValue(builder&: rewriter, loc, src: swizzled, dstType: src.getType()); |
1480 | rewriter.replaceOp(op, result); |
1481 | return success(); |
1482 | } |
1483 | }; |
1484 | |
1485 | struct ConvertAMDGPUToROCDLPass |
1486 | : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> { |
1487 | using Base::Base; |
1488 | |
1489 | void runOnOperation() override { |
1490 | MLIRContext *ctx = &getContext(); |
1491 | FailureOr<Chipset> maybeChipset = Chipset::parse(chipset); |
1492 | if (failed(Result: maybeChipset)) { |
1493 | emitError(UnknownLoc::get(ctx), "Invalid chipset name: "+ chipset); |
1494 | return signalPassFailure(); |
1495 | } |
1496 | |
1497 | RewritePatternSet patterns(ctx); |
1498 | LLVMTypeConverter converter(ctx); |
1499 | populateAMDGPUToROCDLConversionPatterns(converter, patterns, chipset: *maybeChipset); |
1500 | LLVMConversionTarget target(getContext()); |
1501 | target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>(); |
1502 | target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); |
1503 | target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>(); |
1504 | if (failed(applyPartialConversion(getOperation(), target, |
1505 | std::move(patterns)))) |
1506 | signalPassFailure(); |
1507 | } |
1508 | }; |
1509 | } // namespace |
1510 | |
1511 | void mlir::populateAMDGPUMemorySpaceAttributeConversions( |
1512 | TypeConverter &typeConverter) { |
1513 | typeConverter.addTypeAttributeConversion( |
1514 | callback: [](BaseMemRefType type, amdgpu::AddressSpaceAttr as) |
1515 | -> TypeConverter::AttributeConversionResult { |
1516 | MLIRContext *ctx = as.getContext(); |
1517 | Type i64 = IntegerType::get(ctx, 64); |
1518 | switch (as.getValue()) { |
1519 | case amdgpu::AddressSpace::FatRawBuffer: |
1520 | return IntegerAttr::get(i64, 7); |
1521 | case amdgpu::AddressSpace::BufferRsrc: |
1522 | return IntegerAttr::get(i64, 8); |
1523 | case amdgpu::AddressSpace::FatStructuredBuffer: |
1524 | return IntegerAttr::get(i64, 9); |
1525 | } |
1526 | return TypeConverter::AttributeConversionResult::abort(); |
1527 | }); |
1528 | } |
1529 | |
1530 | void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, |
1531 | RewritePatternSet &patterns, |
1532 | Chipset chipset) { |
1533 | populateAMDGPUMemorySpaceAttributeConversions(typeConverter&: converter); |
1534 | patterns |
1535 | .add<FatRawBufferCastLowering, |
1536 | RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>, |
1537 | RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>, |
1538 | RawBufferOpLowering<RawBufferAtomicFaddOp, |
1539 | ROCDL::RawPtrBufferAtomicFaddOp>, |
1540 | RawBufferOpLowering<RawBufferAtomicFmaxOp, |
1541 | ROCDL::RawPtrBufferAtomicFmaxOp>, |
1542 | RawBufferOpLowering<RawBufferAtomicSmaxOp, |
1543 | ROCDL::RawPtrBufferAtomicSmaxOp>, |
1544 | RawBufferOpLowering<RawBufferAtomicUminOp, |
1545 | ROCDL::RawPtrBufferAtomicUminOp>, |
1546 | RawBufferOpLowering<RawBufferAtomicCmpswapOp, |
1547 | ROCDL::RawPtrBufferAtomicCmpSwap>, |
1548 | AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering, |
1549 | MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering, |
1550 | ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering, |
1551 | PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter, |
1552 | chipset); |
1553 | patterns.add<AMDGPUSwizzleBitModeLowering>(arg&: converter); |
1554 | } |
1555 |
Definitions
- kGfx908
- kGfx90a
- kGfx942
- kGfx950
- convertUnsignedToI32
- createI32Constant
- createI1Constant
- getLinearIndexI32
- getNumRecords
- makeBufferRsrc
- FatRawBufferCastLowering
- FatRawBufferCastLowering
- matchAndRewrite
- RawBufferOpLowering
- RawBufferOpLowering
- maxVectorOpWidth
- matchAndRewrite
- LDSBarrierOpLowering
- LDSBarrierOpLowering
- matchAndRewrite
- SchedBarrierOpLowering
- SchedBarrierOpLowering
- matchAndRewrite
- convertMFMAVectorOperand
- castMFMAScaleOperand
- wmmaPushInputOperand
- wmmaPushOutputOperand
- typeIsExpectedBf8ForChipset
- typeIsExpectedFp8ForChipset
- mfmaOpToIntrinsic
- mfmaTypeSelectCode
- mfmaOpToScaledIntrinsic
- mfmaOpToScaledIntrinsic
- mfmaOpToScaledIntrinsic
- wmmaOpToIntrinsic
- MFMAOpLowering
- MFMAOpLowering
- matchAndRewrite
- ScaledMFMAOpLowering
- ScaledMFMAOpLowering
- matchAndRewrite
- WMMAOpLowering
- WMMAOpLowering
- matchAndRewrite
- GatherToLDSOpLowering
- GatherToLDSOpLowering
- matchAndRewrite
- ExtPackedFp8OpLowering
- ExtPackedFp8OpLowering
- PackedTrunc2xFp8OpLowering
- PackedTrunc2xFp8OpLowering
- PackedStochRoundFp8OpLowering
- PackedStochRoundFp8OpLowering
- matchAndRewrite
- matchAndRewrite
- matchAndRewrite
- AMDGPUDPPLowering
- AMDGPUDPPLowering
- matchAndRewrite
- AMDGPUSwizzleBitModeLowering
- matchAndRewrite
- ConvertAMDGPUToROCDLPass
- runOnOperation
- populateAMDGPUMemorySpaceAttributeConversions
Learn to use CMake with our Intro Training
Find out more