1 | //===- XeGPUOps.cpp - MLIR XeGPU ops implementation -------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
10 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
11 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
12 | #include "mlir/Dialect/XeGPU/IR/XeGPU.h" |
13 | #include "mlir/IR/Builders.h" |
14 | #include "mlir/IR/TypeUtilities.h" |
15 | |
16 | #include "llvm/Support/Debug.h" |
17 | |
18 | #define DEBUG_TYPE "xegpu" |
19 | |
20 | namespace mlir { |
21 | namespace xegpu { |
22 | |
23 | static void transpose(llvm::ArrayRef<int64_t> trans, |
24 | SmallVector<int64_t> &shape) { |
25 | SmallVector<int64_t> old = shape; |
26 | for (size_t i = 0; i < trans.size(); i++) |
27 | shape[i] = old[trans[i]]; |
28 | } |
29 | |
30 | template <typename T> |
31 | static std::string makeString(T array, bool breakline = false) { |
32 | std::string buf; |
33 | buf.clear(); |
34 | llvm::raw_string_ostream os(buf); |
35 | os << "[" ; |
36 | for (size_t i = 1; i < array.size(); i++) { |
37 | os << array[i - 1] << ", " ; |
38 | if (breakline) |
39 | os << "\n\t\t" ; |
40 | } |
41 | os << array.back() << "]" ; |
42 | return buf; |
43 | } |
44 | |
45 | static SmallVector<int64_t> getShapeOf(Type type) { |
46 | SmallVector<int64_t> shape; |
47 | if (auto ty = llvm::dyn_cast<ShapedType>(type)) |
48 | shape = SmallVector<int64_t>(ty.getShape()); |
49 | else |
50 | shape.push_back(Elt: 1); |
51 | return shape; |
52 | } |
53 | |
54 | static int64_t getRankOf(Value val) { |
55 | auto type = val.getType(); |
56 | if (auto ty = llvm::dyn_cast<ShapedType>(type)) |
57 | return ty.getRank(); |
58 | return 0; |
59 | } |
60 | |
61 | static bool isReadHintOrNone(const CachePolicyAttr &attr) { |
62 | if (!attr) |
63 | return true; |
64 | auto kind = attr.getValue(); |
65 | return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED || |
66 | kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE; |
67 | } |
68 | |
69 | static bool isWriteHintOrNone(const CachePolicyAttr &attr) { |
70 | if (!attr) |
71 | return true; |
72 | auto kind = attr.getValue(); |
73 | return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED || |
74 | kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH; |
75 | } |
76 | |
77 | static LogicalResult |
78 | isValidGatherScatterParams(Type maskTy, VectorType valueTy, |
79 | TensorDescType tdescTy, UnitAttr transposeAttr, |
80 | function_ref<InFlightDiagnostic()> emitError) { |
81 | |
82 | if (!tdescTy.isScattered()) |
83 | return emitError() << "Expects a scattered TensorDesc." ; |
84 | |
85 | if (!valueTy) |
86 | return emitError() << "Expecting a vector type result." ; |
87 | |
88 | auto maskShape = getShapeOf(type: maskTy); |
89 | auto valueShape = getShapeOf(valueTy); |
90 | auto tdescShape = getShapeOf(tdescTy); |
91 | auto chunkSize = tdescTy.getChunkSize(); |
92 | |
93 | if (valueTy.getElementType() != tdescTy.getElementType()) |
94 | return emitError() |
95 | << "Value should have the same element type as TensorDesc." ; |
96 | |
97 | if (tdescShape[0] != maskShape[0]) |
98 | return emitError() |
99 | << "dim-0 of the Mask and TensorDesc should be the same." ; |
100 | |
101 | // a valid shape for SIMT case |
102 | if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) { |
103 | if (tdescTy.getLayoutAttr()) |
104 | return emitError() << "TensorDesc doesn't need LayoutAttr for SIMT code" ; |
105 | if (transposeAttr) |
106 | return emitError() << "doesn't need TransposeAttr for SIMT code" ; |
107 | return success(); |
108 | } |
109 | |
110 | if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) { |
111 | if (!transposeAttr) |
112 | return emitError() << "rank-2 tensor has to be transposed." ; |
113 | transpose({1, 0}, tdescShape); |
114 | } |
115 | |
116 | if (tdescShape != valueShape) |
117 | return emitError() << "Value shape " << makeString(valueShape) |
118 | << " is neither a valid distribution for SIMT nor " |
119 | "consistent with the tensor descriptor for SIMD " |
120 | << tdescTy; |
121 | return success(); |
122 | } |
123 | |
124 | //===----------------------------------------------------------------------===// |
125 | // XeGPU_CreateNdDescOp |
126 | //===----------------------------------------------------------------------===// |
127 | void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, |
128 | Type tdesc, TypedValue<MemRefType> source, |
129 | llvm::ArrayRef<OpFoldResult> offsets) { |
130 | [[maybe_unused]] auto ty = source.getType(); |
131 | assert(ty.hasStaticShape() && offsets.size() == (size_t)ty.getRank()); |
132 | |
133 | llvm::SmallVector<int64_t> staticOffsets; |
134 | llvm::SmallVector<Value> dynamicOffsets; |
135 | dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); |
136 | |
137 | build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */, |
138 | ValueRange({}) /* empty dynamic shape */, |
139 | ValueRange({}) /* empty dynamic strides */, |
140 | staticOffsets /* const offsets */, {} /* empty const shape*/, |
141 | {} /* empty const strides*/); |
142 | } |
143 | |
144 | void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, |
145 | Type tdesc, Value source, |
146 | llvm::ArrayRef<OpFoldResult> offsets, |
147 | llvm::ArrayRef<OpFoldResult> shape, |
148 | llvm::ArrayRef<OpFoldResult> strides) { |
149 | assert(shape.size() && offsets.size() && strides.size() && |
150 | shape.size() == strides.size() && shape.size() == offsets.size()); |
151 | |
152 | Type srcTy = source.getType(); |
153 | assert(isa<IntegerType>(srcTy) || |
154 | isa<MemRefType>(srcTy) && "Source has to be either int or memref." ); |
155 | |
156 | llvm::SmallVector<Value> dynamicOffsets; |
157 | llvm::SmallVector<Value> dynamicShape; |
158 | llvm::SmallVector<Value> dynamicStrides; |
159 | |
160 | llvm::SmallVector<int64_t> staticOffsets; |
161 | llvm::SmallVector<int64_t> staticShape; |
162 | llvm::SmallVector<int64_t> staticStrides; |
163 | |
164 | dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); |
165 | dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); |
166 | dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); |
167 | |
168 | auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); |
169 | auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); |
170 | auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); |
171 | |
172 | if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) { |
173 | auto memrefShape = memrefTy.getShape(); |
174 | auto [memrefStrides, _] = memrefTy.getStridesAndOffset(); |
175 | |
176 | // if shape and strides are from Memref, we don't need attributes for them |
177 | // to keep the IR print clean. |
178 | if (staticShape == memrefShape && staticStrides == memrefStrides) { |
179 | staticShapeAttr = DenseI64ArrayAttr(); |
180 | staticStridesAttr = DenseI64ArrayAttr(); |
181 | } |
182 | } |
183 | |
184 | build(builder, state, tdesc, source, dynamicOffsets, dynamicShape, |
185 | dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr); |
186 | } |
187 | |
188 | LogicalResult CreateNdDescOp::verify() { |
189 | auto rank = (int64_t)getMixedOffsets().size(); |
190 | bool invalidRank = false; |
191 | bool invalidElemTy = false; |
192 | |
193 | // Memory space of created TensorDesc should match with the source. |
194 | // Both source and TensorDesc are considered for global memory by default, |
195 | // if the memory scope attr is not specified. If source is an integer, |
196 | // it is considered as ptr to global memory. |
197 | auto srcMemorySpace = getSourceMemorySpace(); |
198 | auto tdescMemorySpace = static_cast<unsigned>(getType().getMemorySpace()); |
199 | if (srcMemorySpace != tdescMemorySpace) |
200 | return emitOpError("Memory space mismatch." ) |
201 | << " Source: " << srcMemorySpace |
202 | << ", TensorDesc: " << tdescMemorySpace; |
203 | |
204 | // check source type matches the rank if it is a memref. |
205 | // It also should have the same ElementType as TensorDesc. |
206 | auto memrefTy = dyn_cast<MemRefType>(getSourceType()); |
207 | if (memrefTy) { |
208 | invalidRank |= (memrefTy.getRank() != rank); |
209 | invalidElemTy |= memrefTy.getElementType() != getElementType(); |
210 | } |
211 | |
212 | // mismatches among shape, strides, and offsets are |
213 | // already handeled by OffsetSizeAndStrideOpInterface. |
214 | // So they are not check here. |
215 | if (invalidRank) |
216 | return emitOpError( |
217 | "Expecting the rank of shape, strides, offsets, and source (if source " |
218 | "is a memref) should match with each other." ); |
219 | |
220 | // check result TensorDesc rank |
221 | invalidRank = (getType().getRank() > 2 || getType().getRank() > rank); |
222 | |
223 | if (invalidRank) |
224 | return emitOpError( |
225 | "Expecting the TensorDesc rank is up to 2 and not greater than the " |
226 | "ranks of shape, strides, offsets or the memref source." ); |
227 | |
228 | if (invalidElemTy) |
229 | return emitOpError("TensorDesc should have the same element " |
230 | "type with the source if it is a memref.\n" ); |
231 | |
232 | if (getType().isScattered()) |
233 | return emitOpError("Expects a non-scattered TensorDesc.\n" ); |
234 | |
235 | return success(); |
236 | } |
237 | |
238 | //===----------------------------------------------------------------------===// |
239 | // XeGPU_PrefetchNdOp |
240 | //===----------------------------------------------------------------------===// |
241 | LogicalResult PrefetchNdOp::verify() { |
242 | auto tdescTy = getTensorDescType(); |
243 | if (tdescTy.isScattered()) |
244 | return emitOpError("Expects a non-scattered TensorDesc.\n" ); |
245 | |
246 | if (!isReadHintOrNone(getL1HintAttr())) |
247 | return emitOpError("invalid l1_hint: " ) << getL1HintAttr(); |
248 | |
249 | if (!isReadHintOrNone(getL2HintAttr())) |
250 | return emitOpError("invalid l2_hint: " ) << getL2HintAttr(); |
251 | |
252 | if (!isReadHintOrNone(getL3HintAttr())) |
253 | return emitOpError("invalid l3_hint: " ) << getL3HintAttr(); |
254 | |
255 | return success(); |
256 | } |
257 | |
258 | //===----------------------------------------------------------------------===// |
259 | // XeGPU_LoadNdOp |
260 | //===----------------------------------------------------------------------===// |
261 | LogicalResult LoadNdOp::verify() { |
262 | auto tdescTy = getTensorDescType(); |
263 | auto valueTy = getType(); |
264 | |
265 | if (tdescTy.getRank() > 2) |
266 | return emitOpError("Expecting a 1D/2D TensorDesc.\n" ); |
267 | |
268 | if (tdescTy.isScattered()) |
269 | return emitOpError("Expects a non-scattered TensorDesc.\n" ); |
270 | |
271 | if (!valueTy) |
272 | return emitOpError("Invalid result, it should be a VectorType.\n" ); |
273 | |
274 | if (!isReadHintOrNone(getL1HintAttr())) |
275 | return emitOpError("invalid l1_hint: " ) << getL1HintAttr(); |
276 | |
277 | if (!isReadHintOrNone(getL2HintAttr())) |
278 | return emitOpError("invalid l2_hint: " ) << getL2HintAttr(); |
279 | |
280 | if (!isReadHintOrNone(getL3HintAttr())) |
281 | return emitOpError("invalid l3_hint: " ) << getL3HintAttr(); |
282 | |
283 | int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength(); |
284 | int valueElems = valueTy.getNumElements(); |
285 | |
286 | // If the result vector is 1D and has less elements than the tensor |
287 | // descriptor, it is supposed to be a SIMT op. The layout attribute in |
288 | // tensor_desc is not needed. |
289 | if (valueElems < tdescElems && valueTy.getRank() == 1) { |
290 | // SIMT mode doesn't need LayoutAttr. |
291 | if (tdescTy.getLayoutAttr()) |
292 | return emitOpError() |
293 | << "TensorDesc doesn't need LayoutAttr for SIMT code" ; |
294 | |
295 | // For SIMT code, the load is evenly distributed across all lanes in a |
296 | // subgroup. Since subgroup size is arch dependent, we only check even |
297 | // distribution here. |
298 | if (tdescElems % valueElems) |
299 | return emitOpError() |
300 | << "Result shape " << makeString(getShapeOf(valueTy)) |
301 | << " is not a valid distribution for tensor descriptor " |
302 | << tdescTy; |
303 | |
304 | return success(); |
305 | } |
306 | |
307 | // Check SIMD mode. |
308 | auto tdescShape = getShapeOf(tdescTy); |
309 | auto valueShape = getShapeOf(valueTy); |
310 | |
311 | if (getTranspose()) { |
312 | auto trans = getTranspose().value(); |
313 | |
314 | // Make sure the transpose value is valid. |
315 | bool valid = llvm::all_of( |
316 | trans, [&](int t) { return t >= 0 && t < tdescTy.getRank(); }); |
317 | |
318 | if (valid) |
319 | transpose(trans, tdescShape); |
320 | else |
321 | mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored." ; |
322 | } |
323 | |
324 | if (getPacked()) { |
325 | if (tdescTy.getRank() == 2) { |
326 | const int axis = 0; |
327 | auto vnni_factor = valueShape.back(); |
328 | tdescShape[axis] /= vnni_factor; |
329 | tdescShape.push_back(vnni_factor); |
330 | } else { |
331 | mlir::emitWarning(getLoc()) |
332 | << "Invalid Packed Attr. It is ignored (available for 2D " |
333 | "TensorDesc only)." ; |
334 | } |
335 | } |
336 | |
337 | auto array_len = tdescTy.getArrayLength(); |
338 | if (array_len > 1) { |
339 | tdescShape.insert(tdescShape.begin(), array_len); |
340 | } |
341 | |
342 | if (tdescShape != valueShape) { |
343 | return emitOpError() << "Result shape " << makeString(valueShape) |
344 | << " is not consistent with tensor descriptor " |
345 | << tdescTy; |
346 | } |
347 | |
348 | return success(); |
349 | } |
350 | |
351 | //===----------------------------------------------------------------------===// |
352 | // XeGPU_StoreNdOp |
353 | //===----------------------------------------------------------------------===// |
354 | LogicalResult StoreNdOp::verify() { |
355 | auto dstTy = getTensorDescType(); // Tile |
356 | auto valTy = getValueType(); // Vector |
357 | |
358 | if (dstTy.getRank() > 2) |
359 | return emitOpError("Expecting a 1D/2D TensorDesc.\n" ); |
360 | |
361 | if (dstTy.isScattered()) |
362 | return emitOpError("Expects a non-scattered TensorDesc.\n" ); |
363 | |
364 | if (!valTy) |
365 | return emitOpError("Expecting a VectorType result.\n" ); |
366 | |
367 | if (!isWriteHintOrNone(getL1HintAttr())) |
368 | return emitOpError("invalid l1_hint: " ) << getL1HintAttr(); |
369 | |
370 | if (!isWriteHintOrNone(getL2HintAttr())) |
371 | return emitOpError("invalid l2_hint: " ) << getL2HintAttr(); |
372 | |
373 | if (!isWriteHintOrNone(getL3HintAttr())) |
374 | return emitOpError("invalid l3_hint: " ) << getL3HintAttr(); |
375 | |
376 | auto array_len = dstTy.getArrayLength(); |
377 | if (array_len > 1) |
378 | return emitOpError("array length is not supported by store_nd.\n" ); |
379 | |
380 | auto tdescElems = dstTy.getNumElements(); |
381 | auto valueElems = valTy.getNumElements(); |
382 | |
383 | // Similar to LoadNdOp, if the value vector is 1D and has less elements than |
384 | // the tensor descriptor, it is supposed to be a SIMT op. The layout attribute |
385 | // in tensor_desc is not needed. |
386 | if (valTy.getRank() == 1 && valueElems < tdescElems) { |
387 | // SIMT mode doesn't need LayoutAttr. |
388 | if (dstTy.getLayoutAttr()) |
389 | return emitOpError() |
390 | << "TensorDesc doesn't need LayoutAttr for SIMT code" ; |
391 | |
392 | if (tdescElems % valueElems) { |
393 | return emitOpError() |
394 | << "Value shape " << makeString(getShapeOf(valTy)) |
395 | << " is not a valid distribution for tensor descriptor " << dstTy; |
396 | } |
397 | return success(); |
398 | } |
399 | |
400 | // SIMD code should have the same shape as the tensor descriptor. |
401 | auto tdescShape = getShapeOf(dstTy); |
402 | auto valueShape = getShapeOf(valTy); |
403 | if (tdescShape != valueShape) { |
404 | return emitOpError() << "Value shape " << makeString(valueShape) |
405 | << " is not consistent with tensor descriptor " |
406 | << dstTy; |
407 | } |
408 | |
409 | return success(); |
410 | } |
411 | |
412 | //===----------------------------------------------------------------------===// |
413 | // XeGPU_UpdateNDOffsetOp |
414 | //===----------------------------------------------------------------------===// |
415 | LogicalResult UpdateNdOffsetOp::verify() { |
416 | auto ty = getTensorDescType(); |
417 | if (ty.isScattered()) |
418 | return emitOpError("Expects a non-scattered TensorDesc.\n" ); |
419 | |
420 | // number of offsets specified must match the rank of the tensor descriptor |
421 | if (ty.getRank() != (int64_t)getNumOffsets()) { |
422 | return emitOpError("Invalid number of offsets." ); |
423 | } |
424 | return success(); |
425 | } |
426 | |
427 | //===----------------------------------------------------------------------===// |
428 | // XeGPU_CreateDescOp |
429 | //===----------------------------------------------------------------------===// |
430 | |
431 | void CreateDescOp::build(OpBuilder &builder, OperationState &state, |
432 | TensorDescType TensorDesc, Value source, |
433 | llvm::ArrayRef<OpFoldResult> offsets) { |
434 | auto loc = source.getLoc(); |
435 | int64_t size = static_cast<int64_t>(offsets.size()); |
436 | auto type = VectorType::get(size, builder.getIndexType()); |
437 | auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); |
438 | auto offset = builder.create<vector::FromElementsOp>(loc, type, values); |
439 | build(builder, state, TensorDesc, source, offset); |
440 | } |
441 | |
442 | void CreateDescOp::build(OpBuilder &builder, OperationState &state, |
443 | TensorDescType TensorDesc, Value source, |
444 | llvm::ArrayRef<int64_t> offsets) { |
445 | auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets); |
446 | build(builder, state, TensorDesc, source, ofrs); |
447 | } |
448 | |
449 | LogicalResult CreateDescOp::verify() { |
450 | auto tdescTy = getTensorDescType(); |
451 | |
452 | if (getRankOf(getSource()) > 1) |
453 | return emitOpError( |
454 | "Expecting the source is a 1D memref or pointer (uint64_t)." ); |
455 | |
456 | if (!tdescTy.isScattered()) |
457 | return emitOpError("Expects a scattered TensorDesc.\n" ); |
458 | |
459 | // Memory space of created TensorDesc should match with the source. |
460 | // Both source and TensorDesc are considered for global memory by default, |
461 | // if the memory scope attr is not specified. If source is an integer, |
462 | // it is considered as ptr to global memory. |
463 | auto srcMemorySpace = getSourceMemorySpace(); |
464 | auto tdescMemorySpace = static_cast<unsigned>(tdescTy.getMemorySpace()); |
465 | if (srcMemorySpace != tdescMemorySpace) |
466 | return emitOpError("Memory space mismatch." ) |
467 | << " Source: " << srcMemorySpace |
468 | << ", TensorDesc: " << tdescMemorySpace; |
469 | |
470 | // check total size |
471 | auto chunkSize = tdescTy.getChunkSize(); |
472 | auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth(); |
473 | auto bitsPerLane = elemBits * chunkSize; |
474 | if (chunkSize > 1 && bitsPerLane % 32) { |
475 | // For 8-bit and 16-bit data, the hardware only supports chunk size of 1. |
476 | // For 32-bit data, the hardware can support larger larger chunk size. So |
477 | // we can bitcast 8-bit/16-bit data to 32-bit data for better performance. |
478 | // But this requires the total size is 32 bit aligned to make the |
479 | // optimization work. |
480 | return emitOpError( |
481 | "access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned." ); |
482 | } |
483 | |
484 | auto lscConstraints = 512 * 8; // each access is upto 512 bytes. |
485 | if (elemBits * tdescTy.getNumElements() > lscConstraints) |
486 | return emitOpError("total access size (simd_lanes * chunk_size * " |
487 | "sizeof(elemTy)) is upto 512 bytes." ); |
488 | |
489 | SmallVector<int64_t> shape({(int64_t)getNumOffsets()}); |
490 | if (chunkSize != 1) |
491 | shape.push_back(chunkSize); |
492 | |
493 | auto tdescShape = getShapeOf(tdescTy); |
494 | if (shape != tdescShape) |
495 | return emitOpError("Incorrect TensorDesc shape. " ) |
496 | << "Expected is " << makeString(shape) << "\n" ; |
497 | |
498 | return success(); |
499 | } |
500 | |
501 | //===----------------------------------------------------------------------===// |
502 | // XeGPU_PrefetchOp |
503 | //===----------------------------------------------------------------------===// |
504 | LogicalResult PrefetchOp::verify() { |
505 | auto tdescTy = getTensorDescType(); |
506 | if (!tdescTy.isScattered()) |
507 | return emitOpError("Expects a scattered TensorDesc.\n" ); |
508 | |
509 | if (!isReadHintOrNone(getL1HintAttr())) |
510 | return emitOpError("invalid l1_hint: " ) << getL1HintAttr(); |
511 | |
512 | if (!isReadHintOrNone(getL2HintAttr())) |
513 | return emitOpError("invalid l2_hint: " ) << getL2HintAttr(); |
514 | |
515 | if (!isReadHintOrNone(getL3HintAttr())) |
516 | return emitOpError("invalid l3_hint: " ) << getL3HintAttr(); |
517 | |
518 | return success(); |
519 | } |
520 | |
521 | //===----------------------------------------------------------------------===// |
522 | // XeGPU_LoadGatherOp |
523 | //===----------------------------------------------------------------------===// |
524 | LogicalResult LoadGatherOp::verify() { |
525 | auto tdescTy = getTensorDescType(); |
526 | auto maskTy = getMaskType(); |
527 | auto valueTy = getValueType(); |
528 | |
529 | if (!isReadHintOrNone(getL1HintAttr())) |
530 | return emitOpError("invalid l1_hint: " ) << getL1HintAttr(); |
531 | |
532 | if (!isReadHintOrNone(getL2HintAttr())) |
533 | return emitOpError("invalid l2_hint: " ) << getL2HintAttr(); |
534 | |
535 | if (!isReadHintOrNone(getL3HintAttr())) |
536 | return emitOpError("invalid l3_hint: " ) << getL3HintAttr(); |
537 | |
538 | return isValidGatherScatterParams(maskTy, valueTy, tdescTy, |
539 | getTransposeAttr(), |
540 | [&]() { return emitOpError(); }); |
541 | } |
542 | |
543 | //===----------------------------------------------------------------------===// |
544 | // XeGPU_StoreScatterOp |
545 | //===----------------------------------------------------------------------===// |
546 | LogicalResult StoreScatterOp::verify() { |
547 | auto tdescTy = getTensorDescType(); |
548 | auto maskTy = getMaskType(); |
549 | auto valueTy = getValueType(); |
550 | |
551 | if (!isWriteHintOrNone(getL1HintAttr())) |
552 | return emitOpError("invalid l1_hint: " ) << getL1HintAttr(); |
553 | |
554 | if (!isWriteHintOrNone(getL2HintAttr())) |
555 | return emitOpError("invalid l2_hint: " ) << getL2HintAttr(); |
556 | |
557 | if (!isWriteHintOrNone(getL3HintAttr())) |
558 | return emitOpError("invalid l3_hint: " ) << getL3HintAttr(); |
559 | |
560 | return isValidGatherScatterParams(maskTy, valueTy, tdescTy, |
561 | getTransposeAttr(), |
562 | [&]() { return emitOpError(); }); |
563 | } |
564 | |
565 | //===----------------------------------------------------------------------===// |
566 | // XeGPU_UpdateOffsetOp |
567 | //===----------------------------------------------------------------------===// |
568 | void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state, |
569 | mlir::Value tensorDesc, |
570 | llvm::ArrayRef<OpFoldResult> offsets) { |
571 | auto tdescTy = mlir::dyn_cast<TensorDescType>(tensorDesc.getType()); |
572 | assert(tdescTy && "Expecting the source is a TensorDescType value." ); |
573 | auto loc = tensorDesc.getLoc(); |
574 | int64_t size = static_cast<int64_t>(offsets.size()); |
575 | auto type = VectorType::get({size}, builder.getIndexType()); |
576 | auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); |
577 | auto offset = builder.create<vector::FromElementsOp>(loc, type, values); |
578 | build(builder, state, tdescTy, tensorDesc, offset); |
579 | } |
580 | |
581 | void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state, |
582 | Value tensorDesc, llvm::ArrayRef<int64_t> offsets) { |
583 | auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets); |
584 | build(builder, state, tensorDesc, ofrs); |
585 | } |
586 | |
587 | //===----------------------------------------------------------------------===// |
588 | // XeGPU_DpasOp |
589 | //===----------------------------------------------------------------------===// |
590 | LogicalResult DpasOp::verify() { |
591 | int64_t lhsRank = getLhsType().getRank(); |
592 | int64_t rhsRank = getRhsType().getRank(); |
593 | int64_t resRank = getResultType().getRank(); |
594 | auto lhsShape = getLhsType().getShape(); |
595 | auto rhsShape = getRhsType().getShape(); |
596 | auto resShape = getResultType().getShape(); |
597 | |
598 | if (getAcc() && getAcc().getType() != getResultType()) |
599 | return emitOpError("Expecting the acc type to be the same as result." ); |
600 | |
601 | // SIMT code: the size of the B operand has to be a multiple of 32 bits. |
602 | // It skips the semantic check since lack of architecture information. |
603 | // Users need to ensure the correctness. |
604 | if (lhsRank == 1 && rhsRank == 1 && resRank == 1) { |
605 | auto numElems = getRhsType().getNumElements(); |
606 | auto elemTy = getRhsType().getElementType(); |
607 | auto factor = 32 / elemTy.getIntOrFloatBitWidth(); |
608 | if (numElems % factor != 0) |
609 | return emitOpError("Expecting B operand to be a multiple of 32 bits." ); |
610 | return success(); |
611 | } |
612 | |
613 | // SIMD code |
614 | if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2) |
615 | return emitOpError( |
616 | "expecting lhs and result to be a 2D vector, and rhs to be either " |
617 | "2D or 3D (packed) vector." ); |
618 | auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0]; |
619 | if (bK != lhsShape[1]) |
620 | return emitOpError("K-dimension mismatch." ); |
621 | if (lhsShape[0] != resShape[0]) |
622 | return emitOpError("M-dimension mismatch." ); |
623 | if (rhsShape[1] != resShape[1]) |
624 | return emitOpError("N-dimension mismatch." ); |
625 | |
626 | return success(); |
627 | } |
628 | |
629 | //===----------------------------------------------------------------------===// |
630 | // XeGPU_ConvertLayoutOp |
631 | //===----------------------------------------------------------------------===// |
632 | LogicalResult ConvertLayoutOp::verify() { |
633 | auto srcMap = getSrcMapAttr(); |
634 | auto resMap = getResMapAttr(); |
635 | if (!srcMap) |
636 | return emitOpError("expected srcMap." ); |
637 | if (!resMap) |
638 | return emitOpError("expected resMap." ); |
639 | |
640 | if (srcMap == resMap) |
641 | return emitOpError("expected different srcMap and resMap." ); |
642 | |
643 | // both srcMap and resMap should be WgLayout or SgLayout at the same time. |
644 | if ((!srcMap.isWgLayout() || !resMap.isWgLayout()) && |
645 | (!srcMap.isSgLayout() || !resMap.isSgLayout())) |
646 | return emitOpError( |
647 | "expected srcMap and resMap be WgLayout or SgLayout at the same time." ); |
648 | |
649 | auto shape = getSource().getType().getShape(); |
650 | if (!XeGPUDialect::isEvenlyDistributable(shape, srcMap)) |
651 | return emitOpError("invalid srcMap, data cannot be evenly distributed." ); |
652 | |
653 | if (!XeGPUDialect::isEvenlyDistributable(shape, resMap)) |
654 | return emitOpError("invalid resMap, data cannot be evenly distributed." ); |
655 | |
656 | return mlir::success(); |
657 | } |
658 | |
659 | } // namespace xegpu |
660 | } // namespace mlir |
661 | |
662 | #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc> |
663 | #define GET_OP_CLASSES |
664 | #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc> |
665 | |