1//===- AMDGPUDialect.cpp - MLIR AMDGPU dialect implementation --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the AMDGPU dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/GPU/IR/GPUDialect.h"
17#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
18#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/BuiltinTypes.h"
21#include "mlir/IR/Diagnostics.h"
22#include "mlir/IR/DialectImplementation.h"
23#include "mlir/IR/Matchers.h"
24#include "mlir/IR/OpImplementation.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/IR/TypeUtilities.h"
27#include "llvm/ADT/TypeSwitch.h"
28
29#include <limits>
30#include <optional>
31
32using namespace mlir;
33using namespace mlir::amdgpu;
34
35#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
36
37void AMDGPUDialect::initialize() {
38 addOperations<
39#define GET_OP_LIST
40#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
41 >();
42 addAttributes<
43#define GET_ATTRDEF_LIST
44#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
45 >();
46}
47
48//===----------------------------------------------------------------------===//
49// 8-bit float ops
50//===----------------------------------------------------------------------===//
51LogicalResult PackedTrunc2xFp8Op::verify() {
52 if (getExisting() && getExisting().getType() != getResult().getType())
53 return emitOpError("existing values must have same type as result");
54 return success();
55}
56
57LogicalResult PackedStochRoundFp8Op::verify() {
58 if (getExisting() && getExisting().getType() != getResult().getType())
59 return emitOpError("existing values must have same type as result");
60 return success();
61}
62
63//===----------------------------------------------------------------------===//
64// FatRawBufferCastOp
65//===----------------------------------------------------------------------===//
66
67/// Convert the type `source` to one with the same sizes and strides - and
68/// offset, unless `stripOffset` is true, in which case the offset is reset to
69/// 0, if the offset should be reset but the layout of `source` isn't either the
70/// identity layout or a strided layout, this function fails.
71static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
72 bool resetOffset) {
73 MLIRContext *ctx = source.getContext();
74 MemRefType::Builder mb(source);
75 mb.setMemorySpace(
76 amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
77 MemRefLayoutAttrInterface layout = source.getLayout();
78 if (resetOffset && !layout.isIdentity()) {
79 auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
80 if (!stridedLayout)
81 return failure();
82 mb.setLayout(StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides()));
83 }
84 return (MemRefType)(mb);
85}
86
87LogicalResult FatRawBufferCastOp::inferReturnTypes(
88 MLIRContext *context, std::optional<Location> location, ValueRange operands,
89 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
90 SmallVectorImpl<Type> &inferredReturnTypes) {
91 Adaptor adaptor(operands, attributes, properties, regions);
92 auto sourceType =
93 dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
94 if (!sourceType)
95 return failure();
96 FailureOr<MemRefType> resultType =
97 getFatRawBufferTypeLike(sourceType, adaptor.getResetOffset());
98 if (failed(resultType))
99 return failure();
100 inferredReturnTypes = SmallVector<Type>{*resultType};
101 return success();
102}
103
104LogicalResult FatRawBufferCastOp::verify() {
105 FailureOr<MemRefType> expectedResultType =
106 getFatRawBufferTypeLike(getSource().getType(), getResetOffset());
107 if (failed(expectedResultType))
108 return emitOpError("source type ")
109 << getSource().getType() << " can't have its offset reset";
110 if (getResult().getType() != *expectedResultType)
111 return emitOpError("expected result type to be ")
112 << *expectedResultType << " but got " << getResult().getType();
113 return success();
114}
115
116static bool hasGlobalMemorySpace(Attribute memorySpace) {
117 if (!memorySpace)
118 return true;
119 if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
120 return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
121 if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
122 return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
123 return false;
124}
125
126static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
127 if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
128 return intMemorySpace.getInt() == 3;
129 if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
130 return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
131 return false;
132}
133
134static bool hasFatRawBufferMemorySpace(Attribute memorySpace) {
135 if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
136 return intMemorySpace.getInt() == 7;
137 if (auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
138 return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
139 return false;
140}
141
142//===----------------------------------------------------------------------===//
143// RawBuffer*Op
144//===----------------------------------------------------------------------===//
145template <typename T>
146static LogicalResult verifyRawBufferOp(T &op) {
147 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
148 bool isGlobal = hasGlobalMemorySpace(bufferType.getMemorySpace());
149
150 if (!isGlobal)
151 return op.emitOpError(
152 "Buffer ops must operate on a memref in global memory");
153 if (!bufferType.hasRank())
154 return op.emitOpError(
155 "Cannot meaningfully buffer_store to an unranked memref");
156 if (static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
157 return op.emitOpError("Expected " + Twine(bufferType.getRank()) +
158 " indices to memref");
159 return success();
160}
161
162LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); }
163
164LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); }
165
166LogicalResult RawBufferAtomicFaddOp::verify() {
167 return verifyRawBufferOp(*this);
168}
169
170LogicalResult RawBufferAtomicFmaxOp::verify() {
171 return verifyRawBufferOp(*this);
172}
173
174LogicalResult RawBufferAtomicSmaxOp::verify() {
175 return verifyRawBufferOp(*this);
176}
177
178LogicalResult RawBufferAtomicUminOp::verify() {
179 return verifyRawBufferOp(*this);
180}
181
182LogicalResult RawBufferAtomicCmpswapOp::verify() {
183 return verifyRawBufferOp(*this);
184}
185
186static std::optional<uint32_t> getConstantUint32(Value v) {
187 APInt cst;
188 if (!v.getType().isInteger(width: 32))
189 return std::nullopt;
190 if (matchPattern(v, m_ConstantInt(&cst)))
191 return cst.getZExtValue();
192 return std::nullopt;
193}
194
195template <typename OpType>
196static bool staticallyOutOfBounds(OpType op) {
197 if (!op.getBoundsCheck())
198 return false;
199 MemRefType bufferType = op.getMemref().getType();
200 if (!bufferType.hasStaticShape())
201 return false;
202 int64_t offset;
203 SmallVector<int64_t> strides;
204 if (failed(bufferType.getStridesAndOffset(strides, offset)))
205 return false;
206 int64_t result = offset + op.getIndexOffset().value_or(0);
207 if (op.getSgprOffset()) {
208 std::optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset());
209 if (!sgprOffset)
210 return false;
211 result += *sgprOffset;
212 }
213 if (strides.size() != op.getIndices().size())
214 return false;
215 int64_t indexVal = 0;
216 for (auto pair : llvm::zip(strides, op.getIndices())) {
217 int64_t stride = std::get<0>(pair);
218 Value idx = std::get<1>(pair);
219 std::optional<uint32_t> idxVal = getConstantUint32(v: idx);
220 if (!idxVal)
221 return false;
222 indexVal += stride * *idxVal;
223 }
224 result += indexVal;
225 if (result > std::numeric_limits<uint32_t>::max())
226 // Overflow means don't drop
227 return false;
228 return result >= bufferType.getNumElements();
229}
230
231namespace {
232template <typename OpType>
233struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern<OpType> {
234 using OpRewritePattern<OpType>::OpRewritePattern;
235
236 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
237 if (!staticallyOutOfBounds(op))
238 return failure();
239 Type loadType = op.getResult().getType();
240 rw.replaceOpWithNewOp<arith::ConstantOp>(op, loadType,
241 rw.getZeroAttr(loadType));
242 return success();
243 }
244};
245
246template <typename OpType>
247struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern<OpType> {
248 using OpRewritePattern<OpType>::OpRewritePattern;
249
250 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
251 if (!staticallyOutOfBounds(op))
252 return failure();
253
254 rw.eraseOp(op);
255 return success();
256 }
257};
258} // end namespace
259
260void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
261 MLIRContext *context) {
262 results.add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
263}
264
265void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
266 MLIRContext *context) {
267 results.add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
268}
269
270void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
271 RewritePatternSet &results, MLIRContext *context) {
272 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
273}
274
275void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
276 RewritePatternSet &results, MLIRContext *context) {
277 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
278}
279
280void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
281 RewritePatternSet &results, MLIRContext *context) {
282 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
283}
284
285void RawBufferAtomicUminOp::getCanonicalizationPatterns(
286 RewritePatternSet &results, MLIRContext *context) {
287 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
288}
289
290void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
291 RewritePatternSet &results, MLIRContext *context) {
292 results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
293 context);
294}
295
296//===----------------------------------------------------------------------===//
297// WMMAOp
298//===----------------------------------------------------------------------===//
299LogicalResult WMMAOp::verify() {
300 Type sourceAType = getSourceA().getType();
301 Type sourceBType = getSourceB().getType();
302 Type destType = getDestC().getType();
303
304 VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
305 VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
306 VectorType destVectorType = dyn_cast<VectorType>(destType);
307
308 Type sourceAElemType = sourceVectorAType.getElementType();
309 Type sourceBElemType = sourceVectorBType.getElementType();
310 Type destElemType = destVectorType.getElementType();
311
312 if (sourceVectorAType.getNumElements() !=
313 sourceVectorBType.getNumElements()) {
314 return emitOpError("source vectors have different lengths: ")
315 << sourceVectorAType << " vs. " << sourceVectorBType;
316 }
317
318 bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
319 bool isSrcFloat =
320 isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
321 sourceAElemType);
322
323 if (isDestFloat && !isSrcFloat) {
324 return emitOpError("Expected float sources with float destination");
325 }
326
327 if (!isDestFloat && isSrcFloat) {
328 return emitOpError("Expected int sources with int destination");
329 }
330
331 if (sourceAElemType != sourceBElemType &&
332 !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
333 isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
334 return emitOpError(
335 "source element types much match (except for fp8) but have ")
336 << sourceAType << " and " << sourceBType;
337 }
338 return success();
339}
340
341//===----------------------------------------------------------------------===//
342// MFMAOp
343//===----------------------------------------------------------------------===//
344LogicalResult MFMAOp::verify() {
345 constexpr uint32_t waveSize = 64;
346 Builder b(getContext());
347
348 Type sourceType = getSourceA().getType();
349 Type destType = getDestC().getType();
350
351 Type sourceElem = sourceType, destElem = destType;
352 uint32_t sourceLen = 1, destLen = 1;
353 if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
354 sourceLen = sourceVector.getNumElements();
355 sourceElem = sourceVector.getElementType();
356 }
357 if (auto destVector = llvm::dyn_cast<VectorType>(destType)) {
358 destLen = destVector.getNumElements();
359 destElem = destVector.getElementType();
360 }
361
362 Type sourceBType = getSourceB().getType();
363 if (sourceElem.isFloat(8) || sourceElem.isFloat(6) || sourceElem.isFloat(4)) {
364 int64_t sourceBLen = 1;
365 Type sourceBElem = sourceBType;
366 if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
367 sourceBLen = sourceBVector.getNumElements();
368 sourceBElem = sourceBVector.getElementType();
369 }
370 if (!sourceBElem.isFloat(8) && !sourceBElem.isFloat(6) &&
371 !sourceBElem.isFloat(4))
372 return emitOpError("expected both source operands to have small-float "
373 "elements if one does");
374 if (sourceLen != sourceBLen)
375 return emitOpError(
376 "expected both small-float source vectors to have the same length");
377 } else {
378 if (sourceType != sourceBType)
379 return emitOpError("expected both non-small-float source operand types "
380 "to match exactly");
381 }
382 // Normalize the wider integer types the compiler expects to i8
383 if (sourceElem.isInteger(32)) {
384 sourceLen *= 4;
385 sourceElem = b.getI8Type();
386 }
387 if (sourceElem.isInteger(64)) {
388 sourceLen *= 8;
389 sourceElem = b.getI8Type();
390 }
391
392 int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
393 if (sourceLen != numSourceElems)
394 return emitOpError("expected " + Twine(numSourceElems) +
395 " source values for this operation but got " +
396 Twine(sourceLen));
397
398 int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
399 if (destLen != numDestElems)
400 return emitOpError("expected " + Twine(numDestElems) +
401 " result values for this operation but got " +
402 Twine(destLen));
403
404 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
405 return emitOpError(
406 "double-precision ops do not support permuting lanes of B");
407 if (destElem.isF64() && getCbsz() != 0)
408 return emitOpError(
409 "double-precision ops do not support permuting lanes of A");
410 if (getAbid() >= (1u << getCbsz()))
411 return emitOpError(
412 "block ID for permuting A (abid) must be below 2 ** cbsz");
413
414 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
415 return emitOpError(
416 "negation flags only available for double-precision operations");
417
418 return success();
419}
420
421//===----------------------------------------------------------------------===//
422// DPPOp
423//===----------------------------------------------------------------------===//
424LogicalResult DPPOp::verify() {
425 Type srcType = getSrc().getType();
426 if (srcType.getIntOrFloatBitWidth() > 64) {
427 return emitOpError("integer and floating point types larger than 64 bits "
428 "are not supported");
429 }
430
431 DPPPerm kind = getKind();
432 Attribute permArgument = getPermArgument().value_or(Attribute{});
433
434 switch (kind) {
435
436 case DPPPerm::quad_perm: {
437 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
438 if (!quadPermAttr || quadPermAttr.size() != 4) {
439 return emitOpError("quad_perm attribute must have exactly 4 elements");
440 }
441 for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
442 int32_t num = elem.getInt();
443 if (num < 0 || num > 3) {
444 return emitOpError(
445 "Each element of quad_perm must be in the range [0, 3]");
446 }
447 }
448 } break;
449
450 case DPPPerm::row_shl:
451 case DPPPerm::row_shr:
452 case DPPPerm::row_ror: {
453 if (!permArgument) {
454 return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
455 "' value not specified");
456 }
457 if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
458 uint32_t attrValue = intAttr.getInt();
459 if (attrValue < 1 || attrValue > 15) {
460 return emitOpError("Attribute value must be between 1 and 15");
461 }
462 }
463 } break;
464
465 case DPPPerm::wave_shl:
466 case DPPPerm::wave_shr:
467 case DPPPerm::wave_rol:
468 case DPPPerm::wave_ror:
469 case DPPPerm::row_mirror:
470 case DPPPerm::row_half_mirror:
471 case DPPPerm::row_bcast_15:
472 case DPPPerm::row_bcast_31: {
473 if (permArgument && !isa<UnitAttr>(permArgument)) {
474 return emitOpError("Expected unit attribute for permArgument, but found "
475 "non-trivial argument");
476 }
477 break;
478 }
479 }
480 return success();
481}
482
483LogicalResult GatherToLDSOp::verify() {
484 MemRefType srcType = cast<MemRefType>(getSrc().getType());
485 MemRefType dstType = cast<MemRefType>(getDst().getType());
486
487 if (!dstType.areTrailingDimsContiguous(dstType.getRank()))
488 return emitOpError("destination types must be contiguous");
489
490 auto elemType = srcType.getElementType();
491 // Check $src and $dst element types are the same.
492 if (elemType != dstType.getElementType())
493 return emitOpError("source and destination element types must match");
494
495 // copy type sizes should be 1, 2, or 4 bytes.
496 auto transferType = getTransferType();
497 size_t transferSize;
498 if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
499 transferSize = vectorTransfer.getNumElements() *
500 vectorTransfer.getElementTypeBitWidth();
501 } else {
502 transferSize = transferType.getIntOrFloatBitWidth();
503 }
504 if (transferSize != 8 && transferSize != 16 && transferSize != 32)
505 return emitOpError("Transfering type size must be 8, 16, or 32 bits");
506
507 if (!hasGlobalMemorySpace(srcType.getMemorySpace()) &&
508 !hasFatRawBufferMemorySpace(srcType.getMemorySpace()))
509 return emitOpError(
510 "source memory address space must be global or fat raw buffer");
511
512 if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
513 return emitOpError("destination memory address space must be Workgroup");
514
515 return success();
516}
517
518#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
519
520#define GET_ATTRDEF_CLASSES
521#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
522
523#define GET_OP_CLASSES
524#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
525

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp