1 | //===- XeGPUUnroll.cpp - patterns to do unrolling ---------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains patterns for unrolling XeGPU operations. It follows a |
10 | // similar concept and design as vector unroll patterns, serving as a complement |
11 | // to them. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "mlir/Dialect/XeGPU/Transforms/Passes.h" |
16 | |
17 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
18 | #include "mlir/Dialect/XeGPU/IR/XeGPU.h" |
19 | #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" |
20 | #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" |
21 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
22 | #include "llvm/ADT/STLExtras.h" |
23 | #include "llvm/Support/Debug.h" |
24 | #include <numeric> |
25 | |
26 | namespace mlir { |
27 | namespace xegpu { |
28 | #define GEN_PASS_DEF_XEGPUUNROLL |
29 | #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" |
30 | } // namespace xegpu |
31 | } // namespace mlir |
32 | |
33 | #define DEBUG_TYPE "xegpu-unroll" |
34 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
35 | #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") |
36 | |
37 | using namespace mlir; |
38 | |
39 | namespace { |
40 | |
41 | template <typename SourceOp> |
42 | struct UnrollPattern : public OpRewritePattern<SourceOp> { |
43 | UnrollPattern(MLIRContext *context, const xegpu::UnrollOptions &options, |
44 | PatternBenefit benefit = 1) |
45 | : OpRewritePattern<SourceOp>(context, benefit), options(options) {} |
46 | |
47 | protected: |
48 | /// Return the target shape for the given `op`. Return std::nullopt if the |
49 | /// op shouldn't be or cannot be unrolled. |
50 | std::optional<SmallVector<int64_t>> getTargetShape(Operation *op) const { |
51 | LDBG("" ); |
52 | LDBG("Get unroll shape for: " << *op); |
53 | |
54 | if (options.filterConstraint && failed(options.filterConstraint(op))) { |
55 | LDBG("--no filter constraint -> BAIL" ); |
56 | return std::nullopt; |
57 | } |
58 | |
59 | assert(options.nativeShape && |
60 | "expects the native shape for native shape call back function." ); |
61 | auto nativeShape = options.nativeShape(op); |
62 | return nativeShape; |
63 | } |
64 | |
65 | SmallVector<Type> getUnrolledTypes(ShapedType type, |
66 | ArrayRef<int64_t> tileShape) const { |
67 | return options.getUnrolledTypes(type, tileShape); |
68 | } |
69 | |
70 | /// Emulate the the unpack behavior using insert_strided_slice for VectorType |
71 | /// values and unrealized_conversion_cast for TensorDescType values. |
72 | Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize, |
73 | Location loc, PatternRewriter &rewriter) const { |
74 | if (auto vecTy = dyn_cast<VectorType>(destTy)) { |
75 | assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) && |
76 | "Expecting blockSize size to match the rank of destTy." ); |
77 | auto shape = vecTy.getShape(); |
78 | return xegpu::createVectorWithShapeFromValues(builder&: rewriter, loc, values: srcs, shape: shape); |
79 | } |
80 | |
81 | if (isa<xegpu::TensorDescType>(destTy)) { |
82 | auto attr = NamedAttribute(rewriter.getStringAttr(unpackAttrName), |
83 | rewriter.getUnitAttr()); |
84 | auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName), |
85 | rewriter.getDenseI64ArrayAttr(blockSize)); |
86 | auto castOp = rewriter.create<UnrealizedConversionCastOp>( |
87 | loc, destTy, srcs, ArrayRef<NamedAttribute>({attr, blkAttr})); |
88 | return castOp.getResult(0); |
89 | } |
90 | |
91 | llvm_unreachable("Unexpected destTy." ); |
92 | return Value(); |
93 | } |
94 | |
95 | /// Emulate the the pack behavior using extract_strided_slice for VectorType |
96 | /// values and unrealized_conversion_cast for TensorDescType values. |
97 | SmallVector<Value> pack(Value src, TypeRange destTypes, |
98 | ArrayRef<int64_t> blockSize, Location loc, |
99 | PatternRewriter &rewriter) const { |
100 | if (auto vecTy = dyn_cast<VectorType>(src.getType())) { |
101 | assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) && |
102 | "Expecting blockSize size to match the rank of src." ); |
103 | return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src, |
104 | blockSize); |
105 | } |
106 | |
107 | if (isa<xegpu::TensorDescType>(src.getType())) { |
108 | auto attr = NamedAttribute(rewriter.getStringAttr(packAttrName), |
109 | rewriter.getUnitAttr()); |
110 | auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName), |
111 | rewriter.getDenseI64ArrayAttr(blockSize)); |
112 | auto castOp = rewriter.create<UnrealizedConversionCastOp>( |
113 | loc, destTypes, src, ArrayRef<NamedAttribute>({attr, blkAttr})); |
114 | return castOp.getResults(); |
115 | } |
116 | |
117 | llvm_unreachable("Unexpected src type." ); |
118 | return SmallVector<Value>(); |
119 | } |
120 | |
121 | private: |
122 | const char *const packAttrName = "__xegpu_blocking_pack__" ; |
123 | const char *const unpackAttrName = "__xegpu_blocking_unpack__" ; |
124 | const char *const blockAttrName = "__xegpu_blocking_tile_shape__" ; |
125 | |
126 | xegpu::UnrollOptions options; |
127 | }; |
128 | |
129 | struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> { |
130 | using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern; |
131 | LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op, |
132 | PatternRewriter &rewriter) const override { |
133 | Location loc = op.getLoc(); |
134 | xegpu::TensorDescType tdescTy = op.getType(); |
135 | int64_t rank = tdescTy.getRank(); |
136 | ArrayRef<int64_t> shape = tdescTy.getShape(); |
137 | |
138 | std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); |
139 | if (!targetShape) |
140 | return failure(); |
141 | |
142 | auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0]; |
143 | |
144 | auto addi = [&](OpFoldResult a, int64_t b) -> Value { |
145 | std::optional<int64_t> maybeInt = getConstantIntValue(ofr: a); |
146 | if (maybeInt) { |
147 | return rewriter.create<arith::ConstantIndexOp>(loc, *maybeInt + b); |
148 | } else { |
149 | auto aV = llvm::cast<Value>(a); |
150 | auto bV = rewriter.create<arith::ConstantIndexOp>(loc, b); |
151 | return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV); |
152 | } |
153 | }; |
154 | |
155 | SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets(); |
156 | |
157 | // For n-D memrefs where n > rank, we need to handle the last `rank` |
158 | // dimensions only, and keep the first `n-rank` dimensions as is. |
159 | SmallVector<OpFoldResult> oldOffsets = llvm::to_vector( |
160 | llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank)); |
161 | auto validIdxes = |
162 | llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size()); |
163 | |
164 | SmallVector<Value> newOps; |
165 | for (SmallVector<int64_t> offsets : |
166 | StaticTileOffsetRange(shape, *targetShape)) { |
167 | |
168 | for (auto [idx, oldOff, offset] : |
169 | llvm::zip(validIdxes, oldOffsets, offsets)) |
170 | mixedOffsets[idx] = addi(oldOff, offset); |
171 | |
172 | auto newOp = rewriter.create<xegpu::CreateNdDescOp>( |
173 | loc, newTdescTy, op.getSource(), mixedOffsets, op.getMixedSizes(), |
174 | op.getMixedStrides()); |
175 | newOps.push_back(newOp); |
176 | } |
177 | Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter); |
178 | rewriter.replaceOp(op, castOp); |
179 | |
180 | return success(); |
181 | } |
182 | }; |
183 | |
184 | struct UnrollUpdateNdOffsetOp : public UnrollPattern<xegpu::UpdateNdOffsetOp> { |
185 | using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern; |
186 | LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op, |
187 | PatternRewriter &rewriter) const override { |
188 | Location loc = op.getLoc(); |
189 | xegpu::TensorDescType tdescTy = op.getTensorDescType(); |
190 | |
191 | std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); |
192 | if (!targetShape) |
193 | return failure(); |
194 | |
195 | SmallVector<Type> convertedTdescTypes = |
196 | getUnrolledTypes(tdescTy, *targetShape); |
197 | SmallVector<Value> convertedTdesc = pack( |
198 | op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); |
199 | |
200 | SmallVector<Value> newOps; |
201 | for (auto t : convertedTdesc) { |
202 | auto newOp = rewriter.create<xegpu::UpdateNdOffsetOp>( |
203 | loc, t.getType(), t, op.getOffsets(), op.getConstOffsets()); |
204 | newOps.push_back(newOp); |
205 | } |
206 | Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter); |
207 | rewriter.replaceOp(op, castOp); |
208 | return success(); |
209 | } |
210 | }; |
211 | |
212 | struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> { |
213 | using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern; |
214 | LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op, |
215 | PatternRewriter &rewriter) const override { |
216 | Location loc = op.getLoc(); |
217 | xegpu::TensorDescType tdescTy = op.getTensorDescType(); |
218 | |
219 | std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); |
220 | if (!targetShape) |
221 | return failure(); |
222 | |
223 | SmallVector<Type> convertedTdescTypes = |
224 | getUnrolledTypes(tdescTy, *targetShape); |
225 | SmallVector<Value> convertedTdesc = pack( |
226 | op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); |
227 | |
228 | for (auto t : convertedTdesc) |
229 | rewriter.create<xegpu::PrefetchNdOp>(loc, TypeRange(), t, op->getAttrs()); |
230 | |
231 | rewriter.eraseOp(op: op); |
232 | return success(); |
233 | } |
234 | }; |
235 | |
236 | struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> { |
237 | using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern; |
238 | LogicalResult matchAndRewrite(xegpu::LoadNdOp op, |
239 | PatternRewriter &rewriter) const override { |
240 | |
241 | Location loc = op.getLoc(); |
242 | VectorType valueTy = op.getType(); |
243 | xegpu::TensorDescType tdescTy = op.getTensorDescType(); |
244 | |
245 | std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); |
246 | if (!targetShape) |
247 | return failure(); |
248 | |
249 | Type elemTy = tdescTy.getElementType(); |
250 | VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy); |
251 | |
252 | SmallVector<Type> convertedTdescTypes = |
253 | getUnrolledTypes(tdescTy, *targetShape); |
254 | SmallVector<Value> convertedTdescs = pack( |
255 | op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); |
256 | |
257 | SmallVector<Value> newOps; |
258 | for (auto t : convertedTdescs) { |
259 | auto newOp = |
260 | rewriter.create<xegpu::LoadNdOp>(loc, newValueTy, t, op->getAttrs()); |
261 | newOps.push_back(newOp); |
262 | } |
263 | |
264 | Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter); |
265 | |
266 | rewriter.replaceOp(op, castOp); |
267 | return success(); |
268 | } |
269 | }; |
270 | |
271 | struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> { |
272 | using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern; |
273 | LogicalResult matchAndRewrite(xegpu::StoreNdOp op, |
274 | PatternRewriter &rewriter) const override { |
275 | Location loc = op.getLoc(); |
276 | VectorType valueTy = op.getValueType(); |
277 | xegpu::TensorDescType tdescTy = op.getTensorDescType(); |
278 | |
279 | std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); |
280 | if (!targetShape) |
281 | return failure(); |
282 | |
283 | SmallVector<Type> convertedValTypes = |
284 | getUnrolledTypes(valueTy, *targetShape); |
285 | SmallVector<Type> convertedTdescTypes = |
286 | getUnrolledTypes(tdescTy, *targetShape); |
287 | |
288 | SmallVector<Value> convertedValues = |
289 | pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter); |
290 | SmallVector<Value> convertedTdescs = pack( |
291 | op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); |
292 | |
293 | for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs)) |
294 | rewriter.create<xegpu::StoreNdOp>(loc, v, t, op.getL1HintAttr(), |
295 | op.getL2HintAttr(), op.getL3HintAttr()); |
296 | |
297 | rewriter.eraseOp(op: op); |
298 | return success(); |
299 | } |
300 | }; |
301 | |
302 | struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> { |
303 | using UnrollPattern<xegpu::DpasOp>::UnrollPattern; |
304 | LogicalResult matchAndRewrite(xegpu::DpasOp op, |
305 | PatternRewriter &rewriter) const override { |
306 | Location loc = op.getLoc(); |
307 | |
308 | // expecting every operands is a 2D Vector |
309 | if (llvm::any_of(op->getOperandTypes(), [&](Type type) { |
310 | auto vecTy = dyn_cast<VectorType>(type); |
311 | return !vecTy || vecTy.getRank() != 2; |
312 | })) |
313 | return failure(); |
314 | |
315 | // A vector of 3 elements should be returned, representing M, K, N |
316 | // respectively. |
317 | std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); |
318 | if (!targetShape || targetShape->size() != 3) |
319 | return failure(); |
320 | auto M = (*targetShape)[0]; |
321 | auto K = (*targetShape)[1]; |
322 | auto N = (*targetShape)[2]; |
323 | |
324 | int64_t aBlockSize[2] = {M, K}; |
325 | int64_t bBlockSize[2] = {K, N}; |
326 | int64_t cBlockSize[2] = {M, N}; |
327 | |
328 | auto packWrapper = [&](TypedValue<VectorType> val, |
329 | ArrayRef<int64_t> blockSize) { |
330 | VectorType type = val.getType(); |
331 | std::optional<SmallVector<int64_t>> grids = |
332 | computeShapeRatio(type.getShape(), blockSize); |
333 | assert(grids && "Expecting grids to be computed." ); |
334 | auto numNewOps = computeProduct(*grids); |
335 | if (numNewOps == 1) |
336 | return SmallVector<Value>({val}); |
337 | VectorType newVecTy = type.cloneWith(blockSize, type.getElementType()); |
338 | SmallVector<Type> convertedTypes(numNewOps, newVecTy); |
339 | SmallVector<Value> values = |
340 | pack(val, convertedTypes, blockSize, loc, rewriter); |
341 | return values; |
342 | }; |
343 | |
344 | auto a = op.getLhs(); |
345 | auto b = op.getRhs(); |
346 | auto c = op.getAcc(); |
347 | |
348 | auto aShape = a.getType().getShape(); |
349 | auto bShape = b.getType().getShape(); |
350 | |
351 | SmallVector<Value> aVals, bVals, cVals; |
352 | aVals = packWrapper(a, aBlockSize); |
353 | bVals = packWrapper(b, bBlockSize); |
354 | |
355 | if (c) |
356 | cVals = packWrapper(c, cBlockSize); |
357 | |
358 | // Skip the operation if every operand has an invalid blocking size (empty) |
359 | // or if the original shape matches the blocking size (size == 1). |
360 | auto ranges = c ? SmallVector<ValueRange>({aVals, bVals, cVals}) |
361 | : SmallVector<ValueRange>({aVals, bVals}); |
362 | if (llvm::any_of(ranges, [](auto &v) { return v.size() == 0; }) || |
363 | llvm::all_of(ranges, [](auto &v) { return v.size() == 1; })) |
364 | return failure(); |
365 | |
366 | VectorType resultTy = op.getResult().getType(); |
367 | auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType()); |
368 | |
369 | int64_t mIters = aShape[0] / M; |
370 | int64_t kIters = aShape[1] / K; |
371 | int64_t nIters = bShape[1] / N; |
372 | |
373 | SmallVector<Value> newOps; |
374 | for (int64_t i = 0; i < mIters; ++i) { |
375 | for (int64_t j = 0; j < nIters; ++j) { |
376 | Value tmpC; |
377 | if (c) |
378 | tmpC = cVals[i * nIters + j]; // init with acc |
379 | |
380 | for (int64_t k = 0; k < kIters; ++k) { |
381 | Value aVec = aVals[i * kIters + k]; |
382 | Value bVec = bVals[k * nIters + j]; |
383 | SmallVector<Value> operands({aVec, bVec}); |
384 | if (tmpC) |
385 | operands.push_back(tmpC); |
386 | |
387 | tmpC = rewriter.create<xegpu::DpasOp>(loc, vecTy, operands, |
388 | op->getAttrs()); |
389 | } |
390 | newOps.push_back(tmpC); |
391 | } |
392 | } |
393 | Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter); |
394 | rewriter.replaceOp(op, castOp); |
395 | return success(); |
396 | } |
397 | }; |
398 | |
399 | } // namespace |
400 | |
401 | void mlir::xegpu::populateXeGPUUnrollPatterns( |
402 | RewritePatternSet &patterns, const xegpu::UnrollOptions &options) { |
403 | patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp, |
404 | UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>( |
405 | arg: patterns.getContext(), args: options); |
406 | } |
407 | |