1 | //===- XeGPUDialect.cpp - MLIR XeGPU dialect implementation -----*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
10 | #include "mlir/Dialect/XeGPU/IR/XeGPU.h" |
11 | #include "mlir/IR/Builders.h" |
12 | #include "mlir/IR/DialectImplementation.h" |
13 | #include "llvm/ADT/TypeSwitch.h" |
14 | #include <numeric> |
15 | |
16 | using std::optional; |
17 | |
18 | namespace mlir { |
19 | namespace xegpu { |
20 | |
21 | void XeGPUDialect::initialize() { |
22 | addTypes< |
23 | #define GET_TYPEDEF_LIST |
24 | #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc> |
25 | >(); |
26 | addOperations< |
27 | #define GET_OP_LIST |
28 | #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc> |
29 | >(); |
30 | addAttributes< |
31 | #define GET_ATTRDEF_LIST |
32 | #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc> |
33 | >(); |
34 | } |
35 | |
36 | // Checks if the given shape can be evenly distributed based on the layout |
37 | // and data factors provided by the LayoutAttr. |
38 | bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, |
39 | xegpu::LayoutAttr attr) { |
40 | assert(attr && "Layout attribute is missing." ); |
41 | |
42 | // Checks whether the given shape can be evenly distributed using the |
43 | // specified layout and data attributes. If successful, it returns the work |
44 | // size for each compute unit; otherwise, it returns `std::nullopt`. The work |
45 | // size per compute unit is calculated as follows: |
46 | // - If `data` is null: newShape[i] = shape[i] / layout[i] |
47 | // - If `data` is not null: newShape[i] = data[i] |
48 | // When round-robin distribution (`rr`) is enabled, `shape[i]` can be |
49 | // smaller than `layout[i] * data[i]`, allowing multiple compute units to |
50 | // share the data. |
51 | auto tryDistribute = [&](llvm::ArrayRef<int64_t> shape, |
52 | DenseI32ArrayAttr layout, DenseI32ArrayAttr data, |
53 | bool rr = true) -> optional<SmallVector<int64_t>> { |
54 | llvm::SmallVector<int64_t> newShape(shape); |
55 | if (layout) { |
56 | auto vec = llvm::to_vector_of<int64_t>(layout.asArrayRef()); |
57 | if (vec.size() != shape.size()) |
58 | return std::nullopt; |
59 | auto ratio = computeShapeRatio(shape, vec); |
60 | if (!ratio.has_value()) |
61 | return std::nullopt; |
62 | newShape = ratio.value(); |
63 | } |
64 | |
65 | if (data) { |
66 | auto vec = llvm::to_vector_of<int64_t>(data.asArrayRef()); |
67 | if (vec.size() != shape.size()) |
68 | return std::nullopt; |
69 | auto ratio = computeShapeRatio(newShape, vec); |
70 | if (!ratio.has_value() && rr) |
71 | ratio = computeShapeRatio(vec, newShape); |
72 | if (!ratio.has_value()) |
73 | return std::nullopt; |
74 | |
75 | // if data is not null, we always return it for next phase. |
76 | newShape = vec; |
77 | } |
78 | return newShape; |
79 | }; |
80 | |
81 | // check the sgLayout and sgData |
82 | auto maybeSgShape = |
83 | tryDistribute(shape, attr.getSgLayout(), attr.getSgData()); |
84 | if (!maybeSgShape) |
85 | return false; |
86 | auto sgShape = maybeSgShape.value(); |
87 | |
88 | // check InstData, it neither have layout nor need round-robin |
89 | auto maybeInstShape = |
90 | tryDistribute(sgShape, nullptr, attr.getInstData(), false); |
91 | if (!maybeInstShape) |
92 | return false; |
93 | auto instShape = maybeInstShape.value(); |
94 | |
95 | // check LaneLayout and LaneData |
96 | auto maybeLaneShape = |
97 | tryDistribute(instShape, attr.getLaneLayout(), attr.getLaneData(), false); |
98 | return maybeLaneShape.has_value(); |
99 | } |
100 | |
101 | //===----------------------------------------------------------------------===// |
102 | // XeGPU_BlockTensorDescAttr |
103 | //===----------------------------------------------------------------------===// |
104 | BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context, |
105 | xegpu::MemorySpace memory_space, |
106 | int array_length, |
107 | bool boundary_check) { |
108 | auto scopeAttr = MemorySpaceAttr::get(context, memory_space); |
109 | auto lengthAttr = |
110 | IntegerAttr::get(IntegerType::get(context, 64), array_length); |
111 | auto boundaryAttr = BoolAttr::get(context, boundary_check); |
112 | return Base::get(context, scopeAttr, lengthAttr, boundaryAttr); |
113 | } |
114 | |
115 | //===----------------------------------------------------------------------===// |
116 | // XeGPU_ScatterTensorDescAttr |
117 | //===----------------------------------------------------------------------===// |
118 | ScatterTensorDescAttr |
119 | ScatterTensorDescAttr::get(mlir::MLIRContext *context, |
120 | xegpu::MemorySpace memory_space, int chunk_size) { |
121 | auto scopeAttr = MemorySpaceAttr::get(context, memory_space); |
122 | auto chunkSizeAttr = |
123 | IntegerAttr::get(IntegerType::get(context, 64), chunk_size); |
124 | return Base::get(context, scopeAttr, chunkSizeAttr); |
125 | } |
126 | |
127 | LogicalResult ScatterTensorDescAttr::verify( |
128 | llvm::function_ref<mlir::InFlightDiagnostic()> emitError, |
129 | MemorySpaceAttr memory_space, IntegerAttr chunk_size) { |
130 | int64_t chunkSize = chunk_size.getInt(); |
131 | SmallVector<int64_t> supportedChunkSizes = {1, 2, 3, 4, 8, |
132 | 16, 32, 64, 128, 256}; |
133 | if (!llvm::is_contained(supportedChunkSizes, chunkSize)) |
134 | return emitError() << "invalid chunk size" ; |
135 | |
136 | return success(); |
137 | } |
138 | |
139 | //===----------------------------------------------------------------------===// |
140 | // XeGPU_LayoutAttr |
141 | //===----------------------------------------------------------------------===// |
142 | LogicalResult |
143 | LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, |
144 | DenseI32ArrayAttr sg_layout, DenseI32ArrayAttr sg_data, |
145 | DenseI32ArrayAttr inst_data, DenseI32ArrayAttr lane_layout, |
146 | DenseI32ArrayAttr lane_data, DenseI32ArrayAttr order) { |
147 | |
148 | // A valid layout must include at least one of sg_layout and lane_layout. |
149 | // sg_layout is essential for Workgroup layout, while lane_layout is |
150 | // required for Subgroup layout. |
151 | if (!sg_layout && !inst_data && !lane_layout) { |
152 | return emitError() |
153 | << "expected at least one of sg_layout, inst_data or lane_layout" ; |
154 | } |
155 | |
156 | // generate code to check sg_laout, inst_data and lane_layout having the same |
157 | // rank if they are not null. |
158 | |
159 | if (sg_layout && inst_data && sg_layout.size() != inst_data.size()) { |
160 | return emitError() |
161 | << "expected sg_layout and inst_data to have the same rank" ; |
162 | } |
163 | |
164 | if (sg_layout && lane_layout && sg_layout.size() != lane_layout.size()) { |
165 | return emitError() |
166 | << "expected sg_layout and lane_layout to have the same rank" ; |
167 | } |
168 | |
169 | if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) { |
170 | return emitError() |
171 | << "expected inst_data and lane_layout to have the same rank" ; |
172 | } |
173 | |
174 | // sg_data is optional for Workgroup layout, but its presence requires |
175 | // sg_layout. |
176 | if (sg_data) { |
177 | if (!sg_layout) |
178 | return emitError() << "expected sg_layout being used with sg_data" ; |
179 | if (sg_data.size() != sg_layout.size()) |
180 | return emitError() |
181 | << "expected sg_data and sg_layout to have the same rank" ; |
182 | } |
183 | |
184 | // lane_data is optional for Subgroup layout, but its presence requires |
185 | // lane_layout. |
186 | if (lane_data) { |
187 | if (!lane_layout) |
188 | return emitError() << "expected lane_layout being used with lane_data" ; |
189 | if (lane_data.size() != lane_layout.size()) |
190 | return emitError() |
191 | << "expected lane_data and lane_layout to have the same rank" ; |
192 | } |
193 | |
194 | if (order) { |
195 | if (!sg_layout && !lane_layout) |
196 | return emitError() |
197 | << "expected sg_layout/lane_layout being used with order" ; |
198 | |
199 | if (sg_layout && order.size() != sg_layout.size()) |
200 | return emitError() |
201 | << "expected order and sg_layout to have the same rank" ; |
202 | |
203 | if (lane_layout && order.size() != lane_layout.size()) |
204 | return emitError() |
205 | << "expected order and lane_layout to have the same rank" ; |
206 | } |
207 | |
208 | return success(); |
209 | } |
210 | |
211 | //===----------------------------------------------------------------------===// |
212 | // XeGPU_TensorDescType |
213 | //===----------------------------------------------------------------------===// |
214 | |
215 | mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) { |
216 | llvm::SmallVector<int64_t> shape; |
217 | mlir::Type elementType; |
218 | mlir::FailureOr<mlir::Attribute> encoding; |
219 | mlir::FailureOr<mlir::Attribute> layout; |
220 | |
221 | // Parse literal '<' |
222 | if (parser.parseLess()) |
223 | return {}; |
224 | |
225 | auto shapeLoc = parser.getCurrentLocation(); |
226 | if (mlir::failed(parser.parseDimensionList(shape))) { |
227 | parser.emitError(shapeLoc, "failed to parse parameter 'shape'" ); |
228 | return {}; |
229 | } |
230 | |
231 | auto elemTypeLoc = parser.getCurrentLocation(); |
232 | if (mlir::failed(parser.parseType(elementType))) { |
233 | parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'" ); |
234 | return {}; |
235 | } |
236 | |
237 | // parse optional attributes |
238 | while (mlir::succeeded(parser.parseOptionalComma())) { |
239 | mlir::Attribute attr; |
240 | ParseResult res = parser.parseAttribute(attr); |
241 | if (mlir::succeeded(res)) { |
242 | if (mlir::isa<LayoutAttr>(attr)) { |
243 | layout = attr; |
244 | continue; |
245 | } |
246 | if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) { |
247 | encoding = attr; |
248 | continue; |
249 | } |
250 | } |
251 | return {}; |
252 | } |
253 | |
254 | // Parse literal '>' |
255 | if (parser.parseGreater()) |
256 | return {}; |
257 | |
258 | return TensorDescType::getChecked( |
259 | [&]() { return parser.emitError(parser.getNameLoc()); }, |
260 | parser.getContext(), shape, elementType, |
261 | encoding.value_or(mlir::Attribute()), layout.value_or(mlir::Attribute())); |
262 | } |
263 | |
264 | void TensorDescType::print(::mlir::AsmPrinter &printer) const { |
265 | printer << "<" ; |
266 | |
267 | auto shape = getShape(); |
268 | for (int64_t dim : shape) { |
269 | if (mlir::ShapedType::isDynamic(dim)) |
270 | printer << '?'; |
271 | else |
272 | printer << dim; |
273 | printer << 'x'; |
274 | } |
275 | |
276 | printer << getElementType(); |
277 | |
278 | if (auto encoding = getEncoding()) |
279 | printer << ", " << encoding; |
280 | |
281 | if (auto layout = getLayout()) |
282 | printer << ", " << layout; |
283 | |
284 | printer << ">" ; |
285 | } |
286 | |
287 | TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape, |
288 | mlir::Type elementType, int array_length, |
289 | bool boundary_check, |
290 | MemorySpace memory_space, |
291 | mlir::Attribute layout) { |
292 | auto context = elementType.getContext(); |
293 | auto attr = BlockTensorDescAttr::get(context, memory_space, array_length, |
294 | boundary_check); |
295 | return Base::get(context, shape, elementType, attr, layout); |
296 | } |
297 | |
298 | TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape, |
299 | mlir::Type elementType, int chunk_size, |
300 | MemorySpace memory_space, |
301 | mlir::Attribute layout) { |
302 | auto context = elementType.getContext(); |
303 | auto attr = ScatterTensorDescAttr::get(context, memory_space, chunk_size); |
304 | return Base::get(context, shape, elementType, attr, layout); |
305 | } |
306 | |
307 | LogicalResult TensorDescType::verify( |
308 | llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, |
309 | llvm::ArrayRef<int64_t> shape, mlir::Type elementType, |
310 | mlir::Attribute encoding, mlir::Attribute layout) { |
311 | size_t rank = shape.size(); |
312 | // Low-precision types are packed in 32-bit units. |
313 | int32_t packingFactor = 32 / elementType.getIntOrFloatBitWidth(); |
314 | if (rank != 1 && rank != 2) |
315 | return emitError() << "expected 1D or 2D tensor" ; |
316 | |
317 | auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding); |
318 | if (scatterAttr) { |
319 | // Expected tensor ranks for scattered data: |
320 | // - 1D tensor for fully non-contiguous elements (chunk size == 1) |
321 | // - 2D tensor for scattered blocks (chunk size > 1) |
322 | unsigned chunkSize = scatterAttr.getChunkSize().getInt(); |
323 | if (rank == 1 && chunkSize != 1) |
324 | return emitError() << "expected non-contiguous elements for 1D tensor" ; |
325 | if (rank == 2 && chunkSize < 2) |
326 | return emitError() << "expected chunk blocks for 2D tensor" ; |
327 | // If chunk size > 1, the second dimension of the tensor shape must be |
328 | // equal to chunk size and it must be a multiple of the packing factor. |
329 | if (chunkSize > 1) { |
330 | if (shape.back() != chunkSize) |
331 | return emitError() << "expected tensor shape[1] to match chunk size" ; |
332 | if (shape.back() % packingFactor != 0) |
333 | return emitError() |
334 | << "expected tensor shape[1] to be a multiple of packing factor " |
335 | << packingFactor; |
336 | } |
337 | } |
338 | |
339 | auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding); |
340 | if (blockAttr) { |
341 | MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace(); |
342 | if (rank == 2 && memorySpaceAttr && |
343 | memorySpaceAttr.getValue() == MemorySpace::SLM) |
344 | return emitError() << "SLM is not supported for 2D block tensor" ; |
345 | } |
346 | |
347 | auto layoutAttr = llvm::dyn_cast_if_present<LayoutAttr>(layout); |
348 | if (layoutAttr) { |
349 | if (rank != (size_t)layoutAttr.getRank()) |
350 | return emitError() << "expected layout rank to match tensor rank" ; |
351 | |
352 | auto laneData = layoutAttr.getLaneData(); |
353 | if (scatterAttr && laneData) { |
354 | // Validate subgroup mapping rules for scattered tensors. |
355 | // A work-item's slice of the tensor with shape [sg_size] or |
356 | // [sg_size, chunk_size] will be [1] or [1, 32/element_ty_bit_width] |
357 | // respectively, the mapping should reflect that. This is because each |
358 | // work item access data in 32 bit granularity. |
359 | |
360 | if (rank > 1 && laneData[0] != 1) |
361 | return emitError() |
362 | << "cannot map over non-contiguous scattered row elements" ; |
363 | if (laneData[rank - 1] != packingFactor) |
364 | return emitError() << "work item data mapping must match the number of " |
365 | "contiguous elements" ; |
366 | } |
367 | |
368 | if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) { |
369 | std::string shapeStr; |
370 | llvm::raw_string_ostream stream(shapeStr); |
371 | llvm::interleaveComma(shape, stream); |
372 | return emitError() << "cannot distribute [" << shapeStr << "] using " |
373 | << layoutAttr; |
374 | } |
375 | } |
376 | return success(); |
377 | } |
378 | |
379 | } // namespace xegpu |
380 | } // namespace mlir |
381 | |
382 | #include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc> |
383 | #define GET_ATTRDEF_CLASSES |
384 | #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc> |
385 | #define GET_TYPEDEF_CLASSES |
386 | #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc> |
387 | |