| 1 | //===- XeGPUUnroll.cpp - patterns to do unrolling ---------------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file contains patterns for unrolling XeGPU operations. It follows a |
| 10 | // similar concept and design as vector unroll patterns, serving as a complement |
| 11 | // to them. |
| 12 | // |
| 13 | //===----------------------------------------------------------------------===// |
| 14 | |
| 15 | #include "mlir/Dialect/XeGPU/Transforms/Passes.h" |
| 16 | |
| 17 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 18 | #include "mlir/Dialect/XeGPU/IR/XeGPU.h" |
| 19 | #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" |
| 20 | #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" |
| 21 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 22 | #include "llvm/ADT/STLExtras.h" |
| 23 | #include "llvm/Support/Debug.h" |
| 24 | #include <numeric> |
| 25 | |
| 26 | namespace mlir { |
| 27 | namespace xegpu { |
| 28 | #define GEN_PASS_DEF_XEGPUUNROLL |
| 29 | #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" |
| 30 | } // namespace xegpu |
| 31 | } // namespace mlir |
| 32 | |
| 33 | #define DEBUG_TYPE "xegpu-unroll" |
| 34 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
| 35 | #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") |
| 36 | |
| 37 | using namespace mlir; |
| 38 | |
| 39 | namespace { |
| 40 | |
| 41 | template <typename SourceOp> |
| 42 | struct UnrollPattern : public OpRewritePattern<SourceOp> { |
| 43 | UnrollPattern(MLIRContext *context, const xegpu::UnrollOptions &options, |
| 44 | PatternBenefit benefit = 1) |
| 45 | : OpRewritePattern<SourceOp>(context, benefit), options(options) {} |
| 46 | |
| 47 | protected: |
| 48 | /// Return the target shape for the given `op`. Return std::nullopt if the |
| 49 | /// op shouldn't be or cannot be unrolled. |
| 50 | std::optional<SmallVector<int64_t>> getTargetShape(Operation *op) const { |
| 51 | LDBG("" ); |
| 52 | LDBG("Get unroll shape for: " << *op); |
| 53 | |
| 54 | if (options.filterConstraint && failed(options.filterConstraint(op))) { |
| 55 | LDBG("--no filter constraint -> BAIL" ); |
| 56 | return std::nullopt; |
| 57 | } |
| 58 | |
| 59 | assert(options.nativeShape && |
| 60 | "expects the native shape for native shape call back function." ); |
| 61 | auto nativeShape = options.nativeShape(op); |
| 62 | return nativeShape; |
| 63 | } |
| 64 | |
| 65 | SmallVector<Type> getUnrolledTypes(ShapedType type, |
| 66 | ArrayRef<int64_t> tileShape) const { |
| 67 | return options.getUnrolledTypes(type, tileShape); |
| 68 | } |
| 69 | |
| 70 | /// Emulate the the unpack behavior using insert_strided_slice for VectorType |
| 71 | /// values and unrealized_conversion_cast for TensorDescType values. |
| 72 | Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize, |
| 73 | Location loc, PatternRewriter &rewriter) const { |
| 74 | if (auto vecTy = dyn_cast<VectorType>(destTy)) { |
| 75 | assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) && |
| 76 | "Expecting blockSize size to match the rank of destTy." ); |
| 77 | auto shape = vecTy.getShape(); |
| 78 | return xegpu::createVectorWithShapeFromValues(builder&: rewriter, loc, values: srcs, shape: shape); |
| 79 | } |
| 80 | |
| 81 | if (isa<xegpu::TensorDescType>(destTy)) { |
| 82 | auto attr = NamedAttribute(rewriter.getStringAttr(unpackAttrName), |
| 83 | rewriter.getUnitAttr()); |
| 84 | auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName), |
| 85 | rewriter.getDenseI64ArrayAttr(blockSize)); |
| 86 | auto castOp = rewriter.create<UnrealizedConversionCastOp>( |
| 87 | loc, destTy, srcs, ArrayRef<NamedAttribute>({attr, blkAttr})); |
| 88 | return castOp.getResult(0); |
| 89 | } |
| 90 | |
| 91 | llvm_unreachable("Unexpected destTy." ); |
| 92 | return Value(); |
| 93 | } |
| 94 | |
| 95 | /// Emulate the the pack behavior using extract_strided_slice for VectorType |
| 96 | /// values and unrealized_conversion_cast for TensorDescType values. |
| 97 | SmallVector<Value> pack(Value src, TypeRange destTypes, |
| 98 | ArrayRef<int64_t> blockSize, Location loc, |
| 99 | PatternRewriter &rewriter) const { |
| 100 | if (auto vecTy = dyn_cast<VectorType>(src.getType())) { |
| 101 | assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) && |
| 102 | "Expecting blockSize size to match the rank of src." ); |
| 103 | return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src, |
| 104 | blockSize); |
| 105 | } |
| 106 | |
| 107 | if (isa<xegpu::TensorDescType>(src.getType())) { |
| 108 | auto attr = NamedAttribute(rewriter.getStringAttr(packAttrName), |
| 109 | rewriter.getUnitAttr()); |
| 110 | auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName), |
| 111 | rewriter.getDenseI64ArrayAttr(blockSize)); |
| 112 | auto castOp = rewriter.create<UnrealizedConversionCastOp>( |
| 113 | loc, destTypes, src, ArrayRef<NamedAttribute>({attr, blkAttr})); |
| 114 | return castOp.getResults(); |
| 115 | } |
| 116 | |
| 117 | llvm_unreachable("Unexpected src type." ); |
| 118 | return SmallVector<Value>(); |
| 119 | } |
| 120 | |
| 121 | private: |
| 122 | const char *const packAttrName = "__xegpu_blocking_pack__" ; |
| 123 | const char *const unpackAttrName = "__xegpu_blocking_unpack__" ; |
| 124 | const char *const blockAttrName = "__xegpu_blocking_tile_shape__" ; |
| 125 | |
| 126 | xegpu::UnrollOptions options; |
| 127 | }; |
| 128 | |
| 129 | struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> { |
| 130 | using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern; |
| 131 | LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op, |
| 132 | PatternRewriter &rewriter) const override { |
| 133 | Location loc = op.getLoc(); |
| 134 | xegpu::TensorDescType tdescTy = op.getType(); |
| 135 | int64_t rank = tdescTy.getRank(); |
| 136 | ArrayRef<int64_t> shape = tdescTy.getShape(); |
| 137 | |
| 138 | std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); |
| 139 | if (!targetShape) |
| 140 | return failure(); |
| 141 | |
| 142 | auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0]; |
| 143 | |
| 144 | auto addi = [&](OpFoldResult a, int64_t b) -> Value { |
| 145 | std::optional<int64_t> maybeInt = getConstantIntValue(ofr: a); |
| 146 | if (maybeInt) { |
| 147 | return rewriter.create<arith::ConstantIndexOp>(loc, *maybeInt + b); |
| 148 | } else { |
| 149 | auto aV = llvm::cast<Value>(a); |
| 150 | auto bV = rewriter.create<arith::ConstantIndexOp>(loc, b); |
| 151 | return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV); |
| 152 | } |
| 153 | }; |
| 154 | |
| 155 | SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets(); |
| 156 | |
| 157 | // For n-D memrefs where n > rank, we need to handle the last `rank` |
| 158 | // dimensions only, and keep the first `n-rank` dimensions as is. |
| 159 | SmallVector<OpFoldResult> oldOffsets = llvm::to_vector( |
| 160 | llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank)); |
| 161 | auto validIdxes = |
| 162 | llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size()); |
| 163 | |
| 164 | SmallVector<Value> newOps; |
| 165 | for (SmallVector<int64_t> offsets : |
| 166 | StaticTileOffsetRange(shape, *targetShape)) { |
| 167 | |
| 168 | for (auto [idx, oldOff, offset] : |
| 169 | llvm::zip(validIdxes, oldOffsets, offsets)) |
| 170 | mixedOffsets[idx] = addi(oldOff, offset); |
| 171 | |
| 172 | auto newOp = rewriter.create<xegpu::CreateNdDescOp>( |
| 173 | loc, newTdescTy, op.getSource(), mixedOffsets, op.getMixedSizes(), |
| 174 | op.getMixedStrides()); |
| 175 | newOps.push_back(newOp); |
| 176 | } |
| 177 | Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter); |
| 178 | rewriter.replaceOp(op, castOp); |
| 179 | |
| 180 | return success(); |
| 181 | } |
| 182 | }; |
| 183 | |
| 184 | struct UnrollUpdateNdOffsetOp : public UnrollPattern<xegpu::UpdateNdOffsetOp> { |
| 185 | using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern; |
| 186 | LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op, |
| 187 | PatternRewriter &rewriter) const override { |
| 188 | Location loc = op.getLoc(); |
| 189 | xegpu::TensorDescType tdescTy = op.getTensorDescType(); |
| 190 | |
| 191 | std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); |
| 192 | if (!targetShape) |
| 193 | return failure(); |
| 194 | |
| 195 | SmallVector<Type> convertedTdescTypes = |
| 196 | getUnrolledTypes(tdescTy, *targetShape); |
| 197 | SmallVector<Value> convertedTdesc = pack( |
| 198 | op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); |
| 199 | |
| 200 | SmallVector<Value> newOps; |
| 201 | for (auto t : convertedTdesc) { |
| 202 | auto newOp = rewriter.create<xegpu::UpdateNdOffsetOp>( |
| 203 | loc, t.getType(), t, op.getOffsets(), op.getConstOffsets()); |
| 204 | newOps.push_back(newOp); |
| 205 | } |
| 206 | Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter); |
| 207 | rewriter.replaceOp(op, castOp); |
| 208 | return success(); |
| 209 | } |
| 210 | }; |
| 211 | |
| 212 | struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> { |
| 213 | using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern; |
| 214 | LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op, |
| 215 | PatternRewriter &rewriter) const override { |
| 216 | Location loc = op.getLoc(); |
| 217 | xegpu::TensorDescType tdescTy = op.getTensorDescType(); |
| 218 | |
| 219 | std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); |
| 220 | if (!targetShape) |
| 221 | return failure(); |
| 222 | |
| 223 | SmallVector<Type> convertedTdescTypes = |
| 224 | getUnrolledTypes(tdescTy, *targetShape); |
| 225 | SmallVector<Value> convertedTdesc = pack( |
| 226 | op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); |
| 227 | |
| 228 | for (auto t : convertedTdesc) |
| 229 | rewriter.create<xegpu::PrefetchNdOp>(loc, TypeRange(), t, op->getAttrs()); |
| 230 | |
| 231 | rewriter.eraseOp(op: op); |
| 232 | return success(); |
| 233 | } |
| 234 | }; |
| 235 | |
| 236 | struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> { |
| 237 | using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern; |
| 238 | LogicalResult matchAndRewrite(xegpu::LoadNdOp op, |
| 239 | PatternRewriter &rewriter) const override { |
| 240 | |
| 241 | Location loc = op.getLoc(); |
| 242 | VectorType valueTy = op.getType(); |
| 243 | xegpu::TensorDescType tdescTy = op.getTensorDescType(); |
| 244 | |
| 245 | std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); |
| 246 | if (!targetShape) |
| 247 | return failure(); |
| 248 | |
| 249 | Type elemTy = tdescTy.getElementType(); |
| 250 | VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy); |
| 251 | |
| 252 | SmallVector<Type> convertedTdescTypes = |
| 253 | getUnrolledTypes(tdescTy, *targetShape); |
| 254 | SmallVector<Value> convertedTdescs = pack( |
| 255 | op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); |
| 256 | |
| 257 | SmallVector<Value> newOps; |
| 258 | for (auto t : convertedTdescs) { |
| 259 | auto newOp = |
| 260 | rewriter.create<xegpu::LoadNdOp>(loc, newValueTy, t, op->getAttrs()); |
| 261 | newOps.push_back(newOp); |
| 262 | } |
| 263 | |
| 264 | Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter); |
| 265 | |
| 266 | rewriter.replaceOp(op, castOp); |
| 267 | return success(); |
| 268 | } |
| 269 | }; |
| 270 | |
| 271 | struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> { |
| 272 | using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern; |
| 273 | LogicalResult matchAndRewrite(xegpu::StoreNdOp op, |
| 274 | PatternRewriter &rewriter) const override { |
| 275 | Location loc = op.getLoc(); |
| 276 | VectorType valueTy = op.getValueType(); |
| 277 | xegpu::TensorDescType tdescTy = op.getTensorDescType(); |
| 278 | |
| 279 | std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); |
| 280 | if (!targetShape) |
| 281 | return failure(); |
| 282 | |
| 283 | SmallVector<Type> convertedValTypes = |
| 284 | getUnrolledTypes(valueTy, *targetShape); |
| 285 | SmallVector<Type> convertedTdescTypes = |
| 286 | getUnrolledTypes(tdescTy, *targetShape); |
| 287 | |
| 288 | SmallVector<Value> convertedValues = |
| 289 | pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter); |
| 290 | SmallVector<Value> convertedTdescs = pack( |
| 291 | op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); |
| 292 | |
| 293 | for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs)) |
| 294 | rewriter.create<xegpu::StoreNdOp>(loc, v, t, op.getL1HintAttr(), |
| 295 | op.getL2HintAttr(), op.getL3HintAttr()); |
| 296 | |
| 297 | rewriter.eraseOp(op: op); |
| 298 | return success(); |
| 299 | } |
| 300 | }; |
| 301 | |
| 302 | struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> { |
| 303 | using UnrollPattern<xegpu::DpasOp>::UnrollPattern; |
| 304 | LogicalResult matchAndRewrite(xegpu::DpasOp op, |
| 305 | PatternRewriter &rewriter) const override { |
| 306 | Location loc = op.getLoc(); |
| 307 | |
| 308 | // expecting every operands is a 2D Vector |
| 309 | if (llvm::any_of(op->getOperandTypes(), [&](Type type) { |
| 310 | auto vecTy = dyn_cast<VectorType>(type); |
| 311 | return !vecTy || vecTy.getRank() != 2; |
| 312 | })) |
| 313 | return failure(); |
| 314 | |
| 315 | // A vector of 3 elements should be returned, representing M, K, N |
| 316 | // respectively. |
| 317 | std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); |
| 318 | if (!targetShape || targetShape->size() != 3) |
| 319 | return failure(); |
| 320 | auto M = (*targetShape)[0]; |
| 321 | auto K = (*targetShape)[1]; |
| 322 | auto N = (*targetShape)[2]; |
| 323 | |
| 324 | int64_t aBlockSize[2] = {M, K}; |
| 325 | int64_t bBlockSize[2] = {K, N}; |
| 326 | int64_t cBlockSize[2] = {M, N}; |
| 327 | |
| 328 | auto packWrapper = [&](TypedValue<VectorType> val, |
| 329 | ArrayRef<int64_t> blockSize) { |
| 330 | VectorType type = val.getType(); |
| 331 | std::optional<SmallVector<int64_t>> grids = |
| 332 | computeShapeRatio(type.getShape(), blockSize); |
| 333 | assert(grids && "Expecting grids to be computed." ); |
| 334 | auto numNewOps = computeProduct(*grids); |
| 335 | if (numNewOps == 1) |
| 336 | return SmallVector<Value>({val}); |
| 337 | VectorType newVecTy = type.cloneWith(blockSize, type.getElementType()); |
| 338 | SmallVector<Type> convertedTypes(numNewOps, newVecTy); |
| 339 | SmallVector<Value> values = |
| 340 | pack(val, convertedTypes, blockSize, loc, rewriter); |
| 341 | return values; |
| 342 | }; |
| 343 | |
| 344 | auto a = op.getLhs(); |
| 345 | auto b = op.getRhs(); |
| 346 | auto c = op.getAcc(); |
| 347 | |
| 348 | auto aShape = a.getType().getShape(); |
| 349 | auto bShape = b.getType().getShape(); |
| 350 | |
| 351 | SmallVector<Value> aVals, bVals, cVals; |
| 352 | aVals = packWrapper(a, aBlockSize); |
| 353 | bVals = packWrapper(b, bBlockSize); |
| 354 | |
| 355 | if (c) |
| 356 | cVals = packWrapper(c, cBlockSize); |
| 357 | |
| 358 | // Skip the operation if every operand has an invalid blocking size (empty) |
| 359 | // or if the original shape matches the blocking size (size == 1). |
| 360 | auto ranges = c ? SmallVector<ValueRange>({aVals, bVals, cVals}) |
| 361 | : SmallVector<ValueRange>({aVals, bVals}); |
| 362 | if (llvm::any_of(ranges, [](auto &v) { return v.size() == 0; }) || |
| 363 | llvm::all_of(ranges, [](auto &v) { return v.size() == 1; })) |
| 364 | return failure(); |
| 365 | |
| 366 | VectorType resultTy = op.getResult().getType(); |
| 367 | auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType()); |
| 368 | |
| 369 | int64_t mIters = aShape[0] / M; |
| 370 | int64_t kIters = aShape[1] / K; |
| 371 | int64_t nIters = bShape[1] / N; |
| 372 | |
| 373 | SmallVector<Value> newOps; |
| 374 | for (int64_t i = 0; i < mIters; ++i) { |
| 375 | for (int64_t j = 0; j < nIters; ++j) { |
| 376 | Value tmpC; |
| 377 | if (c) |
| 378 | tmpC = cVals[i * nIters + j]; // init with acc |
| 379 | |
| 380 | for (int64_t k = 0; k < kIters; ++k) { |
| 381 | Value aVec = aVals[i * kIters + k]; |
| 382 | Value bVec = bVals[k * nIters + j]; |
| 383 | SmallVector<Value> operands({aVec, bVec}); |
| 384 | if (tmpC) |
| 385 | operands.push_back(tmpC); |
| 386 | |
| 387 | tmpC = rewriter.create<xegpu::DpasOp>(loc, vecTy, operands, |
| 388 | op->getAttrs()); |
| 389 | } |
| 390 | newOps.push_back(tmpC); |
| 391 | } |
| 392 | } |
| 393 | Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter); |
| 394 | rewriter.replaceOp(op, castOp); |
| 395 | return success(); |
| 396 | } |
| 397 | }; |
| 398 | |
| 399 | } // namespace |
| 400 | |
| 401 | void mlir::xegpu::populateXeGPUUnrollPatterns( |
| 402 | RewritePatternSet &patterns, const xegpu::UnrollOptions &options) { |
| 403 | patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp, |
| 404 | UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>( |
| 405 | arg: patterns.getContext(), args: options); |
| 406 | } |
| 407 | |