1//===- XeGPUWgToSgDistribute.cpp - XeGPU Workgroup to Subgroup Pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
9
10#include "mlir/Dialect/Affine/Utils.h"
11#include "mlir/Dialect/Arith/Utils/Utils.h"
12#include "mlir/Dialect/GPU/IR/GPUDialect.h"
13#include "mlir/Dialect/Index/IR/IndexDialect.h"
14#include "mlir/Dialect/Index/IR/IndexOps.h"
15#include "mlir/Dialect/MemRef/IR/MemRef.h"
16#include "mlir/Dialect/Utils/IndexingUtils.h"
17#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
18#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
19#include "mlir/Transforms/DialectConversion.h"
20
21namespace mlir {
22namespace xegpu {
23#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
24#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
25} // namespace xegpu
26} // namespace mlir
27
28using namespace mlir;
29
30namespace {
31
32/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
33/// from a workgroup descriptor. It replaces the offsets and sizes with
34/// appropriate values for the subgroup.
35/// It uses round-robin assignment to distribute the work to the subgroups.
36/// Following create_nd_desc operation:,
37/// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
38/// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
39/// sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
40/// is converted to 9 subgroup level operations based on the sg_layout &
41/// sg_data:
42/// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
43/// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2],
44/// lane_data = [1, 1]>>
45///
46/// The sg_layout and sg_data attributes are dropped after the pass as they are
47/// no longer needed.
48///
49/// 24x24 matrix distribution example:
50/// sg_layout = [4, 4], sg_data = [2, 2]
51/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
52/// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i]
53///
54/// +------------------------+
55/// | 8x8 | 8x8 | 8x8 | <- 3 tiles across
56/// |-----+-----+-----|
57/// | 8x8 | 8x8 | 8x8 | <- 3 tiles down
58/// |-----+-----+-----|
59/// | 8x8 | 8x8 | 8x8 |
60/// +------------------------+
61///
62/// Each 8x8 tile is further subdivided among subgroups:
63/// +------------------------+
64/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups across (each handles 2 columns)
65/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups down (each handles 2 rows)
66/// | 2x2 2x2 2x2 2x2 |
67/// | 2x2 2x2 2x2 2x2 |
68/// +------------------------+
69///
70/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be
71/// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations.
72
73/// The pass currently has entire distribution logic in the WgToSgCreateNdOp
74/// pattern and all the other ops just follow.
75/// TODO: Decouple the distribution logic from WgToSgCreateNdOp for all the
76/// ops in the pass.
77struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
78 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
79
80 // Calculate offset for each subgroup
81 SmallVector<OpFoldResult>
82 calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
83 const SmallVector<OpFoldResult> &originalOffsets,
84 const SmallVector<Value> &localOffset,
85 const SmallVector<int64_t> &distUnitBaseAddr,
86 const SmallVector<int64_t> &distUnitShape) const {
87 assert(localOffset.size() == distUnitBaseAddr.size() &&
88 "localOffset and distUnitBaseAddr must have the same rank");
89
90 SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
91 originalOffsets.end());
92 size_t rank = localOffset.size();
93 for (size_t i = 0; i < rank; ++i) {
94 size_t dimIdx = originalOffsets.size() - rank + i;
95 Value constOffset =
96 rewriter.create<arith::ConstantIndexOp>(loc, distUnitBaseAddr[i]);
97 Value offset =
98 rewriter.createOrFold<index::AddOp>(loc, localOffset[i], constOffset);
99 Value modValue =
100 rewriter.create<arith::ConstantIndexOp>(loc, distUnitShape[i]);
101 Value offsetMod =
102 rewriter.createOrFold<index::RemUOp>(loc, offset, modValue);
103 Value origOffset = getValueOrCreateConstantIndexOp(
104 rewriter, loc, originalOffsets[dimIdx]);
105 Value globalOffset =
106 rewriter.createOrFold<index::AddOp>(loc, origOffset, offsetMod);
107 globalOffsets[dimIdx] = globalOffset;
108 }
109
110 return globalOffsets;
111 }
112
113 LogicalResult
114 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
115 ConversionPatternRewriter &rewriter) const override {
116 Location loc = op.getLoc();
117 MLIRContext *ctx = op.getContext();
118 xegpu::TensorDescType tdescTy = op.getType();
119 auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
120 if (!layout)
121 return failure();
122 Type elemTy = tdescTy.getElementType();
123 ArrayRef<int64_t> wgShape = tdescTy.getShape();
124 // sgLayout must be present for workgroup-level distribution.
125 SmallVector<int64_t> sgLayout;
126 if (auto sgLayoutAttr = layout.getSgLayout())
127 sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
128 else
129 return rewriter.notifyMatchFailure(
130 op, "sgLayout attribute is required in layout");
131
132 SmallVector<int64_t> sgShape;
133 if (auto sgDataAttr = layout.getSgData()) {
134 sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
135 } else {
136 assert(wgShape.size() == sgLayout.size() &&
137 "sgLayout and wgShape must have the same rank");
138 sgShape.reserve(wgShape.size());
139 for (size_t i = 0; i < wgShape.size(); ++i) {
140 assert(sgLayout[i] != 0 && "sgLayout elements must be non-zero");
141 sgShape.push_back(wgShape[i] / sgLayout[i]);
142 }
143 }
144
145 // TODO : Handle order attribute
146 // Get the subgroup ID
147 auto linearSgId =
148 rewriter.create<gpu::SubgroupIdOp>(loc, /*upper_bound=*/nullptr);
149
150 // Create constants for layout dimensions
151 SmallVector<Value> sgLayoutDim(sgLayout.size());
152 SmallVector<Value> sgDataDim(sgShape.size());
153
154 for (size_t i = 0; i < sgLayout.size(); i++) {
155 sgLayoutDim[i] =
156 rewriter.create<arith::ConstantIndexOp>(loc, sgLayout[i]);
157 sgDataDim[i] = rewriter.create<arith::ConstantIndexOp>(loc, sgShape[i]);
158 }
159
160 auto deLinearizeSgId =
161 affine::delinearizeIndex(rewriter, loc, linearSgId, sgLayoutDim);
162 if (failed(deLinearizeSgId))
163 return failure();
164 SmallVector<Value> sgIds = *deLinearizeSgId;
165
166 // Calculate distribution unit shape and local offsets for subgroup
167 SmallVector<int64_t> distUnitShape(sgLayout.size());
168 SmallVector<Value> localOffset(sgLayout.size());
169 for (size_t i = 0; i < sgLayout.size(); i++) {
170 distUnitShape[i] = std::min(sgLayout[i] * sgShape[i], wgShape[i]);
171 localOffset[i] =
172 rewriter.createOrFold<index::MulOp>(loc, sgIds[i], sgDataDim[i]);
173 }
174
175 SmallVector<OpFoldResult> originalOffsets = op.getMixedOffsets();
176
177 xegpu::TensorDescType newTdescTy =
178 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
179 layout.dropSgLayoutAndData());
180 SmallVector<Value> newCreateNdOps;
181 for (SmallVector<int64_t> distUnitBaseAddr :
182 StaticTileOffsetRange(wgShape, distUnitShape)) {
183 SmallVector<OpFoldResult> globalOffsets =
184 calculateGlobalOffsets(rewriter, loc, originalOffsets, localOffset,
185 distUnitBaseAddr, distUnitShape);
186
187 auto newCreateNdOp = rewriter.create<xegpu::CreateNdDescOp>(
188 loc, newTdescTy, op.getSource(), globalOffsets, op.getMixedSizes(),
189 op.getMixedStrides());
190 newCreateNdOps.push_back(newCreateNdOp);
191 }
192
193 rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
194 return success();
195 }
196};
197
198/// This pattern transforms the LoadNdOp to load subgroup data.
199struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
200 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
201 LogicalResult
202 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
203 ConversionPatternRewriter &rewriter) const override {
204 SmallVector<Value> newLoadOps;
205 for (auto src : adaptor.getTensorDesc()) {
206 xegpu::TensorDescType tdescTy =
207 dyn_cast<xegpu::TensorDescType>(src.getType());
208 ArrayRef<int64_t> srcShape = tdescTy.getShape();
209 VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType());
210 auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(op.getLoc(), newResTy,
211 src, op->getAttrs());
212 newLoadOps.push_back(newLoadOp);
213 }
214 rewriter.replaceOpWithMultiple(op, {newLoadOps});
215 return mlir::success();
216 }
217};
218
219/// This pattern transforms the StoreNdOp to store to a subgroup descriptor
220/// It creates a StoreNdOp op to store the updated values to the new subgroup
221/// src tensor descriptors.
222struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
223 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
224 LogicalResult
225 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
226 ConversionPatternRewriter &rewriter) const override {
227 for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
228 rewriter.create<xegpu::StoreNdOp>(op.getLoc(), v, t, op.getL1HintAttr(),
229 op.getL2HintAttr(), op.getL3HintAttr());
230
231 rewriter.eraseOp(op: op);
232 return success();
233 }
234};
235
236/// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
237/// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
238/// offsets of the new subgroup src tensor descriptors.
239struct WgToSgUpdateNdOffsetOp
240 : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
241 using OpConversionPattern<xegpu::UpdateNdOffsetOp>::OpConversionPattern;
242 LogicalResult
243 matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
244 ConversionPatternRewriter &rewriter) const override {
245 llvm::SmallVector<Value> newUpdateTileOffsetOps;
246 for (auto tDesc : adaptor.getTensorDesc()) {
247 auto newUpdateTileOffsetOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
248 op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
249 op.getConstOffsets());
250 newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
251 }
252
253 rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps});
254 return success();
255 }
256};
257
258/// This pattern transforms the DpasOp to work at subgroup level.
259struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
260 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
261 LogicalResult
262 matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
263 ConversionPatternRewriter &rewriter) const override {
264 Location loc = op.getLoc();
265 VectorType resultTy = op.getResult().getType();
266 if (resultTy.getRank() != 2)
267 return failure();
268
269 auto originalLayout =
270 llvm::dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
271 if (!originalLayout)
272 return failure();
273
274 SmallVector<Value> newDpasOps;
275 size_t i = 0;
276 for (auto aVec : adaptor.getLhs()) {
277 for (auto bVec : adaptor.getRhs()) {
278 llvm::SmallVector<Value> operands({aVec, bVec});
279 Value tmpC;
280 if (op.getAcc()) {
281 tmpC = adaptor.getAcc()[i++];
282 operands.push_back(tmpC);
283 }
284
285 ArrayRef<int64_t> aVecShape =
286 llvm::cast<VectorType>(aVec.getType()).getShape();
287 ArrayRef<int64_t> bVecShape =
288 llvm::cast<VectorType>(bVec.getType()).getShape();
289 VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
290 resultTy.getElementType());
291 tmpC = rewriter.create<xegpu::DpasOp>(
292 loc, resTy, operands,
293 llvm::ArrayRef<NamedAttribute>(
294 {"layout_result_0", originalLayout.dropSgLayoutAndData()}));
295 newDpasOps.push_back(tmpC);
296 }
297 }
298 rewriter.replaceOpWithMultiple(op, {newDpasOps});
299 return success();
300 }
301};
302
303/// This pattern transforms the PrefetchNdOp to prefetch the subgroup data.
304struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
305 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
306 LogicalResult
307 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
308 ConversionPatternRewriter &rewriter) const override {
309 for (auto src : adaptor.getTensorDesc())
310 rewriter.create<xegpu::PrefetchNdOp>(op.getLoc(), TypeRange(), src,
311 op->getAttrs());
312 rewriter.eraseOp(op: op);
313 return success();
314 }
315};
316
317} // namespace
318
319namespace mlir {
320namespace xegpu {
321void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
322 patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
323 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
324 arg: patterns.getContext());
325}
326} // namespace xegpu
327} // namespace mlir
328
329namespace {
330struct XeGPUWgToSgDistributePass
331 : public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> {
332 void runOnOperation() override;
333};
334} // namespace
335
336void XeGPUWgToSgDistributePass::runOnOperation() {
337 MLIRContext *ctx = &getContext();
338 RewritePatternSet patterns(ctx);
339 ConversionTarget target(*ctx);
340
341 auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
342 if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
343 return createOp.getType();
344 if (auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
345 return loadOp.getTensorDescType();
346 if (auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
347 return storeOp.getTensorDescType();
348 if (auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
349 return updateOp.getType();
350 if (auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
351 return prefetchOp.getTensorDescType();
352 return xegpu::TensorDescType();
353 };
354
355 auto isLegal = [&](xegpu::LayoutAttr layout) -> bool {
356 return !layout || layout.getSgLayout() == nullptr;
357 };
358
359 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
360 xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
361 xegpu::PrefetchNdOp>([=](Operation *op) -> bool {
362 auto tdescTy = getTensorDescType(op);
363 auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(tdescTy.getLayout());
364 return isLegal(layout);
365 });
366
367 target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
368 auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
369 return isLegal(layout);
370 });
371
372 target.markUnknownOpDynamicallyLegal(fn: [](Operation *) { return true; });
373
374 xegpu::populateXeGPUWgToSgDistributePatterns(patterns);
375 if (failed(
376 applyPartialConversion(getOperation(), target, std::move(patterns))))
377 return signalPassFailure();
378}
379

source code of mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp