| 1 | //===- XeGPUWgToSgDistribute.cpp - XeGPU Workgroup to Subgroup Pass -------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | #include "mlir/Dialect/XeGPU/Transforms/Passes.h" |
| 9 | |
| 10 | #include "mlir/Dialect/Affine/Utils.h" |
| 11 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 12 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 13 | #include "mlir/Dialect/Index/IR/IndexDialect.h" |
| 14 | #include "mlir/Dialect/Index/IR/IndexOps.h" |
| 15 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 16 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 17 | #include "mlir/Dialect/XeGPU/IR/XeGPU.h" |
| 18 | #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" |
| 19 | #include "mlir/Transforms/DialectConversion.h" |
| 20 | |
| 21 | namespace mlir { |
| 22 | namespace xegpu { |
| 23 | #define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE |
| 24 | #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" |
| 25 | } // namespace xegpu |
| 26 | } // namespace mlir |
| 27 | |
| 28 | using namespace mlir; |
| 29 | |
| 30 | namespace { |
| 31 | |
| 32 | /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor |
| 33 | /// from a workgroup descriptor. It replaces the offsets and sizes with |
| 34 | /// appropriate values for the subgroup. |
| 35 | /// It uses round-robin assignment to distribute the work to the subgroups. |
| 36 | /// Following create_nd_desc operation:, |
| 37 | /// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32> |
| 38 | /// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4], |
| 39 | /// sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> |
| 40 | /// is converted to 9 subgroup level operations based on the sg_layout & |
| 41 | /// sg_data: |
| 42 | /// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> -> |
| 43 | /// !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], |
| 44 | /// lane_data = [1, 1]>> |
| 45 | /// |
| 46 | /// The sg_layout and sg_data attributes are dropped after the pass as they are |
| 47 | /// no longer needed. |
| 48 | /// |
| 49 | /// 24x24 matrix distribution example: |
| 50 | /// sg_layout = [4, 4], sg_data = [2, 2] |
| 51 | /// Each 8x8 matrix within the 24x24 matrix is called a distribution unit. |
| 52 | /// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i] |
| 53 | /// |
| 54 | /// +------------------------+ |
| 55 | /// | 8x8 | 8x8 | 8x8 | <- 3 tiles across |
| 56 | /// |-----+-----+-----| |
| 57 | /// | 8x8 | 8x8 | 8x8 | <- 3 tiles down |
| 58 | /// |-----+-----+-----| |
| 59 | /// | 8x8 | 8x8 | 8x8 | |
| 60 | /// +------------------------+ |
| 61 | /// |
| 62 | /// Each 8x8 tile is further subdivided among subgroups: |
| 63 | /// +------------------------+ |
| 64 | /// | 2x2 2x2 2x2 2x2 | <- 4 subgroups across (each handles 2 columns) |
| 65 | /// | 2x2 2x2 2x2 2x2 | <- 4 subgroups down (each handles 2 rows) |
| 66 | /// | 2x2 2x2 2x2 2x2 | |
| 67 | /// | 2x2 2x2 2x2 2x2 | |
| 68 | /// +------------------------+ |
| 69 | /// |
| 70 | /// Since the 24x24 matrix is divided into 8x8 distribution units, there will be |
| 71 | /// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations. |
| 72 | |
| 73 | /// The pass currently has entire distribution logic in the WgToSgCreateNdOp |
| 74 | /// pattern and all the other ops just follow. |
| 75 | /// TODO: Decouple the distribution logic from WgToSgCreateNdOp for all the |
| 76 | /// ops in the pass. |
| 77 | struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> { |
| 78 | using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern; |
| 79 | |
| 80 | // Calculate offset for each subgroup |
| 81 | SmallVector<OpFoldResult> |
| 82 | calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc, |
| 83 | const SmallVector<OpFoldResult> &originalOffsets, |
| 84 | const SmallVector<Value> &localOffset, |
| 85 | const SmallVector<int64_t> &distUnitBaseAddr, |
| 86 | const SmallVector<int64_t> &distUnitShape) const { |
| 87 | assert(localOffset.size() == distUnitBaseAddr.size() && |
| 88 | "localOffset and distUnitBaseAddr must have the same rank" ); |
| 89 | |
| 90 | SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(), |
| 91 | originalOffsets.end()); |
| 92 | size_t rank = localOffset.size(); |
| 93 | for (size_t i = 0; i < rank; ++i) { |
| 94 | size_t dimIdx = originalOffsets.size() - rank + i; |
| 95 | Value constOffset = |
| 96 | rewriter.create<arith::ConstantIndexOp>(loc, distUnitBaseAddr[i]); |
| 97 | Value offset = |
| 98 | rewriter.createOrFold<index::AddOp>(loc, localOffset[i], constOffset); |
| 99 | Value modValue = |
| 100 | rewriter.create<arith::ConstantIndexOp>(loc, distUnitShape[i]); |
| 101 | Value offsetMod = |
| 102 | rewriter.createOrFold<index::RemUOp>(loc, offset, modValue); |
| 103 | Value origOffset = getValueOrCreateConstantIndexOp( |
| 104 | rewriter, loc, originalOffsets[dimIdx]); |
| 105 | Value globalOffset = |
| 106 | rewriter.createOrFold<index::AddOp>(loc, origOffset, offsetMod); |
| 107 | globalOffsets[dimIdx] = globalOffset; |
| 108 | } |
| 109 | |
| 110 | return globalOffsets; |
| 111 | } |
| 112 | |
| 113 | LogicalResult |
| 114 | matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor, |
| 115 | ConversionPatternRewriter &rewriter) const override { |
| 116 | Location loc = op.getLoc(); |
| 117 | MLIRContext *ctx = op.getContext(); |
| 118 | xegpu::TensorDescType tdescTy = op.getType(); |
| 119 | auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout()); |
| 120 | if (!layout) |
| 121 | return failure(); |
| 122 | Type elemTy = tdescTy.getElementType(); |
| 123 | ArrayRef<int64_t> wgShape = tdescTy.getShape(); |
| 124 | // sgLayout must be present for workgroup-level distribution. |
| 125 | SmallVector<int64_t> sgLayout; |
| 126 | if (auto sgLayoutAttr = layout.getSgLayout()) |
| 127 | sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef()); |
| 128 | else |
| 129 | return rewriter.notifyMatchFailure( |
| 130 | op, "sgLayout attribute is required in layout" ); |
| 131 | |
| 132 | SmallVector<int64_t> sgShape; |
| 133 | if (auto sgDataAttr = layout.getSgData()) { |
| 134 | sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef()); |
| 135 | } else { |
| 136 | assert(wgShape.size() == sgLayout.size() && |
| 137 | "sgLayout and wgShape must have the same rank" ); |
| 138 | sgShape.reserve(wgShape.size()); |
| 139 | for (size_t i = 0; i < wgShape.size(); ++i) { |
| 140 | assert(sgLayout[i] != 0 && "sgLayout elements must be non-zero" ); |
| 141 | sgShape.push_back(wgShape[i] / sgLayout[i]); |
| 142 | } |
| 143 | } |
| 144 | |
| 145 | // TODO : Handle order attribute |
| 146 | // Get the subgroup ID |
| 147 | auto linearSgId = |
| 148 | rewriter.create<gpu::SubgroupIdOp>(loc, /*upper_bound=*/nullptr); |
| 149 | |
| 150 | // Create constants for layout dimensions |
| 151 | SmallVector<Value> sgLayoutDim(sgLayout.size()); |
| 152 | SmallVector<Value> sgDataDim(sgShape.size()); |
| 153 | |
| 154 | for (size_t i = 0; i < sgLayout.size(); i++) { |
| 155 | sgLayoutDim[i] = |
| 156 | rewriter.create<arith::ConstantIndexOp>(loc, sgLayout[i]); |
| 157 | sgDataDim[i] = rewriter.create<arith::ConstantIndexOp>(loc, sgShape[i]); |
| 158 | } |
| 159 | |
| 160 | auto deLinearizeSgId = |
| 161 | affine::delinearizeIndex(rewriter, loc, linearSgId, sgLayoutDim); |
| 162 | if (failed(deLinearizeSgId)) |
| 163 | return failure(); |
| 164 | SmallVector<Value> sgIds = *deLinearizeSgId; |
| 165 | |
| 166 | // Calculate distribution unit shape and local offsets for subgroup |
| 167 | SmallVector<int64_t> distUnitShape(sgLayout.size()); |
| 168 | SmallVector<Value> localOffset(sgLayout.size()); |
| 169 | for (size_t i = 0; i < sgLayout.size(); i++) { |
| 170 | distUnitShape[i] = std::min(sgLayout[i] * sgShape[i], wgShape[i]); |
| 171 | localOffset[i] = |
| 172 | rewriter.createOrFold<index::MulOp>(loc, sgIds[i], sgDataDim[i]); |
| 173 | } |
| 174 | |
| 175 | SmallVector<OpFoldResult> originalOffsets = op.getMixedOffsets(); |
| 176 | |
| 177 | xegpu::TensorDescType newTdescTy = |
| 178 | xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(), |
| 179 | layout.dropSgLayoutAndData()); |
| 180 | SmallVector<Value> newCreateNdOps; |
| 181 | for (SmallVector<int64_t> distUnitBaseAddr : |
| 182 | StaticTileOffsetRange(wgShape, distUnitShape)) { |
| 183 | SmallVector<OpFoldResult> globalOffsets = |
| 184 | calculateGlobalOffsets(rewriter, loc, originalOffsets, localOffset, |
| 185 | distUnitBaseAddr, distUnitShape); |
| 186 | |
| 187 | auto newCreateNdOp = rewriter.create<xegpu::CreateNdDescOp>( |
| 188 | loc, newTdescTy, op.getSource(), globalOffsets, op.getMixedSizes(), |
| 189 | op.getMixedStrides()); |
| 190 | newCreateNdOps.push_back(newCreateNdOp); |
| 191 | } |
| 192 | |
| 193 | rewriter.replaceOpWithMultiple(op, {newCreateNdOps}); |
| 194 | return success(); |
| 195 | } |
| 196 | }; |
| 197 | |
| 198 | /// This pattern transforms the LoadNdOp to load subgroup data. |
| 199 | struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> { |
| 200 | using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern; |
| 201 | LogicalResult |
| 202 | matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor, |
| 203 | ConversionPatternRewriter &rewriter) const override { |
| 204 | SmallVector<Value> newLoadOps; |
| 205 | for (auto src : adaptor.getTensorDesc()) { |
| 206 | xegpu::TensorDescType tdescTy = |
| 207 | dyn_cast<xegpu::TensorDescType>(src.getType()); |
| 208 | ArrayRef<int64_t> srcShape = tdescTy.getShape(); |
| 209 | VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType()); |
| 210 | auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(op.getLoc(), newResTy, |
| 211 | src, op->getAttrs()); |
| 212 | newLoadOps.push_back(newLoadOp); |
| 213 | } |
| 214 | rewriter.replaceOpWithMultiple(op, {newLoadOps}); |
| 215 | return mlir::success(); |
| 216 | } |
| 217 | }; |
| 218 | |
| 219 | /// This pattern transforms the StoreNdOp to store to a subgroup descriptor |
| 220 | /// It creates a StoreNdOp op to store the updated values to the new subgroup |
| 221 | /// src tensor descriptors. |
| 222 | struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> { |
| 223 | using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern; |
| 224 | LogicalResult |
| 225 | matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor, |
| 226 | ConversionPatternRewriter &rewriter) const override { |
| 227 | for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc())) |
| 228 | rewriter.create<xegpu::StoreNdOp>(op.getLoc(), v, t, op.getL1HintAttr(), |
| 229 | op.getL2HintAttr(), op.getL3HintAttr()); |
| 230 | |
| 231 | rewriter.eraseOp(op: op); |
| 232 | return success(); |
| 233 | } |
| 234 | }; |
| 235 | |
| 236 | /// This pattern transforms the UpdateNdOffsetOp to update the offsets of a |
| 237 | /// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the |
| 238 | /// offsets of the new subgroup src tensor descriptors. |
| 239 | struct WgToSgUpdateNdOffsetOp |
| 240 | : public OpConversionPattern<xegpu::UpdateNdOffsetOp> { |
| 241 | using OpConversionPattern<xegpu::UpdateNdOffsetOp>::OpConversionPattern; |
| 242 | LogicalResult |
| 243 | matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor, |
| 244 | ConversionPatternRewriter &rewriter) const override { |
| 245 | llvm::SmallVector<Value> newUpdateTileOffsetOps; |
| 246 | for (auto tDesc : adaptor.getTensorDesc()) { |
| 247 | auto newUpdateTileOffsetOp = rewriter.create<xegpu::UpdateNdOffsetOp>( |
| 248 | op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(), |
| 249 | op.getConstOffsets()); |
| 250 | newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp); |
| 251 | } |
| 252 | |
| 253 | rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps}); |
| 254 | return success(); |
| 255 | } |
| 256 | }; |
| 257 | |
| 258 | /// This pattern transforms the DpasOp to work at subgroup level. |
| 259 | struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> { |
| 260 | using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern; |
| 261 | LogicalResult |
| 262 | matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor, |
| 263 | ConversionPatternRewriter &rewriter) const override { |
| 264 | Location loc = op.getLoc(); |
| 265 | VectorType resultTy = op.getResult().getType(); |
| 266 | if (resultTy.getRank() != 2) |
| 267 | return failure(); |
| 268 | |
| 269 | auto originalLayout = |
| 270 | llvm::dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout" )); |
| 271 | if (!originalLayout) |
| 272 | return failure(); |
| 273 | |
| 274 | SmallVector<Value> newDpasOps; |
| 275 | size_t i = 0; |
| 276 | for (auto aVec : adaptor.getLhs()) { |
| 277 | for (auto bVec : adaptor.getRhs()) { |
| 278 | llvm::SmallVector<Value> operands({aVec, bVec}); |
| 279 | Value tmpC; |
| 280 | if (op.getAcc()) { |
| 281 | tmpC = adaptor.getAcc()[i++]; |
| 282 | operands.push_back(tmpC); |
| 283 | } |
| 284 | |
| 285 | ArrayRef<int64_t> aVecShape = |
| 286 | llvm::cast<VectorType>(aVec.getType()).getShape(); |
| 287 | ArrayRef<int64_t> bVecShape = |
| 288 | llvm::cast<VectorType>(bVec.getType()).getShape(); |
| 289 | VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]}, |
| 290 | resultTy.getElementType()); |
| 291 | tmpC = rewriter.create<xegpu::DpasOp>( |
| 292 | loc, resTy, operands, |
| 293 | llvm::ArrayRef<NamedAttribute>( |
| 294 | {"layout_result_0" , originalLayout.dropSgLayoutAndData()})); |
| 295 | newDpasOps.push_back(tmpC); |
| 296 | } |
| 297 | } |
| 298 | rewriter.replaceOpWithMultiple(op, {newDpasOps}); |
| 299 | return success(); |
| 300 | } |
| 301 | }; |
| 302 | |
| 303 | /// This pattern transforms the PrefetchNdOp to prefetch the subgroup data. |
| 304 | struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> { |
| 305 | using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern; |
| 306 | LogicalResult |
| 307 | matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor, |
| 308 | ConversionPatternRewriter &rewriter) const override { |
| 309 | for (auto src : adaptor.getTensorDesc()) |
| 310 | rewriter.create<xegpu::PrefetchNdOp>(op.getLoc(), TypeRange(), src, |
| 311 | op->getAttrs()); |
| 312 | rewriter.eraseOp(op: op); |
| 313 | return success(); |
| 314 | } |
| 315 | }; |
| 316 | |
| 317 | } // namespace |
| 318 | |
| 319 | namespace mlir { |
| 320 | namespace xegpu { |
| 321 | void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { |
| 322 | patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp, |
| 323 | WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>( |
| 324 | arg: patterns.getContext()); |
| 325 | } |
| 326 | } // namespace xegpu |
| 327 | } // namespace mlir |
| 328 | |
| 329 | namespace { |
| 330 | struct XeGPUWgToSgDistributePass |
| 331 | : public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> { |
| 332 | void runOnOperation() override; |
| 333 | }; |
| 334 | } // namespace |
| 335 | |
| 336 | void XeGPUWgToSgDistributePass::runOnOperation() { |
| 337 | MLIRContext *ctx = &getContext(); |
| 338 | RewritePatternSet patterns(ctx); |
| 339 | ConversionTarget target(*ctx); |
| 340 | |
| 341 | auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType { |
| 342 | if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op)) |
| 343 | return createOp.getType(); |
| 344 | if (auto loadOp = dyn_cast<xegpu::LoadNdOp>(op)) |
| 345 | return loadOp.getTensorDescType(); |
| 346 | if (auto storeOp = dyn_cast<xegpu::StoreNdOp>(op)) |
| 347 | return storeOp.getTensorDescType(); |
| 348 | if (auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op)) |
| 349 | return updateOp.getType(); |
| 350 | if (auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op)) |
| 351 | return prefetchOp.getTensorDescType(); |
| 352 | return xegpu::TensorDescType(); |
| 353 | }; |
| 354 | |
| 355 | auto isLegal = [&](xegpu::LayoutAttr layout) -> bool { |
| 356 | return !layout || layout.getSgLayout() == nullptr; |
| 357 | }; |
| 358 | |
| 359 | target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp, |
| 360 | xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp, |
| 361 | xegpu::PrefetchNdOp>([=](Operation *op) -> bool { |
| 362 | auto tdescTy = getTensorDescType(op); |
| 363 | auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(tdescTy.getLayout()); |
| 364 | return isLegal(layout); |
| 365 | }); |
| 366 | |
| 367 | target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool { |
| 368 | auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout" )); |
| 369 | return isLegal(layout); |
| 370 | }); |
| 371 | |
| 372 | target.markUnknownOpDynamicallyLegal(fn: [](Operation *) { return true; }); |
| 373 | |
| 374 | xegpu::populateXeGPUWgToSgDistributePatterns(patterns); |
| 375 | if (failed( |
| 376 | applyPartialConversion(getOperation(), target, std::move(patterns)))) |
| 377 | return signalPassFailure(); |
| 378 | } |
| 379 | |