1//===- XeGPUOps.cpp - MLIR XeGPU ops implementation -------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/Utils/Utils.h"
10#include "mlir/Dialect/Utils/IndexingUtils.h"
11#include "mlir/Dialect/Utils/StaticValueUtils.h"
12#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
13#include "mlir/IR/Builders.h"
14#include "mlir/IR/TypeUtilities.h"
15
16#include "llvm/Support/Debug.h"
17
18#define DEBUG_TYPE "xegpu"
19
20namespace mlir {
21namespace xegpu {
22
23template <typename T>
24static std::string makeString(T array, bool breakline = false) {
25 std::string buf;
26 buf.clear();
27 llvm::raw_string_ostream os(buf);
28 os << "[";
29 for (size_t i = 1; i < array.size(); i++) {
30 os << array[i - 1] << ", ";
31 if (breakline)
32 os << "\n\t\t";
33 }
34 os << array.back() << "]";
35 return buf;
36}
37
38static SmallVector<int64_t> getShapeOf(Type type) {
39 SmallVector<int64_t> shape;
40 if (auto ty = llvm::dyn_cast<ShapedType>(Val&: type))
41 shape = SmallVector<int64_t>(ty.getShape());
42 else
43 shape.push_back(Elt: 1);
44 return shape;
45}
46
47static int64_t getRankOf(Value val) {
48 auto type = val.getType();
49 if (auto ty = llvm::dyn_cast<ShapedType>(Val&: type))
50 return ty.getRank();
51 return 0;
52}
53
54static bool isReadHintOrNone(const CachePolicyAttr &attr) {
55 if (!attr)
56 return true;
57 auto kind = attr.getValue();
58 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
59 kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
60}
61
62static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
63 if (!attr)
64 return true;
65 auto kind = attr.getValue();
66 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
67 kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
68}
69
70static LogicalResult
71isValidGatherScatterParams(Type maskTy, VectorType valueTy,
72 TensorDescType tdescTy,
73 function_ref<InFlightDiagnostic()> emitError) {
74
75 if (!tdescTy.isScattered())
76 return emitError() << "Expects a scattered TensorDesc.";
77
78 if (!valueTy)
79 return emitError() << "Expecting a vector type result.";
80
81 auto maskShape = getShapeOf(type: maskTy);
82 auto valueShape = getShapeOf(type: valueTy);
83 auto tdescShape = getShapeOf(type: tdescTy);
84 auto chunkSize = tdescTy.getChunkSizeAsInt();
85
86 if (valueTy.getElementType() != tdescTy.getElementType())
87 return emitError()
88 << "Value should have the same element type as TensorDesc.";
89
90 llvm::SmallVector<int64_t> expectedMaskShape(tdescShape);
91 if (chunkSize > 1)
92 expectedMaskShape.pop_back();
93 if (expectedMaskShape != maskShape)
94 return emitError()
95 << "Mask should match TensorDesc except the chunk size dim.";
96
97 // a valid shape for SIMT case
98 if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
99 if (tdescTy.getLayoutAttr())
100 return emitError() << "TensorDesc doesn't need LayoutAttr for SIMT code";
101 return success();
102 }
103
104 if (tdescShape != valueShape)
105 return emitError() << "Value shape " << makeString(array: valueShape)
106 << " is neither a valid distribution for SIMT nor "
107 "consistent with the tensor descriptor for SIMD "
108 << tdescTy;
109 return success();
110}
111
112//===----------------------------------------------------------------------===//
113// XeGPU_CreateNdDescOp
114//===----------------------------------------------------------------------===//
115void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
116 Type tdesc, TypedValue<MemRefType> source,
117 llvm::ArrayRef<OpFoldResult> offsets) {
118 [[maybe_unused]] auto ty = source.getType();
119 assert(ty.hasStaticShape() && offsets.size() == (size_t)ty.getRank());
120
121 llvm::SmallVector<int64_t> staticOffsets;
122 llvm::SmallVector<Value> dynamicOffsets;
123 dispatchIndexOpFoldResults(ofrs: offsets, dynamicVec&: dynamicOffsets, staticVec&: staticOffsets);
124
125 build(odsBuilder&: builder, odsState&: state, TensorDesc: tdesc, source, offsets: dynamicOffsets /* dynamic offsets */,
126 shape: ValueRange({}) /* empty dynamic shape */,
127 strides: ValueRange({}) /* empty dynamic strides */,
128 const_offsets: staticOffsets /* const offsets */, const_shape: {} /* empty const shape*/,
129 const_strides: {} /* empty const strides*/);
130}
131
132void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
133 Type tdesc, Value source,
134 llvm::ArrayRef<OpFoldResult> offsets,
135 llvm::ArrayRef<OpFoldResult> shape,
136 llvm::ArrayRef<OpFoldResult> strides) {
137 assert(shape.size() && offsets.size() && strides.size() &&
138 shape.size() == strides.size() && shape.size() == offsets.size());
139
140 Type srcTy = source.getType();
141 assert((isa<IntegerType, MemRefType>(srcTy)) &&
142 "Source has to be either int or memref.");
143
144 llvm::SmallVector<Value> dynamicOffsets;
145 llvm::SmallVector<Value> dynamicShape;
146 llvm::SmallVector<Value> dynamicStrides;
147
148 llvm::SmallVector<int64_t> staticOffsets;
149 llvm::SmallVector<int64_t> staticShape;
150 llvm::SmallVector<int64_t> staticStrides;
151
152 dispatchIndexOpFoldResults(ofrs: offsets, dynamicVec&: dynamicOffsets, staticVec&: staticOffsets);
153 dispatchIndexOpFoldResults(ofrs: shape, dynamicVec&: dynamicShape, staticVec&: staticShape);
154 dispatchIndexOpFoldResults(ofrs: strides, dynamicVec&: dynamicStrides, staticVec&: staticStrides);
155
156 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(values: staticOffsets);
157 auto staticShapeAttr = builder.getDenseI64ArrayAttr(values: staticShape);
158 auto staticStridesAttr = builder.getDenseI64ArrayAttr(values: staticStrides);
159
160 if (auto memrefTy = dyn_cast<MemRefType>(Val&: srcTy)) {
161 auto memrefShape = memrefTy.getShape();
162 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
163
164 // if shape and strides are from Memref, we don't need attributes for them
165 // to keep the IR print clean.
166 if (staticShape == memrefShape && staticStrides == memrefStrides) {
167 staticShapeAttr = DenseI64ArrayAttr();
168 staticStridesAttr = DenseI64ArrayAttr();
169 }
170 }
171
172 build(odsBuilder&: builder, odsState&: state, TensorDesc: tdesc, source, offsets: dynamicOffsets, shape: dynamicShape,
173 strides: dynamicStrides, const_offsets: staticOffsetsAttr, const_shape: staticShapeAttr, const_strides: staticStridesAttr);
174}
175
176LogicalResult CreateNdDescOp::verify() {
177 auto rank = (int64_t)getMixedOffsets().size();
178 bool invalidRank = false;
179 bool invalidElemTy = false;
180
181 // Memory space of created TensorDesc should match with the source.
182 // Both source and TensorDesc are considered for global memory by default,
183 // if the memory scope attr is not specified. If source is an integer,
184 // it is considered as ptr to global memory.
185 auto srcMemorySpace = getSourceMemorySpace();
186 auto tdescMemorySpace = static_cast<unsigned>(getType().getMemorySpace());
187 if (srcMemorySpace != tdescMemorySpace)
188 return emitOpError(message: "Memory space mismatch.")
189 << " Source: " << srcMemorySpace
190 << ", TensorDesc: " << tdescMemorySpace;
191
192 // check source type matches the rank if it is a memref.
193 // It also should have the same ElementType as TensorDesc.
194 auto memrefTy = dyn_cast<MemRefType>(Val: getSourceType());
195 if (memrefTy) {
196 invalidRank |= (memrefTy.getRank() != rank);
197 invalidElemTy |= memrefTy.getElementType() != getElementType();
198 }
199
200 // mismatches among shape, strides, and offsets are
201 // already handeled by OffsetSizeAndStrideOpInterface.
202 // So they are not check here.
203 if (invalidRank)
204 return emitOpError(
205 message: "Expecting the rank of shape, strides, offsets, and source (if source "
206 "is a memref) should match with each other.");
207
208 // check result TensorDesc rank
209 if (getType().getRank() > rank)
210 return emitOpError(
211 message: "Expecting the TensorDesc rank is not greater than the "
212 "ranks of shape, strides, offsets or the memref source.");
213
214 if (invalidElemTy)
215 return emitOpError(message: "TensorDesc should have the same element "
216 "type with the source if it is a memref.\n");
217
218 if (getType().isScattered())
219 return emitOpError(message: "Expects a non-scattered TensorDesc.\n");
220
221 return success();
222}
223
224//===----------------------------------------------------------------------===//
225// XeGPU_PrefetchNdOp
226//===----------------------------------------------------------------------===//
227LogicalResult PrefetchNdOp::verify() {
228 auto tdescTy = getTensorDescType();
229 if (tdescTy.isScattered())
230 return emitOpError(message: "Expects a non-scattered TensorDesc.\n");
231
232 if (!isReadHintOrNone(attr: getL1HintAttr()))
233 return emitOpError(message: "invalid l1_hint: ") << getL1HintAttr();
234
235 if (!isReadHintOrNone(attr: getL2HintAttr()))
236 return emitOpError(message: "invalid l2_hint: ") << getL2HintAttr();
237
238 if (!isReadHintOrNone(attr: getL3HintAttr()))
239 return emitOpError(message: "invalid l3_hint: ") << getL3HintAttr();
240
241 return success();
242}
243
244//===----------------------------------------------------------------------===//
245// XeGPU_LoadNdOp
246//===----------------------------------------------------------------------===//
247LogicalResult LoadNdOp::verify() {
248 auto tdescTy = getTensorDescType();
249 auto valueTy = getType();
250
251 if (tdescTy.isScattered())
252 return emitOpError(message: "Expects a non-scattered TensorDesc.\n");
253
254 if (tdescTy.getRank() > 2)
255 return emitOpError(message: "Expects a 1D or 2D TensorDesc.\n");
256
257 if (!valueTy)
258 return emitOpError(message: "Invalid result, it should be a VectorType.\n");
259
260 if (!isReadHintOrNone(attr: getL1HintAttr()))
261 return emitOpError(message: "invalid l1_hint: ") << getL1HintAttr();
262
263 if (!isReadHintOrNone(attr: getL2HintAttr()))
264 return emitOpError(message: "invalid l2_hint: ") << getL2HintAttr();
265
266 if (!isReadHintOrNone(attr: getL3HintAttr()))
267 return emitOpError(message: "invalid l3_hint: ") << getL3HintAttr();
268
269 int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
270 int valueElems = valueTy.getNumElements();
271
272 // If the result vector is 1D and has less elements than the tensor
273 // descriptor, it is supposed to be a SIMT op. The layout attribute in
274 // tensor_desc is not needed.
275 if (valueElems < tdescElems && valueTy.getRank() == 1) {
276 // SIMT mode doesn't need LayoutAttr.
277 if (tdescTy.getLayoutAttr())
278 return emitOpError()
279 << "TensorDesc doesn't need LayoutAttr for SIMT code";
280
281 // For SIMT code, the load is evenly distributed across all lanes in a
282 // subgroup. Since subgroup size is arch dependent, we only check even
283 // distribution here.
284 if (tdescElems % valueElems)
285 return emitOpError()
286 << "Result shape " << makeString(array: getShapeOf(type: valueTy))
287 << " is not a valid distribution for tensor descriptor "
288 << tdescTy;
289
290 return success();
291 }
292
293 // Check SIMD mode.
294 auto tdescShape = getShapeOf(type: tdescTy);
295 auto valueShape = getShapeOf(type: valueTy);
296
297 if (getTranspose()) {
298 auto trans = getTranspose().value();
299 // Make sure the transpose value is valid, and apply it
300 if (llvm::all_of(Range&: trans, P: [&](size_t s) { return s < tdescShape.size(); }))
301 tdescShape = applyPermutation(input: tdescShape, permutation: trans);
302 else
303 mlir::emitWarning(loc: getLoc()) << "Invalid transpose attr. It is ignored.";
304 }
305
306 if (getPacked()) {
307 if (tdescTy.getRank() == 2) {
308 const int axis = 0;
309 auto vnni_factor = valueShape.back();
310 tdescShape[axis] /= vnni_factor;
311 tdescShape.push_back(Elt: vnni_factor);
312 } else {
313 mlir::emitWarning(loc: getLoc())
314 << "Invalid Packed Attr. It is ignored (available for 2D "
315 "TensorDesc only).";
316 }
317 }
318
319 auto array_len = tdescTy.getArrayLength();
320 if (array_len > 1)
321 tdescShape.insert(I: tdescShape.begin(), Elt: array_len);
322
323 if (tdescShape != valueShape)
324 return emitOpError() << "Result shape " << makeString(array: valueShape)
325 << " is not consistent with tensor descriptor "
326 << tdescTy;
327
328 return success();
329}
330
331//===----------------------------------------------------------------------===//
332// XeGPU_StoreNdOp
333//===----------------------------------------------------------------------===//
334LogicalResult StoreNdOp::verify() {
335 auto dstTy = getTensorDescType(); // Tile
336 auto valTy = getValueType(); // Vector
337
338 if (dstTy.isScattered())
339 return emitOpError(message: "Expects a non-scattered TensorDesc.\n");
340
341 if (dstTy.getRank() > 2)
342 return emitOpError(message: "Expects a 1D or 2D TensorDesc.\n");
343
344 if (!valTy)
345 return emitOpError(message: "Expecting a VectorType result.\n");
346
347 if (!isWriteHintOrNone(attr: getL1HintAttr()))
348 return emitOpError(message: "invalid l1_hint: ") << getL1HintAttr();
349
350 if (!isWriteHintOrNone(attr: getL2HintAttr()))
351 return emitOpError(message: "invalid l2_hint: ") << getL2HintAttr();
352
353 if (!isWriteHintOrNone(attr: getL3HintAttr()))
354 return emitOpError(message: "invalid l3_hint: ") << getL3HintAttr();
355
356 auto array_len = dstTy.getArrayLength();
357 if (array_len > 1)
358 return emitOpError(message: "array length is not supported by store_nd.\n");
359
360 auto tdescElems = dstTy.getNumElements();
361 auto valueElems = valTy.getNumElements();
362
363 // Similar to LoadNdOp, if the value vector is 1D and has less elements than
364 // the tensor descriptor, it is supposed to be a SIMT op. The layout attribute
365 // in tensor_desc is not needed.
366 if (valTy.getRank() == 1 && valueElems < tdescElems) {
367 // SIMT mode doesn't need LayoutAttr.
368 if (dstTy.getLayoutAttr())
369 return emitOpError()
370 << "TensorDesc doesn't need LayoutAttr for SIMT code";
371
372 if (tdescElems % valueElems)
373 return emitOpError()
374 << "Value shape " << makeString(array: getShapeOf(type: valTy))
375 << " is not a valid distribution for tensor descriptor " << dstTy;
376
377 return success();
378 }
379
380 // SIMD code should have the same shape as the tensor descriptor.
381 auto tdescShape = getShapeOf(type: dstTy);
382 auto valueShape = getShapeOf(type: valTy);
383 if (tdescShape != valueShape)
384 return emitOpError() << "Value shape " << makeString(array: valueShape)
385 << " is not consistent with tensor descriptor "
386 << dstTy;
387
388 return success();
389}
390
391//===----------------------------------------------------------------------===//
392// XeGPU_UpdateNDOffsetOp
393//===----------------------------------------------------------------------===//
394LogicalResult UpdateNdOffsetOp::verify() {
395 auto ty = getTensorDescType();
396 if (ty.isScattered())
397 return emitOpError(message: "Expects a non-scattered TensorDesc.\n");
398
399 // number of offsets specified must match the rank of the tensor descriptor
400 if (ty.getRank() != (int64_t)getNumOffsets()) {
401 return emitOpError(message: "Invalid number of offsets.");
402 }
403 return success();
404}
405
406//===----------------------------------------------------------------------===//
407// XeGPU_CreateDescOp
408//===----------------------------------------------------------------------===//
409
410void CreateDescOp::build(OpBuilder &builder, OperationState &state,
411 TensorDescType TensorDesc, Value source,
412 llvm::ArrayRef<OpFoldResult> offsets) {
413 auto loc = source.getLoc();
414 int64_t size = static_cast<int64_t>(offsets.size());
415 auto type = VectorType::get(shape: size, elementType: builder.getIndexType());
416 auto values = getValueOrCreateConstantIndexOp(b&: builder, loc, valueOrAttrVec: offsets);
417 auto offset = builder.create<vector::FromElementsOp>(location: loc, args&: type, args&: values);
418 build(odsBuilder&: builder, odsState&: state, TensorDesc, source, offsets: offset);
419}
420
421void CreateDescOp::build(OpBuilder &builder, OperationState &state,
422 TensorDescType TensorDesc, Value source,
423 llvm::ArrayRef<int64_t> offsets) {
424 auto ofrs = getAsIndexOpFoldResult(ctx: builder.getContext(), values: offsets);
425 build(builder, state, TensorDesc, source, offsets: ofrs);
426}
427
428LogicalResult CreateDescOp::verify() {
429 auto tdescTy = getTensorDescType();
430
431 if (getRankOf(val: getSource()) > 1)
432 return emitOpError(
433 message: "Expecting the source is a 1D memref or pointer (uint64_t).");
434
435 if (!tdescTy.isScattered())
436 return emitOpError(message: "Expects a scattered TensorDesc.\n");
437
438 // Memory space of created TensorDesc should match with the source.
439 // Both source and TensorDesc are considered for global memory by default,
440 // if the memory scope attr is not specified. If source is an integer,
441 // it is considered as ptr to global memory.
442 auto srcMemorySpace = getSourceMemorySpace();
443 auto tdescMemorySpace = static_cast<unsigned>(tdescTy.getMemorySpace());
444 if (srcMemorySpace != tdescMemorySpace)
445 return emitOpError(message: "Memory space mismatch.")
446 << " Source: " << srcMemorySpace
447 << ", TensorDesc: " << tdescMemorySpace;
448
449 // check total size
450 auto chunkSize = tdescTy.getChunkSizeAsInt();
451 SmallVector<int64_t> shape(getOffsetsType().getShape());
452 if (chunkSize != 1)
453 shape.push_back(Elt: chunkSize);
454
455 auto tdescShape = getShapeOf(type: tdescTy);
456 if (shape != tdescShape)
457 return emitOpError(message: "Incorrect TensorDesc shape. ")
458 << "Expected is " << makeString(array: shape) << "\n";
459
460 return success();
461}
462
463//===----------------------------------------------------------------------===//
464// XeGPU_PrefetchOp
465//===----------------------------------------------------------------------===//
466LogicalResult PrefetchOp::verify() {
467 auto tdescTy = getTensorDescType();
468 if (!tdescTy.isScattered())
469 return emitOpError(message: "Expects a scattered TensorDesc.\n");
470
471 if (!isReadHintOrNone(attr: getL1HintAttr()))
472 return emitOpError(message: "invalid l1_hint: ") << getL1HintAttr();
473
474 if (!isReadHintOrNone(attr: getL2HintAttr()))
475 return emitOpError(message: "invalid l2_hint: ") << getL2HintAttr();
476
477 if (!isReadHintOrNone(attr: getL3HintAttr()))
478 return emitOpError(message: "invalid l3_hint: ") << getL3HintAttr();
479
480 return success();
481}
482
483//===----------------------------------------------------------------------===//
484// XeGPU_LoadGatherOp
485//===----------------------------------------------------------------------===//
486LogicalResult LoadGatherOp::verify() {
487 auto tdescTy = getTensorDescType();
488 auto maskTy = getMaskType();
489 auto valueTy = getValueType();
490
491 if (!isReadHintOrNone(attr: getL1HintAttr()))
492 return emitOpError(message: "invalid l1_hint: ") << getL1HintAttr();
493
494 if (!isReadHintOrNone(attr: getL2HintAttr()))
495 return emitOpError(message: "invalid l2_hint: ") << getL2HintAttr();
496
497 if (!isReadHintOrNone(attr: getL3HintAttr()))
498 return emitOpError(message: "invalid l3_hint: ") << getL3HintAttr();
499
500 return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
501 emitError: [&]() { return emitOpError(); });
502}
503
504//===----------------------------------------------------------------------===//
505// XeGPU_StoreScatterOp
506//===----------------------------------------------------------------------===//
507LogicalResult StoreScatterOp::verify() {
508 auto tdescTy = getTensorDescType();
509 auto maskTy = getMaskType();
510 auto valueTy = getValueType();
511
512 if (!isWriteHintOrNone(attr: getL1HintAttr()))
513 return emitOpError(message: "invalid l1_hint: ") << getL1HintAttr();
514
515 if (!isWriteHintOrNone(attr: getL2HintAttr()))
516 return emitOpError(message: "invalid l2_hint: ") << getL2HintAttr();
517
518 if (!isWriteHintOrNone(attr: getL3HintAttr()))
519 return emitOpError(message: "invalid l3_hint: ") << getL3HintAttr();
520
521 return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
522 emitError: [&]() { return emitOpError(); });
523}
524
525//===----------------------------------------------------------------------===//
526// XeGPU_UpdateOffsetOp
527//===----------------------------------------------------------------------===//
528void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
529 mlir::Value tensorDesc,
530 llvm::ArrayRef<OpFoldResult> offsets) {
531 auto tdescTy = mlir::dyn_cast<TensorDescType>(Val: tensorDesc.getType());
532 assert(tdescTy && "Expecting the source is a TensorDescType value.");
533 auto loc = tensorDesc.getLoc();
534 int64_t size = static_cast<int64_t>(offsets.size());
535 auto type = VectorType::get(shape: {size}, elementType: builder.getIndexType());
536 auto values = getValueOrCreateConstantIndexOp(b&: builder, loc, valueOrAttrVec: offsets);
537 auto offset = builder.create<vector::FromElementsOp>(location: loc, args&: type, args&: values);
538 build(odsBuilder&: builder, odsState&: state, result: tdescTy, TensorDesc: tensorDesc, offsets: offset);
539}
540
541void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
542 Value tensorDesc, llvm::ArrayRef<int64_t> offsets) {
543 auto ofrs = getAsIndexOpFoldResult(ctx: builder.getContext(), values: offsets);
544 build(builder, state, tensorDesc, offsets: ofrs);
545}
546
547LogicalResult UpdateOffsetOp::verify() {
548 auto tdescTy = getTensorDescType();
549 if (!tdescTy.isScattered())
550 return emitOpError(message: "Expects a scattered TensorDesc.\n");
551
552 SmallVector<int64_t> expectedOffsetShape = getShapeOf(type: tdescTy);
553 SmallVector<int64_t> offsetShape = getShapeOf(type: getOffsetsType());
554 if (tdescTy.getChunkSizeAsInt() > 1)
555 expectedOffsetShape.pop_back();
556
557 if (expectedOffsetShape != offsetShape)
558 return emitOpError(
559 message: "Offsets should match TensorDesc except the chunk size dim.");
560
561 return success();
562}
563
564//===----------------------------------------------------------------------===//
565// XeGPU_DpasOp
566//===----------------------------------------------------------------------===//
567LogicalResult DpasOp::verify() {
568 int64_t lhsRank = getLhsType().getRank();
569 int64_t rhsRank = getRhsType().getRank();
570 int64_t resRank = getResultType().getRank();
571 auto lhsShape = getLhsType().getShape();
572 auto rhsShape = getRhsType().getShape();
573 auto resShape = getResultType().getShape();
574
575 if (getAcc() && getAcc().getType() != getResultType())
576 return emitOpError(message: "Expecting the acc type to be the same as result.");
577
578 // SIMT code: the size of the B operand has to be a multiple of 32 bits.
579 // It skips the semantic check since lack of architecture information.
580 // Users need to ensure the correctness.
581 if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
582 auto numElems = getRhsType().getNumElements();
583 auto elemTy = getRhsType().getElementType();
584 auto factor = 32 / elemTy.getIntOrFloatBitWidth();
585 if (numElems % factor != 0)
586 return emitOpError(message: "Expecting B operand to be a multiple of 32 bits.");
587 return success();
588 }
589
590 // SIMD code
591 if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2)
592 return emitOpError(
593 message: "expecting lhs and result to be a 2D vector, and rhs to be either "
594 "2D or 3D (packed) vector.");
595 auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
596 if (bK != lhsShape[1])
597 return emitOpError(message: "K-dimension mismatch.");
598 if (lhsShape[0] != resShape[0])
599 return emitOpError(message: "M-dimension mismatch.");
600 if (rhsShape[1] != resShape[1])
601 return emitOpError(message: "N-dimension mismatch.");
602
603 return success();
604}
605
606//===----------------------------------------------------------------------===//
607// XeGPU_ConvertLayoutOp
608//===----------------------------------------------------------------------===//
609LogicalResult ConvertLayoutOp::verify() {
610 auto srcMap = getSrcMapAttr();
611 auto resMap = getResMapAttr();
612 if (!srcMap)
613 return emitOpError(message: "expected srcMap.");
614 if (!resMap)
615 return emitOpError(message: "expected resMap.");
616
617 if (srcMap == resMap)
618 return emitOpError(message: "expected different srcMap and resMap.");
619
620 // both srcMap and resMap should be WgLayout or SgLayout at the same time.
621 if ((!srcMap.isWgLayout() || !resMap.isWgLayout()) &&
622 (!srcMap.isSgLayout() || !resMap.isSgLayout()))
623 return emitOpError(
624 message: "expected srcMap and resMap be WgLayout or SgLayout at the same time.");
625
626 auto shape = getSource().getType().getShape();
627 if (!XeGPUDialect::isEvenlyDistributable(shape, attr: srcMap))
628 return emitOpError(message: "invalid srcMap, data cannot be evenly distributed.");
629
630 if (!XeGPUDialect::isEvenlyDistributable(shape, attr: resMap))
631 return emitOpError(message: "invalid resMap, data cannot be evenly distributed.");
632
633 return mlir::success();
634}
635
636} // namespace xegpu
637} // namespace mlir
638
639#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
640#define GET_OP_CLASSES
641#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
642

source code of mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp