1//===- XeGPUOps.cpp - MLIR XeGPU ops implementation -------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Utils/StaticValueUtils.h"
10#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
11#include "mlir/IR/Builders.h"
12#include "mlir/IR/TypeUtilities.h"
13
14#include "llvm/Support/Debug.h"
15
16#define DEBUG_TYPE "xegpu"
17
18namespace mlir {
19namespace xegpu {
20
21static void transpose(llvm::ArrayRef<int64_t> trans,
22 SmallVector<int64_t> &shape) {
23 SmallVector<int64_t> old = shape;
24 for (size_t i = 0; i < trans.size(); i++)
25 shape[i] = old[trans[i]];
26}
27
28template <typename T>
29static std::string makeString(T array, bool breakline = false) {
30 std::string buf;
31 buf.clear();
32 llvm::raw_string_ostream os(buf);
33 os << "[";
34 for (size_t i = 1; i < array.size(); i++) {
35 os << array[i - 1] << ", ";
36 if (breakline)
37 os << "\n\t\t";
38 }
39 os << array.back() << "]";
40 os.flush();
41 return buf;
42}
43
44static SmallVector<int64_t> getShapeOf(Type type) {
45 SmallVector<int64_t> shape;
46 if (auto ty = llvm::dyn_cast<ShapedType>(type))
47 shape = SmallVector<int64_t>(ty.getShape());
48 else
49 shape.push_back(Elt: 1);
50 return shape;
51}
52
53static int64_t getRankOf(Value val) {
54 auto type = val.getType();
55 if (auto ty = llvm::dyn_cast<ShapedType>(type))
56 return ty.getRank();
57 return 0;
58}
59
60static bool isReadHintOrNone(const CachePolicyAttr &attr) {
61 if (!attr)
62 return true;
63 auto kind = attr.getValue();
64 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
65 kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
66}
67
68static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
69 if (!attr)
70 return true;
71 auto kind = attr.getValue();
72 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
73 kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
74}
75
76//===----------------------------------------------------------------------===//
77// XeGPU_CreateNdDescOp
78//===----------------------------------------------------------------------===//
79void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
80 Type tdesc, TypedValue<MemRefType> source,
81 llvm::ArrayRef<OpFoldResult> offsets) {
82 [[maybe_unused]] auto ty = source.getType();
83 assert(ty.hasStaticShape() && offsets.size() == (size_t)ty.getRank());
84
85 llvm::SmallVector<int64_t> staticOffsets;
86 llvm::SmallVector<Value> dynamicOffsets;
87 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
88
89 build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */,
90 ValueRange({}) /* empty dynamic shape */,
91 ValueRange({}) /* empty dynamic strides */,
92 staticOffsets /* const offsets */, {} /* empty const shape*/,
93 {} /* empty const strides*/);
94}
95
96void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
97 Type tdesc, TypedValue<IntegerType> source,
98 llvm::ArrayRef<OpFoldResult> offsets,
99 llvm::ArrayRef<OpFoldResult> shape,
100 llvm::ArrayRef<OpFoldResult> strides) {
101 assert(shape.size() && offsets.size() && strides.size() &&
102 shape.size() == strides.size() && shape.size() == offsets.size());
103
104 llvm::SmallVector<int64_t> staticOffsets;
105 llvm::SmallVector<int64_t> staticShape;
106 llvm::SmallVector<int64_t> staticStrides;
107 llvm::SmallVector<Value> dynamicOffsets;
108 llvm::SmallVector<Value> dynamicShape;
109 llvm::SmallVector<Value> dynamicStrides;
110
111 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
112 dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
113 dispatchIndexOpFoldResults(strides, dynamicStrides, staticOffsets);
114
115 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
116 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
117 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
118
119 build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
120 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
121}
122
123LogicalResult CreateNdDescOp::verify() {
124 auto rank = (int64_t)getMixedOffsets().size();
125 bool invalidRank = (rank != 2);
126 bool invalidElemTy = false;
127
128 // check source type matches the rank if it is a memref.
129 // It also should have the same ElementType as TensorDesc.
130 auto memrefTy = dyn_cast<MemRefType>(getSourceType());
131 if (memrefTy) {
132 invalidRank |= (memrefTy.getRank() != rank);
133 invalidElemTy |= memrefTy.getElementType() != getElementType();
134 }
135
136 // check result type matches the rank
137 invalidRank = (getType().getRank() != rank);
138
139 // mismatches among shape, strides, and offsets are
140 // already handeled by OffsetSizeAndStrideOpInterface.
141 // So they are not check here.
142 if (invalidRank)
143 return emitOpError(
144 "Expecting the rank of shape, strides, offsets, "
145 "source memref type (if source is a memref) and TensorDesc "
146 "should match with each other. They currenlty are 2D.");
147
148 if (invalidElemTy)
149 return emitOpError("TensorDesc should have the same element "
150 "type with the source if it is a memref.\n");
151
152 if (getType().getScattered())
153 return emitOpError("Expects a non-scattered TensorDesc.\n");
154
155 return success();
156}
157
158//===----------------------------------------------------------------------===//
159// XeGPU_PrefetchNdOp
160//===----------------------------------------------------------------------===//
161LogicalResult PrefetchNdOp::verify() {
162 auto tdescTy = getTensorDescType();
163 if (tdescTy.getScattered())
164 return emitOpError("Expects a non-scattered TensorDesc.\n");
165
166 if (!isReadHintOrNone(getL1HintAttr()))
167 return emitOpError("invlid l1_hint: ") << getL1HintAttr();
168
169 if (!isReadHintOrNone(getL2HintAttr()))
170 return emitOpError("invlid l2_hint: ") << getL2HintAttr();
171
172 if (!isReadHintOrNone(getL3HintAttr()))
173 return emitOpError("invlid l3_hint: ") << getL3HintAttr();
174
175 return success();
176}
177
178//===----------------------------------------------------------------------===//
179// XeGPU_LoadNdOp
180//===----------------------------------------------------------------------===//
181LogicalResult LoadNdOp::verify() {
182 auto tdescTy = getTensorDescType();
183 auto valueTy = getType();
184
185 if (tdescTy.getRank() != 2)
186 return emitOpError("Expecting a 2D TensorDesc.\n");
187
188 if (tdescTy.getScattered())
189 return emitOpError("Expects a non-scattered TensorDesc.\n");
190
191 if (!valueTy)
192 return emitOpError("Invalid result, it should be a VectorType.\n");
193
194 if (!isReadHintOrNone(getL1HintAttr()))
195 return emitOpError("invlid l1_hint: ") << getL1HintAttr();
196
197 if (!isReadHintOrNone(getL2HintAttr()))
198 return emitOpError("invlid l2_hint: ") << getL2HintAttr();
199
200 if (!isReadHintOrNone(getL3HintAttr()))
201 return emitOpError("invlid l3_hint: ") << getL3HintAttr();
202
203 auto array_len = tdescTy.getArrayLength();
204 auto tdescShape = getShapeOf(tdescTy);
205 auto valueShape = getShapeOf(valueTy);
206
207 if (getTranspose()) {
208 auto trans = getTranspose().value();
209 if (tdescShape.size() >= trans.size())
210 transpose(trans, tdescShape);
211 else
212 emitWarning("Invalid transpose attr. It is ignored.");
213 }
214
215 if (getVnniAxis()) {
216 auto axis = getVnniAxis().value();
217 auto vnni_factor = valueShape.back();
218 tdescShape[axis] /= vnni_factor;
219 tdescShape.push_back(vnni_factor);
220 }
221
222 if (array_len > 1) {
223 auto it = tdescShape.begin();
224 tdescShape.insert(it, array_len);
225 }
226
227 if (tdescShape != valueShape)
228 return emitOpError() << "Result shape doesn't match TensorDesc shape."
229 << "The expected shape is " << makeString(tdescShape)
230 << ". But the given shape is "
231 << makeString(valueShape) << ".\n";
232 return success();
233}
234
235//===----------------------------------------------------------------------===//
236// XeGPU_StoreNdOp
237//===----------------------------------------------------------------------===//
238LogicalResult StoreNdOp::verify() {
239 auto dstTy = getTensorDescType(); // Tile
240 auto valTy = getValueType(); // Vector
241
242 if (dstTy.getRank() != 2)
243 return emitOpError("Expecting a 2D TensorDesc.\n");
244
245 if (dstTy.getScattered())
246 return emitOpError("Expects a non-scattered TensorDesc.\n");
247
248 if (!valTy)
249 return emitOpError("Exepcting a VectorType result.\n");
250
251 if (!isWriteHintOrNone(getL1HintAttr()))
252 return emitOpError("invlid l1_hint: ") << getL1HintAttr();
253
254 if (!isWriteHintOrNone(getL2HintAttr()))
255 return emitOpError("invlid l2_hint: ") << getL2HintAttr();
256
257 if (!isWriteHintOrNone(getL3HintAttr()))
258 return emitOpError("invlid l3_hint: ") << getL3HintAttr();
259
260 return success();
261}
262
263//===----------------------------------------------------------------------===//
264// XeGPU_UpdateNDOffsetOp
265//===----------------------------------------------------------------------===//
266LogicalResult UpdateNdOffsetOp::verify() {
267 auto ty = getTensorDescType();
268 if (ty.getScattered())
269 return emitOpError("Expects a non-scattered TensorDesc.\n");
270
271 // number of offsets specified must match the rank of the tensor descriptor
272 if (ty.getRank() != (int64_t)getNumOffsets()) {
273 return emitOpError("Invalid number of offsets.");
274 }
275 return success();
276}
277
278//===----------------------------------------------------------------------===//
279// XeGPU_CreateDescOp
280//===----------------------------------------------------------------------===//
281void CreateDescOp::build(OpBuilder &builder, OperationState &state,
282 TensorDescType TensorDesc, Value source,
283 llvm::ArrayRef<OpFoldResult> offsets,
284 uint32_t chunk_size) {
285 llvm::SmallVector<int64_t> staticOffsets;
286 llvm::SmallVector<Value> dynamicOffsets;
287 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
288 build(builder, state, TensorDesc, source, dynamicOffsets, staticOffsets,
289 chunk_size);
290}
291
292LogicalResult CreateDescOp::verify() {
293 auto tdescTy = getTensorDescType();
294 auto chunkSize = getChunkSize();
295
296 if (getRankOf(getSource()) > 1)
297 return emitOpError(
298 "Expecting the source is a 1D memref or pointer (uint64_t).");
299
300 if (!tdescTy.getScattered())
301 return emitOpError("Expects a scattered TensorDesc.\n");
302
303 SmallVector<int64_t> shape({(int64_t)getNumOffsets()});
304 if (chunkSize != 1)
305 shape.push_back(chunkSize);
306
307 auto tdescShape = getShapeOf(tdescTy);
308 if (shape != tdescShape)
309 return emitOpError("Incorrect TensorDesc shape. ")
310 << "Expected is " << makeString(shape) << "\n";
311
312 return success();
313}
314
315//===----------------------------------------------------------------------===//
316// XeGPU_PrefetchOp
317//===----------------------------------------------------------------------===//
318LogicalResult PrefetchOp::verify() {
319 auto tdescTy = getTensorDescType();
320 if (!tdescTy.getScattered())
321 return emitOpError("Expects a scattered TensorDesc.\n");
322
323 if (!isReadHintOrNone(getL1HintAttr()))
324 return emitOpError("invlid l1_hint: ") << getL1HintAttr();
325
326 if (!isReadHintOrNone(getL2HintAttr()))
327 return emitOpError("invlid l2_hint: ") << getL2HintAttr();
328
329 if (!isReadHintOrNone(getL3HintAttr()))
330 return emitOpError("invlid l3_hint: ") << getL3HintAttr();
331
332 return success();
333}
334
335//===----------------------------------------------------------------------===//
336// XeGPU_LoadGatherOp
337//===----------------------------------------------------------------------===//
338LogicalResult LoadGatherOp::verify() {
339 auto tdescTy = getTensorDescType();
340 auto maskTy = getMaskType();
341 auto valueTy = getValueType();
342
343 if (!tdescTy.getScattered())
344 return emitOpError("Expects a scattered TensorDesc.\n");
345
346 if (!isReadHintOrNone(getL1HintAttr()))
347 return emitOpError("invlid l1_hint: ") << getL1HintAttr();
348
349 if (!isReadHintOrNone(getL2HintAttr()))
350 return emitOpError("invlid l2_hint: ") << getL2HintAttr();
351
352 if (!isReadHintOrNone(getL3HintAttr()))
353 return emitOpError("invlid l3_hint: ") << getL3HintAttr();
354
355 auto tdescElemTy = tdescTy.getElementType();
356 auto valueElemTy = getElementType();
357 if (tdescElemTy != valueElemTy)
358 return emitOpError(
359 "Value should have the same element type as TensorDesc.");
360
361 auto maskShape = getShapeOf(maskTy);
362 auto valueShape = getShapeOf(valueTy);
363 auto tdescShape = getShapeOf(tdescTy);
364
365 if (tdescShape[0] != maskShape[0])
366 return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
367
368 if (getTransposeAttr()) {
369 auto trans = getTranspose().value();
370 if (tdescShape.size() < trans.size())
371 emitWarning("Invalid transpose attr. It is ignored.");
372 else
373 transpose(trans, tdescShape);
374 }
375
376 if (valueShape != tdescShape)
377 return emitOpError("Unexpected result shape")
378 << "(Expected shape: " << makeString(tdescShape)
379 << ", Given shape: " << makeString(valueShape) << ").\n";
380
381 return success();
382}
383
384//===----------------------------------------------------------------------===//
385// XeGPU_StoreScatterOp
386//===----------------------------------------------------------------------===//
387LogicalResult StoreScatterOp::verify() {
388 auto tdescTy = getTensorDescType();
389 if (!tdescTy.getScattered())
390 return emitOpError("Expects a scattered TensorDesc.\n");
391
392 if (!isWriteHintOrNone(getL1HintAttr()))
393 return emitOpError("invlid l1_hint: ") << getL1HintAttr();
394
395 if (!isWriteHintOrNone(getL2HintAttr()))
396 return emitOpError("invlid l2_hint: ") << getL2HintAttr();
397
398 if (!isWriteHintOrNone(getL3HintAttr()))
399 return emitOpError("invlid l3_hint: ") << getL3HintAttr();
400
401 auto maskTy = getMaskType();
402 auto maskShape = getShapeOf(maskTy);
403 auto tdescShape = getShapeOf(tdescTy);
404 if (tdescShape[0] != maskShape[0])
405 return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
406
407 return success();
408}
409
410} // namespace xegpu
411} // namespace mlir
412
413#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
414#define GET_OP_CLASSES
415#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
416

source code of mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp