1//===- AMDGPUDialect.cpp - MLIR AMDGPU dialect implementation --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the AMDGPU dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/GPU/IR/GPUDialect.h"
17#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
18#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/BuiltinTypes.h"
21#include "mlir/IR/Diagnostics.h"
22#include "mlir/IR/DialectImplementation.h"
23#include "mlir/IR/Matchers.h"
24#include "mlir/IR/OpImplementation.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/IR/TypeUtilities.h"
27#include "llvm/ADT/DenseMap.h"
28#include "llvm/ADT/TypeSwitch.h"
29
30#include <limits>
31#include <optional>
32
33using namespace mlir;
34using namespace mlir::amdgpu;
35
36#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
37
38void AMDGPUDialect::initialize() {
39 addOperations<
40#define GET_OP_LIST
41#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
42 >();
43 addAttributes<
44#define GET_ATTRDEF_LIST
45#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
46 >();
47}
48
49//===----------------------------------------------------------------------===//
50// 8-bit float ops
51//===----------------------------------------------------------------------===//
52LogicalResult PackedTrunc2xFp8Op::verify() {
53 if (getExisting() && getExisting().getType() != getResult().getType())
54 return emitOpError(message: "existing values must have same type as result");
55 return success();
56}
57
58LogicalResult PackedStochRoundFp8Op::verify() {
59 if (getExisting() && getExisting().getType() != getResult().getType())
60 return emitOpError(message: "existing values must have same type as result");
61 return success();
62}
63
64//===----------------------------------------------------------------------===//
65// mxfp float ops
66//===----------------------------------------------------------------------===//
67LogicalResult PackedScaledTruncOp::verify() {
68 if (getExisting() && getExisting().getType() != getResult().getType())
69 return emitOpError(message: "existing values must have same type as result");
70 return success();
71}
72
73//===----------------------------------------------------------------------===//
74// FatRawBufferCastOp
75//===----------------------------------------------------------------------===//
76
77/// Convert the type `source` to one with the same sizes and strides - and
78/// offset, unless `stripOffset` is true, in which case the offset is reset to
79/// 0, if the offset should be reset but the layout of `source` isn't either the
80/// identity layout or a strided layout, this function fails.
81static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
82 bool resetOffset) {
83 MLIRContext *ctx = source.getContext();
84 MemRefType::Builder mb(source);
85 mb.setMemorySpace(
86 amdgpu::AddressSpaceAttr::get(context: ctx, value: amdgpu::AddressSpace::FatRawBuffer));
87 MemRefLayoutAttrInterface layout = source.getLayout();
88 if (resetOffset && !layout.isIdentity()) {
89 auto stridedLayout = dyn_cast<StridedLayoutAttr>(Val&: layout);
90 if (!stridedLayout)
91 return failure();
92 mb.setLayout(StridedLayoutAttr::get(context: ctx, offset: 0, strides: stridedLayout.getStrides()));
93 }
94 return (MemRefType)(mb);
95}
96
97LogicalResult FatRawBufferCastOp::inferReturnTypes(
98 MLIRContext *context, std::optional<Location> location, ValueRange operands,
99 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
100 SmallVectorImpl<Type> &inferredReturnTypes) {
101 Adaptor adaptor(operands, attributes, properties, regions);
102 auto sourceType =
103 dyn_cast_if_present<MemRefType>(Val: adaptor.getSource().getType());
104 if (!sourceType)
105 return failure();
106 FailureOr<MemRefType> resultType =
107 getFatRawBufferTypeLike(source: sourceType, resetOffset: adaptor.getResetOffset());
108 if (failed(Result: resultType))
109 return failure();
110 inferredReturnTypes = SmallVector<Type>{*resultType};
111 return success();
112}
113
114LogicalResult FatRawBufferCastOp::verify() {
115 FailureOr<MemRefType> expectedResultType =
116 getFatRawBufferTypeLike(source: getSource().getType(), resetOffset: getResetOffset());
117 if (failed(Result: expectedResultType))
118 return emitOpError(message: "source type ")
119 << getSource().getType() << " can't have its offset reset";
120 if (getResult().getType() != *expectedResultType)
121 return emitOpError(message: "expected result type to be ")
122 << *expectedResultType << " but got " << getResult().getType();
123 return success();
124}
125
126static bool hasGlobalMemorySpace(Attribute memorySpace) {
127 if (!memorySpace)
128 return true;
129 if (auto intMemorySpace = dyn_cast<IntegerAttr>(Val&: memorySpace))
130 return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
131 if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(Val&: memorySpace))
132 return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
133 return false;
134}
135
136static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
137 if (auto intMemorySpace = dyn_cast<IntegerAttr>(Val&: memorySpace))
138 return intMemorySpace.getInt() == 3;
139 if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(Val&: memorySpace))
140 return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
141 return false;
142}
143
144static bool hasFatRawBufferMemorySpace(Attribute memorySpace) {
145 if (auto intMemorySpace = dyn_cast<IntegerAttr>(Val&: memorySpace))
146 return intMemorySpace.getInt() == 7;
147 if (auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(Val&: memorySpace))
148 return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
149 return false;
150}
151
152//===----------------------------------------------------------------------===//
153// RawBuffer*Op
154//===----------------------------------------------------------------------===//
155template <typename T>
156static LogicalResult verifyRawBufferOp(T &op) {
157 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
158 bool isGlobal = hasGlobalMemorySpace(memorySpace: bufferType.getMemorySpace());
159
160 if (!isGlobal)
161 return op.emitOpError(
162 "Buffer ops must operate on a memref in global memory");
163 if (!bufferType.hasRank())
164 return op.emitOpError(
165 "Cannot meaningfully buffer_store to an unranked memref");
166 if (static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
167 return op.emitOpError("Expected " + Twine(bufferType.getRank()) +
168 " indices to memref");
169 return success();
170}
171
172LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(op&: *this); }
173
174LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(op&: *this); }
175
176LogicalResult RawBufferAtomicFaddOp::verify() {
177 return verifyRawBufferOp(op&: *this);
178}
179
180LogicalResult RawBufferAtomicFmaxOp::verify() {
181 return verifyRawBufferOp(op&: *this);
182}
183
184LogicalResult RawBufferAtomicSmaxOp::verify() {
185 return verifyRawBufferOp(op&: *this);
186}
187
188LogicalResult RawBufferAtomicUminOp::verify() {
189 return verifyRawBufferOp(op&: *this);
190}
191
192LogicalResult RawBufferAtomicCmpswapOp::verify() {
193 return verifyRawBufferOp(op&: *this);
194}
195
196static std::optional<uint32_t> getConstantUint32(Value v) {
197 APInt cst;
198 if (!v.getType().isInteger(width: 32))
199 return std::nullopt;
200 if (matchPattern(value: v, pattern: m_ConstantInt(bind_value: &cst)))
201 return cst.getZExtValue();
202 return std::nullopt;
203}
204
205template <typename OpType>
206static bool staticallyOutOfBounds(OpType op) {
207 if (!op.getBoundsCheck())
208 return false;
209 MemRefType bufferType = op.getMemref().getType();
210 if (!bufferType.hasStaticShape())
211 return false;
212 int64_t offset;
213 SmallVector<int64_t> strides;
214 if (failed(Result: bufferType.getStridesAndOffset(strides, offset)))
215 return false;
216 int64_t result = offset + op.getIndexOffset().value_or(0);
217 if (op.getSgprOffset()) {
218 std::optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset());
219 if (!sgprOffset)
220 return false;
221 result += *sgprOffset;
222 }
223 if (strides.size() != op.getIndices().size())
224 return false;
225 int64_t indexVal = 0;
226 for (auto pair : llvm::zip(strides, op.getIndices())) {
227 int64_t stride = std::get<0>(pair);
228 Value idx = std::get<1>(pair);
229 std::optional<uint32_t> idxVal = getConstantUint32(v: idx);
230 if (!idxVal)
231 return false;
232 indexVal += stride * *idxVal;
233 }
234 result += indexVal;
235 if (result > std::numeric_limits<uint32_t>::max())
236 // Overflow means don't drop
237 return false;
238 return result >= bufferType.getNumElements();
239}
240
241namespace {
242template <typename OpType>
243struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern<OpType> {
244 using OpRewritePattern<OpType>::OpRewritePattern;
245
246 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
247 if (!staticallyOutOfBounds(op))
248 return failure();
249 Type loadType = op.getResult().getType();
250 rw.replaceOpWithNewOp<arith::ConstantOp>(op, loadType,
251 rw.getZeroAttr(type: loadType));
252 return success();
253 }
254};
255
256template <typename OpType>
257struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern<OpType> {
258 using OpRewritePattern<OpType>::OpRewritePattern;
259
260 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
261 if (!staticallyOutOfBounds(op))
262 return failure();
263
264 rw.eraseOp(op);
265 return success();
266 }
267};
268} // end namespace
269
270void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
271 MLIRContext *context) {
272 results.add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(arg&: context);
273}
274
275void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
276 MLIRContext *context) {
277 results.add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(arg&: context);
278}
279
280void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
281 RewritePatternSet &results, MLIRContext *context) {
282 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(arg&: context);
283}
284
285void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
286 RewritePatternSet &results, MLIRContext *context) {
287 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(arg&: context);
288}
289
290void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
291 RewritePatternSet &results, MLIRContext *context) {
292 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(arg&: context);
293}
294
295void RawBufferAtomicUminOp::getCanonicalizationPatterns(
296 RewritePatternSet &results, MLIRContext *context) {
297 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(arg&: context);
298}
299
300void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
301 RewritePatternSet &results, MLIRContext *context) {
302 results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
303 arg&: context);
304}
305
306//===----------------------------------------------------------------------===//
307// WMMAOp
308//===----------------------------------------------------------------------===//
309LogicalResult WMMAOp::verify() {
310 Type sourceAType = getSourceA().getType();
311 Type sourceBType = getSourceB().getType();
312 Type destType = getDestC().getType();
313
314 VectorType sourceVectorAType = dyn_cast<VectorType>(Val&: sourceAType);
315 VectorType sourceVectorBType = dyn_cast<VectorType>(Val&: sourceBType);
316 VectorType destVectorType = dyn_cast<VectorType>(Val&: destType);
317
318 Type sourceAElemType = sourceVectorAType.getElementType();
319 Type sourceBElemType = sourceVectorBType.getElementType();
320 Type destElemType = destVectorType.getElementType();
321
322 if (sourceVectorAType.getNumElements() !=
323 sourceVectorBType.getNumElements()) {
324 return emitOpError(message: "source vectors have different lengths: ")
325 << sourceVectorAType << " vs. " << sourceVectorBType;
326 }
327
328 bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(Val: destElemType);
329 bool isSrcFloat =
330 isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
331 Val: sourceAElemType);
332
333 if (isDestFloat && !isSrcFloat) {
334 return emitOpError(message: "Expected float sources with float destination");
335 }
336
337 if (!isDestFloat && isSrcFloat) {
338 return emitOpError(message: "Expected int sources with int destination");
339 }
340
341 if (sourceAElemType != sourceBElemType &&
342 !(isa<Float8E5M2Type, Float8E4M3FNType>(Val: sourceAElemType) &&
343 isa<Float8E5M2Type, Float8E4M3FNType>(Val: sourceBElemType))) {
344 return emitOpError(
345 message: "source element types much match (except for fp8) but have ")
346 << sourceAType << " and " << sourceBType;
347 }
348 return success();
349}
350
351//===----------------------------------------------------------------------===//
352// MFMAOp
353//===----------------------------------------------------------------------===//
354LogicalResult MFMAOp::verify() {
355 constexpr uint32_t waveSize = 64;
356 Builder b(getContext());
357
358 Type sourceType = getSourceA().getType();
359 Type destType = getDestC().getType();
360
361 Type sourceElem = sourceType, destElem = destType;
362 uint32_t sourceLen = 1, destLen = 1;
363 if (auto sourceVector = llvm::dyn_cast<VectorType>(Val&: sourceType)) {
364 sourceLen = sourceVector.getNumElements();
365 sourceElem = sourceVector.getElementType();
366 }
367 if (auto destVector = llvm::dyn_cast<VectorType>(Val&: destType)) {
368 destLen = destVector.getNumElements();
369 destElem = destVector.getElementType();
370 }
371
372 Type sourceBType = getSourceB().getType();
373 if (sourceElem.isFloat(width: 8) || sourceElem.isFloat(width: 6) || sourceElem.isFloat(width: 4)) {
374 int64_t sourceBLen = 1;
375 Type sourceBElem = sourceBType;
376 if (auto sourceBVector = llvm::dyn_cast<VectorType>(Val&: sourceBType)) {
377 sourceBLen = sourceBVector.getNumElements();
378 sourceBElem = sourceBVector.getElementType();
379 }
380 if (!sourceBElem.isFloat(width: 8) && !sourceBElem.isFloat(width: 6) &&
381 !sourceBElem.isFloat(width: 4))
382 return emitOpError(message: "expected both source operands to have small-float "
383 "elements if one does");
384 if (sourceLen != sourceBLen)
385 return emitOpError(
386 message: "expected both small-float source vectors to have the same length");
387 } else {
388 if (sourceType != sourceBType)
389 return emitOpError(message: "expected both non-small-float source operand types "
390 "to match exactly");
391 }
392 // Normalize the wider integer types the compiler expects to i8
393 if (sourceElem.isInteger(width: 32)) {
394 sourceLen *= 4;
395 sourceElem = b.getI8Type();
396 }
397 if (sourceElem.isInteger(width: 64)) {
398 sourceLen *= 8;
399 sourceElem = b.getI8Type();
400 }
401
402 int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
403 if (sourceLen != numSourceElems)
404 return emitOpError(message: "expected " + Twine(numSourceElems) +
405 " source values for this operation but got " +
406 Twine(sourceLen));
407
408 int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
409 if (destLen != numDestElems)
410 return emitOpError(message: "expected " + Twine(numDestElems) +
411 " result values for this operation but got " +
412 Twine(destLen));
413
414 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
415 return emitOpError(
416 message: "double-precision ops do not support permuting lanes of B");
417 if (destElem.isF64() && getCbsz() != 0)
418 return emitOpError(
419 message: "double-precision ops do not support permuting lanes of A");
420 if (getAbid() >= (1u << getCbsz()))
421 return emitOpError(
422 message: "block ID for permuting A (abid) must be below 2 ** cbsz");
423
424 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
425 return emitOpError(
426 message: "negation flags only available for double-precision operations");
427
428 return success();
429}
430
431//===----------------------------------------------------------------------===//
432// DPPOp
433//===----------------------------------------------------------------------===//
434LogicalResult DPPOp::verify() {
435 Type srcType = getSrc().getType();
436 if (srcType.getIntOrFloatBitWidth() > 64) {
437 return emitOpError(message: "integer and floating point types larger than 64 bits "
438 "are not supported");
439 }
440
441 DPPPerm kind = getKind();
442 Attribute permArgument = getPermArgument().value_or(u: Attribute{});
443
444 switch (kind) {
445
446 case DPPPerm::quad_perm: {
447 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(Val&: permArgument);
448 if (!quadPermAttr || quadPermAttr.size() != 4) {
449 return emitOpError(message: "quad_perm attribute must have exactly 4 elements");
450 }
451 for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
452 int32_t num = elem.getInt();
453 if (num < 0 || num > 3) {
454 return emitOpError(
455 message: "Each element of quad_perm must be in the range [0, 3]");
456 }
457 }
458 } break;
459
460 case DPPPerm::row_shl:
461 case DPPPerm::row_shr:
462 case DPPPerm::row_ror: {
463 if (!permArgument) {
464 return emitOpError(message: "Attribute '" + Twine(stringifyDPPPerm(kind)) +
465 "' value not specified");
466 }
467 if (auto intAttr = dyn_cast<IntegerAttr>(Val&: permArgument)) {
468 uint32_t attrValue = intAttr.getInt();
469 if (attrValue < 1 || attrValue > 15) {
470 return emitOpError(message: "Attribute value must be between 1 and 15");
471 }
472 }
473 } break;
474
475 case DPPPerm::wave_shl:
476 case DPPPerm::wave_shr:
477 case DPPPerm::wave_rol:
478 case DPPPerm::wave_ror:
479 case DPPPerm::row_mirror:
480 case DPPPerm::row_half_mirror:
481 case DPPPerm::row_bcast_15:
482 case DPPPerm::row_bcast_31: {
483 if (permArgument && !isa<UnitAttr>(Val: permArgument)) {
484 return emitOpError(message: "Expected unit attribute for permArgument, but found "
485 "non-trivial argument");
486 }
487 break;
488 }
489 }
490 return success();
491}
492
493LogicalResult GatherToLDSOp::verify() {
494 MemRefType srcType = cast<MemRefType>(Val: getSrc().getType());
495 MemRefType dstType = cast<MemRefType>(Val: getDst().getType());
496
497 if (!dstType.areTrailingDimsContiguous(n: dstType.getRank()))
498 return emitOpError(message: "destination types must be contiguous");
499
500 auto elemType = srcType.getElementType();
501 // Check $src and $dst element types are the same.
502 if (elemType != dstType.getElementType())
503 return emitOpError(message: "source and destination element types must match");
504
505 // copy type sizes should be 1, 2, 4, 12 or 16 bytes.
506 auto transferType = getTransferType();
507 int transferSize;
508 if (auto vectorTransfer = dyn_cast<VectorType>(Val&: transferType)) {
509 transferSize = vectorTransfer.getNumElements() *
510 vectorTransfer.getElementTypeBitWidth();
511 } else {
512 transferSize = transferType.getIntOrFloatBitWidth();
513 }
514 if (!llvm::is_contained(Set: {8, 16, 32, 96, 128}, Element: transferSize))
515 return emitOpError(
516 message: "Transfering type size must be 8, 16, 32, 96 or 128 bits");
517
518 if (!hasGlobalMemorySpace(memorySpace: srcType.getMemorySpace()) &&
519 !hasFatRawBufferMemorySpace(memorySpace: srcType.getMemorySpace()))
520 return emitOpError(
521 message: "source memory address space must be global or fat raw buffer");
522
523 if (!hasWorkgroupMemorySpace(memorySpace: dstType.getMemorySpace()))
524 return emitOpError(message: "destination memory address space must be Workgroup");
525
526 return success();
527}
528
529LogicalResult TransposeLoadOp::verify() {
530 MemRefType srcType = cast<MemRefType>(Val: getSrc().getType());
531
532 if (!hasWorkgroupMemorySpace(memorySpace: srcType.getMemorySpace()))
533 return emitOpError(message: "source memory address space must be Workgroup");
534
535 auto transferType = cast<VectorType>(Val: getType());
536 size_t numElements = transferType.getNumElements();
537 size_t elementTypeSize =
538 transferType.getElementType().getIntOrFloatBitWidth();
539
540 // ElementSize -> NumElements
541 const llvm::SmallDenseMap<size_t, size_t> KValidLoadSizeMap = {
542 {4, 16},
543 {6, 16},
544 {8, 8},
545 {16, 4},
546 };
547
548 auto validNumElems = KValidLoadSizeMap.find(Val: elementTypeSize);
549 if (validNumElems == KValidLoadSizeMap.end()) {
550 return emitOpError(message: "Unsupported element type size for transpose load: ")
551 << elementTypeSize << " bits";
552 }
553 if (numElements != validNumElems->second) {
554 return emitOpError(
555 message: "Transferring type size mismatch: expected num of elements: ")
556 << validNumElems->second;
557 }
558
559 return success();
560}
561
562#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
563
564#define GET_ATTRDEF_CLASSES
565#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
566
567#define GET_OP_CLASSES
568#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
569

source code of mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp