1//===- XeGPUDialect.cpp - MLIR XeGPU dialect implementation -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Utils/IndexingUtils.h"
10#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
11#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/DialectImplementation.h"
14#include "llvm/ADT/TypeSwitch.h"
15
16using std::optional;
17
18namespace mlir {
19namespace xegpu {
20
21void XeGPUDialect::initialize() {
22 addTypes<
23#define GET_TYPEDEF_LIST
24#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
25 >();
26 addOperations<
27#define GET_OP_LIST
28#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
29 >();
30 addAttributes<
31#define GET_ATTRDEF_LIST
32#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
33 >();
34}
35
36// Checks if the given shape can be evenly distributed based on the layout
37// and data factors provided by the LayoutAttr.
38bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
39 xegpu::LayoutAttr attr) {
40 assert(attr && "Layout attribute is missing.");
41
42 // Checks whether the given shape can be evenly distributed using the
43 // specified layout and data attributes. If successful, it returns the work
44 // size for each compute unit; otherwise, it returns `std::nullopt`. The work
45 // size per compute unit is calculated as follows:
46 // - If `data` is null: newShape[i] = shape[i] / layout[i]
47 // - If `data` is not null: newShape[i] = data[i]
48 // When round-robin distribution (`rr`) is enabled, `shape[i]` can be
49 // smaller than `layout[i] * data[i]`, allowing multiple compute units to
50 // share the data.
51 auto tryDistribute = [&](llvm::ArrayRef<int64_t> shape,
52 DenseI32ArrayAttr layout, DenseI32ArrayAttr data,
53 bool rr = true) -> optional<SmallVector<int64_t>> {
54 llvm::SmallVector<int64_t> newShape(shape);
55 if (layout) {
56 auto vec = llvm::to_vector_of<int64_t>(Range: layout.asArrayRef());
57 if (vec.size() != shape.size())
58 return std::nullopt;
59 auto ratio = computeShapeRatio(shape, subShape: vec);
60 if (!ratio.has_value())
61 return std::nullopt;
62 newShape = ratio.value();
63 }
64
65 if (data) {
66 auto vec = llvm::to_vector_of<int64_t>(Range: data.asArrayRef());
67 if (vec.size() != shape.size())
68 return std::nullopt;
69 auto ratio = computeShapeRatio(shape: newShape, subShape: vec);
70 if (!ratio.has_value() && rr)
71 ratio = computeShapeRatio(shape: vec, subShape: newShape);
72 if (!ratio.has_value())
73 return std::nullopt;
74
75 // if data is not null, we always return it for next phase.
76 newShape = vec;
77 }
78 return newShape;
79 };
80
81 // check the sgLayout and sgData
82 auto maybeSgShape =
83 tryDistribute(shape, attr.getSgLayout(), attr.getSgData());
84 if (!maybeSgShape)
85 return false;
86 auto sgShape = maybeSgShape.value();
87
88 // check InstData, it neither have layout nor need round-robin
89 auto maybeInstShape =
90 tryDistribute(sgShape, nullptr, attr.getInstData(), false);
91 if (!maybeInstShape)
92 return false;
93 auto instShape = maybeInstShape.value();
94
95 // check LaneLayout and LaneData
96 auto maybeLaneShape =
97 tryDistribute(instShape, attr.getLaneLayout(), attr.getLaneData(), false);
98 return maybeLaneShape.has_value();
99}
100
101//===----------------------------------------------------------------------===//
102// XeGPU_BlockTensorDescAttr
103//===----------------------------------------------------------------------===//
104BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context,
105 xegpu::MemorySpace memory_space,
106 int array_length,
107 bool boundary_check) {
108 auto scopeAttr = MemorySpaceAttr::get(context, value: memory_space);
109 auto lengthAttr =
110 IntegerAttr::get(type: IntegerType::get(context, width: 64), value: array_length);
111 auto boundaryAttr = BoolAttr::get(context, value: boundary_check);
112 return Base::get(ctx: context, args&: scopeAttr, args&: lengthAttr, args&: boundaryAttr);
113}
114
115//===----------------------------------------------------------------------===//
116// XeGPU_ScatterTensorDescAttr
117//===----------------------------------------------------------------------===//
118ScatterTensorDescAttr
119ScatterTensorDescAttr::get(mlir::MLIRContext *context,
120 xegpu::MemorySpace memory_space, int chunk_size) {
121 auto scopeAttr = MemorySpaceAttr::get(context, value: memory_space);
122 auto chunkSizeAttr =
123 IntegerAttr::get(type: IntegerType::get(context, width: 64), value: chunk_size);
124 return Base::get(ctx: context, args&: scopeAttr, args&: chunkSizeAttr);
125}
126
127LogicalResult ScatterTensorDescAttr::verify(
128 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
129 MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
130 int64_t chunkSize = chunk_size.getInt();
131 if (chunkSize <= 0)
132 return emitError() << "invalid chunk size";
133
134 return success();
135}
136
137//===----------------------------------------------------------------------===//
138// XeGPU_LayoutAttr
139//===----------------------------------------------------------------------===//
140LogicalResult
141LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
142 DenseI32ArrayAttr sg_layout, DenseI32ArrayAttr sg_data,
143 DenseI32ArrayAttr inst_data, DenseI32ArrayAttr lane_layout,
144 DenseI32ArrayAttr lane_data, DenseI32ArrayAttr order) {
145
146 // A valid layout must include at least one of sg_layout and lane_layout.
147 // sg_layout is essential for Workgroup layout, while lane_layout is
148 // required for Subgroup layout.
149 if (!sg_layout && !inst_data && !lane_layout) {
150 return emitError()
151 << "expected at least one of sg_layout, inst_data or lane_layout";
152 }
153
154 // generate code to check sg_laout, inst_data and lane_layout having the same
155 // rank if they are not null.
156
157 if (sg_layout && inst_data && sg_layout.size() != inst_data.size()) {
158 return emitError()
159 << "expected sg_layout and inst_data to have the same rank";
160 }
161
162 if (sg_layout && lane_layout && sg_layout.size() != lane_layout.size()) {
163 return emitError()
164 << "expected sg_layout and lane_layout to have the same rank";
165 }
166
167 if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) {
168 return emitError()
169 << "expected inst_data and lane_layout to have the same rank";
170 }
171
172 // sg_data is optional for Workgroup layout, but its presence requires
173 // sg_layout.
174 if (sg_data) {
175 if (!sg_layout)
176 return emitError() << "expected sg_layout being used with sg_data";
177 if (sg_data.size() != sg_layout.size())
178 return emitError()
179 << "expected sg_data and sg_layout to have the same rank";
180 }
181
182 // lane_data is optional for Subgroup layout, but its presence requires
183 // lane_layout.
184 if (lane_data) {
185 if (!lane_layout)
186 return emitError() << "expected lane_layout being used with lane_data";
187 if (lane_data.size() != lane_layout.size())
188 return emitError()
189 << "expected lane_data and lane_layout to have the same rank";
190 }
191
192 if (order) {
193 if (!sg_layout && !lane_layout)
194 return emitError()
195 << "expected sg_layout/lane_layout being used with order";
196
197 if (sg_layout && order.size() != sg_layout.size())
198 return emitError()
199 << "expected order and sg_layout to have the same rank";
200
201 if (lane_layout && order.size() != lane_layout.size())
202 return emitError()
203 << "expected order and lane_layout to have the same rank";
204 }
205
206 return success();
207}
208
209//===----------------------------------------------------------------------===//
210// XeGPU_TensorDescType
211//===----------------------------------------------------------------------===//
212
213mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
214 llvm::SmallVector<int64_t> shape;
215 mlir::Type elementType;
216 mlir::FailureOr<mlir::Attribute> encoding;
217 mlir::FailureOr<mlir::Attribute> layout;
218
219 // Parse literal '<'
220 if (parser.parseLess())
221 return {};
222
223 auto shapeLoc = parser.getCurrentLocation();
224 if (mlir::failed(Result: parser.parseDimensionList(dimensions&: shape))) {
225 parser.emitError(loc: shapeLoc, message: "failed to parse parameter 'shape'");
226 return {};
227 }
228
229 auto elemTypeLoc = parser.getCurrentLocation();
230 if (mlir::failed(Result: parser.parseType(result&: elementType))) {
231 parser.emitError(loc: elemTypeLoc, message: "failed to parse parameter 'elementType'");
232 return {};
233 }
234
235 // parse optional attributes
236 while (mlir::succeeded(Result: parser.parseOptionalComma())) {
237 mlir::Attribute attr;
238 ParseResult res = parser.parseAttribute(result&: attr);
239 if (mlir::succeeded(Result: res)) {
240 if (mlir::isa<LayoutAttr>(Val: attr)) {
241 layout = attr;
242 continue;
243 }
244 if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(Val: attr)) {
245 encoding = attr;
246 continue;
247 }
248 }
249 return {};
250 }
251
252 // Parse literal '>'
253 if (parser.parseGreater())
254 return {};
255
256 return TensorDescType::getChecked(
257 emitErrorFn: [&]() { return parser.emitError(loc: parser.getNameLoc()); },
258 ctx: parser.getContext(), args: shape, args: elementType,
259 args: encoding.value_or(u: mlir::Attribute()), args: layout.value_or(u: mlir::Attribute()));
260}
261
262void TensorDescType::print(::mlir::AsmPrinter &printer) const {
263 printer << "<";
264
265 auto shape = getShape();
266 for (int64_t dim : shape) {
267 if (mlir::ShapedType::isDynamic(dValue: dim))
268 printer << '?';
269 else
270 printer << dim;
271 printer << 'x';
272 }
273
274 printer << getElementType();
275
276 if (auto encoding = getEncoding())
277 printer << ", " << encoding;
278
279 if (auto layout = getLayout())
280 printer << ", " << layout;
281
282 printer << ">";
283}
284
285TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
286 mlir::Type elementType, int array_length,
287 bool boundary_check,
288 MemorySpace memory_space,
289 mlir::Attribute layout) {
290 auto context = elementType.getContext();
291 auto attr = BlockTensorDescAttr::get(context, memory_space, array_length,
292 boundary_check);
293 return Base::get(ctx: context, args&: shape, args&: elementType, args&: attr, args&: layout);
294}
295
296TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
297 mlir::Type elementType, int chunk_size,
298 MemorySpace memory_space,
299 mlir::Attribute layout) {
300 auto context = elementType.getContext();
301 auto attr = ScatterTensorDescAttr::get(context, memory_space, chunk_size);
302 return Base::get(ctx: context, args&: shape, args&: elementType, args&: attr, args&: layout);
303}
304
305LogicalResult TensorDescType::verify(
306 llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
307 llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
308 mlir::Attribute encoding, mlir::Attribute layout) {
309 size_t rank = shape.size();
310
311 if (rank == 0)
312 return emitError() << "expected non-zero rank tensor";
313
314 auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(Val&: encoding);
315 if (blockAttr) {
316 MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
317 if (rank > 1 && memorySpaceAttr &&
318 memorySpaceAttr.getValue() == MemorySpace::SLM)
319 return emitError() << "SLM is only supported for 1D block tensor";
320 }
321
322 // for gather and scatter ops, Low-precision types are packed in 32-bit units.
323 unsigned bitWidth = elementType.getIntOrFloatBitWidth();
324 int chunkAlignmentFactor =
325 bitWidth < targetinfo::packedSizeInBitsForGatherScatter
326 ? targetinfo::packedSizeInBitsForGatherScatter / bitWidth
327 : 1;
328 auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(Val&: encoding);
329 if (scatterAttr) {
330 int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
331 if (rank == 1 && chunkSize != 1)
332 return emitError() << "expected non-contiguous elements for 1D tensor";
333
334 // If chunk size > 1, the second dimension of the tensor shape must be
335 // equal to chunk size and it must be a multiple of the
336 // chunkAlignmentFactor.
337 if (chunkSize > 1) {
338 if (shape.back() != chunkSize)
339 return emitError() << "expected last dim of tensor to match chunk size";
340 if (shape.back() % chunkAlignmentFactor != 0)
341 return emitError() << "expected last dim of tensor to be a multiple of "
342 << chunkAlignmentFactor;
343 }
344 }
345
346 auto layoutAttr = llvm::dyn_cast_if_present<LayoutAttr>(Val&: layout);
347 if (layoutAttr) {
348 if (rank != (size_t)layoutAttr.getRank())
349 return emitError() << "expected layout rank to match tensor rank";
350
351 auto laneData = layoutAttr.getLaneData();
352 if (scatterAttr && laneData) {
353 // Validate subgroup mapping rules for scattered tensors.
354 // if chunkSize > 1, the last dimension of the tensor should
355 // be distributed in the units divisible by chunkAlignmentFactor.
356 int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
357 if (chunkSize > 1 && laneData[rank - 1] % chunkAlignmentFactor)
358 return emitError()
359 << "expected last dim of lane_data to be a multiple of: "
360 << chunkAlignmentFactor;
361 }
362
363 if (!XeGPUDialect::isEvenlyDistributable(shape, attr: layoutAttr)) {
364 std::string shapeStr;
365 llvm::raw_string_ostream stream(shapeStr);
366 llvm::interleaveComma(c: shape, os&: stream);
367 return emitError() << "cannot distribute [" << shapeStr << "] using "
368 << layoutAttr;
369 }
370 }
371 return success();
372}
373
374} // namespace xegpu
375} // namespace mlir
376
377#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
378#define GET_ATTRDEF_CLASSES
379#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
380#define GET_TYPEDEF_CLASSES
381#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
382

source code of mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp