1//===- GPUDialect.h - MLIR Dialect for GPU Kernels --------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the GPU kernel-related operations and puts them in the
10// corresponding dialect.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef MLIR_DIALECT_GPU_IR_GPUDIALECT_H
15#define MLIR_DIALECT_GPU_IR_GPUDIALECT_H
16
17#include "mlir/Bytecode/BytecodeOpInterface.h"
18#include "mlir/Dialect/DLTI/Traits.h"
19#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/BuiltinTypes.h"
22#include "mlir/IR/Dialect.h"
23#include "mlir/IR/OpDefinition.h"
24#include "mlir/IR/OpImplementation.h"
25#include "mlir/IR/SymbolTable.h"
26#include "mlir/Interfaces/ControlFlowInterfaces.h"
27#include "mlir/Interfaces/FunctionInterfaces.h"
28#include "mlir/Interfaces/InferIntRangeInterface.h"
29#include "mlir/Interfaces/InferTypeOpInterface.h"
30#include "mlir/Interfaces/SideEffectInterfaces.h"
31#include "llvm/ADT/STLExtras.h"
32
33namespace mlir {
34namespace gpu {
35
36/// Utility class for the GPU dialect to represent triples of `Value`s
37/// accessible through `.x`, `.y`, and `.z` similarly to CUDA notation.
38struct KernelDim3 {
39 Value x;
40 Value y;
41 Value z;
42};
43
44class AsyncTokenType
45 : public Type::TypeBase<AsyncTokenType, Type, TypeStorage> {
46public:
47 // Used for generic hooks in TypeBase.
48 using Base::Base;
49
50 static constexpr StringLiteral name = "gpu.async_token";
51};
52
53/// MMAMatrixType storage and uniquing. Array is uniqued based on its shape
54/// and type.
55struct MMAMatrixStorageType : public TypeStorage {
56 MMAMatrixStorageType(unsigned numDims, const int64_t *dimShapes,
57 Type elementType, StringRef operand)
58 : dimShapes(dimShapes), numDims(numDims), elementType(elementType),
59 operand(operand) {}
60
61 /// The hash key for uniquing.
62 using KeyTy = std::tuple<ArrayRef<int64_t>, Type, StringRef>;
63 bool operator==(const KeyTy &key) const {
64 return key == KeyTy(getShape(), elementType, operand);
65 }
66
67 /// Construction.
68 static MMAMatrixStorageType *construct(TypeStorageAllocator &allocator,
69 const KeyTy &key) {
70 ArrayRef<int64_t> shape = allocator.copyInto(elements: std::get<0>(t: key));
71 StringRef operand = allocator.copyInto(str: std::get<2>(t: key));
72
73 return new (allocator.allocate<MMAMatrixStorageType>())
74 MMAMatrixStorageType(shape.size(), shape.data(), std::get<1>(t: key),
75 operand);
76 }
77
78 ArrayRef<int64_t> getShape() const {
79 return ArrayRef<int64_t>(dimShapes, numDims);
80 }
81
82 StringRef getOperand() const { return operand; }
83
84 /// Reference to the shape of the MMA matrix.
85 const int64_t *dimShapes;
86
87 /// Number of dimensions in the MMA matrix.
88 unsigned numDims;
89
90 /// Element type of elements held in the MMA matrix.
91 Type elementType;
92
93 /// MMA operand that this MMAMatrix holds. The general form of operation this
94 /// type supports is given by the equation C += A*B. This field specifies
95 /// which operand in the given equation is held by this type. The valid values
96 /// are "AOp", "BOp" and "COp".
97 StringRef operand;
98};
99
100/// MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply
101/// accumulate operations. MMAMatrices are taken as direct operands by these
102/// operations and are also produced as results. These matrices are meant to
103/// reside in the registers. A limited number of pointwise operations can be
104/// performed on these matrices, i.e., operations which operate uniformly on
105/// all the elements in the matrix and do not change the order of matrix
106/// elements. The above conditions exist because the layout of matrix elements
107/// inside the matrix is opaque i.e., the elements may be present in the
108/// matrix in any order. The general usage of this type is shown as follows:-
109///
110/// %0 = gpu.subgroup_mma_load_matrix %arg0[%c0, %c0] {leadDimension = 16 :
111/// index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
112///
113/// The MMAMatrixType describes the shape of the matrix being loaded and the
114/// operand being loaded too. The operand needs to be specified to aid the
115/// lowering of this type to dialects such as NVVM where each workitem may
116/// hold different amount of elements depending on the elementType of the
117/// matrix. For e.g., Each workitem holds 4 vector<2xf16>s for f16 data type
118/// and 8 f32s for f32 data type of MMAMatrix. Some other instances of usage
119/// are:-
120///
121/// %3 = gpu.subgroup_mma_compute %0, %1, %2 :
122/// !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">
123/// -> !gpu.mma_matrix<16x16xf32, "COp">
124///
125///
126/// gpu.subgroup_mma_store_matrix %3, %arg22[%c0, %c0] {leadDimension = 16
127/// : index}: !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32>
128// TODO: consider moving this to ODS.
129class MMAMatrixType
130 : public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType> {
131public:
132 using Base::Base;
133
134 static constexpr StringLiteral name = "gpu.mma_matrix";
135
136 /// Get MMAMatrixType and verify construction Invariants.
137 static MMAMatrixType get(ArrayRef<int64_t> shape, Type elementType,
138 StringRef operand);
139
140 /// Get MMAMatrixType at a particular location and verify construction
141 /// Invariants.
142 static MMAMatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
143 ArrayRef<int64_t> shape, Type elementType,
144 StringRef operand);
145
146 /// Check if a type is valid a MMAMatrixType elementType.
147 static bool isValidElementType(Type elementType);
148
149 /// Verify that shape and elementType are actually allowed for the
150 /// MMAMatrixType.
151 static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
152 ArrayRef<int64_t> shape, Type elementType,
153 StringRef operand);
154
155 /// Get number of dims.
156 unsigned getNumDims() const;
157
158 /// Get shape of the matrix.
159 ArrayRef<int64_t> getShape() const;
160
161 /// Get elementType of a single element.
162 Type getElementType() const;
163
164 /// The general form of operation this type supports is given by the equation
165 /// C += A*B. This function returns which operand in the given equation is
166 /// held by this type. String returned can be one of"AOp", "BOp" and "COp".
167 StringRef getOperand() const;
168};
169
170// Adds a `gpu.async.token` to the front of the argument list.
171void addAsyncDependency(Operation *op, Value token);
172
173// Handle types for sparse.
174enum class SparseHandleKind { SpMat, DnTensor, SpGEMMOp };
175
176class SparseDnTensorHandleType
177 : public Type::TypeBase<SparseDnTensorHandleType, Type, TypeStorage> {
178public:
179 using Base = typename Type::TypeBase<SparseDnTensorHandleType, Type,
180 TypeStorage>::Base;
181 using Base::Base;
182
183 static constexpr StringLiteral name = "gpu.sparse.dntensor_handle";
184};
185
186class SparseSpMatHandleType
187 : public Type::TypeBase<SparseSpMatHandleType, Type, TypeStorage> {
188public:
189 using Base =
190 typename Type::TypeBase<SparseSpMatHandleType, Type, TypeStorage>::Base;
191 using Base::Base;
192
193 static constexpr StringLiteral name = "gpu.sparse.spmat_handle";
194};
195
196class SparseSpGEMMOpHandleType
197 : public Type::TypeBase<SparseSpGEMMOpHandleType, Type, TypeStorage> {
198public:
199 using Base = typename Type::TypeBase<SparseSpGEMMOpHandleType, Type,
200 TypeStorage>::Base;
201 using Base::Base;
202
203 static constexpr StringLiteral name = "gpu.sparse.spgemmop_handle";
204};
205
206} // namespace gpu
207} // namespace mlir
208
209#include "mlir/Dialect/GPU/IR/GPUOpsEnums.h.inc"
210
211#include "mlir/Dialect/GPU/IR/GPUOpsDialect.h.inc"
212
213#include "mlir/Dialect/GPU/IR/GPUOpInterfaces.h.inc"
214
215#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
216
217#define GET_ATTRDEF_CLASSES
218#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.h.inc"
219
220#define GET_OP_CLASSES
221#include "mlir/Dialect/GPU/IR/GPUOps.h.inc"
222
223#endif // MLIR_DIALECT_GPU_IR_GPUDIALECT_H
224

source code of mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h