1//===- XeGPUOps.cpp - MLIR XeGPU ops implementation -------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/Utils/Utils.h"
10#include "mlir/Dialect/Utils/IndexingUtils.h"
11#include "mlir/Dialect/Utils/StaticValueUtils.h"
12#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
13#include "mlir/IR/Builders.h"
14#include "mlir/IR/TypeUtilities.h"
15
16#include "llvm/Support/Debug.h"
17
18#define DEBUG_TYPE "xegpu"
19
20namespace mlir {
21namespace xegpu {
22
23static void transpose(llvm::ArrayRef<int64_t> trans,
24 SmallVector<int64_t> &shape) {
25 SmallVector<int64_t> old = shape;
26 for (size_t i = 0; i < trans.size(); i++)
27 shape[i] = old[trans[i]];
28}
29
30template <typename T>
31static std::string makeString(T array, bool breakline = false) {
32 std::string buf;
33 buf.clear();
34 llvm::raw_string_ostream os(buf);
35 os << "[";
36 for (size_t i = 1; i < array.size(); i++) {
37 os << array[i - 1] << ", ";
38 if (breakline)
39 os << "\n\t\t";
40 }
41 os << array.back() << "]";
42 return buf;
43}
44
45static SmallVector<int64_t> getShapeOf(Type type) {
46 SmallVector<int64_t> shape;
47 if (auto ty = llvm::dyn_cast<ShapedType>(type))
48 shape = SmallVector<int64_t>(ty.getShape());
49 else
50 shape.push_back(Elt: 1);
51 return shape;
52}
53
54static int64_t getRankOf(Value val) {
55 auto type = val.getType();
56 if (auto ty = llvm::dyn_cast<ShapedType>(type))
57 return ty.getRank();
58 return 0;
59}
60
61static bool isReadHintOrNone(const CachePolicyAttr &attr) {
62 if (!attr)
63 return true;
64 auto kind = attr.getValue();
65 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
66 kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
67}
68
69static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
70 if (!attr)
71 return true;
72 auto kind = attr.getValue();
73 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
74 kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
75}
76
77static LogicalResult
78isValidGatherScatterParams(Type maskTy, VectorType valueTy,
79 TensorDescType tdescTy, UnitAttr transposeAttr,
80 function_ref<InFlightDiagnostic()> emitError) {
81
82 if (!tdescTy.isScattered())
83 return emitError() << "Expects a scattered TensorDesc.";
84
85 if (!valueTy)
86 return emitError() << "Expecting a vector type result.";
87
88 auto maskShape = getShapeOf(type: maskTy);
89 auto valueShape = getShapeOf(valueTy);
90 auto tdescShape = getShapeOf(tdescTy);
91 auto chunkSize = tdescTy.getChunkSize();
92
93 if (valueTy.getElementType() != tdescTy.getElementType())
94 return emitError()
95 << "Value should have the same element type as TensorDesc.";
96
97 if (tdescShape[0] != maskShape[0])
98 return emitError()
99 << "dim-0 of the Mask and TensorDesc should be the same.";
100
101 // a valid shape for SIMT case
102 if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
103 if (tdescTy.getLayoutAttr())
104 return emitError() << "TensorDesc doesn't need LayoutAttr for SIMT code";
105 if (transposeAttr)
106 return emitError() << "doesn't need TransposeAttr for SIMT code";
107 return success();
108 }
109
110 if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) {
111 if (!transposeAttr)
112 return emitError() << "rank-2 tensor has to be transposed.";
113 transpose({1, 0}, tdescShape);
114 }
115
116 if (tdescShape != valueShape)
117 return emitError() << "Value shape " << makeString(valueShape)
118 << " is neither a valid distribution for SIMT nor "
119 "consistent with the tensor descriptor for SIMD "
120 << tdescTy;
121 return success();
122}
123
124//===----------------------------------------------------------------------===//
125// XeGPU_CreateNdDescOp
126//===----------------------------------------------------------------------===//
127void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
128 Type tdesc, TypedValue<MemRefType> source,
129 llvm::ArrayRef<OpFoldResult> offsets) {
130 [[maybe_unused]] auto ty = source.getType();
131 assert(ty.hasStaticShape() && offsets.size() == (size_t)ty.getRank());
132
133 llvm::SmallVector<int64_t> staticOffsets;
134 llvm::SmallVector<Value> dynamicOffsets;
135 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
136
137 build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */,
138 ValueRange({}) /* empty dynamic shape */,
139 ValueRange({}) /* empty dynamic strides */,
140 staticOffsets /* const offsets */, {} /* empty const shape*/,
141 {} /* empty const strides*/);
142}
143
144void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
145 Type tdesc, Value source,
146 llvm::ArrayRef<OpFoldResult> offsets,
147 llvm::ArrayRef<OpFoldResult> shape,
148 llvm::ArrayRef<OpFoldResult> strides) {
149 assert(shape.size() && offsets.size() && strides.size() &&
150 shape.size() == strides.size() && shape.size() == offsets.size());
151
152 Type srcTy = source.getType();
153 assert(isa<IntegerType>(srcTy) ||
154 isa<MemRefType>(srcTy) && "Source has to be either int or memref.");
155
156 llvm::SmallVector<Value> dynamicOffsets;
157 llvm::SmallVector<Value> dynamicShape;
158 llvm::SmallVector<Value> dynamicStrides;
159
160 llvm::SmallVector<int64_t> staticOffsets;
161 llvm::SmallVector<int64_t> staticShape;
162 llvm::SmallVector<int64_t> staticStrides;
163
164 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
165 dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
166 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
167
168 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
169 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
170 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
171
172 if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
173 auto memrefShape = memrefTy.getShape();
174 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
175
176 // if shape and strides are from Memref, we don't need attributes for them
177 // to keep the IR print clean.
178 if (staticShape == memrefShape && staticStrides == memrefStrides) {
179 staticShapeAttr = DenseI64ArrayAttr();
180 staticStridesAttr = DenseI64ArrayAttr();
181 }
182 }
183
184 build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
185 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
186}
187
188LogicalResult CreateNdDescOp::verify() {
189 auto rank = (int64_t)getMixedOffsets().size();
190 bool invalidRank = false;
191 bool invalidElemTy = false;
192
193 // Memory space of created TensorDesc should match with the source.
194 // Both source and TensorDesc are considered for global memory by default,
195 // if the memory scope attr is not specified. If source is an integer,
196 // it is considered as ptr to global memory.
197 auto srcMemorySpace = getSourceMemorySpace();
198 auto tdescMemorySpace = static_cast<unsigned>(getType().getMemorySpace());
199 if (srcMemorySpace != tdescMemorySpace)
200 return emitOpError("Memory space mismatch.")
201 << " Source: " << srcMemorySpace
202 << ", TensorDesc: " << tdescMemorySpace;
203
204 // check source type matches the rank if it is a memref.
205 // It also should have the same ElementType as TensorDesc.
206 auto memrefTy = dyn_cast<MemRefType>(getSourceType());
207 if (memrefTy) {
208 invalidRank |= (memrefTy.getRank() != rank);
209 invalidElemTy |= memrefTy.getElementType() != getElementType();
210 }
211
212 // mismatches among shape, strides, and offsets are
213 // already handeled by OffsetSizeAndStrideOpInterface.
214 // So they are not check here.
215 if (invalidRank)
216 return emitOpError(
217 "Expecting the rank of shape, strides, offsets, and source (if source "
218 "is a memref) should match with each other.");
219
220 // check result TensorDesc rank
221 invalidRank = (getType().getRank() > 2 || getType().getRank() > rank);
222
223 if (invalidRank)
224 return emitOpError(
225 "Expecting the TensorDesc rank is up to 2 and not greater than the "
226 "ranks of shape, strides, offsets or the memref source.");
227
228 if (invalidElemTy)
229 return emitOpError("TensorDesc should have the same element "
230 "type with the source if it is a memref.\n");
231
232 if (getType().isScattered())
233 return emitOpError("Expects a non-scattered TensorDesc.\n");
234
235 return success();
236}
237
238//===----------------------------------------------------------------------===//
239// XeGPU_PrefetchNdOp
240//===----------------------------------------------------------------------===//
241LogicalResult PrefetchNdOp::verify() {
242 auto tdescTy = getTensorDescType();
243 if (tdescTy.isScattered())
244 return emitOpError("Expects a non-scattered TensorDesc.\n");
245
246 if (!isReadHintOrNone(getL1HintAttr()))
247 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
248
249 if (!isReadHintOrNone(getL2HintAttr()))
250 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
251
252 if (!isReadHintOrNone(getL3HintAttr()))
253 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
254
255 return success();
256}
257
258//===----------------------------------------------------------------------===//
259// XeGPU_LoadNdOp
260//===----------------------------------------------------------------------===//
261LogicalResult LoadNdOp::verify() {
262 auto tdescTy = getTensorDescType();
263 auto valueTy = getType();
264
265 if (tdescTy.getRank() > 2)
266 return emitOpError("Expecting a 1D/2D TensorDesc.\n");
267
268 if (tdescTy.isScattered())
269 return emitOpError("Expects a non-scattered TensorDesc.\n");
270
271 if (!valueTy)
272 return emitOpError("Invalid result, it should be a VectorType.\n");
273
274 if (!isReadHintOrNone(getL1HintAttr()))
275 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
276
277 if (!isReadHintOrNone(getL2HintAttr()))
278 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
279
280 if (!isReadHintOrNone(getL3HintAttr()))
281 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
282
283 int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
284 int valueElems = valueTy.getNumElements();
285
286 // If the result vector is 1D and has less elements than the tensor
287 // descriptor, it is supposed to be a SIMT op. The layout attribute in
288 // tensor_desc is not needed.
289 if (valueElems < tdescElems && valueTy.getRank() == 1) {
290 // SIMT mode doesn't need LayoutAttr.
291 if (tdescTy.getLayoutAttr())
292 return emitOpError()
293 << "TensorDesc doesn't need LayoutAttr for SIMT code";
294
295 // For SIMT code, the load is evenly distributed across all lanes in a
296 // subgroup. Since subgroup size is arch dependent, we only check even
297 // distribution here.
298 if (tdescElems % valueElems)
299 return emitOpError()
300 << "Result shape " << makeString(getShapeOf(valueTy))
301 << " is not a valid distribution for tensor descriptor "
302 << tdescTy;
303
304 return success();
305 }
306
307 // Check SIMD mode.
308 auto tdescShape = getShapeOf(tdescTy);
309 auto valueShape = getShapeOf(valueTy);
310
311 if (getTranspose()) {
312 auto trans = getTranspose().value();
313
314 // Make sure the transpose value is valid.
315 bool valid = llvm::all_of(
316 trans, [&](int t) { return t >= 0 && t < tdescTy.getRank(); });
317
318 if (valid)
319 transpose(trans, tdescShape);
320 else
321 mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored.";
322 }
323
324 if (getPacked()) {
325 if (tdescTy.getRank() == 2) {
326 const int axis = 0;
327 auto vnni_factor = valueShape.back();
328 tdescShape[axis] /= vnni_factor;
329 tdescShape.push_back(vnni_factor);
330 } else {
331 mlir::emitWarning(getLoc())
332 << "Invalid Packed Attr. It is ignored (available for 2D "
333 "TensorDesc only).";
334 }
335 }
336
337 auto array_len = tdescTy.getArrayLength();
338 if (array_len > 1) {
339 tdescShape.insert(tdescShape.begin(), array_len);
340 }
341
342 if (tdescShape != valueShape) {
343 return emitOpError() << "Result shape " << makeString(valueShape)
344 << " is not consistent with tensor descriptor "
345 << tdescTy;
346 }
347
348 return success();
349}
350
351//===----------------------------------------------------------------------===//
352// XeGPU_StoreNdOp
353//===----------------------------------------------------------------------===//
354LogicalResult StoreNdOp::verify() {
355 auto dstTy = getTensorDescType(); // Tile
356 auto valTy = getValueType(); // Vector
357
358 if (dstTy.getRank() > 2)
359 return emitOpError("Expecting a 1D/2D TensorDesc.\n");
360
361 if (dstTy.isScattered())
362 return emitOpError("Expects a non-scattered TensorDesc.\n");
363
364 if (!valTy)
365 return emitOpError("Expecting a VectorType result.\n");
366
367 if (!isWriteHintOrNone(getL1HintAttr()))
368 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
369
370 if (!isWriteHintOrNone(getL2HintAttr()))
371 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
372
373 if (!isWriteHintOrNone(getL3HintAttr()))
374 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
375
376 auto array_len = dstTy.getArrayLength();
377 if (array_len > 1)
378 return emitOpError("array length is not supported by store_nd.\n");
379
380 auto tdescElems = dstTy.getNumElements();
381 auto valueElems = valTy.getNumElements();
382
383 // Similar to LoadNdOp, if the value vector is 1D and has less elements than
384 // the tensor descriptor, it is supposed to be a SIMT op. The layout attribute
385 // in tensor_desc is not needed.
386 if (valTy.getRank() == 1 && valueElems < tdescElems) {
387 // SIMT mode doesn't need LayoutAttr.
388 if (dstTy.getLayoutAttr())
389 return emitOpError()
390 << "TensorDesc doesn't need LayoutAttr for SIMT code";
391
392 if (tdescElems % valueElems) {
393 return emitOpError()
394 << "Value shape " << makeString(getShapeOf(valTy))
395 << " is not a valid distribution for tensor descriptor " << dstTy;
396 }
397 return success();
398 }
399
400 // SIMD code should have the same shape as the tensor descriptor.
401 auto tdescShape = getShapeOf(dstTy);
402 auto valueShape = getShapeOf(valTy);
403 if (tdescShape != valueShape) {
404 return emitOpError() << "Value shape " << makeString(valueShape)
405 << " is not consistent with tensor descriptor "
406 << dstTy;
407 }
408
409 return success();
410}
411
412//===----------------------------------------------------------------------===//
413// XeGPU_UpdateNDOffsetOp
414//===----------------------------------------------------------------------===//
415LogicalResult UpdateNdOffsetOp::verify() {
416 auto ty = getTensorDescType();
417 if (ty.isScattered())
418 return emitOpError("Expects a non-scattered TensorDesc.\n");
419
420 // number of offsets specified must match the rank of the tensor descriptor
421 if (ty.getRank() != (int64_t)getNumOffsets()) {
422 return emitOpError("Invalid number of offsets.");
423 }
424 return success();
425}
426
427//===----------------------------------------------------------------------===//
428// XeGPU_CreateDescOp
429//===----------------------------------------------------------------------===//
430
431void CreateDescOp::build(OpBuilder &builder, OperationState &state,
432 TensorDescType TensorDesc, Value source,
433 llvm::ArrayRef<OpFoldResult> offsets) {
434 auto loc = source.getLoc();
435 int64_t size = static_cast<int64_t>(offsets.size());
436 auto type = VectorType::get(size, builder.getIndexType());
437 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
438 auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
439 build(builder, state, TensorDesc, source, offset);
440}
441
442void CreateDescOp::build(OpBuilder &builder, OperationState &state,
443 TensorDescType TensorDesc, Value source,
444 llvm::ArrayRef<int64_t> offsets) {
445 auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets);
446 build(builder, state, TensorDesc, source, ofrs);
447}
448
449LogicalResult CreateDescOp::verify() {
450 auto tdescTy = getTensorDescType();
451
452 if (getRankOf(getSource()) > 1)
453 return emitOpError(
454 "Expecting the source is a 1D memref or pointer (uint64_t).");
455
456 if (!tdescTy.isScattered())
457 return emitOpError("Expects a scattered TensorDesc.\n");
458
459 // Memory space of created TensorDesc should match with the source.
460 // Both source and TensorDesc are considered for global memory by default,
461 // if the memory scope attr is not specified. If source is an integer,
462 // it is considered as ptr to global memory.
463 auto srcMemorySpace = getSourceMemorySpace();
464 auto tdescMemorySpace = static_cast<unsigned>(tdescTy.getMemorySpace());
465 if (srcMemorySpace != tdescMemorySpace)
466 return emitOpError("Memory space mismatch.")
467 << " Source: " << srcMemorySpace
468 << ", TensorDesc: " << tdescMemorySpace;
469
470 // check total size
471 auto chunkSize = tdescTy.getChunkSize();
472 auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth();
473 auto bitsPerLane = elemBits * chunkSize;
474 if (chunkSize > 1 && bitsPerLane % 32) {
475 // For 8-bit and 16-bit data, the hardware only supports chunk size of 1.
476 // For 32-bit data, the hardware can support larger larger chunk size. So
477 // we can bitcast 8-bit/16-bit data to 32-bit data for better performance.
478 // But this requires the total size is 32 bit aligned to make the
479 // optimization work.
480 return emitOpError(
481 "access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned.");
482 }
483
484 auto lscConstraints = 512 * 8; // each access is upto 512 bytes.
485 if (elemBits * tdescTy.getNumElements() > lscConstraints)
486 return emitOpError("total access size (simd_lanes * chunk_size * "
487 "sizeof(elemTy)) is upto 512 bytes.");
488
489 SmallVector<int64_t> shape({(int64_t)getNumOffsets()});
490 if (chunkSize != 1)
491 shape.push_back(chunkSize);
492
493 auto tdescShape = getShapeOf(tdescTy);
494 if (shape != tdescShape)
495 return emitOpError("Incorrect TensorDesc shape. ")
496 << "Expected is " << makeString(shape) << "\n";
497
498 return success();
499}
500
501//===----------------------------------------------------------------------===//
502// XeGPU_PrefetchOp
503//===----------------------------------------------------------------------===//
504LogicalResult PrefetchOp::verify() {
505 auto tdescTy = getTensorDescType();
506 if (!tdescTy.isScattered())
507 return emitOpError("Expects a scattered TensorDesc.\n");
508
509 if (!isReadHintOrNone(getL1HintAttr()))
510 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
511
512 if (!isReadHintOrNone(getL2HintAttr()))
513 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
514
515 if (!isReadHintOrNone(getL3HintAttr()))
516 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
517
518 return success();
519}
520
521//===----------------------------------------------------------------------===//
522// XeGPU_LoadGatherOp
523//===----------------------------------------------------------------------===//
524LogicalResult LoadGatherOp::verify() {
525 auto tdescTy = getTensorDescType();
526 auto maskTy = getMaskType();
527 auto valueTy = getValueType();
528
529 if (!isReadHintOrNone(getL1HintAttr()))
530 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
531
532 if (!isReadHintOrNone(getL2HintAttr()))
533 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
534
535 if (!isReadHintOrNone(getL3HintAttr()))
536 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
537
538 return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
539 getTransposeAttr(),
540 [&]() { return emitOpError(); });
541}
542
543//===----------------------------------------------------------------------===//
544// XeGPU_StoreScatterOp
545//===----------------------------------------------------------------------===//
546LogicalResult StoreScatterOp::verify() {
547 auto tdescTy = getTensorDescType();
548 auto maskTy = getMaskType();
549 auto valueTy = getValueType();
550
551 if (!isWriteHintOrNone(getL1HintAttr()))
552 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
553
554 if (!isWriteHintOrNone(getL2HintAttr()))
555 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
556
557 if (!isWriteHintOrNone(getL3HintAttr()))
558 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
559
560 return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
561 getTransposeAttr(),
562 [&]() { return emitOpError(); });
563}
564
565//===----------------------------------------------------------------------===//
566// XeGPU_UpdateOffsetOp
567//===----------------------------------------------------------------------===//
568void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
569 mlir::Value tensorDesc,
570 llvm::ArrayRef<OpFoldResult> offsets) {
571 auto tdescTy = mlir::dyn_cast<TensorDescType>(tensorDesc.getType());
572 assert(tdescTy && "Expecting the source is a TensorDescType value.");
573 auto loc = tensorDesc.getLoc();
574 int64_t size = static_cast<int64_t>(offsets.size());
575 auto type = VectorType::get({size}, builder.getIndexType());
576 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
577 auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
578 build(builder, state, tdescTy, tensorDesc, offset);
579}
580
581void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
582 Value tensorDesc, llvm::ArrayRef<int64_t> offsets) {
583 auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets);
584 build(builder, state, tensorDesc, ofrs);
585}
586
587//===----------------------------------------------------------------------===//
588// XeGPU_DpasOp
589//===----------------------------------------------------------------------===//
590LogicalResult DpasOp::verify() {
591 int64_t lhsRank = getLhsType().getRank();
592 int64_t rhsRank = getRhsType().getRank();
593 int64_t resRank = getResultType().getRank();
594 auto lhsShape = getLhsType().getShape();
595 auto rhsShape = getRhsType().getShape();
596 auto resShape = getResultType().getShape();
597
598 if (getAcc() && getAcc().getType() != getResultType())
599 return emitOpError("Expecting the acc type to be the same as result.");
600
601 // SIMT code: the size of the B operand has to be a multiple of 32 bits.
602 // It skips the semantic check since lack of architecture information.
603 // Users need to ensure the correctness.
604 if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
605 auto numElems = getRhsType().getNumElements();
606 auto elemTy = getRhsType().getElementType();
607 auto factor = 32 / elemTy.getIntOrFloatBitWidth();
608 if (numElems % factor != 0)
609 return emitOpError("Expecting B operand to be a multiple of 32 bits.");
610 return success();
611 }
612
613 // SIMD code
614 if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2)
615 return emitOpError(
616 "expecting lhs and result to be a 2D vector, and rhs to be either "
617 "2D or 3D (packed) vector.");
618 auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
619 if (bK != lhsShape[1])
620 return emitOpError("K-dimension mismatch.");
621 if (lhsShape[0] != resShape[0])
622 return emitOpError("M-dimension mismatch.");
623 if (rhsShape[1] != resShape[1])
624 return emitOpError("N-dimension mismatch.");
625
626 return success();
627}
628
629//===----------------------------------------------------------------------===//
630// XeGPU_ConvertLayoutOp
631//===----------------------------------------------------------------------===//
632LogicalResult ConvertLayoutOp::verify() {
633 auto srcMap = getSrcMapAttr();
634 auto resMap = getResMapAttr();
635 if (!srcMap)
636 return emitOpError("expected srcMap.");
637 if (!resMap)
638 return emitOpError("expected resMap.");
639
640 if (srcMap == resMap)
641 return emitOpError("expected different srcMap and resMap.");
642
643 // both srcMap and resMap should be WgLayout or SgLayout at the same time.
644 if ((!srcMap.isWgLayout() || !resMap.isWgLayout()) &&
645 (!srcMap.isSgLayout() || !resMap.isSgLayout()))
646 return emitOpError(
647 "expected srcMap and resMap be WgLayout or SgLayout at the same time.");
648
649 auto shape = getSource().getType().getShape();
650 if (!XeGPUDialect::isEvenlyDistributable(shape, srcMap))
651 return emitOpError("invalid srcMap, data cannot be evenly distributed.");
652
653 if (!XeGPUDialect::isEvenlyDistributable(shape, resMap))
654 return emitOpError("invalid resMap, data cannot be evenly distributed.");
655
656 return mlir::success();
657}
658
659} // namespace xegpu
660} // namespace mlir
661
662#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
663#define GET_OP_CLASSES
664#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
665

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp