1//===- AMDGPUDialect.cpp - MLIR AMDGPU dialect implementation --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the AMDGPU dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/GPU/IR/GPUDialect.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/BuiltinTypes.h"
19#include "mlir/IR/Diagnostics.h"
20#include "mlir/IR/DialectImplementation.h"
21#include "mlir/IR/Matchers.h"
22#include "mlir/IR/OpImplementation.h"
23#include "mlir/IR/PatternMatch.h"
24#include "mlir/IR/TypeUtilities.h"
25#include "llvm/ADT/TypeSwitch.h"
26
27#include <limits>
28#include <optional>
29
30using namespace mlir;
31using namespace mlir::amdgpu;
32
33#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
34
35void AMDGPUDialect::initialize() {
36 addOperations<
37#define GET_OP_LIST
38#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
39 >();
40 addAttributes<
41#define GET_ATTRDEF_LIST
42#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
43 >();
44}
45
46//===----------------------------------------------------------------------===//
47// 8-bit float ops
48//===----------------------------------------------------------------------===//
49LogicalResult PackedTrunc2xFp8Op::verify() {
50 if (getExisting() && getExisting().getType() != getResult().getType())
51 return emitOpError("existing values must have same type as result");
52 return success();
53}
54
55LogicalResult PackedStochRoundFp8Op::verify() {
56 if (getExisting() && getExisting().getType() != getResult().getType())
57 return emitOpError("existing values must have same type as result");
58 return success();
59}
60
61//===----------------------------------------------------------------------===//
62// RawBuffer*Op
63//===----------------------------------------------------------------------===//
64template <typename T>
65static LogicalResult verifyRawBufferOp(T &op) {
66 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
67 Attribute memorySpace = bufferType.getMemorySpace();
68 bool isGlobal = false;
69 if (!memorySpace)
70 isGlobal = true;
71 else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
72 isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
73 else if (auto gpuMemorySpace =
74 llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
75 isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
76
77 if (!isGlobal)
78 return op.emitOpError(
79 "Buffer ops must operate on a memref in global memory");
80 if (!bufferType.hasRank())
81 return op.emitOpError(
82 "Cannot meaningfully buffer_store to an unranked memref");
83 if (static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
84 return op.emitOpError("Expected " + Twine(bufferType.getRank()) +
85 " indices to memref");
86 return success();
87}
88
89LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); }
90
91LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); }
92
93LogicalResult RawBufferAtomicFaddOp::verify() {
94 return verifyRawBufferOp(*this);
95}
96
97LogicalResult RawBufferAtomicFmaxOp::verify() {
98 return verifyRawBufferOp(*this);
99}
100
101LogicalResult RawBufferAtomicSmaxOp::verify() {
102 return verifyRawBufferOp(*this);
103}
104
105LogicalResult RawBufferAtomicUminOp::verify() {
106 return verifyRawBufferOp(*this);
107}
108
109LogicalResult RawBufferAtomicCmpswapOp::verify() {
110 return verifyRawBufferOp(*this);
111}
112
113static std::optional<uint32_t> getConstantUint32(Value v) {
114 APInt cst;
115 if (!v.getType().isInteger(width: 32))
116 return std::nullopt;
117 if (matchPattern(v, m_ConstantInt(&cst)))
118 return cst.getZExtValue();
119 return std::nullopt;
120}
121
122template <typename OpType>
123static bool staticallyOutOfBounds(OpType op) {
124 if (!op.getBoundsCheck())
125 return false;
126 MemRefType bufferType = op.getMemref().getType();
127 if (!bufferType.hasStaticShape())
128 return false;
129 int64_t offset;
130 SmallVector<int64_t> strides;
131 if (failed(getStridesAndOffset(bufferType, strides, offset)))
132 return false;
133 int64_t result = offset + op.getIndexOffset().value_or(0);
134 if (op.getSgprOffset()) {
135 std::optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset());
136 if (!sgprOffset)
137 return false;
138 result += *sgprOffset;
139 }
140 if (strides.size() != op.getIndices().size())
141 return false;
142 int64_t indexVal = 0;
143 for (auto pair : llvm::zip(strides, op.getIndices())) {
144 int64_t stride = std::get<0>(pair);
145 Value idx = std::get<1>(pair);
146 std::optional<uint32_t> idxVal = getConstantUint32(v: idx);
147 if (!idxVal)
148 return false;
149 indexVal += stride * *idxVal;
150 }
151 result += indexVal;
152 if (result > std::numeric_limits<uint32_t>::max())
153 // Overflow means don't drop
154 return false;
155 return result >= bufferType.getNumElements();
156}
157
158namespace {
159template <typename OpType>
160struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern<OpType> {
161 using OpRewritePattern<OpType>::OpRewritePattern;
162
163 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
164 if (!staticallyOutOfBounds(op))
165 return failure();
166 Type loadType = op.getResult().getType();
167 rw.replaceOpWithNewOp<arith::ConstantOp>(op, loadType,
168 rw.getZeroAttr(loadType));
169 return success();
170 }
171};
172
173template <typename OpType>
174struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern<OpType> {
175 using OpRewritePattern<OpType>::OpRewritePattern;
176
177 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
178 if (!staticallyOutOfBounds(op))
179 return failure();
180
181 rw.eraseOp(op);
182 return success();
183 }
184};
185} // end namespace
186
187void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
188 MLIRContext *context) {
189 results.add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
190}
191
192void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
193 MLIRContext *context) {
194 results.add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
195}
196
197void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
198 RewritePatternSet &results, MLIRContext *context) {
199 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
200}
201
202void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
203 RewritePatternSet &results, MLIRContext *context) {
204 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
205}
206
207void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
208 RewritePatternSet &results, MLIRContext *context) {
209 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
210}
211
212void RawBufferAtomicUminOp::getCanonicalizationPatterns(
213 RewritePatternSet &results, MLIRContext *context) {
214 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
215}
216
217void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
218 RewritePatternSet &results, MLIRContext *context) {
219 results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
220 context);
221}
222
223//===----------------------------------------------------------------------===//
224// WMMAOp
225//===----------------------------------------------------------------------===//
226LogicalResult WMMAOp::verify() {
227 Type sourceAType = getSourceA().getType();
228 Type destType = getDestC().getType();
229
230 VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
231 VectorType destVectorType = dyn_cast<VectorType>(destType);
232
233 Type sourceAElemType = sourceVectorAType.getElementType();
234 Type destElemType = destVectorType.getElementType();
235
236 bool isDestFloat =
237 (destElemType.isF32() || destElemType.isF16() || destElemType.isBF16());
238 bool isSrcFloat = (sourceAElemType.isF16() || sourceAElemType.isBF16());
239
240 if (isDestFloat && !isSrcFloat) {
241 return emitOpError("Expected float sources with float destination");
242 }
243
244 if (!isDestFloat && isSrcFloat) {
245 return emitOpError("Expected int sources with int destination");
246 }
247
248 return success();
249}
250
251//===----------------------------------------------------------------------===//
252// MFMAOp
253//===----------------------------------------------------------------------===//
254LogicalResult MFMAOp::verify() {
255 constexpr uint32_t waveSize = 64;
256 Builder b(getContext());
257
258 Type sourceType = getSourceA().getType();
259 Type destType = getDestC().getType();
260
261 Type sourceElem = sourceType, destElem = destType;
262 uint32_t sourceLen = 1, destLen = 1;
263 if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
264 sourceLen = sourceVector.getNumElements();
265 sourceElem = sourceVector.getElementType();
266 }
267 if (auto destVector = llvm::dyn_cast<VectorType>(destType)) {
268 destLen = destVector.getNumElements();
269 destElem = destVector.getElementType();
270 }
271
272 Type sourceBType = getSourceB().getType();
273 if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) {
274 int64_t sourceBLen = 1;
275 Type sourceBElem = sourceBType;
276 if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
277 sourceBLen = sourceBVector.getNumElements();
278 sourceBElem = sourceBVector.getElementType();
279 }
280 if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ())
281 return emitOpError("expected both source operands to have f8 elements");
282 if (sourceLen != sourceBLen)
283 return emitOpError(
284 "expected both f8 source vectors to have the same length");
285 } else {
286 if (sourceType != sourceBType)
287 return emitOpError(
288 "expected both non-f8 source operand types to match exactly");
289 }
290 // Normalize the wider integer types the compiler expects to i8
291 if (sourceElem.isInteger(32)) {
292 sourceLen *= 4;
293 sourceElem = b.getI8Type();
294 }
295 if (sourceElem.isInteger(64)) {
296 sourceLen *= 8;
297 sourceElem = b.getI8Type();
298 }
299
300 int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
301 if (sourceLen != numSourceElems)
302 return emitOpError("expected " + Twine(numSourceElems) +
303 " source values for this operation but got " +
304 Twine(sourceLen));
305
306 int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
307 if (destLen != numDestElems)
308 return emitOpError("expected " + Twine(numDestElems) +
309 " result values for this operation but got " +
310 Twine(destLen));
311
312 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
313 return emitOpError(
314 "double-precision ops do not support permuting lanes of B");
315 if (destElem.isF64() && getCbsz() != 0)
316 return emitOpError(
317 "double-precision ops do not support permuting lanes of A");
318 if (getAbid() >= (1u << getCbsz()))
319 return emitOpError(
320 "block ID for permuting A (abid) must be below 2 ** cbsz");
321
322 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
323 return emitOpError(
324 "negation flags only available for double-precision operations");
325
326 return success();
327}
328
329#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
330
331#define GET_ATTRDEF_CLASSES
332#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
333
334#define GET_OP_CLASSES
335#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
336

source code of mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp