1 | //===- AMDGPUDialect.cpp - MLIR AMDGPU dialect implementation --------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the AMDGPU dialect and its operations. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" |
14 | |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
17 | #include "mlir/IR/Builders.h" |
18 | #include "mlir/IR/BuiltinTypes.h" |
19 | #include "mlir/IR/Diagnostics.h" |
20 | #include "mlir/IR/DialectImplementation.h" |
21 | #include "mlir/IR/Matchers.h" |
22 | #include "mlir/IR/OpImplementation.h" |
23 | #include "mlir/IR/PatternMatch.h" |
24 | #include "mlir/IR/TypeUtilities.h" |
25 | #include "llvm/ADT/TypeSwitch.h" |
26 | |
27 | #include <limits> |
28 | #include <optional> |
29 | |
30 | using namespace mlir; |
31 | using namespace mlir::amdgpu; |
32 | |
33 | #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc" |
34 | |
35 | void AMDGPUDialect::initialize() { |
36 | addOperations< |
37 | #define GET_OP_LIST |
38 | #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc" |
39 | >(); |
40 | addAttributes< |
41 | #define GET_ATTRDEF_LIST |
42 | #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc" |
43 | >(); |
44 | } |
45 | |
46 | //===----------------------------------------------------------------------===// |
47 | // 8-bit float ops |
48 | //===----------------------------------------------------------------------===// |
49 | LogicalResult PackedTrunc2xFp8Op::verify() { |
50 | if (getExisting() && getExisting().getType() != getResult().getType()) |
51 | return emitOpError("existing values must have same type as result" ); |
52 | return success(); |
53 | } |
54 | |
55 | LogicalResult PackedStochRoundFp8Op::verify() { |
56 | if (getExisting() && getExisting().getType() != getResult().getType()) |
57 | return emitOpError("existing values must have same type as result" ); |
58 | return success(); |
59 | } |
60 | |
61 | //===----------------------------------------------------------------------===// |
62 | // RawBuffer*Op |
63 | //===----------------------------------------------------------------------===// |
64 | template <typename T> |
65 | static LogicalResult verifyRawBufferOp(T &op) { |
66 | MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType()); |
67 | Attribute memorySpace = bufferType.getMemorySpace(); |
68 | bool isGlobal = false; |
69 | if (!memorySpace) |
70 | isGlobal = true; |
71 | else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace)) |
72 | isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1; |
73 | else if (auto gpuMemorySpace = |
74 | llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace)) |
75 | isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global; |
76 | |
77 | if (!isGlobal) |
78 | return op.emitOpError( |
79 | "Buffer ops must operate on a memref in global memory" ); |
80 | if (!bufferType.hasRank()) |
81 | return op.emitOpError( |
82 | "Cannot meaningfully buffer_store to an unranked memref" ); |
83 | if (static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank()) |
84 | return op.emitOpError("Expected " + Twine(bufferType.getRank()) + |
85 | " indices to memref" ); |
86 | return success(); |
87 | } |
88 | |
89 | LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); } |
90 | |
91 | LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); } |
92 | |
93 | LogicalResult RawBufferAtomicFaddOp::verify() { |
94 | return verifyRawBufferOp(*this); |
95 | } |
96 | |
97 | LogicalResult RawBufferAtomicFmaxOp::verify() { |
98 | return verifyRawBufferOp(*this); |
99 | } |
100 | |
101 | LogicalResult RawBufferAtomicSmaxOp::verify() { |
102 | return verifyRawBufferOp(*this); |
103 | } |
104 | |
105 | LogicalResult RawBufferAtomicUminOp::verify() { |
106 | return verifyRawBufferOp(*this); |
107 | } |
108 | |
109 | LogicalResult RawBufferAtomicCmpswapOp::verify() { |
110 | return verifyRawBufferOp(*this); |
111 | } |
112 | |
113 | static std::optional<uint32_t> getConstantUint32(Value v) { |
114 | APInt cst; |
115 | if (!v.getType().isInteger(width: 32)) |
116 | return std::nullopt; |
117 | if (matchPattern(v, m_ConstantInt(&cst))) |
118 | return cst.getZExtValue(); |
119 | return std::nullopt; |
120 | } |
121 | |
122 | template <typename OpType> |
123 | static bool staticallyOutOfBounds(OpType op) { |
124 | if (!op.getBoundsCheck()) |
125 | return false; |
126 | MemRefType bufferType = op.getMemref().getType(); |
127 | if (!bufferType.hasStaticShape()) |
128 | return false; |
129 | int64_t offset; |
130 | SmallVector<int64_t> strides; |
131 | if (failed(getStridesAndOffset(bufferType, strides, offset))) |
132 | return false; |
133 | int64_t result = offset + op.getIndexOffset().value_or(0); |
134 | if (op.getSgprOffset()) { |
135 | std::optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset()); |
136 | if (!sgprOffset) |
137 | return false; |
138 | result += *sgprOffset; |
139 | } |
140 | if (strides.size() != op.getIndices().size()) |
141 | return false; |
142 | int64_t indexVal = 0; |
143 | for (auto pair : llvm::zip(strides, op.getIndices())) { |
144 | int64_t stride = std::get<0>(pair); |
145 | Value idx = std::get<1>(pair); |
146 | std::optional<uint32_t> idxVal = getConstantUint32(v: idx); |
147 | if (!idxVal) |
148 | return false; |
149 | indexVal += stride * *idxVal; |
150 | } |
151 | result += indexVal; |
152 | if (result > std::numeric_limits<uint32_t>::max()) |
153 | // Overflow means don't drop |
154 | return false; |
155 | return result >= bufferType.getNumElements(); |
156 | } |
157 | |
158 | namespace { |
159 | template <typename OpType> |
160 | struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern<OpType> { |
161 | using OpRewritePattern<OpType>::OpRewritePattern; |
162 | |
163 | LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override { |
164 | if (!staticallyOutOfBounds(op)) |
165 | return failure(); |
166 | Type loadType = op.getResult().getType(); |
167 | rw.replaceOpWithNewOp<arith::ConstantOp>(op, loadType, |
168 | rw.getZeroAttr(loadType)); |
169 | return success(); |
170 | } |
171 | }; |
172 | |
173 | template <typename OpType> |
174 | struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern<OpType> { |
175 | using OpRewritePattern<OpType>::OpRewritePattern; |
176 | |
177 | LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override { |
178 | if (!staticallyOutOfBounds(op)) |
179 | return failure(); |
180 | |
181 | rw.eraseOp(op); |
182 | return success(); |
183 | } |
184 | }; |
185 | } // end namespace |
186 | |
187 | void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, |
188 | MLIRContext *context) { |
189 | results.add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context); |
190 | } |
191 | |
192 | void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results, |
193 | MLIRContext *context) { |
194 | results.add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context); |
195 | } |
196 | |
197 | void RawBufferAtomicFaddOp::getCanonicalizationPatterns( |
198 | RewritePatternSet &results, MLIRContext *context) { |
199 | results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context); |
200 | } |
201 | |
202 | void RawBufferAtomicFmaxOp::getCanonicalizationPatterns( |
203 | RewritePatternSet &results, MLIRContext *context) { |
204 | results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context); |
205 | } |
206 | |
207 | void RawBufferAtomicSmaxOp::getCanonicalizationPatterns( |
208 | RewritePatternSet &results, MLIRContext *context) { |
209 | results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context); |
210 | } |
211 | |
212 | void RawBufferAtomicUminOp::getCanonicalizationPatterns( |
213 | RewritePatternSet &results, MLIRContext *context) { |
214 | results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context); |
215 | } |
216 | |
217 | void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns( |
218 | RewritePatternSet &results, MLIRContext *context) { |
219 | results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>( |
220 | context); |
221 | } |
222 | |
223 | //===----------------------------------------------------------------------===// |
224 | // WMMAOp |
225 | //===----------------------------------------------------------------------===// |
226 | LogicalResult WMMAOp::verify() { |
227 | Type sourceAType = getSourceA().getType(); |
228 | Type destType = getDestC().getType(); |
229 | |
230 | VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType); |
231 | VectorType destVectorType = dyn_cast<VectorType>(destType); |
232 | |
233 | Type sourceAElemType = sourceVectorAType.getElementType(); |
234 | Type destElemType = destVectorType.getElementType(); |
235 | |
236 | bool isDestFloat = |
237 | (destElemType.isF32() || destElemType.isF16() || destElemType.isBF16()); |
238 | bool isSrcFloat = (sourceAElemType.isF16() || sourceAElemType.isBF16()); |
239 | |
240 | if (isDestFloat && !isSrcFloat) { |
241 | return emitOpError("Expected float sources with float destination" ); |
242 | } |
243 | |
244 | if (!isDestFloat && isSrcFloat) { |
245 | return emitOpError("Expected int sources with int destination" ); |
246 | } |
247 | |
248 | return success(); |
249 | } |
250 | |
251 | //===----------------------------------------------------------------------===// |
252 | // MFMAOp |
253 | //===----------------------------------------------------------------------===// |
254 | LogicalResult MFMAOp::verify() { |
255 | constexpr uint32_t waveSize = 64; |
256 | Builder b(getContext()); |
257 | |
258 | Type sourceType = getSourceA().getType(); |
259 | Type destType = getDestC().getType(); |
260 | |
261 | Type sourceElem = sourceType, destElem = destType; |
262 | uint32_t sourceLen = 1, destLen = 1; |
263 | if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) { |
264 | sourceLen = sourceVector.getNumElements(); |
265 | sourceElem = sourceVector.getElementType(); |
266 | } |
267 | if (auto destVector = llvm::dyn_cast<VectorType>(destType)) { |
268 | destLen = destVector.getNumElements(); |
269 | destElem = destVector.getElementType(); |
270 | } |
271 | |
272 | Type sourceBType = getSourceB().getType(); |
273 | if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) { |
274 | int64_t sourceBLen = 1; |
275 | Type sourceBElem = sourceBType; |
276 | if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) { |
277 | sourceBLen = sourceBVector.getNumElements(); |
278 | sourceBElem = sourceBVector.getElementType(); |
279 | } |
280 | if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ()) |
281 | return emitOpError("expected both source operands to have f8 elements" ); |
282 | if (sourceLen != sourceBLen) |
283 | return emitOpError( |
284 | "expected both f8 source vectors to have the same length" ); |
285 | } else { |
286 | if (sourceType != sourceBType) |
287 | return emitOpError( |
288 | "expected both non-f8 source operand types to match exactly" ); |
289 | } |
290 | // Normalize the wider integer types the compiler expects to i8 |
291 | if (sourceElem.isInteger(32)) { |
292 | sourceLen *= 4; |
293 | sourceElem = b.getI8Type(); |
294 | } |
295 | if (sourceElem.isInteger(64)) { |
296 | sourceLen *= 8; |
297 | sourceElem = b.getI8Type(); |
298 | } |
299 | |
300 | int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize; |
301 | if (sourceLen != numSourceElems) |
302 | return emitOpError("expected " + Twine(numSourceElems) + |
303 | " source values for this operation but got " + |
304 | Twine(sourceLen)); |
305 | |
306 | int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize; |
307 | if (destLen != numDestElems) |
308 | return emitOpError("expected " + Twine(numDestElems) + |
309 | " result values for this operation but got " + |
310 | Twine(destLen)); |
311 | |
312 | if (destElem.isF64() && getBlgp() != MFMAPermB::none) |
313 | return emitOpError( |
314 | "double-precision ops do not support permuting lanes of B" ); |
315 | if (destElem.isF64() && getCbsz() != 0) |
316 | return emitOpError( |
317 | "double-precision ops do not support permuting lanes of A" ); |
318 | if (getAbid() >= (1u << getCbsz())) |
319 | return emitOpError( |
320 | "block ID for permuting A (abid) must be below 2 ** cbsz" ); |
321 | |
322 | if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64()) |
323 | return emitOpError( |
324 | "negation flags only available for double-precision operations" ); |
325 | |
326 | return success(); |
327 | } |
328 | |
329 | #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc" |
330 | |
331 | #define GET_ATTRDEF_CLASSES |
332 | #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc" |
333 | |
334 | #define GET_OP_CLASSES |
335 | #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc" |
336 | |