1 | //===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" |
10 | |
11 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
12 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
13 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
14 | #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" |
15 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
16 | #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" |
17 | #include "mlir/IR/BuiltinTypes.h" |
18 | #include "mlir/IR/TypeUtilities.h" |
19 | #include "mlir/Pass/Pass.h" |
20 | |
21 | #include "llvm/ADT/STLExtras.h" |
22 | #include <optional> |
23 | |
24 | namespace mlir { |
25 | #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDL |
26 | #include "mlir/Conversion/Passes.h.inc" |
27 | } // namespace mlir |
28 | |
29 | using namespace mlir; |
30 | using namespace mlir::amdgpu; |
31 | |
32 | static Value createI32Constant(ConversionPatternRewriter &rewriter, |
33 | Location loc, int32_t value) { |
34 | Type llvmI32 = rewriter.getI32Type(); |
35 | return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, value); |
36 | } |
37 | |
38 | static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, |
39 | bool value) { |
40 | Type llvmI1 = rewriter.getI1Type(); |
41 | return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value); |
42 | } |
43 | |
44 | namespace { |
45 | /// Define lowering patterns for raw buffer ops |
46 | template <typename GpuOp, typename Intrinsic> |
47 | struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> { |
48 | RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset) |
49 | : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {} |
50 | |
51 | Chipset chipset; |
52 | static constexpr uint32_t maxVectorOpWidth = 128; |
53 | |
54 | LogicalResult |
55 | matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor, |
56 | ConversionPatternRewriter &rewriter) const override { |
57 | Location loc = gpuOp.getLoc(); |
58 | Value memref = adaptor.getMemref(); |
59 | Value unconvertedMemref = gpuOp.getMemref(); |
60 | MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType()); |
61 | |
62 | if (chipset.majorVersion < 9) |
63 | return gpuOp.emitOpError("raw buffer ops require GCN or higher" ); |
64 | |
65 | Value storeData = adaptor.getODSOperands(0)[0]; |
66 | if (storeData == memref) // no write component to this op |
67 | storeData = Value(); |
68 | Type wantedDataType; |
69 | if (storeData) |
70 | wantedDataType = storeData.getType(); |
71 | else |
72 | wantedDataType = gpuOp.getODSResults(0)[0].getType(); |
73 | |
74 | Value atomicCmpData = Value(); |
75 | // Operand index 1 of a load is the indices, trying to read them can crash. |
76 | if (storeData) { |
77 | Value maybeCmpData = adaptor.getODSOperands(1)[0]; |
78 | if (maybeCmpData != memref) |
79 | atomicCmpData = maybeCmpData; |
80 | } |
81 | |
82 | Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType); |
83 | |
84 | Type i32 = rewriter.getI32Type(); |
85 | Type llvmI32 = this->typeConverter->convertType(i32); |
86 | Type llvmI16 = this->typeConverter->convertType(rewriter.getI16Type()); |
87 | |
88 | int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8; |
89 | Value byteWidthConst = createI32Constant(rewriter, loc, value: elementByteWidth); |
90 | |
91 | // If we want to load a vector<NxT> with total size <= 32 |
92 | // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32 |
93 | // and the total load size is >= 32, use a vector load of N / (bitsize(T) / |
94 | // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands, |
95 | // so bitcast any floats to integers. On top of all this, cast bfloat |
96 | // (vectors) to i16 since the backend doesn't currently support bfloat on |
97 | // these operations. |
98 | Type llvmBufferValType = llvmWantedDataType; |
99 | if (wantedDataType.isBF16()) |
100 | llvmBufferValType = rewriter.getI16Type(); |
101 | if (auto wantedVecType = dyn_cast<VectorType>(wantedDataType)) |
102 | if (wantedVecType.getElementType().isBF16()) |
103 | llvmBufferValType = wantedVecType.clone(rewriter.getI16Type()); |
104 | if (atomicCmpData) { |
105 | if (isa<VectorType>(Val: wantedDataType)) |
106 | return gpuOp.emitOpError("vector compare-and-swap does not exist" ); |
107 | if (auto floatType = dyn_cast<FloatType>(Val&: wantedDataType)) |
108 | llvmBufferValType = this->getTypeConverter()->convertType( |
109 | rewriter.getIntegerType(floatType.getWidth())); |
110 | } |
111 | if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) { |
112 | uint32_t elemBits = dataVector.getElementTypeBitWidth(); |
113 | uint32_t totalBits = elemBits * dataVector.getNumElements(); |
114 | if (totalBits > maxVectorOpWidth) |
115 | return gpuOp.emitOpError( |
116 | "Total width of loads or stores must be no more than " + |
117 | Twine(maxVectorOpWidth) + " bits, but we call for " + |
118 | Twine(totalBits) + |
119 | " bits. This should've been caught in validation" ); |
120 | if (elemBits < 32) { |
121 | if (totalBits > 32) { |
122 | if (totalBits % 32 != 0) |
123 | return gpuOp.emitOpError("Load or store of more than 32-bits that " |
124 | "doesn't fit into words. Can't happen\n" ); |
125 | llvmBufferValType = this->typeConverter->convertType( |
126 | VectorType::get(totalBits / 32, i32)); |
127 | } else { |
128 | llvmBufferValType = this->typeConverter->convertType( |
129 | rewriter.getIntegerType(totalBits)); |
130 | } |
131 | } |
132 | } |
133 | |
134 | SmallVector<Value, 6> args; |
135 | if (storeData) { |
136 | if (llvmBufferValType != llvmWantedDataType) { |
137 | Value castForStore = |
138 | rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData); |
139 | args.push_back(Elt: castForStore); |
140 | } else { |
141 | args.push_back(Elt: storeData); |
142 | } |
143 | } |
144 | |
145 | if (atomicCmpData) { |
146 | if (llvmBufferValType != llvmWantedDataType) { |
147 | Value castForCmp = rewriter.create<LLVM::BitcastOp>( |
148 | loc, llvmBufferValType, atomicCmpData); |
149 | args.push_back(Elt: castForCmp); |
150 | } else { |
151 | args.push_back(Elt: atomicCmpData); |
152 | } |
153 | } |
154 | |
155 | // Construct buffer descriptor from memref, attributes |
156 | int64_t offset = 0; |
157 | SmallVector<int64_t, 5> strides; |
158 | if (failed(getStridesAndOffset(memrefType, strides, offset))) |
159 | return gpuOp.emitOpError("Can't lower non-stride-offset memrefs" ); |
160 | |
161 | MemRefDescriptor memrefDescriptor(memref); |
162 | |
163 | Value ptr = memrefDescriptor.alignedPtr(builder&: rewriter, loc); |
164 | // The stride value is always 0 for raw buffers. This also disables |
165 | // swizling. |
166 | Value stride = rewriter.create<LLVM::ConstantOp>( |
167 | loc, llvmI16, rewriter.getI16IntegerAttr(0)); |
168 | Value numRecords; |
169 | if (memrefType.hasStaticShape()) { |
170 | numRecords = createI32Constant( |
171 | rewriter, loc, |
172 | value: static_cast<int32_t>(memrefType.getNumElements() * elementByteWidth)); |
173 | } else { |
174 | Value maxIndex; |
175 | for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) { |
176 | Value size = memrefDescriptor.size(builder&: rewriter, loc, pos: i); |
177 | Value stride = memrefDescriptor.stride(builder&: rewriter, loc, pos: i); |
178 | stride = rewriter.create<LLVM::MulOp>(loc, stride, byteWidthConst); |
179 | Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride); |
180 | maxIndex = maxIndex ? rewriter.create<LLVM::MaximumOp>(loc, maxIndex, |
181 | maxThisDim) |
182 | : maxThisDim; |
183 | } |
184 | numRecords = rewriter.create<LLVM::TruncOp>(loc, llvmI32, maxIndex); |
185 | } |
186 | |
187 | // Flag word: |
188 | // bits 0-11: dst sel, ignored by these intrinsics |
189 | // bits 12-14: data format (ignored, must be nonzero, 7=float) |
190 | // bits 15-18: data format (ignored, must be nonzero, 4=32bit) |
191 | // bit 19: In nested heap (0 here) |
192 | // bit 20: Behavior on unmap (0 means "return 0 / ignore") |
193 | // bits 21-22: Index stride for swizzles (N/A) |
194 | // bit 23: Add thread ID (0) |
195 | // bit 24: Reserved to 1 (RDNA) or 0 (CDNA) |
196 | // bits 25-26: Reserved (0) |
197 | // bit 27: Buffer is non-volatile (CDNA only) |
198 | // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 = |
199 | // none, 3 = either swizzles or testing against offset field) RDNA only |
200 | // bits 30-31: Type (must be 0) |
201 | uint32_t flags = (7 << 12) | (4 << 15); |
202 | if (chipset.majorVersion >= 10) { |
203 | flags |= (1 << 24); |
204 | uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2; |
205 | flags |= (oob << 28); |
206 | } |
207 | Value flagsConst = createI32Constant(rewriter, loc, value: flags); |
208 | Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8); |
209 | Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>( |
210 | loc, rsrcType, ptr, stride, numRecords, flagsConst); |
211 | args.push_back(Elt: resource); |
212 | |
213 | // Indexing (voffset) |
214 | Value voffset = createI32Constant(rewriter, loc, value: 0); |
215 | for (auto pair : llvm::enumerate(adaptor.getIndices())) { |
216 | size_t i = pair.index(); |
217 | Value index = pair.value(); |
218 | Value strideOp; |
219 | if (ShapedType::isDynamic(strides[i])) { |
220 | strideOp = rewriter.create<LLVM::MulOp>( |
221 | loc, memrefDescriptor.stride(rewriter, loc, i), byteWidthConst); |
222 | } else { |
223 | strideOp = |
224 | createI32Constant(rewriter, loc, value: strides[i] * elementByteWidth); |
225 | } |
226 | index = rewriter.create<LLVM::MulOp>(loc, index, strideOp); |
227 | voffset = rewriter.create<LLVM::AddOp>(loc, voffset, index); |
228 | } |
229 | if (adaptor.getIndexOffset()) { |
230 | int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth; |
231 | Value = createI32Constant(rewriter, loc, value: indexOffset); |
232 | voffset = |
233 | voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst) |
234 | : extraOffsetConst; |
235 | } |
236 | args.push_back(Elt: voffset); |
237 | |
238 | Value sgprOffset = adaptor.getSgprOffset(); |
239 | if (!sgprOffset) |
240 | sgprOffset = createI32Constant(rewriter, loc, value: 0); |
241 | if (ShapedType::isDynamic(offset)) |
242 | sgprOffset = rewriter.create<LLVM::AddOp>( |
243 | loc, memrefDescriptor.offset(rewriter, loc), sgprOffset); |
244 | else if (offset > 0) |
245 | sgprOffset = rewriter.create<LLVM::AddOp>( |
246 | loc, sgprOffset, createI32Constant(rewriter, loc, offset)); |
247 | args.push_back(Elt: sgprOffset); |
248 | |
249 | // bit 0: GLC = 0 (atomics drop value, less coherency) |
250 | // bits 1-2: SLC, DLC = 0 (similarly) |
251 | // bit 3: swizzled (0 for raw) |
252 | args.push_back(Elt: createI32Constant(rewriter, loc, value: 0)); |
253 | |
254 | llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(), |
255 | llvmBufferValType); |
256 | Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args, |
257 | ArrayRef<NamedAttribute>()); |
258 | if (lowered->getNumResults() == 1) { |
259 | Value replacement = lowered->getResult(idx: 0); |
260 | if (llvmBufferValType != llvmWantedDataType) { |
261 | replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType, |
262 | replacement); |
263 | } |
264 | rewriter.replaceOp(gpuOp, replacement); |
265 | } else { |
266 | rewriter.eraseOp(op: gpuOp); |
267 | } |
268 | return success(); |
269 | } |
270 | }; |
271 | |
272 | struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> { |
273 | LDSBarrierOpLowering(LLVMTypeConverter &converter, Chipset chipset) |
274 | : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {} |
275 | |
276 | Chipset chipset; |
277 | |
278 | LogicalResult |
279 | matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor, |
280 | ConversionPatternRewriter &rewriter) const override { |
281 | bool requiresInlineAsm = |
282 | chipset.majorVersion < 9 || |
283 | (chipset.majorVersion == 9 && chipset.minorVersion < 0x0a) || |
284 | (chipset.majorVersion == 11); |
285 | |
286 | if (requiresInlineAsm) { |
287 | auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), |
288 | LLVM::AsmDialect::AD_ATT); |
289 | const char *asmStr = |
290 | ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier" ; |
291 | const char *constraints = "" ; |
292 | rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>( |
293 | op, |
294 | /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(), |
295 | /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true, |
296 | /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr, |
297 | /*operand_attrs=*/ArrayAttr()); |
298 | return success(); |
299 | } |
300 | constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8); |
301 | constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8); |
302 | // Left in place in case someone disables the inline ASM path or future |
303 | // chipsets use the same bit pattern. |
304 | constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4); |
305 | |
306 | int32_t ldsOnlyBits; |
307 | if (chipset.majorVersion == 11) |
308 | ldsOnlyBits = ldsOnlyBitsGfx11; |
309 | else if (chipset.majorVersion == 10) |
310 | ldsOnlyBits = ldsOnlyBitsGfx10; |
311 | else if (chipset.majorVersion <= 9) |
312 | ldsOnlyBits = ldsOnlyBitsGfx6789; |
313 | else |
314 | return op.emitOpError( |
315 | "don't know how to lower this for chipset major version" ) |
316 | << chipset.majorVersion; |
317 | |
318 | Location loc = op->getLoc(); |
319 | rewriter.create<ROCDL::WaitcntOp>(loc, ldsOnlyBits); |
320 | rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op); |
321 | return success(); |
322 | } |
323 | }; |
324 | } // namespace |
325 | |
326 | /// If `input` is a vector of bytes, concatentate those bytes in little-endian |
327 | /// order to form a single integer of size 8 * [vector length]. This works |
328 | /// around a wart in the AMDGPU intrinsics where operations that logically take |
329 | /// vectors of bytes instead integers. Since we do not want to expose this |
330 | /// implementation detail to MLIR, we correct for it here. |
331 | /// |
332 | /// In addition, convert vectors of LLVM bfloats to vectors of i16, since AMDGPU |
333 | /// MFMA intrinsics pre-date the bfloat type. |
334 | static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter, |
335 | Location loc, Value input) { |
336 | Type inputType = input.getType(); |
337 | if (auto vectorType = dyn_cast<VectorType>(inputType)) { |
338 | if (vectorType.getElementType().isBF16()) |
339 | return rewriter.create<LLVM::BitcastOp>( |
340 | loc, vectorType.clone(rewriter.getI16Type()), input); |
341 | |
342 | if (!vectorType.getElementType().isInteger(8)) |
343 | return input; |
344 | int64_t numBytes = vectorType.getNumElements(); |
345 | Type destType = rewriter.getIntegerType(numBytes * 8); |
346 | Value result = rewriter.create<LLVM::ConstantOp>( |
347 | loc, destType, rewriter.getIntegerAttr(destType, 0)); |
348 | for (int64_t i = 0; i < numBytes; ++i) { |
349 | Value idxConst = createI32Constant(rewriter, loc, value: i); |
350 | Value element = |
351 | rewriter.create<LLVM::ExtractElementOp>(loc, input, idxConst); |
352 | Value extended = rewriter.create<LLVM::ZExtOp>(loc, destType, element); |
353 | Value shiftConst = rewriter.create<LLVM::ConstantOp>( |
354 | loc, destType, rewriter.getIntegerAttr(destType, i * 8)); |
355 | Value shifted = rewriter.create<LLVM::ShlOp>(loc, extended, shiftConst); |
356 | result = rewriter.create<LLVM::OrOp>(loc, result, shifted); |
357 | } |
358 | return result; |
359 | } |
360 | return input; |
361 | } |
362 | |
363 | /// Push an input operand. If it is a float type, nothing to do. If it is |
364 | /// an integer type, then we need to also push its signdness (1 for signed, 0 |
365 | /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32 |
366 | /// vector. We also need to convert bfloat inputs to i16 to account for the lack |
367 | /// of bfloat support in the WMMA intrinsics themselves. |
368 | static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, |
369 | Location loc, |
370 | const TypeConverter *typeConverter, |
371 | bool isUnsigned, Value llvmInput, |
372 | SmallVector<Value, 4> &operands) { |
373 | Type inputType = llvmInput.getType(); |
374 | auto vectorType = dyn_cast<VectorType>(inputType); |
375 | Type elemType = vectorType.getElementType(); |
376 | |
377 | if (elemType.isBF16()) |
378 | llvmInput = rewriter.create<LLVM::BitcastOp>( |
379 | loc, vectorType.clone(rewriter.getI16Type()), llvmInput); |
380 | if (!elemType.isInteger(width: 8)) { |
381 | operands.push_back(Elt: llvmInput); |
382 | return; |
383 | } |
384 | |
385 | int64_t numBytes = vectorType.getNumElements(); |
386 | Type i32 = rewriter.getI32Type(); |
387 | VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32); |
388 | auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits); |
389 | |
390 | Value result = rewriter.createOrFold<LLVM::BitcastOp>( |
391 | loc, llvmVectorType32bits, llvmInput); |
392 | |
393 | // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag |
394 | bool localIsUnsigned = isUnsigned; |
395 | if (elemType.isUnsignedInteger(width: 8)) { |
396 | localIsUnsigned = true; |
397 | } else if (elemType.isSignedInteger(width: 8)) { |
398 | localIsUnsigned = false; |
399 | } |
400 | Value sign = createI1Constant(rewriter, loc, value: !localIsUnsigned); |
401 | operands.push_back(Elt: sign); |
402 | operands.push_back(Elt: result); |
403 | } |
404 | |
405 | /// Push the output operand. For many cases this is only pushing the output in |
406 | /// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics, |
407 | /// since the same numbers of VGPRs is used, we need to decide if to store the |
408 | /// result in the upper 16 bits of the VGPRs or in the lower part. To store the |
409 | /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will |
410 | /// be stored it in the upper part |
411 | static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, |
412 | Location loc, |
413 | const TypeConverter *typeConverter, |
414 | Value output, int32_t subwordOffset, |
415 | bool clamp, SmallVector<Value, 4> &operands) { |
416 | Type inputType = output.getType(); |
417 | auto vectorType = dyn_cast<VectorType>(inputType); |
418 | Type elemType = vectorType.getElementType(); |
419 | if (elemType.isBF16()) |
420 | output = rewriter.create<LLVM::BitcastOp>( |
421 | loc, vectorType.clone(rewriter.getI16Type()), output); |
422 | operands.push_back(Elt: output); |
423 | if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(width: 16)) { |
424 | operands.push_back(Elt: createI1Constant(rewriter, loc, value: subwordOffset)); |
425 | } else if (elemType.isInteger(width: 32)) { |
426 | operands.push_back(Elt: createI1Constant(rewriter, loc, value: clamp)); |
427 | } |
428 | } |
429 | |
430 | /// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma` |
431 | /// if one exists. This includes checking to ensure the intrinsic is supported |
432 | /// on the architecture you are compiling for. |
433 | static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma, |
434 | Chipset chipset) { |
435 | uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(), |
436 | b = mfma.getBlocks(); |
437 | Type sourceElem = mfma.getSourceA().getType(); |
438 | if (auto sourceType = dyn_cast<VectorType>(sourceElem)) |
439 | sourceElem = sourceType.getElementType(); |
440 | Type destElem = mfma.getDestC().getType(); |
441 | if (auto destType = dyn_cast<VectorType>(destElem)) |
442 | destElem = destType.getElementType(); |
443 | |
444 | if (sourceElem.isF32() && destElem.isF32()) { |
445 | if (mfma.getReducePrecision() && chipset.minorVersion >= 0x40) { |
446 | if (m == 32 && n == 32 && k == 4 && b == 1) |
447 | return ROCDL::mfma_f32_32x32x4_xf32::getOperationName(); |
448 | if (m == 16 && n == 16 && k == 8 && b == 1) |
449 | return ROCDL::mfma_f32_16x16x8_xf32::getOperationName(); |
450 | } |
451 | if (m == 32 && n == 32 && k == 1 && b == 2) |
452 | return ROCDL::mfma_f32_32x32x1f32::getOperationName(); |
453 | if (m == 16 && n == 16 && k == 1 && b == 4) |
454 | return ROCDL::mfma_f32_16x16x1f32::getOperationName(); |
455 | if (m == 4 && n == 4 && k == 1 && b == 16) |
456 | return ROCDL::mfma_f32_4x4x1f32::getOperationName(); |
457 | if (m == 32 && n == 32 && k == 2 && b == 1) |
458 | return ROCDL::mfma_f32_32x32x2f32::getOperationName(); |
459 | if (m == 16 && n == 16 && k == 4 && b == 1) |
460 | return ROCDL::mfma_f32_16x16x4f32::getOperationName(); |
461 | } |
462 | |
463 | if (sourceElem.isF16() && destElem.isF32()) { |
464 | if (m == 32 && n == 32 && k == 4 && b == 2) |
465 | return ROCDL::mfma_f32_32x32x4f16::getOperationName(); |
466 | if (m == 16 && n == 16 && k == 4 && b == 4) |
467 | return ROCDL::mfma_f32_16x16x4f16::getOperationName(); |
468 | if (m == 4 && n == 4 && k == 4 && b == 16) |
469 | return ROCDL::mfma_f32_4x4x4f16::getOperationName(); |
470 | if (m == 32 && n == 32 && k == 8 && b == 1) |
471 | return ROCDL::mfma_f32_32x32x8f16::getOperationName(); |
472 | if (m == 16 && n == 16 && k == 16 && b == 1) |
473 | return ROCDL::mfma_f32_16x16x16f16::getOperationName(); |
474 | } |
475 | |
476 | if (sourceElem.isBF16() && destElem.isF32() && chipset.minorVersion >= 0x0a) { |
477 | if (m == 32 && n == 32 && k == 4 && b == 2) |
478 | return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName(); |
479 | if (m == 16 && n == 16 && k == 4 && b == 4) |
480 | return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName(); |
481 | if (m == 4 && n == 4 && k == 4 && b == 16) |
482 | return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName(); |
483 | if (m == 32 && n == 32 && k == 8 && b == 1) |
484 | return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName(); |
485 | if (m == 16 && n == 16 && k == 16 && b == 1) |
486 | return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName(); |
487 | } |
488 | |
489 | if (sourceElem.isBF16() && destElem.isF32()) { |
490 | if (m == 32 && n == 32 && k == 2 && b == 2) |
491 | return ROCDL::mfma_f32_32x32x2bf16::getOperationName(); |
492 | if (m == 16 && n == 16 && k == 2 && b == 4) |
493 | return ROCDL::mfma_f32_16x16x2bf16::getOperationName(); |
494 | if (m == 4 && n == 4 && k == 2 && b == 16) |
495 | return ROCDL::mfma_f32_4x4x2bf16::getOperationName(); |
496 | if (m == 32 && n == 32 && k == 4 && b == 1) |
497 | return ROCDL::mfma_f32_32x32x4bf16::getOperationName(); |
498 | if (m == 16 && n == 16 && k == 8 && b == 1) |
499 | return ROCDL::mfma_f32_16x16x8bf16::getOperationName(); |
500 | } |
501 | |
502 | if (isa<IntegerType>(Val: sourceElem) && destElem.isInteger(width: 32)) { |
503 | if (m == 32 && n == 32 && k == 4 && b == 2) |
504 | return ROCDL::mfma_i32_32x32x4i8::getOperationName(); |
505 | if (m == 16 && n == 16 && k == 4 && b == 4) |
506 | return ROCDL::mfma_i32_16x16x4i8::getOperationName(); |
507 | if (m == 4 && n == 4 && k == 4 && b == 16) |
508 | return ROCDL::mfma_i32_4x4x4i8::getOperationName(); |
509 | if (m == 32 && n == 32 && k == 8 && b == 1) |
510 | return ROCDL::mfma_i32_32x32x8i8::getOperationName(); |
511 | if (m == 16 && n == 16 && k == 16 && b == 1) |
512 | return ROCDL::mfma_i32_16x16x16i8::getOperationName(); |
513 | if (m == 32 && n == 32 && k == 16 && b == 1 && chipset.minorVersion >= 0x40) |
514 | return ROCDL::mfma_i32_32x32x16_i8::getOperationName(); |
515 | if (m == 16 && n == 16 && k == 32 && b == 1 && chipset.minorVersion >= 0x40) |
516 | return ROCDL::mfma_i32_16x16x32_i8::getOperationName(); |
517 | } |
518 | |
519 | if (sourceElem.isF64() && destElem.isF64() && chipset.minorVersion >= 0x0a) { |
520 | if (m == 16 && n == 16 && k == 4 && b == 1) |
521 | return ROCDL::mfma_f64_16x16x4f64::getOperationName(); |
522 | if (m == 4 && n == 4 && k == 4 && b == 4) |
523 | return ROCDL::mfma_f64_4x4x4f64::getOperationName(); |
524 | } |
525 | |
526 | if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && |
527 | chipset.minorVersion >= 0x40) { |
528 | // Known to be correct because there are no scalar f8 instructions and |
529 | // because a length mismatch will have been caught by the verifier. |
530 | Type sourceBElem = |
531 | cast<VectorType>(mfma.getSourceB().getType()).getElementType(); |
532 | if (m == 16 && n == 16 && k == 32 && b == 1) { |
533 | if (sourceBElem.isFloat8E5M2FNUZ()) |
534 | return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName(); |
535 | if (sourceBElem.isFloat8E4M3FNUZ()) |
536 | return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName(); |
537 | } |
538 | if (m == 32 && n == 32 && k == 16 && b == 1) { |
539 | if (sourceBElem.isFloat8E5M2FNUZ()) |
540 | return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName(); |
541 | if (sourceBElem.isFloat8E4M3FNUZ()) |
542 | return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName(); |
543 | } |
544 | } |
545 | |
546 | if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && |
547 | chipset.minorVersion >= 0x40) { |
548 | Type sourceBElem = |
549 | cast<VectorType>(mfma.getSourceB().getType()).getElementType(); |
550 | if (m == 16 && n == 16 && k == 32 && b == 1) { |
551 | if (sourceBElem.isFloat8E5M2FNUZ()) |
552 | return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName(); |
553 | if (sourceBElem.isFloat8E4M3FNUZ()) |
554 | return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName(); |
555 | } |
556 | if (m == 32 && n == 32 && k == 16 && b == 1) { |
557 | if (sourceBElem.isFloat8E5M2FNUZ()) |
558 | return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName(); |
559 | if (sourceBElem.isFloat8E4M3FNUZ()) |
560 | return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName(); |
561 | } |
562 | } |
563 | |
564 | return std::nullopt; |
565 | } |
566 | |
567 | /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma` |
568 | /// if one exists. This includes checking to ensure the intrinsic is supported |
569 | /// on the architecture you are compiling for. |
570 | static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, |
571 | Chipset chipset) { |
572 | auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType()); |
573 | auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType()); |
574 | auto elemSourceType = sourceVectorType.getElementType(); |
575 | auto elemDestType = destVectorType.getElementType(); |
576 | |
577 | if (elemSourceType.isF16() && elemDestType.isF32()) { |
578 | return ROCDL::wmma_f32_16x16x16_f16::getOperationName(); |
579 | } |
580 | if (elemSourceType.isBF16() && elemDestType.isF32()) { |
581 | return ROCDL::wmma_f32_16x16x16_bf16::getOperationName(); |
582 | } else if (elemSourceType.isF16() && elemDestType.isF16()) { |
583 | return ROCDL::wmma_f16_16x16x16_f16::getOperationName(); |
584 | } else if (elemSourceType.isBF16() && elemDestType.isBF16()) { |
585 | return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName(); |
586 | } else if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) { |
587 | return ROCDL::wmma_i32_16x16x16_iu8::getOperationName(); |
588 | } |
589 | return std::nullopt; |
590 | } |
591 | |
592 | namespace { |
593 | struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> { |
594 | MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) |
595 | : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {} |
596 | |
597 | Chipset chipset; |
598 | |
599 | LogicalResult |
600 | matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor, |
601 | ConversionPatternRewriter &rewriter) const override { |
602 | Location loc = op.getLoc(); |
603 | Type outType = typeConverter->convertType(op.getDestD().getType()); |
604 | Type intrinsicOutType = outType; |
605 | if (auto outVecType = dyn_cast<VectorType>(outType)) |
606 | if (outVecType.getElementType().isBF16()) |
607 | intrinsicOutType = outVecType.clone(rewriter.getI16Type()); |
608 | |
609 | if (chipset.majorVersion != 9 || chipset.minorVersion < 0x08) |
610 | return op->emitOpError("MFMA only supported on gfx908+" ); |
611 | uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp()); |
612 | if (op.getNegateA() || op.getNegateB() || op.getNegateC()) { |
613 | if (chipset.minorVersion < 0x40) |
614 | return op.emitOpError("negation unsupported on older than gfx840" ); |
615 | getBlgpField |= |
616 | op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2); |
617 | } |
618 | std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset); |
619 | if (!maybeIntrinsic.has_value()) |
620 | return op.emitOpError("no intrinsic matching MFMA size on given chipset" ); |
621 | OperationState loweredOp(loc, *maybeIntrinsic); |
622 | loweredOp.addTypes(newTypes: intrinsicOutType); |
623 | loweredOp.addOperands( |
624 | newOperands: {mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceA()), |
625 | mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceB()), |
626 | adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()), |
627 | createI32Constant(rewriter, loc, op.getAbid()), |
628 | createI32Constant(rewriter, loc, value: getBlgpField)}); |
629 | Value lowered = rewriter.create(state: loweredOp)->getResult(idx: 0); |
630 | if (outType != intrinsicOutType) |
631 | lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered); |
632 | rewriter.replaceOp(op, lowered); |
633 | return success(); |
634 | } |
635 | }; |
636 | |
637 | struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> { |
638 | WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) |
639 | : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {} |
640 | |
641 | Chipset chipset; |
642 | |
643 | LogicalResult |
644 | matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor, |
645 | ConversionPatternRewriter &rewriter) const override { |
646 | Location loc = op.getLoc(); |
647 | Type outType = typeConverter->convertType(op.getDestD().getType()); |
648 | |
649 | if (chipset.majorVersion != 11) |
650 | return op->emitOpError("WMMA only supported on gfx11" ); |
651 | |
652 | std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset); |
653 | |
654 | if (!maybeIntrinsic.has_value()) |
655 | return op.emitOpError("no intrinsic matching WMMA on the given chipset" ); |
656 | |
657 | OperationState loweredOp(loc, *maybeIntrinsic); |
658 | loweredOp.addTypes(newTypes: outType); |
659 | |
660 | SmallVector<Value, 4> operands; |
661 | wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), |
662 | adaptor.getSourceA(), operands); |
663 | wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), |
664 | adaptor.getSourceB(), operands); |
665 | wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(), |
666 | op.getSubwordOffset(), op.getClamp(), operands); |
667 | |
668 | loweredOp.addOperands(newOperands: operands); |
669 | Operation *lowered = rewriter.create(state: loweredOp); |
670 | rewriter.replaceOp(op, lowered->getResults()); |
671 | |
672 | return success(); |
673 | } |
674 | }; |
675 | |
676 | namespace { |
677 | struct ExtPackedFp8OpLowering final |
678 | : public ConvertOpToLLVMPattern<ExtPackedFp8Op> { |
679 | ExtPackedFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset) |
680 | : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter), |
681 | chipset(chipset) {} |
682 | Chipset chipset; |
683 | |
684 | LogicalResult |
685 | matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor, |
686 | ConversionPatternRewriter &rewriter) const override; |
687 | }; |
688 | |
689 | struct PackedTrunc2xFp8OpLowering final |
690 | : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> { |
691 | PackedTrunc2xFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset) |
692 | : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter), |
693 | chipset(chipset) {} |
694 | Chipset chipset; |
695 | |
696 | LogicalResult |
697 | matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor, |
698 | ConversionPatternRewriter &rewriter) const override; |
699 | }; |
700 | |
701 | struct PackedStochRoundFp8OpLowering final |
702 | : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> { |
703 | PackedStochRoundFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset) |
704 | : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter), |
705 | chipset(chipset) {} |
706 | Chipset chipset; |
707 | |
708 | LogicalResult |
709 | matchAndRewrite(PackedStochRoundFp8Op op, |
710 | PackedStochRoundFp8OpAdaptor adaptor, |
711 | ConversionPatternRewriter &rewriter) const override; |
712 | }; |
713 | } // end namespace |
714 | |
715 | LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( |
716 | ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor, |
717 | ConversionPatternRewriter &rewriter) const { |
718 | Location loc = op.getLoc(); |
719 | if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40) |
720 | return rewriter.notifyMatchFailure( |
721 | arg&: loc, msg: "Fp8 conversion instructions are not available on target " |
722 | "architecture and their emulation is not implemented" ); |
723 | Type v4i8 = |
724 | getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type())); |
725 | Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); |
726 | Type f32 = getTypeConverter()->convertType(op.getResult().getType()); |
727 | |
728 | Value source = adaptor.getSource(); |
729 | auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType()); |
730 | Type sourceElemType = getElementTypeOrSelf(op.getSource()); |
731 | // Extend to a v4i8 |
732 | if (!sourceVecType || sourceVecType.getNumElements() < 4) { |
733 | Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8); |
734 | if (!sourceVecType) { |
735 | longVec = rewriter.create<LLVM::InsertElementOp>( |
736 | loc, longVec, source, createI32Constant(rewriter, loc, 0)); |
737 | } else { |
738 | for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) { |
739 | Value idx = createI32Constant(rewriter, loc, value: i); |
740 | Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx); |
741 | longVec = |
742 | rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx); |
743 | } |
744 | } |
745 | source = longVec; |
746 | } |
747 | Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source); |
748 | Value wordSel = createI32Constant(rewriter, loc, op.getIndex()); |
749 | if (sourceElemType.isFloat8E5M2FNUZ()) { |
750 | rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source, |
751 | wordSel); |
752 | } else if (sourceElemType.isFloat8E4M3FNUZ()) { |
753 | rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source, |
754 | wordSel); |
755 | } |
756 | return success(); |
757 | } |
758 | |
759 | LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite( |
760 | PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor, |
761 | ConversionPatternRewriter &rewriter) const { |
762 | Location loc = op.getLoc(); |
763 | if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40) |
764 | return rewriter.notifyMatchFailure( |
765 | arg&: loc, msg: "Fp8 conversion instructions are not available on target " |
766 | "architecture and their emulation is not implemented" ); |
767 | Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); |
768 | |
769 | Type resultType = op.getResult().getType(); |
770 | Type resultElemType = getElementTypeOrSelf(type: resultType); |
771 | |
772 | Value sourceA = adaptor.getSourceA(); |
773 | Value sourceB = adaptor.getSourceB(); |
774 | if (!sourceB) |
775 | sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType()); |
776 | Value existing = adaptor.getExisting(); |
777 | if (existing) |
778 | existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing); |
779 | else |
780 | existing = rewriter.create<LLVM::UndefOp>(loc, i32); |
781 | Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex()); |
782 | |
783 | Value result; |
784 | if (resultElemType.isFloat8E5M2FNUZ()) |
785 | result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB, |
786 | existing, wordSel); |
787 | else if (resultElemType.isFloat8E4M3FNUZ()) |
788 | result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB, |
789 | existing, wordSel); |
790 | |
791 | result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>( |
792 | op, getTypeConverter()->convertType(resultType), result); |
793 | return success(); |
794 | } |
795 | |
796 | LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite( |
797 | PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor, |
798 | ConversionPatternRewriter &rewriter) const { |
799 | Location loc = op.getLoc(); |
800 | if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40) |
801 | return rewriter.notifyMatchFailure( |
802 | arg&: loc, msg: "Fp8 conversion instructions are not available on target " |
803 | "architecture and their emulation is not implemented" ); |
804 | Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); |
805 | |
806 | Type resultType = op.getResult().getType(); |
807 | Type resultElemType = getElementTypeOrSelf(type: resultType); |
808 | |
809 | Value source = adaptor.getSource(); |
810 | Value stoch = adaptor.getStochiasticParam(); |
811 | Value existing = adaptor.getExisting(); |
812 | if (existing) |
813 | existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing); |
814 | else |
815 | existing = rewriter.create<LLVM::UndefOp>(loc, i32); |
816 | Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex()); |
817 | |
818 | Value result; |
819 | if (resultElemType.isFloat8E5M2FNUZ()) |
820 | result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch, |
821 | existing, byteSel); |
822 | else if (resultElemType.isFloat8E4M3FNUZ()) |
823 | result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch, |
824 | existing, byteSel); |
825 | |
826 | result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>( |
827 | op, getTypeConverter()->convertType(resultType), result); |
828 | return success(); |
829 | } |
830 | |
831 | struct ConvertAMDGPUToROCDLPass |
832 | : public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> { |
833 | ConvertAMDGPUToROCDLPass() = default; |
834 | |
835 | void runOnOperation() override { |
836 | MLIRContext *ctx = &getContext(); |
837 | FailureOr<Chipset> maybeChipset = Chipset::parse(chipset); |
838 | if (failed(result: maybeChipset)) { |
839 | emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset); |
840 | return signalPassFailure(); |
841 | } |
842 | |
843 | RewritePatternSet patterns(ctx); |
844 | LLVMTypeConverter converter(ctx); |
845 | populateAMDGPUToROCDLConversionPatterns(converter, patterns, chipset: *maybeChipset); |
846 | LLVMConversionTarget target(getContext()); |
847 | target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>(); |
848 | target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); |
849 | target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>(); |
850 | if (failed(applyPartialConversion(getOperation(), target, |
851 | std::move(patterns)))) |
852 | signalPassFailure(); |
853 | } |
854 | }; |
855 | } // namespace |
856 | |
857 | void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, |
858 | RewritePatternSet &patterns, |
859 | Chipset chipset) { |
860 | converter.addConversion(callback: [](BFloat16Type t) -> Type { |
861 | return IntegerType::get(t.getContext(), 16); |
862 | }); |
863 | converter.addConversion(callback: [&converter](VectorType t) -> std::optional<Type> { |
864 | if (!t.getElementType().isBF16()) |
865 | return std::nullopt; |
866 | return converter.convertType(t.clone(IntegerType::get(t.getContext(), 16))); |
867 | }); |
868 | |
869 | patterns |
870 | .add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>, |
871 | RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>, |
872 | RawBufferOpLowering<RawBufferAtomicFaddOp, |
873 | ROCDL::RawPtrBufferAtomicFaddOp>, |
874 | RawBufferOpLowering<RawBufferAtomicFmaxOp, |
875 | ROCDL::RawPtrBufferAtomicFmaxOp>, |
876 | RawBufferOpLowering<RawBufferAtomicSmaxOp, |
877 | ROCDL::RawPtrBufferAtomicSmaxOp>, |
878 | RawBufferOpLowering<RawBufferAtomicUminOp, |
879 | ROCDL::RawPtrBufferAtomicUminOp>, |
880 | RawBufferOpLowering<RawBufferAtomicCmpswapOp, |
881 | ROCDL::RawPtrBufferAtomicCmpSwap>, |
882 | LDSBarrierOpLowering, MFMAOpLowering, WMMAOpLowering, |
883 | ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering, |
884 | PackedStochRoundFp8OpLowering>(converter, chipset); |
885 | } |
886 | |
887 | std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() { |
888 | return std::make_unique<ConvertAMDGPUToROCDLPass>(); |
889 | } |
890 | |