1 | //===- XeGPUWgToSgDistribute.cpp - XeGPU Workgroup to Subgroup Pass -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | #include "mlir/Dialect/XeGPU/Transforms/Passes.h" |
9 | |
10 | #include "mlir/Dialect/Affine/Utils.h" |
11 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
12 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
13 | #include "mlir/Dialect/Index/IR/IndexDialect.h" |
14 | #include "mlir/Dialect/Index/IR/IndexOps.h" |
15 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
16 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
17 | #include "mlir/Dialect/XeGPU/IR/XeGPU.h" |
18 | #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" |
19 | #include "mlir/Transforms/DialectConversion.h" |
20 | |
21 | namespace mlir { |
22 | namespace xegpu { |
23 | #define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE |
24 | #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" |
25 | } // namespace xegpu |
26 | } // namespace mlir |
27 | |
28 | using namespace mlir; |
29 | |
30 | namespace { |
31 | |
32 | /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor |
33 | /// from a workgroup descriptor. It replaces the offsets and sizes with |
34 | /// appropriate values for the subgroup. |
35 | /// It uses round-robin assignment to distribute the work to the subgroups. |
36 | /// Following create_nd_desc operation:, |
37 | /// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32> |
38 | /// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4], |
39 | /// sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> |
40 | /// is converted to 9 subgroup level operations based on the sg_layout & |
41 | /// sg_data: |
42 | /// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> -> |
43 | /// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], |
44 | /// lane_data = [1, 1]>> |
45 | /// |
46 | /// The sg_layout and sg_data attributes are dropped after the pass as they are |
47 | /// no longer needed. |
48 | /// |
49 | /// 24x24 matrix distribution example: |
50 | /// sg_layout = [4, 4], sg_data = [2, 2] |
51 | /// Each 8x8 matrix within the 24x24 matrix is called a distribution unit. |
52 | /// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i] |
53 | /// |
54 | /// +------------------------+ |
55 | /// | 8x8 | 8x8 | 8x8 | <- 3 tiles across |
56 | /// |-----+-----+-----| |
57 | /// | 8x8 | 8x8 | 8x8 | <- 3 tiles down |
58 | /// |-----+-----+-----| |
59 | /// | 8x8 | 8x8 | 8x8 | |
60 | /// +------------------------+ |
61 | /// |
62 | /// Each 8x8 tile is further subdivided among subgroups: |
63 | /// +------------------------+ |
64 | /// | 2x2 2x2 2x2 2x2 | <- 4 subgroups across (each handles 2 columns) |
65 | /// | 2x2 2x2 2x2 2x2 | <- 4 subgroups down (each handles 2 rows) |
66 | /// | 2x2 2x2 2x2 2x2 | |
67 | /// | 2x2 2x2 2x2 2x2 | |
68 | /// +------------------------+ |
69 | /// |
70 | /// Since the 24x24 matrix is divided into 8x8 distribution units, there will be |
71 | /// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations. |
72 | |
73 | /// The pass currently has entire distribution logic in the WgToSgCreateNdOp |
74 | /// pattern and all the other ops just follow. |
75 | /// TODO: Decouple the distribution logic from WgToSgCreateNdOp for all the |
76 | /// ops in the pass. |
77 | struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> { |
78 | using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern; |
79 | |
80 | // Calculate offset for each subgroup |
81 | SmallVector<OpFoldResult> |
82 | calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc, |
83 | const SmallVector<OpFoldResult> &originalOffsets, |
84 | const SmallVector<Value> &localOffset, |
85 | const SmallVector<int64_t> &distUnitBaseAddr, |
86 | const SmallVector<int64_t> &distUnitShape) const { |
87 | assert(localOffset.size() == distUnitBaseAddr.size() && |
88 | "localOffset and distUnitBaseAddr must have the same rank" ); |
89 | |
90 | SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(), |
91 | originalOffsets.end()); |
92 | size_t rank = localOffset.size(); |
93 | for (size_t i = 0; i < rank; ++i) { |
94 | size_t dimIdx = originalOffsets.size() - rank + i; |
95 | Value constOffset = |
96 | rewriter.create<arith::ConstantIndexOp>(loc, distUnitBaseAddr[i]); |
97 | Value offset = |
98 | rewriter.createOrFold<index::AddOp>(loc, localOffset[i], constOffset); |
99 | Value modValue = |
100 | rewriter.create<arith::ConstantIndexOp>(loc, distUnitShape[i]); |
101 | Value offsetMod = |
102 | rewriter.createOrFold<index::RemUOp>(loc, offset, modValue); |
103 | Value origOffset = getValueOrCreateConstantIndexOp( |
104 | rewriter, loc, originalOffsets[dimIdx]); |
105 | Value globalOffset = |
106 | rewriter.createOrFold<index::AddOp>(loc, origOffset, offsetMod); |
107 | globalOffsets[dimIdx] = globalOffset; |
108 | } |
109 | |
110 | return globalOffsets; |
111 | } |
112 | |
113 | LogicalResult |
114 | matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor, |
115 | ConversionPatternRewriter &rewriter) const override { |
116 | Location loc = op.getLoc(); |
117 | MLIRContext *ctx = op.getContext(); |
118 | xegpu::TensorDescType tdescTy = op.getType(); |
119 | auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout()); |
120 | if (!layout) |
121 | return failure(); |
122 | Type elemTy = tdescTy.getElementType(); |
123 | ArrayRef<int64_t> wgShape = tdescTy.getShape(); |
124 | // sgLayout must be present for workgroup-level distribution. |
125 | SmallVector<int64_t> sgLayout; |
126 | if (auto sgLayoutAttr = layout.getSgLayout()) |
127 | sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef()); |
128 | else |
129 | return rewriter.notifyMatchFailure( |
130 | op, "sgLayout attribute is required in layout" ); |
131 | |
132 | SmallVector<int64_t> sgShape; |
133 | if (auto sgDataAttr = layout.getSgData()) { |
134 | sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef()); |
135 | } else { |
136 | assert(wgShape.size() == sgLayout.size() && |
137 | "sgLayout and wgShape must have the same rank" ); |
138 | sgShape.reserve(wgShape.size()); |
139 | for (size_t i = 0; i < wgShape.size(); ++i) { |
140 | assert(sgLayout[i] != 0 && "sgLayout elements must be non-zero" ); |
141 | sgShape.push_back(wgShape[i] / sgLayout[i]); |
142 | } |
143 | } |
144 | |
145 | // TODO : Handle order attribute |
146 | // Get the subgroup ID |
147 | auto linearSgId = |
148 | rewriter.create<gpu::SubgroupIdOp>(loc, /*upper_bound=*/nullptr); |
149 | |
150 | // Create constants for layout dimensions |
151 | SmallVector<Value> sgLayoutDim(sgLayout.size()); |
152 | SmallVector<Value> sgDataDim(sgShape.size()); |
153 | |
154 | for (size_t i = 0; i < sgLayout.size(); i++) { |
155 | sgLayoutDim[i] = |
156 | rewriter.create<arith::ConstantIndexOp>(loc, sgLayout[i]); |
157 | sgDataDim[i] = rewriter.create<arith::ConstantIndexOp>(loc, sgShape[i]); |
158 | } |
159 | |
160 | auto deLinearizeSgId = |
161 | affine::delinearizeIndex(rewriter, loc, linearSgId, sgLayoutDim); |
162 | if (failed(deLinearizeSgId)) |
163 | return failure(); |
164 | SmallVector<Value> sgIds = *deLinearizeSgId; |
165 | |
166 | // Calculate distribution unit shape and local offsets for subgroup |
167 | SmallVector<int64_t> distUnitShape(sgLayout.size()); |
168 | SmallVector<Value> localOffset(sgLayout.size()); |
169 | for (size_t i = 0; i < sgLayout.size(); i++) { |
170 | distUnitShape[i] = std::min(sgLayout[i] * sgShape[i], wgShape[i]); |
171 | localOffset[i] = |
172 | rewriter.createOrFold<index::MulOp>(loc, sgIds[i], sgDataDim[i]); |
173 | } |
174 | |
175 | SmallVector<OpFoldResult> originalOffsets = op.getMixedOffsets(); |
176 | |
177 | xegpu::TensorDescType newTdescTy = |
178 | xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(), |
179 | layout.dropSgLayoutAndData()); |
180 | SmallVector<Value> newCreateNdOps; |
181 | for (SmallVector<int64_t> distUnitBaseAddr : |
182 | StaticTileOffsetRange(wgShape, distUnitShape)) { |
183 | SmallVector<OpFoldResult> globalOffsets = |
184 | calculateGlobalOffsets(rewriter, loc, originalOffsets, localOffset, |
185 | distUnitBaseAddr, distUnitShape); |
186 | |
187 | auto newCreateNdOp = rewriter.create<xegpu::CreateNdDescOp>( |
188 | loc, newTdescTy, op.getSource(), globalOffsets, op.getMixedSizes(), |
189 | op.getMixedStrides()); |
190 | newCreateNdOps.push_back(newCreateNdOp); |
191 | } |
192 | |
193 | rewriter.replaceOpWithMultiple(op, {newCreateNdOps}); |
194 | return success(); |
195 | } |
196 | }; |
197 | |
198 | /// This pattern transforms the LoadNdOp to load subgroup data. |
199 | struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> { |
200 | using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern; |
201 | LogicalResult |
202 | matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor, |
203 | ConversionPatternRewriter &rewriter) const override { |
204 | SmallVector<Value> newLoadOps; |
205 | for (auto src : adaptor.getTensorDesc()) { |
206 | xegpu::TensorDescType tdescTy = |
207 | dyn_cast<xegpu::TensorDescType>(src.getType()); |
208 | ArrayRef<int64_t> srcShape = tdescTy.getShape(); |
209 | VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType()); |
210 | auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(op.getLoc(), newResTy, |
211 | src, op->getAttrs()); |
212 | newLoadOps.push_back(newLoadOp); |
213 | } |
214 | rewriter.replaceOpWithMultiple(op, {newLoadOps}); |
215 | return mlir::success(); |
216 | } |
217 | }; |
218 | |
219 | /// This pattern transforms the StoreNdOp to store to a subgroup descriptor |
220 | /// It creates a StoreNdOp op to store the updated values to the new subgroup |
221 | /// src tensor descriptors. |
222 | struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> { |
223 | using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern; |
224 | LogicalResult |
225 | matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor, |
226 | ConversionPatternRewriter &rewriter) const override { |
227 | for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc())) |
228 | rewriter.create<xegpu::StoreNdOp>(op.getLoc(), v, t, op.getL1HintAttr(), |
229 | op.getL2HintAttr(), op.getL3HintAttr()); |
230 | |
231 | rewriter.eraseOp(op: op); |
232 | return success(); |
233 | } |
234 | }; |
235 | |
236 | /// This pattern transforms the UpdateNdOffsetOp to update the offsets of a |
237 | /// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the |
238 | /// offsets of the new subgroup src tensor descriptors. |
239 | struct WgToSgUpdateNdOffsetOp |
240 | : public OpConversionPattern<xegpu::UpdateNdOffsetOp> { |
241 | using OpConversionPattern<xegpu::UpdateNdOffsetOp>::OpConversionPattern; |
242 | LogicalResult |
243 | matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor, |
244 | ConversionPatternRewriter &rewriter) const override { |
245 | llvm::SmallVector<Value> newUpdateTileOffsetOps; |
246 | for (auto tDesc : adaptor.getTensorDesc()) { |
247 | auto newUpdateTileOffsetOp = rewriter.create<xegpu::UpdateNdOffsetOp>( |
248 | op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(), |
249 | op.getConstOffsets()); |
250 | newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp); |
251 | } |
252 | |
253 | rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps}); |
254 | return success(); |
255 | } |
256 | }; |
257 | |
258 | /// This pattern transforms the DpasOp to work at subgroup level. |
259 | struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> { |
260 | using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern; |
261 | LogicalResult |
262 | matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor, |
263 | ConversionPatternRewriter &rewriter) const override { |
264 | Location loc = op.getLoc(); |
265 | VectorType resultTy = op.getResult().getType(); |
266 | if (resultTy.getRank() != 2) |
267 | return failure(); |
268 | |
269 | auto originalLayout = |
270 | llvm::dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout" )); |
271 | if (!originalLayout) |
272 | return failure(); |
273 | |
274 | SmallVector<Value> newDpasOps; |
275 | size_t i = 0; |
276 | for (auto aVec : adaptor.getLhs()) { |
277 | for (auto bVec : adaptor.getRhs()) { |
278 | llvm::SmallVector<Value> operands({aVec, bVec}); |
279 | Value tmpC; |
280 | if (op.getAcc()) { |
281 | tmpC = adaptor.getAcc()[i++]; |
282 | operands.push_back(tmpC); |
283 | } |
284 | |
285 | ArrayRef<int64_t> aVecShape = |
286 | llvm::cast<VectorType>(aVec.getType()).getShape(); |
287 | ArrayRef<int64_t> bVecShape = |
288 | llvm::cast<VectorType>(bVec.getType()).getShape(); |
289 | VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]}, |
290 | resultTy.getElementType()); |
291 | tmpC = rewriter.create<xegpu::DpasOp>( |
292 | loc, resTy, operands, |
293 | llvm::ArrayRef<NamedAttribute>( |
294 | {"layout_result_0" , originalLayout.dropSgLayoutAndData()})); |
295 | newDpasOps.push_back(tmpC); |
296 | } |
297 | } |
298 | rewriter.replaceOpWithMultiple(op, {newDpasOps}); |
299 | return success(); |
300 | } |
301 | }; |
302 | |
303 | /// This pattern transforms the PrefetchNdOp to prefetch the subgroup data. |
304 | struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> { |
305 | using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern; |
306 | LogicalResult |
307 | matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor, |
308 | ConversionPatternRewriter &rewriter) const override { |
309 | for (auto src : adaptor.getTensorDesc()) |
310 | rewriter.create<xegpu::PrefetchNdOp>(op.getLoc(), TypeRange(), src, |
311 | op->getAttrs()); |
312 | rewriter.eraseOp(op: op); |
313 | return success(); |
314 | } |
315 | }; |
316 | |
317 | } // namespace |
318 | |
319 | namespace mlir { |
320 | namespace xegpu { |
321 | void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { |
322 | patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp, |
323 | WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>( |
324 | arg: patterns.getContext()); |
325 | } |
326 | } // namespace xegpu |
327 | } // namespace mlir |
328 | |
329 | namespace { |
330 | struct XeGPUWgToSgDistributePass |
331 | : public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> { |
332 | void runOnOperation() override; |
333 | }; |
334 | } // namespace |
335 | |
336 | void XeGPUWgToSgDistributePass::runOnOperation() { |
337 | MLIRContext *ctx = &getContext(); |
338 | RewritePatternSet patterns(ctx); |
339 | ConversionTarget target(*ctx); |
340 | |
341 | auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType { |
342 | if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op)) |
343 | return createOp.getType(); |
344 | if (auto loadOp = dyn_cast<xegpu::LoadNdOp>(op)) |
345 | return loadOp.getTensorDescType(); |
346 | if (auto storeOp = dyn_cast<xegpu::StoreNdOp>(op)) |
347 | return storeOp.getTensorDescType(); |
348 | if (auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op)) |
349 | return updateOp.getType(); |
350 | if (auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op)) |
351 | return prefetchOp.getTensorDescType(); |
352 | return xegpu::TensorDescType(); |
353 | }; |
354 | |
355 | auto isLegal = [&](xegpu::LayoutAttr layout) -> bool { |
356 | return !layout || layout.getSgLayout() == nullptr; |
357 | }; |
358 | |
359 | target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp, |
360 | xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp, |
361 | xegpu::PrefetchNdOp>([=](Operation *op) -> bool { |
362 | auto tdescTy = getTensorDescType(op); |
363 | auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(tdescTy.getLayout()); |
364 | return isLegal(layout); |
365 | }); |
366 | |
367 | target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool { |
368 | auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout" )); |
369 | return isLegal(layout); |
370 | }); |
371 | |
372 | target.markUnknownOpDynamicallyLegal(fn: [](Operation *) { return true; }); |
373 | |
374 | xegpu::populateXeGPUWgToSgDistributePatterns(patterns); |
375 | if (failed( |
376 | applyPartialConversion(getOperation(), target, std::move(patterns)))) |
377 | return signalPassFailure(); |
378 | } |
379 | |