1//===- XeGPUUnroll.cpp - patterns to do unrolling ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains patterns for unrolling XeGPU operations. It follows a
10// similar concept and design as vector unroll patterns, serving as a complement
11// to them.
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Dialect/Utils/IndexingUtils.h"
16#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
17#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
18#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/Support/Debug.h"
21
22namespace mlir {
23namespace xegpu {
24#define GEN_PASS_DEF_XEGPUUNROLL
25#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
26} // namespace xegpu
27} // namespace mlir
28
29#define DEBUG_TYPE "xegpu-unroll"
30#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
31#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
32
33using namespace mlir;
34
35namespace {
36
37template <typename SourceOp>
38struct UnrollPattern : public OpRewritePattern<SourceOp> {
39 UnrollPattern(MLIRContext *context, const xegpu::UnrollOptions &options,
40 PatternBenefit benefit = 1)
41 : OpRewritePattern<SourceOp>(context, benefit), options(options) {}
42
43protected:
44 /// Return the target shape for the given `op`. Return std::nullopt if the
45 /// op shouldn't be or cannot be unrolled.
46 std::optional<SmallVector<int64_t>> getTargetShape(Operation *op) const {
47 LDBG("");
48 LDBG("Get unroll shape for: " << *op);
49
50 if (options.filterConstraint && failed(Result: options.filterConstraint(op))) {
51 LDBG("--no filter constraint -> BAIL");
52 return std::nullopt;
53 }
54
55 assert(options.nativeShape &&
56 "expects the native shape for native shape call back function.");
57 auto nativeShape = options.nativeShape(op);
58 return nativeShape;
59 }
60
61 SmallVector<Type> getUnrolledTypes(ShapedType type,
62 ArrayRef<int64_t> tileShape) const {
63 return options.getUnrolledTypes(type, tileShape);
64 }
65
66 /// Emulate the the unpack behavior using insert_strided_slice for VectorType
67 /// values and unrealized_conversion_cast for TensorDescType values.
68 Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize,
69 Location loc, PatternRewriter &rewriter) const {
70 if (auto vecTy = dyn_cast<VectorType>(Val&: destTy)) {
71 assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
72 "Expecting blockSize size to match the rank of destTy.");
73 auto shape = vecTy.getShape();
74 return xegpu::createVectorWithShapeFromValues(builder&: rewriter, loc, values: srcs, shape);
75 }
76
77 if (isa<xegpu::TensorDescType>(Val: destTy)) {
78 auto attr = NamedAttribute(rewriter.getStringAttr(bytes: unpackAttrName),
79 rewriter.getUnitAttr());
80 auto blkAttr = NamedAttribute(rewriter.getStringAttr(bytes: blockAttrName),
81 rewriter.getDenseI64ArrayAttr(values: blockSize));
82 auto castOp = rewriter.create<UnrealizedConversionCastOp>(
83 location: loc, args&: destTy, args&: srcs, args: ArrayRef<NamedAttribute>({attr, blkAttr}));
84 return castOp.getResult(i: 0);
85 }
86
87 llvm_unreachable("Unexpected destTy.");
88 return Value();
89 }
90
91 /// Emulate the the pack behavior using extract_strided_slice for VectorType
92 /// values and unrealized_conversion_cast for TensorDescType values.
93 SmallVector<Value> pack(Value src, TypeRange destTypes,
94 ArrayRef<int64_t> blockSize, Location loc,
95 PatternRewriter &rewriter) const {
96 if (auto vecTy = dyn_cast<VectorType>(Val: src.getType())) {
97 assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
98 "Expecting blockSize size to match the rank of src.");
99 return xegpu::extractVectorsWithShapeFromValue(builder&: rewriter, loc, value: src,
100 shape: blockSize);
101 }
102
103 if (isa<xegpu::TensorDescType>(Val: src.getType())) {
104 auto attr = NamedAttribute(rewriter.getStringAttr(bytes: packAttrName),
105 rewriter.getUnitAttr());
106 auto blkAttr = NamedAttribute(rewriter.getStringAttr(bytes: blockAttrName),
107 rewriter.getDenseI64ArrayAttr(values: blockSize));
108 auto castOp = rewriter.create<UnrealizedConversionCastOp>(
109 location: loc, args&: destTypes, args&: src, args: ArrayRef<NamedAttribute>({attr, blkAttr}));
110 return castOp.getResults();
111 }
112
113 llvm_unreachable("Unexpected src type.");
114 return SmallVector<Value>();
115 }
116
117private:
118 const char *const packAttrName = "__xegpu_blocking_pack__";
119 const char *const unpackAttrName = "__xegpu_blocking_unpack__";
120 const char *const blockAttrName = "__xegpu_blocking_tile_shape__";
121
122 xegpu::UnrollOptions options;
123};
124
125struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
126 using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
127 LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
128 PatternRewriter &rewriter) const override {
129 Location loc = op.getLoc();
130 xegpu::TensorDescType tdescTy = op.getType();
131 int64_t rank = tdescTy.getRank();
132 ArrayRef<int64_t> shape = tdescTy.getShape();
133
134 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
135 if (!targetShape)
136 return failure();
137
138 auto newTdescTy = getUnrolledTypes(type: tdescTy, tileShape: *targetShape)[0];
139
140 auto addi = [&](OpFoldResult a, int64_t b) -> Value {
141 std::optional<int64_t> maybeInt = getConstantIntValue(ofr: a);
142 if (maybeInt) {
143 return rewriter.create<arith::ConstantIndexOp>(location: loc, args: *maybeInt + b);
144 } else {
145 auto aV = llvm::cast<Value>(Val&: a);
146 auto bV = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: b);
147 return rewriter.createOrFold<arith::AddIOp>(location: loc, args&: aV, args&: bV);
148 }
149 };
150
151 SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
152
153 // For n-D memrefs where n > rank, we need to handle the last `rank`
154 // dimensions only, and keep the first `n-rank` dimensions as is.
155 SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
156 Range: llvm::drop_begin(RangeOrContainer&: mixedOffsets, N: mixedOffsets.size() - rank));
157 auto validIdxes =
158 llvm::seq<int64_t>(Begin: mixedOffsets.size() - rank, End: mixedOffsets.size());
159
160 SmallVector<Value> newOps;
161 for (SmallVector<int64_t> offsets :
162 StaticTileOffsetRange(shape, *targetShape)) {
163
164 for (auto [idx, oldOff, offset] :
165 llvm::zip(t&: validIdxes, u&: oldOffsets, args&: offsets))
166 mixedOffsets[idx] = addi(oldOff, offset);
167
168 auto newOp = rewriter.create<xegpu::CreateNdDescOp>(
169 location: loc, args&: newTdescTy, args: op.getSource(), args&: mixedOffsets, args: op.getMixedSizes(),
170 args: op.getMixedStrides());
171 newOps.push_back(Elt: newOp);
172 }
173 Value castOp = unpack(srcs: newOps, destTy: tdescTy, blockSize: *targetShape, loc, rewriter);
174 rewriter.replaceOp(op, newValues: castOp);
175
176 return success();
177 }
178};
179
180struct UnrollUpdateNdOffsetOp : public UnrollPattern<xegpu::UpdateNdOffsetOp> {
181 using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
182 LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
183 PatternRewriter &rewriter) const override {
184 Location loc = op.getLoc();
185 xegpu::TensorDescType tdescTy = op.getTensorDescType();
186
187 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
188 if (!targetShape)
189 return failure();
190
191 SmallVector<Type> convertedTdescTypes =
192 getUnrolledTypes(type: tdescTy, tileShape: *targetShape);
193 SmallVector<Value> convertedTdesc = pack(
194 src: op.getTensorDesc(), destTypes: convertedTdescTypes, blockSize: *targetShape, loc, rewriter);
195
196 SmallVector<Value> newOps;
197 for (auto t : convertedTdesc) {
198 auto newOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
199 location: loc, args: t.getType(), args&: t, args: op.getOffsets(), args: op.getConstOffsets());
200 newOps.push_back(Elt: newOp);
201 }
202 Value castOp = unpack(srcs: newOps, destTy: op.getType(), blockSize: *targetShape, loc, rewriter);
203 rewriter.replaceOp(op, newValues: castOp);
204 return success();
205 }
206};
207
208struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
209 using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
210 LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
211 PatternRewriter &rewriter) const override {
212 Location loc = op.getLoc();
213 xegpu::TensorDescType tdescTy = op.getTensorDescType();
214
215 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
216 if (!targetShape)
217 return failure();
218
219 SmallVector<Type> convertedTdescTypes =
220 getUnrolledTypes(type: tdescTy, tileShape: *targetShape);
221 SmallVector<Value> convertedTdesc = pack(
222 src: op.getTensorDesc(), destTypes: convertedTdescTypes, blockSize: *targetShape, loc, rewriter);
223
224 for (auto t : convertedTdesc)
225 rewriter.create<xegpu::PrefetchNdOp>(location: loc, args: TypeRange(), args&: t, args: op->getAttrs());
226
227 rewriter.eraseOp(op);
228 return success();
229 }
230};
231
232struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
233 using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
234 LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
235 PatternRewriter &rewriter) const override {
236
237 Location loc = op.getLoc();
238 VectorType valueTy = op.getType();
239 xegpu::TensorDescType tdescTy = op.getTensorDescType();
240
241 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
242 if (!targetShape)
243 return failure();
244
245 Type elemTy = tdescTy.getElementType();
246 VectorType newValueTy = valueTy.cloneWith(shape: *targetShape, elementType: elemTy);
247
248 SmallVector<Type> convertedTdescTypes =
249 getUnrolledTypes(type: tdescTy, tileShape: *targetShape);
250 SmallVector<Value> convertedTdescs = pack(
251 src: op.getTensorDesc(), destTypes: convertedTdescTypes, blockSize: *targetShape, loc, rewriter);
252
253 SmallVector<Value> newOps;
254 for (auto t : convertedTdescs) {
255 auto newOp =
256 rewriter.create<xegpu::LoadNdOp>(location: loc, args&: newValueTy, args&: t, args: op->getAttrs());
257 newOps.push_back(Elt: newOp);
258 }
259
260 Value castOp = unpack(srcs: newOps, destTy: op.getType(), blockSize: *targetShape, loc, rewriter);
261
262 rewriter.replaceOp(op, newValues: castOp);
263 return success();
264 }
265};
266
267struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
268 using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
269 LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
270 PatternRewriter &rewriter) const override {
271 Location loc = op.getLoc();
272 VectorType valueTy = op.getValueType();
273 xegpu::TensorDescType tdescTy = op.getTensorDescType();
274
275 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
276 if (!targetShape)
277 return failure();
278
279 SmallVector<Type> convertedValTypes =
280 getUnrolledTypes(type: valueTy, tileShape: *targetShape);
281 SmallVector<Type> convertedTdescTypes =
282 getUnrolledTypes(type: tdescTy, tileShape: *targetShape);
283
284 SmallVector<Value> convertedValues =
285 pack(src: op.getValue(), destTypes: convertedValTypes, blockSize: *targetShape, loc, rewriter);
286 SmallVector<Value> convertedTdescs = pack(
287 src: op.getTensorDesc(), destTypes: convertedTdescTypes, blockSize: *targetShape, loc, rewriter);
288
289 for (auto [v, t] : llvm::zip(t&: convertedValues, u&: convertedTdescs))
290 rewriter.create<xegpu::StoreNdOp>(location: loc, args&: v, args&: t, args: op.getL1HintAttr(),
291 args: op.getL2HintAttr(), args: op.getL3HintAttr());
292
293 rewriter.eraseOp(op);
294 return success();
295 }
296};
297
298struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
299 using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
300 LogicalResult matchAndRewrite(xegpu::DpasOp op,
301 PatternRewriter &rewriter) const override {
302 Location loc = op.getLoc();
303
304 // expecting every operands is a 2D Vector
305 if (llvm::any_of(Range: op->getOperandTypes(), P: [&](Type type) {
306 auto vecTy = dyn_cast<VectorType>(Val&: type);
307 return !vecTy || vecTy.getRank() != 2;
308 }))
309 return failure();
310
311 // A vector of 3 elements should be returned, representing M, K, N
312 // respectively.
313 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
314 if (!targetShape || targetShape->size() != 3)
315 return failure();
316 auto M = (*targetShape)[0];
317 auto K = (*targetShape)[1];
318 auto N = (*targetShape)[2];
319
320 int64_t aBlockSize[2] = {M, K};
321 int64_t bBlockSize[2] = {K, N};
322 int64_t cBlockSize[2] = {M, N};
323
324 auto packWrapper = [&](TypedValue<VectorType> val,
325 ArrayRef<int64_t> blockSize) {
326 VectorType type = val.getType();
327 std::optional<SmallVector<int64_t>> grids =
328 computeShapeRatio(shape: type.getShape(), subShape: blockSize);
329 assert(grids && "Expecting grids to be computed.");
330 auto numNewOps = computeProduct(basis: *grids);
331 if (numNewOps == 1)
332 return SmallVector<Value>({val});
333 VectorType newVecTy = type.cloneWith(shape: blockSize, elementType: type.getElementType());
334 SmallVector<Type> convertedTypes(numNewOps, newVecTy);
335 SmallVector<Value> values =
336 pack(src: val, destTypes: convertedTypes, blockSize, loc, rewriter);
337 return values;
338 };
339
340 auto a = op.getLhs();
341 auto b = op.getRhs();
342 auto c = op.getAcc();
343
344 auto aShape = a.getType().getShape();
345 auto bShape = b.getType().getShape();
346
347 SmallVector<Value> aVals, bVals, cVals;
348 aVals = packWrapper(a, aBlockSize);
349 bVals = packWrapper(b, bBlockSize);
350
351 if (c)
352 cVals = packWrapper(c, cBlockSize);
353
354 // Skip the operation if every operand has an invalid blocking size (empty)
355 // or if the original shape matches the blocking size (size == 1).
356 auto ranges = c ? SmallVector<ValueRange>({aVals, bVals, cVals})
357 : SmallVector<ValueRange>({aVals, bVals});
358 if (llvm::any_of(Range&: ranges, P: [](auto &v) { return v.size() == 0; }) ||
359 llvm::all_of(Range&: ranges, P: [](auto &v) { return v.size() == 1; }))
360 return failure();
361
362 VectorType resultTy = op.getResult().getType();
363 auto vecTy = VectorType::get(shape: cBlockSize, elementType: resultTy.getElementType());
364
365 int64_t mIters = aShape[0] / M;
366 int64_t kIters = aShape[1] / K;
367 int64_t nIters = bShape[1] / N;
368
369 SmallVector<Value> newOps;
370 for (int64_t i = 0; i < mIters; ++i) {
371 for (int64_t j = 0; j < nIters; ++j) {
372 Value tmpC;
373 if (c)
374 tmpC = cVals[i * nIters + j]; // init with acc
375
376 for (int64_t k = 0; k < kIters; ++k) {
377 Value aVec = aVals[i * kIters + k];
378 Value bVec = bVals[k * nIters + j];
379 SmallVector<Value> operands({aVec, bVec});
380 if (tmpC)
381 operands.push_back(Elt: tmpC);
382
383 tmpC = rewriter.create<xegpu::DpasOp>(location: loc, args&: vecTy, args&: operands,
384 args: op->getAttrs());
385 }
386 newOps.push_back(Elt: tmpC);
387 }
388 }
389 Value castOp = unpack(srcs: newOps, destTy: resultTy, blockSize: cBlockSize, loc, rewriter);
390 rewriter.replaceOp(op, newValues: castOp);
391 return success();
392 }
393};
394
395struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
396 using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
397 LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
398 PatternRewriter &rewriter) const override {
399 Location loc = op.getLoc();
400 xegpu::TensorDescType tdescTy = op.getType();
401 TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
402 VectorType indiceVecTy = indiceVec.getType();
403
404 if (!tdescTy.isScattered())
405 return failure();
406
407 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
408 if (!targetShape)
409 return failure();
410
411 SmallVector<int64_t> targetIndiceShape(*targetShape);
412 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
413 // IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
414 if (originalChunkSize > 1)
415 targetIndiceShape.pop_back();
416
417 auto newTdescTy = getUnrolledTypes(type: tdescTy, tileShape: *targetShape)[0];
418 SmallVector<Type> convertedIndiceTypes =
419 getUnrolledTypes(type: indiceVecTy, tileShape: targetIndiceShape);
420 SmallVector<Value> convertedIndiceVec =
421 pack(src: indiceVec, destTypes: convertedIndiceTypes, blockSize: targetIndiceShape, loc, rewriter);
422
423 SmallVector<Value> newOps;
424
425 // More indices is need when chunkSize > 1. Since a big load from one
426 // address could be break into multiple small loads.
427 if (originalChunkSize > 1) {
428 int64_t blockedChunkSize = targetShape->back();
429 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
430
431 for (auto [indice, indiceType] :
432 llvm::zip(t&: convertedIndiceVec, u&: convertedIndiceTypes)) {
433 for (int64_t i = 0; i < numNewChunks; ++i) {
434 // Compute the offset
435 Value inc = rewriter.create<arith::ConstantIndexOp>(
436 location: loc, args: i * blockedChunkSize);
437 Value incVec =
438 rewriter.create<vector::BroadcastOp>(location: loc, args&: indiceType, args&: inc);
439 Value offsetIndice =
440 rewriter.create<arith::AddIOp>(location: loc, args&: indice, args&: incVec);
441
442 auto newOp = rewriter.create<xegpu::CreateDescOp>(
443 location: loc, args&: newTdescTy, args: op.getSource(), args&: offsetIndice);
444
445 newOps.push_back(Elt: newOp);
446 }
447 }
448 } else {
449 for (auto indice : convertedIndiceVec) {
450 auto newOp = rewriter.create<xegpu::CreateDescOp>(
451 location: loc, args&: newTdescTy, args: op.getSource(), args&: indice);
452 newOps.push_back(Elt: newOp);
453 }
454 }
455
456 Value castOp = unpack(srcs: newOps, destTy: tdescTy, blockSize: *targetShape, loc, rewriter);
457 rewriter.replaceOp(op, newValues: castOp);
458
459 return success();
460 }
461};
462
463struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
464 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
465 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
466 PatternRewriter &rewriter) const override {
467
468 Location loc = op.getLoc();
469 VectorType valueTy = llvm::dyn_cast<VectorType>(Val: op.getValue().getType());
470 xegpu::TensorDescType tdescTy = op.getTensorDescType();
471
472 if (!tdescTy.isScattered())
473 return failure();
474
475 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
476 if (!targetShape)
477 return failure();
478
479 SmallVector<int64_t> targetMaskShape(*targetShape);
480 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
481
482 VectorType maskTy = llvm::dyn_cast<VectorType>(Val: op.getMask().getType());
483
484 Type elemTy = tdescTy.getElementType();
485 VectorType newValueTy = valueTy.cloneWith(shape: *targetShape, elementType: elemTy);
486
487 SmallVector<Type> convertedTdescTypes =
488 getUnrolledTypes(type: tdescTy, tileShape: *targetShape);
489 SmallVector<Value> convertedTdescs = pack(
490 src: op.getTensorDesc(), destTypes: convertedTdescTypes, blockSize: *targetShape, loc, rewriter);
491
492 SmallVector<Type> convertedMaskTypes;
493 SmallVector<Value> convertedMasks;
494
495 if (originalChunkSize > 1) {
496 targetMaskShape.pop_back();
497 convertedMaskTypes = getUnrolledTypes(type: maskTy, tileShape: targetMaskShape);
498 int64_t blockedChunkSize = targetShape->back();
499 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
500
501 // the mask is reused across the chunk_size dimension
502 for (auto mask : pack(src: op.getMask(), destTypes: convertedMaskTypes, blockSize: targetMaskShape,
503 loc, rewriter))
504 convertedMasks.append(NumInputs: numNewChunks, Elt: mask);
505
506 newValueTy = valueTy.cloneWith(shape: *targetShape, elementType: elemTy);
507 } else {
508 convertedMaskTypes = getUnrolledTypes(type: maskTy, tileShape: targetMaskShape);
509 convertedMasks = pack(src: op.getMask(), destTypes: convertedMaskTypes, blockSize: targetMaskShape,
510 loc, rewriter);
511 }
512
513 SmallVector<Value> newOps;
514 for (auto [t, m] : llvm::zip(t&: convertedTdescs, u&: convertedMasks)) {
515 auto newOp = rewriter.create<xegpu::LoadGatherOp>(
516 location: loc, args&: newValueTy, args&: t, args&: m, args: op.getL1HintAttr(), args: op.getL2HintAttr(),
517 args: op.getL3HintAttr());
518 newOps.push_back(Elt: newOp);
519 }
520
521 Value castOp = unpack(srcs: newOps, destTy: op.getType(), blockSize: *targetShape, loc, rewriter);
522 rewriter.replaceOp(op, newValues: castOp);
523 return success();
524 }
525};
526
527struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
528 using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
529 LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
530 PatternRewriter &rewriter) const override {
531 Location loc = op.getLoc();
532 xegpu::TensorDescType tdescTy = op.getTensorDescType();
533
534 if (!tdescTy.isScattered())
535 return failure();
536
537 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
538 if (!targetShape)
539 return failure();
540
541 SmallVector<Type> convertedTdescTypes =
542 getUnrolledTypes(type: tdescTy, tileShape: *targetShape);
543 SmallVector<Value> convertedTdesc = pack(
544 src: op.getTensorDesc(), destTypes: convertedTdescTypes, blockSize: *targetShape, loc, rewriter);
545
546 for (auto t : convertedTdesc)
547 rewriter.create<xegpu::PrefetchOp>(location: loc, args: TypeRange(), args&: t, args: op->getAttrs());
548
549 rewriter.eraseOp(op);
550 return success();
551 }
552};
553
554struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
555 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
556 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
557 PatternRewriter &rewriter) const override {
558
559 Location loc = op.getLoc();
560 VectorType valueTy = llvm::dyn_cast<VectorType>(Val: op.getValue().getType());
561 xegpu::TensorDescType tdescTy = op.getTensorDescType();
562
563 if (!tdescTy.isScattered())
564 return failure();
565
566 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
567 if (!targetShape)
568 return failure();
569
570 SmallVector<int64_t> targetMaskShape(*targetShape);
571 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
572
573 VectorType maskTy = llvm::dyn_cast<VectorType>(Val: op.getMask().getType());
574
575 SmallVector<Type> convertedTdescTypes =
576 getUnrolledTypes(type: tdescTy, tileShape: *targetShape);
577 SmallVector<Value> convertedTdescs = pack(
578 src: op.getTensorDesc(), destTypes: convertedTdescTypes, blockSize: *targetShape, loc, rewriter);
579
580 SmallVector<Type> convertedMaskTypes;
581 SmallVector<Value> convertedMasks;
582
583 if (originalChunkSize > 1) {
584 targetMaskShape.pop_back();
585 int64_t blockedChunkSize = targetShape->back();
586 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
587 convertedMaskTypes = getUnrolledTypes(type: maskTy, tileShape: targetMaskShape);
588
589 // the mask is reused across the chunk_size dimension
590 for (auto mask : pack(src: op.getMask(), destTypes: convertedMaskTypes, blockSize: targetMaskShape,
591 loc, rewriter))
592 convertedMasks.append(NumInputs: numNewChunks, Elt: mask);
593 } else {
594 convertedMaskTypes = getUnrolledTypes(type: maskTy, tileShape: targetMaskShape);
595 convertedMasks = pack(src: op.getMask(), destTypes: convertedMaskTypes, blockSize: targetMaskShape,
596 loc, rewriter);
597 }
598
599 SmallVector<Type> convertedValTypes =
600 getUnrolledTypes(type: valueTy, tileShape: *targetShape);
601 SmallVector<Value> convertedValues =
602 pack(src: op.getValue(), destTypes: convertedValTypes, blockSize: *targetShape, loc, rewriter);
603
604 for (size_t i = 0; i < convertedValues.size(); ++i) {
605 Value v = convertedValues[i];
606 Value t = convertedTdescs[i];
607 Value m = op.getMask() ? convertedMasks[i] : nullptr;
608 rewriter.create<xegpu::StoreScatterOp>(location: loc, args&: v, args&: t, args&: m, args: op.getL1HintAttr(),
609 args: op.getL2HintAttr(),
610 args: op.getL3HintAttr());
611 }
612
613 rewriter.eraseOp(op);
614 return success();
615 }
616};
617
618struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
619 using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
620 LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
621 PatternRewriter &rewriter) const override {
622 Location loc = op.getLoc();
623 xegpu::TensorDescType tdescTy = op.getTensorDescType();
624
625 if (!tdescTy.isScattered())
626 return failure();
627
628 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
629 if (!targetShape)
630 return failure();
631
632 SmallVector<Type> convertedTdescTypes =
633 getUnrolledTypes(type: tdescTy, tileShape: *targetShape);
634 SmallVector<Value> convertedTdesc = pack(
635 src: op.getTensorDesc(), destTypes: convertedTdescTypes, blockSize: *targetShape, loc, rewriter);
636
637 TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
638 VectorType offsetVecTy = offsetVec.getType();
639 SmallVector<Type> convertedOffsetTypes;
640 SmallVector<Value> convertedOffsetVec;
641 SmallVector<Value> newOps;
642 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
643 if (originalChunkSize > 1) {
644 auto targetOffsetShape = ArrayRef<int64_t>(*targetShape).drop_back();
645 convertedOffsetTypes = getUnrolledTypes(type: offsetVecTy, tileShape: targetOffsetShape);
646
647 int64_t blockedChunkSize = targetShape->back();
648 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
649 // the offset is reused across the chunk_size dimension
650 for (auto offset : pack(src: offsetVec, destTypes: convertedOffsetTypes,
651 blockSize: targetOffsetShape, loc, rewriter))
652 convertedOffsetVec.append(NumInputs: numNewChunks, Elt: offset);
653
654 } else {
655 convertedOffsetTypes = getUnrolledTypes(type: offsetVecTy, tileShape: *targetShape);
656 convertedOffsetVec =
657 pack(src: offsetVec, destTypes: convertedOffsetTypes, blockSize: *targetShape, loc, rewriter);
658 }
659
660 for (auto [t, o] : llvm::zip(t&: convertedTdesc, u&: convertedOffsetVec)) {
661 auto newOp =
662 rewriter.create<xegpu::UpdateOffsetOp>(location: loc, args: t.getType(), args&: t, args&: o);
663 newOps.push_back(Elt: newOp);
664 }
665 Value castOp = unpack(srcs: newOps, destTy: op.getType(), blockSize: *targetShape, loc, rewriter);
666 rewriter.replaceOp(op, newValues: castOp);
667 return success();
668 }
669};
670
671} // namespace
672
673void mlir::xegpu::populateXeGPUUnrollPatterns(
674 RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
675 patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
676 UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
677 UnrollCreateDescOp, UnrollLoadGatherOp, UnrollStoreScatterOp,
678 UnrollPrefetchOp, UnrollUpdateOffsetOp>(arg: patterns.getContext(),
679 args: options);
680}
681

source code of mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp