1//===-- XeVMDialect.cpp - XeVM dialect registration -------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
9#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
10#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
11#include "mlir/Dialect/Utils/StaticValueUtils.h"
12#include "mlir/IR/DialectImplementation.h"
13#include "llvm/ADT/TypeSwitch.h"
14#include "llvm/Support/FileSystem.h"
15#include "llvm/Support/MathExtras.h"
16
17using namespace mlir;
18using namespace mlir::xevm;
19
20#include "mlir/Dialect/LLVMIR/XeVMOpsDialect.cpp.inc"
21#include "mlir/Dialect/LLVMIR/XeVMOpsEnums.cpp.inc"
22
23namespace {
24static constexpr uint32_t subgroupSize = 16;
25
26template <typename Op>
27LogicalResult verifyMatrixInput(Op op) {
28 static_assert(llvm::is_one_of<Op, BlockLoad2dOp, BlockStore2dOp,
29 BlockPrefetch2dOp>::value,
30 "Unexpected template parameter");
31
32 std::optional<int64_t> width = getConstantIntValue(op.getBaseWidth());
33 std::optional<int64_t> pitch = getConstantIntValue(op.getBasePitch());
34 if (pitch && width && *pitch < *width)
35 return op->emitOpError(
36 "4th operand (base pitch) should be >= 2nd operand (base width)");
37
38 uint32_t elemSize = op.getElemSizeInBits();
39 if (elemSize < 8 || !llvm::isPowerOf2_32(Value: elemSize) || elemSize > 32)
40 return op->emitOpError("expecting 'elem_size_in_bits' to be 8, 16, or 32");
41
42 uint32_t tileHeight = op.getTileHeight();
43 if (tileHeight > 32 || !llvm::isPowerOf2_32(Value: tileHeight))
44 return op->emitOpError("expecting tile_height to be 1, 2, 4, 8, 16, or 32");
45
46 uint32_t vBlocks = op.getVBlocks();
47 if (vBlocks > 8 || !llvm::isPowerOf2_32(Value: vBlocks))
48 return op->emitOpError("expecting v_blocks to be 1, 2, 4, or 8");
49
50 return success();
51}
52
53LogicalResult verify2DBlockLoadRestriction(BlockLoad2dOp op) {
54 VectorType resTy = op.getRes().getType();
55 if (!resTy.getElementType().isIntOrFloat())
56 return op.emitOpError()
57 << "expecting result element type to be int or float";
58 unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
59 unsigned resSize = resTy.getNumElements() * resElemTySize;
60 unsigned expectedSize = op.getElemSizeInBits() * op.getTileHeight() *
61 op.getTileWidth() * op.getVBlocks() / subgroupSize;
62 if (resSize != expectedSize)
63 return op.emitOpError() << "result size of " << resSize
64 << " bits does not match the expected size of "
65 << expectedSize << " bits";
66
67 if (op.getTranspose() && op.getPackRegister())
68 return op.emitOpError(message: "transpose and pack_register are mutually exclusive");
69
70 if (!op.getTranspose() && !op.getPackRegister()) {
71 uint32_t tileHeight = op.getTileHeight();
72 if (tileHeight < 1 || tileHeight > 32)
73 return op.emitOpError(message: "expecting tile_height to be between 1 and 32");
74
75 uint32_t tileWidth = op.getTileWidth();
76 uint32_t vBlocks = op.getVBlocks();
77 switch (op.getElemSizeInBits()) {
78 case 8:
79 if (tileWidth < 4 || tileWidth > 64)
80 return op.emitOpError(message: "expecting tile_width to be between 4 and 64");
81 if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
82 return op.emitOpError(message: "expecting v_blocks to be 1, 2, or 4");
83 if (tileWidth * vBlocks > 64)
84 return op.emitOpError(
85 message: "tile_width * v_blocks should be less than or equal "
86 "to 64 for 8 bit elements");
87 break;
88 case 16:
89 if (tileWidth < 2 || tileWidth > 32)
90 return op.emitOpError(message: "expecting tile_width to be between 2 and 32");
91 if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
92 return op.emitOpError(message: "expecting v_blocks to be 1, 2, or 4");
93 if (tileWidth * vBlocks > 32)
94 return op.emitOpError(
95 message: "tile_width * v_blocks should be less than or equal "
96 "to 32 for 16 bit elements");
97 break;
98 case 32:
99 if (tileWidth < 1 || tileWidth > 16)
100 return op.emitOpError(message: "expecting tile_width to be between 1 and 16");
101 if (vBlocks != 1 && vBlocks != 2)
102 return op.emitOpError(message: "expecting v_blocks to be 1 or 2");
103 if (tileWidth * vBlocks > 16)
104 return op.emitOpError(
105 message: "tile_width * v_blocks should be less than or equal "
106 "to 16 for 32 bit elements");
107 break;
108 case 64:
109 if (tileWidth < 1 || tileWidth > 8)
110 return op.emitOpError(message: "expecting tile_width to be between 1 and 8");
111 if (vBlocks != 1)
112 return op.emitOpError(message: "expecting v_blocks to be 1");
113 break;
114 default:
115 return op.emitOpError(
116 message: "expecting elem_size_in_bits to be 8, 16, 32, or 64");
117 }
118
119 return success();
120 }
121
122 if (op.getTranspose()) {
123 assert(!op.getPackRegister() && "Expecting pack_register should be false");
124
125 uint32_t vBlocks = op.getVBlocks();
126 if (vBlocks != 1)
127 return op.emitOpError(message: "expecting v_blocks to be 1");
128
129 uint32_t tileHeight = op.getTileHeight();
130 uint32_t tileWidth = op.getTileWidth();
131 switch (op.getElemSizeInBits()) {
132 case 32:
133 if (tileHeight < 1 || tileHeight > 32)
134 return op.emitOpError(message: "expecting tile_height to be between 1 and 32");
135 if (tileWidth < 1 || tileWidth > 8)
136 return op.emitOpError(message: "expecting tile_width to be between 1 and 8");
137 break;
138 case 64:
139 if (tileHeight != 8)
140 return op.emitOpError(
141 message: "expecting tile_height to be 8 for 64 bit elements");
142 if (tileWidth != 1 && tileWidth != 2 && tileWidth != 4)
143 return op.emitOpError(message: "expecting tile_width to be 1, 2, or 4");
144 break;
145 default:
146 return op.emitOpError(message: "transpose is only supported for 32 and 64 bit "
147 "elements");
148 }
149
150 return success();
151 }
152
153 assert(op.getPackRegister() && !op.getTranspose() &&
154 "Expecting pack_register should be true and transpose should be "
155 "false");
156
157 uint32_t vBlocks = op.getVBlocks();
158 if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
159 return op.emitOpError(message: "expecting v_blocks to be 1, 2, or 4");
160
161 uint32_t tileHeight = op.getTileHeight();
162 uint32_t tileWidth = op.getTileWidth();
163 switch (op.getElemSizeInBits()) {
164 case 8:
165 if (tileHeight < 4 || tileHeight > 32)
166 return op.emitOpError(message: "expecting tile_height to be between 4 and 32");
167 if (tileWidth < 4 || tileWidth > 16)
168 return op.emitOpError(message: "expecting tile_width to be between 4 and 16");
169 break;
170 case 16:
171 if (tileHeight < 2 || tileHeight > 32)
172 return op.emitOpError(message: "expecting tile_height to be between 2 and 32");
173 if (tileWidth < 2 || tileWidth > 16)
174 return op.emitOpError(message: "expecting tile_width to be between 2 and 16");
175 if (tileWidth * vBlocks > 32)
176 return op.emitOpError(
177 message: "tile_width * v_blocks should be less than or equal "
178 "to 32 for 16 bit elements");
179 break;
180 default:
181 return op.emitOpError(message: "pack_register is only supported for 8 and 16 bit "
182 "elements");
183 }
184
185 return success();
186}
187
188static LogicalResult verify2DBlockStoreRestriction(BlockStore2dOp op) {
189 uint32_t tileHeight = op.getTileHeight();
190 if (tileHeight < 1 || tileHeight > 8)
191 return op.emitOpError(message: "expecting tile_height to be between 1 and 8");
192
193 uint32_t tileWidth = op.getTileWidth();
194 switch (op.getElemSizeInBits()) {
195 case 8:
196 if (tileWidth < 4 || tileWidth > 64)
197 return op.emitOpError(message: "expecting tile_width to be between 4 and 64");
198 break;
199 case 16:
200 if (tileWidth < 2 || tileWidth > 32)
201 return op.emitOpError(message: "expecting tile_width to be between 2 and 32");
202 break;
203 case 32:
204 if (tileWidth < 1 || tileWidth > 16)
205 return op.emitOpError(message: "expecting tile_width to be between 1 and 16");
206 break;
207 case 64:
208 if (tileWidth < 1 || tileWidth > 8)
209 return op.emitOpError(message: "expecting tile_width to be between 1 and 8");
210 break;
211 default:
212 return op.emitOpError(message: "expecting elem_size_in_bits to be 8, 16, 32, or 64");
213 }
214
215 uint32_t vBlocks = op.getVBlocks();
216 if (vBlocks != 1)
217 return op.emitOpError(message: "expecting v_blocks to be 1");
218 return success();
219}
220
221} // namespace
222
223LogicalResult BlockLoad2dOp::verify() {
224 if (verify2DBlockLoadRestriction(op: *this).failed())
225 return failure();
226
227 if (verifyMatrixInput(op: *this).failed())
228 return failure();
229
230 VectorType resTy = getRes().getType();
231 if (!resTy.getElementType().isIntOrFloat())
232 return emitOpError() << "expecting result element type to be int of float";
233 unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
234 if (getElemSizeInBits() == 32 || getPackRegister()) {
235 if (resElemTySize != 32)
236 return emitOpError() << "expecting result element type to be 32 bits";
237 }
238
239 uint32_t tileWidth = getTileWidth();
240 if (getPackRegister()) {
241 if (tileWidth != 16)
242 return emitOpError(
243 message: "tile_width when pack_register is true should be equal "
244 "to subgroup size (16 elements)");
245 return success();
246 }
247
248 return success();
249}
250
251LogicalResult BlockStore2dOp::verify() {
252 if (verify2DBlockStoreRestriction(op: *this).failed())
253 return failure();
254
255 if (verifyMatrixInput(op: *this).failed())
256 return failure();
257
258 uint32_t tileWidth = getTileWidth();
259 switch (getElemSizeInBits()) {
260 case 8:
261 if (tileWidth != 16 && tileWidth != 32)
262 return emitOpError(message: "tile_width for 8 bit elements should be equal to "
263 "16 or 32");
264 break;
265 case 16:
266 if (tileWidth != 16)
267 return emitOpError(message: "tile_width for 16 bit elements should be equal "
268 "to 16");
269 break;
270 case 32:
271 if (tileWidth != 16)
272 return emitOpError(message: "tile_width for 32 bit elements should be equal "
273 "to 16");
274 break;
275 default:
276 llvm_unreachable("unexpected element size");
277 }
278
279 return success();
280}
281
282LogicalResult BlockPrefetch2dOp::verify() {
283 if (verifyMatrixInput(op: *this).failed())
284 return failure();
285
286 uint32_t tileWidth = getTileWidth();
287 switch (getElemSizeInBits()) {
288 case 8:
289 if (tileWidth != 16 && tileWidth != 32)
290 return emitOpError(message: "tile_width for 8 bit elements should be equal to "
291 "16 or 32");
292 break;
293 case 16:
294 if (tileWidth != 16)
295 return emitOpError(message: "tile_width for 16 bit elements should be equal "
296 "to 16");
297 break;
298 case 32:
299 if (tileWidth != 8 && tileWidth != 16)
300 return emitOpError(
301 message: "tile_width for 32 bit elements should be equal to 8 or 16");
302 break;
303 default:
304 llvm_unreachable("unexpected element size");
305 }
306
307 return success();
308}
309
310LogicalResult MMAOp::verify() {
311 if (getC()) {
312 if (getResult().getType() != getC().getType())
313 return emitOpError(message: "type of C operand must match result type");
314 }
315 return success();
316}
317
318LogicalResult
319XeVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError, int O,
320 StringRef triple, StringRef chip, DictionaryAttr flags,
321 ArrayAttr linkFiles) {
322 if (O < 0 || O > 3) {
323 return emitError()
324 << "The optimization level must be a number between 0 and 3.";
325 }
326 if (triple.empty()) {
327 return emitError() << "The target triple cannot be empty.";
328 }
329 if (chip.empty()) {
330 return emitError() << "The target chip cannot be empty.";
331 }
332 if (linkFiles) {
333 for (Attribute fileAttr : linkFiles) {
334 if (auto fileStrAttr = llvm::dyn_cast<StringAttr>(Val&: fileAttr)) {
335 StringRef filePath = fileStrAttr.getValue();
336 if (filePath.empty()) {
337 return emitError() << "File paths in linkFiles cannot be empty.";
338 }
339 if (!llvm::sys::fs::exists(Path: filePath)) {
340 return emitError() << "File '" << filePath << "' does not exist.";
341 }
342 }
343 }
344 }
345 return success();
346}
347
348void XeVMDialect::initialize() {
349 addOperations<
350#define GET_OP_LIST
351#include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
352 >();
353
354 addAttributes<
355#define GET_ATTRDEF_LIST
356#include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc"
357 >();
358 declarePromisedInterface<mlir::gpu::TargetAttrInterface,
359 mlir::xevm::XeVMTargetAttr>();
360}
361
362#define GET_OP_CLASSES
363#include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
364
365#define GET_ATTRDEF_CLASSES
366#include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc"
367

source code of mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp