1//===- XeGPUUnroll.cpp - patterns to do unrolling ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains patterns for unrolling XeGPU operations. It follows a
10// similar concept and design as vector unroll patterns, serving as a complement
11// to them.
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
16
17#include "mlir/Dialect/Utils/IndexingUtils.h"
18#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
19#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
20#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
21#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/Support/Debug.h"
24#include <numeric>
25
26namespace mlir {
27namespace xegpu {
28#define GEN_PASS_DEF_XEGPUUNROLL
29#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
30} // namespace xegpu
31} // namespace mlir
32
33#define DEBUG_TYPE "xegpu-unroll"
34#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
35#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
36
37using namespace mlir;
38
39namespace {
40
41template <typename SourceOp>
42struct UnrollPattern : public OpRewritePattern<SourceOp> {
43 UnrollPattern(MLIRContext *context, const xegpu::UnrollOptions &options,
44 PatternBenefit benefit = 1)
45 : OpRewritePattern<SourceOp>(context, benefit), options(options) {}
46
47protected:
48 /// Return the target shape for the given `op`. Return std::nullopt if the
49 /// op shouldn't be or cannot be unrolled.
50 std::optional<SmallVector<int64_t>> getTargetShape(Operation *op) const {
51 LDBG("");
52 LDBG("Get unroll shape for: " << *op);
53
54 if (options.filterConstraint && failed(options.filterConstraint(op))) {
55 LDBG("--no filter constraint -> BAIL");
56 return std::nullopt;
57 }
58
59 assert(options.nativeShape &&
60 "expects the native shape for native shape call back function.");
61 auto nativeShape = options.nativeShape(op);
62 return nativeShape;
63 }
64
65 SmallVector<Type> getUnrolledTypes(ShapedType type,
66 ArrayRef<int64_t> tileShape) const {
67 return options.getUnrolledTypes(type, tileShape);
68 }
69
70 /// Emulate the the unpack behavior using insert_strided_slice for VectorType
71 /// values and unrealized_conversion_cast for TensorDescType values.
72 Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize,
73 Location loc, PatternRewriter &rewriter) const {
74 if (auto vecTy = dyn_cast<VectorType>(destTy)) {
75 assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
76 "Expecting blockSize size to match the rank of destTy.");
77 auto shape = vecTy.getShape();
78 return xegpu::createVectorWithShapeFromValues(builder&: rewriter, loc, values: srcs, shape: shape);
79 }
80
81 if (isa<xegpu::TensorDescType>(destTy)) {
82 auto attr = NamedAttribute(rewriter.getStringAttr(unpackAttrName),
83 rewriter.getUnitAttr());
84 auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
85 rewriter.getDenseI64ArrayAttr(blockSize));
86 auto castOp = rewriter.create<UnrealizedConversionCastOp>(
87 loc, destTy, srcs, ArrayRef<NamedAttribute>({attr, blkAttr}));
88 return castOp.getResult(0);
89 }
90
91 llvm_unreachable("Unexpected destTy.");
92 return Value();
93 }
94
95 /// Emulate the the pack behavior using extract_strided_slice for VectorType
96 /// values and unrealized_conversion_cast for TensorDescType values.
97 SmallVector<Value> pack(Value src, TypeRange destTypes,
98 ArrayRef<int64_t> blockSize, Location loc,
99 PatternRewriter &rewriter) const {
100 if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
101 assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
102 "Expecting blockSize size to match the rank of src.");
103 return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src,
104 blockSize);
105 }
106
107 if (isa<xegpu::TensorDescType>(src.getType())) {
108 auto attr = NamedAttribute(rewriter.getStringAttr(packAttrName),
109 rewriter.getUnitAttr());
110 auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
111 rewriter.getDenseI64ArrayAttr(blockSize));
112 auto castOp = rewriter.create<UnrealizedConversionCastOp>(
113 loc, destTypes, src, ArrayRef<NamedAttribute>({attr, blkAttr}));
114 return castOp.getResults();
115 }
116
117 llvm_unreachable("Unexpected src type.");
118 return SmallVector<Value>();
119 }
120
121private:
122 const char *const packAttrName = "__xegpu_blocking_pack__";
123 const char *const unpackAttrName = "__xegpu_blocking_unpack__";
124 const char *const blockAttrName = "__xegpu_blocking_tile_shape__";
125
126 xegpu::UnrollOptions options;
127};
128
129struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
130 using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
131 LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
132 PatternRewriter &rewriter) const override {
133 Location loc = op.getLoc();
134 xegpu::TensorDescType tdescTy = op.getType();
135 int64_t rank = tdescTy.getRank();
136 ArrayRef<int64_t> shape = tdescTy.getShape();
137
138 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
139 if (!targetShape)
140 return failure();
141
142 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
143
144 auto addi = [&](OpFoldResult a, int64_t b) -> Value {
145 std::optional<int64_t> maybeInt = getConstantIntValue(ofr: a);
146 if (maybeInt) {
147 return rewriter.create<arith::ConstantIndexOp>(loc, *maybeInt + b);
148 } else {
149 auto aV = llvm::cast<Value>(a);
150 auto bV = rewriter.create<arith::ConstantIndexOp>(loc, b);
151 return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
152 }
153 };
154
155 SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
156
157 // For n-D memrefs where n > rank, we need to handle the last `rank`
158 // dimensions only, and keep the first `n-rank` dimensions as is.
159 SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
160 llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
161 auto validIdxes =
162 llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
163
164 SmallVector<Value> newOps;
165 for (SmallVector<int64_t> offsets :
166 StaticTileOffsetRange(shape, *targetShape)) {
167
168 for (auto [idx, oldOff, offset] :
169 llvm::zip(validIdxes, oldOffsets, offsets))
170 mixedOffsets[idx] = addi(oldOff, offset);
171
172 auto newOp = rewriter.create<xegpu::CreateNdDescOp>(
173 loc, newTdescTy, op.getSource(), mixedOffsets, op.getMixedSizes(),
174 op.getMixedStrides());
175 newOps.push_back(newOp);
176 }
177 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
178 rewriter.replaceOp(op, castOp);
179
180 return success();
181 }
182};
183
184struct UnrollUpdateNdOffsetOp : public UnrollPattern<xegpu::UpdateNdOffsetOp> {
185 using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
186 LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
187 PatternRewriter &rewriter) const override {
188 Location loc = op.getLoc();
189 xegpu::TensorDescType tdescTy = op.getTensorDescType();
190
191 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
192 if (!targetShape)
193 return failure();
194
195 SmallVector<Type> convertedTdescTypes =
196 getUnrolledTypes(tdescTy, *targetShape);
197 SmallVector<Value> convertedTdesc = pack(
198 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
199
200 SmallVector<Value> newOps;
201 for (auto t : convertedTdesc) {
202 auto newOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
203 loc, t.getType(), t, op.getOffsets(), op.getConstOffsets());
204 newOps.push_back(newOp);
205 }
206 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
207 rewriter.replaceOp(op, castOp);
208 return success();
209 }
210};
211
212struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
213 using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
214 LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
215 PatternRewriter &rewriter) const override {
216 Location loc = op.getLoc();
217 xegpu::TensorDescType tdescTy = op.getTensorDescType();
218
219 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
220 if (!targetShape)
221 return failure();
222
223 SmallVector<Type> convertedTdescTypes =
224 getUnrolledTypes(tdescTy, *targetShape);
225 SmallVector<Value> convertedTdesc = pack(
226 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
227
228 for (auto t : convertedTdesc)
229 rewriter.create<xegpu::PrefetchNdOp>(loc, TypeRange(), t, op->getAttrs());
230
231 rewriter.eraseOp(op: op);
232 return success();
233 }
234};
235
236struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
237 using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
238 LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
239 PatternRewriter &rewriter) const override {
240
241 Location loc = op.getLoc();
242 VectorType valueTy = op.getType();
243 xegpu::TensorDescType tdescTy = op.getTensorDescType();
244
245 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
246 if (!targetShape)
247 return failure();
248
249 Type elemTy = tdescTy.getElementType();
250 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
251
252 SmallVector<Type> convertedTdescTypes =
253 getUnrolledTypes(tdescTy, *targetShape);
254 SmallVector<Value> convertedTdescs = pack(
255 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
256
257 SmallVector<Value> newOps;
258 for (auto t : convertedTdescs) {
259 auto newOp =
260 rewriter.create<xegpu::LoadNdOp>(loc, newValueTy, t, op->getAttrs());
261 newOps.push_back(newOp);
262 }
263
264 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
265
266 rewriter.replaceOp(op, castOp);
267 return success();
268 }
269};
270
271struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
272 using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
273 LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
274 PatternRewriter &rewriter) const override {
275 Location loc = op.getLoc();
276 VectorType valueTy = op.getValueType();
277 xegpu::TensorDescType tdescTy = op.getTensorDescType();
278
279 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
280 if (!targetShape)
281 return failure();
282
283 SmallVector<Type> convertedValTypes =
284 getUnrolledTypes(valueTy, *targetShape);
285 SmallVector<Type> convertedTdescTypes =
286 getUnrolledTypes(tdescTy, *targetShape);
287
288 SmallVector<Value> convertedValues =
289 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
290 SmallVector<Value> convertedTdescs = pack(
291 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
292
293 for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
294 rewriter.create<xegpu::StoreNdOp>(loc, v, t, op.getL1HintAttr(),
295 op.getL2HintAttr(), op.getL3HintAttr());
296
297 rewriter.eraseOp(op: op);
298 return success();
299 }
300};
301
302struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
303 using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
304 LogicalResult matchAndRewrite(xegpu::DpasOp op,
305 PatternRewriter &rewriter) const override {
306 Location loc = op.getLoc();
307
308 // expecting every operands is a 2D Vector
309 if (llvm::any_of(op->getOperandTypes(), [&](Type type) {
310 auto vecTy = dyn_cast<VectorType>(type);
311 return !vecTy || vecTy.getRank() != 2;
312 }))
313 return failure();
314
315 // A vector of 3 elements should be returned, representing M, K, N
316 // respectively.
317 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
318 if (!targetShape || targetShape->size() != 3)
319 return failure();
320 auto M = (*targetShape)[0];
321 auto K = (*targetShape)[1];
322 auto N = (*targetShape)[2];
323
324 int64_t aBlockSize[2] = {M, K};
325 int64_t bBlockSize[2] = {K, N};
326 int64_t cBlockSize[2] = {M, N};
327
328 auto packWrapper = [&](TypedValue<VectorType> val,
329 ArrayRef<int64_t> blockSize) {
330 VectorType type = val.getType();
331 std::optional<SmallVector<int64_t>> grids =
332 computeShapeRatio(type.getShape(), blockSize);
333 assert(grids && "Expecting grids to be computed.");
334 auto numNewOps = computeProduct(*grids);
335 if (numNewOps == 1)
336 return SmallVector<Value>({val});
337 VectorType newVecTy = type.cloneWith(blockSize, type.getElementType());
338 SmallVector<Type> convertedTypes(numNewOps, newVecTy);
339 SmallVector<Value> values =
340 pack(val, convertedTypes, blockSize, loc, rewriter);
341 return values;
342 };
343
344 auto a = op.getLhs();
345 auto b = op.getRhs();
346 auto c = op.getAcc();
347
348 auto aShape = a.getType().getShape();
349 auto bShape = b.getType().getShape();
350
351 SmallVector<Value> aVals, bVals, cVals;
352 aVals = packWrapper(a, aBlockSize);
353 bVals = packWrapper(b, bBlockSize);
354
355 if (c)
356 cVals = packWrapper(c, cBlockSize);
357
358 // Skip the operation if every operand has an invalid blocking size (empty)
359 // or if the original shape matches the blocking size (size == 1).
360 auto ranges = c ? SmallVector<ValueRange>({aVals, bVals, cVals})
361 : SmallVector<ValueRange>({aVals, bVals});
362 if (llvm::any_of(ranges, [](auto &v) { return v.size() == 0; }) ||
363 llvm::all_of(ranges, [](auto &v) { return v.size() == 1; }))
364 return failure();
365
366 VectorType resultTy = op.getResult().getType();
367 auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());
368
369 int64_t mIters = aShape[0] / M;
370 int64_t kIters = aShape[1] / K;
371 int64_t nIters = bShape[1] / N;
372
373 SmallVector<Value> newOps;
374 for (int64_t i = 0; i < mIters; ++i) {
375 for (int64_t j = 0; j < nIters; ++j) {
376 Value tmpC;
377 if (c)
378 tmpC = cVals[i * nIters + j]; // init with acc
379
380 for (int64_t k = 0; k < kIters; ++k) {
381 Value aVec = aVals[i * kIters + k];
382 Value bVec = bVals[k * nIters + j];
383 SmallVector<Value> operands({aVec, bVec});
384 if (tmpC)
385 operands.push_back(tmpC);
386
387 tmpC = rewriter.create<xegpu::DpasOp>(loc, vecTy, operands,
388 op->getAttrs());
389 }
390 newOps.push_back(tmpC);
391 }
392 }
393 Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
394 rewriter.replaceOp(op, castOp);
395 return success();
396 }
397};
398
399} // namespace
400
401void mlir::xegpu::populateXeGPUUnrollPatterns(
402 RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
403 patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
404 UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>(
405 arg: patterns.getContext(), args: options);
406}
407

source code of mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp