1//===- XeGPUSubgroupDistribute.cpp - XeGPU Subgroup Distribute Pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
9#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
10#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
11#include "mlir/Analysis/DataFlowFramework.h"
12#include "mlir/Dialect/GPU/IR/GPUDialect.h"
13#include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
14#include "mlir/Dialect/MemRef/IR/MemRef.h"
15#include "mlir/Dialect/Vector/IR/VectorOps.h"
16#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
17#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
18#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
19#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
20#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
21#include "mlir/IR/AffineMap.h"
22#include "mlir/IR/Attributes.h"
23#include "mlir/IR/Builders.h"
24#include "mlir/IR/BuiltinAttributes.h"
25#include "mlir/IR/BuiltinOps.h"
26#include "mlir/IR/BuiltinTypes.h"
27#include "mlir/IR/Operation.h"
28#include "mlir/IR/PatternMatch.h"
29#include "mlir/IR/TypeRange.h"
30#include "mlir/IR/Value.h"
31#include "mlir/IR/Visitors.h"
32#include "mlir/Interfaces/FunctionInterfaces.h"
33#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
34#include "mlir/Transforms/InliningUtils.h"
35#include "llvm/ADT/ArrayRef.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/SmallVector.h"
38#include "llvm/ADT/TypeSwitch.h"
39#include "llvm/Support/FormatVariadic.h"
40#include "llvm/Support/InterleavedRange.h"
41#include "llvm/Support/raw_ostream.h"
42
43namespace mlir {
44namespace xegpu {
45#define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
46#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
47} // namespace xegpu
48} // namespace mlir
49
50#define DEBUG_TYPE "xegpu-subgroup-distribute"
51#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
52
53using namespace mlir;
54using namespace mlir::dataflow;
55
56/// HW dependent constants.
57/// TODO: These constants should be queried from the target information.
58constexpr unsigned subgroupSize = 16; // How many lanes in a subgroup.
59/// If DPAS A or B operands have low precision element types they must be packed
60/// according to the following sizes.
61constexpr unsigned packedSizeInBitsForDefault =
62 16; // Minimum packing size per register for DPAS A.
63constexpr unsigned packedSizeInBitsForDpasB =
64 32; // Minimum packing size per register for DPAS B.
65
66namespace {
67
68//===----------------------------------------------------------------------===//
69// Layout
70//===----------------------------------------------------------------------===//
71
72/// Helper class to store the ND layout of lanes within a subgroup and data
73/// owned by each lane.
74struct Layout {
75 SmallVector<int64_t, 3> layout;
76 Layout() = default;
77 Layout(std::initializer_list<int64_t> list) : layout(list) {}
78 void print(llvm::raw_ostream &os) const;
79 size_t size() const { return layout.size(); }
80 int64_t operator[](size_t idx) const;
81};
82
83void Layout::print(llvm::raw_ostream &os) const {
84 os << llvm::interleaved_array(R: layout);
85}
86
87int64_t Layout::operator[](size_t idx) const {
88 assert(idx < layout.size() && "Index out of bounds.");
89 return layout[idx];
90}
91
92/// LaneLayout represents the logical layout of lanes within a subgroup when it
93/// accesses some value. LaneData represents the logical layout of data owned by
94/// each work item.
95using LaneLayout = Layout;
96using LaneData = Layout;
97
98//===----------------------------------------------------------------------===//
99// LayoutInfo
100//===----------------------------------------------------------------------===//
101
102/// Helper class for tracking the analysis state of an mlir value. For layout
103/// propagation, the analysis state is simply the lane_layout and lane_data of
104/// each value. Purpose of this analysis to propagate some unique layout for
105/// each value in the program starting from a set of anchor operations (like
106/// DPAS, StoreNd, etc.).
107///
108/// Given this, LayoutInfo satisifies the following properties:
109/// 1) A LayoutInfo value can be in one of two states - `assigned` or `not
110/// assigned`.
111/// 2) Two LayoutInfo values are equal if they are both assigned or
112/// both not assigned. The concrete value of assigned state does not matter.
113/// 3) The meet operator works as follows:
114/// - If current state is assigned, return the current state. (already
115/// a unique layout is assigned. don't change it)
116/// - Otherwise, return the other state.
117
118struct LayoutInfo {
119private:
120 LaneLayout laneLayout;
121 LaneData laneData;
122
123public:
124 LayoutInfo() = default;
125 LayoutInfo(const LaneLayout &layout, const LaneData &data)
126 : laneLayout(layout), laneData(data) {}
127
128 // Two lattice values are equal if they have `some` layout. The actual
129 // content of the layout does not matter.
130 bool operator==(const LayoutInfo &other) const {
131 return this->isAssigned() == other.isAssigned();
132 }
133
134 static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs);
135
136 static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs);
137
138 void print(raw_ostream &os) const;
139
140 bool isAssigned() const {
141 return laneLayout.size() > 0 && laneData.size() > 0;
142 }
143
144 LayoutInfo getTransposedLayout(ArrayRef<int64_t> permutation) const;
145
146 const LaneLayout &getLayout() const { return laneLayout; }
147 const LaneData &getData() const { return laneData; }
148 ArrayRef<int64_t> getLayoutAsArrayRef() const { return laneLayout.layout; }
149 ArrayRef<int64_t> getDataAsArrayRef() const { return laneData.layout; }
150};
151
152void LayoutInfo::print(raw_ostream &os) const {
153 if (isAssigned()) {
154 os << "lane_layout: ";
155 laneLayout.print(os);
156 os << ", lane_data: ";
157 laneData.print(os);
158 } else {
159 os << "Not assigned.";
160 }
161}
162
163LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) {
164 if (!lhs.isAssigned())
165 return rhs;
166 return lhs;
167}
168
169/// Since this is a backward analysis, join method is not used.
170LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
171 llvm_unreachable("Join should not be triggered by layout propagation.");
172}
173
174/// Get the transposed layout according to the given permutation.
175LayoutInfo
176LayoutInfo::getTransposedLayout(ArrayRef<int64_t> permutation) const {
177 if (!isAssigned())
178 return {};
179 LaneLayout newLayout;
180 LaneData newData;
181 for (int64_t idx : permutation) {
182 newLayout.layout.push_back(Elt: laneLayout.layout[idx]);
183 newData.layout.push_back(Elt: laneData.layout[idx]);
184 }
185 return LayoutInfo(newLayout, newData);
186}
187
188//===----------------------------------------------------------------------===//
189// LayoutInfoLattice
190//===----------------------------------------------------------------------===//
191
192/// Lattice holding the LayoutInfo for each value.
193struct LayoutInfoLattice : public Lattice<LayoutInfo> {
194 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LayoutInfoLattice)
195 using Lattice::Lattice;
196};
197
198/// Helper Functions to get default layouts. A `default layout` is a layout that
199/// is assigned to a value when the layout is not fixed by some anchor operation
200/// (like DPAS).
201
202/// Helper Function to get the default layout for uniform values like constants.
203/// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
204/// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
205static LayoutInfo getDefaultLayoutInfo(unsigned rank) {
206 assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
207 if (rank == 1)
208 return LayoutInfo(LaneLayout({subgroupSize}), LaneData({1}));
209 return LayoutInfo(LaneLayout({1, subgroupSize}), LaneData({1, 1}));
210}
211
212/// Helper to get the default layout for a vector type.
213static LayoutInfo getDefaultLayoutInfo(VectorType vectorTy) {
214 // Expecting a 1D or 2D vector.
215 assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
216 "Expected 1D or 2D vector.");
217 // Expecting int or float element type.
218 assert(vectorTy.getElementType().isIntOrFloat() &&
219 "Expected int or float element type.");
220 // If the rank is 1, then return default layout for 1D vector.
221 if (vectorTy.getRank() == 1)
222 return getDefaultLayoutInfo(rank: 1);
223 // Packing factor is determined by the element type bitwidth.
224 int packingFactor = 1;
225 unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
226 if (bitwidth < packedSizeInBitsForDefault)
227 packingFactor = packedSizeInBitsForDefault / bitwidth;
228 return LayoutInfo(LaneLayout({1, subgroupSize}),
229 LaneData({1, packingFactor}));
230}
231
232/// Helper Function to get the expected layouts for DPAS operands. `lane_data`
233/// is set according to the following criteria:
234/// * For A operand, the data must be packed in minimum
235/// `packedSizeInBitsForDefault`
236/// * For B operand, the data must be packed in minimum
237/// `packedSizeInBitsForDpasB`
238static LayoutInfo getLayoutInfoForDPASOperand(VectorType vectorTy,
239 unsigned operandNum) {
240 Type elementTy = vectorTy.getElementType();
241 assert(elementTy.isIntOrFloat() &&
242 "Expected int or float type in DPAS operands");
243 LaneLayout layout({1, subgroupSize});
244 // For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
245 // must have the VNNI format.
246 if (operandNum == 1 &&
247 elementTy.getIntOrFloatBitWidth() < packedSizeInBitsForDpasB) {
248 LaneData data(
249 {packedSizeInBitsForDpasB / elementTy.getIntOrFloatBitWidth(), 1});
250 return LayoutInfo(layout, data);
251 }
252 // Otherwise, return the default layout for the vector type.
253 return getDefaultLayoutInfo(vectorTy);
254}
255
256//===----------------------------------------------------------------------===//
257// LayoutInfoPropagation
258//===----------------------------------------------------------------------===//
259
260/// Backward data flow analysis to propagate the lane_layout and lane_data of
261/// each value in the program. Currently, the layouts for operands DPAS,
262/// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of
263/// this analysis is to propagate those known layouts to all their producers and
264/// (other) consumers.
265class LayoutInfoPropagation
266 : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
267private:
268 void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
269 ArrayRef<const LayoutInfoLattice *> results);
270
271 void visitStoreNdOp(xegpu::StoreNdOp store,
272 ArrayRef<LayoutInfoLattice *> operands,
273 ArrayRef<const LayoutInfoLattice *> results);
274
275 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
276 ArrayRef<LayoutInfoLattice *> operands,
277 ArrayRef<const LayoutInfoLattice *> results);
278
279 void visitLoadNdOp(xegpu::LoadNdOp load,
280 ArrayRef<LayoutInfoLattice *> operands,
281 ArrayRef<const LayoutInfoLattice *> results);
282
283 void visitLoadGatherOp(xegpu::LoadGatherOp load,
284 ArrayRef<LayoutInfoLattice *> operands,
285 ArrayRef<const LayoutInfoLattice *> results);
286
287 void visitTransposeOp(vector::TransposeOp transpose,
288 ArrayRef<LayoutInfoLattice *> operands,
289 ArrayRef<const LayoutInfoLattice *> results);
290
291 void visitVectorBitcastOp(vector::BitCastOp bitcast,
292 ArrayRef<LayoutInfoLattice *> operands,
293 ArrayRef<const LayoutInfoLattice *> results);
294
295 void visitCreateDescOp(xegpu::CreateDescOp createDesc,
296 ArrayRef<LayoutInfoLattice *> operands,
297 ArrayRef<const LayoutInfoLattice *> results);
298
299 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
300 ArrayRef<LayoutInfoLattice *> operands,
301 ArrayRef<const LayoutInfoLattice *> results);
302
303 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
304 ArrayRef<LayoutInfoLattice *> operands,
305 ArrayRef<const LayoutInfoLattice *> results);
306
307 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
308 ArrayRef<LayoutInfoLattice *> operands,
309 ArrayRef<const LayoutInfoLattice *> results);
310
311public:
312 LayoutInfoPropagation(DataFlowSolver &solver,
313 SymbolTableCollection &symbolTable)
314 : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
315 using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
316
317 LogicalResult
318 visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
319 ArrayRef<const LayoutInfoLattice *> results) override;
320
321 void visitBranchOperand(OpOperand &operand) override {};
322
323 void visitCallOperand(OpOperand &operand) override {};
324
325 void visitExternalCall(CallOpInterface call,
326 ArrayRef<LayoutInfoLattice *> operands,
327 ArrayRef<const LayoutInfoLattice *> results) override {
328 };
329
330 void setToExitState(LayoutInfoLattice *lattice) override {
331 (void)lattice->meet(rhs: LayoutInfo());
332 }
333};
334} // namespace
335
336LogicalResult LayoutInfoPropagation::visitOperation(
337 Operation *op, ArrayRef<LayoutInfoLattice *> operands,
338 ArrayRef<const LayoutInfoLattice *> results) {
339 TypeSwitch<Operation *>(op)
340 .Case<xegpu::DpasOp>(
341 [&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
342 .Case<xegpu::StoreNdOp>(
343 [&](auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
344 .Case<xegpu::StoreScatterOp>([&](auto storeScatterOp) {
345 visitStoreScatterOp(storeScatterOp, operands, results);
346 })
347 .Case<xegpu::LoadNdOp>(
348 [&](auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
349 .Case<xegpu::LoadGatherOp>([&](auto loadGatherOp) {
350 visitLoadGatherOp(loadGatherOp, operands, results);
351 })
352 .Case<xegpu::CreateDescOp>([&](auto createDescOp) {
353 visitCreateDescOp(createDescOp, operands, results);
354 })
355 .Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
356 visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
357 })
358 .Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
359 visitPrefetchNdOp(prefetchNdOp, operands, results);
360 })
361 // No need to propagate the layout to operands in CreateNdDescOp because
362 // they are scalars (offsets, sizes, etc.).
363 .Case<xegpu::CreateNdDescOp>([&](auto createNdDescOp) {})
364 .Case<vector::TransposeOp>([&](auto transposeOp) {
365 visitTransposeOp(transposeOp, operands, results);
366 })
367 .Case<vector::BitCastOp>([&](auto bitcastOp) {
368 visitVectorBitcastOp(bitcastOp, operands, results);
369 })
370 .Case<vector::MultiDimReductionOp>([&](auto reductionOp) {
371 visitVectorMultiReductionOp(reductionOp, operands, results);
372 })
373 // All other ops.
374 .Default([&](Operation *op) {
375 for (const LayoutInfoLattice *r : results) {
376 for (LayoutInfoLattice *operand : operands) {
377 // Propagate the layout of the result to the operand.
378 if (r->getValue().isAssigned())
379 meet(operand, *r);
380 }
381 }
382 });
383 // Add a dependency from each result to program point after the operation.
384 for (const LayoutInfoLattice *r : results) {
385 addDependency(state: const_cast<LayoutInfoLattice *>(r), point: getProgramPointAfter(op));
386 }
387 return success();
388}
389
390void LayoutInfoPropagation::visitPrefetchNdOp(
391 xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
392 ArrayRef<const LayoutInfoLattice *> results) {
393 // Here we assign the default layout to the tensor descriptor operand of
394 // prefetch.
395 auto tdescTy = prefetch.getTensorDescType();
396 auto prefetchLayout = getDefaultLayoutInfo(
397 VectorType::get(tdescTy.getShape(), tdescTy.getElementType()));
398 // Propagate the layout to the source tensor descriptor.
399 propagateIfChanged(state: operands[0], changed: operands[0]->meet(prefetchLayout));
400}
401
402void LayoutInfoPropagation::visitVectorMultiReductionOp(
403 vector::MultiDimReductionOp reduction,
404 ArrayRef<LayoutInfoLattice *> operands,
405 ArrayRef<const LayoutInfoLattice *> results) {
406 // The layout of the result must be present.
407 LayoutInfo resultLayout = results[0]->getValue();
408 if (!resultLayout.isAssigned())
409 return;
410 // We only consider 2D -> 1D reductions at this point.
411 assert(resultLayout.getLayout().size() == 1 &&
412 "Expected 1D layout for reduction result.");
413 // Given that the result is 1D, the layout of the operand should be 2D with
414 // default layout.
415 LayoutInfo operandLayout = getDefaultLayoutInfo(rank: 2);
416 propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: operandLayout));
417 // Accumulator should have the same layout as the result.
418 propagateIfChanged(state: operands[1], changed: operands[1]->meet(rhs: resultLayout));
419}
420
421/// Propagate the layout of the result tensor to the source tensor descriptor in
422/// UpdateNdOffsetOp.
423void LayoutInfoPropagation::visitUpdateNdOffsetOp(
424 xegpu::UpdateNdOffsetOp updateNdOffset,
425 ArrayRef<LayoutInfoLattice *> operands,
426 ArrayRef<const LayoutInfoLattice *> results) {
427 // The layout of the result must be present.
428 LayoutInfo resultLayout = results[0]->getValue();
429 if (!resultLayout.isAssigned())
430 return;
431 // Propagate the layout to the source operand.
432 propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: resultLayout));
433}
434
435/// Set the layouts for DPAS A, B, and C operands.
436void LayoutInfoPropagation::visitDpasOp(
437 xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
438 ArrayRef<const LayoutInfoLattice *> results) {
439 VectorType aTy = dpas.getLhsType();
440 VectorType bTy = dpas.getRhsType();
441 propagateIfChanged(state: operands[0],
442 changed: operands[0]->meet(getLayoutInfoForDPASOperand(aTy, 0)));
443 propagateIfChanged(state: operands[1],
444 changed: operands[1]->meet(getLayoutInfoForDPASOperand(bTy, 1)));
445 if (operands.size() > 2) {
446 VectorType cTy = dpas.getAccType();
447 propagateIfChanged(state: operands[2],
448 changed: operands[2]->meet(getLayoutInfoForDPASOperand(cTy, 2)));
449 }
450}
451
452/// Set the layout for the value and tensor descriptor operands in StoreNdOp.
453void LayoutInfoPropagation::visitStoreNdOp(
454 xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
455 ArrayRef<const LayoutInfoLattice *> results) {
456 LayoutInfo storeLayout = getDefaultLayoutInfo(store.getValueType());
457 // Both operands should have the same layout
458 for (LayoutInfoLattice *operand : operands) {
459 propagateIfChanged(state: operand, changed: operand->meet(rhs: storeLayout));
460 }
461}
462
463/// Propagate the layout of the value to the tensor descriptor operand in
464/// LoadNdOp.
465void LayoutInfoPropagation::visitLoadNdOp(
466 xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
467 ArrayRef<const LayoutInfoLattice *> results) {
468 LayoutInfo valueLayout = results[0]->getValue();
469 // Need the layout of the value to propagate to the tensor descriptor.
470 if (!valueLayout.isAssigned())
471 return;
472 LayoutInfo tensorDescLayout = valueLayout;
473 // LoadNdOp has the transpose effect. However, at the stage of this analysis
474 // this effect is not expected and should be abstracted away. Emit a warning.
475 if (auto transpose = load.getTranspose()) {
476 load.emitWarning("Transpose effect is not expected for LoadNdOp at "
477 "LayoutInfoPropagation stage.");
478 tensorDescLayout = valueLayout.getTransposedLayout(permutation: transpose.value());
479 }
480 // Propagate the new layout to the tensor descriptor operand.
481 propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: tensorDescLayout));
482}
483
484/// For vector::TransposeOp, the layout of the result is transposed and
485/// propagated to the operand.
486void LayoutInfoPropagation::visitTransposeOp(
487 vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
488 ArrayRef<const LayoutInfoLattice *> results) {
489 // Need the layout of transpose result to propagate to the operands.
490 LayoutInfo resultLayout = results[0]->getValue();
491 if (!resultLayout.isAssigned())
492 return;
493 LayoutInfo newLayout =
494 resultLayout.getTransposedLayout(permutation: transpose.getPermutation());
495 // Propagate the new layout to the vector operand.
496 propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: newLayout));
497}
498
499/// For vector::BitCastOp, the lane_data of the source layout is changed based
500/// on the bit width of the source and result types.
501void LayoutInfoPropagation::visitVectorBitcastOp(
502 vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
503 ArrayRef<const LayoutInfoLattice *> results) {
504 // Need the layout of bitcast result to propagate to the operands.
505 LayoutInfo resultLayout = results[0]->getValue();
506 if (!resultLayout.isAssigned())
507 return;
508 int inElemTyBitWidth =
509 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
510 int outElemTyBitWidth =
511 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
512
513 // LaneLayout does not change.
514 const LaneLayout &newLaneLayout = resultLayout.getLayout();
515 const LaneData &currData = resultLayout.getData();
516 LaneData newLaneData;
517 // It's a widening bitcast
518 if (inElemTyBitWidth < outElemTyBitWidth) {
519 int ratio = outElemTyBitWidth / inElemTyBitWidth;
520 newLaneData = resultLayout.getData()[0] == 1
521 ? LaneData({1, currData[1] * ratio})
522 : LaneData({currData[0] * ratio, 1});
523 } else {
524 // It's a narrowing bitcast
525 int ratio = inElemTyBitWidth / outElemTyBitWidth;
526 newLaneData = resultLayout.getData()[0] == 1
527 ? LaneData({1, currData[1] / ratio})
528 : LaneData({currData[0] / ratio, 1});
529 }
530
531 propagateIfChanged(state: operands[0],
532 changed: operands[0]->meet(rhs: LayoutInfo(newLaneLayout, newLaneData)));
533}
534
535/// Propagate the layout of the result to the tensor descriptor and mask
536/// operands in LoadGatherOp.
537void LayoutInfoPropagation::visitLoadGatherOp(
538 xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
539 ArrayRef<const LayoutInfoLattice *> results) {
540 LayoutInfo valueLayout = results[0]->getValue();
541 // Need the layout of the value to propagate to the tensor descriptor.
542 if (!valueLayout.isAssigned())
543 return;
544
545 LayoutInfo tensorDescLayout = valueLayout;
546 if (load.getTranspose()) {
547 // LoadGatherOp has the transpose effect. However, at the stage of this
548 // analyis this effect is not expected and should be abstracted away. Emit
549 // a warning.
550 load.emitWarning("Transpose effect is not expected for LoadGatherOp at "
551 "LayoutInfoPropagation stage.");
552 tensorDescLayout = valueLayout.getTransposedLayout(permutation: {1, 0});
553 }
554 // Mask operand should have 1D default layout.
555 LayoutInfo maskLayout = getDefaultLayoutInfo(rank: 1);
556 // Propagate the new layout to the tensor descriptor operand.
557 propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: tensorDescLayout));
558 // Propagate the new layout to the mask operand.
559 propagateIfChanged(state: operands[1], changed: operands[1]->meet(rhs: maskLayout));
560}
561
562/// Propagate the layout of the descriptor to the vector offset operand in
563/// CreateDescOp.
564void LayoutInfoPropagation::visitCreateDescOp(
565 xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
566 ArrayRef<const LayoutInfoLattice *> results) {
567 LayoutInfo descLayout = results[0]->getValue();
568 // Need the layout of the descriptor to propagate to the operands.
569 if (!descLayout.isAssigned())
570 return;
571 // For offset operand propagate 1D default layout.
572 LayoutInfo layout = getDefaultLayoutInfo(rank: 1);
573 propagateIfChanged(state: operands[1], changed: operands[1]->meet(rhs: layout));
574}
575
576/// Set the layout for the value, tensor descriptor, and mask operands in the
577/// StoreScatterOp.
578void LayoutInfoPropagation::visitStoreScatterOp(
579 xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
580 ArrayRef<const LayoutInfoLattice *> results) {
581 // Currently, for 2D StoreScatterOp we expect that the height dimension of
582 // the tensor descriptor is equal to the subgroup size. This is ensured by
583 // the op verifier.
584 ArrayRef<int64_t> tdescShape = storeScatter.getTensorDescType().getShape();
585 if (tdescShape.size() > 1)
586 assert(
587 tdescShape[0] == subgroupSize &&
588 "Expected the first dimension of 2D tensor descriptor to be equal to "
589 "subgroup size.");
590
591 LayoutInfo valueLayout = getDefaultLayoutInfo(storeScatter.getValueType());
592 LayoutInfo storeScatterLayout = valueLayout;
593 if (storeScatter.getTranspose()) {
594 // StoreScatteOp allows transpose effect. However, at the stage of this
595 // analyis this effect is not expected and should be abstracted away. Emit
596 // a warning.
597 storeScatter.emitWarning("Transpose effect is not expected for "
598 "StoreScatterOp at LayoutInfoPropagation stage.");
599 storeScatterLayout = valueLayout.getTransposedLayout(permutation: {1, 0});
600 }
601 // Propagate the value layout.
602 propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: valueLayout));
603 // Propagate the tensor descriptor layout.
604 propagateIfChanged(state: operands[1], changed: operands[1]->meet(rhs: storeScatterLayout));
605 // Use default 1D layout for mask operand.
606 LayoutInfo maskLayout = getDefaultLayoutInfo(rank: 1);
607 propagateIfChanged(state: operands[2], changed: operands[2]->meet(rhs: maskLayout));
608}
609
610namespace {
611
612//===----------------------------------------------------------------------===//
613// RunLayoutInfoPropagation
614//===----------------------------------------------------------------------===//
615
616/// Driver class for running the LayoutInfoPropagation analysis.
617class RunLayoutInfoPropagation {
618public:
619 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
620
621 RunLayoutInfoPropagation(Operation *op) : target(op) {
622 SymbolTableCollection symbolTable;
623 solver.load<DeadCodeAnalysis>();
624 solver.load<SparseConstantPropagation>();
625 solver.load<LayoutInfoPropagation>(args&: symbolTable);
626 (void)solver.initializeAndRun(top: op);
627 }
628
629 LayoutInfo getLayoutInfo(Value val);
630
631 void printAnalysisResult(llvm::raw_ostream &os);
632
633private:
634 DataFlowSolver solver;
635 const Operation *target;
636};
637} // namespace
638
639LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
640 auto *state = solver.lookupState<LayoutInfoLattice>(anchor: val);
641 if (!state)
642 return {};
643 return state->getValue();
644}
645
646void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
647 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
648 os << "function: " << funcOp.getName() << ":\n";
649 // Function arguments
650 for (BlockArgument arg : funcOp.getArguments()) {
651 LayoutInfo layout = getLayoutInfo(arg);
652 os << "argument: " << arg << "\n";
653 os << "layout : ";
654 layout.print(os);
655 os << "\n";
656 }
657 // Function ops
658 funcOp.walk([&](Operation *op) {
659 // Skip ops that do not have results
660 if (op->getResults().empty())
661 return;
662 os << "op : ";
663 // For control-flow ops, print the op name only.
664 if (isa<BranchOpInterface>(Val: op) || isa<RegionBranchOpInterface>(Val: op))
665 os << op->getName();
666 else
667 op->print(os);
668 os << "\n";
669 // Print the layout for each result.
670 for (auto [i, r] : llvm::enumerate(First: op->getResults())) {
671 LayoutInfo layout = getLayoutInfo(val: r);
672 os << "layout for result #" << i << ": ";
673 layout.print(os);
674 os << "\n";
675 }
676 });
677 };
678
679 SmallVector<FunctionOpInterface> funcOps;
680 if (auto modOp = dyn_cast<ModuleOp>(target)) {
681 for (auto funcOp : modOp.getOps<FunctionOpInterface>()) {
682 funcOps.push_back(funcOp);
683 }
684 // Collect all GpuFuncOps in the module.
685 for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
686 for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>()) {
687 funcOps.push_back(gpuFuncOp);
688 }
689 }
690 }
691 // Print the analysis result for each function.
692 for (FunctionOpInterface funcOp : funcOps) {
693 printFunctionResult(funcOp);
694 }
695}
696
697namespace {
698
699//===----------------------------------------------------------------------===//
700// LayoutAttrAssignment
701//===----------------------------------------------------------------------===//
702
703/// This class is responsible for assigning the layout attributes to the ops and
704/// their users based on the layout propagation analysis result.
705class LayoutAttrAssignment {
706public:
707 LayoutAttrAssignment(Operation *top,
708 function_ref<LayoutInfo(Value)> getLayout)
709 : getAnalysisResult(getLayout), top(top) {}
710
711 LogicalResult run();
712
713private:
714 LogicalResult assign(Operation *op);
715 void assignToUsers(Value v, xegpu::LayoutAttr layout);
716 xegpu::LayoutAttr getLayoutAttrForValue(Value v);
717 LogicalResult resolveConflicts();
718 // Callable to get the layout of a value based on the layout propagation
719 // analysis.
720 function_ref<LayoutInfo(Value)> getAnalysisResult;
721 Operation *top;
722};
723
724} // namespace
725
726/// Helper to assign the layout attribute to the users of the value.
727void LayoutAttrAssignment::assignToUsers(Value v, xegpu::LayoutAttr layout) {
728 for (OpOperand &user : v.getUses()) {
729 Operation *owner = user.getOwner();
730 std::string attrName = xegpu::getLayoutName(operand: user);
731 owner->setAttr(attrName, layout);
732 }
733}
734
735/// Convert the layout assigned to a value to xegpu::LayoutAttr.
736xegpu::LayoutAttr LayoutAttrAssignment::getLayoutAttrForValue(Value v) {
737 LayoutInfo layout = getAnalysisResult(v);
738 if (!layout.isAssigned())
739 return {};
740 SmallVector<int, 2> laneLayout, laneData;
741 for (auto [layout, data] : llvm::zip_equal(t: layout.getLayoutAsArrayRef(),
742 u: layout.getDataAsArrayRef())) {
743 laneLayout.push_back(Elt: static_cast<int>(layout));
744 laneData.push_back(Elt: static_cast<int>(data));
745 }
746 return xegpu::LayoutAttr::get(v.getContext(), laneLayout, laneData);
747}
748
749/// Assign xegpu::LayoutAttr to the op and its users. The layout is assigned
750/// based on the layout propagation analysis result.
751LogicalResult LayoutAttrAssignment::assign(Operation *op) {
752 // For function ops, propagate the function argument layout to the users.
753 if (auto func = dyn_cast<FunctionOpInterface>(op)) {
754 for (BlockArgument arg : func.getArguments()) {
755 xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(arg);
756 if (layoutInfo) {
757 assignToUsers(arg, layoutInfo);
758 }
759 }
760 return success();
761 }
762 // If no results, move on.
763 if (op->getNumResults() == 0)
764 return success();
765 // If all the results are scalars, move on.
766 if (llvm::all_of(Range: op->getResultTypes(),
767 P: [](Type t) { return t.isIntOrIndexOrFloat(); }))
768 return success();
769 // If the op has more than one result and at least one result is a tensor
770 // descriptor, exit. This case is not supported yet.
771 // TODO: Support this case.
772 if (op->getNumResults() > 1 && llvm::any_of(Range: op->getResultTypes(), P: [](Type t) {
773 return isa<xegpu::TensorDescType>(Val: t);
774 })) {
775 LLVM_DEBUG(
776 DBGS() << op->getName()
777 << " op has more than one result and at least one is a tensor "
778 "descriptor. This case is not handled.\n");
779 return failure();
780 }
781 // If the result is a tensor descriptor, attach the layout to the tensor
782 // descriptor itself.
783 if (auto tensorDescTy =
784 dyn_cast<xegpu::TensorDescType>(op->getResultTypes()[0])) {
785 xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(op->getResult(0));
786 if (!layoutInfo) {
787 LLVM_DEBUG(DBGS() << "No layout for result of " << *op << "\n");
788 return failure();
789 }
790
791 // Clone the op, attach the layout to the result tensor descriptor, and
792 // remove the original op.
793 OpBuilder builder(op);
794 Operation *newOp = builder.clone(op&: *op);
795 auto newTensorDescTy = xegpu::TensorDescType::get(
796 tensorDescTy.getContext(), tensorDescTy.getShape(),
797 tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layoutInfo);
798 newOp->getResult(idx: 0).setType(newTensorDescTy);
799 op->replaceAllUsesWith(values: newOp->getResults());
800 op->erase();
801 return success();
802 }
803 // Otherwise simply attach the layout to the op itself.
804 for (auto r : op->getOpResults()) {
805 xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(r);
806 if (layoutInfo) {
807 std::string attrName = xegpu::getLayoutName(result: r);
808 op->setAttr(attrName, layoutInfo);
809 // Attach the layout attribute to the users of the result.
810 assignToUsers(v: r, layout: layoutInfo);
811 }
812 }
813 return success();
814}
815
816/// Walk the IR and attach xegpu::LayoutAttr to all ops and their users.
817LogicalResult LayoutAttrAssignment::run() {
818 auto walkResult = top->walk(callback: [&](Operation *op) {
819 if (failed(Result: assign(op)))
820 return WalkResult::interrupt();
821 return WalkResult::advance();
822 });
823
824 if (walkResult.wasInterrupted())
825 return failure();
826
827 return resolveConflicts();
828}
829
830/// TODO: Implement the layout conflict resolution. This must ensure mainly two
831/// things:
832/// 1) Is a given layout supported by the op? (need to query the target
833/// HW info). Otherwise can we achieve this layout using a layout conversion?
834/// 2) Do all the operands have the required layout? If not, can it
835/// be resolved using a layout conversion?
836LogicalResult LayoutAttrAssignment::resolveConflicts() { return success(); }
837
838namespace {
839
840//===----------------------------------------------------------------------===//
841// SIMT Distribution Patterns
842//===----------------------------------------------------------------------===//
843
844/// Helper function to get distributed vector type for a source vector type
845/// according to the lane_layout. We simply divide each dimension of tensor
846/// descriptor shape by corresponding lane_layout dimension. If
847/// array_length > 1, that is appended to the front of the ditributed shape.
848/// NOTE: This is the vector type that will be returned by the
849/// gpu.warp_execute_on_lane0 op.
850///
851/// Examples:
852/// | original vector shape | lane_layout | distributed vector shape |
853/// |-----------------------|-------------|--------------------------|
854/// | 32x16 | [1, 16] | 32x1 |
855/// | 32x16 | [2, 8] | 16x2 |
856/// | 2x32x16 | [1, 16] | 2x32x1 |
857static FailureOr<VectorType>
858getDistVecTypeBasedOnLaneLayout(xegpu::LayoutAttr layout,
859 VectorType originalType) {
860 if (!layout)
861 return failure();
862
863 auto laneLayout = layout.getLaneLayout().asArrayRef();
864 assert(originalType.getShape().size() >= laneLayout.size() &&
865 "Rank of the original vector type should be greater or equal to the "
866 "size of the lane layout to distribute the vector type.");
867 SmallVector<int64_t> distributedShape(originalType.getShape());
868 // Only distribute the last `laneLayout.size()` dimensions. The remaining
869 // dimensions are not distributed.
870 unsigned distributionStart = originalType.getRank() - laneLayout.size();
871 for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
872 if (i < distributionStart) {
873 continue;
874 }
875 // Check if the dimension can be distributed evenly.
876 if (dim % laneLayout[i - distributionStart] != 0)
877 return failure();
878 distributedShape[i] = dim / laneLayout[i - distributionStart];
879 }
880 return VectorType::get(distributedShape, originalType.getElementType());
881}
882
883/// Helper function to resolve types if the distributed type out of
884/// gpu.warp_execute_on_lane0 is different from the expected xegpu SIMT type.
885/// Example 1:
886/// distributed type: vector<8x1xf32>
887/// expected type: vector<8xf32>
888/// resolved using,
889/// %0 = vector.shape_cast %1 : vector<8x1xf32> to vector<8xf32>
890/// Example 2:
891/// distributed type: xegpu.tensor_desc<8x16xf32, #xegpu.layout<...>>
892/// expected type: xegpu.tensor_desc<8x16xf32>
893/// resolved using,
894/// %0 = unrealized_conversion_cast %1 :
895/// xegpu.tensor_desc<8x16xf32, #xegpu.layout<..>> ->
896/// xegpu.tensor_desc<8x16xf32>
897template <typename T>
898static Value resolveDistributedTy(Value orig, T expected,
899 PatternRewriter &rewriter) {
900 // If orig and expected types are the same, return orig.
901 if (orig.getType() == expected)
902 return orig;
903 // If orig is a vector type, create a shape cast op to reconcile the types.
904 if (isa<VectorType>(Val: orig.getType())) {
905 auto castOp =
906 rewriter.create<vector::ShapeCastOp>(orig.getLoc(), expected, orig);
907 return castOp.getResult();
908 }
909 // If orig is a tensor descriptor type, create an unrealized conversion cast
910 // op to reconcile the types.
911 if (isa<xegpu::TensorDescType>(Val: orig.getType())) {
912 auto castOp = rewriter.create<UnrealizedConversionCastOp>(orig.getLoc(),
913 expected, orig);
914 return castOp.getResult(0);
915 }
916 llvm_unreachable("Unsupported type for reconciliation");
917 return orig;
918}
919
920/// Helper function to filter out the temporary layout attributes attached
921/// during the layout assignment process. These are not needed after going to
922/// SIMT.
923static SmallVector<NamedAttribute>
924removeTemporaryLayoutAttributes(ArrayRef<NamedAttribute> attrs) {
925 SmallVector<NamedAttribute> newAttrs;
926 for (NamedAttribute attr : attrs) {
927 if (!isa<xegpu::LayoutAttr>(Val: attr.getValue()))
928 newAttrs.push_back(Elt: attr);
929 }
930 return newAttrs;
931}
932
933/// Helper function to check if the layout is packed. Layout is packed if it is
934/// 2D and lane_data[0] != 1 (data packed from col dimension).
935static bool hasPackedLayout(xegpu::LayoutAttr layout) {
936 if (layout == xegpu::LayoutAttr())
937 return false;
938 DenseI32ArrayAttr laneData = layout.getLaneData();
939 if (!laneData || laneData.size() != 2)
940 return false;
941 return laneData.asArrayRef()[0] != 1;
942}
943
944/// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
945/// of the original GPUFuncOp to the new GPUFuncOp such that entire body is
946/// contained within a WarpExecuteOnLane0Op.
947/// Example:
948///
949/// ```
950/// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
951/// ...
952/// ...
953/// gpu.return %result: vector<8x16xf32>
954/// }
955/// ```
956/// To
957/// ```
958/// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> {
959/// %laneid = gpu.lane_id : index
960/// %0 = gpu.warp_execute_on_lane_0(%laneid) -> vector<8x16xf32> {
961/// ...
962/// ...
963/// gpu.yield %result: vector<8x16xf32>
964/// }
965/// return %0
966/// }
967struct MoveFuncBodyToWarpExecuteOnLane0
968 : public OpRewritePattern<gpu::GPUFuncOp> {
969 using OpRewritePattern<gpu::GPUFuncOp>::OpRewritePattern;
970 LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
971 PatternRewriter &rewriter) const override {
972 // If the function only contains a single void return, skip.
973 if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
974 return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
975 }))
976 return failure();
977 // If the function already moved inside a warp_execute_on_lane0, skip.
978 if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
979 return isa<gpu::WarpExecuteOnLane0Op>(op);
980 }))
981 return failure();
982 // Create a new function with the same signature.
983 auto newGpuFunc = rewriter.create<gpu::GPUFuncOp>(
984 gpuFuncOp.getLoc(), gpuFuncOp.getName(), gpuFuncOp.getFunctionType());
985 // Create a WarpExecuteOnLane0Op with same arguments and results as the
986 // original gpuFuncOp.
987 rewriter.setInsertionPointToEnd(&newGpuFunc.getFunctionBody().front());
988 auto laneId = rewriter.create<gpu::LaneIdOp>(
989 newGpuFunc.getLoc(), rewriter.getIndexType(),
990 /** upperBound = **/ mlir::IntegerAttr());
991 ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
992 auto warpOp = rewriter.create<gpu::WarpExecuteOnLane0Op>(
993 laneId.getLoc(), gpuFuncResultType, laneId, subgroupSize,
994 newGpuFunc.getArguments(), newGpuFunc.getArgumentTypes());
995 Block &warpBodyBlock = warpOp.getBodyRegion().front();
996 // Replace the ReturnOp of the original gpu function with a YieldOp.
997 auto origRetunOp =
998 cast<gpu::ReturnOp>(gpuFuncOp.getBlocks().back().getTerminator());
999 rewriter.setInsertionPointAfter(origRetunOp);
1000 rewriter.create<gpu::YieldOp>(origRetunOp.getLoc(),
1001 origRetunOp.getOperands());
1002 rewriter.eraseOp(op: origRetunOp);
1003 // Move the original function body to the WarpExecuteOnLane0Op body.
1004 rewriter.inlineRegionBefore(gpuFuncOp.getBody(), warpOp.getBodyRegion(),
1005 warpOp.getBodyRegion().begin());
1006 rewriter.eraseBlock(block: &warpBodyBlock);
1007 // Insert a new ReturnOp after the WarpExecuteOnLane0Op.
1008 rewriter.setInsertionPointAfter(warpOp);
1009 rewriter.create<gpu::ReturnOp>(newGpuFunc.getLoc(), warpOp.getResults());
1010 rewriter.replaceOp(gpuFuncOp, newGpuFunc);
1011 return success();
1012 }
1013};
1014
1015/// Distribute a create_nd_tdesc feeding into vector.yield op of the enclosing
1016/// `gpu.warp_execute_on_lane_0` region. After the sinking, the warp op will
1017/// still contain the original op that will not be used by the yield op (and
1018/// should be cleaned up later). The yield op will bypass the create_nd_tdesc's
1019/// arguments. Tensor descriptor shape is not distributed because it is a
1020/// uniform value across all work items within the subgroup. However, the
1021/// layout information is dropped in the new tensor descriptor type.
1022///
1023/// Example:
1024///
1025/// ```
1026/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1027/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
1028/// (!xegpu.tensor_desc<4x8xf32, #layout0>) {
1029/// ...
1030/// %td = xegpu.create_nd_tdesc %arg0[0, 0]
1031/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
1032/// vector.yield %td
1033/// }
1034/// ```
1035/// To
1036/// ```
1037/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (...) {
1038/// ...
1039/// %dead = xegpu.create_nd_tdesc %arg0[0, 0]
1040/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
1041/// vector.yield %arg0, %dead
1042/// }
1043/// %td = xegpu.create_nd_tdesc %r#0[0, 0]: memref<4x8xf32>
1044/// -> !xegpu.tensor_desc<4x8xf32>
1045///
1046/// ```
1047struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
1048 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1049 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1050 PatternRewriter &rewriter) const override {
1051 OpOperand *operand =
1052 getWarpResult(subgroupOp, llvm::IsaPred<xegpu::CreateNdDescOp>);
1053 if (!operand)
1054 return rewriter.notifyMatchFailure(
1055 subgroupOp, "warp result is not a xegpu::CreateNdDesc op");
1056 auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
1057 unsigned operandIdx = operand->getOperandNumber();
1058
1059 xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr();
1060 if (!layout)
1061 return rewriter.notifyMatchFailure(
1062 descOp, "the tensor descriptor lacks layout attribute");
1063
1064 SmallVector<size_t> newRetIndices;
1065 SmallVector<Value> newYieldValues;
1066 SmallVector<Type> newYieldTypes;
1067
1068 for (Value operand : descOp->getOperands()) {
1069 newYieldValues.push_back(operand);
1070 newYieldTypes.push_back(operand.getType());
1071 }
1072 rewriter.setInsertionPoint(subgroupOp);
1073 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1074 rewriter, subgroupOp, /* new yieled values = */ newYieldValues,
1075 /* new yielded types = */ newYieldTypes, newRetIndices);
1076
1077 SmallVector<Value> newDescOperands;
1078 for (size_t i : newRetIndices) {
1079 newDescOperands.push_back(Elt: newWarpOp.getResult(i));
1080 }
1081 rewriter.setInsertionPointAfter(newWarpOp);
1082 xegpu::TensorDescType distributedTensorDescTy =
1083 descOp.getType().dropLayouts(); // Distributed tensor descriptor type
1084 // does not contain layout info.
1085 auto newDescOp = rewriter.create<xegpu::CreateNdDescOp>(
1086 newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
1087 descOp->getAttrs());
1088
1089 Value distributedVal = newWarpOp.getResult(operandIdx);
1090 rewriter.replaceAllUsesWith(distributedVal, newDescOp);
1091 return success();
1092 }
1093};
1094
1095/// Distribute a store_nd op at the end of enclosing
1096/// `gpu.warp_execute_on_lane_0`. In case arguments for the store are passed
1097/// through the warp op interface they would be propagated as returned values.
1098/// Source vector is distributed based on lane layout. Appropriate cast ops are
1099/// inserted if the distributed types does not match expected xegpu SIMT types.
1100///
1101/// Example:
1102///
1103/// ```
1104/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1105/// gpu.warp_execute_on_lane_0(%laneid) -> () {
1106/// ...
1107/// xegpu.store_nd %arg0, %arg1: vector<4x8xf32>,
1108/// !xegpu.tensor_desc<4x8xf32, #layout0>
1109/// }
1110/// ```
1111/// To
1112/// ```
1113/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
1114/// !xegpu.tensor_desc<4x8xf32, #layout0>) {
1115/// gpu.yield %arg0, %arg1: vector<4x8xf32>, !xegpu.tensor_desc<4x8xf32,
1116/// #layout0>
1117/// }
1118/// %0 = vector.shape_cast %r#0: vector<4x1xf32> to vector<4xf32>
1119/// %1 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
1120/// #layout0>
1121/// -> !xegpu.tensor_desc<4x8xf32>
1122/// xegpu.store_nd %0, %1: vector<4xf32>,
1123/// !xegpu.tensor_desc<4x8xf32>
1124///
1125/// ```
1126struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
1127 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1128 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1129 PatternRewriter &rewriter) const override {
1130 auto yield = cast<gpu::YieldOp>(
1131 subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
1132 Operation *lastNode = yield->getPrevNode();
1133 auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
1134 if (!storeOp)
1135 return failure();
1136
1137 xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
1138 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
1139 if (!layout)
1140 return rewriter.notifyMatchFailure(
1141 storeOp, "the source tensor descriptor lacks layout attribute");
1142
1143 FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
1144 getDistVecTypeBasedOnLaneLayout(layout, storeOp.getValueType());
1145 if (failed(Result: distributedTypeByWarpOpOrFailure))
1146 return rewriter.notifyMatchFailure(storeOp,
1147 "Failed to distribute the type");
1148 VectorType distributedTypeByWarpOp =
1149 distributedTypeByWarpOpOrFailure.value();
1150
1151 SmallVector<size_t> newRetIndices;
1152 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1153 rewriter, subgroupOp,
1154 /* new yielded values = */
1155 ValueRange{storeOp.getValue(), storeOp.getTensorDesc()},
1156 /* new yielded types = */
1157 TypeRange{distributedTypeByWarpOp, storeOp.getTensorDescType()},
1158 newRetIndices);
1159 // Create a new store op outside the warp op with the distributed vector
1160 // type. Tensor descriptor is not distributed.
1161 rewriter.setInsertionPointAfter(newWarpOp);
1162 SmallVector<Value> newStoreOperands;
1163
1164 // For the value operand, there can be a mismatch between the vector type
1165 // distributed by the warp op and (xegpu-specific) distributed type
1166 // supported by the store op. Type mismatch must be resolved using
1167 // appropriate cast op.
1168 FailureOr<VectorType> storeNdDistributedValueTyOrFailure =
1169 xegpu::getDistributedVectorType(storeOp.getTensorDescType());
1170 if (failed(Result: storeNdDistributedValueTyOrFailure))
1171 return rewriter.notifyMatchFailure(
1172 storeOp, "Failed to get distributed vector type for the store op");
1173 newStoreOperands.push_back(Elt: resolveDistributedTy(
1174 newWarpOp.getResult(newRetIndices[0]),
1175 storeNdDistributedValueTyOrFailure.value(), rewriter));
1176 // For the tensor descriptor operand, the layout attribute is dropped after
1177 // distribution. Types needs to be resolved in this case also.
1178 xegpu::TensorDescType distributedTensorDescTy =
1179 storeOp.getTensorDescType().dropLayouts();
1180 newStoreOperands.push_back(
1181 Elt: resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
1182 distributedTensorDescTy, rewriter));
1183
1184 rewriter.create<xegpu::StoreNdOp>(
1185 newWarpOp.getLoc(), TypeRange{}, newStoreOperands,
1186 removeTemporaryLayoutAttributes(storeOp->getAttrs()));
1187 rewriter.eraseOp(op: storeOp);
1188 return success();
1189 }
1190};
1191
1192/// Distribute a load_nd op feeding into vector.yield op for the enclosing
1193/// `gpu.warp_execute_on_lane_0` and put it after the warp op.
1194/// The warp op will still contain the original op that will not be used by
1195/// the yield op (and should be cleaned up later). The yield op will
1196/// bypass the load's arguments. Only the loaded vector is distributed
1197/// according to lane layout and, tensor descriptor types is not
1198/// distributed. Appropriate cast ops are inserted if the distributed types does
1199/// not match expected xegpu SIMT types.
1200///
1201/// Example:
1202///
1203/// ```
1204/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1205/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
1206/// (vector<4x1xf32>) {
1207/// ...
1208/// %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32, #layout0>
1209/// ->
1210/// vector<4x8xf32>
1211/// gpu.yield %ld
1212/// }
1213/// ```
1214/// To
1215/// ```
1216/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
1217/// !xegpu.tensor_desc<4x8xf32, #layout0>) {
1218/// ...
1219/// %dead = xegpu.load_nd %arg0: !xegpu.tensor_desc<4x8xf32, #layout0> ->
1220/// vector<4x8xf32> gpu.yield %dead, %arg0
1221/// }
1222/// %0 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
1223/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
1224/// %1 = xegpu.load_nd %0: !xegpu.tensor_desc<4x8xf32> -> vector<4xf32>
1225/// %2 = vector.shape_cast %r#0: vector<4xf32> to vector<4x1xf32>
1226///
1227/// ```
1228struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
1229 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1230 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1231 PatternRewriter &rewriter) const override {
1232 OpOperand *operand =
1233 getWarpResult(subgroupOp, llvm::IsaPred<xegpu::LoadNdOp>);
1234 if (!operand)
1235 return rewriter.notifyMatchFailure(
1236 subgroupOp, "warp result is not a xegpu::LoadNd op");
1237
1238 auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
1239 xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
1240 xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
1241 if (!layout)
1242 return rewriter.notifyMatchFailure(
1243 loadOp, "the source tensor descriptor lacks layout attribute");
1244
1245 unsigned operandIdx = operand->getOperandNumber();
1246 VectorType distributedTypeByWarpOp =
1247 cast<VectorType>(subgroupOp.getResult(operandIdx).getType());
1248
1249 SmallVector<size_t> newRetIndices;
1250 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1251 rewriter, subgroupOp,
1252 /* new yielded values = */ loadOp.getTensorDesc(),
1253 /* new yielded types = */ tensorDescTy, newRetIndices);
1254
1255 // Create a new load op outside the warp op with the distributed vector
1256 // type.
1257 rewriter.setInsertionPointAfter(newWarpOp);
1258 FailureOr<VectorType> loadNdDistValueTyOrFailure =
1259 xegpu::getDistributedVectorType(loadOp.getTensorDescType());
1260 if (failed(Result: loadNdDistValueTyOrFailure))
1261 return rewriter.notifyMatchFailure(
1262 loadOp, "Failed to get distributed vector type for the load op");
1263 xegpu::TensorDescType distributedTensorDescTy =
1264 loadOp.getTensorDescType().dropLayouts(); // Distributed tensor
1265 // descriptor type does not
1266 // contain layout info.
1267 auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
1268 newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
1269 resolveDistributedTy(newWarpOp->getResult(newRetIndices[0]),
1270 distributedTensorDescTy, rewriter),
1271 removeTemporaryLayoutAttributes(loadOp->getAttrs()));
1272 // Set the packed attribute if the layout requires it.
1273 newLoadOp.setPacked(hasPackedLayout(layout));
1274 Value distributedVal = newWarpOp.getResult(operandIdx);
1275 // There can be a conflict between the vector type distributed by the
1276 // warp op and (xegpu-specific) distributed type supported by the load
1277 // op. Resolve these mismatches by inserting a cast.
1278 Value tyResolvedVal = resolveDistributedTy(
1279 newLoadOp.getResult(), distributedTypeByWarpOp, rewriter);
1280 rewriter.replaceAllUsesWith(from: distributedVal, to: tyResolvedVal);
1281 return success();
1282 }
1283};
1284
1285/// Distribute a dpas op feeding into vector.yield op for the enclosing
1286/// `gpu.warp_execute_on_lane_0` and put it after the warp op.
1287/// The warp op will still contain the original op that will not be used by
1288/// the yield op (and should be cleaned up later). The yield op will
1289/// bypass the dpas's arguments. Appropriate cast ops are inserted if the
1290/// distributed types does not match expected xegpu SIMT types.
1291/// Example:
1292/// ```
1293/// #lo_a = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
1294/// #lo_b = #xegpu.layout<wi_layout = [1, 16], wi_data = [2, 1]>
1295/// #lo_c = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]>
1296/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
1297/// (vector<8x1xf32>) {
1298/// ...
1299/// %dpas = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16> ->
1300/// vector<8x16xf32>
1301/// gpu.yield %dpas
1302/// }
1303/// ```
1304/// To
1305/// ```
1306/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<8x1xf32>,
1307/// vector<8x1xf16>, vector<16x1xf16>) {
1308/// ...
1309/// %dead = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16>
1310/// -> vector<8x16xf32>
1311/// gpu.yield %dead, %arg0, %arg1
1312/// }
1313/// %0 = vector.shape_cast %r#1: vector<8x1xf16> to vector<8xf16>
1314/// %1 = vector.shape_cast %r#2: vector<16x1xf16> to vector<16xf16>
1315/// %2 = xegpu.dpas %0, %1: vector<8xf16>, vector<16xf16> ->
1316/// vector<8xf32>
1317/// %dpas = vector.shape_cast %2: vector<8xf32> to vector<8x1xf32>
1318/// ```
1319struct DpasDistribution final : public gpu::WarpDistributionPattern {
1320 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1321 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1322 PatternRewriter &rewriter) const override {
1323 OpOperand *operand =
1324 getWarpResult(subgroupOp, llvm::IsaPred<xegpu::DpasOp>);
1325 if (!operand)
1326 return rewriter.notifyMatchFailure(subgroupOp,
1327 "warp result is not a xegpu::Dpas op");
1328
1329 auto dpasOp = operand->get().getDefiningOp<xegpu::DpasOp>();
1330 unsigned operandIdx = operand->getOperandNumber();
1331 std::string layoutAName = xegpu::getLayoutName(dpasOp->getOpOperand(0));
1332 std::string layoutBName = xegpu::getLayoutName(dpasOp->getOpOperand(1));
1333 std::string layoutCName = xegpu::getLayoutName(dpasOp->getOpResult(0));
1334
1335 xegpu::LayoutAttr layoutA =
1336 dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutAName);
1337 xegpu::LayoutAttr layoutB =
1338 dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutBName);
1339 xegpu::LayoutAttr layoutOut =
1340 dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutCName);
1341 if (!layoutA || !layoutB || !layoutOut)
1342 return rewriter.notifyMatchFailure(
1343 dpasOp,
1344 "the xegpu::Dpas op lacks layout attribute for A, B or output");
1345
1346 FailureOr<VectorType> distLhsTypeByWarpOpOrFailure =
1347 getDistVecTypeBasedOnLaneLayout(layoutA, dpasOp.getLhsType());
1348 FailureOr<VectorType> distRhsTypeByWarpOpOrFailure =
1349 getDistVecTypeBasedOnLaneLayout(layoutB, dpasOp.getRhsType());
1350 FailureOr<VectorType> distResultTypeByWarpOpOrFailure =
1351 getDistVecTypeBasedOnLaneLayout(layoutOut, dpasOp.getResultType());
1352 if (failed(Result: distLhsTypeByWarpOpOrFailure) ||
1353 failed(Result: distRhsTypeByWarpOpOrFailure) ||
1354 failed(Result: distResultTypeByWarpOpOrFailure))
1355 return rewriter.notifyMatchFailure(
1356 dpasOp,
1357 "Failed to distribute the A, B or output types in xegpu::Dpas op");
1358
1359 llvm::SmallVector<Value, 3> newYieldValues{dpasOp.getLhs(),
1360 dpasOp.getRhs()};
1361 llvm::SmallVector<Type, 3> newYieldTypes{
1362 distLhsTypeByWarpOpOrFailure.value(),
1363 distRhsTypeByWarpOpOrFailure.value()};
1364 // Dpas acc operand is optional.
1365 if (dpasOp.getAcc()) {
1366 newYieldValues.push_back(Elt: dpasOp.getAcc());
1367 newYieldTypes.push_back(Elt: distResultTypeByWarpOpOrFailure.value());
1368 }
1369 // Create a new warp op without the dpas.
1370 SmallVector<size_t> newRetIndices;
1371 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1372 rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1373
1374 FailureOr<VectorType> expectedDistLhsTyOrFailure =
1375 xegpu::getDistributedVectorType(dpasOp.getLhsType(), layoutA);
1376 FailureOr<VectorType> expectedDistRhsTyOrFailure =
1377 xegpu::getDistributedVectorType(dpasOp.getRhsType(), layoutB);
1378 FailureOr<VectorType> expectedDistResultTyOrFailure =
1379 xegpu::getDistributedVectorType(dpasOp.getResultType(), layoutOut);
1380 if (failed(Result: expectedDistLhsTyOrFailure) ||
1381 failed(Result: expectedDistRhsTyOrFailure) ||
1382 failed(Result: expectedDistResultTyOrFailure))
1383 return rewriter.notifyMatchFailure(
1384 dpasOp,
1385 "Failed to get distributed vector type for the dpas operands.");
1386 // Create a new dpas op outside the warp op.
1387 rewriter.setInsertionPointAfter(newWarpOp);
1388 SmallVector<Value> newDpasOperands;
1389 SmallVector<VectorType> newDpasOperandExpectedTypes;
1390
1391 // Resolve the distributed types with the original types.
1392 newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value());
1393 newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value());
1394 VectorType distributedResultTy = expectedDistResultTyOrFailure.value();
1395 if (dpasOp.getAcc())
1396 newDpasOperandExpectedTypes.push_back(distributedResultTy);
1397
1398 for (unsigned i = 0; i < newRetIndices.size(); i++) {
1399 newDpasOperands.push_back(
1400 Elt: resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]),
1401 newDpasOperandExpectedTypes[i], rewriter));
1402 }
1403 Value newDpasOp = rewriter.create<xegpu::DpasOp>(
1404 newWarpOp->getLoc(), distributedResultTy, newDpasOperands,
1405 removeTemporaryLayoutAttributes(dpasOp->getAttrs()));
1406 Value distributedVal = newWarpOp.getResult(operandIdx);
1407 // Resolve the output type.
1408 newDpasOp = resolveDistributedTy(
1409 newDpasOp, distResultTypeByWarpOpOrFailure.value(), rewriter);
1410 rewriter.replaceAllUsesWith(from: distributedVal, to: newDpasOp);
1411 return success();
1412 }
1413};
1414
1415/// Sink an update_nd_offset op feeding into yield op of an enclosing
1416/// `gpu.warp_execute_on_lane_0` region. The warp op will still contain the
1417/// original op that will not be used by the yield op (and should be cleaned
1418/// up later). The yield op will bypass the updateOp's arguments. The tensor
1419/// descriptor type is not distributed. Appropriate cast ops are inserted if
1420/// the distributed types does not match expected xegpu SIMT types.
1421/// Example:
1422/// ```
1423/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1424/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
1425/// (!xegpu.tensor_desc<4x8xf32, #layout0>) {
1426/// ...
1427/// %update = xegpu.update_nd_offset %arg0, [%c32, %c16]:
1428/// !xegpu.tensor_desc<4x8xf32, #layout0>
1429/// gpu.yield %update
1430/// }
1431/// ...
1432/// ```
1433/// To
1434/// ```
1435/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (
1436/// !xegpu.tensor_desc<4x8xf32, #layout0>,
1437/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
1438/// ...
1439/// %dead = xegpu.update_nd_offset %arg0, [%c32, %c16]:
1440/// !xegpu.tensor_desc<4x8xf32, #layout0> gpu.yield %dead, %arg0
1441/// gpu.yield %dead, %arg0, %c32, %c16
1442/// }
1443/// %0 = xegpu.unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
1444/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
1445/// %1 = xegpu.update_nd_offset %0, [%r#2, %r#3]:
1446/// !xegpu.tensor_desc<4x8xf32>
1447/// ...
1448/// ```
1449struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
1450 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1451 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1452 PatternRewriter &rewriter) const override {
1453 OpOperand *operand =
1454 getWarpResult(subgroupOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
1455 if (!operand)
1456 return rewriter.notifyMatchFailure(
1457 subgroupOp, "warp result is not a xegpu::UpdateNdOffset op");
1458 auto updateOp = operand->get().getDefiningOp<xegpu::UpdateNdOffsetOp>();
1459 unsigned operandIdx = operand->getOperandNumber();
1460 // new update op does not have layout attribute.
1461 xegpu::TensorDescType newTensorDescTy =
1462 updateOp.getTensorDescType().dropLayouts();
1463
1464 SmallVector<Value, 3> newYieldValues;
1465 SmallVector<Type, 3> newYieldTypes;
1466 for (Value operand : updateOp->getOperands()) {
1467 newYieldValues.push_back(operand);
1468 if (isa<xegpu::TensorDescType>(operand.getType())) {
1469 newYieldTypes.push_back(newTensorDescTy);
1470 } else {
1471 newYieldTypes.push_back(operand.getType());
1472 }
1473 }
1474 SmallVector<size_t> newRetIndices;
1475 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1476 rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1477 rewriter.setInsertionPointAfter(newWarpOp);
1478 SmallVector<Value> newUpdateOperands;
1479 for (size_t i : newRetIndices) {
1480 // For the tensor descriptor operand, the layout attribute is dropped
1481 // after distribution. Types needs to be resolved in this case.
1482 if (isa<xegpu::TensorDescType>(newWarpOp.getResult(i).getType())) {
1483 newUpdateOperands.push_back(Elt: resolveDistributedTy(
1484 newWarpOp.getResult(i), newTensorDescTy, rewriter));
1485 } else {
1486 newUpdateOperands.push_back(Elt: newWarpOp.getResult(i));
1487 }
1488 }
1489 // Create a new update op outside the warp op.
1490 auto newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
1491 newWarpOp.getLoc(), newTensorDescTy, newUpdateOperands,
1492 removeTemporaryLayoutAttributes(updateOp->getAttrs()));
1493 Value distributedVal = newWarpOp.getResult(operandIdx);
1494 rewriter.replaceAllUsesWith(distributedVal, newUpdateOp);
1495 return success();
1496 }
1497};
1498
1499/// Distribute a prefetch_nd op at the end of enclosing
1500/// `gpu.warp_execute_on_lane_0`. In case arguments for the prefetch are passed
1501/// through the warp op interface they would be propagated as returned values.
1502/// Tensor descriptor shape is not distributed because it is a uniform value
1503/// across all work items within the subgroup. Appropriate cast ops are inserted
1504/// if the distributed types does not match expected xegpu SIMT types.
1505///
1506/// Example:
1507///
1508/// ```
1509/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1510/// gpu.warp_execute_on_lane_0(%laneid) -> () {
1511/// ...
1512/// xegpu.prefetch_nd %arg0 : !xegpu.tensor_desc<4x8xf32, #layout0>
1513/// }
1514/// ```
1515/// To
1516/// ```
1517/// %r:1 = gpu.warp_execute_on_lane_0(%laneid) -> (
1518/// !xegpu.tensor_desc<4x8xf32, #layout0>) {
1519/// gpu.yield %arg0: !xegpu.tensor_desc<4x8xf32, #layout0>
1520/// }
1521/// %1 = unrealized_conversion_cast %r#0: !xegpu.tensor_desc<4x8xf32,
1522/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
1523/// xegpu.prefetch_nd %1 : !xegpu.tensor_desc<4x8xf32>
1524///
1525/// ```
1526struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
1527 using gpu::WarpDistributionPattern::WarpDistributionPattern;
1528 LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1529 PatternRewriter &rewriter) const override {
1530 auto yield = cast<gpu::YieldOp>(
1531 subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
1532 Operation *lastNode = yield->getPrevNode();
1533 auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
1534 if (!prefetchOp)
1535 return failure();
1536 xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
1537 if (!layout)
1538 return rewriter.notifyMatchFailure(
1539 prefetchOp, "the source tensor descriptor lacks layout attribute");
1540
1541 SmallVector<Value, 1> newYieldValues = {prefetchOp.getTensorDesc()};
1542 SmallVector<Type, 1> newYieldTypes = {prefetchOp.getTensorDescType()};
1543 SmallVector<size_t> newRetIndices;
1544 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1545 rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1546 // Create a new prefetch op outside the warp op with updated tensor
1547 // descriptor type. Source tensor descriptor require type resolution.
1548 xegpu::TensorDescType newTensorDescTy =
1549 prefetchOp.getTensorDescType().dropLayouts();
1550 rewriter.setInsertionPointAfter(newWarpOp);
1551 SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
1552 newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
1553 rewriter.create<xegpu::PrefetchNdOp>(
1554 newWarpOp.getLoc(), TypeRange{}, newPrefetchOperands,
1555 removeTemporaryLayoutAttributes(prefetchOp->getAttrs()));
1556 rewriter.eraseOp(op: prefetchOp);
1557 return success();
1558 }
1559};
1560
1561} // namespace
1562
1563namespace {
1564struct XeGPUSubgroupDistributePass final
1565 : public xegpu::impl::XeGPUSubgroupDistributeBase<
1566 XeGPUSubgroupDistributePass> {
1567 XeGPUSubgroupDistributePass() = default;
1568 XeGPUSubgroupDistributePass(const XeGPUSubgroupDistributePass &other) =
1569 default;
1570 XeGPUSubgroupDistributePass(xegpu::XeGPUSubgroupDistributeOptions options)
1571 : XeGPUSubgroupDistributeBase(options) {}
1572 void runOnOperation() override;
1573};
1574} // namespace
1575
1576void xegpu::populateXeGPUSubgroupDistributePatterns(
1577 RewritePatternSet &patterns) {
1578 patterns.add<CreateNdDescDistribution, StoreNdDistribution,
1579 LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
1580 UpdateNdOffsetDistribution>(arg: patterns.getContext());
1581}
1582
1583void XeGPUSubgroupDistributePass::runOnOperation() {
1584 auto &analyis = getAnalysis<RunLayoutInfoPropagation>();
1585 // Print the analysis result and exit. (for testing purposes)
1586 if (printOnly) {
1587 auto &os = llvm::outs();
1588 analyis.printAnalysisResult(os);
1589 return;
1590 }
1591 auto getPropagatedLayout = [&](Value val) {
1592 return analyis.getLayoutInfo(val);
1593 };
1594
1595 // Assign xegpu::LayoutAttr to all ops and their users based on the layout
1596 // propagation analysis result.
1597 LayoutAttrAssignment layoutAssignment(getOperation(), getPropagatedLayout);
1598 if (failed(Result: layoutAssignment.run())) {
1599 signalPassFailure();
1600 return;
1601 }
1602
1603 // Move all operations of a GPU function inside gpu.warp_execute_on_lane_0
1604 // operation.
1605 {
1606 RewritePatternSet patterns(&getContext());
1607 patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
1608
1609 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
1610 signalPassFailure();
1611 return;
1612 }
1613 // At this point, we have moved the entire function body inside the warpOp.
1614 // Now move any scalar uniform code outside of the warpOp (like GPU index
1615 // ops, scalar constants, etc.). This will simplify the later lowering and
1616 // avoid custom patterns for these ops.
1617 getOperation()->walk([&](Operation *op) {
1618 if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
1619 vector::moveScalarUniformCode(warpOp);
1620 }
1621 });
1622 }
1623 // Finally, do the SIMD to SIMT distribution.
1624 RewritePatternSet patterns(&getContext());
1625 xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
1626 // TODO: distributionFn and shuffleFn are not used at this point.
1627 auto distributionFn = [](Value val) {
1628 VectorType vecType = dyn_cast<VectorType>(val.getType());
1629 int64_t vecRank = vecType ? vecType.getRank() : 0;
1630 OpBuilder builder(val.getContext());
1631 if (vecRank == 0)
1632 return AffineMap::get(context: val.getContext());
1633 return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
1634 };
1635 auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
1636 int64_t warpSz) { return Value(); };
1637 vector::populatePropagateWarpVectorDistributionPatterns(
1638 patterns, distributionFn, shuffleFn);
1639 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
1640 signalPassFailure();
1641 return;
1642 }
1643}
1644

source code of mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp