| 1 | //===- XeGPUSubgroupDistribute.cpp - XeGPU Subgroup Distribute Pass -------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 9 | #include "mlir/Dialect/GPU/Utils/DistributionUtils.h" |
| 10 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 11 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 12 | #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" |
| 13 | #include "mlir/Dialect/XeGPU/IR/XeGPU.h" |
| 14 | #include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h" |
| 15 | #include "mlir/Dialect/XeGPU/Transforms/Passes.h" |
| 16 | #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" |
| 17 | #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" |
| 18 | #include "mlir/IR/AffineMap.h" |
| 19 | #include "mlir/IR/Attributes.h" |
| 20 | #include "mlir/IR/Builders.h" |
| 21 | #include "mlir/IR/BuiltinAttributes.h" |
| 22 | #include "mlir/IR/BuiltinOps.h" |
| 23 | #include "mlir/IR/BuiltinTypes.h" |
| 24 | #include "mlir/IR/Operation.h" |
| 25 | #include "mlir/IR/PatternMatch.h" |
| 26 | #include "mlir/IR/TypeRange.h" |
| 27 | #include "mlir/IR/Value.h" |
| 28 | #include "mlir/IR/Visitors.h" |
| 29 | #include "mlir/Interfaces/FunctionInterfaces.h" |
| 30 | #include "mlir/Support/LLVM.h" |
| 31 | #include "mlir/Transforms/DialectConversion.h" |
| 32 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 33 | #include "mlir/Transforms/InliningUtils.h" |
| 34 | #include "llvm/ADT/ArrayRef.h" |
| 35 | #include "llvm/ADT/STLExtras.h" |
| 36 | #include "llvm/ADT/SmallVector.h" |
| 37 | |
| 38 | namespace mlir { |
| 39 | namespace xegpu { |
| 40 | #define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE |
| 41 | #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" |
| 42 | } // namespace xegpu |
| 43 | } // namespace mlir |
| 44 | |
| 45 | #define DEBUG_TYPE "xegpu-subgroup-distribute" |
| 46 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
| 47 | |
| 48 | using namespace mlir; |
| 49 | |
| 50 | static const char *const resolveSIMTTypeMismatch = |
| 51 | "resolve_simt_type_mismatch" ; // Attribute name for identifying |
| 52 | // UnrelizedConversionCastOp added to resolve |
| 53 | // SIMT type mismatches. |
| 54 | |
| 55 | namespace { |
| 56 | |
| 57 | //===----------------------------------------------------------------------===// |
| 58 | // SIMT Distribution Patterns |
| 59 | //===----------------------------------------------------------------------===// |
| 60 | |
| 61 | /// Helper function to get distributed vector type for a source vector type |
| 62 | /// according to the lane_layout. We simply divide each dimension of tensor |
| 63 | /// descriptor shape by corresponding lane_layout dimension. If |
| 64 | /// array_length > 1, that is appended to the front of the ditributed shape. |
| 65 | /// NOTE: This is the vector type that will be returned by the |
| 66 | /// gpu.warp_execute_on_lane0 op. |
| 67 | /// |
| 68 | /// Examples: |
| 69 | /// | original vector shape | lane_layout | distributed vector shape | |
| 70 | /// |-----------------------|-------------|--------------------------| |
| 71 | /// | 32x16 | [1, 16] | 32x1 | |
| 72 | /// | 32x16 | [2, 8] | 16x2 | |
| 73 | /// | 2x32x16 | [1, 16] | 2x32x1 | |
| 74 | static FailureOr<VectorType> |
| 75 | getDistVecTypeBasedOnLaneLayout(xegpu::LayoutAttr layout, |
| 76 | VectorType originalType) { |
| 77 | if (!layout) |
| 78 | return failure(); |
| 79 | |
| 80 | auto laneLayout = layout.getLaneLayout().asArrayRef(); |
| 81 | assert(originalType.getShape().size() >= laneLayout.size() && |
| 82 | "Rank of the original vector type should be greater or equal to the " |
| 83 | "size of the lane layout to distribute the vector type." ); |
| 84 | SmallVector<int64_t> distributedShape(originalType.getShape()); |
| 85 | // Only distribute the last `laneLayout.size()` dimensions. The remaining |
| 86 | // dimensions are not distributed. |
| 87 | unsigned distributionStart = originalType.getRank() - laneLayout.size(); |
| 88 | for (auto [i, dim] : llvm::enumerate(First: originalType.getShape())) { |
| 89 | if (i < distributionStart) |
| 90 | continue; |
| 91 | |
| 92 | // Check if the dimension can be distributed evenly. |
| 93 | if (dim % laneLayout[i - distributionStart] != 0) |
| 94 | return failure(); |
| 95 | distributedShape[i] = dim / laneLayout[i - distributionStart]; |
| 96 | } |
| 97 | return VectorType::get(shape: distributedShape, elementType: originalType.getElementType()); |
| 98 | } |
| 99 | |
| 100 | /// Helper function to resolve types if the distributed type out of |
| 101 | /// gpu.warp_execute_on_lane0 is different from the expected xegpu SIMT type. |
| 102 | /// Example 1: |
| 103 | /// distributed type: vector<8x1xf32> |
| 104 | /// expected type: vector<8xf32> |
| 105 | /// resolved using, |
| 106 | /// %0 = vector.shape_cast %1 : vector<8x1xf32> to vector<8xf32> |
| 107 | /// Example 2: |
| 108 | /// distributed type: xegpu.tensor_desc<8x16xf32, #xegpu.layout<...>> |
| 109 | /// expected type: xegpu.tensor_desc<8x16xf32> |
| 110 | /// resolved using, |
| 111 | /// %0 = unrealized_conversion_cast %1 : |
| 112 | /// xegpu.tensor_desc<8x16xf32, #xegpu.layout<..>> -> |
| 113 | /// xegpu.tensor_desc<8x16xf32> |
| 114 | template <typename T> |
| 115 | static Value resolveDistributedTy(Value orig, T expected, |
| 116 | PatternRewriter &rewriter) { |
| 117 | // If orig and expected types are the same, return orig. |
| 118 | if (orig.getType() == expected) |
| 119 | return orig; |
| 120 | // If orig is a vector type, create a shape cast op to reconcile the types. |
| 121 | if (isa<VectorType>(Val: orig.getType())) { |
| 122 | auto castOp = |
| 123 | rewriter.create<vector::ShapeCastOp>(orig.getLoc(), expected, orig); |
| 124 | return castOp.getResult(); |
| 125 | } |
| 126 | // If orig is a tensor descriptor type, create an unrealized conversion cast |
| 127 | // op to reconcile the types. |
| 128 | if (isa<xegpu::TensorDescType>(Val: orig.getType())) { |
| 129 | auto castOp = rewriter.create<UnrealizedConversionCastOp>(orig.getLoc(), |
| 130 | expected, orig); |
| 131 | castOp->setAttr(resolveSIMTTypeMismatch, rewriter.getUnitAttr()); |
| 132 | return castOp.getResult(0); |
| 133 | } |
| 134 | llvm_unreachable("Unsupported type for reconciliation" ); |
| 135 | return orig; |
| 136 | } |
| 137 | |
| 138 | /// Helper function to filter out the temporary layout attributes attached |
| 139 | /// during the layout assignment process. These are not needed after going to |
| 140 | /// SIMT. |
| 141 | static SmallVector<NamedAttribute> |
| 142 | removeTemporaryLayoutAttributes(ArrayRef<NamedAttribute> attrs) { |
| 143 | SmallVector<NamedAttribute> newAttrs; |
| 144 | for (NamedAttribute attr : attrs) { |
| 145 | if (!isa<xegpu::LayoutAttr>(Val: attr.getValue())) |
| 146 | newAttrs.push_back(Elt: attr); |
| 147 | } |
| 148 | return newAttrs; |
| 149 | } |
| 150 | |
| 151 | /// Helper function to check if the layout is packed. Layout is packed if it is |
| 152 | /// 2D and lane_data[0] != 1 (data packed from col dimension). |
| 153 | static bool hasPackedLayout(xegpu::LayoutAttr layout) { |
| 154 | if (layout == xegpu::LayoutAttr()) |
| 155 | return false; |
| 156 | DenseI32ArrayAttr laneData = layout.getLaneData(); |
| 157 | if (!laneData || laneData.size() != 2) |
| 158 | return false; |
| 159 | return laneData.asArrayRef()[0] != 1; |
| 160 | } |
| 161 | |
| 162 | /// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body |
| 163 | /// of the original GPUFuncOp to the new GPUFuncOp such that entire body is |
| 164 | /// contained within a WarpExecuteOnLane0Op. |
| 165 | /// Example: |
| 166 | /// |
| 167 | /// ``` |
| 168 | /// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> { |
| 169 | /// ... |
| 170 | /// ... |
| 171 | /// gpu.return %result: vector<8x16xf32> |
| 172 | /// } |
| 173 | /// ``` |
| 174 | /// To |
| 175 | /// ``` |
| 176 | /// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> { |
| 177 | /// %laneid = gpu.lane_id : index |
| 178 | /// %0 = gpu.warp_execute_on_lane_0(%laneid) -> vector<8x16xf32> { |
| 179 | /// ... |
| 180 | /// ... |
| 181 | /// gpu.yield %result: vector<8x16xf32> |
| 182 | /// } |
| 183 | /// return %0 |
| 184 | /// } |
| 185 | struct MoveFuncBodyToWarpExecuteOnLane0 |
| 186 | : public OpRewritePattern<gpu::GPUFuncOp> { |
| 187 | using OpRewritePattern<gpu::GPUFuncOp>::OpRewritePattern; |
| 188 | LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, |
| 189 | PatternRewriter &rewriter) const override { |
| 190 | // If the function only contains a single void return, skip. |
| 191 | if (llvm::all_of(Range: gpuFuncOp.getBody().getOps(), P: [](Operation &op) { |
| 192 | return isa<gpu::ReturnOp>(Val: op) && !op.getNumOperands(); |
| 193 | })) |
| 194 | return failure(); |
| 195 | // If the function already moved inside a warp_execute_on_lane0, skip. |
| 196 | if (llvm::any_of(Range: gpuFuncOp.getBody().getOps(), P: [](Operation &op) { |
| 197 | return isa<gpu::WarpExecuteOnLane0Op>(Val: op); |
| 198 | })) |
| 199 | return failure(); |
| 200 | // Create a new function with the same signature. |
| 201 | auto newGpuFunc = rewriter.create<gpu::GPUFuncOp>( |
| 202 | location: gpuFuncOp.getLoc(), args: gpuFuncOp.getName(), args: gpuFuncOp.getFunctionType()); |
| 203 | // Create a WarpExecuteOnLane0Op with same arguments and results as the |
| 204 | // original gpuFuncOp. |
| 205 | rewriter.setInsertionPointToEnd(&newGpuFunc.getFunctionBody().front()); |
| 206 | auto laneId = rewriter.create<gpu::LaneIdOp>( |
| 207 | location: newGpuFunc.getLoc(), args: rewriter.getIndexType(), |
| 208 | /** upperBound = **/ args: mlir::IntegerAttr()); |
| 209 | ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults(); |
| 210 | auto warpOp = rewriter.create<gpu::WarpExecuteOnLane0Op>( |
| 211 | location: laneId.getLoc(), args&: gpuFuncResultType, args&: laneId, |
| 212 | args: xegpu::targetinfo::subgroupSize, args: newGpuFunc.getArguments(), |
| 213 | args: newGpuFunc.getArgumentTypes()); |
| 214 | Block &warpBodyBlock = warpOp.getBodyRegion().front(); |
| 215 | // Replace the ReturnOp of the original gpu function with a YieldOp. |
| 216 | auto origRetunOp = |
| 217 | cast<gpu::ReturnOp>(Val: gpuFuncOp.getBlocks().back().getTerminator()); |
| 218 | rewriter.setInsertionPointAfter(origRetunOp); |
| 219 | rewriter.create<gpu::YieldOp>(location: origRetunOp.getLoc(), |
| 220 | args: origRetunOp.getOperands()); |
| 221 | rewriter.eraseOp(op: origRetunOp); |
| 222 | // Move the original function body to the WarpExecuteOnLane0Op body. |
| 223 | rewriter.inlineRegionBefore(region&: gpuFuncOp.getBody(), parent&: warpOp.getBodyRegion(), |
| 224 | before: warpOp.getBodyRegion().begin()); |
| 225 | rewriter.eraseBlock(block: &warpBodyBlock); |
| 226 | // Insert a new ReturnOp after the WarpExecuteOnLane0Op. |
| 227 | rewriter.setInsertionPointAfter(warpOp); |
| 228 | rewriter.create<gpu::ReturnOp>(location: newGpuFunc.getLoc(), args: warpOp.getResults()); |
| 229 | rewriter.replaceOp(op: gpuFuncOp, newOp: newGpuFunc); |
| 230 | return success(); |
| 231 | } |
| 232 | }; |
| 233 | |
| 234 | /// Distribute a create_nd_tdesc feeding into vector.yield op of the enclosing |
| 235 | /// `gpu.warp_execute_on_lane_0` region. After the sinking, the warp op will |
| 236 | /// still contain the original op that will not be used by the yield op (and |
| 237 | /// should be cleaned up later). The yield op will bypass the create_nd_tdesc's |
| 238 | /// arguments. Tensor descriptor shape is not distributed because it is a |
| 239 | /// uniform value across all work items within the subgroup. However, the |
| 240 | /// layout information is dropped in the new tensor descriptor type. |
| 241 | /// |
| 242 | /// Example: |
| 243 | /// |
| 244 | /// ``` |
| 245 | /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]> |
| 246 | /// %r = gpu.warp_execute_on_lane_0(%laneid) -> |
| 247 | /// (!xegpu.tensor_desc<4x8xf32, #layout0>) { |
| 248 | /// ... |
| 249 | /// %td = xegpu.create_nd_tdesc %arg0[0, 0] |
| 250 | /// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0> |
| 251 | /// vector.yield %td |
| 252 | /// } |
| 253 | /// ``` |
| 254 | /// To |
| 255 | /// ``` |
| 256 | /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (...) { |
| 257 | /// ... |
| 258 | /// %dead = xegpu.create_nd_tdesc %arg0[0, 0] |
| 259 | /// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0> |
| 260 | /// vector.yield %arg0, %dead |
| 261 | /// } |
| 262 | /// %td = xegpu.create_nd_tdesc %r#0[0, 0]: memref<4x8xf32> |
| 263 | /// -> !xegpu.tensor_desc<4x8xf32> |
| 264 | /// |
| 265 | /// ``` |
| 266 | struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern { |
| 267 | using gpu::WarpDistributionPattern::WarpDistributionPattern; |
| 268 | LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp, |
| 269 | PatternRewriter &rewriter) const override { |
| 270 | OpOperand *operand = |
| 271 | getWarpResult(warpOp: subgroupOp, fn: llvm::IsaPred<xegpu::CreateNdDescOp>); |
| 272 | if (!operand) |
| 273 | return rewriter.notifyMatchFailure( |
| 274 | arg&: subgroupOp, msg: "warp result is not a xegpu::CreateNdDesc op" ); |
| 275 | auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>(); |
| 276 | unsigned operandIdx = operand->getOperandNumber(); |
| 277 | |
| 278 | xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr(); |
| 279 | if (!layout) |
| 280 | return rewriter.notifyMatchFailure( |
| 281 | arg&: descOp, msg: "the tensor descriptor lacks layout attribute" ); |
| 282 | |
| 283 | SmallVector<size_t> newRetIndices; |
| 284 | SmallVector<Value> newYieldValues; |
| 285 | SmallVector<Type> newYieldTypes; |
| 286 | |
| 287 | for (Value operand : descOp->getOperands()) { |
| 288 | newYieldValues.push_back(Elt: operand); |
| 289 | newYieldTypes.push_back(Elt: operand.getType()); |
| 290 | } |
| 291 | rewriter.setInsertionPoint(subgroupOp); |
| 292 | gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
| 293 | rewriter, warpOp: subgroupOp, /* new yieled values = */ newYieldedValues: newYieldValues, |
| 294 | /* new yielded types = */ newReturnTypes: newYieldTypes, indices&: newRetIndices); |
| 295 | |
| 296 | SmallVector<Value> newDescOperands; |
| 297 | for (size_t i : newRetIndices) { |
| 298 | newDescOperands.push_back(Elt: newWarpOp.getResult(i)); |
| 299 | } |
| 300 | rewriter.setInsertionPointAfter(newWarpOp); |
| 301 | xegpu::TensorDescType distributedTensorDescTy = |
| 302 | descOp.getType().dropLayouts(); // Distributed tensor descriptor type |
| 303 | // does not contain layout info. |
| 304 | Value newDescOp = rewriter.create<xegpu::CreateNdDescOp>( |
| 305 | location: newWarpOp.getLoc(), args&: distributedTensorDescTy, args&: newDescOperands, |
| 306 | args: descOp->getAttrs()); |
| 307 | |
| 308 | Value distributedVal = newWarpOp.getResult(i: operandIdx); |
| 309 | // Resolve the distributed type to the expected type. |
| 310 | newDescOp = |
| 311 | resolveDistributedTy(orig: newDescOp, expected: distributedVal.getType(), rewriter); |
| 312 | rewriter.replaceAllUsesWith(from: distributedVal, to: newDescOp); |
| 313 | return success(); |
| 314 | } |
| 315 | }; |
| 316 | |
| 317 | /// Distribute a store_nd op at the end of enclosing |
| 318 | /// `gpu.warp_execute_on_lane_0`. In case arguments for the store are passed |
| 319 | /// through the warp op interface they would be propagated as returned values. |
| 320 | /// Source vector is distributed based on lane layout. Appropriate cast ops are |
| 321 | /// inserted if the distributed types does not match expected xegpu SIMT types. |
| 322 | /// |
| 323 | /// Example: |
| 324 | /// |
| 325 | /// ``` |
| 326 | /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]> |
| 327 | /// gpu.warp_execute_on_lane_0(%laneid) -> () { |
| 328 | /// ... |
| 329 | /// xegpu.store_nd %arg0, %arg1: vector<4x8xf32>, |
| 330 | /// !xegpu.tensor_desc<4x8xf32, #layout0> |
| 331 | /// } |
| 332 | /// ``` |
| 333 | /// To |
| 334 | /// ``` |
| 335 | /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>, |
| 336 | /// !xegpu.tensor_desc<4x8xf32, #layout0>) { |
| 337 | /// gpu.yield %arg0, %arg1: vector<4x8xf32>, !xegpu.tensor_desc<4x8xf32, |
| 338 | /// #layout0> |
| 339 | /// } |
| 340 | /// %0 = vector.shape_cast %r#0: vector<4x1xf32> to vector<4xf32> |
| 341 | /// %1 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32, |
| 342 | /// #layout0> |
| 343 | /// -> !xegpu.tensor_desc<4x8xf32> |
| 344 | /// xegpu.store_nd %0, %1: vector<4xf32>, |
| 345 | /// !xegpu.tensor_desc<4x8xf32> |
| 346 | /// |
| 347 | /// ``` |
| 348 | struct StoreNdDistribution final : public gpu::WarpDistributionPattern { |
| 349 | using gpu::WarpDistributionPattern::WarpDistributionPattern; |
| 350 | LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp, |
| 351 | PatternRewriter &rewriter) const override { |
| 352 | auto yield = cast<gpu::YieldOp>( |
| 353 | Val: subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
| 354 | Operation *lastNode = yield->getPrevNode(); |
| 355 | auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(Val: lastNode); |
| 356 | if (!storeOp) |
| 357 | return failure(); |
| 358 | |
| 359 | xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType(); |
| 360 | xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr(); |
| 361 | if (!layout) |
| 362 | return rewriter.notifyMatchFailure( |
| 363 | arg&: storeOp, msg: "the source tensor descriptor lacks layout attribute" ); |
| 364 | |
| 365 | FailureOr<VectorType> distributedTypeByWarpOpOrFailure = |
| 366 | getDistVecTypeBasedOnLaneLayout(layout, originalType: storeOp.getValueType()); |
| 367 | if (failed(Result: distributedTypeByWarpOpOrFailure)) |
| 368 | return rewriter.notifyMatchFailure(arg&: storeOp, |
| 369 | msg: "Failed to distribute the type" ); |
| 370 | VectorType distributedTypeByWarpOp = |
| 371 | distributedTypeByWarpOpOrFailure.value(); |
| 372 | |
| 373 | SmallVector<size_t> newRetIndices; |
| 374 | gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
| 375 | rewriter, warpOp: subgroupOp, |
| 376 | /* new yielded values = */ |
| 377 | newYieldedValues: ValueRange{storeOp.getValue(), storeOp.getTensorDesc()}, |
| 378 | /* new yielded types = */ |
| 379 | newReturnTypes: TypeRange{distributedTypeByWarpOp, storeOp.getTensorDescType()}, |
| 380 | indices&: newRetIndices); |
| 381 | // Create a new store op outside the warp op with the distributed vector |
| 382 | // type. Tensor descriptor is not distributed. |
| 383 | rewriter.setInsertionPointAfter(newWarpOp); |
| 384 | SmallVector<Value> newStoreOperands; |
| 385 | |
| 386 | // For the value operand, there can be a mismatch between the vector type |
| 387 | // distributed by the warp op and (xegpu-specific) distributed type |
| 388 | // supported by the store op. Type mismatch must be resolved using |
| 389 | // appropriate cast op. |
| 390 | FailureOr<VectorType> storeNdDistributedValueTyOrFailure = |
| 391 | xegpu::getDistributedVectorType(tdescTy: storeOp.getTensorDescType()); |
| 392 | if (failed(Result: storeNdDistributedValueTyOrFailure)) |
| 393 | return rewriter.notifyMatchFailure( |
| 394 | arg&: storeOp, msg: "Failed to get distributed vector type for the store op" ); |
| 395 | newStoreOperands.push_back(Elt: resolveDistributedTy( |
| 396 | orig: newWarpOp.getResult(i: newRetIndices[0]), |
| 397 | expected: storeNdDistributedValueTyOrFailure.value(), rewriter)); |
| 398 | // For the tensor descriptor operand, the layout attribute is dropped after |
| 399 | // distribution. Types needs to be resolved in this case also. |
| 400 | xegpu::TensorDescType distributedTensorDescTy = |
| 401 | storeOp.getTensorDescType().dropLayouts(); |
| 402 | newStoreOperands.push_back( |
| 403 | Elt: resolveDistributedTy(orig: newWarpOp.getResult(i: newRetIndices[1]), |
| 404 | expected: distributedTensorDescTy, rewriter)); |
| 405 | |
| 406 | rewriter.create<xegpu::StoreNdOp>( |
| 407 | location: newWarpOp.getLoc(), args: TypeRange{}, args&: newStoreOperands, |
| 408 | args: removeTemporaryLayoutAttributes(attrs: storeOp->getAttrs())); |
| 409 | rewriter.eraseOp(op: storeOp); |
| 410 | return success(); |
| 411 | } |
| 412 | }; |
| 413 | |
| 414 | /// Distribute a load_nd op feeding into vector.yield op for the enclosing |
| 415 | /// `gpu.warp_execute_on_lane_0` and put it after the warp op. |
| 416 | /// The warp op will still contain the original op that will not be used by |
| 417 | /// the yield op (and should be cleaned up later). The yield op will |
| 418 | /// bypass the load's arguments. Only the loaded vector is distributed |
| 419 | /// according to lane layout and, tensor descriptor types is not |
| 420 | /// distributed. Appropriate cast ops are inserted if the distributed types does |
| 421 | /// not match expected xegpu SIMT types. |
| 422 | /// |
| 423 | /// Example: |
| 424 | /// |
| 425 | /// ``` |
| 426 | /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]> |
| 427 | /// %r = gpu.warp_execute_on_lane_0(%laneid) -> |
| 428 | /// (vector<4x1xf32>) { |
| 429 | /// ... |
| 430 | /// %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32, #layout0> |
| 431 | /// -> |
| 432 | /// vector<4x8xf32> |
| 433 | /// gpu.yield %ld |
| 434 | /// } |
| 435 | /// ``` |
| 436 | /// To |
| 437 | /// ``` |
| 438 | /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>, |
| 439 | /// !xegpu.tensor_desc<4x8xf32, #layout0>) { |
| 440 | /// ... |
| 441 | /// %dead = xegpu.load_nd %arg0: !xegpu.tensor_desc<4x8xf32, #layout0> -> |
| 442 | /// vector<4x8xf32> gpu.yield %dead, %arg0 |
| 443 | /// } |
| 444 | /// %0 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32, |
| 445 | /// #layout0> -> !xegpu.tensor_desc<4x8xf32> |
| 446 | /// %1 = xegpu.load_nd %0: !xegpu.tensor_desc<4x8xf32> -> vector<4xf32> |
| 447 | /// %2 = vector.shape_cast %r#0: vector<4xf32> to vector<4x1xf32> |
| 448 | /// |
| 449 | /// ``` |
| 450 | struct LoadNdDistribution final : public gpu::WarpDistributionPattern { |
| 451 | using gpu::WarpDistributionPattern::WarpDistributionPattern; |
| 452 | LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp, |
| 453 | PatternRewriter &rewriter) const override { |
| 454 | OpOperand *operand = |
| 455 | getWarpResult(warpOp: subgroupOp, fn: llvm::IsaPred<xegpu::LoadNdOp>); |
| 456 | if (!operand) |
| 457 | return rewriter.notifyMatchFailure( |
| 458 | arg&: subgroupOp, msg: "warp result is not a xegpu::LoadNd op" ); |
| 459 | // Make sure the load op is the last operation in the warp op body. This |
| 460 | // ensure that load op is not sinked earlier violating any barrier |
| 461 | // synchronizations. |
| 462 | auto yield = cast<gpu::YieldOp>( |
| 463 | Val: subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
| 464 | Operation *lastNode = yield->getPrevNode(); |
| 465 | if (!dyn_cast_or_null<xegpu::LoadNdOp>(Val: lastNode)) |
| 466 | return failure(); |
| 467 | |
| 468 | auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>(); |
| 469 | xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType(); |
| 470 | xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr(); |
| 471 | if (!layout) |
| 472 | return rewriter.notifyMatchFailure( |
| 473 | arg&: loadOp, msg: "the source tensor descriptor lacks layout attribute" ); |
| 474 | |
| 475 | unsigned operandIdx = operand->getOperandNumber(); |
| 476 | VectorType distributedTypeByWarpOp = |
| 477 | cast<VectorType>(Val: subgroupOp.getResult(i: operandIdx).getType()); |
| 478 | |
| 479 | SmallVector<size_t> newRetIndices; |
| 480 | gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
| 481 | rewriter, warpOp: subgroupOp, |
| 482 | /* new yielded values = */ newYieldedValues: loadOp.getTensorDesc(), |
| 483 | /* new yielded types = */ newReturnTypes: tensorDescTy, indices&: newRetIndices); |
| 484 | |
| 485 | // Create a new load op outside the warp op with the distributed vector |
| 486 | // type. |
| 487 | rewriter.setInsertionPointAfter(newWarpOp); |
| 488 | FailureOr<VectorType> loadNdDistValueTyOrFailure = |
| 489 | xegpu::getDistributedVectorType(tdescTy: loadOp.getTensorDescType()); |
| 490 | if (failed(Result: loadNdDistValueTyOrFailure)) |
| 491 | return rewriter.notifyMatchFailure( |
| 492 | arg&: loadOp, msg: "Failed to get distributed vector type for the load op" ); |
| 493 | xegpu::TensorDescType distributedTensorDescTy = |
| 494 | loadOp.getTensorDescType().dropLayouts(); // Distributed tensor |
| 495 | // descriptor type does not |
| 496 | // contain layout info. |
| 497 | auto newLoadOp = rewriter.create<xegpu::LoadNdOp>( |
| 498 | location: newWarpOp.getLoc(), args&: loadNdDistValueTyOrFailure.value(), |
| 499 | args: resolveDistributedTy(orig: newWarpOp->getResult(idx: newRetIndices[0]), |
| 500 | expected: distributedTensorDescTy, rewriter), |
| 501 | args: removeTemporaryLayoutAttributes(attrs: loadOp->getAttrs())); |
| 502 | // Set the packed attribute if the layout requires it. |
| 503 | newLoadOp.setPacked(hasPackedLayout(layout)); |
| 504 | Value distributedVal = newWarpOp.getResult(i: operandIdx); |
| 505 | // There can be a conflict between the vector type distributed by the |
| 506 | // warp op and (xegpu-specific) distributed type supported by the load |
| 507 | // op. Resolve these mismatches by inserting a cast. |
| 508 | Value tyResolvedVal = resolveDistributedTy( |
| 509 | orig: newLoadOp.getResult(), expected: distributedTypeByWarpOp, rewriter); |
| 510 | rewriter.replaceAllUsesWith(from: distributedVal, to: tyResolvedVal); |
| 511 | return success(); |
| 512 | } |
| 513 | }; |
| 514 | |
| 515 | /// Distribute a dpas op feeding into vector.yield op for the enclosing |
| 516 | /// `gpu.warp_execute_on_lane_0` and put it after the warp op. |
| 517 | /// The warp op will still contain the original op that will not be used by |
| 518 | /// the yield op (and should be cleaned up later). The yield op will |
| 519 | /// bypass the dpas's arguments. Appropriate cast ops are inserted if the |
| 520 | /// distributed types does not match expected xegpu SIMT types. |
| 521 | /// Example: |
| 522 | /// ``` |
| 523 | /// #lo_a = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]> |
| 524 | /// #lo_b = #xegpu.layout<wi_layout = [1, 16], wi_data = [2, 1]> |
| 525 | /// #lo_c = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]> |
| 526 | /// %r = gpu.warp_execute_on_lane_0(%laneid) -> |
| 527 | /// (vector<8x1xf32>) { |
| 528 | /// ... |
| 529 | /// %dpas = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16> -> |
| 530 | /// vector<8x16xf32> |
| 531 | /// gpu.yield %dpas |
| 532 | /// } |
| 533 | /// ``` |
| 534 | /// To |
| 535 | /// ``` |
| 536 | /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<8x1xf32>, |
| 537 | /// vector<8x1xf16>, vector<16x1xf16>) { |
| 538 | /// ... |
| 539 | /// %dead = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16> |
| 540 | /// -> vector<8x16xf32> |
| 541 | /// gpu.yield %dead, %arg0, %arg1 |
| 542 | /// } |
| 543 | /// %0 = vector.shape_cast %r#1: vector<8x1xf16> to vector<8xf16> |
| 544 | /// %1 = vector.shape_cast %r#2: vector<16x1xf16> to vector<16xf16> |
| 545 | /// %2 = xegpu.dpas %0, %1: vector<8xf16>, vector<16xf16> -> |
| 546 | /// vector<8xf32> |
| 547 | /// %dpas = vector.shape_cast %2: vector<8xf32> to vector<8x1xf32> |
| 548 | /// ``` |
| 549 | struct DpasDistribution final : public gpu::WarpDistributionPattern { |
| 550 | using gpu::WarpDistributionPattern::WarpDistributionPattern; |
| 551 | LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp, |
| 552 | PatternRewriter &rewriter) const override { |
| 553 | OpOperand *operand = |
| 554 | getWarpResult(warpOp: subgroupOp, fn: llvm::IsaPred<xegpu::DpasOp>); |
| 555 | if (!operand) |
| 556 | return rewriter.notifyMatchFailure(arg&: subgroupOp, |
| 557 | msg: "warp result is not a xegpu::Dpas op" ); |
| 558 | |
| 559 | auto dpasOp = operand->get().getDefiningOp<xegpu::DpasOp>(); |
| 560 | unsigned operandIdx = operand->getOperandNumber(); |
| 561 | std::string layoutAName = xegpu::getLayoutName(operand: dpasOp->getOpOperand(idx: 0)); |
| 562 | std::string layoutBName = xegpu::getLayoutName(operand: dpasOp->getOpOperand(idx: 1)); |
| 563 | std::string layoutCName = xegpu::getLayoutName(result: dpasOp->getOpResult(idx: 0)); |
| 564 | |
| 565 | xegpu::LayoutAttr layoutA = |
| 566 | dpasOp->getAttrOfType<xegpu::LayoutAttr>(name: layoutAName); |
| 567 | xegpu::LayoutAttr layoutB = |
| 568 | dpasOp->getAttrOfType<xegpu::LayoutAttr>(name: layoutBName); |
| 569 | xegpu::LayoutAttr layoutOut = |
| 570 | dpasOp->getAttrOfType<xegpu::LayoutAttr>(name: layoutCName); |
| 571 | if (!layoutA || !layoutB || !layoutOut) |
| 572 | return rewriter.notifyMatchFailure( |
| 573 | arg&: dpasOp, |
| 574 | msg: "the xegpu::Dpas op lacks layout attribute for A, B or output" ); |
| 575 | |
| 576 | FailureOr<VectorType> distLhsTypeByWarpOpOrFailure = |
| 577 | getDistVecTypeBasedOnLaneLayout(layout: layoutA, originalType: dpasOp.getLhsType()); |
| 578 | FailureOr<VectorType> distRhsTypeByWarpOpOrFailure = |
| 579 | getDistVecTypeBasedOnLaneLayout(layout: layoutB, originalType: dpasOp.getRhsType()); |
| 580 | FailureOr<VectorType> distResultTypeByWarpOpOrFailure = |
| 581 | getDistVecTypeBasedOnLaneLayout(layout: layoutOut, originalType: dpasOp.getResultType()); |
| 582 | if (failed(Result: distLhsTypeByWarpOpOrFailure) || |
| 583 | failed(Result: distRhsTypeByWarpOpOrFailure) || |
| 584 | failed(Result: distResultTypeByWarpOpOrFailure)) |
| 585 | return rewriter.notifyMatchFailure( |
| 586 | arg&: dpasOp, |
| 587 | msg: "Failed to distribute the A, B or output types in xegpu::Dpas op" ); |
| 588 | |
| 589 | llvm::SmallVector<Value, 3> newYieldValues{dpasOp.getLhs(), |
| 590 | dpasOp.getRhs()}; |
| 591 | llvm::SmallVector<Type, 3> newYieldTypes{ |
| 592 | distLhsTypeByWarpOpOrFailure.value(), |
| 593 | distRhsTypeByWarpOpOrFailure.value()}; |
| 594 | // Dpas acc operand is optional. |
| 595 | if (dpasOp.getAcc()) { |
| 596 | newYieldValues.push_back(Elt: dpasOp.getAcc()); |
| 597 | newYieldTypes.push_back(Elt: distResultTypeByWarpOpOrFailure.value()); |
| 598 | } |
| 599 | // Create a new warp op without the dpas. |
| 600 | SmallVector<size_t> newRetIndices; |
| 601 | gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
| 602 | rewriter, warpOp: subgroupOp, newYieldedValues: newYieldValues, newReturnTypes: newYieldTypes, indices&: newRetIndices); |
| 603 | |
| 604 | FailureOr<VectorType> expectedDistLhsTyOrFailure = |
| 605 | xegpu::getDistributedVectorType(originalType: dpasOp.getLhsType(), layout: layoutA); |
| 606 | FailureOr<VectorType> expectedDistRhsTyOrFailure = |
| 607 | xegpu::getDistributedVectorType(originalType: dpasOp.getRhsType(), layout: layoutB); |
| 608 | FailureOr<VectorType> expectedDistResultTyOrFailure = |
| 609 | xegpu::getDistributedVectorType(originalType: dpasOp.getResultType(), layout: layoutOut); |
| 610 | if (failed(Result: expectedDistLhsTyOrFailure) || |
| 611 | failed(Result: expectedDistRhsTyOrFailure) || |
| 612 | failed(Result: expectedDistResultTyOrFailure)) |
| 613 | return rewriter.notifyMatchFailure( |
| 614 | arg&: dpasOp, |
| 615 | msg: "Failed to get distributed vector type for the dpas operands." ); |
| 616 | // Create a new dpas op outside the warp op. |
| 617 | rewriter.setInsertionPointAfter(newWarpOp); |
| 618 | SmallVector<Value> newDpasOperands; |
| 619 | SmallVector<VectorType> newDpasOperandExpectedTypes; |
| 620 | |
| 621 | // Resolve the distributed types with the original types. |
| 622 | newDpasOperandExpectedTypes.push_back(Elt: expectedDistLhsTyOrFailure.value()); |
| 623 | newDpasOperandExpectedTypes.push_back(Elt: expectedDistRhsTyOrFailure.value()); |
| 624 | VectorType distributedResultTy = expectedDistResultTyOrFailure.value(); |
| 625 | if (dpasOp.getAcc()) |
| 626 | newDpasOperandExpectedTypes.push_back(Elt: distributedResultTy); |
| 627 | |
| 628 | for (unsigned i = 0; i < newRetIndices.size(); i++) { |
| 629 | newDpasOperands.push_back( |
| 630 | Elt: resolveDistributedTy(orig: newWarpOp.getResult(i: newRetIndices[i]), |
| 631 | expected: newDpasOperandExpectedTypes[i], rewriter)); |
| 632 | } |
| 633 | Value newDpasOp = rewriter.create<xegpu::DpasOp>( |
| 634 | location: newWarpOp->getLoc(), args&: distributedResultTy, args&: newDpasOperands, |
| 635 | args: removeTemporaryLayoutAttributes(attrs: dpasOp->getAttrs())); |
| 636 | Value distributedVal = newWarpOp.getResult(i: operandIdx); |
| 637 | // Resolve the output type. |
| 638 | newDpasOp = resolveDistributedTy( |
| 639 | orig: newDpasOp, expected: distResultTypeByWarpOpOrFailure.value(), rewriter); |
| 640 | rewriter.replaceAllUsesWith(from: distributedVal, to: newDpasOp); |
| 641 | return success(); |
| 642 | } |
| 643 | }; |
| 644 | |
| 645 | /// Sink an update_nd_offset op feeding into yield op of an enclosing |
| 646 | /// `gpu.warp_execute_on_lane_0` region. The warp op will still contain the |
| 647 | /// original op that will not be used by the yield op (and should be cleaned |
| 648 | /// up later). The yield op will bypass the updateOp's arguments. The tensor |
| 649 | /// descriptor type is not distributed. Appropriate cast ops are inserted if |
| 650 | /// the distributed types does not match expected xegpu SIMT types. |
| 651 | /// Example: |
| 652 | /// ``` |
| 653 | /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]> |
| 654 | /// %r = gpu.warp_execute_on_lane_0(%laneid) -> |
| 655 | /// (!xegpu.tensor_desc<4x8xf32, #layout0>) { |
| 656 | /// ... |
| 657 | /// %update = xegpu.update_nd_offset %arg0, [%c32, %c16]: |
| 658 | /// !xegpu.tensor_desc<4x8xf32, #layout0> |
| 659 | /// gpu.yield %update |
| 660 | /// } |
| 661 | /// ... |
| 662 | /// ``` |
| 663 | /// To |
| 664 | /// ``` |
| 665 | /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> ( |
| 666 | /// !xegpu.tensor_desc<4x8xf32, #layout0>, |
| 667 | /// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) { |
| 668 | /// ... |
| 669 | /// %dead = xegpu.update_nd_offset %arg0, [%c32, %c16]: |
| 670 | /// !xegpu.tensor_desc<4x8xf32, #layout0> gpu.yield %dead, %arg0 |
| 671 | /// gpu.yield %dead, %arg0, %c32, %c16 |
| 672 | /// } |
| 673 | /// %0 = xegpu.unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32, |
| 674 | /// #layout0> -> !xegpu.tensor_desc<4x8xf32> |
| 675 | /// %1 = xegpu.update_nd_offset %0, [%r#2, %r#3]: |
| 676 | /// !xegpu.tensor_desc<4x8xf32> |
| 677 | /// ... |
| 678 | /// ``` |
| 679 | struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern { |
| 680 | using gpu::WarpDistributionPattern::WarpDistributionPattern; |
| 681 | LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp, |
| 682 | PatternRewriter &rewriter) const override { |
| 683 | OpOperand *operand = |
| 684 | getWarpResult(warpOp: subgroupOp, fn: llvm::IsaPred<xegpu::UpdateNdOffsetOp>); |
| 685 | if (!operand) |
| 686 | return rewriter.notifyMatchFailure( |
| 687 | arg&: subgroupOp, msg: "warp result is not a xegpu::UpdateNdOffset op" ); |
| 688 | auto updateOp = operand->get().getDefiningOp<xegpu::UpdateNdOffsetOp>(); |
| 689 | unsigned operandIdx = operand->getOperandNumber(); |
| 690 | // new update op does not have layout attribute. |
| 691 | xegpu::TensorDescType newTensorDescTy = |
| 692 | updateOp.getTensorDescType().dropLayouts(); |
| 693 | |
| 694 | SmallVector<Value, 3> newYieldValues; |
| 695 | SmallVector<Type, 3> newYieldTypes; |
| 696 | for (Value operand : updateOp->getOperands()) { |
| 697 | newYieldValues.push_back(Elt: operand); |
| 698 | if (isa<xegpu::TensorDescType>(Val: operand.getType())) { |
| 699 | newYieldTypes.push_back(Elt: newTensorDescTy); |
| 700 | } else { |
| 701 | newYieldTypes.push_back(Elt: operand.getType()); |
| 702 | } |
| 703 | } |
| 704 | SmallVector<size_t> newRetIndices; |
| 705 | gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
| 706 | rewriter, warpOp: subgroupOp, newYieldedValues: newYieldValues, newReturnTypes: newYieldTypes, indices&: newRetIndices); |
| 707 | rewriter.setInsertionPointAfter(newWarpOp); |
| 708 | SmallVector<Value> newUpdateOperands; |
| 709 | for (size_t i : newRetIndices) { |
| 710 | // For the tensor descriptor operand, the layout attribute is dropped |
| 711 | // after distribution. Types needs to be resolved in this case. |
| 712 | if (isa<xegpu::TensorDescType>(Val: newWarpOp.getResult(i).getType())) { |
| 713 | newUpdateOperands.push_back(Elt: resolveDistributedTy( |
| 714 | orig: newWarpOp.getResult(i), expected: newTensorDescTy, rewriter)); |
| 715 | } else { |
| 716 | newUpdateOperands.push_back(Elt: newWarpOp.getResult(i)); |
| 717 | } |
| 718 | } |
| 719 | // Create a new update op outside the warp op. |
| 720 | Value newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>( |
| 721 | location: newWarpOp.getLoc(), args&: newTensorDescTy, args&: newUpdateOperands, |
| 722 | args: removeTemporaryLayoutAttributes(attrs: updateOp->getAttrs())); |
| 723 | Value distributedVal = newWarpOp.getResult(i: operandIdx); |
| 724 | // Resolve the distributed type with the original type. |
| 725 | newUpdateOp = |
| 726 | resolveDistributedTy(orig: newUpdateOp, expected: distributedVal.getType(), rewriter); |
| 727 | rewriter.replaceAllUsesWith(from: distributedVal, to: newUpdateOp); |
| 728 | return success(); |
| 729 | } |
| 730 | }; |
| 731 | |
| 732 | /// Distribute a prefetch_nd op at the end of enclosing |
| 733 | /// `gpu.warp_execute_on_lane_0`. In case arguments for the prefetch are passed |
| 734 | /// through the warp op interface they would be propagated as returned values. |
| 735 | /// Tensor descriptor shape is not distributed because it is a uniform value |
| 736 | /// across all work items within the subgroup. Appropriate cast ops are inserted |
| 737 | /// if the distributed types does not match expected xegpu SIMT types. |
| 738 | /// |
| 739 | /// Example: |
| 740 | /// |
| 741 | /// ``` |
| 742 | /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]> |
| 743 | /// gpu.warp_execute_on_lane_0(%laneid) -> () { |
| 744 | /// ... |
| 745 | /// xegpu.prefetch_nd %arg0 : !xegpu.tensor_desc<4x8xf32, #layout0> |
| 746 | /// } |
| 747 | /// ``` |
| 748 | /// To |
| 749 | /// ``` |
| 750 | /// %r:1 = gpu.warp_execute_on_lane_0(%laneid) -> ( |
| 751 | /// !xegpu.tensor_desc<4x8xf32, #layout0>) { |
| 752 | /// gpu.yield %arg0: !xegpu.tensor_desc<4x8xf32, #layout0> |
| 753 | /// } |
| 754 | /// %1 = unrealized_conversion_cast %r#0: !xegpu.tensor_desc<4x8xf32, |
| 755 | /// #layout0> -> !xegpu.tensor_desc<4x8xf32> |
| 756 | /// xegpu.prefetch_nd %1 : !xegpu.tensor_desc<4x8xf32> |
| 757 | /// |
| 758 | /// ``` |
| 759 | struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern { |
| 760 | using gpu::WarpDistributionPattern::WarpDistributionPattern; |
| 761 | LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp, |
| 762 | PatternRewriter &rewriter) const override { |
| 763 | auto yield = cast<gpu::YieldOp>( |
| 764 | Val: subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
| 765 | Operation *lastNode = yield->getPrevNode(); |
| 766 | auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(Val: lastNode); |
| 767 | if (!prefetchOp) |
| 768 | return failure(); |
| 769 | xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr(); |
| 770 | if (!layout) |
| 771 | return rewriter.notifyMatchFailure( |
| 772 | arg&: prefetchOp, msg: "the source tensor descriptor lacks layout attribute" ); |
| 773 | |
| 774 | SmallVector<Value, 1> newYieldValues = {prefetchOp.getTensorDesc()}; |
| 775 | SmallVector<Type, 1> newYieldTypes = {prefetchOp.getTensorDescType()}; |
| 776 | SmallVector<size_t> newRetIndices; |
| 777 | gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
| 778 | rewriter, warpOp: subgroupOp, newYieldedValues: newYieldValues, newReturnTypes: newYieldTypes, indices&: newRetIndices); |
| 779 | // Create a new prefetch op outside the warp op with updated tensor |
| 780 | // descriptor type. Source tensor descriptor require type resolution. |
| 781 | xegpu::TensorDescType newTensorDescTy = |
| 782 | prefetchOp.getTensorDescType().dropLayouts(); |
| 783 | rewriter.setInsertionPointAfter(newWarpOp); |
| 784 | SmallVector<Value> newPrefetchOperands = {resolveDistributedTy( |
| 785 | orig: newWarpOp.getResult(i: newRetIndices[0]), expected: newTensorDescTy, rewriter)}; |
| 786 | rewriter.create<xegpu::PrefetchNdOp>( |
| 787 | location: newWarpOp.getLoc(), args: TypeRange{}, args&: newPrefetchOperands, |
| 788 | args: removeTemporaryLayoutAttributes(attrs: prefetchOp->getAttrs())); |
| 789 | rewriter.eraseOp(op: prefetchOp); |
| 790 | return success(); |
| 791 | } |
| 792 | }; |
| 793 | |
| 794 | /// Sink a gpu::BarrierOp at the end of enclosing `gpu.warp_execute_on_lane_0` |
| 795 | /// region. This will simply move the barrier op outside of the warp op. |
| 796 | struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern { |
| 797 | using gpu::WarpDistributionPattern::WarpDistributionPattern; |
| 798 | LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp, |
| 799 | PatternRewriter &rewriter) const override { |
| 800 | auto yield = cast<gpu::YieldOp>( |
| 801 | Val: subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
| 802 | Operation *lastNode = yield->getPrevNode(); |
| 803 | // The last node must be a gpu::BarrierOp. |
| 804 | auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(Val: lastNode); |
| 805 | if (!barrierOp) |
| 806 | return failure(); |
| 807 | // Move the barrier op outside of the warp op. |
| 808 | rewriter.setInsertionPointAfter(subgroupOp); |
| 809 | rewriter.create<gpu::BarrierOp>( |
| 810 | location: barrierOp.getLoc(), args: barrierOp->getResultTypes(), |
| 811 | args: barrierOp->getOperands(), args: barrierOp->getAttrs()); |
| 812 | rewriter.eraseOp(op: barrierOp); |
| 813 | return success(); |
| 814 | } |
| 815 | }; |
| 816 | |
| 817 | } // namespace |
| 818 | |
| 819 | namespace { |
| 820 | struct XeGPUSubgroupDistributePass final |
| 821 | : public xegpu::impl::XeGPUSubgroupDistributeBase< |
| 822 | XeGPUSubgroupDistributePass> { |
| 823 | void runOnOperation() override; |
| 824 | }; |
| 825 | } // namespace |
| 826 | |
| 827 | void xegpu::populateXeGPUSubgroupDistributePatterns( |
| 828 | RewritePatternSet &patterns) { |
| 829 | patterns.add<CreateNdDescDistribution, StoreNdDistribution, |
| 830 | LoadNdDistribution, DpasDistribution, PrefetchNdDistribution, |
| 831 | UpdateNdOffsetDistribution, GpuBarrierDistribution>( |
| 832 | arg: patterns.getContext()); |
| 833 | } |
| 834 | |
| 835 | void XeGPUSubgroupDistributePass::runOnOperation() { |
| 836 | // Step 1: Attach layouts to op operands. |
| 837 | // TODO: Following assumptions are made: |
| 838 | // 1) It is assumed that there are no layout conflicts. |
| 839 | // 2) Any existing layout attributes attached to the operands are ignored. |
| 840 | Operation *op = getOperation(); |
| 841 | op->walk(callback: [&](Operation *op) { |
| 842 | for (OpOperand &operand : op->getOpOperands()) { |
| 843 | // Layouts are needed for vector type only. |
| 844 | if (!isa<VectorType>(Val: operand.get().getType())) |
| 845 | continue; |
| 846 | |
| 847 | xegpu::LayoutAttr layout = xegpu::getLayoutAttr(opr: operand); |
| 848 | if (!layout) { |
| 849 | op->emitError(message: "Could not find layout attribute for operand " ) |
| 850 | << operand.getOperandNumber() << " of operation " << op->getName(); |
| 851 | signalPassFailure(); |
| 852 | return; |
| 853 | } |
| 854 | xegpu::setLayoutAttr(operandOrResult: operand, layout); |
| 855 | } |
| 856 | }); |
| 857 | // Step 2: Move all operations of a GPU function inside |
| 858 | // gpu.warp_execute_on_lane_0 operation. |
| 859 | { |
| 860 | RewritePatternSet patterns(&getContext()); |
| 861 | patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(arg: &getContext()); |
| 862 | |
| 863 | if (failed(Result: applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns)))) { |
| 864 | signalPassFailure(); |
| 865 | return; |
| 866 | } |
| 867 | // At this point, we have moved the entire function body inside the |
| 868 | // warpOp. Now move any scalar uniform code outside of the warpOp (like |
| 869 | // GPU index ops, scalar constants, etc.). This will simplify the |
| 870 | // later lowering and avoid custom patterns for these ops. |
| 871 | getOperation()->walk(callback: [&](Operation *op) { |
| 872 | if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(Val: op)) |
| 873 | vector::moveScalarUniformCode(op: warpOp); |
| 874 | }); |
| 875 | } |
| 876 | // Step 3: Apply subgroup to workitem distribution patterns. |
| 877 | RewritePatternSet patterns(&getContext()); |
| 878 | xegpu::populateXeGPUSubgroupDistributePatterns(patterns); |
| 879 | // distributionFn is used by vector distribution patterns to determine the |
| 880 | // distributed vector type for a given vector value. In XeGPU subgroup |
| 881 | // distribution context, we compute this based on lane layout. |
| 882 | auto distributionFn = [](Value val) { |
| 883 | VectorType vecType = dyn_cast<VectorType>(Val: val.getType()); |
| 884 | int64_t vecRank = vecType ? vecType.getRank() : 0; |
| 885 | if (vecRank == 0) |
| 886 | return AffineMap::get(context: val.getContext()); |
| 887 | // Get the layout of the vector type. |
| 888 | xegpu::LayoutAttr layout = xegpu::getLayoutAttr(value: val); |
| 889 | // If no layout is specified, assume the inner most dimension is distributed |
| 890 | // for now. |
| 891 | if (!layout) |
| 892 | return AffineMap::getMultiDimMapWithTargets( |
| 893 | numDims: vecRank, targets: {static_cast<unsigned int>(vecRank - 1)}, context: val.getContext()); |
| 894 | SmallVector<unsigned int> distributedDims; |
| 895 | // Get the distributed dimensions based on the layout. |
| 896 | ArrayRef<int> laneLayout = layout.getLaneLayout().asArrayRef(); |
| 897 | for (unsigned i = 0; i < laneLayout.size(); ++i) { |
| 898 | if (laneLayout[i] > 1) |
| 899 | distributedDims.push_back(Elt: i); |
| 900 | } |
| 901 | return AffineMap::getMultiDimMapWithTargets(numDims: vecRank, targets: distributedDims, |
| 902 | context: val.getContext()); |
| 903 | }; |
| 904 | // TODO: shuffleFn is not used. |
| 905 | auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx, |
| 906 | int64_t warpSz) { return Value(); }; |
| 907 | vector::populatePropagateWarpVectorDistributionPatterns( |
| 908 | pattern&: patterns, distributionMapFn: distributionFn, warpShuffleFromIdxFn: shuffleFn); |
| 909 | if (failed(Result: applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns)))) { |
| 910 | signalPassFailure(); |
| 911 | return; |
| 912 | } |
| 913 | |
| 914 | // Step 4: Finllay, clean up UnrealizedConversionCastOps that were inserted |
| 915 | // due to tensor desc type mismatches created by using upstream distribution |
| 916 | // patterns (scf.for) |
| 917 | getOperation()->walk(callback: [&](mlir::UnrealizedConversionCastOp op) { |
| 918 | // We are only interested in UnrealizedConversionCastOps there were added |
| 919 | // for resolving SIMT type mismatches. |
| 920 | if (!op->getAttr(name: resolveSIMTTypeMismatch)) |
| 921 | return WalkResult::skip(); |
| 922 | |
| 923 | Value input = op.getOperand(i: 0); |
| 924 | Value output = op.getResult(i: 0); |
| 925 | |
| 926 | // Both input and output must have tensor descriptor types. |
| 927 | xegpu::TensorDescType inputDescType = |
| 928 | mlir::dyn_cast<xegpu::TensorDescType>(Val: input.getType()); |
| 929 | xegpu::TensorDescType outputDescType = |
| 930 | mlir::dyn_cast<xegpu::TensorDescType>(Val: output.getType()); |
| 931 | assert(inputDescType && outputDescType && |
| 932 | "Unrealized conversion cast must have tensor descriptor types" ); |
| 933 | |
| 934 | // tensor_desc<shape, layout> -> tensor_desc<shape> Type of conversions. |
| 935 | // This occurs iside scf.for body to resolve the block argument type to |
| 936 | // SIMT type. |
| 937 | if (inputDescType.getLayout()) { |
| 938 | auto argument = mlir::dyn_cast<mlir::BlockArgument>(Val&: input); |
| 939 | if (argument) { |
| 940 | argument.setType(output.getType()); |
| 941 | output.replaceAllUsesWith(newValue: argument); |
| 942 | if (auto loopOp = mlir::dyn_cast<mlir::LoopLikeOpInterface>( |
| 943 | Val: argument.getOwner()->getParentOp())) { |
| 944 | auto result = loopOp.getTiedLoopResult(bbArg: argument); |
| 945 | result.setType(output.getType()); |
| 946 | } |
| 947 | } |
| 948 | } |
| 949 | |
| 950 | // tensor_desc<shape> -> tensor_desc<shape, layout> Type of |
| 951 | // conversions. This occurs at the yield op of scf.for body to go back |
| 952 | // from SIMT type to original type. |
| 953 | if (outputDescType.getLayout()) |
| 954 | output.replaceAllUsesWith(newValue: input); |
| 955 | |
| 956 | if (op->use_empty()) |
| 957 | op->erase(); |
| 958 | return WalkResult::advance(); |
| 959 | }); |
| 960 | } |
| 961 | |