1//===- XeGPUWgToSgDistribute.cpp - XeGPU Workgroup to Subgroup Pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
9
10#include "mlir/Dialect/Affine/Utils.h"
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/Arith/Utils/Utils.h"
13#include "mlir/Dialect/GPU/IR/GPUDialect.h"
14#include "mlir/Dialect/Index/IR/IndexDialect.h"
15#include "mlir/Dialect/Index/IR/IndexOps.h"
16#include "mlir/Dialect/Math/IR/Math.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Dialect/SCF/Transforms/Patterns.h"
19#include "mlir/Dialect/Utils/IndexingUtils.h"
20#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
21#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
22#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
23#include "mlir/Transforms/DialectConversion.h"
24#include <optional>
25
26namespace mlir {
27namespace xegpu {
28#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
29#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
30} // namespace xegpu
31} // namespace mlir
32
33using namespace mlir;
34
35namespace {
36
37static std::pair<SmallVector<int64_t>, int>
38getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
39 int count = 1;
40 SmallVector<int64_t> sgShape(shape);
41
42 if (layout && layout.isWgLayout()) {
43 DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout();
44 auto sgLayout = llvm::to_vector_of<int64_t>(Range: sgLayoutAttr.asArrayRef());
45 if (DenseI32ArrayAttr sgDataAttr = layout.getSgData())
46 sgShape = llvm::to_vector_of<int64_t>(Range: sgDataAttr.asArrayRef());
47 else
48 sgShape = computeShapeRatio(shape, subShape: sgLayout).value_or(u&: sgShape);
49 SmallVector<int64_t> distUnit = computeElementwiseMul(v1: sgLayout, v2: sgShape);
50 // Clamp distUnit to the original shape to handle cases where data is
51 // shared among subgroups, which may cause distUnit to exceed the original
52 // shape.
53 for (size_t i = 0; i < distUnit.size(); ++i)
54 distUnit[i] = std::min(a: shape[i], b: distUnit[i]);
55 count = computeProduct(basis: shape) / computeProduct(basis: distUnit);
56 }
57 return std::make_pair(x&: sgShape, y&: count);
58}
59
60/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
61/// from a workgroup descriptor. It replaces the offsets and sizes with
62/// appropriate values for the subgroup.
63/// It uses round-robin assignment to distribute the work to the subgroups.
64/// Following create_nd_desc operation:,
65/// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
66/// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
67/// sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
68/// is converted to 9 subgroup level operations based on the sg_layout &
69/// sg_data:
70/// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
71/// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2],
72/// lane_data = [1, 1]>>
73///
74/// The sg_layout and sg_data attributes are dropped after the pass as they are
75/// no longer needed.
76///
77/// 24x24 matrix distribution example:
78/// sg_layout = [4, 4], sg_data = [2, 2]
79/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
80/// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i]
81///
82/// +------------------------+
83/// | 8x8 | 8x8 | 8x8 | <- 3 tiles across
84/// |-----+-----+-----|
85/// | 8x8 | 8x8 | 8x8 | <- 3 tiles down
86/// |-----+-----+-----|
87/// | 8x8 | 8x8 | 8x8 |
88/// +------------------------+
89///
90/// Each 8x8 tile is further subdivided among subgroups:
91/// +------------------------+
92/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups across (each handles 2 columns)
93/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups down (each handles 2 rows)
94/// | 2x2 2x2 2x2 2x2 |
95/// | 2x2 2x2 2x2 2x2 |
96/// +------------------------+
97///
98/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be
99/// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations.
100
101/// The pass currently has entire distribution logic in the WgToSgCreateNdOp
102/// pattern and all the other ops just follow.
103/// TODO: Decouple the distribution logic from WgToSgCreateNdOp for all the
104/// ops in the pass.
105struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
106 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
107
108 // Calculate offset for each subgroup
109 SmallVector<OpFoldResult>
110 calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
111 const SmallVector<OpFoldResult> &originalOffsets,
112 const SmallVector<Value> &localOffset,
113 const SmallVector<int64_t> &distUnitBaseAddr,
114 const SmallVector<int64_t> &distUnitShape) const {
115 assert(localOffset.size() == distUnitBaseAddr.size() &&
116 "localOffset and distUnitBaseAddr must have the same rank");
117
118 SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
119 originalOffsets.end());
120 size_t rank = localOffset.size();
121 for (size_t i = 0; i < rank; ++i) {
122 size_t dimIdx = originalOffsets.size() - rank + i;
123 Value constOffset =
124 rewriter.create<arith::ConstantIndexOp>(location: loc, args: distUnitBaseAddr[i]);
125 Value offset =
126 rewriter.createOrFold<index::AddOp>(location: loc, args: localOffset[i], args&: constOffset);
127 Value modValue =
128 rewriter.create<arith::ConstantIndexOp>(location: loc, args: distUnitShape[i]);
129 Value offsetMod =
130 rewriter.createOrFold<index::RemUOp>(location: loc, args&: offset, args&: modValue);
131 Value origOffset = getValueOrCreateConstantIndexOp(
132 b&: rewriter, loc, ofr: originalOffsets[dimIdx]);
133 Value globalOffset =
134 rewriter.createOrFold<index::AddOp>(location: loc, args&: origOffset, args&: offsetMod);
135 globalOffsets[dimIdx] = globalOffset;
136 }
137
138 return globalOffsets;
139 }
140
141 LogicalResult
142 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
143 ConversionPatternRewriter &rewriter) const override {
144 Location loc = op.getLoc();
145 MLIRContext *ctx = op.getContext();
146 xegpu::TensorDescType tdescTy = op.getType();
147 auto layout = dyn_cast<xegpu::LayoutAttr>(Val: tdescTy.getLayout());
148 if (!layout)
149 return failure();
150 Type elemTy = tdescTy.getElementType();
151 ArrayRef<int64_t> wgShape = tdescTy.getShape();
152 // sgLayout must be present for workgroup-level distribution.
153 SmallVector<int64_t> sgLayout;
154 if (auto sgLayoutAttr = layout.getSgLayout())
155 sgLayout = llvm::to_vector_of<int64_t>(Range: sgLayoutAttr.asArrayRef());
156 else
157 return rewriter.notifyMatchFailure(
158 arg&: op, msg: "sgLayout attribute is required in layout");
159
160 SmallVector<int64_t> sgShape = getSgShapeAndCount(shape: wgShape, layout).first;
161
162 // TODO : Handle order attribute
163 // Get the subgroup ID
164 auto linearSgId =
165 rewriter.create<gpu::SubgroupIdOp>(location: loc, /*upper_bound=*/args: nullptr);
166
167 // Create constants for layout dimensions
168 SmallVector<Value> sgLayoutDim(sgLayout.size());
169 SmallVector<Value> sgDataDim(sgShape.size());
170
171 for (size_t i = 0; i < sgLayout.size(); i++) {
172 sgLayoutDim[i] =
173 rewriter.create<arith::ConstantIndexOp>(location: loc, args&: sgLayout[i]);
174 sgDataDim[i] = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: sgShape[i]);
175 }
176
177 auto deLinearizeSgId =
178 affine::delinearizeIndex(b&: rewriter, loc, linearIndex: linearSgId, basis: sgLayoutDim);
179 if (failed(Result: deLinearizeSgId))
180 return failure();
181 SmallVector<Value> sgIds = *deLinearizeSgId;
182
183 // Calculate distribution unit shape and local offsets for subgroup
184 SmallVector<int64_t> distUnitShape(sgLayout.size());
185 SmallVector<Value> localOffset(sgLayout.size());
186 for (size_t i = 0; i < sgLayout.size(); i++) {
187 distUnitShape[i] = std::min(a: sgLayout[i] * sgShape[i], b: wgShape[i]);
188 localOffset[i] =
189 rewriter.createOrFold<index::MulOp>(location: loc, args&: sgIds[i], args&: sgDataDim[i]);
190 }
191
192 SmallVector<OpFoldResult> originalOffsets = op.getMixedOffsets();
193
194 xegpu::TensorDescType newTdescTy =
195 xegpu::TensorDescType::get(context: ctx, shape: sgShape, elementType: elemTy, encoding: tdescTy.getEncoding(),
196 layout: layout.dropSgLayoutAndData());
197 SmallVector<Value> newCreateNdOps;
198 for (SmallVector<int64_t> distUnitBaseAddr :
199 StaticTileOffsetRange(wgShape, distUnitShape)) {
200 SmallVector<OpFoldResult> globalOffsets =
201 calculateGlobalOffsets(rewriter, loc, originalOffsets, localOffset,
202 distUnitBaseAddr, distUnitShape);
203
204 auto newCreateNdOp = rewriter.create<xegpu::CreateNdDescOp>(
205 location: loc, args&: newTdescTy, args: op.getSource(), args&: globalOffsets, args: op.getMixedSizes(),
206 args: op.getMixedStrides());
207 newCreateNdOps.push_back(Elt: newCreateNdOp);
208 }
209
210 rewriter.replaceOpWithMultiple(op, newValues: {newCreateNdOps});
211 return success();
212 }
213};
214
215/// This pattern transforms the LoadNdOp to load subgroup data.
216struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
217 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
218 LogicalResult
219 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
220 ConversionPatternRewriter &rewriter) const override {
221 SmallVector<Value> newLoadOps;
222 for (auto src : adaptor.getTensorDesc()) {
223 xegpu::TensorDescType tdescTy =
224 dyn_cast<xegpu::TensorDescType>(Val: src.getType());
225 ArrayRef<int64_t> srcShape = tdescTy.getShape();
226 VectorType newResTy = VectorType::get(shape: srcShape, elementType: tdescTy.getElementType());
227 auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(location: op.getLoc(), args&: newResTy,
228 args&: src, args: op->getAttrs());
229 newLoadOps.push_back(Elt: newLoadOp);
230 }
231 rewriter.replaceOpWithMultiple(op, newValues: {newLoadOps});
232 return mlir::success();
233 }
234};
235
236/// This pattern transforms the StoreNdOp to store to a subgroup descriptor
237/// It creates a StoreNdOp op to store the updated values to the new subgroup
238/// src tensor descriptors.
239struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
240 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
241 LogicalResult
242 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
243 ConversionPatternRewriter &rewriter) const override {
244 for (auto [v, t] : llvm::zip(t: adaptor.getValue(), u: adaptor.getTensorDesc()))
245 rewriter.create<xegpu::StoreNdOp>(location: op.getLoc(), args&: v, args&: t, args: op.getL1HintAttr(),
246 args: op.getL2HintAttr(), args: op.getL3HintAttr());
247
248 rewriter.eraseOp(op);
249 return success();
250 }
251};
252
253/// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
254/// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
255/// offsets of the new subgroup src tensor descriptors.
256struct WgToSgUpdateNdOffsetOp
257 : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
258 using OpConversionPattern<xegpu::UpdateNdOffsetOp>::OpConversionPattern;
259 LogicalResult
260 matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
261 ConversionPatternRewriter &rewriter) const override {
262 llvm::SmallVector<Value> newUpdateTileOffsetOps;
263 for (auto tDesc : adaptor.getTensorDesc()) {
264 auto newUpdateTileOffsetOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
265 location: op.getLoc(), args: tDesc.getType(), args&: tDesc, args: op.getOffsets(),
266 args: op.getConstOffsets());
267 newUpdateTileOffsetOps.push_back(Elt: newUpdateTileOffsetOp);
268 }
269
270 rewriter.replaceOpWithMultiple(op, newValues: {newUpdateTileOffsetOps});
271 return success();
272 }
273};
274
275/// This pattern transforms the DpasOp to work at subgroup level.
276struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
277 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
278 LogicalResult
279 matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
280 ConversionPatternRewriter &rewriter) const override {
281 Location loc = op.getLoc();
282 VectorType resultTy = op.getResult().getType();
283 if (resultTy.getRank() != 2)
284 return failure();
285
286 auto originalLayout = xegpu::getLayoutAttr(value: op.getResult());
287 if (!originalLayout)
288 return failure();
289
290 size_t i = 0;
291 SmallVector<Value> newDpasOps;
292 for (auto aVec : adaptor.getLhs()) {
293 for (auto bVec : adaptor.getRhs()) {
294
295 llvm::SmallVector<Value> operands({aVec, bVec});
296 Value tmpC;
297 if (op.getAcc()) {
298 tmpC = adaptor.getAcc()[i++];
299 operands.push_back(Elt: tmpC);
300 }
301
302 ArrayRef<int64_t> aVecShape =
303 llvm::cast<VectorType>(Val: aVec.getType()).getShape();
304 ArrayRef<int64_t> bVecShape =
305 llvm::cast<VectorType>(Val: bVec.getType()).getShape();
306 VectorType resTy = VectorType::get(shape: {aVecShape[0], bVecShape[1]},
307 elementType: resultTy.getElementType());
308 tmpC = rewriter.create<xegpu::DpasOp>(location: loc, args&: resTy, args&: operands);
309 xegpu::setLayoutAttr(operandOrResult: cast<OpResult>(Val&: tmpC),
310 layout: originalLayout.dropSgLayoutAndData());
311
312 newDpasOps.push_back(Elt: tmpC);
313 }
314 }
315 rewriter.replaceOpWithMultiple(op, newValues: {newDpasOps});
316 return success();
317 }
318};
319
320/// This pattern transforms the PrefetchNdOp to prefetch the subgroup data.
321struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
322 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
323 LogicalResult
324 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
325 ConversionPatternRewriter &rewriter) const override {
326 for (auto src : adaptor.getTensorDesc())
327 rewriter.create<xegpu::PrefetchNdOp>(location: op.getLoc(), args: TypeRange(), args&: src,
328 args: op->getAttrs());
329 rewriter.eraseOp(op);
330 return success();
331 }
332};
333
334// This pattern transforms elementwise ops to work at subgroup level.
335struct WgToSgElementwiseOp : public ConversionPattern {
336 WgToSgElementwiseOp(MLIRContext *ctx)
337 : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
338
339 LogicalResult
340 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
341 ConversionPatternRewriter &rewriter) const override {
342 // Only match ops with elementwise trait and single result.
343 if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
344 return failure();
345
346 auto resultType = dyn_cast<VectorType>(Val: op->getResult(idx: 0).getType());
347 assert(resultType && "Expected result to be a VectorType");
348
349 ArrayRef<int64_t> wgShape = resultType.getShape();
350
351 xegpu::LayoutAttr layout = xegpu::getLayoutAttr(value: op->getResult(idx: 0));
352 if (!layout || !layout.getSgLayout())
353 return failure();
354
355 SmallVector<int64_t> sgShape = getSgShapeAndCount(shape: wgShape, layout).first;
356
357 size_t numVariants = operands.empty() ? 0 : operands.front().size();
358
359 if (llvm::any_of(Range&: operands, P: [&](const ValueRange &operandVec) {
360 return operandVec.size() != numVariants;
361 }))
362 return failure();
363
364 SmallVector<Value> newResults;
365 VectorType newResultType =
366 VectorType::get(shape: sgShape, elementType: resultType.getElementType());
367
368 for (size_t i = 0; i < numVariants; ++i) {
369 SmallVector<Value> opOperands;
370 for (auto &operandVec : operands)
371 opOperands.push_back(Elt: operandVec[i]);
372
373 OperationState state(op->getLoc(), op->getName());
374 state.addOperands(newOperands: opOperands);
375 state.addTypes(newTypes: newResultType);
376 // Copy all attributes, but update "layout_result_0" to drop
377 // sgLayout/sgData
378 for (auto attr : op->getAttrs()) {
379 if (auto layout = dyn_cast<xegpu::LayoutAttr>(Val: attr.getValue())) {
380 if (auto newLayout = layout.dropSgLayoutAndData())
381 state.addAttribute(name: attr.getName(), attr: newLayout);
382 } else {
383 state.addAttribute(name: attr.getName(), attr: attr.getValue());
384 }
385 }
386 Operation *newOp = rewriter.create(state);
387 newResults.push_back(Elt: newOp->getResult(idx: 0));
388 }
389
390 rewriter.replaceOpWithMultiple(op, newValues: {newResults});
391 return success();
392 }
393};
394
395// Handles UnrealizedConversionCastOp generated during
396// SCFStructuralTypeConversions (step 1). This op may appear as either a
397// target or source materialization for Vector values, e.g.:
398// 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ...
399// 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32>
400// it could be either 1:N or N:1 cast. In both cases, the pattern
401// simply forwards the inputs to the outputs using 1:1 or 1:N interface.
402// for example, the following scf::forOp
403// ```
404// %for = scf.for ... iter_args(%arg1 = %0)->(vector<128x128xf16>) {
405// %n = use(%arg1): vector<128x128xf16>
406// scf.yield %n : vector<128x128xf16>
407// }
408// ```
409// Could be converted to:
410// ```
411// %1 = unrealized_conversion_cast %0
412// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
413// %for:2 = scf.for ... iter_args(%arg1 = %1#1, %arg2 = %1#2)
414// -> (vector<16x16xf16>, vector<16x16xf16) {
415// %m = unrealized_conversion_cast %arg1, %arg2
416// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
417// %n = use(%m): vector<128x128xf16>
418// %b = unrealized_conversion_cast %n
419// : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
420// scf.yield %b#1, %b#2 : vector<16x16xf16>, vector<16x16xf16>
421// }
422// %cast = unrealized_conversion_cast %for:2
423// : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
424// ```
425// TODO: remove it when context-aware type converter is ready.
426struct UnrealizedConversionCastOpPattern
427 : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
428 using OpConversionPattern<
429 mlir::UnrealizedConversionCastOp>::OpConversionPattern;
430
431 mlir::LogicalResult
432 matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
433 ConversionPatternRewriter &rewriter) const override {
434 SmallVector<Value> inputs = xegpu::flattenValues(values: adaptor.getInputs());
435
436 auto inputTy = dyn_cast<VectorType>(Val: inputs[0].getType());
437 auto outputTy = dyn_cast<VectorType>(Val: op->getOpResult(idx: 0).getType());
438
439 if (!inputTy || !outputTy || !llvm::all_equal(Range: op->getResultTypes()) ||
440 !llvm::all_equal(Range: ValueRange(inputs).getTypes()))
441 return failure();
442
443 // Handles the case "cast %1 : vector<256xf32> to vector<16xf32>, ...".
444 // It is generated by source materialization (e.g., inits to scf forOp).
445 // The input values provided by the adaptor should already be distributed,
446 // and their types should correspond exactly to the result types of the
447 // operation.
448 if (op.getNumOperands() == 1 &&
449 llvm::equal(LRange: ValueRange(inputs).getTypes(), RRange: op->getResultTypes())) {
450 rewriter.replaceOp(op, newValues: inputs);
451 return success();
452 }
453
454 // Handles the case "cast %1 : vector<16xf32>, ... to vector<256xf32>".
455 // It is generated by target materialization (e.g., arguments/results
456 // of scf forOp). All input values must have the same vector type, and
457 // their shape must be evenly divisible by the output vector's shape
458 // (determined by the nature of the workgroup to subgroup distribution).
459 // TODO: it is not safe to do such forward, since such N:1 cast could be
460 // from others.
461 if (op.getNumResults() == 1 &&
462 computeShapeRatio(shape: outputTy.getShape(), subShape: inputTy.getShape())) {
463 rewriter.replaceOpWithMultiple(op, newValues: {inputs});
464 return success();
465 }
466
467 return mlir::failure();
468 }
469};
470
471} // namespace
472
473namespace mlir {
474namespace xegpu {
475void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
476 patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
477 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
478 UnrealizedConversionCastOpPattern, WgToSgElementwiseOp>(
479 arg: patterns.getContext());
480}
481} // namespace xegpu
482} // namespace mlir
483
484namespace {
485struct XeGPUWgToSgDistributePass
486 : public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> {
487 void runOnOperation() override;
488};
489} // namespace
490
491void XeGPUWgToSgDistributePass::runOnOperation() {
492 // Track existing UnrealizedConversionCastOps
493 SmallVector<Operation *> existingCastOps;
494 getOperation()->walk(callback: [&](UnrealizedConversionCastOp castOp) {
495 existingCastOps.push_back(Elt: castOp.getOperation());
496 });
497
498 {
499 // Step 1: Apply SCFStructuralTypeConversions to SCF operations with
500 // VectorType operands. This first converts such operands to
501 // RankedTensorType, propagates the layout attribute into the encoding
502 // attribute, and finally converts the RankedTensorType to VectorType based
503 // on the encoding.
504
505 TypeConverter converter;
506 converter.addConversion(callback: [&](Type type) -> Type { return type; });
507 converter.addConversion(
508 callback: [&](RankedTensorType type,
509 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
510 Type elemTy = type.getElementType();
511 ArrayRef<int64_t> shape = type.getShape();
512
513 int count;
514 SmallVector<int64_t> subShape;
515 std::tie(args&: subShape, args&: count) = getSgShapeAndCount(
516 shape,
517 layout: dyn_cast_if_present<xegpu::LayoutAttr>(Val: type.getEncoding()));
518
519 auto newTy = VectorType::get(shape: subShape, elementType: elemTy);
520 result.append(NumInputs: count, Elt: newTy);
521 return success();
522 });
523
524 xegpu::doSCFStructuralTypeConversionWithTensorType(op: getOperation(),
525 converter);
526 }
527
528 // Step 2: Perform workgroup to subgroup distribution for TensorDesc values,
529 // as well as XeGPU, Arith, and Vector operations.
530 MLIRContext *ctx = &getContext();
531 RewritePatternSet patterns(ctx);
532 ConversionTarget target(*ctx);
533 TypeConverter converter;
534 converter.addConversion(callback: [&](Type type) -> Type { return type; });
535 converter.addConversion(
536 callback: [&](xegpu::TensorDescType type,
537 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
538 Type elemTy = type.getElementType();
539 ArrayRef<int64_t> shape = type.getShape();
540
541 int count;
542 SmallVector<int64_t> subShape;
543 xegpu::LayoutAttr layout = type.getLayoutAttr();
544 std::tie(args&: subShape, args&: count) = getSgShapeAndCount(shape, layout);
545
546 if (layout)
547 layout = layout.dropSgLayoutAndData();
548
549 auto newTy = xegpu::TensorDescType::get(
550 context: type.getContext(), shape: subShape, elementType: elemTy, encoding: type.getEncoding(), layout);
551 result.append(NumInputs: count, Elt: newTy);
552 return success();
553 });
554
555 auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
556 if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(Val: op))
557 return createOp.getType();
558 if (auto loadOp = dyn_cast<xegpu::LoadNdOp>(Val: op))
559 return loadOp.getTensorDescType();
560 if (auto storeOp = dyn_cast<xegpu::StoreNdOp>(Val: op))
561 return storeOp.getTensorDescType();
562 if (auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(Val: op))
563 return updateOp.getType();
564 if (auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(Val: op))
565 return prefetchOp.getTensorDescType();
566 return xegpu::TensorDescType();
567 };
568
569 auto isLegal = [&](xegpu::LayoutAttr layout) -> bool {
570 return !layout || !layout.isWgLayout();
571 };
572
573 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
574 xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
575 xegpu::PrefetchNdOp>(callback: [=](Operation *op) -> bool {
576 auto tdescTy = getTensorDescType(op);
577 auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(Val: tdescTy.getLayout());
578 return isLegal(layout);
579 });
580
581 target.addDynamicallyLegalOp<xegpu::DpasOp>(callback: [=](xegpu::DpasOp op) -> bool {
582 auto layout = xegpu::getLayoutAttr(value: op.getResult());
583 return isLegal(layout);
584 });
585
586 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
587 callback: [=](Operation *op) -> std::optional<bool> {
588 // Only handle elementwise mappable ops
589 if (!OpTrait::hasElementwiseMappableTraits(op))
590 return true;
591
592 VectorType resultType =
593 dyn_cast<VectorType>(Val: op->getResult(idx: 0).getType());
594 if (!resultType)
595 return true;
596
597 // Check if all operands are vectors of the same shape
598 // TODO: Support other types.
599 for (Value operand : op->getOperands()) {
600 VectorType operandType = dyn_cast<VectorType>(Val: operand.getType());
601 if (!operandType || operandType.getShape() != resultType.getShape()) {
602 return true;
603 }
604 }
605
606 xegpu::LayoutAttr layout = xegpu::getLayoutAttr(value: op->getResult(idx: 0));
607 return isLegal(layout);
608 });
609
610 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
611 callback: [=](UnrealizedConversionCastOp op) {
612 return llvm::is_contained(Range: existingCastOps, Element: op.getOperation());
613 });
614
615 target.markUnknownOpDynamicallyLegal(fn: [](Operation *) { return true; });
616
617 scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter: converter, patterns,
618 target);
619 xegpu::populateXeGPUWgToSgDistributePatterns(patterns);
620 if (failed(
621 Result: applyPartialConversion(op: getOperation(), target, patterns: std::move(patterns))))
622 return signalPassFailure();
623
624 // Remove sg_layout and sg_data attributes from the Layout
625 // attribute for each VectorType result of the operation.
626 // For Structured Control Flow ops, the layout is simply removed,
627 // since in 1:N case, the layout for new results are missing.
628 // Layout propagation pass will activated.
629 getOperation()->walk(callback: [](Operation *op) {
630 for (OpResult result : op->getOpResults()) {
631 std::string name = xegpu::getLayoutName(result);
632 if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
633 op->removeAttr(name);
634 if (!isa<scf::IfOp, scf::ForOp, scf::WhileOp, scf::ConditionOp>(Val: op)) {
635 if (auto newLayout = layout.dropSgLayoutAndData())
636 op->setAttr(name, value: newLayout);
637 }
638 }
639 }
640 });
641}
642

source code of mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp