1//===---- XeGPUBlocking.cpp ---- XeGPU Blocking Pass ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
10
11#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
12#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
13#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
14#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
15#include "mlir/Interfaces/LoopLikeInterface.h"
16#include "mlir/Pass/PassManager.h"
17#include "mlir/Transforms/DialectConversion.h"
18#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19#include "llvm/ADT/STLExtras.h"
20
21namespace mlir {
22namespace xegpu {
23#define GEN_PASS_DEF_XEGPUBLOCKING
24#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
25} // namespace xegpu
26} // namespace mlir
27
28#define DEBUG_TYPE "xegpu-blocking"
29#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
30#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
31
32using namespace mlir;
33
34namespace {
35
36// reslove the unrealized conversion cast ops generated when doing SCF
37// Structural Type Conversion. It will have two formats, N:1 vector
38// cast and 1:N vector cast. vector::insert_strided_slice ops will be
39// used for the first case, and vector::extract_strided_slice ops will be
40// used for the second case.
41static void
42resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
43 ValueRange inputs = castOp.getInputs();
44 ValueRange outputs = castOp.getOutputs();
45
46 auto hasIdenticalVectorTypes = [](ValueRange values) {
47 auto types = values.getTypes();
48 return llvm::all_of(Range&: types, P: [&](Type type) {
49 return isa<VectorType>(Val: type) && type == types.front();
50 });
51 };
52
53 // We only interest in the case where all inputs and outputs have the
54 // identical VectorTypes
55 if (!hasIdenticalVectorTypes(inputs) || !hasIdenticalVectorTypes(outputs)) {
56 LDBG("skip unrealized conversion cast op not emulating pack/unpack.");
57 return;
58 }
59
60 VectorType outputTy = dyn_cast<VectorType>(Val: outputs[0].getType());
61 OpBuilder builder(castOp);
62 if (inputs.size() > 1 && outputs.size() == 1) {
63 // the castOp is emulating an unpack op
64 ArrayRef<int64_t> shape = outputTy.getShape();
65 Value result = xegpu::createVectorWithShapeFromValues(
66 builder, loc: castOp.getLoc(), values: inputs, shape);
67 castOp->replaceAllUsesWith(values: ValueRange(result));
68 castOp->erase();
69 } else if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) {
70 // the castOp is emulating a pack op
71 ArrayRef<int64_t> tileShape = outputTy.getShape();
72 SmallVector<Value> results = xegpu::extractVectorsWithShapeFromValue(
73 builder, loc: castOp.getLoc(), value: inputs[0], shape: tileShape);
74 castOp->replaceAllUsesWith(values&: results);
75 castOp->erase();
76 }
77}
78
79//===------------------------------------------------------------------------===//
80// The XeGPUBlockingPass leverages the unroll patterns for XeGPU and Vector ops
81// to partition operations that process large shapes into multiple operations on
82// smaller shapes, as specified by the inst_data in the layout attribute. This
83// enables each resulting operation to be efficiently mapped to a hardware
84// instruction.
85//===------------------------------------------------------------------------===//
86
87class XeGPUBlockingPass final
88 : public xegpu::impl::XeGPUBlockingBase<XeGPUBlockingPass> {
89public:
90 void runOnOperation() override;
91
92private:
93 // Get the tile shape for a given OpOperand or OpResult by examining the
94 // corresponding layout attribute. If layout is not present or is not a
95 // subgroup level layout, it returns std::nullopt.
96 template <typename T,
97 typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
98 std::is_same_v<T, OpResult>>>
99 std::optional<SmallVector<int64_t>>
100 getTileShape(const T &operandOrResult) const;
101
102 // Get the tile shape for a given operation.
103 std::optional<SmallVector<int64_t>> getTileShape(Operation *op) const;
104
105 // Determine if the operation requires unrolling. Return false if all operands
106 // and results have tile shapes identical to their original types. Otherwise,
107 // return true.
108 bool needsUnroll(Operation *op) const;
109};
110} // namespace
111
112template <typename T, typename>
113std::optional<SmallVector<int64_t>>
114XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
115 Value value;
116 if constexpr (std::is_same_v<T, OpOperand>)
117 value = operandOrResult.get();
118 else
119 value = (Value)operandOrResult;
120
121 xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operandOrResult);
122 if (layout && layout.isSgLayout()) {
123 if (auto inst_data = layout.getInstData())
124 return llvm::to_vector_of<int64_t>(Range: inst_data.asArrayRef());
125
126 if (auto type = dyn_cast<ShapedType>(Val: value.getType()))
127 return llvm::to_vector(Range: type.getShape());
128 }
129 LDBG("failed to getTileShape for: " << value);
130 return std::nullopt;
131}
132
133std::optional<SmallVector<int64_t>>
134XeGPUBlockingPass::getTileShape(Operation *op) const {
135 if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp, xegpu::CreateDescOp,
136 xegpu::UpdateOffsetOp>(Val: op))
137 return getTileShape(operandOrResult: op->getOpResult(idx: 0));
138 if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
139 xegpu::LoadGatherOp>(Val: op))
140 return getTileShape(operandOrResult: op->getOpOperand(idx: 0));
141 if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(Val: op))
142 return getTileShape(operandOrResult: op->getOpOperand(idx: 1));
143
144 if (isa<xegpu::DpasOp>(Val: op)) {
145 std::optional<SmallVector<int64_t>> aTile =
146 getTileShape(operandOrResult: op->getOpOperand(idx: 0));
147 std::optional<SmallVector<int64_t>> bTile =
148 getTileShape(operandOrResult: op->getOpOperand(idx: 1));
149
150 if (!aTile || aTile->size() != 2 || !bTile || bTile->size() != 2)
151 return std::nullopt;
152
153 // semantic check for A and B
154 if ((*aTile)[1] != (*bTile)[0])
155 return std::nullopt;
156
157 // semantic check for C
158 if (op->getNumOperands() == 3) {
159 std::optional<SmallVector<int64_t>> cTile =
160 getTileShape(operandOrResult: op->getOpOperand(idx: 2));
161 int64_t expectedCTile[2] = {(*aTile)[0], (*bTile)[1]};
162 if (!cTile || !llvm::equal(LRange&: *cTile, RRange&: expectedCTile))
163 return std::nullopt;
164 }
165
166 return SmallVector<int64_t>({(*aTile)[0], (*aTile)[1], (*bTile)[1]});
167 }
168
169 if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)
170 return getTileShape(operandOrResult: op->getOpResult(idx: 0));
171
172 if (isa<vector::MultiDimReductionOp>(Val: op))
173 return getTileShape(operandOrResult: op->getOpOperand(idx: 0));
174
175 if (isa<vector::TransposeOp, vector::BroadcastOp>(Val: op))
176 return getTileShape(operandOrResult: op->getOpResult(idx: 0));
177
178 return std::nullopt;
179}
180
181bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
182 // skip the op if any of its operands or results has workgroup level layouts
183 bool hasWgLayoutOperands =
184 llvm::any_of(Range: op->getOpOperands(), P: [](OpOperand &opr) {
185 xegpu::LayoutAttr layout = xegpu::getLayoutAttr(opr);
186 return layout && layout.isWgLayout();
187 });
188 bool hasWgLayoutResults =
189 llvm::any_of(Range: op->getOpResults(), P: [](OpResult result) {
190 xegpu::LayoutAttr layout = xegpu::getLayoutAttr(value: result);
191 return layout && layout.isWgLayout();
192 });
193 if (hasWgLayoutOperands || hasWgLayoutResults) {
194 LDBG("skip unrolling for op with workgroup level layout: " << *op);
195 return false;
196 }
197
198 auto isUnrollable = [](Value value, ArrayRef<int64_t> tileShape) {
199 Type valTy = value.getType();
200 if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(Val&: valTy)) {
201 xegpu::LayoutAttr layout = tdescTy.getLayoutAttr();
202 return layout && layout.getInstData();
203 }
204 auto shapedType = dyn_cast<ShapedType>(Val&: valTy);
205 return shapedType && !llvm::equal(LRange&: tileShape, RRange: shapedType.getShape());
206 };
207
208 bool hasUnrollableOperands =
209 llvm::any_of(Range: op->getOpOperands(), P: [&](OpOperand &opr) {
210 std::optional<SmallVector<int64_t>> tileShape = getTileShape(operandOrResult: opr);
211 return tileShape.has_value() && isUnrollable(opr.get(), *tileShape);
212 });
213 bool hasUnrollableResults =
214 llvm::any_of(Range: op->getOpResults(), P: [&](OpResult result) {
215 std::optional<SmallVector<int64_t>> tileShape = getTileShape(operandOrResult: result);
216 return tileShape.has_value() && isUnrollable(result, *tileShape);
217 });
218 return hasUnrollableOperands || hasUnrollableResults;
219}
220
221void XeGPUBlockingPass::runOnOperation() {
222 MLIRContext *ctx = &getContext();
223 Operation *op = getOperation();
224
225 // Preserve the LayoutAttr for each operand to the owner's DictionaryAttr.
226 // This ensures that the LayoutAttr remains accessible even if the defining
227 // operation is replaced.
228 xegpu::setLayoutAttrs(op, getLayoutImpl: [](Value v) { return xegpu::getLayoutAttr(value: v); });
229
230 auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape,
231 xegpu::LayoutAttr layout) {
232 int count = 1;
233 SmallVector<int64_t> tileShape(shape);
234 if (layout && layout.getInstData()) {
235 DenseI32ArrayAttr instData = layout.getInstData();
236 tileShape = llvm::to_vector_of<int64_t>(Range: instData.asArrayRef());
237 count = computeProduct(basis: shape) / computeProduct(basis: tileShape);
238 }
239 return std::make_pair(x&: tileShape, y&: count);
240 };
241
242 // Perform type conversion for SCF control folow ops
243 TypeConverter converter;
244 converter.addConversion(callback: [](Type type) -> Type { return type; });
245 converter.addConversion(
246 callback: [&](RankedTensorType type,
247 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
248 Type elemTy = type.getElementType();
249 ArrayRef<int64_t> shape = type.getShape();
250
251 auto layout =
252 llvm::dyn_cast_if_present<xegpu::LayoutAttr>(Val: type.getEncoding());
253 if (layout && layout.isWgLayout())
254 return failure();
255
256 int count;
257 SmallVector<int64_t> subShape;
258 std::tie(args&: subShape, args&: count) = getTileShapeAndCount(shape, layout);
259 auto newTy = VectorType::get(shape: subShape, elementType: elemTy);
260 result.append(NumInputs: count, Elt: newTy);
261 return success();
262 });
263 converter.addConversion(
264 callback: [&](xegpu::TensorDescType type,
265 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
266 Type elemTy = type.getElementType();
267 ArrayRef<int64_t> shape = type.getShape();
268
269 xegpu::LayoutAttr layout = type.getLayoutAttr();
270 if (layout && layout.isWgLayout())
271 return failure();
272
273 int count;
274 SmallVector<int64_t> subShape;
275 std::tie(args&: subShape, args&: count) = getTileShapeAndCount(shape, layout);
276
277 if (layout)
278 layout = layout.dropInstData();
279
280 auto newTy = xegpu::TensorDescType::get(
281 context: type.getContext(), shape: subShape, elementType: elemTy, encoding: type.getEncoding(), layout);
282 result.append(NumInputs: count, Elt: newTy);
283 return success();
284 });
285
286 xegpu::doSCFStructuralTypeConversionWithTensorType(op, converter);
287
288 xegpu::UnrollOptions options;
289 options.setFilterConstraint(
290 [&](Operation *op) -> LogicalResult { return success(IsSuccess: needsUnroll(op)); });
291
292 options.setNativeShapeFn([&](Operation *op) { return getTileShape(op); });
293
294 options.setUnrolledTypesFn([&](ShapedType type, ArrayRef<int64_t> tileShape) {
295 Type elemTy = type.getElementType();
296 Type newTy;
297
298 if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(Val&: type)) {
299
300 Attribute encoding = tdescTy.getEncoding();
301 // If the encoding is a ScatterTensorDescAttr, we need to
302 // potentially adjust the chunk size based on the inst_data.
303 if (tdescTy.isScattered()) {
304 int64_t chunkSize = tdescTy.getChunkSizeAsInt();
305
306 if (chunkSize > 1) {
307 int64_t blockedChunkSize = chunkSize;
308 auto instData = tdescTy.getLayoutAttr().getInstData();
309 if (!instData.empty())
310 blockedChunkSize = instData.asArrayRef().back();
311
312 // To create a new attribute with a different chunk_size:
313 auto newEncoding = xegpu::ScatterTensorDescAttr::get(
314 context: ctx, memory_space: tdescTy.getMemorySpace(), chunk_size: blockedChunkSize);
315
316 encoding = newEncoding;
317 }
318 }
319
320 newTy =
321 xegpu::TensorDescType::get(context: ctx, shape: tileShape, elementType: elemTy, encoding,
322 layout: tdescTy.getLayoutAttr().dropInstData());
323 } else {
324 newTy = type.clone(shape: tileShape, elementType: elemTy);
325 }
326
327 std::optional<SmallVector<int64_t>> ratio =
328 computeShapeRatio(shape: type.getShape(), subShape: tileShape);
329 assert(ratio && "The shape of the type must be a multiple of tileShape.");
330 return SmallVector<Type>(computeProduct(basis: *ratio), newTy);
331 });
332
333 RewritePatternSet patterns(ctx);
334
335 vector::UnrollVectorOptions vectorOptions;
336 vectorOptions.setNativeShapeFn(options.nativeShape);
337
338 populateXeGPUUnrollPatterns(patterns, options);
339 vector::populateVectorUnrollPatterns(patterns, options: vectorOptions);
340
341 (void)applyPatternsGreedily(op, patterns: std::move(patterns));
342
343 op->walk(callback: [](Operation *op) {
344 // Remove the layout attributes cached per operands.
345 for (OpOperand &opr : op->getOpOperands()) {
346 std::string name = xegpu::getLayoutName(operand: opr);
347 if (op->hasAttrOfType<xegpu::LayoutAttr>(name))
348 op->removeAttr(name);
349 }
350
351 // Update the layout attributes per result.
352 for (OpResult result : op->getOpResults()) {
353 std::string name = xegpu::getLayoutName(result);
354 if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
355 op->removeAttr(name);
356 if (!isa<LoopLikeOpInterface>(Val: op))
357 xegpu::setLayoutAttr(operandOrResult: result, layout: layout.dropInstData());
358 }
359 }
360
361 // Resolve unrealized conversion cast ops emulating pack/unpack
362 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(Val: op))
363 resolveUnrealizedConversionCastOp(castOp);
364 });
365}
366

source code of mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp