1//===---- XeGPUBlocking.cpp ---- XeGPU Blocking Pass ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
10
11#include "mlir/Dialect/GPU/IR/GPUDialect.h"
12#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
13#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
14#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
15#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
16#include "mlir/Interfaces/LoopLikeInterface.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Pass/PassManager.h"
19#include "mlir/Transforms/DialectConversion.h"
20#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21#include "llvm/ADT/STLExtras.h"
22
23namespace mlir {
24namespace xegpu {
25#define GEN_PASS_DEF_XEGPUBLOCKING
26#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
27} // namespace xegpu
28} // namespace mlir
29
30#define DEBUG_TYPE "xegpu-blocking"
31#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
32#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
33
34using namespace mlir;
35
36namespace {
37
38// reslove the unrealized conversion cast ops generated when doing SCF
39// Structural Type Conversion. It will have two formats, N:1 vector
40// cast and 1:N vector cast. vector::insert_strided_slice ops will be
41// used for the first case, and vector::extract_strided_slice ops will be
42// used for the second case.
43static void
44resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
45 ValueRange inputs = castOp.getInputs();
46 ValueRange outputs = castOp.getOutputs();
47
48 auto hasIdenticalVectorTypes = [](ValueRange values) {
49 auto types = values.getTypes();
50 return llvm::all_of(types, [&](Type type) {
51 return isa<VectorType>(type) && type == types.front();
52 });
53 };
54
55 // We only interest in the case where all inputs and outputs have the
56 // identical VectorTypes
57 if (!hasIdenticalVectorTypes(inputs) || !hasIdenticalVectorTypes(outputs)) {
58 LDBG("skip unrealized conversion cast op not emulating pack/unpack.");
59 return;
60 }
61
62 VectorType outputTy = dyn_cast<VectorType>(outputs[0].getType());
63 OpBuilder builder(castOp);
64 if (inputs.size() > 1 && outputs.size() == 1) {
65 // the castOp is emulating an unpack op
66 ArrayRef<int64_t> shape = outputTy.getShape();
67 Value result = xegpu::createVectorWithShapeFromValues(
68 builder, loc: castOp.getLoc(), values: inputs, shape);
69 castOp->replaceAllUsesWith(ValueRange(result));
70 castOp->erase();
71 } else if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) {
72 // the castOp is emulating a pack op
73 ArrayRef<int64_t> tileShape = outputTy.getShape();
74 SmallVector<Value> results = xegpu::extractVectorsWithShapeFromValue(
75 builder, castOp.getLoc(), inputs[0], tileShape);
76 castOp->replaceAllUsesWith(results);
77 castOp->erase();
78 }
79}
80
81//===------------------------------------------------------------------------===//
82// The XeGPUBlockingPass leverages the unroll patterns for XeGPU and Vector ops
83// to partition operations that process large shapes into multiple operations on
84// smaller shapes, as specified by the inst_data in the layout attribute. This
85// enables each resulting operation to be efficiently mapped to a hardware
86// instruction.
87//===------------------------------------------------------------------------===//
88
89class XeGPUBlockingPass final
90 : public xegpu::impl::XeGPUBlockingBase<XeGPUBlockingPass> {
91public:
92 void runOnOperation() override;
93
94private:
95 // Get the tile shape for a given OpOperand or OpResult by examining the
96 // corresponding layout attribute. If layout is not present or is not a
97 // subgroup level layout, it returns std::nullopt.
98 template <typename T,
99 typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
100 std::is_same_v<T, OpResult>>>
101 std::optional<SmallVector<int64_t>>
102 getTileShape(const T &operandOrResult) const;
103
104 // Get the tile shape for a given operation.
105 std::optional<SmallVector<int64_t>> getTileShape(Operation *op) const;
106
107 // Determine if the operation requires unrolling. Return false if all operands
108 // and results have tile shapes identical to their original types. Otherwise,
109 // return true.
110 bool needsUnroll(Operation *op) const;
111};
112} // namespace
113
114template <typename T, typename>
115std::optional<SmallVector<int64_t>>
116XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
117 Value value;
118 if constexpr (std::is_same_v<T, OpOperand>)
119 value = operandOrResult.get();
120 else
121 value = (Value)operandOrResult;
122
123 xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operandOrResult);
124 if (layout && layout.isSgLayout()) {
125 if (auto inst_data = layout.getInstData())
126 return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());
127
128 if (auto type = dyn_cast<ShapedType>(value.getType()))
129 return llvm::to_vector(type.getShape());
130 }
131 LDBG("failed to getTileShape for: " << value);
132 return std::nullopt;
133}
134
135std::optional<SmallVector<int64_t>>
136XeGPUBlockingPass::getTileShape(Operation *op) const {
137 if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp>(op))
138 return getTileShape(op->getOpResult(0));
139 if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp>(op))
140 return getTileShape(op->getOpOperand(0));
141 if (isa<xegpu::StoreNdOp>(op))
142 return getTileShape(op->getOpOperand(1));
143
144 if (isa<xegpu::DpasOp>(op)) {
145 std::optional<SmallVector<int64_t>> aTile =
146 getTileShape(op->getOpOperand(0));
147 std::optional<SmallVector<int64_t>> bTile =
148 getTileShape(op->getOpOperand(1));
149
150 if (!aTile || aTile->size() != 2 || !bTile || bTile->size() != 2)
151 return std::nullopt;
152
153 // semantic check for A and B
154 if ((*aTile)[1] != (*bTile)[0])
155 return std::nullopt;
156
157 // semantic check for C
158 if (op->getNumOperands() == 3) {
159 std::optional<SmallVector<int64_t>> cTile =
160 getTileShape(op->getOpOperand(2));
161 int64_t expectedCTile[2] = {(*aTile)[0], (*bTile)[1]};
162 if (!cTile || !llvm::equal(LRange&: *cTile, RRange&: expectedCTile))
163 return std::nullopt;
164 }
165
166 return SmallVector<int64_t>({(*aTile)[0], (*aTile)[1], (*bTile)[1]});
167 }
168
169 if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)
170 return getTileShape(op->getOpResult(0));
171
172 if (isa<vector::MultiDimReductionOp>(op))
173 return getTileShape(op->getOpOperand(0));
174
175 if (isa<vector::TransposeOp, vector::BroadcastOp>(op))
176 return getTileShape(op->getOpResult(0));
177
178 return std::nullopt;
179}
180
181bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
182 // skip the op if any of its operands or results has workgroup level layouts
183 bool hasWgLayoutOperands =
184 llvm::any_of(Range: op->getOpOperands(), P: [](OpOperand &opr) {
185 xegpu::LayoutAttr layout = xegpu::getLayoutAttr(opr);
186 return layout && layout.isWgLayout();
187 });
188 bool hasWgLayoutResults =
189 llvm::any_of(Range: op->getOpResults(), P: [](OpResult result) {
190 xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result);
191 return layout && layout.isWgLayout();
192 });
193 if (hasWgLayoutOperands || hasWgLayoutResults) {
194 LDBG("skip unrolling for op with workgroup level layout: " << *op);
195 return false;
196 }
197
198 auto isUnrollable = [](Value value, ArrayRef<int64_t> tileShape) {
199 Type valTy = value.getType();
200 if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(valTy)) {
201 xegpu::LayoutAttr layout = tdescTy.getLayoutAttr();
202 return layout && layout.getInstData();
203 }
204 auto shapedType = dyn_cast<ShapedType>(valTy);
205 return shapedType && !llvm::equal(tileShape, shapedType.getShape());
206 };
207
208 bool hasUnrollableOperands =
209 llvm::any_of(Range: op->getOpOperands(), P: [&](OpOperand &opr) {
210 std::optional<SmallVector<int64_t>> tileShape = getTileShape(opr);
211 return tileShape.has_value() && isUnrollable(opr.get(), *tileShape);
212 });
213 bool hasUnrollableResults =
214 llvm::any_of(Range: op->getOpResults(), P: [&](OpResult result) {
215 std::optional<SmallVector<int64_t>> tileShape = getTileShape(result);
216 return tileShape.has_value() && isUnrollable(result, *tileShape);
217 });
218 return hasUnrollableOperands || hasUnrollableResults;
219}
220
221void XeGPUBlockingPass::runOnOperation() {
222 MLIRContext *ctx = &getContext();
223 Operation *op = getOperation();
224
225 // Preserve the LayoutAttr for each operand to the owner's DictionaryAttr.
226 // This ensures that the LayoutAttr remains accessible even if the defining
227 // operation is replaced.
228 xegpu::setLayoutAttrs(op, [](Value v) { return xegpu::getLayoutAttr(v); });
229
230 auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape,
231 xegpu::LayoutAttr layout) {
232 int count = 1;
233 SmallVector<int64_t> tileShape(shape);
234 if (layout && layout.getInstData()) {
235 DenseI32ArrayAttr instData = layout.getInstData();
236 tileShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
237 count = computeProduct(basis: shape) / computeProduct(basis: tileShape);
238 }
239 return std::make_pair(x&: tileShape, y&: count);
240 };
241
242 // Perform type conversion for SCF control folow ops
243 TypeConverter converter;
244 converter.addConversion(callback: [](Type type) -> Type { return type; });
245 converter.addConversion(
246 callback: [&](RankedTensorType type,
247 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
248 Type elemTy = type.getElementType();
249 ArrayRef<int64_t> shape = type.getShape();
250
251 auto layout =
252 llvm::dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding());
253 if (layout && layout.isWgLayout())
254 return failure();
255
256 int count;
257 SmallVector<int64_t> subShape;
258 std::tie(args&: subShape, args&: count) = getTileShapeAndCount(shape, layout);
259 auto newTy = VectorType::get(subShape, elemTy);
260 result.append(count, newTy);
261 return success();
262 });
263 converter.addConversion(
264 callback: [&](xegpu::TensorDescType type,
265 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
266 Type elemTy = type.getElementType();
267 ArrayRef<int64_t> shape = type.getShape();
268
269 xegpu::LayoutAttr layout = type.getLayoutAttr();
270 if (layout && layout.isWgLayout())
271 return failure();
272
273 int count;
274 SmallVector<int64_t> subShape;
275 std::tie(args&: subShape, args&: count) = getTileShapeAndCount(shape, layout);
276
277 if (layout)
278 layout = layout.dropInstData();
279
280 auto newTy = xegpu::TensorDescType::get(
281 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
282 result.append(count, newTy);
283 return success();
284 });
285
286 xegpu::doSCFStructuralTypeConversionWithTensorType(op, converter);
287
288 xegpu::UnrollOptions options;
289 options.setFilterConstraint(
290 [&](Operation *op) -> LogicalResult { return success(IsSuccess: needsUnroll(op)); });
291
292 options.setNativeShapeFn([&](Operation *op) { return getTileShape(op); });
293
294 options.setUnrolledTypesFn([&](ShapedType type, ArrayRef<int64_t> tileShape) {
295 Type elemTy = type.getElementType();
296 Type newTy;
297
298 if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type))
299 newTy = xegpu::TensorDescType::get(
300 ctx, tileShape, elemTy, tdescTy.getEncoding(),
301 tdescTy.getLayoutAttr().dropInstData());
302 else
303 newTy = type.clone(tileShape, elemTy);
304
305 std::optional<SmallVector<int64_t>> ratio =
306 computeShapeRatio(type.getShape(), tileShape);
307 assert(ratio && "The shape of the type must be a multiple of tileShape.");
308 return SmallVector<Type>(computeProduct(*ratio), newTy);
309 });
310
311 RewritePatternSet patterns(ctx);
312
313 vector::UnrollVectorOptions vectorOptions;
314 vectorOptions.setNativeShapeFn(options.nativeShape);
315
316 populateXeGPUUnrollPatterns(patterns, options);
317 vector::populateVectorUnrollPatterns(patterns, options: vectorOptions);
318
319 (void)applyPatternsGreedily(op, std::move(patterns));
320
321 op->walk(callback: [](Operation *op) {
322 // Remove the layout attributes cached per operands.
323 for (OpOperand &opr : op->getOpOperands()) {
324 std::string name = xegpu::getLayoutName(operand: opr);
325 if (op->hasAttrOfType<xegpu::LayoutAttr>(name))
326 op->removeAttr(name);
327 }
328
329 // Update the layout attributes per result.
330 for (OpResult result : op->getOpResults()) {
331 std::string name = xegpu::getLayoutName(result);
332 if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
333 op->removeAttr(name);
334 if (!isa<LoopLikeOpInterface>(op))
335 xegpu::setLayoutAttr(result, layout.dropInstData());
336 }
337 }
338
339 // Resolve unrealized conversion cast ops emulating pack/unpack
340 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
341 resolveUnrealizedConversionCastOp(castOp);
342 });
343}
344

source code of mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp