1//===- XeGPUPropagateLayout.cpp - XeGPU Layout Propagation ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
10#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
11#include "mlir/Analysis/DataFlow/Utils.h"
12#include "mlir/Analysis/DataFlowFramework.h"
13#include "mlir/Dialect/GPU/IR/GPUDialect.h"
14#include "mlir/Dialect/MemRef/IR/MemRef.h"
15#include "mlir/Dialect/Vector/IR/VectorOps.h"
16#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
17#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
18#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
19#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
20#include "mlir/IR/Attributes.h"
21#include "mlir/IR/Builders.h"
22#include "mlir/IR/BuiltinAttributes.h"
23#include "mlir/IR/BuiltinTypes.h"
24#include "mlir/IR/Operation.h"
25#include "mlir/IR/Value.h"
26#include "mlir/IR/Visitors.h"
27#include "mlir/Interfaces/ControlFlowInterfaces.h"
28#include "mlir/Interfaces/FunctionInterfaces.h"
29#include "mlir/Support/LLVM.h"
30#include "llvm/ADT/ArrayRef.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/Casting.h"
35#include "llvm/Support/Debug.h"
36#include "llvm/Support/InterleavedRange.h"
37#include "llvm/Support/LogicalResult.h"
38#include "llvm/Support/raw_ostream.h"
39
40namespace mlir {
41namespace xegpu {
42#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
43#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
44} // namespace xegpu
45} // namespace mlir
46
47#define DEBUG_TYPE "xegpu-propagate-layout"
48#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
49
50using namespace mlir;
51using namespace mlir::dataflow;
52
53namespace {
54
55//===----------------------------------------------------------------------===//
56// Layout
57//===----------------------------------------------------------------------===//
58
59/// Helper class to store the ND layout of lanes within a subgroup and data
60/// owned by each lane.
61struct Layout {
62 SmallVector<int64_t, 3> layout;
63 Layout() = default;
64 Layout(std::initializer_list<int64_t> list) : layout(list) {}
65 void print(llvm::raw_ostream &os) const;
66 size_t size() const { return layout.size(); }
67};
68
69void Layout::print(llvm::raw_ostream &os) const {
70 os << llvm::interleaved_array(R: layout);
71}
72
73/// LaneLayout represents the logical layout of lanes within a subgroup when it
74/// accesses some value. LaneData represents the logical layout of data owned by
75/// each work item.
76using LaneLayout = Layout;
77using LaneData = Layout;
78
79//===----------------------------------------------------------------------===//
80// LayoutInfo
81//===----------------------------------------------------------------------===//
82
83/// Helper class for tracking the analysis state of an mlir value. For layout
84/// propagation, the analysis state is simply the lane_layout and lane_data of
85/// each value. Purpose of this analysis to propagate some unique layout for
86/// each value in the program starting from a set of anchor operations (like
87/// DPAS, StoreNd, etc.).
88///
89/// Given this, LayoutInfo satisifies the following properties:
90/// 1) A LayoutInfo value can be in one of two states - `assigned` or `not
91/// assigned`.
92/// 2) Two LayoutInfo values are equal if they are both assigned or
93/// both not assigned. The concrete value of assigned state does not matter.
94/// 3) The meet operator works as follows:
95/// - If current state is assigned, return the current state. (already
96/// a unique layout is assigned. don't change it)
97/// - Otherwise, return the other state.
98
99struct LayoutInfo {
100private:
101 LaneLayout laneLayout;
102 LaneData laneData;
103 xegpu::LayoutAttr layoutAttr;
104
105public:
106 LayoutInfo() = default;
107 LayoutInfo(const LaneLayout &layout, const LaneData &data)
108 : laneLayout(layout), laneData(data) {}
109
110 // Two lattice values are equal if they have `some` layout. The actual
111 // content of the layout does not matter.
112 bool operator==(const LayoutInfo &other) const {
113 return this->isAssigned() == other.isAssigned();
114 }
115
116 static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs);
117
118 static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs);
119
120 void print(raw_ostream &os) const;
121
122 bool isAssigned() const {
123 return laneLayout.size() > 0 && laneData.size() > 0;
124 }
125
126 LayoutInfo getTransposedLayout(ArrayRef<int64_t> permutation) const;
127
128 const LaneLayout &getLayout() const { return laneLayout; }
129 const LaneData &getData() const { return laneData; }
130 ArrayRef<int64_t> getLayoutAsArrayRef() const { return laneLayout.layout; }
131 ArrayRef<int64_t> getDataAsArrayRef() const { return laneData.layout; }
132};
133
134void LayoutInfo::print(raw_ostream &os) const {
135 if (isAssigned()) {
136 os << "lane_layout: ";
137 laneLayout.print(os);
138 os << ", lane_data: ";
139 laneData.print(os);
140 } else {
141 os << "Not assigned.";
142 }
143}
144
145LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) {
146 if (!lhs.isAssigned())
147 return rhs;
148 return lhs;
149}
150
151/// Since this is a backward analysis, join method is not used.
152LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
153 llvm_unreachable("Join should not be triggered by layout propagation.");
154}
155
156/// Get the transposed layout according to the given permutation.
157LayoutInfo
158LayoutInfo::getTransposedLayout(ArrayRef<int64_t> permutation) const {
159 if (!isAssigned())
160 return {};
161 LaneLayout newLayout;
162 LaneData newData;
163 for (int64_t idx : permutation) {
164 newLayout.layout.push_back(Elt: laneLayout.layout[idx]);
165 newData.layout.push_back(Elt: laneData.layout[idx]);
166 }
167 return LayoutInfo(newLayout, newData);
168}
169
170//===----------------------------------------------------------------------===//
171// LayoutInfoLattice
172//===----------------------------------------------------------------------===//
173
174/// Lattice holding the LayoutInfo for each value.
175struct LayoutInfoLattice : public Lattice<LayoutInfo> {
176 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LayoutInfoLattice)
177 using Lattice::Lattice;
178};
179
180/// Helper Functions to get default layouts. A `default layout` is a layout that
181/// is assigned to a value when the layout is not fixed by some anchor operation
182/// (like DPAS).
183
184/// Helper Function to get the default layout for uniform values like constants.
185/// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
186/// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
187static LayoutInfo getDefaultSIMTLayoutInfo(unsigned rank) {
188 assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
189 if (rank == 1)
190 return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize}),
191 LaneData({1}));
192 return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
193 LaneData({1, 1}));
194}
195
196/// Helper to get the default layout for a vector type.
197static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
198 // Expecting a 1D or 2D vector.
199 assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
200 "Expected 1D or 2D vector.");
201 // Expecting int or float element type.
202 assert(vectorTy.getElementType().isIntOrFloat() &&
203 "Expected int or float element type.");
204 // If the rank is 1, then return default layout for 1D vector.
205 if (vectorTy.getRank() == 1)
206 return getDefaultSIMTLayoutInfo(rank: 1);
207 // Packing factor is determined by the element type bitwidth.
208 int packingFactor = 1;
209 unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
210 if (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
211 packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth;
212 return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
213 LaneData({1, packingFactor}));
214}
215
216/// Helper to get the default layout for a vector type.
217static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
218 // Expecting a 1D or 2D vector.
219 assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
220 "Expected 1D or 2D TensorDesc.");
221 // Expecting int or float element type.
222 assert(tdescTy.getElementType().isIntOrFloat() &&
223 "Expected int or float element type.");
224 // If the rank is 1, then return default layout for 1D vector.
225 if (tdescTy.getRank() == 1)
226 return getDefaultSIMTLayoutInfo(rank: 1);
227 // Packing factor is determined by the element type bitwidth.
228 unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
229
230 if (tdescTy.isScattered()) {
231 int packingFactor =
232 bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
233 ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
234 : 1;
235 return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize, 1}),
236 LaneData({1, packingFactor}));
237 }
238
239 int packingFactor =
240 (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
241 ? xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth
242 : 1;
243 return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
244 LaneData({1, packingFactor}));
245}
246
247/// Helper Function to get the expected layouts for DPAS operands. `lane_data`
248/// is set according to the following criteria:
249/// * For A operand, the data must be packed in minimum
250/// `packedSizeInBitsForDefault`
251/// * For B operand, the data must be packed in minimum
252/// `packedSizeInBitsForDpasB`
253static LayoutInfo getSIMTLayoutInfoForDPASOperand(VectorType vectorTy,
254 unsigned operandNum) {
255 Type elementTy = vectorTy.getElementType();
256 assert(elementTy.isIntOrFloat() &&
257 "Expected int or float type in DPAS operands");
258 LaneLayout layout({1, xegpu::targetinfo::subgroupSize});
259 // For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
260 // must have the VNNI format.
261 if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() <
262 xegpu::targetinfo::packedSizeInBitsForDpasB) {
263 LaneData data({xegpu::targetinfo::packedSizeInBitsForDpasB /
264 elementTy.getIntOrFloatBitWidth(),
265 1});
266 return LayoutInfo(layout, data);
267 }
268 // Otherwise, return the default layout for the vector type.
269 return getDefaultSIMTLayoutInfo(vectorTy);
270}
271
272//===----------------------------------------------------------------------===//
273// LayoutInfoPropagation
274//===----------------------------------------------------------------------===//
275
276/// Backward data flow analysis to propagate the lane_layout and lane_data of
277/// each value in the program. Currently, the layouts for operands DPAS,
278/// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of
279/// this analysis is to propagate those known layouts to all their producers and
280/// (other) consumers.
281class LayoutInfoPropagation
282 : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
283private:
284 void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
285 ArrayRef<const LayoutInfoLattice *> results);
286
287 void visitStoreNdOp(xegpu::StoreNdOp store,
288 ArrayRef<LayoutInfoLattice *> operands,
289 ArrayRef<const LayoutInfoLattice *> results);
290
291 void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
292 ArrayRef<LayoutInfoLattice *> operands,
293 ArrayRef<const LayoutInfoLattice *> results);
294
295 void visitLoadNdOp(xegpu::LoadNdOp load,
296 ArrayRef<LayoutInfoLattice *> operands,
297 ArrayRef<const LayoutInfoLattice *> results);
298
299 void visitLoadGatherOp(xegpu::LoadGatherOp load,
300 ArrayRef<LayoutInfoLattice *> operands,
301 ArrayRef<const LayoutInfoLattice *> results);
302
303 void visitTransposeOp(vector::TransposeOp transpose,
304 ArrayRef<LayoutInfoLattice *> operands,
305 ArrayRef<const LayoutInfoLattice *> results);
306
307 void visitVectorBitcastOp(vector::BitCastOp bitcast,
308 ArrayRef<LayoutInfoLattice *> operands,
309 ArrayRef<const LayoutInfoLattice *> results);
310
311 void visitCreateDescOp(xegpu::CreateDescOp createDesc,
312 ArrayRef<LayoutInfoLattice *> operands,
313 ArrayRef<const LayoutInfoLattice *> results);
314
315 void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
316 ArrayRef<LayoutInfoLattice *> operands,
317 ArrayRef<const LayoutInfoLattice *> results);
318
319 void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
320 ArrayRef<LayoutInfoLattice *> operands,
321 ArrayRef<const LayoutInfoLattice *> results);
322
323 void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
324 ArrayRef<LayoutInfoLattice *> operands,
325 ArrayRef<const LayoutInfoLattice *> results);
326
327public:
328 LayoutInfoPropagation(DataFlowSolver &solver,
329 SymbolTableCollection &symbolTable)
330 : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
331 using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
332
333 LogicalResult
334 visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
335 ArrayRef<const LayoutInfoLattice *> results) override;
336
337 void visitBranchOperand(OpOperand &operand) override {};
338
339 void visitCallOperand(OpOperand &operand) override {};
340
341 void visitExternalCall(CallOpInterface call,
342 ArrayRef<LayoutInfoLattice *> operands,
343 ArrayRef<const LayoutInfoLattice *> results) override {
344 };
345
346 void setToExitState(LayoutInfoLattice *lattice) override {
347 (void)lattice->meet(rhs: LayoutInfo());
348 }
349};
350} // namespace
351
352LogicalResult LayoutInfoPropagation::visitOperation(
353 Operation *op, ArrayRef<LayoutInfoLattice *> operands,
354 ArrayRef<const LayoutInfoLattice *> results) {
355 TypeSwitch<Operation *>(op)
356 .Case<xegpu::DpasOp>(
357 caseFn: [&](auto dpasOp) { visitDpasOp(dpas: dpasOp, operands, results); })
358 .Case<xegpu::StoreNdOp>(
359 caseFn: [&](auto storeNdOp) { visitStoreNdOp(store: storeNdOp, operands, results); })
360 .Case<xegpu::StoreScatterOp>(caseFn: [&](auto storeScatterOp) {
361 visitStoreScatterOp(storeScatter: storeScatterOp, operands, results);
362 })
363 .Case<xegpu::LoadNdOp>(
364 caseFn: [&](auto loadNdOp) { visitLoadNdOp(load: loadNdOp, operands, results); })
365 .Case<xegpu::LoadGatherOp>(caseFn: [&](auto loadGatherOp) {
366 visitLoadGatherOp(load: loadGatherOp, operands, results);
367 })
368 .Case<xegpu::CreateDescOp>(caseFn: [&](auto createDescOp) {
369 visitCreateDescOp(createDesc: createDescOp, operands, results);
370 })
371 .Case<xegpu::UpdateNdOffsetOp>(caseFn: [&](auto updateNdOffsetOp) {
372 visitUpdateNdOffsetOp(updateNdOffset: updateNdOffsetOp, operands, results);
373 })
374 .Case<xegpu::PrefetchNdOp>(caseFn: [&](auto prefetchNdOp) {
375 visitPrefetchNdOp(prefetch: prefetchNdOp, operands, results);
376 })
377 .Case<vector::TransposeOp>(caseFn: [&](auto transposeOp) {
378 visitTransposeOp(transpose: transposeOp, operands, results);
379 })
380 .Case<vector::BitCastOp>(caseFn: [&](auto bitcastOp) {
381 visitVectorBitcastOp(bitcast: bitcastOp, operands, results);
382 })
383 .Case<vector::MultiDimReductionOp>(caseFn: [&](auto reductionOp) {
384 visitVectorMultiReductionOp(reduction: reductionOp, operands, results);
385 })
386 // All other ops.
387 .Default(defaultFn: [&](Operation *op) {
388 for (const LayoutInfoLattice *resultInfo : results) {
389 if (!resultInfo->getValue().isAssigned())
390 continue;
391 for (auto [operandInfo, operand] :
392 llvm::zip(t&: operands, u: op->getOpOperands())) {
393 // If the operand type is not a vector or tensor descriptor, skip
394 // it.
395 if (!isa<xegpu::TensorDescType, VectorType>(
396 Val: operand.get().getType()))
397 continue;
398 // Propagate the result layout to the operand.
399 meet(lhs: operandInfo, rhs: *resultInfo);
400 }
401 }
402 });
403
404 return success();
405}
406
407void LayoutInfoPropagation::visitPrefetchNdOp(
408 xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
409 ArrayRef<const LayoutInfoLattice *> results) {
410 // Here we assign the default layout to the tensor descriptor operand of
411 // prefetch.
412 auto tdescTy = prefetch.getTensorDescType();
413 auto prefetchLayout = getDefaultSIMTLayoutInfo(tdescTy);
414 // Propagate the layout to the source tensor descriptor.
415 propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: prefetchLayout));
416}
417
418void LayoutInfoPropagation::visitVectorMultiReductionOp(
419 vector::MultiDimReductionOp reduction,
420 ArrayRef<LayoutInfoLattice *> operands,
421 ArrayRef<const LayoutInfoLattice *> results) {
422 // The layout of the result must be present.
423 LayoutInfo resultLayout = results[0]->getValue();
424 if (!resultLayout.isAssigned())
425 return;
426 // We only consider 2D -> 1D reductions at this point.
427 VectorType resultTy = llvm::dyn_cast<VectorType>(Val: reduction.getDestType());
428 if (!resultTy || resultTy.getRank() != 1) {
429 reduction.emitWarning(message: "Expecting output type to be 1D vector.");
430 return;
431 }
432 // Given that the result is 1D, the layout of the operand should be 2D with
433 // default layout.
434 LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(rank: 2);
435 propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: operandLayout));
436 // Accumulator should have the same layout as the result.
437 propagateIfChanged(state: operands[1], changed: operands[1]->meet(rhs: resultLayout));
438}
439
440/// Propagate the layout of the result tensor to the source tensor descriptor in
441/// UpdateNdOffsetOp.
442void LayoutInfoPropagation::visitUpdateNdOffsetOp(
443 xegpu::UpdateNdOffsetOp updateNdOffset,
444 ArrayRef<LayoutInfoLattice *> operands,
445 ArrayRef<const LayoutInfoLattice *> results) {
446 // The layout of the result must be present.
447 LayoutInfo resultLayout = results[0]->getValue();
448 if (!resultLayout.isAssigned())
449 return;
450 // Propagate the layout to the source operand.
451 propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: resultLayout));
452}
453
454/// Set the layouts for DPAS A, B, and C operands.
455void LayoutInfoPropagation::visitDpasOp(
456 xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
457 ArrayRef<const LayoutInfoLattice *> results) {
458 VectorType aTy = dpas.getLhsType();
459 VectorType bTy = dpas.getRhsType();
460 propagateIfChanged(
461 state: operands[0], changed: operands[0]->meet(rhs: getSIMTLayoutInfoForDPASOperand(vectorTy: aTy, operandNum: 0)));
462 propagateIfChanged(
463 state: operands[1], changed: operands[1]->meet(rhs: getSIMTLayoutInfoForDPASOperand(vectorTy: bTy, operandNum: 1)));
464 if (operands.size() > 2) {
465 VectorType cTy = dpas.getAccType();
466 propagateIfChanged(
467 state: operands[2],
468 changed: operands[2]->meet(rhs: getSIMTLayoutInfoForDPASOperand(vectorTy: cTy, operandNum: 2)));
469 }
470}
471
472/// Set the layout for the value and tensor descriptor operands in StoreNdOp.
473void LayoutInfoPropagation::visitStoreNdOp(
474 xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
475 ArrayRef<const LayoutInfoLattice *> results) {
476 LayoutInfo storeLayout = getDefaultSIMTLayoutInfo(vectorTy: store.getValueType());
477 // Both operands should have the same layout
478 for (LayoutInfoLattice *operand : operands)
479 propagateIfChanged(state: operand, changed: operand->meet(rhs: storeLayout));
480}
481
482/// Propagate the layout of the value to the tensor descriptor operand in
483/// LoadNdOp.
484void LayoutInfoPropagation::visitLoadNdOp(
485 xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
486 ArrayRef<const LayoutInfoLattice *> results) {
487 LayoutInfo valueLayout = results[0]->getValue();
488 // Need the layout of the value to propagate to the tensor descriptor.
489 if (!valueLayout.isAssigned())
490 return;
491 LayoutInfo tensorDescLayout = valueLayout;
492 // LoadNdOp has the transpose effect. However, at the stage of this analysis
493 // this effect is not expected and should be abstracted away. Emit a
494 // warning.
495 if (auto transpose = load.getTranspose()) {
496 load.emitWarning(message: "Transpose effect is not expected for LoadNdOp at "
497 "LayoutInfoPropagation stage.");
498 tensorDescLayout = valueLayout.getTransposedLayout(permutation: transpose.value());
499 }
500 // Propagate the new layout to the tensor descriptor operand.
501 propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: tensorDescLayout));
502}
503
504/// For vector::TransposeOp, the layout of the result is transposed and
505/// propagated to the operand.
506void LayoutInfoPropagation::visitTransposeOp(
507 vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
508 ArrayRef<const LayoutInfoLattice *> results) {
509 // Need the layout of transpose result to propagate to the operands.
510 LayoutInfo resultLayout = results[0]->getValue();
511 if (!resultLayout.isAssigned())
512 return;
513 LayoutInfo newLayout =
514 resultLayout.getTransposedLayout(permutation: transpose.getPermutation());
515 // Propagate the new layout to the vector operand.
516 propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: newLayout));
517}
518
519/// For vector::BitCastOp, the lane_data of the source layout is changed based
520/// on the bit width of the source and result types.
521void LayoutInfoPropagation::visitVectorBitcastOp(
522 vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
523 ArrayRef<const LayoutInfoLattice *> results) {
524 // Need the layout of bitcast result to propagate to the operands.
525 LayoutInfo resultLayout = results[0]->getValue();
526 if (!resultLayout.isAssigned())
527 return;
528 int inElemTyBitWidth =
529 bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
530 int outElemTyBitWidth =
531 bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
532
533 // NOTE: We do not expect widening or narrowing bitcasts at this stage. Emit
534 // a warning and return.
535 if (inElemTyBitWidth != outElemTyBitWidth) {
536 bitcast.emitWarning(message: "Widening or narrowing bitcasts are not expected at "
537 "layout propagation stage.");
538 return;
539 }
540
541 propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: resultLayout));
542}
543
544/// Propagate the layout of the result to the tensor descriptor and mask
545/// operands in LoadGatherOp.
546void LayoutInfoPropagation::visitLoadGatherOp(
547 xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
548 ArrayRef<const LayoutInfoLattice *> results) {
549 // The layout is strictly determined by the tensor descriptor type.
550 LayoutInfo layout = getDefaultSIMTLayoutInfo(tdescTy: load.getTensorDescType());
551
552 // Mask operand should have 1D default layout.
553 LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(rank: 1);
554
555 // Propagate the new layout to the tensor descriptor operand.
556 propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: layout));
557 // Propagate the new layout to the mask operand.
558 propagateIfChanged(state: operands[1], changed: operands[1]->meet(rhs: maskLayout));
559}
560
561/// Propagate the layout of the descriptor to the vector offset operand in
562/// CreateDescOp.
563void LayoutInfoPropagation::visitCreateDescOp(
564 xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
565 ArrayRef<const LayoutInfoLattice *> results) {
566 LayoutInfo descLayout = results[0]->getValue();
567 // Need the layout of the descriptor to propagate to the operands.
568 if (!descLayout.isAssigned())
569 return;
570 // For offset operand propagate 1D default layout.
571 LayoutInfo layout = getDefaultSIMTLayoutInfo(rank: 1);
572 propagateIfChanged(state: operands[1], changed: operands[1]->meet(rhs: layout));
573}
574
575/// Set the layout for the value, tensor descriptor, and mask operands in the
576/// StoreScatterOp.
577void LayoutInfoPropagation::visitStoreScatterOp(
578 xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
579 ArrayRef<const LayoutInfoLattice *> results) {
580 // Currently, for 2D StoreScatterOp we expect that the height dimension of
581 // the tensor descriptor is equal to the subgroup size. This is ensured by
582 // the op verifier.
583 ArrayRef<int64_t> tdescShape = storeScatter.getTensorDescType().getShape();
584 if (tdescShape.size() > 1)
585 assert(
586 tdescShape[0] == xegpu::targetinfo::subgroupSize &&
587 "Expected the first dimension of 2D tensor descriptor to be equal to "
588 "subgroup size.");
589
590 LayoutInfo layout =
591 getDefaultSIMTLayoutInfo(tdescTy: storeScatter.getTensorDescType());
592
593 // Propagate the value layout.
594 propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: layout));
595 // Propagate the tensor descriptor layout.
596 propagateIfChanged(state: operands[1], changed: operands[1]->meet(rhs: layout));
597 // Use default 1D layout for mask operand.
598 LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(rank: 1);
599 propagateIfChanged(state: operands[2], changed: operands[2]->meet(rhs: maskLayout));
600}
601
602namespace {
603//===----------------------------------------------------------------------===//
604// RunLayoutInfoPropagation
605//===----------------------------------------------------------------------===//
606
607/// Driver class for running the LayoutInfoPropagation analysis.
608class RunLayoutInfoPropagation {
609public:
610 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
611
612 RunLayoutInfoPropagation(Operation *op) : target(op) {
613 SymbolTableCollection symbolTable;
614 loadBaselineAnalyses(solver);
615 solver.load<LayoutInfoPropagation>(args&: symbolTable);
616 (void)solver.initializeAndRun(top: op);
617 }
618
619 LayoutInfo getLayoutInfo(Value val);
620
621 void printAnalysisResult(llvm::raw_ostream &os);
622
623private:
624 DataFlowSolver solver;
625 const Operation *target;
626};
627} // namespace
628
629LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
630 auto *state = solver.lookupState<LayoutInfoLattice>(anchor: val);
631 if (!state)
632 return {};
633 return state->getValue();
634}
635
636// Print the analysis result for debugging purposes.
637void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
638 auto printFunctionResult = [&](FunctionOpInterface funcOp) {
639 os << "function: " << funcOp.getName() << ":\n";
640 // Function arguments
641 for (BlockArgument arg : funcOp.getArguments()) {
642 LayoutInfo layout = getLayoutInfo(val: arg);
643 os << "argument: " << arg << "\n";
644 os << "layout : ";
645 layout.print(os);
646 os << "\n";
647 }
648 // Function ops
649 funcOp.walk(callback: [&](Operation *op) {
650 // Skip ops that do not have results
651 if (op->getResults().empty())
652 return;
653 os << "op : ";
654 // For control-flow ops, print the op name only.
655 if (isa<BranchOpInterface>(Val: op) || isa<RegionBranchOpInterface>(Val: op))
656 os << op->getName();
657 else
658 op->print(os);
659 os << "\n";
660 // Print the layout for each result.
661 for (auto [i, r] : llvm::enumerate(First: op->getResults())) {
662 LayoutInfo layout = getLayoutInfo(val: r);
663 os << "layout for result #" << i << ": ";
664 layout.print(os);
665 os << "\n";
666 }
667 });
668 };
669
670 SmallVector<FunctionOpInterface> funcOps;
671 if (auto modOp = dyn_cast<ModuleOp>(Val: target)) {
672 for (auto funcOp : modOp.getOps<FunctionOpInterface>())
673 funcOps.push_back(Elt: funcOp);
674
675 // Collect all GpuFuncOps in the module.
676 for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
677 for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
678 funcOps.push_back(Elt: gpuFuncOp);
679 }
680 }
681 // Print the analysis result for each function.
682 for (FunctionOpInterface funcOp : funcOps)
683 printFunctionResult(funcOp);
684}
685
686using GetLayoutFnTy = function_ref<xegpu::LayoutAttr(Value)>;
687/// Update an operation with the layout of its results. If the result type is a
688/// vector type, a temporary layout attribute is added to the operation. If the
689/// result type is a tensor descriptor type, the type is updated with the layout
690/// attribute. The users of the result are also updated with the layout
691/// attribute.
692static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
693 GetLayoutFnTy getLayoutOfValue) {
694 // Region ops (like scf.for) are already handled by the updateControlFlowOps.
695 if (mlir::isa<mlir::RegionBranchOpInterface>(Val: op))
696 return success();
697
698 // Iterate over all the results.
699 for (OpResult result : op->getResults()) {
700 Type resultType = result.getType();
701 // Layouts are needed only for vector and tensor descriptor types.
702 if (!isa<VectorType, xegpu::TensorDescType>(Val: resultType))
703 continue;
704 // If the result has no layout but has users, emit a warning and continue.
705 xegpu::LayoutAttr layout = getLayoutOfValue(result);
706 if (!layout && result.getNumUses() > 0) {
707 op->emitWarning(message: "op has users but no layout assigned for its result");
708 continue;
709 }
710 // If the result is a tensor descriptor type, update the tensor desc type
711 // with layout.
712 if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(Val&: resultType)) {
713 auto typeWithLayout = xegpu::TensorDescType::get(
714 context: tensorDescTy.getContext(), shape: tensorDescTy.getShape(),
715 elementType: tensorDescTy.getElementType(), encoding: tensorDescTy.getEncoding(), layout);
716 result.setType(typeWithLayout);
717 continue;
718 }
719 // If the result is a vector type, add a temporary layout attribute to the
720 // op.
721 xegpu::setLayoutAttr(operandOrResult: result, layout);
722 }
723 return success();
724}
725
726/// Region ops like scf.for need special handling because they have blocks
727/// inside. If the blocks have tensor descriptor type as block arguments, thier
728/// types must be updated. Also region op can have results that may not have any
729/// users (e.g. A and B tiles). They are not assigned a layout by layout
730/// analysis because they have no users. However inside the region op
731/// corresponding block arguments for these results do have layouts. Therefore,
732/// in this case we still need to update the result types with the layout
733/// attribute. This function function updates the internal block arguments and
734/// the result types of the region op with the assigned layouts.
735/// clang-format off
736/// Example: scf.for ... iter_args(...) -> (out types) {
737/// ^bb0(block types):
738/// ...
739/// scf.yield ... : (yield types)
740/// }
741/// clang-format on
742/// In this example, at scf.yield, control-flow can transfer to two successor
743/// regions. One is the ^bb0 (for loop body) and the other is the scf.for op
744/// itself (yield the results). So we update both the block arguments of the
745/// successor region (i.e. block types) and the result types of the scf.for op
746/// (i.e. out types). Note that yield types are updated by respective producers
747/// inside bb0.
748static LogicalResult
749updateControlFlowOps(mlir::OpBuilder &builder,
750 mlir::RegionBranchTerminatorOpInterface terminator,
751 GetLayoutFnTy getLayoutOfValue) {
752 // Only process if the terminator is inside a region branch op.
753 if (!mlir::isa<mlir::RegionBranchOpInterface>(Val: terminator->getParentOp()))
754 return success();
755
756 llvm::SmallVector<mlir::RegionSuccessor> successors;
757 llvm::SmallVector<mlir::Attribute> operands(terminator->getNumOperands(),
758 nullptr);
759 terminator.getSuccessorRegions(operands, regions&: successors);
760
761 for (mlir::RegionSuccessor &successor : successors) {
762 mlir::OperandRange successorOperands =
763 terminator.getSuccessorOperands(point: successor);
764 mlir::ValueRange successorInputs = successor.getSuccessorInputs();
765 for (auto [successorOperand, successorInput] :
766 llvm::zip(t&: successorOperands, u&: successorInputs)) {
767 Type inputType = successorInput.getType();
768 // We only need to operate on tensor descriptor or vector types.
769 if (!isa<xegpu::TensorDescType, VectorType>(Val: inputType))
770 continue;
771 xegpu::LayoutAttr successorInputLayout = getLayoutOfValue(successorInput);
772 xegpu::LayoutAttr successorOperandLayout =
773 getLayoutOfValue(successorOperand);
774
775 // If either of the layouts is not assigned, we cannot proceed.
776 if (!successorOperandLayout) {
777 LLVM_DEBUG(
778 DBGS()
779 << "No layout assigned for forwarded operand in branch terminator: "
780 << successorOperand << "\n");
781 return failure();
782 }
783 // We expect the layouts to match.
784 if (successorInputLayout &&
785 successorInputLayout != successorOperandLayout) {
786 LLVM_DEBUG(DBGS() << "Conflicting layouts for region argument and "
787 "operand forwarded as the argument: "
788 << successorInputLayout << " vs "
789 << successorOperandLayout << "\n");
790 return failure();
791 }
792 // Get tensor descriptor type with the layout.
793 if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(Val&: inputType)) {
794 auto newTdescTy = xegpu::TensorDescType::get(
795 context: tdescTy.getContext(), shape: tdescTy.getShape(), elementType: tdescTy.getElementType(),
796 encoding: tdescTy.getEncoding(), layout: successorOperandLayout);
797 successorInput.setType(newTdescTy);
798 continue;
799 }
800 // If the type is a vector type and this region argument is an OpResult,
801 // set the layout attribute on the OpResult.
802 if (auto result = dyn_cast<OpResult>(Val&: successorInput))
803 xegpu::setLayoutAttr(operandOrResult: result, layout: successorOperandLayout);
804 }
805 }
806 return success();
807}
808
809/// Update the function arguments and results with the layouts.
810static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder,
811 mlir::FunctionOpInterface funcOp,
812 GetLayoutFnTy getLayoutOfValue) {
813 SmallVector<Type> newArgTypes;
814 // Update the function arguments.
815 for (BlockArgument arg : funcOp.getArguments()) {
816 Type argType = arg.getType();
817 newArgTypes.push_back(Elt: argType);
818 if (!isa<VectorType, xegpu::TensorDescType>(Val: argType))
819 continue;
820 xegpu::LayoutAttr layout = getLayoutOfValue(arg);
821 if (!layout) {
822 LLVM_DEBUG(DBGS() << "Expecting layout for function argument: " << arg
823 << " but got none.\n");
824 return failure();
825 }
826 if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(Val&: argType)) {
827 auto newTdescTy = xegpu::TensorDescType::get(
828 context: tensorDescTy.getContext(), shape: tensorDescTy.getShape(),
829 elementType: tensorDescTy.getElementType(), encoding: tensorDescTy.getEncoding(), layout);
830 arg.setType(newTdescTy);
831 newArgTypes.back() = newTdescTy;
832 }
833 }
834 // Update the function type with the new argument types.
835 // NOTE: We assume that function results are not expected to have layouts.
836 funcOp.setType(FunctionType::get(context: funcOp.getContext(), inputs: newArgTypes,
837 results: funcOp.getResultTypes()));
838 return success();
839}
840
841namespace {
842struct XeGPUPropagateLayoutPass final
843 : public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
844 XeGPUPropagateLayoutPass() = default;
845 XeGPUPropagateLayoutPass(const XeGPUPropagateLayoutPass &other) = default;
846 XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions options)
847 : XeGPUPropagateLayoutBase(options) {}
848 void runOnOperation() override;
849};
850
851} // namespace
852
853void XeGPUPropagateLayoutPass::runOnOperation() {
854 auto &analysis = getAnalysis<RunLayoutInfoPropagation>();
855 // Print the analysis result and exit. (for debugging purposes)
856 if (printOnly) {
857 auto &os = llvm::outs();
858 analysis.printAnalysisResult(os);
859 return;
860 }
861 // Helper to convert LayoutInfo to xegpu::LayoutAttr.
862 auto getXeGPULayoutForValue = [&](Value val) -> xegpu::LayoutAttr {
863 LayoutInfo layout = analysis.getLayoutInfo(val);
864 if (!layout.isAssigned())
865 return {};
866 return xegpu::LayoutAttr::get(
867 context: val.getContext(), lane_layout: llvm::to_vector_of<int>(Range: layout.getLayoutAsArrayRef()),
868 lane_data: llvm::to_vector_of<int>(Range: layout.getDataAsArrayRef()));
869 };
870
871 mlir::OpBuilder builder(&getContext());
872 Operation *op = getOperation();
873 auto walkResult = op->walk(callback: [&](mlir::Block *block) -> WalkResult {
874 for (mlir::Operation &op : llvm::reverse(C&: block->getOperations())) {
875 LogicalResult r = success();
876 TypeSwitch<Operation *>(&op)
877 .Case<mlir::RegionBranchTerminatorOpInterface>(
878 caseFn: [&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
879 r = updateControlFlowOps(builder, terminator: branchTermOp,
880 getLayoutOfValue: getXeGPULayoutForValue);
881 })
882 .Case<mlir::FunctionOpInterface>(
883 caseFn: [&](mlir::FunctionOpInterface funcOp) {
884 r = updateFunctionOpInterface(builder, funcOp,
885 getLayoutOfValue: getXeGPULayoutForValue);
886 })
887 .Default(defaultFn: [&](Operation *op) {
888 r = updateOp(builder, op, getLayoutOfValue: getXeGPULayoutForValue);
889 });
890 if (failed(Result: r)) {
891 op.emitError(message: "Failed to update operation with the layout.");
892 return WalkResult::interrupt();
893 }
894 }
895 return WalkResult::advance();
896 });
897 if (walkResult.wasInterrupted()) {
898 signalPassFailure();
899 return;
900 }
901}
902

source code of mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp