1 | //===- XeGPUOps.cpp - MLIR XeGPU ops implementation -------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
10 | #include "mlir/Dialect/XeGPU/IR/XeGPU.h" |
11 | #include "mlir/IR/Builders.h" |
12 | #include "mlir/IR/TypeUtilities.h" |
13 | |
14 | #include "llvm/Support/Debug.h" |
15 | |
16 | #define DEBUG_TYPE "xegpu" |
17 | |
18 | namespace mlir { |
19 | namespace xegpu { |
20 | |
21 | static void transpose(llvm::ArrayRef<int64_t> trans, |
22 | SmallVector<int64_t> &shape) { |
23 | SmallVector<int64_t> old = shape; |
24 | for (size_t i = 0; i < trans.size(); i++) |
25 | shape[i] = old[trans[i]]; |
26 | } |
27 | |
28 | template <typename T> |
29 | static std::string makeString(T array, bool breakline = false) { |
30 | std::string buf; |
31 | buf.clear(); |
32 | llvm::raw_string_ostream os(buf); |
33 | os << "[" ; |
34 | for (size_t i = 1; i < array.size(); i++) { |
35 | os << array[i - 1] << ", " ; |
36 | if (breakline) |
37 | os << "\n\t\t" ; |
38 | } |
39 | os << array.back() << "]" ; |
40 | os.flush(); |
41 | return buf; |
42 | } |
43 | |
44 | static SmallVector<int64_t> getShapeOf(Type type) { |
45 | SmallVector<int64_t> shape; |
46 | if (auto ty = llvm::dyn_cast<ShapedType>(type)) |
47 | shape = SmallVector<int64_t>(ty.getShape()); |
48 | else |
49 | shape.push_back(Elt: 1); |
50 | return shape; |
51 | } |
52 | |
53 | static int64_t getRankOf(Value val) { |
54 | auto type = val.getType(); |
55 | if (auto ty = llvm::dyn_cast<ShapedType>(type)) |
56 | return ty.getRank(); |
57 | return 0; |
58 | } |
59 | |
60 | static bool isReadHintOrNone(const CachePolicyAttr &attr) { |
61 | if (!attr) |
62 | return true; |
63 | auto kind = attr.getValue(); |
64 | return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED || |
65 | kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE; |
66 | } |
67 | |
68 | static bool isWriteHintOrNone(const CachePolicyAttr &attr) { |
69 | if (!attr) |
70 | return true; |
71 | auto kind = attr.getValue(); |
72 | return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED || |
73 | kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH; |
74 | } |
75 | |
76 | //===----------------------------------------------------------------------===// |
77 | // XeGPU_CreateNdDescOp |
78 | //===----------------------------------------------------------------------===// |
79 | void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, |
80 | Type tdesc, TypedValue<MemRefType> source, |
81 | llvm::ArrayRef<OpFoldResult> offsets) { |
82 | [[maybe_unused]] auto ty = source.getType(); |
83 | assert(ty.hasStaticShape() && offsets.size() == (size_t)ty.getRank()); |
84 | |
85 | llvm::SmallVector<int64_t> staticOffsets; |
86 | llvm::SmallVector<Value> dynamicOffsets; |
87 | dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); |
88 | |
89 | build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */, |
90 | ValueRange({}) /* empty dynamic shape */, |
91 | ValueRange({}) /* empty dynamic strides */, |
92 | staticOffsets /* const offsets */, {} /* empty const shape*/, |
93 | {} /* empty const strides*/); |
94 | } |
95 | |
96 | void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, |
97 | Type tdesc, TypedValue<IntegerType> source, |
98 | llvm::ArrayRef<OpFoldResult> offsets, |
99 | llvm::ArrayRef<OpFoldResult> shape, |
100 | llvm::ArrayRef<OpFoldResult> strides) { |
101 | assert(shape.size() && offsets.size() && strides.size() && |
102 | shape.size() == strides.size() && shape.size() == offsets.size()); |
103 | |
104 | llvm::SmallVector<int64_t> staticOffsets; |
105 | llvm::SmallVector<int64_t> staticShape; |
106 | llvm::SmallVector<int64_t> staticStrides; |
107 | llvm::SmallVector<Value> dynamicOffsets; |
108 | llvm::SmallVector<Value> dynamicShape; |
109 | llvm::SmallVector<Value> dynamicStrides; |
110 | |
111 | dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); |
112 | dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); |
113 | dispatchIndexOpFoldResults(strides, dynamicStrides, staticOffsets); |
114 | |
115 | auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); |
116 | auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); |
117 | auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); |
118 | |
119 | build(builder, state, tdesc, source, dynamicOffsets, dynamicShape, |
120 | dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr); |
121 | } |
122 | |
123 | LogicalResult CreateNdDescOp::verify() { |
124 | auto rank = (int64_t)getMixedOffsets().size(); |
125 | bool invalidRank = (rank != 2); |
126 | bool invalidElemTy = false; |
127 | |
128 | // check source type matches the rank if it is a memref. |
129 | // It also should have the same ElementType as TensorDesc. |
130 | auto memrefTy = dyn_cast<MemRefType>(getSourceType()); |
131 | if (memrefTy) { |
132 | invalidRank |= (memrefTy.getRank() != rank); |
133 | invalidElemTy |= memrefTy.getElementType() != getElementType(); |
134 | } |
135 | |
136 | // check result type matches the rank |
137 | invalidRank = (getType().getRank() != rank); |
138 | |
139 | // mismatches among shape, strides, and offsets are |
140 | // already handeled by OffsetSizeAndStrideOpInterface. |
141 | // So they are not check here. |
142 | if (invalidRank) |
143 | return emitOpError( |
144 | "Expecting the rank of shape, strides, offsets, " |
145 | "source memref type (if source is a memref) and TensorDesc " |
146 | "should match with each other. They currenlty are 2D." ); |
147 | |
148 | if (invalidElemTy) |
149 | return emitOpError("TensorDesc should have the same element " |
150 | "type with the source if it is a memref.\n" ); |
151 | |
152 | if (getType().getScattered()) |
153 | return emitOpError("Expects a non-scattered TensorDesc.\n" ); |
154 | |
155 | return success(); |
156 | } |
157 | |
158 | //===----------------------------------------------------------------------===// |
159 | // XeGPU_PrefetchNdOp |
160 | //===----------------------------------------------------------------------===// |
161 | LogicalResult PrefetchNdOp::verify() { |
162 | auto tdescTy = getTensorDescType(); |
163 | if (tdescTy.getScattered()) |
164 | return emitOpError("Expects a non-scattered TensorDesc.\n" ); |
165 | |
166 | if (!isReadHintOrNone(getL1HintAttr())) |
167 | return emitOpError("invlid l1_hint: " ) << getL1HintAttr(); |
168 | |
169 | if (!isReadHintOrNone(getL2HintAttr())) |
170 | return emitOpError("invlid l2_hint: " ) << getL2HintAttr(); |
171 | |
172 | if (!isReadHintOrNone(getL3HintAttr())) |
173 | return emitOpError("invlid l3_hint: " ) << getL3HintAttr(); |
174 | |
175 | return success(); |
176 | } |
177 | |
178 | //===----------------------------------------------------------------------===// |
179 | // XeGPU_LoadNdOp |
180 | //===----------------------------------------------------------------------===// |
181 | LogicalResult LoadNdOp::verify() { |
182 | auto tdescTy = getTensorDescType(); |
183 | auto valueTy = getType(); |
184 | |
185 | if (tdescTy.getRank() != 2) |
186 | return emitOpError("Expecting a 2D TensorDesc.\n" ); |
187 | |
188 | if (tdescTy.getScattered()) |
189 | return emitOpError("Expects a non-scattered TensorDesc.\n" ); |
190 | |
191 | if (!valueTy) |
192 | return emitOpError("Invalid result, it should be a VectorType.\n" ); |
193 | |
194 | if (!isReadHintOrNone(getL1HintAttr())) |
195 | return emitOpError("invlid l1_hint: " ) << getL1HintAttr(); |
196 | |
197 | if (!isReadHintOrNone(getL2HintAttr())) |
198 | return emitOpError("invlid l2_hint: " ) << getL2HintAttr(); |
199 | |
200 | if (!isReadHintOrNone(getL3HintAttr())) |
201 | return emitOpError("invlid l3_hint: " ) << getL3HintAttr(); |
202 | |
203 | auto array_len = tdescTy.getArrayLength(); |
204 | auto tdescShape = getShapeOf(tdescTy); |
205 | auto valueShape = getShapeOf(valueTy); |
206 | |
207 | if (getTranspose()) { |
208 | auto trans = getTranspose().value(); |
209 | if (tdescShape.size() >= trans.size()) |
210 | transpose(trans, tdescShape); |
211 | else |
212 | emitWarning("Invalid transpose attr. It is ignored." ); |
213 | } |
214 | |
215 | if (getVnniAxis()) { |
216 | auto axis = getVnniAxis().value(); |
217 | auto vnni_factor = valueShape.back(); |
218 | tdescShape[axis] /= vnni_factor; |
219 | tdescShape.push_back(vnni_factor); |
220 | } |
221 | |
222 | if (array_len > 1) { |
223 | auto it = tdescShape.begin(); |
224 | tdescShape.insert(it, array_len); |
225 | } |
226 | |
227 | if (tdescShape != valueShape) |
228 | return emitOpError() << "Result shape doesn't match TensorDesc shape." |
229 | << "The expected shape is " << makeString(tdescShape) |
230 | << ". But the given shape is " |
231 | << makeString(valueShape) << ".\n" ; |
232 | return success(); |
233 | } |
234 | |
235 | //===----------------------------------------------------------------------===// |
236 | // XeGPU_StoreNdOp |
237 | //===----------------------------------------------------------------------===// |
238 | LogicalResult StoreNdOp::verify() { |
239 | auto dstTy = getTensorDescType(); // Tile |
240 | auto valTy = getValueType(); // Vector |
241 | |
242 | if (dstTy.getRank() != 2) |
243 | return emitOpError("Expecting a 2D TensorDesc.\n" ); |
244 | |
245 | if (dstTy.getScattered()) |
246 | return emitOpError("Expects a non-scattered TensorDesc.\n" ); |
247 | |
248 | if (!valTy) |
249 | return emitOpError("Exepcting a VectorType result.\n" ); |
250 | |
251 | if (!isWriteHintOrNone(getL1HintAttr())) |
252 | return emitOpError("invlid l1_hint: " ) << getL1HintAttr(); |
253 | |
254 | if (!isWriteHintOrNone(getL2HintAttr())) |
255 | return emitOpError("invlid l2_hint: " ) << getL2HintAttr(); |
256 | |
257 | if (!isWriteHintOrNone(getL3HintAttr())) |
258 | return emitOpError("invlid l3_hint: " ) << getL3HintAttr(); |
259 | |
260 | return success(); |
261 | } |
262 | |
263 | //===----------------------------------------------------------------------===// |
264 | // XeGPU_UpdateNDOffsetOp |
265 | //===----------------------------------------------------------------------===// |
266 | LogicalResult UpdateNdOffsetOp::verify() { |
267 | auto ty = getTensorDescType(); |
268 | if (ty.getScattered()) |
269 | return emitOpError("Expects a non-scattered TensorDesc.\n" ); |
270 | |
271 | // number of offsets specified must match the rank of the tensor descriptor |
272 | if (ty.getRank() != (int64_t)getNumOffsets()) { |
273 | return emitOpError("Invalid number of offsets." ); |
274 | } |
275 | return success(); |
276 | } |
277 | |
278 | //===----------------------------------------------------------------------===// |
279 | // XeGPU_CreateDescOp |
280 | //===----------------------------------------------------------------------===// |
281 | void CreateDescOp::build(OpBuilder &builder, OperationState &state, |
282 | TensorDescType TensorDesc, Value source, |
283 | llvm::ArrayRef<OpFoldResult> offsets, |
284 | uint32_t chunk_size) { |
285 | llvm::SmallVector<int64_t> staticOffsets; |
286 | llvm::SmallVector<Value> dynamicOffsets; |
287 | dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); |
288 | build(builder, state, TensorDesc, source, dynamicOffsets, staticOffsets, |
289 | chunk_size); |
290 | } |
291 | |
292 | LogicalResult CreateDescOp::verify() { |
293 | auto tdescTy = getTensorDescType(); |
294 | auto chunkSize = getChunkSize(); |
295 | |
296 | if (getRankOf(getSource()) > 1) |
297 | return emitOpError( |
298 | "Expecting the source is a 1D memref or pointer (uint64_t)." ); |
299 | |
300 | if (!tdescTy.getScattered()) |
301 | return emitOpError("Expects a scattered TensorDesc.\n" ); |
302 | |
303 | SmallVector<int64_t> shape({(int64_t)getNumOffsets()}); |
304 | if (chunkSize != 1) |
305 | shape.push_back(chunkSize); |
306 | |
307 | auto tdescShape = getShapeOf(tdescTy); |
308 | if (shape != tdescShape) |
309 | return emitOpError("Incorrect TensorDesc shape. " ) |
310 | << "Expected is " << makeString(shape) << "\n" ; |
311 | |
312 | return success(); |
313 | } |
314 | |
315 | //===----------------------------------------------------------------------===// |
316 | // XeGPU_PrefetchOp |
317 | //===----------------------------------------------------------------------===// |
318 | LogicalResult PrefetchOp::verify() { |
319 | auto tdescTy = getTensorDescType(); |
320 | if (!tdescTy.getScattered()) |
321 | return emitOpError("Expects a scattered TensorDesc.\n" ); |
322 | |
323 | if (!isReadHintOrNone(getL1HintAttr())) |
324 | return emitOpError("invlid l1_hint: " ) << getL1HintAttr(); |
325 | |
326 | if (!isReadHintOrNone(getL2HintAttr())) |
327 | return emitOpError("invlid l2_hint: " ) << getL2HintAttr(); |
328 | |
329 | if (!isReadHintOrNone(getL3HintAttr())) |
330 | return emitOpError("invlid l3_hint: " ) << getL3HintAttr(); |
331 | |
332 | return success(); |
333 | } |
334 | |
335 | //===----------------------------------------------------------------------===// |
336 | // XeGPU_LoadGatherOp |
337 | //===----------------------------------------------------------------------===// |
338 | LogicalResult LoadGatherOp::verify() { |
339 | auto tdescTy = getTensorDescType(); |
340 | auto maskTy = getMaskType(); |
341 | auto valueTy = getValueType(); |
342 | |
343 | if (!tdescTy.getScattered()) |
344 | return emitOpError("Expects a scattered TensorDesc.\n" ); |
345 | |
346 | if (!isReadHintOrNone(getL1HintAttr())) |
347 | return emitOpError("invlid l1_hint: " ) << getL1HintAttr(); |
348 | |
349 | if (!isReadHintOrNone(getL2HintAttr())) |
350 | return emitOpError("invlid l2_hint: " ) << getL2HintAttr(); |
351 | |
352 | if (!isReadHintOrNone(getL3HintAttr())) |
353 | return emitOpError("invlid l3_hint: " ) << getL3HintAttr(); |
354 | |
355 | auto tdescElemTy = tdescTy.getElementType(); |
356 | auto valueElemTy = getElementType(); |
357 | if (tdescElemTy != valueElemTy) |
358 | return emitOpError( |
359 | "Value should have the same element type as TensorDesc." ); |
360 | |
361 | auto maskShape = getShapeOf(maskTy); |
362 | auto valueShape = getShapeOf(valueTy); |
363 | auto tdescShape = getShapeOf(tdescTy); |
364 | |
365 | if (tdescShape[0] != maskShape[0]) |
366 | return emitOpError("dim-0 of the Mask and TensorDesc should be the same." ); |
367 | |
368 | if (getTransposeAttr()) { |
369 | auto trans = getTranspose().value(); |
370 | if (tdescShape.size() < trans.size()) |
371 | emitWarning("Invalid transpose attr. It is ignored." ); |
372 | else |
373 | transpose(trans, tdescShape); |
374 | } |
375 | |
376 | if (valueShape != tdescShape) |
377 | return emitOpError("Unexpected result shape" ) |
378 | << "(Expected shape: " << makeString(tdescShape) |
379 | << ", Given shape: " << makeString(valueShape) << ").\n" ; |
380 | |
381 | return success(); |
382 | } |
383 | |
384 | //===----------------------------------------------------------------------===// |
385 | // XeGPU_StoreScatterOp |
386 | //===----------------------------------------------------------------------===// |
387 | LogicalResult StoreScatterOp::verify() { |
388 | auto tdescTy = getTensorDescType(); |
389 | if (!tdescTy.getScattered()) |
390 | return emitOpError("Expects a scattered TensorDesc.\n" ); |
391 | |
392 | if (!isWriteHintOrNone(getL1HintAttr())) |
393 | return emitOpError("invlid l1_hint: " ) << getL1HintAttr(); |
394 | |
395 | if (!isWriteHintOrNone(getL2HintAttr())) |
396 | return emitOpError("invlid l2_hint: " ) << getL2HintAttr(); |
397 | |
398 | if (!isWriteHintOrNone(getL3HintAttr())) |
399 | return emitOpError("invlid l3_hint: " ) << getL3HintAttr(); |
400 | |
401 | auto maskTy = getMaskType(); |
402 | auto maskShape = getShapeOf(maskTy); |
403 | auto tdescShape = getShapeOf(tdescTy); |
404 | if (tdescShape[0] != maskShape[0]) |
405 | return emitOpError("dim-0 of the Mask and TensorDesc should be the same." ); |
406 | |
407 | return success(); |
408 | } |
409 | |
410 | } // namespace xegpu |
411 | } // namespace mlir |
412 | |
413 | #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc> |
414 | #define GET_OP_CLASSES |
415 | #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc> |
416 | |