1//===- XeGPUSubgroupDistribute.cpp - XeGPU Subgroup Distribute Pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8#include "mlir/Dialect/GPU/IR/GPUDialect.h"
9#include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
10#include "mlir/Dialect/MemRef/IR/MemRef.h"
11#include "mlir/Dialect/Vector/IR/VectorOps.h"
12#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
13#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
14#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
15#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
16#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
17#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
18#include "mlir/IR/AffineMap.h"
19#include "mlir/IR/Attributes.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/BuiltinAttributes.h"
22#include "mlir/IR/BuiltinOps.h"
23#include "mlir/IR/BuiltinTypes.h"
24#include "mlir/IR/Operation.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/IR/TypeRange.h"
27#include "mlir/IR/Value.h"
28#include "mlir/IR/Visitors.h"
29#include "mlir/Interfaces/FunctionInterfaces.h"
30#include "mlir/Support/LLVM.h"
31#include "mlir/Transforms/DialectConversion.h"
32#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33#include "mlir/Transforms/InliningUtils.h"
34#include "llvm/ADT/ArrayRef.h"
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/ADT/SmallVector.h"
37
38namespace mlir {
39namespace xegpu {
40#define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
41#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
42} // namespace xegpu
43} // namespace mlir
44
45#define DEBUG_TYPE "xegpu-subgroup-distribute"
46#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
47
48using namespace mlir;
49
50static const char *const resolveSIMTTypeMismatch =
51 "resolve_simt_type_mismatch"; // Attribute name for identifying
52 // UnrelizedConversionCastOp added to resolve
53 // SIMT type mismatches.
54
55namespace {
56
57//===----------------------------------------------------------------------===//
58// SIMT Distribution Patterns
59//===----------------------------------------------------------------------===//
60
61/// Helper function to get distributed vector type for a source vector type
62/// according to the lane_layout. We simply divide each dimension of tensor
63/// descriptor shape by corresponding lane_layout dimension. If
64/// array_length > 1, that is appended to the front of the ditributed shape.
65/// NOTE: This is the vector type that will be returned by the
66/// gpu.warp_execute_on_lane0 op.
67///
68/// Examples:
69/// | original vector shape | lane_layout | distributed vector shape |
70/// |-----------------------|-------------|--------------------------|
71/// | 32x16 | [1, 16] | 32x1 |
72/// | 32x16 | [2, 8] | 16x2 |
73/// | 2x32x16 | [1, 16] | 2x32x1 |
74static FailureOr<VectorType>
75getDistVecTypeBasedOnLaneLayout(xegpu::LayoutAttr layout,
76 VectorType originalType) {
77 if (!layout)
78 return failure();
79
80 auto laneLayout = layout.getLaneLayout().asArrayRef();
81 assert(originalType.getShape().size() >= laneLayout.size() &&
82 "Rank of the original vector type should be greater or equal to the "
83 "size of the lane layout to distribute the vector type.");
84 SmallVector<int64_t> distributedShape(originalType.getShape());
85 // Only distribute the last `laneLayout.size()` dimensions. The remaining
86 // dimensions are not distributed.
87 unsigned distributionStart = originalType.getRank() - laneLayout.size();
88 for (auto [i, dim] : llvm::enumerate(First: originalType.getShape())) {
89 if (i < distributionStart)
90 continue;
91
92 // Check if the dimension can be distributed evenly.
93 if (dim % laneLayout[i - distributionStart] != 0)
94 return failure();
95 distributedShape[i] = dim / laneLayout[i - distributionStart];
96 }
97 return VectorType::get(shape: distributedShape, elementType: originalType.getElementType());
98}
99
100/// Helper function to resolve types if the distributed type out of
101/// gpu.warp_execute_on_lane0 is different from the expected xegpu SIMT type.
102/// Example 1:
103/// distributed type: vector<8x1xf32>
104/// expected type: vector<8xf32>
105/// resolved using,
106/// %0 = vector.shape_cast %1 : vector<8x1xf32> to vector<8xf32>
107/// Example 2:
108/// distributed type: xegpu.tensor_desc<8x16xf32, #xegpu.layout<...>>
109/// expected type: xegpu.tensor_desc<8x16xf32>
110/// resolved using,
111/// %0 = unrealized_conversion_cast %1 :
112/// xegpu.tensor_desc<8x16xf32, #xegpu.layout<..>> ->
113/// xegpu.tensor_desc<8x16xf32>
114template <typename T>
115static Value resolveDistributedTy(Value orig, T expected,
116 PatternRewriter &rewriter) {
117 // If orig and expected types are the same, return orig.
118 if (orig.getType() == expected)
119 return orig;
120 // If orig is a vector type, create a shape cast op to reconcile the types.
121 if (isa<VectorType>(Val: orig.getType())) {
122 auto castOp =
123 rewriter.create<vector::ShapeCastOp>(orig.getLoc(), expected, orig);
124 return castOp.getResult();
125 }
126 // If orig is a tensor descriptor type, create an unrealized conversion cast
127 // op to reconcile the types.
128 if (isa<xegpu::TensorDescType>(Val: orig.getType())) {
129 auto castOp = rewriter.create<UnrealizedConversionCastOp>(orig.getLoc(),
130 expected, orig);
131 castOp->setAttr(resolveSIMTTypeMismatch, rewriter.getUnitAttr());
132 return castOp.getResult(0);
133 }
134 llvm_unreachable("Unsupported type for reconciliation");
135 return orig;
136}
137
138/// Helper function to filter out the temporary layout attributes attached
139/// during the layout assignment process. These are not needed after going to
140/// SIMT.
141static SmallVector<NamedAttribute>
142removeTemporaryLayoutAttributes(ArrayRef<NamedAttribute> attrs) {
143 SmallVector<NamedAttribute> newAttrs;
144 for (NamedAttribute attr : attrs) {
145 if (!isa<xegpu::LayoutAttr>(Val: attr.getValue()))
146 newAttrs.push_back(Elt: attr);
147 }
148 return newAttrs;
149}
150
151/// Helper function to check if the layout is packed. Layout is packed if it is
152/// 2D and lane_data[0] != 1 (data packed from col dimension).
153static bool hasPackedLayout(xegpu::LayoutAttr layout) {
154 if (layout == xegpu::LayoutAttr())
155 return false;
156 DenseI32ArrayAttr laneData = layout.getLaneData();
157 if (!laneData || laneData.size() != 2)
158 return false;
159 return laneData.asArrayRef()[0] != 1;
160}
161
162/// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
163/// of the original GPUFuncOp to the new GPUFuncOp such that entire body is
164/// contained within a WarpExecuteOnLane0Op.
165/// Example:
166///
167/// ```
168/// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
169/// ...
170/// ...
171/// gpu.return %result: vector<8x16xf32>
172/// }
173/// ```
174/// To
175/// ```
176/// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
177/// %laneid = gpu.lane_id : index
178/// %0 = gpu.warp_execute_on_lane_0(%laneid) -> vector<8x16xf32> {
179/// ...
180/// ...
181/// gpu.yield %result: vector<8x16xf32>
182/// }
183/// return %0
184/// }
185struct MoveFuncBodyToWarpExecuteOnLane0
186 : public OpRewritePattern<gpu::GPUFuncOp> {
187 using OpRewritePattern<gpu::GPUFuncOp>::OpRewritePattern;
188 LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
189 PatternRewriter &rewriter) const override {
190 // If the function only contains a single void return, skip.
191 if (llvm::all_of(Range: gpuFuncOp.getBody().getOps(), P: [](Operation &op) {
192 return isa<gpu::ReturnOp>(Val: op) && !op.getNumOperands();
193 }))
194 return failure();
195 // If the function already moved inside a warp_execute_on_lane0, skip.
196 if (llvm::any_of(Range: gpuFuncOp.getBody().getOps(), P: [](Operation &op) {
197 return isa<gpu::WarpExecuteOnLane0Op>(Val: op);
198 }))
199 return failure();
200 // Create a new function with the same signature.
201 auto newGpuFunc = rewriter.create<gpu::GPUFuncOp>(
202 location: gpuFuncOp.getLoc(), args: gpuFuncOp.getName(), args: gpuFuncOp.getFunctionType());
203 // Create a WarpExecuteOnLane0Op with same arguments and results as the
204 // original gpuFuncOp.
205 rewriter.setInsertionPointToEnd(&newGpuFunc.getFunctionBody().front());
206 auto laneId = rewriter.create<gpu::LaneIdOp>(
207 location: newGpuFunc.getLoc(), args: rewriter.getIndexType(),
208 /** upperBound = **/ args: mlir::IntegerAttr());
209 ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
210 auto warpOp = rewriter.create<gpu::WarpExecuteOnLane0Op>(
211 location: laneId.getLoc(), args&: gpuFuncResultType, args&: laneId,
212 args: xegpu::targetinfo::subgroupSize, args: newGpuFunc.getArguments(),
213 args: newGpuFunc.getArgumentTypes());
214 Block &warpBodyBlock = warpOp.getBodyRegion().front();
215 // Replace the ReturnOp of the original gpu function with a YieldOp.
216 auto origRetunOp =
217 cast<gpu::ReturnOp>(Val: gpuFuncOp.getBlocks().back().getTerminator());
218 rewriter.setInsertionPointAfter(origRetunOp);
219 rewriter.create<gpu::YieldOp>(location: origRetunOp.getLoc(),
220 args: origRetunOp.getOperands());
221 rewriter.eraseOp(op: origRetunOp);
222 // Move the original function body to the WarpExecuteOnLane0Op body.
223 rewriter.inlineRegionBefore(region&: gpuFuncOp.getBody(), parent&: warpOp.getBodyRegion(),
224 before: warpOp.getBodyRegion().begin());
225 rewriter.eraseBlock(block: &warpBodyBlock);
226 // Insert a new ReturnOp after the WarpExecuteOnLane0Op.
227 rewriter.setInsertionPointAfter(warpOp);
228 rewriter.create<gpu::ReturnOp>(location: newGpuFunc.getLoc(), args: warpOp.getResults());
229 rewriter.replaceOp(op: gpuFuncOp, newOp: newGpuFunc);
230 return success();
231 }
232};
233
234/// Distribute a create_nd_tdesc feeding into vector.yield op of the enclosing
235/// `gpu.warp_execute_on_lane_0` region. After the sinking, the warp op will
236/// still contain the original op that will not be used by the yield op (and
237/// should be cleaned up later). The yield op will bypass the create_nd_tdesc's
238/// arguments. Tensor descriptor shape is not distributed because it is a
239/// uniform value across all work items within the subgroup. However, the
240/// layout information is dropped in the new tensor descriptor type.
241///
242/// Example:
243///
244/// ```
245/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
246/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
247/// (!xegpu.tensor_desc<4x8xf32, #layout0>) {
248/// ...
249/// %td = xegpu.create_nd_tdesc %arg0[0, 0]
250/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
251/// vector.yield %td
252/// }
253/// ```
254/// To
255/// ```
256/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (...) {
257/// ...
258/// %dead = xegpu.create_nd_tdesc %arg0[0, 0]
259/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
260/// vector.yield %arg0, %dead
261/// }
262/// %td = xegpu.create_nd_tdesc %r#0[0, 0]: memref<4x8xf32>
263/// -> !xegpu.tensor_desc<4x8xf32>
264///
265/// ```
266struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
267 using gpu::WarpDistributionPattern::WarpDistributionPattern;
268 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
269 PatternRewriter &rewriter) const override {
270 OpOperand *operand =
271 getWarpResult(warpOp: subgroupOp, fn: llvm::IsaPred<xegpu::CreateNdDescOp>);
272 if (!operand)
273 return rewriter.notifyMatchFailure(
274 arg&: subgroupOp, msg: "warp result is not a xegpu::CreateNdDesc op");
275 auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
276 unsigned operandIdx = operand->getOperandNumber();
277
278 xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr();
279 if (!layout)
280 return rewriter.notifyMatchFailure(
281 arg&: descOp, msg: "the tensor descriptor lacks layout attribute");
282
283 SmallVector<size_t> newRetIndices;
284 SmallVector<Value> newYieldValues;
285 SmallVector<Type> newYieldTypes;
286
287 for (Value operand : descOp->getOperands()) {
288 newYieldValues.push_back(Elt: operand);
289 newYieldTypes.push_back(Elt: operand.getType());
290 }
291 rewriter.setInsertionPoint(subgroupOp);
292 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
293 rewriter, warpOp: subgroupOp, /* new yieled values = */ newYieldedValues: newYieldValues,
294 /* new yielded types = */ newReturnTypes: newYieldTypes, indices&: newRetIndices);
295
296 SmallVector<Value> newDescOperands;
297 for (size_t i : newRetIndices) {
298 newDescOperands.push_back(Elt: newWarpOp.getResult(i));
299 }
300 rewriter.setInsertionPointAfter(newWarpOp);
301 xegpu::TensorDescType distributedTensorDescTy =
302 descOp.getType().dropLayouts(); // Distributed tensor descriptor type
303 // does not contain layout info.
304 Value newDescOp = rewriter.create<xegpu::CreateNdDescOp>(
305 location: newWarpOp.getLoc(), args&: distributedTensorDescTy, args&: newDescOperands,
306 args: descOp->getAttrs());
307
308 Value distributedVal = newWarpOp.getResult(i: operandIdx);
309 // Resolve the distributed type to the expected type.
310 newDescOp =
311 resolveDistributedTy(orig: newDescOp, expected: distributedVal.getType(), rewriter);
312 rewriter.replaceAllUsesWith(from: distributedVal, to: newDescOp);
313 return success();
314 }
315};
316
317/// Distribute a store_nd op at the end of enclosing
318/// `gpu.warp_execute_on_lane_0`. In case arguments for the store are passed
319/// through the warp op interface they would be propagated as returned values.
320/// Source vector is distributed based on lane layout. Appropriate cast ops are
321/// inserted if the distributed types does not match expected xegpu SIMT types.
322///
323/// Example:
324///
325/// ```
326/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
327/// gpu.warp_execute_on_lane_0(%laneid) -> () {
328/// ...
329/// xegpu.store_nd %arg0, %arg1: vector<4x8xf32>,
330/// !xegpu.tensor_desc<4x8xf32, #layout0>
331/// }
332/// ```
333/// To
334/// ```
335/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
336/// !xegpu.tensor_desc<4x8xf32, #layout0>) {
337/// gpu.yield %arg0, %arg1: vector<4x8xf32>, !xegpu.tensor_desc<4x8xf32,
338/// #layout0>
339/// }
340/// %0 = vector.shape_cast %r#0: vector<4x1xf32> to vector<4xf32>
341/// %1 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
342/// #layout0>
343/// -> !xegpu.tensor_desc<4x8xf32>
344/// xegpu.store_nd %0, %1: vector<4xf32>,
345/// !xegpu.tensor_desc<4x8xf32>
346///
347/// ```
348struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
349 using gpu::WarpDistributionPattern::WarpDistributionPattern;
350 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
351 PatternRewriter &rewriter) const override {
352 auto yield = cast<gpu::YieldOp>(
353 Val: subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
354 Operation *lastNode = yield->getPrevNode();
355 auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(Val: lastNode);
356 if (!storeOp)
357 return failure();
358
359 xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
360 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
361 if (!layout)
362 return rewriter.notifyMatchFailure(
363 arg&: storeOp, msg: "the source tensor descriptor lacks layout attribute");
364
365 FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
366 getDistVecTypeBasedOnLaneLayout(layout, originalType: storeOp.getValueType());
367 if (failed(Result: distributedTypeByWarpOpOrFailure))
368 return rewriter.notifyMatchFailure(arg&: storeOp,
369 msg: "Failed to distribute the type");
370 VectorType distributedTypeByWarpOp =
371 distributedTypeByWarpOpOrFailure.value();
372
373 SmallVector<size_t> newRetIndices;
374 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
375 rewriter, warpOp: subgroupOp,
376 /* new yielded values = */
377 newYieldedValues: ValueRange{storeOp.getValue(), storeOp.getTensorDesc()},
378 /* new yielded types = */
379 newReturnTypes: TypeRange{distributedTypeByWarpOp, storeOp.getTensorDescType()},
380 indices&: newRetIndices);
381 // Create a new store op outside the warp op with the distributed vector
382 // type. Tensor descriptor is not distributed.
383 rewriter.setInsertionPointAfter(newWarpOp);
384 SmallVector<Value> newStoreOperands;
385
386 // For the value operand, there can be a mismatch between the vector type
387 // distributed by the warp op and (xegpu-specific) distributed type
388 // supported by the store op. Type mismatch must be resolved using
389 // appropriate cast op.
390 FailureOr<VectorType> storeNdDistributedValueTyOrFailure =
391 xegpu::getDistributedVectorType(tdescTy: storeOp.getTensorDescType());
392 if (failed(Result: storeNdDistributedValueTyOrFailure))
393 return rewriter.notifyMatchFailure(
394 arg&: storeOp, msg: "Failed to get distributed vector type for the store op");
395 newStoreOperands.push_back(Elt: resolveDistributedTy(
396 orig: newWarpOp.getResult(i: newRetIndices[0]),
397 expected: storeNdDistributedValueTyOrFailure.value(), rewriter));
398 // For the tensor descriptor operand, the layout attribute is dropped after
399 // distribution. Types needs to be resolved in this case also.
400 xegpu::TensorDescType distributedTensorDescTy =
401 storeOp.getTensorDescType().dropLayouts();
402 newStoreOperands.push_back(
403 Elt: resolveDistributedTy(orig: newWarpOp.getResult(i: newRetIndices[1]),
404 expected: distributedTensorDescTy, rewriter));
405
406 rewriter.create<xegpu::StoreNdOp>(
407 location: newWarpOp.getLoc(), args: TypeRange{}, args&: newStoreOperands,
408 args: removeTemporaryLayoutAttributes(attrs: storeOp->getAttrs()));
409 rewriter.eraseOp(op: storeOp);
410 return success();
411 }
412};
413
414/// Distribute a load_nd op feeding into vector.yield op for the enclosing
415/// `gpu.warp_execute_on_lane_0` and put it after the warp op.
416/// The warp op will still contain the original op that will not be used by
417/// the yield op (and should be cleaned up later). The yield op will
418/// bypass the load's arguments. Only the loaded vector is distributed
419/// according to lane layout and, tensor descriptor types is not
420/// distributed. Appropriate cast ops are inserted if the distributed types does
421/// not match expected xegpu SIMT types.
422///
423/// Example:
424///
425/// ```
426/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
427/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
428/// (vector<4x1xf32>) {
429/// ...
430/// %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32, #layout0>
431/// ->
432/// vector<4x8xf32>
433/// gpu.yield %ld
434/// }
435/// ```
436/// To
437/// ```
438/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
439/// !xegpu.tensor_desc<4x8xf32, #layout0>) {
440/// ...
441/// %dead = xegpu.load_nd %arg0: !xegpu.tensor_desc<4x8xf32, #layout0> ->
442/// vector<4x8xf32> gpu.yield %dead, %arg0
443/// }
444/// %0 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
445/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
446/// %1 = xegpu.load_nd %0: !xegpu.tensor_desc<4x8xf32> -> vector<4xf32>
447/// %2 = vector.shape_cast %r#0: vector<4xf32> to vector<4x1xf32>
448///
449/// ```
450struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
451 using gpu::WarpDistributionPattern::WarpDistributionPattern;
452 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
453 PatternRewriter &rewriter) const override {
454 OpOperand *operand =
455 getWarpResult(warpOp: subgroupOp, fn: llvm::IsaPred<xegpu::LoadNdOp>);
456 if (!operand)
457 return rewriter.notifyMatchFailure(
458 arg&: subgroupOp, msg: "warp result is not a xegpu::LoadNd op");
459 // Make sure the load op is the last operation in the warp op body. This
460 // ensure that load op is not sinked earlier violating any barrier
461 // synchronizations.
462 auto yield = cast<gpu::YieldOp>(
463 Val: subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
464 Operation *lastNode = yield->getPrevNode();
465 if (!dyn_cast_or_null<xegpu::LoadNdOp>(Val: lastNode))
466 return failure();
467
468 auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
469 xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
470 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
471 if (!layout)
472 return rewriter.notifyMatchFailure(
473 arg&: loadOp, msg: "the source tensor descriptor lacks layout attribute");
474
475 unsigned operandIdx = operand->getOperandNumber();
476 VectorType distributedTypeByWarpOp =
477 cast<VectorType>(Val: subgroupOp.getResult(i: operandIdx).getType());
478
479 SmallVector<size_t> newRetIndices;
480 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
481 rewriter, warpOp: subgroupOp,
482 /* new yielded values = */ newYieldedValues: loadOp.getTensorDesc(),
483 /* new yielded types = */ newReturnTypes: tensorDescTy, indices&: newRetIndices);
484
485 // Create a new load op outside the warp op with the distributed vector
486 // type.
487 rewriter.setInsertionPointAfter(newWarpOp);
488 FailureOr<VectorType> loadNdDistValueTyOrFailure =
489 xegpu::getDistributedVectorType(tdescTy: loadOp.getTensorDescType());
490 if (failed(Result: loadNdDistValueTyOrFailure))
491 return rewriter.notifyMatchFailure(
492 arg&: loadOp, msg: "Failed to get distributed vector type for the load op");
493 xegpu::TensorDescType distributedTensorDescTy =
494 loadOp.getTensorDescType().dropLayouts(); // Distributed tensor
495 // descriptor type does not
496 // contain layout info.
497 auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
498 location: newWarpOp.getLoc(), args&: loadNdDistValueTyOrFailure.value(),
499 args: resolveDistributedTy(orig: newWarpOp->getResult(idx: newRetIndices[0]),
500 expected: distributedTensorDescTy, rewriter),
501 args: removeTemporaryLayoutAttributes(attrs: loadOp->getAttrs()));
502 // Set the packed attribute if the layout requires it.
503 newLoadOp.setPacked(hasPackedLayout(layout));
504 Value distributedVal = newWarpOp.getResult(i: operandIdx);
505 // There can be a conflict between the vector type distributed by the
506 // warp op and (xegpu-specific) distributed type supported by the load
507 // op. Resolve these mismatches by inserting a cast.
508 Value tyResolvedVal = resolveDistributedTy(
509 orig: newLoadOp.getResult(), expected: distributedTypeByWarpOp, rewriter);
510 rewriter.replaceAllUsesWith(from: distributedVal, to: tyResolvedVal);
511 return success();
512 }
513};
514
515/// Distribute a dpas op feeding into vector.yield op for the enclosing
516/// `gpu.warp_execute_on_lane_0` and put it after the warp op.
517/// The warp op will still contain the original op that will not be used by
518/// the yield op (and should be cleaned up later). The yield op will
519/// bypass the dpas's arguments. Appropriate cast ops are inserted if the
520/// distributed types does not match expected xegpu SIMT types.
521/// Example:
522/// ```
523/// #lo_a = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
524/// #lo_b = #xegpu.layout<wi_layout = [1, 16], wi_data = [2, 1]>
525/// #lo_c = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
526/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
527/// (vector<8x1xf32>) {
528/// ...
529/// %dpas = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16> ->
530/// vector<8x16xf32>
531/// gpu.yield %dpas
532/// }
533/// ```
534/// To
535/// ```
536/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<8x1xf32>,
537/// vector<8x1xf16>, vector<16x1xf16>) {
538/// ...
539/// %dead = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16>
540/// -> vector<8x16xf32>
541/// gpu.yield %dead, %arg0, %arg1
542/// }
543/// %0 = vector.shape_cast %r#1: vector<8x1xf16> to vector<8xf16>
544/// %1 = vector.shape_cast %r#2: vector<16x1xf16> to vector<16xf16>
545/// %2 = xegpu.dpas %0, %1: vector<8xf16>, vector<16xf16> ->
546/// vector<8xf32>
547/// %dpas = vector.shape_cast %2: vector<8xf32> to vector<8x1xf32>
548/// ```
549struct DpasDistribution final : public gpu::WarpDistributionPattern {
550 using gpu::WarpDistributionPattern::WarpDistributionPattern;
551 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
552 PatternRewriter &rewriter) const override {
553 OpOperand *operand =
554 getWarpResult(warpOp: subgroupOp, fn: llvm::IsaPred<xegpu::DpasOp>);
555 if (!operand)
556 return rewriter.notifyMatchFailure(arg&: subgroupOp,
557 msg: "warp result is not a xegpu::Dpas op");
558
559 auto dpasOp = operand->get().getDefiningOp<xegpu::DpasOp>();
560 unsigned operandIdx = operand->getOperandNumber();
561 std::string layoutAName = xegpu::getLayoutName(operand: dpasOp->getOpOperand(idx: 0));
562 std::string layoutBName = xegpu::getLayoutName(operand: dpasOp->getOpOperand(idx: 1));
563 std::string layoutCName = xegpu::getLayoutName(result: dpasOp->getOpResult(idx: 0));
564
565 xegpu::LayoutAttr layoutA =
566 dpasOp->getAttrOfType<xegpu::LayoutAttr>(name: layoutAName);
567 xegpu::LayoutAttr layoutB =
568 dpasOp->getAttrOfType<xegpu::LayoutAttr>(name: layoutBName);
569 xegpu::LayoutAttr layoutOut =
570 dpasOp->getAttrOfType<xegpu::LayoutAttr>(name: layoutCName);
571 if (!layoutA || !layoutB || !layoutOut)
572 return rewriter.notifyMatchFailure(
573 arg&: dpasOp,
574 msg: "the xegpu::Dpas op lacks layout attribute for A, B or output");
575
576 FailureOr<VectorType> distLhsTypeByWarpOpOrFailure =
577 getDistVecTypeBasedOnLaneLayout(layout: layoutA, originalType: dpasOp.getLhsType());
578 FailureOr<VectorType> distRhsTypeByWarpOpOrFailure =
579 getDistVecTypeBasedOnLaneLayout(layout: layoutB, originalType: dpasOp.getRhsType());
580 FailureOr<VectorType> distResultTypeByWarpOpOrFailure =
581 getDistVecTypeBasedOnLaneLayout(layout: layoutOut, originalType: dpasOp.getResultType());
582 if (failed(Result: distLhsTypeByWarpOpOrFailure) ||
583 failed(Result: distRhsTypeByWarpOpOrFailure) ||
584 failed(Result: distResultTypeByWarpOpOrFailure))
585 return rewriter.notifyMatchFailure(
586 arg&: dpasOp,
587 msg: "Failed to distribute the A, B or output types in xegpu::Dpas op");
588
589 llvm::SmallVector<Value, 3> newYieldValues{dpasOp.getLhs(),
590 dpasOp.getRhs()};
591 llvm::SmallVector<Type, 3> newYieldTypes{
592 distLhsTypeByWarpOpOrFailure.value(),
593 distRhsTypeByWarpOpOrFailure.value()};
594 // Dpas acc operand is optional.
595 if (dpasOp.getAcc()) {
596 newYieldValues.push_back(Elt: dpasOp.getAcc());
597 newYieldTypes.push_back(Elt: distResultTypeByWarpOpOrFailure.value());
598 }
599 // Create a new warp op without the dpas.
600 SmallVector<size_t> newRetIndices;
601 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
602 rewriter, warpOp: subgroupOp, newYieldedValues: newYieldValues, newReturnTypes: newYieldTypes, indices&: newRetIndices);
603
604 FailureOr<VectorType> expectedDistLhsTyOrFailure =
605 xegpu::getDistributedVectorType(originalType: dpasOp.getLhsType(), layout: layoutA);
606 FailureOr<VectorType> expectedDistRhsTyOrFailure =
607 xegpu::getDistributedVectorType(originalType: dpasOp.getRhsType(), layout: layoutB);
608 FailureOr<VectorType> expectedDistResultTyOrFailure =
609 xegpu::getDistributedVectorType(originalType: dpasOp.getResultType(), layout: layoutOut);
610 if (failed(Result: expectedDistLhsTyOrFailure) ||
611 failed(Result: expectedDistRhsTyOrFailure) ||
612 failed(Result: expectedDistResultTyOrFailure))
613 return rewriter.notifyMatchFailure(
614 arg&: dpasOp,
615 msg: "Failed to get distributed vector type for the dpas operands.");
616 // Create a new dpas op outside the warp op.
617 rewriter.setInsertionPointAfter(newWarpOp);
618 SmallVector<Value> newDpasOperands;
619 SmallVector<VectorType> newDpasOperandExpectedTypes;
620
621 // Resolve the distributed types with the original types.
622 newDpasOperandExpectedTypes.push_back(Elt: expectedDistLhsTyOrFailure.value());
623 newDpasOperandExpectedTypes.push_back(Elt: expectedDistRhsTyOrFailure.value());
624 VectorType distributedResultTy = expectedDistResultTyOrFailure.value();
625 if (dpasOp.getAcc())
626 newDpasOperandExpectedTypes.push_back(Elt: distributedResultTy);
627
628 for (unsigned i = 0; i < newRetIndices.size(); i++) {
629 newDpasOperands.push_back(
630 Elt: resolveDistributedTy(orig: newWarpOp.getResult(i: newRetIndices[i]),
631 expected: newDpasOperandExpectedTypes[i], rewriter));
632 }
633 Value newDpasOp = rewriter.create<xegpu::DpasOp>(
634 location: newWarpOp->getLoc(), args&: distributedResultTy, args&: newDpasOperands,
635 args: removeTemporaryLayoutAttributes(attrs: dpasOp->getAttrs()));
636 Value distributedVal = newWarpOp.getResult(i: operandIdx);
637 // Resolve the output type.
638 newDpasOp = resolveDistributedTy(
639 orig: newDpasOp, expected: distResultTypeByWarpOpOrFailure.value(), rewriter);
640 rewriter.replaceAllUsesWith(from: distributedVal, to: newDpasOp);
641 return success();
642 }
643};
644
645/// Sink an update_nd_offset op feeding into yield op of an enclosing
646/// `gpu.warp_execute_on_lane_0` region. The warp op will still contain the
647/// original op that will not be used by the yield op (and should be cleaned
648/// up later). The yield op will bypass the updateOp's arguments. The tensor
649/// descriptor type is not distributed. Appropriate cast ops are inserted if
650/// the distributed types does not match expected xegpu SIMT types.
651/// Example:
652/// ```
653/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
654/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
655/// (!xegpu.tensor_desc<4x8xf32, #layout0>) {
656/// ...
657/// %update = xegpu.update_nd_offset %arg0, [%c32, %c16]:
658/// !xegpu.tensor_desc<4x8xf32, #layout0>
659/// gpu.yield %update
660/// }
661/// ...
662/// ```
663/// To
664/// ```
665/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (
666/// !xegpu.tensor_desc<4x8xf32, #layout0>,
667/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
668/// ...
669/// %dead = xegpu.update_nd_offset %arg0, [%c32, %c16]:
670/// !xegpu.tensor_desc<4x8xf32, #layout0> gpu.yield %dead, %arg0
671/// gpu.yield %dead, %arg0, %c32, %c16
672/// }
673/// %0 = xegpu.unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
674/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
675/// %1 = xegpu.update_nd_offset %0, [%r#2, %r#3]:
676/// !xegpu.tensor_desc<4x8xf32>
677/// ...
678/// ```
679struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
680 using gpu::WarpDistributionPattern::WarpDistributionPattern;
681 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
682 PatternRewriter &rewriter) const override {
683 OpOperand *operand =
684 getWarpResult(warpOp: subgroupOp, fn: llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
685 if (!operand)
686 return rewriter.notifyMatchFailure(
687 arg&: subgroupOp, msg: "warp result is not a xegpu::UpdateNdOffset op");
688 auto updateOp = operand->get().getDefiningOp<xegpu::UpdateNdOffsetOp>();
689 unsigned operandIdx = operand->getOperandNumber();
690 // new update op does not have layout attribute.
691 xegpu::TensorDescType newTensorDescTy =
692 updateOp.getTensorDescType().dropLayouts();
693
694 SmallVector<Value, 3> newYieldValues;
695 SmallVector<Type, 3> newYieldTypes;
696 for (Value operand : updateOp->getOperands()) {
697 newYieldValues.push_back(Elt: operand);
698 if (isa<xegpu::TensorDescType>(Val: operand.getType())) {
699 newYieldTypes.push_back(Elt: newTensorDescTy);
700 } else {
701 newYieldTypes.push_back(Elt: operand.getType());
702 }
703 }
704 SmallVector<size_t> newRetIndices;
705 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
706 rewriter, warpOp: subgroupOp, newYieldedValues: newYieldValues, newReturnTypes: newYieldTypes, indices&: newRetIndices);
707 rewriter.setInsertionPointAfter(newWarpOp);
708 SmallVector<Value> newUpdateOperands;
709 for (size_t i : newRetIndices) {
710 // For the tensor descriptor operand, the layout attribute is dropped
711 // after distribution. Types needs to be resolved in this case.
712 if (isa<xegpu::TensorDescType>(Val: newWarpOp.getResult(i).getType())) {
713 newUpdateOperands.push_back(Elt: resolveDistributedTy(
714 orig: newWarpOp.getResult(i), expected: newTensorDescTy, rewriter));
715 } else {
716 newUpdateOperands.push_back(Elt: newWarpOp.getResult(i));
717 }
718 }
719 // Create a new update op outside the warp op.
720 Value newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
721 location: newWarpOp.getLoc(), args&: newTensorDescTy, args&: newUpdateOperands,
722 args: removeTemporaryLayoutAttributes(attrs: updateOp->getAttrs()));
723 Value distributedVal = newWarpOp.getResult(i: operandIdx);
724 // Resolve the distributed type with the original type.
725 newUpdateOp =
726 resolveDistributedTy(orig: newUpdateOp, expected: distributedVal.getType(), rewriter);
727 rewriter.replaceAllUsesWith(from: distributedVal, to: newUpdateOp);
728 return success();
729 }
730};
731
732/// Distribute a prefetch_nd op at the end of enclosing
733/// `gpu.warp_execute_on_lane_0`. In case arguments for the prefetch are passed
734/// through the warp op interface they would be propagated as returned values.
735/// Tensor descriptor shape is not distributed because it is a uniform value
736/// across all work items within the subgroup. Appropriate cast ops are inserted
737/// if the distributed types does not match expected xegpu SIMT types.
738///
739/// Example:
740///
741/// ```
742/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
743/// gpu.warp_execute_on_lane_0(%laneid) -> () {
744/// ...
745/// xegpu.prefetch_nd %arg0 : !xegpu.tensor_desc<4x8xf32, #layout0>
746/// }
747/// ```
748/// To
749/// ```
750/// %r:1 = gpu.warp_execute_on_lane_0(%laneid) -> (
751/// !xegpu.tensor_desc<4x8xf32, #layout0>) {
752/// gpu.yield %arg0: !xegpu.tensor_desc<4x8xf32, #layout0>
753/// }
754/// %1 = unrealized_conversion_cast %r#0: !xegpu.tensor_desc<4x8xf32,
755/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
756/// xegpu.prefetch_nd %1 : !xegpu.tensor_desc<4x8xf32>
757///
758/// ```
759struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
760 using gpu::WarpDistributionPattern::WarpDistributionPattern;
761 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
762 PatternRewriter &rewriter) const override {
763 auto yield = cast<gpu::YieldOp>(
764 Val: subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
765 Operation *lastNode = yield->getPrevNode();
766 auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(Val: lastNode);
767 if (!prefetchOp)
768 return failure();
769 xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
770 if (!layout)
771 return rewriter.notifyMatchFailure(
772 arg&: prefetchOp, msg: "the source tensor descriptor lacks layout attribute");
773
774 SmallVector<Value, 1> newYieldValues = {prefetchOp.getTensorDesc()};
775 SmallVector<Type, 1> newYieldTypes = {prefetchOp.getTensorDescType()};
776 SmallVector<size_t> newRetIndices;
777 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
778 rewriter, warpOp: subgroupOp, newYieldedValues: newYieldValues, newReturnTypes: newYieldTypes, indices&: newRetIndices);
779 // Create a new prefetch op outside the warp op with updated tensor
780 // descriptor type. Source tensor descriptor require type resolution.
781 xegpu::TensorDescType newTensorDescTy =
782 prefetchOp.getTensorDescType().dropLayouts();
783 rewriter.setInsertionPointAfter(newWarpOp);
784 SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
785 orig: newWarpOp.getResult(i: newRetIndices[0]), expected: newTensorDescTy, rewriter)};
786 rewriter.create<xegpu::PrefetchNdOp>(
787 location: newWarpOp.getLoc(), args: TypeRange{}, args&: newPrefetchOperands,
788 args: removeTemporaryLayoutAttributes(attrs: prefetchOp->getAttrs()));
789 rewriter.eraseOp(op: prefetchOp);
790 return success();
791 }
792};
793
794/// Sink a gpu::BarrierOp at the end of enclosing `gpu.warp_execute_on_lane_0`
795/// region. This will simply move the barrier op outside of the warp op.
796struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
797 using gpu::WarpDistributionPattern::WarpDistributionPattern;
798 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
799 PatternRewriter &rewriter) const override {
800 auto yield = cast<gpu::YieldOp>(
801 Val: subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
802 Operation *lastNode = yield->getPrevNode();
803 // The last node must be a gpu::BarrierOp.
804 auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(Val: lastNode);
805 if (!barrierOp)
806 return failure();
807 // Move the barrier op outside of the warp op.
808 rewriter.setInsertionPointAfter(subgroupOp);
809 rewriter.create<gpu::BarrierOp>(
810 location: barrierOp.getLoc(), args: barrierOp->getResultTypes(),
811 args: barrierOp->getOperands(), args: barrierOp->getAttrs());
812 rewriter.eraseOp(op: barrierOp);
813 return success();
814 }
815};
816
817} // namespace
818
819namespace {
820struct XeGPUSubgroupDistributePass final
821 : public xegpu::impl::XeGPUSubgroupDistributeBase<
822 XeGPUSubgroupDistributePass> {
823 void runOnOperation() override;
824};
825} // namespace
826
827void xegpu::populateXeGPUSubgroupDistributePatterns(
828 RewritePatternSet &patterns) {
829 patterns.add<CreateNdDescDistribution, StoreNdDistribution,
830 LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
831 UpdateNdOffsetDistribution, GpuBarrierDistribution>(
832 arg: patterns.getContext());
833}
834
835void XeGPUSubgroupDistributePass::runOnOperation() {
836 // Step 1: Attach layouts to op operands.
837 // TODO: Following assumptions are made:
838 // 1) It is assumed that there are no layout conflicts.
839 // 2) Any existing layout attributes attached to the operands are ignored.
840 Operation *op = getOperation();
841 op->walk(callback: [&](Operation *op) {
842 for (OpOperand &operand : op->getOpOperands()) {
843 // Layouts are needed for vector type only.
844 if (!isa<VectorType>(Val: operand.get().getType()))
845 continue;
846
847 xegpu::LayoutAttr layout = xegpu::getLayoutAttr(opr: operand);
848 if (!layout) {
849 op->emitError(message: "Could not find layout attribute for operand ")
850 << operand.getOperandNumber() << " of operation " << op->getName();
851 signalPassFailure();
852 return;
853 }
854 xegpu::setLayoutAttr(operandOrResult: operand, layout);
855 }
856 });
857 // Step 2: Move all operations of a GPU function inside
858 // gpu.warp_execute_on_lane_0 operation.
859 {
860 RewritePatternSet patterns(&getContext());
861 patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(arg: &getContext());
862
863 if (failed(Result: applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns)))) {
864 signalPassFailure();
865 return;
866 }
867 // At this point, we have moved the entire function body inside the
868 // warpOp. Now move any scalar uniform code outside of the warpOp (like
869 // GPU index ops, scalar constants, etc.). This will simplify the
870 // later lowering and avoid custom patterns for these ops.
871 getOperation()->walk(callback: [&](Operation *op) {
872 if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(Val: op))
873 vector::moveScalarUniformCode(op: warpOp);
874 });
875 }
876 // Step 3: Apply subgroup to workitem distribution patterns.
877 RewritePatternSet patterns(&getContext());
878 xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
879 // distributionFn is used by vector distribution patterns to determine the
880 // distributed vector type for a given vector value. In XeGPU subgroup
881 // distribution context, we compute this based on lane layout.
882 auto distributionFn = [](Value val) {
883 VectorType vecType = dyn_cast<VectorType>(Val: val.getType());
884 int64_t vecRank = vecType ? vecType.getRank() : 0;
885 if (vecRank == 0)
886 return AffineMap::get(context: val.getContext());
887 // Get the layout of the vector type.
888 xegpu::LayoutAttr layout = xegpu::getLayoutAttr(value: val);
889 // If no layout is specified, assume the inner most dimension is distributed
890 // for now.
891 if (!layout)
892 return AffineMap::getMultiDimMapWithTargets(
893 numDims: vecRank, targets: {static_cast<unsigned int>(vecRank - 1)}, context: val.getContext());
894 SmallVector<unsigned int> distributedDims;
895 // Get the distributed dimensions based on the layout.
896 ArrayRef<int> laneLayout = layout.getLaneLayout().asArrayRef();
897 for (unsigned i = 0; i < laneLayout.size(); ++i) {
898 if (laneLayout[i] > 1)
899 distributedDims.push_back(Elt: i);
900 }
901 return AffineMap::getMultiDimMapWithTargets(numDims: vecRank, targets: distributedDims,
902 context: val.getContext());
903 };
904 // TODO: shuffleFn is not used.
905 auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
906 int64_t warpSz) { return Value(); };
907 vector::populatePropagateWarpVectorDistributionPatterns(
908 pattern&: patterns, distributionMapFn: distributionFn, warpShuffleFromIdxFn: shuffleFn);
909 if (failed(Result: applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns)))) {
910 signalPassFailure();
911 return;
912 }
913
914 // Step 4: Finllay, clean up UnrealizedConversionCastOps that were inserted
915 // due to tensor desc type mismatches created by using upstream distribution
916 // patterns (scf.for)
917 getOperation()->walk(callback: [&](mlir::UnrealizedConversionCastOp op) {
918 // We are only interested in UnrealizedConversionCastOps there were added
919 // for resolving SIMT type mismatches.
920 if (!op->getAttr(name: resolveSIMTTypeMismatch))
921 return WalkResult::skip();
922
923 Value input = op.getOperand(i: 0);
924 Value output = op.getResult(i: 0);
925
926 // Both input and output must have tensor descriptor types.
927 xegpu::TensorDescType inputDescType =
928 mlir::dyn_cast<xegpu::TensorDescType>(Val: input.getType());
929 xegpu::TensorDescType outputDescType =
930 mlir::dyn_cast<xegpu::TensorDescType>(Val: output.getType());
931 assert(inputDescType && outputDescType &&
932 "Unrealized conversion cast must have tensor descriptor types");
933
934 // tensor_desc<shape, layout> -> tensor_desc<shape> Type of conversions.
935 // This occurs iside scf.for body to resolve the block argument type to
936 // SIMT type.
937 if (inputDescType.getLayout()) {
938 auto argument = mlir::dyn_cast<mlir::BlockArgument>(Val&: input);
939 if (argument) {
940 argument.setType(output.getType());
941 output.replaceAllUsesWith(newValue: argument);
942 if (auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>(
943 Val: argument.getOwner()->getParentOp())) {
944 auto result = loopOp.getTiedLoopResult(bbArg: argument);
945 result.setType(output.getType());
946 }
947 }
948 }
949
950 // tensor_desc<shape> -> tensor_desc<shape, layout> Type of
951 // conversions. This occurs at the yield op of scf.for body to go back
952 // from SIMT type to original type.
953 if (outputDescType.getLayout())
954 output.replaceAllUsesWith(newValue: input);
955
956 if (op->use_empty())
957 op->erase();
958 return WalkResult::advance();
959 });
960}
961

source code of mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp