1 | //===- XeGPUSubgroupDistribute.cpp - XeGPU Subgroup Distribute Pass -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" |
9 | #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" |
10 | #include "mlir/Analysis/DataFlow/SparseAnalysis.h" |
11 | #include "mlir/Analysis/DataFlowFramework.h" |
12 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
13 | #include "mlir/Dialect/GPU/Utils/DistributionUtils.h" |
14 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
15 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
16 | #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" |
17 | #include "mlir/Dialect/XeGPU/IR/XeGPU.h" |
18 | #include "mlir/Dialect/XeGPU/Transforms/Passes.h" |
19 | #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" |
20 | #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" |
21 | #include "mlir/IR/AffineMap.h" |
22 | #include "mlir/IR/Attributes.h" |
23 | #include "mlir/IR/Builders.h" |
24 | #include "mlir/IR/BuiltinAttributes.h" |
25 | #include "mlir/IR/BuiltinOps.h" |
26 | #include "mlir/IR/BuiltinTypes.h" |
27 | #include "mlir/IR/Operation.h" |
28 | #include "mlir/IR/PatternMatch.h" |
29 | #include "mlir/IR/TypeRange.h" |
30 | #include "mlir/IR/Value.h" |
31 | #include "mlir/IR/Visitors.h" |
32 | #include "mlir/Interfaces/FunctionInterfaces.h" |
33 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
34 | #include "mlir/Transforms/InliningUtils.h" |
35 | #include "llvm/ADT/ArrayRef.h" |
36 | #include "llvm/ADT/STLExtras.h" |
37 | #include "llvm/ADT/SmallVector.h" |
38 | #include "llvm/ADT/TypeSwitch.h" |
39 | #include "llvm/Support/FormatVariadic.h" |
40 | #include "llvm/Support/InterleavedRange.h" |
41 | #include "llvm/Support/raw_ostream.h" |
42 | |
43 | namespace mlir { |
44 | namespace xegpu { |
45 | #define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE |
46 | #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" |
47 | } // namespace xegpu |
48 | } // namespace mlir |
49 | |
50 | #define DEBUG_TYPE "xegpu-subgroup-distribute" |
51 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
52 | |
53 | using namespace mlir; |
54 | using namespace mlir::dataflow; |
55 | |
56 | /// HW dependent constants. |
57 | /// TODO: These constants should be queried from the target information. |
58 | constexpr unsigned subgroupSize = 16; // How many lanes in a subgroup. |
59 | /// If DPAS A or B operands have low precision element types they must be packed |
60 | /// according to the following sizes. |
61 | constexpr unsigned packedSizeInBitsForDefault = |
62 | 16; // Minimum packing size per register for DPAS A. |
63 | constexpr unsigned packedSizeInBitsForDpasB = |
64 | 32; // Minimum packing size per register for DPAS B. |
65 | |
66 | namespace { |
67 | |
68 | //===----------------------------------------------------------------------===// |
69 | // Layout |
70 | //===----------------------------------------------------------------------===// |
71 | |
72 | /// Helper class to store the ND layout of lanes within a subgroup and data |
73 | /// owned by each lane. |
74 | struct Layout { |
75 | SmallVector<int64_t, 3> layout; |
76 | Layout() = default; |
77 | Layout(std::initializer_list<int64_t> list) : layout(list) {} |
78 | void print(llvm::raw_ostream &os) const; |
79 | size_t size() const { return layout.size(); } |
80 | int64_t operator[](size_t idx) const; |
81 | }; |
82 | |
83 | void Layout::print(llvm::raw_ostream &os) const { |
84 | os << llvm::interleaved_array(R: layout); |
85 | } |
86 | |
87 | int64_t Layout::operator[](size_t idx) const { |
88 | assert(idx < layout.size() && "Index out of bounds." ); |
89 | return layout[idx]; |
90 | } |
91 | |
92 | /// LaneLayout represents the logical layout of lanes within a subgroup when it |
93 | /// accesses some value. LaneData represents the logical layout of data owned by |
94 | /// each work item. |
95 | using LaneLayout = Layout; |
96 | using LaneData = Layout; |
97 | |
98 | //===----------------------------------------------------------------------===// |
99 | // LayoutInfo |
100 | //===----------------------------------------------------------------------===// |
101 | |
102 | /// Helper class for tracking the analysis state of an mlir value. For layout |
103 | /// propagation, the analysis state is simply the lane_layout and lane_data of |
104 | /// each value. Purpose of this analysis to propagate some unique layout for |
105 | /// each value in the program starting from a set of anchor operations (like |
106 | /// DPAS, StoreNd, etc.). |
107 | /// |
108 | /// Given this, LayoutInfo satisifies the following properties: |
109 | /// 1) A LayoutInfo value can be in one of two states - `assigned` or `not |
110 | /// assigned`. |
111 | /// 2) Two LayoutInfo values are equal if they are both assigned or |
112 | /// both not assigned. The concrete value of assigned state does not matter. |
113 | /// 3) The meet operator works as follows: |
114 | /// - If current state is assigned, return the current state. (already |
115 | /// a unique layout is assigned. don't change it) |
116 | /// - Otherwise, return the other state. |
117 | |
118 | struct LayoutInfo { |
119 | private: |
120 | LaneLayout laneLayout; |
121 | LaneData laneData; |
122 | |
123 | public: |
124 | LayoutInfo() = default; |
125 | LayoutInfo(const LaneLayout &layout, const LaneData &data) |
126 | : laneLayout(layout), laneData(data) {} |
127 | |
128 | // Two lattice values are equal if they have `some` layout. The actual |
129 | // content of the layout does not matter. |
130 | bool operator==(const LayoutInfo &other) const { |
131 | return this->isAssigned() == other.isAssigned(); |
132 | } |
133 | |
134 | static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs); |
135 | |
136 | static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs); |
137 | |
138 | void print(raw_ostream &os) const; |
139 | |
140 | bool isAssigned() const { |
141 | return laneLayout.size() > 0 && laneData.size() > 0; |
142 | } |
143 | |
144 | LayoutInfo getTransposedLayout(ArrayRef<int64_t> permutation) const; |
145 | |
146 | const LaneLayout &getLayout() const { return laneLayout; } |
147 | const LaneData &getData() const { return laneData; } |
148 | ArrayRef<int64_t> getLayoutAsArrayRef() const { return laneLayout.layout; } |
149 | ArrayRef<int64_t> getDataAsArrayRef() const { return laneData.layout; } |
150 | }; |
151 | |
152 | void LayoutInfo::print(raw_ostream &os) const { |
153 | if (isAssigned()) { |
154 | os << "lane_layout: " ; |
155 | laneLayout.print(os); |
156 | os << ", lane_data: " ; |
157 | laneData.print(os); |
158 | } else { |
159 | os << "Not assigned." ; |
160 | } |
161 | } |
162 | |
163 | LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) { |
164 | if (!lhs.isAssigned()) |
165 | return rhs; |
166 | return lhs; |
167 | } |
168 | |
169 | /// Since this is a backward analysis, join method is not used. |
170 | LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) { |
171 | llvm_unreachable("Join should not be triggered by layout propagation." ); |
172 | } |
173 | |
174 | /// Get the transposed layout according to the given permutation. |
175 | LayoutInfo |
176 | LayoutInfo::getTransposedLayout(ArrayRef<int64_t> permutation) const { |
177 | if (!isAssigned()) |
178 | return {}; |
179 | LaneLayout newLayout; |
180 | LaneData newData; |
181 | for (int64_t idx : permutation) { |
182 | newLayout.layout.push_back(Elt: laneLayout.layout[idx]); |
183 | newData.layout.push_back(Elt: laneData.layout[idx]); |
184 | } |
185 | return LayoutInfo(newLayout, newData); |
186 | } |
187 | |
188 | //===----------------------------------------------------------------------===// |
189 | // LayoutInfoLattice |
190 | //===----------------------------------------------------------------------===// |
191 | |
192 | /// Lattice holding the LayoutInfo for each value. |
193 | struct LayoutInfoLattice : public Lattice<LayoutInfo> { |
194 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LayoutInfoLattice) |
195 | using Lattice::Lattice; |
196 | }; |
197 | |
198 | /// Helper Functions to get default layouts. A `default layout` is a layout that |
199 | /// is assigned to a value when the layout is not fixed by some anchor operation |
200 | /// (like DPAS). |
201 | |
202 | /// Helper Function to get the default layout for uniform values like constants. |
203 | /// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1]. |
204 | /// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1]. |
205 | static LayoutInfo getDefaultLayoutInfo(unsigned rank) { |
206 | assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector." ); |
207 | if (rank == 1) |
208 | return LayoutInfo(LaneLayout({subgroupSize}), LaneData({1})); |
209 | return LayoutInfo(LaneLayout({1, subgroupSize}), LaneData({1, 1})); |
210 | } |
211 | |
212 | /// Helper to get the default layout for a vector type. |
213 | static LayoutInfo getDefaultLayoutInfo(VectorType vectorTy) { |
214 | // Expecting a 1D or 2D vector. |
215 | assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) && |
216 | "Expected 1D or 2D vector." ); |
217 | // Expecting int or float element type. |
218 | assert(vectorTy.getElementType().isIntOrFloat() && |
219 | "Expected int or float element type." ); |
220 | // If the rank is 1, then return default layout for 1D vector. |
221 | if (vectorTy.getRank() == 1) |
222 | return getDefaultLayoutInfo(rank: 1); |
223 | // Packing factor is determined by the element type bitwidth. |
224 | int packingFactor = 1; |
225 | unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth(); |
226 | if (bitwidth < packedSizeInBitsForDefault) |
227 | packingFactor = packedSizeInBitsForDefault / bitwidth; |
228 | return LayoutInfo(LaneLayout({1, subgroupSize}), |
229 | LaneData({1, packingFactor})); |
230 | } |
231 | |
232 | /// Helper Function to get the expected layouts for DPAS operands. `lane_data` |
233 | /// is set according to the following criteria: |
234 | /// * For A operand, the data must be packed in minimum |
235 | /// `packedSizeInBitsForDefault` |
236 | /// * For B operand, the data must be packed in minimum |
237 | /// `packedSizeInBitsForDpasB` |
238 | static LayoutInfo getLayoutInfoForDPASOperand(VectorType vectorTy, |
239 | unsigned operandNum) { |
240 | Type elementTy = vectorTy.getElementType(); |
241 | assert(elementTy.isIntOrFloat() && |
242 | "Expected int or float type in DPAS operands" ); |
243 | LaneLayout layout({1, subgroupSize}); |
244 | // For B operand, data must be packed in minimum `packedDpasBSizeInBits` and |
245 | // must have the VNNI format. |
246 | if (operandNum == 1 && |
247 | elementTy.getIntOrFloatBitWidth() < packedSizeInBitsForDpasB) { |
248 | LaneData data( |
249 | {packedSizeInBitsForDpasB / elementTy.getIntOrFloatBitWidth(), 1}); |
250 | return LayoutInfo(layout, data); |
251 | } |
252 | // Otherwise, return the default layout for the vector type. |
253 | return getDefaultLayoutInfo(vectorTy); |
254 | } |
255 | |
256 | //===----------------------------------------------------------------------===// |
257 | // LayoutInfoPropagation |
258 | //===----------------------------------------------------------------------===// |
259 | |
260 | /// Backward data flow analysis to propagate the lane_layout and lane_data of |
261 | /// each value in the program. Currently, the layouts for operands DPAS, |
262 | /// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of |
263 | /// this analysis is to propagate those known layouts to all their producers and |
264 | /// (other) consumers. |
265 | class LayoutInfoPropagation |
266 | : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> { |
267 | private: |
268 | void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands, |
269 | ArrayRef<const LayoutInfoLattice *> results); |
270 | |
271 | void visitStoreNdOp(xegpu::StoreNdOp store, |
272 | ArrayRef<LayoutInfoLattice *> operands, |
273 | ArrayRef<const LayoutInfoLattice *> results); |
274 | |
275 | void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter, |
276 | ArrayRef<LayoutInfoLattice *> operands, |
277 | ArrayRef<const LayoutInfoLattice *> results); |
278 | |
279 | void visitLoadNdOp(xegpu::LoadNdOp load, |
280 | ArrayRef<LayoutInfoLattice *> operands, |
281 | ArrayRef<const LayoutInfoLattice *> results); |
282 | |
283 | void visitLoadGatherOp(xegpu::LoadGatherOp load, |
284 | ArrayRef<LayoutInfoLattice *> operands, |
285 | ArrayRef<const LayoutInfoLattice *> results); |
286 | |
287 | void visitTransposeOp(vector::TransposeOp transpose, |
288 | ArrayRef<LayoutInfoLattice *> operands, |
289 | ArrayRef<const LayoutInfoLattice *> results); |
290 | |
291 | void visitVectorBitcastOp(vector::BitCastOp bitcast, |
292 | ArrayRef<LayoutInfoLattice *> operands, |
293 | ArrayRef<const LayoutInfoLattice *> results); |
294 | |
295 | void visitCreateDescOp(xegpu::CreateDescOp createDesc, |
296 | ArrayRef<LayoutInfoLattice *> operands, |
297 | ArrayRef<const LayoutInfoLattice *> results); |
298 | |
299 | void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset, |
300 | ArrayRef<LayoutInfoLattice *> operands, |
301 | ArrayRef<const LayoutInfoLattice *> results); |
302 | |
303 | void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch, |
304 | ArrayRef<LayoutInfoLattice *> operands, |
305 | ArrayRef<const LayoutInfoLattice *> results); |
306 | |
307 | void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction, |
308 | ArrayRef<LayoutInfoLattice *> operands, |
309 | ArrayRef<const LayoutInfoLattice *> results); |
310 | |
311 | public: |
312 | LayoutInfoPropagation(DataFlowSolver &solver, |
313 | SymbolTableCollection &symbolTable) |
314 | : SparseBackwardDataFlowAnalysis(solver, symbolTable) {} |
315 | using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; |
316 | |
317 | LogicalResult |
318 | visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands, |
319 | ArrayRef<const LayoutInfoLattice *> results) override; |
320 | |
321 | void visitBranchOperand(OpOperand &operand) override {}; |
322 | |
323 | void visitCallOperand(OpOperand &operand) override {}; |
324 | |
325 | void visitExternalCall(CallOpInterface call, |
326 | ArrayRef<LayoutInfoLattice *> operands, |
327 | ArrayRef<const LayoutInfoLattice *> results) override { |
328 | }; |
329 | |
330 | void setToExitState(LayoutInfoLattice *lattice) override { |
331 | (void)lattice->meet(rhs: LayoutInfo()); |
332 | } |
333 | }; |
334 | } // namespace |
335 | |
336 | LogicalResult LayoutInfoPropagation::visitOperation( |
337 | Operation *op, ArrayRef<LayoutInfoLattice *> operands, |
338 | ArrayRef<const LayoutInfoLattice *> results) { |
339 | TypeSwitch<Operation *>(op) |
340 | .Case<xegpu::DpasOp>( |
341 | [&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); }) |
342 | .Case<xegpu::StoreNdOp>( |
343 | [&](auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); }) |
344 | .Case<xegpu::StoreScatterOp>([&](auto storeScatterOp) { |
345 | visitStoreScatterOp(storeScatterOp, operands, results); |
346 | }) |
347 | .Case<xegpu::LoadNdOp>( |
348 | [&](auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); }) |
349 | .Case<xegpu::LoadGatherOp>([&](auto loadGatherOp) { |
350 | visitLoadGatherOp(loadGatherOp, operands, results); |
351 | }) |
352 | .Case<xegpu::CreateDescOp>([&](auto createDescOp) { |
353 | visitCreateDescOp(createDescOp, operands, results); |
354 | }) |
355 | .Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) { |
356 | visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results); |
357 | }) |
358 | .Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) { |
359 | visitPrefetchNdOp(prefetchNdOp, operands, results); |
360 | }) |
361 | // No need to propagate the layout to operands in CreateNdDescOp because |
362 | // they are scalars (offsets, sizes, etc.). |
363 | .Case<xegpu::CreateNdDescOp>([&](auto createNdDescOp) {}) |
364 | .Case<vector::TransposeOp>([&](auto transposeOp) { |
365 | visitTransposeOp(transposeOp, operands, results); |
366 | }) |
367 | .Case<vector::BitCastOp>([&](auto bitcastOp) { |
368 | visitVectorBitcastOp(bitcastOp, operands, results); |
369 | }) |
370 | .Case<vector::MultiDimReductionOp>([&](auto reductionOp) { |
371 | visitVectorMultiReductionOp(reductionOp, operands, results); |
372 | }) |
373 | // All other ops. |
374 | .Default([&](Operation *op) { |
375 | for (const LayoutInfoLattice *r : results) { |
376 | for (LayoutInfoLattice *operand : operands) { |
377 | // Propagate the layout of the result to the operand. |
378 | if (r->getValue().isAssigned()) |
379 | meet(operand, *r); |
380 | } |
381 | } |
382 | }); |
383 | // Add a dependency from each result to program point after the operation. |
384 | for (const LayoutInfoLattice *r : results) { |
385 | addDependency(state: const_cast<LayoutInfoLattice *>(r), point: getProgramPointAfter(op)); |
386 | } |
387 | return success(); |
388 | } |
389 | |
390 | void LayoutInfoPropagation::visitPrefetchNdOp( |
391 | xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands, |
392 | ArrayRef<const LayoutInfoLattice *> results) { |
393 | // Here we assign the default layout to the tensor descriptor operand of |
394 | // prefetch. |
395 | auto tdescTy = prefetch.getTensorDescType(); |
396 | auto prefetchLayout = getDefaultLayoutInfo( |
397 | VectorType::get(tdescTy.getShape(), tdescTy.getElementType())); |
398 | // Propagate the layout to the source tensor descriptor. |
399 | propagateIfChanged(state: operands[0], changed: operands[0]->meet(prefetchLayout)); |
400 | } |
401 | |
402 | void LayoutInfoPropagation::visitVectorMultiReductionOp( |
403 | vector::MultiDimReductionOp reduction, |
404 | ArrayRef<LayoutInfoLattice *> operands, |
405 | ArrayRef<const LayoutInfoLattice *> results) { |
406 | // The layout of the result must be present. |
407 | LayoutInfo resultLayout = results[0]->getValue(); |
408 | if (!resultLayout.isAssigned()) |
409 | return; |
410 | // We only consider 2D -> 1D reductions at this point. |
411 | assert(resultLayout.getLayout().size() == 1 && |
412 | "Expected 1D layout for reduction result." ); |
413 | // Given that the result is 1D, the layout of the operand should be 2D with |
414 | // default layout. |
415 | LayoutInfo operandLayout = getDefaultLayoutInfo(rank: 2); |
416 | propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: operandLayout)); |
417 | // Accumulator should have the same layout as the result. |
418 | propagateIfChanged(state: operands[1], changed: operands[1]->meet(rhs: resultLayout)); |
419 | } |
420 | |
421 | /// Propagate the layout of the result tensor to the source tensor descriptor in |
422 | /// UpdateNdOffsetOp. |
423 | void LayoutInfoPropagation::visitUpdateNdOffsetOp( |
424 | xegpu::UpdateNdOffsetOp updateNdOffset, |
425 | ArrayRef<LayoutInfoLattice *> operands, |
426 | ArrayRef<const LayoutInfoLattice *> results) { |
427 | // The layout of the result must be present. |
428 | LayoutInfo resultLayout = results[0]->getValue(); |
429 | if (!resultLayout.isAssigned()) |
430 | return; |
431 | // Propagate the layout to the source operand. |
432 | propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: resultLayout)); |
433 | } |
434 | |
435 | /// Set the layouts for DPAS A, B, and C operands. |
436 | void LayoutInfoPropagation::visitDpasOp( |
437 | xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands, |
438 | ArrayRef<const LayoutInfoLattice *> results) { |
439 | VectorType aTy = dpas.getLhsType(); |
440 | VectorType bTy = dpas.getRhsType(); |
441 | propagateIfChanged(state: operands[0], |
442 | changed: operands[0]->meet(getLayoutInfoForDPASOperand(aTy, 0))); |
443 | propagateIfChanged(state: operands[1], |
444 | changed: operands[1]->meet(getLayoutInfoForDPASOperand(bTy, 1))); |
445 | if (operands.size() > 2) { |
446 | VectorType cTy = dpas.getAccType(); |
447 | propagateIfChanged(state: operands[2], |
448 | changed: operands[2]->meet(getLayoutInfoForDPASOperand(cTy, 2))); |
449 | } |
450 | } |
451 | |
452 | /// Set the layout for the value and tensor descriptor operands in StoreNdOp. |
453 | void LayoutInfoPropagation::visitStoreNdOp( |
454 | xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands, |
455 | ArrayRef<const LayoutInfoLattice *> results) { |
456 | LayoutInfo storeLayout = getDefaultLayoutInfo(store.getValueType()); |
457 | // Both operands should have the same layout |
458 | for (LayoutInfoLattice *operand : operands) { |
459 | propagateIfChanged(state: operand, changed: operand->meet(rhs: storeLayout)); |
460 | } |
461 | } |
462 | |
463 | /// Propagate the layout of the value to the tensor descriptor operand in |
464 | /// LoadNdOp. |
465 | void LayoutInfoPropagation::visitLoadNdOp( |
466 | xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands, |
467 | ArrayRef<const LayoutInfoLattice *> results) { |
468 | LayoutInfo valueLayout = results[0]->getValue(); |
469 | // Need the layout of the value to propagate to the tensor descriptor. |
470 | if (!valueLayout.isAssigned()) |
471 | return; |
472 | LayoutInfo tensorDescLayout = valueLayout; |
473 | // LoadNdOp has the transpose effect. However, at the stage of this analysis |
474 | // this effect is not expected and should be abstracted away. Emit a warning. |
475 | if (auto transpose = load.getTranspose()) { |
476 | load.emitWarning("Transpose effect is not expected for LoadNdOp at " |
477 | "LayoutInfoPropagation stage." ); |
478 | tensorDescLayout = valueLayout.getTransposedLayout(permutation: transpose.value()); |
479 | } |
480 | // Propagate the new layout to the tensor descriptor operand. |
481 | propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: tensorDescLayout)); |
482 | } |
483 | |
484 | /// For vector::TransposeOp, the layout of the result is transposed and |
485 | /// propagated to the operand. |
486 | void LayoutInfoPropagation::visitTransposeOp( |
487 | vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands, |
488 | ArrayRef<const LayoutInfoLattice *> results) { |
489 | // Need the layout of transpose result to propagate to the operands. |
490 | LayoutInfo resultLayout = results[0]->getValue(); |
491 | if (!resultLayout.isAssigned()) |
492 | return; |
493 | LayoutInfo newLayout = |
494 | resultLayout.getTransposedLayout(permutation: transpose.getPermutation()); |
495 | // Propagate the new layout to the vector operand. |
496 | propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: newLayout)); |
497 | } |
498 | |
499 | /// For vector::BitCastOp, the lane_data of the source layout is changed based |
500 | /// on the bit width of the source and result types. |
501 | void LayoutInfoPropagation::visitVectorBitcastOp( |
502 | vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands, |
503 | ArrayRef<const LayoutInfoLattice *> results) { |
504 | // Need the layout of bitcast result to propagate to the operands. |
505 | LayoutInfo resultLayout = results[0]->getValue(); |
506 | if (!resultLayout.isAssigned()) |
507 | return; |
508 | int inElemTyBitWidth = |
509 | bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth(); |
510 | int outElemTyBitWidth = |
511 | bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth(); |
512 | |
513 | // LaneLayout does not change. |
514 | const LaneLayout &newLaneLayout = resultLayout.getLayout(); |
515 | const LaneData &currData = resultLayout.getData(); |
516 | LaneData newLaneData; |
517 | // It's a widening bitcast |
518 | if (inElemTyBitWidth < outElemTyBitWidth) { |
519 | int ratio = outElemTyBitWidth / inElemTyBitWidth; |
520 | newLaneData = resultLayout.getData()[0] == 1 |
521 | ? LaneData({1, currData[1] * ratio}) |
522 | : LaneData({currData[0] * ratio, 1}); |
523 | } else { |
524 | // It's a narrowing bitcast |
525 | int ratio = inElemTyBitWidth / outElemTyBitWidth; |
526 | newLaneData = resultLayout.getData()[0] == 1 |
527 | ? LaneData({1, currData[1] / ratio}) |
528 | : LaneData({currData[0] / ratio, 1}); |
529 | } |
530 | |
531 | propagateIfChanged(state: operands[0], |
532 | changed: operands[0]->meet(rhs: LayoutInfo(newLaneLayout, newLaneData))); |
533 | } |
534 | |
535 | /// Propagate the layout of the result to the tensor descriptor and mask |
536 | /// operands in LoadGatherOp. |
537 | void LayoutInfoPropagation::visitLoadGatherOp( |
538 | xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands, |
539 | ArrayRef<const LayoutInfoLattice *> results) { |
540 | LayoutInfo valueLayout = results[0]->getValue(); |
541 | // Need the layout of the value to propagate to the tensor descriptor. |
542 | if (!valueLayout.isAssigned()) |
543 | return; |
544 | |
545 | LayoutInfo tensorDescLayout = valueLayout; |
546 | if (load.getTranspose()) { |
547 | // LoadGatherOp has the transpose effect. However, at the stage of this |
548 | // analyis this effect is not expected and should be abstracted away. Emit |
549 | // a warning. |
550 | load.emitWarning("Transpose effect is not expected for LoadGatherOp at " |
551 | "LayoutInfoPropagation stage." ); |
552 | tensorDescLayout = valueLayout.getTransposedLayout(permutation: {1, 0}); |
553 | } |
554 | // Mask operand should have 1D default layout. |
555 | LayoutInfo maskLayout = getDefaultLayoutInfo(rank: 1); |
556 | // Propagate the new layout to the tensor descriptor operand. |
557 | propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: tensorDescLayout)); |
558 | // Propagate the new layout to the mask operand. |
559 | propagateIfChanged(state: operands[1], changed: operands[1]->meet(rhs: maskLayout)); |
560 | } |
561 | |
562 | /// Propagate the layout of the descriptor to the vector offset operand in |
563 | /// CreateDescOp. |
564 | void LayoutInfoPropagation::visitCreateDescOp( |
565 | xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands, |
566 | ArrayRef<const LayoutInfoLattice *> results) { |
567 | LayoutInfo descLayout = results[0]->getValue(); |
568 | // Need the layout of the descriptor to propagate to the operands. |
569 | if (!descLayout.isAssigned()) |
570 | return; |
571 | // For offset operand propagate 1D default layout. |
572 | LayoutInfo layout = getDefaultLayoutInfo(rank: 1); |
573 | propagateIfChanged(state: operands[1], changed: operands[1]->meet(rhs: layout)); |
574 | } |
575 | |
576 | /// Set the layout for the value, tensor descriptor, and mask operands in the |
577 | /// StoreScatterOp. |
578 | void LayoutInfoPropagation::visitStoreScatterOp( |
579 | xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands, |
580 | ArrayRef<const LayoutInfoLattice *> results) { |
581 | // Currently, for 2D StoreScatterOp we expect that the height dimension of |
582 | // the tensor descriptor is equal to the subgroup size. This is ensured by |
583 | // the op verifier. |
584 | ArrayRef<int64_t> tdescShape = storeScatter.getTensorDescType().getShape(); |
585 | if (tdescShape.size() > 1) |
586 | assert( |
587 | tdescShape[0] == subgroupSize && |
588 | "Expected the first dimension of 2D tensor descriptor to be equal to " |
589 | "subgroup size." ); |
590 | |
591 | LayoutInfo valueLayout = getDefaultLayoutInfo(storeScatter.getValueType()); |
592 | LayoutInfo storeScatterLayout = valueLayout; |
593 | if (storeScatter.getTranspose()) { |
594 | // StoreScatteOp allows transpose effect. However, at the stage of this |
595 | // analyis this effect is not expected and should be abstracted away. Emit |
596 | // a warning. |
597 | storeScatter.emitWarning("Transpose effect is not expected for " |
598 | "StoreScatterOp at LayoutInfoPropagation stage." ); |
599 | storeScatterLayout = valueLayout.getTransposedLayout(permutation: {1, 0}); |
600 | } |
601 | // Propagate the value layout. |
602 | propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: valueLayout)); |
603 | // Propagate the tensor descriptor layout. |
604 | propagateIfChanged(state: operands[1], changed: operands[1]->meet(rhs: storeScatterLayout)); |
605 | // Use default 1D layout for mask operand. |
606 | LayoutInfo maskLayout = getDefaultLayoutInfo(rank: 1); |
607 | propagateIfChanged(state: operands[2], changed: operands[2]->meet(rhs: maskLayout)); |
608 | } |
609 | |
610 | namespace { |
611 | |
612 | //===----------------------------------------------------------------------===// |
613 | // RunLayoutInfoPropagation |
614 | //===----------------------------------------------------------------------===// |
615 | |
616 | /// Driver class for running the LayoutInfoPropagation analysis. |
617 | class RunLayoutInfoPropagation { |
618 | public: |
619 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation) |
620 | |
621 | RunLayoutInfoPropagation(Operation *op) : target(op) { |
622 | SymbolTableCollection symbolTable; |
623 | solver.load<DeadCodeAnalysis>(); |
624 | solver.load<SparseConstantPropagation>(); |
625 | solver.load<LayoutInfoPropagation>(args&: symbolTable); |
626 | (void)solver.initializeAndRun(top: op); |
627 | } |
628 | |
629 | LayoutInfo getLayoutInfo(Value val); |
630 | |
631 | void printAnalysisResult(llvm::raw_ostream &os); |
632 | |
633 | private: |
634 | DataFlowSolver solver; |
635 | const Operation *target; |
636 | }; |
637 | } // namespace |
638 | |
639 | LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) { |
640 | auto *state = solver.lookupState<LayoutInfoLattice>(anchor: val); |
641 | if (!state) |
642 | return {}; |
643 | return state->getValue(); |
644 | } |
645 | |
646 | void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) { |
647 | auto printFunctionResult = [&](FunctionOpInterface funcOp) { |
648 | os << "function: " << funcOp.getName() << ":\n" ; |
649 | // Function arguments |
650 | for (BlockArgument arg : funcOp.getArguments()) { |
651 | LayoutInfo layout = getLayoutInfo(arg); |
652 | os << "argument: " << arg << "\n" ; |
653 | os << "layout : " ; |
654 | layout.print(os); |
655 | os << "\n" ; |
656 | } |
657 | // Function ops |
658 | funcOp.walk([&](Operation *op) { |
659 | // Skip ops that do not have results |
660 | if (op->getResults().empty()) |
661 | return; |
662 | os << "op : " ; |
663 | // For control-flow ops, print the op name only. |
664 | if (isa<BranchOpInterface>(Val: op) || isa<RegionBranchOpInterface>(Val: op)) |
665 | os << op->getName(); |
666 | else |
667 | op->print(os); |
668 | os << "\n" ; |
669 | // Print the layout for each result. |
670 | for (auto [i, r] : llvm::enumerate(First: op->getResults())) { |
671 | LayoutInfo layout = getLayoutInfo(val: r); |
672 | os << "layout for result #" << i << ": " ; |
673 | layout.print(os); |
674 | os << "\n" ; |
675 | } |
676 | }); |
677 | }; |
678 | |
679 | SmallVector<FunctionOpInterface> funcOps; |
680 | if (auto modOp = dyn_cast<ModuleOp>(target)) { |
681 | for (auto funcOp : modOp.getOps<FunctionOpInterface>()) { |
682 | funcOps.push_back(funcOp); |
683 | } |
684 | // Collect all GpuFuncOps in the module. |
685 | for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) { |
686 | for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>()) { |
687 | funcOps.push_back(gpuFuncOp); |
688 | } |
689 | } |
690 | } |
691 | // Print the analysis result for each function. |
692 | for (FunctionOpInterface funcOp : funcOps) { |
693 | printFunctionResult(funcOp); |
694 | } |
695 | } |
696 | |
697 | namespace { |
698 | |
699 | //===----------------------------------------------------------------------===// |
700 | // LayoutAttrAssignment |
701 | //===----------------------------------------------------------------------===// |
702 | |
703 | /// This class is responsible for assigning the layout attributes to the ops and |
704 | /// their users based on the layout propagation analysis result. |
705 | class LayoutAttrAssignment { |
706 | public: |
707 | LayoutAttrAssignment(Operation *top, |
708 | function_ref<LayoutInfo(Value)> getLayout) |
709 | : getAnalysisResult(getLayout), top(top) {} |
710 | |
711 | LogicalResult run(); |
712 | |
713 | private: |
714 | LogicalResult assign(Operation *op); |
715 | void assignToUsers(Value v, xegpu::LayoutAttr layout); |
716 | xegpu::LayoutAttr getLayoutAttrForValue(Value v); |
717 | LogicalResult resolveConflicts(); |
718 | // Callable to get the layout of a value based on the layout propagation |
719 | // analysis. |
720 | function_ref<LayoutInfo(Value)> getAnalysisResult; |
721 | Operation *top; |
722 | }; |
723 | |
724 | } // namespace |
725 | |
726 | /// Helper to assign the layout attribute to the users of the value. |
727 | void LayoutAttrAssignment::assignToUsers(Value v, xegpu::LayoutAttr layout) { |
728 | for (OpOperand &user : v.getUses()) { |
729 | Operation *owner = user.getOwner(); |
730 | std::string attrName = xegpu::getLayoutName(operand: user); |
731 | owner->setAttr(attrName, layout); |
732 | } |
733 | } |
734 | |
735 | /// Convert the layout assigned to a value to xegpu::LayoutAttr. |
736 | xegpu::LayoutAttr LayoutAttrAssignment::getLayoutAttrForValue(Value v) { |
737 | LayoutInfo layout = getAnalysisResult(v); |
738 | if (!layout.isAssigned()) |
739 | return {}; |
740 | SmallVector<int, 2> laneLayout, laneData; |
741 | for (auto [layout, data] : llvm::zip_equal(t: layout.getLayoutAsArrayRef(), |
742 | u: layout.getDataAsArrayRef())) { |
743 | laneLayout.push_back(Elt: static_cast<int>(layout)); |
744 | laneData.push_back(Elt: static_cast<int>(data)); |
745 | } |
746 | return xegpu::LayoutAttr::get(v.getContext(), laneLayout, laneData); |
747 | } |
748 | |
749 | /// Assign xegpu::LayoutAttr to the op and its users. The layout is assigned |
750 | /// based on the layout propagation analysis result. |
751 | LogicalResult LayoutAttrAssignment::assign(Operation *op) { |
752 | // For function ops, propagate the function argument layout to the users. |
753 | if (auto func = dyn_cast<FunctionOpInterface>(op)) { |
754 | for (BlockArgument arg : func.getArguments()) { |
755 | xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(arg); |
756 | if (layoutInfo) { |
757 | assignToUsers(arg, layoutInfo); |
758 | } |
759 | } |
760 | return success(); |
761 | } |
762 | // If no results, move on. |
763 | if (op->getNumResults() == 0) |
764 | return success(); |
765 | // If all the results are scalars, move on. |
766 | if (llvm::all_of(Range: op->getResultTypes(), |
767 | P: [](Type t) { return t.isIntOrIndexOrFloat(); })) |
768 | return success(); |
769 | // If the op has more than one result and at least one result is a tensor |
770 | // descriptor, exit. This case is not supported yet. |
771 | // TODO: Support this case. |
772 | if (op->getNumResults() > 1 && llvm::any_of(Range: op->getResultTypes(), P: [](Type t) { |
773 | return isa<xegpu::TensorDescType>(Val: t); |
774 | })) { |
775 | LLVM_DEBUG( |
776 | DBGS() << op->getName() |
777 | << " op has more than one result and at least one is a tensor " |
778 | "descriptor. This case is not handled.\n" ); |
779 | return failure(); |
780 | } |
781 | // If the result is a tensor descriptor, attach the layout to the tensor |
782 | // descriptor itself. |
783 | if (auto tensorDescTy = |
784 | dyn_cast<xegpu::TensorDescType>(op->getResultTypes()[0])) { |
785 | xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(op->getResult(0)); |
786 | if (!layoutInfo) { |
787 | LLVM_DEBUG(DBGS() << "No layout for result of " << *op << "\n" ); |
788 | return failure(); |
789 | } |
790 | |
791 | // Clone the op, attach the layout to the result tensor descriptor, and |
792 | // remove the original op. |
793 | OpBuilder builder(op); |
794 | Operation *newOp = builder.clone(op&: *op); |
795 | auto newTensorDescTy = xegpu::TensorDescType::get( |
796 | tensorDescTy.getContext(), tensorDescTy.getShape(), |
797 | tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layoutInfo); |
798 | newOp->getResult(idx: 0).setType(newTensorDescTy); |
799 | op->replaceAllUsesWith(values: newOp->getResults()); |
800 | op->erase(); |
801 | return success(); |
802 | } |
803 | // Otherwise simply attach the layout to the op itself. |
804 | for (auto r : op->getOpResults()) { |
805 | xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(r); |
806 | if (layoutInfo) { |
807 | std::string attrName = xegpu::getLayoutName(result: r); |
808 | op->setAttr(attrName, layoutInfo); |
809 | // Attach the layout attribute to the users of the result. |
810 | assignToUsers(v: r, layout: layoutInfo); |
811 | } |
812 | } |
813 | return success(); |
814 | } |
815 | |
816 | /// Walk the IR and attach xegpu::LayoutAttr to all ops and their users. |
817 | LogicalResult LayoutAttrAssignment::run() { |
818 | auto walkResult = top->walk(callback: [&](Operation *op) { |
819 | if (failed(Result: assign(op))) |
820 | return WalkResult::interrupt(); |
821 | return WalkResult::advance(); |
822 | }); |
823 | |
824 | if (walkResult.wasInterrupted()) |
825 | return failure(); |
826 | |
827 | return resolveConflicts(); |
828 | } |
829 | |
830 | /// TODO: Implement the layout conflict resolution. This must ensure mainly two |
831 | /// things: |
832 | /// 1) Is a given layout supported by the op? (need to query the target |
833 | /// HW info). Otherwise can we achieve this layout using a layout conversion? |
834 | /// 2) Do all the operands have the required layout? If not, can it |
835 | /// be resolved using a layout conversion? |
836 | LogicalResult LayoutAttrAssignment::resolveConflicts() { return success(); } |
837 | |
838 | namespace { |
839 | |
840 | //===----------------------------------------------------------------------===// |
841 | // SIMT Distribution Patterns |
842 | //===----------------------------------------------------------------------===// |
843 | |
844 | /// Helper function to get distributed vector type for a source vector type |
845 | /// according to the lane_layout. We simply divide each dimension of tensor |
846 | /// descriptor shape by corresponding lane_layout dimension. If |
847 | /// array_length > 1, that is appended to the front of the ditributed shape. |
848 | /// NOTE: This is the vector type that will be returned by the |
849 | /// gpu.warp_execute_on_lane0 op. |
850 | /// |
851 | /// Examples: |
852 | /// | original vector shape | lane_layout | distributed vector shape | |
853 | /// |-----------------------|-------------|--------------------------| |
854 | /// | 32x16 | [1, 16] | 32x1 | |
855 | /// | 32x16 | [2, 8] | 16x2 | |
856 | /// | 2x32x16 | [1, 16] | 2x32x1 | |
857 | static FailureOr<VectorType> |
858 | getDistVecTypeBasedOnLaneLayout(xegpu::LayoutAttr layout, |
859 | VectorType originalType) { |
860 | if (!layout) |
861 | return failure(); |
862 | |
863 | auto laneLayout = layout.getLaneLayout().asArrayRef(); |
864 | assert(originalType.getShape().size() >= laneLayout.size() && |
865 | "Rank of the original vector type should be greater or equal to the " |
866 | "size of the lane layout to distribute the vector type." ); |
867 | SmallVector<int64_t> distributedShape(originalType.getShape()); |
868 | // Only distribute the last `laneLayout.size()` dimensions. The remaining |
869 | // dimensions are not distributed. |
870 | unsigned distributionStart = originalType.getRank() - laneLayout.size(); |
871 | for (auto [i, dim] : llvm::enumerate(originalType.getShape())) { |
872 | if (i < distributionStart) { |
873 | continue; |
874 | } |
875 | // Check if the dimension can be distributed evenly. |
876 | if (dim % laneLayout[i - distributionStart] != 0) |
877 | return failure(); |
878 | distributedShape[i] = dim / laneLayout[i - distributionStart]; |
879 | } |
880 | return VectorType::get(distributedShape, originalType.getElementType()); |
881 | } |
882 | |
883 | /// Helper function to resolve types if the distributed type out of |
884 | /// gpu.warp_execute_on_lane0 is different from the expected xegpu SIMT type. |
885 | /// Example 1: |
886 | /// distributed type: vector<8x1xf32> |
887 | /// expected type: vector<8xf32> |
888 | /// resolved using, |
889 | /// %0 = vector.shape_cast %1 : vector<8x1xf32> to vector<8xf32> |
890 | /// Example 2: |
891 | /// distributed type: xegpu.tensor_desc<8x16xf32, #xegpu.layout<...>> |
892 | /// expected type: xegpu.tensor_desc<8x16xf32> |
893 | /// resolved using, |
894 | /// %0 = unrealized_conversion_cast %1 : |
895 | /// xegpu.tensor_desc<8x16xf32, #xegpu.layout<..>> -> |
896 | /// xegpu.tensor_desc<8x16xf32> |
897 | template <typename T> |
898 | static Value resolveDistributedTy(Value orig, T expected, |
899 | PatternRewriter &rewriter) { |
900 | // If orig and expected types are the same, return orig. |
901 | if (orig.getType() == expected) |
902 | return orig; |
903 | // If orig is a vector type, create a shape cast op to reconcile the types. |
904 | if (isa<VectorType>(Val: orig.getType())) { |
905 | auto castOp = |
906 | rewriter.create<vector::ShapeCastOp>(orig.getLoc(), expected, orig); |
907 | return castOp.getResult(); |
908 | } |
909 | // If orig is a tensor descriptor type, create an unrealized conversion cast |
910 | // op to reconcile the types. |
911 | if (isa<xegpu::TensorDescType>(Val: orig.getType())) { |
912 | auto castOp = rewriter.create<UnrealizedConversionCastOp>(orig.getLoc(), |
913 | expected, orig); |
914 | return castOp.getResult(0); |
915 | } |
916 | llvm_unreachable("Unsupported type for reconciliation" ); |
917 | return orig; |
918 | } |
919 | |
920 | /// Helper function to filter out the temporary layout attributes attached |
921 | /// during the layout assignment process. These are not needed after going to |
922 | /// SIMT. |
923 | static SmallVector<NamedAttribute> |
924 | removeTemporaryLayoutAttributes(ArrayRef<NamedAttribute> attrs) { |
925 | SmallVector<NamedAttribute> newAttrs; |
926 | for (NamedAttribute attr : attrs) { |
927 | if (!isa<xegpu::LayoutAttr>(Val: attr.getValue())) |
928 | newAttrs.push_back(Elt: attr); |
929 | } |
930 | return newAttrs; |
931 | } |
932 | |
933 | /// Helper function to check if the layout is packed. Layout is packed if it is |
934 | /// 2D and lane_data[0] != 1 (data packed from col dimension). |
935 | static bool hasPackedLayout(xegpu::LayoutAttr layout) { |
936 | if (layout == xegpu::LayoutAttr()) |
937 | return false; |
938 | DenseI32ArrayAttr laneData = layout.getLaneData(); |
939 | if (!laneData || laneData.size() != 2) |
940 | return false; |
941 | return laneData.asArrayRef()[0] != 1; |
942 | } |
943 | |
944 | /// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body |
945 | /// of the original GPUFuncOp to the new GPUFuncOp such that entire body is |
946 | /// contained within a WarpExecuteOnLane0Op. |
947 | /// Example: |
948 | /// |
949 | /// ``` |
950 | /// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> { |
951 | /// ... |
952 | /// ... |
953 | /// gpu.return %result: vector<8x16xf32> |
954 | /// } |
955 | /// ``` |
956 | /// To |
957 | /// ``` |
958 | /// gpu.func @foo(%arg0: memref<*xf16>) -> vector<8x16xf32> { |
959 | /// %laneid = gpu.lane_id : index |
960 | /// %0 = gpu.warp_execute_on_lane_0(%laneid) -> vector<8x16xf32> { |
961 | /// ... |
962 | /// ... |
963 | /// gpu.yield %result: vector<8x16xf32> |
964 | /// } |
965 | /// return %0 |
966 | /// } |
967 | struct MoveFuncBodyToWarpExecuteOnLane0 |
968 | : public OpRewritePattern<gpu::GPUFuncOp> { |
969 | using OpRewritePattern<gpu::GPUFuncOp>::OpRewritePattern; |
970 | LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, |
971 | PatternRewriter &rewriter) const override { |
972 | // If the function only contains a single void return, skip. |
973 | if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](Operation &op) { |
974 | return isa<gpu::ReturnOp>(op) && !op.getNumOperands(); |
975 | })) |
976 | return failure(); |
977 | // If the function already moved inside a warp_execute_on_lane0, skip. |
978 | if (llvm::any_of(gpuFuncOp.getBody().getOps(), [](Operation &op) { |
979 | return isa<gpu::WarpExecuteOnLane0Op>(op); |
980 | })) |
981 | return failure(); |
982 | // Create a new function with the same signature. |
983 | auto newGpuFunc = rewriter.create<gpu::GPUFuncOp>( |
984 | gpuFuncOp.getLoc(), gpuFuncOp.getName(), gpuFuncOp.getFunctionType()); |
985 | // Create a WarpExecuteOnLane0Op with same arguments and results as the |
986 | // original gpuFuncOp. |
987 | rewriter.setInsertionPointToEnd(&newGpuFunc.getFunctionBody().front()); |
988 | auto laneId = rewriter.create<gpu::LaneIdOp>( |
989 | newGpuFunc.getLoc(), rewriter.getIndexType(), |
990 | /** upperBound = **/ mlir::IntegerAttr()); |
991 | ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults(); |
992 | auto warpOp = rewriter.create<gpu::WarpExecuteOnLane0Op>( |
993 | laneId.getLoc(), gpuFuncResultType, laneId, subgroupSize, |
994 | newGpuFunc.getArguments(), newGpuFunc.getArgumentTypes()); |
995 | Block &warpBodyBlock = warpOp.getBodyRegion().front(); |
996 | // Replace the ReturnOp of the original gpu function with a YieldOp. |
997 | auto origRetunOp = |
998 | cast<gpu::ReturnOp>(gpuFuncOp.getBlocks().back().getTerminator()); |
999 | rewriter.setInsertionPointAfter(origRetunOp); |
1000 | rewriter.create<gpu::YieldOp>(origRetunOp.getLoc(), |
1001 | origRetunOp.getOperands()); |
1002 | rewriter.eraseOp(op: origRetunOp); |
1003 | // Move the original function body to the WarpExecuteOnLane0Op body. |
1004 | rewriter.inlineRegionBefore(gpuFuncOp.getBody(), warpOp.getBodyRegion(), |
1005 | warpOp.getBodyRegion().begin()); |
1006 | rewriter.eraseBlock(block: &warpBodyBlock); |
1007 | // Insert a new ReturnOp after the WarpExecuteOnLane0Op. |
1008 | rewriter.setInsertionPointAfter(warpOp); |
1009 | rewriter.create<gpu::ReturnOp>(newGpuFunc.getLoc(), warpOp.getResults()); |
1010 | rewriter.replaceOp(gpuFuncOp, newGpuFunc); |
1011 | return success(); |
1012 | } |
1013 | }; |
1014 | |
1015 | /// Distribute a create_nd_tdesc feeding into vector.yield op of the enclosing |
1016 | /// `gpu.warp_execute_on_lane_0` region. After the sinking, the warp op will |
1017 | /// still contain the original op that will not be used by the yield op (and |
1018 | /// should be cleaned up later). The yield op will bypass the create_nd_tdesc's |
1019 | /// arguments. Tensor descriptor shape is not distributed because it is a |
1020 | /// uniform value across all work items within the subgroup. However, the |
1021 | /// layout information is dropped in the new tensor descriptor type. |
1022 | /// |
1023 | /// Example: |
1024 | /// |
1025 | /// ``` |
1026 | /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]> |
1027 | /// %r = gpu.warp_execute_on_lane_0(%laneid) -> |
1028 | /// (!xegpu.tensor_desc<4x8xf32, #layout0>) { |
1029 | /// ... |
1030 | /// %td = xegpu.create_nd_tdesc %arg0[0, 0] |
1031 | /// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0> |
1032 | /// vector.yield %td |
1033 | /// } |
1034 | /// ``` |
1035 | /// To |
1036 | /// ``` |
1037 | /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (...) { |
1038 | /// ... |
1039 | /// %dead = xegpu.create_nd_tdesc %arg0[0, 0] |
1040 | /// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0> |
1041 | /// vector.yield %arg0, %dead |
1042 | /// } |
1043 | /// %td = xegpu.create_nd_tdesc %r#0[0, 0]: memref<4x8xf32> |
1044 | /// -> !xegpu.tensor_desc<4x8xf32> |
1045 | /// |
1046 | /// ``` |
1047 | struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern { |
1048 | using gpu::WarpDistributionPattern::WarpDistributionPattern; |
1049 | LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp, |
1050 | PatternRewriter &rewriter) const override { |
1051 | OpOperand *operand = |
1052 | getWarpResult(subgroupOp, llvm::IsaPred<xegpu::CreateNdDescOp>); |
1053 | if (!operand) |
1054 | return rewriter.notifyMatchFailure( |
1055 | subgroupOp, "warp result is not a xegpu::CreateNdDesc op" ); |
1056 | auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>(); |
1057 | unsigned operandIdx = operand->getOperandNumber(); |
1058 | |
1059 | xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr(); |
1060 | if (!layout) |
1061 | return rewriter.notifyMatchFailure( |
1062 | descOp, "the tensor descriptor lacks layout attribute" ); |
1063 | |
1064 | SmallVector<size_t> newRetIndices; |
1065 | SmallVector<Value> newYieldValues; |
1066 | SmallVector<Type> newYieldTypes; |
1067 | |
1068 | for (Value operand : descOp->getOperands()) { |
1069 | newYieldValues.push_back(operand); |
1070 | newYieldTypes.push_back(operand.getType()); |
1071 | } |
1072 | rewriter.setInsertionPoint(subgroupOp); |
1073 | gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
1074 | rewriter, subgroupOp, /* new yieled values = */ newYieldValues, |
1075 | /* new yielded types = */ newYieldTypes, newRetIndices); |
1076 | |
1077 | SmallVector<Value> newDescOperands; |
1078 | for (size_t i : newRetIndices) { |
1079 | newDescOperands.push_back(Elt: newWarpOp.getResult(i)); |
1080 | } |
1081 | rewriter.setInsertionPointAfter(newWarpOp); |
1082 | xegpu::TensorDescType distributedTensorDescTy = |
1083 | descOp.getType().dropLayouts(); // Distributed tensor descriptor type |
1084 | // does not contain layout info. |
1085 | auto newDescOp = rewriter.create<xegpu::CreateNdDescOp>( |
1086 | newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands, |
1087 | descOp->getAttrs()); |
1088 | |
1089 | Value distributedVal = newWarpOp.getResult(operandIdx); |
1090 | rewriter.replaceAllUsesWith(distributedVal, newDescOp); |
1091 | return success(); |
1092 | } |
1093 | }; |
1094 | |
1095 | /// Distribute a store_nd op at the end of enclosing |
1096 | /// `gpu.warp_execute_on_lane_0`. In case arguments for the store are passed |
1097 | /// through the warp op interface they would be propagated as returned values. |
1098 | /// Source vector is distributed based on lane layout. Appropriate cast ops are |
1099 | /// inserted if the distributed types does not match expected xegpu SIMT types. |
1100 | /// |
1101 | /// Example: |
1102 | /// |
1103 | /// ``` |
1104 | /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]> |
1105 | /// gpu.warp_execute_on_lane_0(%laneid) -> () { |
1106 | /// ... |
1107 | /// xegpu.store_nd %arg0, %arg1: vector<4x8xf32>, |
1108 | /// !xegpu.tensor_desc<4x8xf32, #layout0> |
1109 | /// } |
1110 | /// ``` |
1111 | /// To |
1112 | /// ``` |
1113 | /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>, |
1114 | /// !xegpu.tensor_desc<4x8xf32, #layout0>) { |
1115 | /// gpu.yield %arg0, %arg1: vector<4x8xf32>, !xegpu.tensor_desc<4x8xf32, |
1116 | /// #layout0> |
1117 | /// } |
1118 | /// %0 = vector.shape_cast %r#0: vector<4x1xf32> to vector<4xf32> |
1119 | /// %1 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32, |
1120 | /// #layout0> |
1121 | /// -> !xegpu.tensor_desc<4x8xf32> |
1122 | /// xegpu.store_nd %0, %1: vector<4xf32>, |
1123 | /// !xegpu.tensor_desc<4x8xf32> |
1124 | /// |
1125 | /// ``` |
1126 | struct StoreNdDistribution final : public gpu::WarpDistributionPattern { |
1127 | using gpu::WarpDistributionPattern::WarpDistributionPattern; |
1128 | LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp, |
1129 | PatternRewriter &rewriter) const override { |
1130 | auto yield = cast<gpu::YieldOp>( |
1131 | subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
1132 | Operation *lastNode = yield->getPrevNode(); |
1133 | auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode); |
1134 | if (!storeOp) |
1135 | return failure(); |
1136 | |
1137 | xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType(); |
1138 | xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr(); |
1139 | if (!layout) |
1140 | return rewriter.notifyMatchFailure( |
1141 | storeOp, "the source tensor descriptor lacks layout attribute" ); |
1142 | |
1143 | FailureOr<VectorType> distributedTypeByWarpOpOrFailure = |
1144 | getDistVecTypeBasedOnLaneLayout(layout, storeOp.getValueType()); |
1145 | if (failed(Result: distributedTypeByWarpOpOrFailure)) |
1146 | return rewriter.notifyMatchFailure(storeOp, |
1147 | "Failed to distribute the type" ); |
1148 | VectorType distributedTypeByWarpOp = |
1149 | distributedTypeByWarpOpOrFailure.value(); |
1150 | |
1151 | SmallVector<size_t> newRetIndices; |
1152 | gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
1153 | rewriter, subgroupOp, |
1154 | /* new yielded values = */ |
1155 | ValueRange{storeOp.getValue(), storeOp.getTensorDesc()}, |
1156 | /* new yielded types = */ |
1157 | TypeRange{distributedTypeByWarpOp, storeOp.getTensorDescType()}, |
1158 | newRetIndices); |
1159 | // Create a new store op outside the warp op with the distributed vector |
1160 | // type. Tensor descriptor is not distributed. |
1161 | rewriter.setInsertionPointAfter(newWarpOp); |
1162 | SmallVector<Value> newStoreOperands; |
1163 | |
1164 | // For the value operand, there can be a mismatch between the vector type |
1165 | // distributed by the warp op and (xegpu-specific) distributed type |
1166 | // supported by the store op. Type mismatch must be resolved using |
1167 | // appropriate cast op. |
1168 | FailureOr<VectorType> storeNdDistributedValueTyOrFailure = |
1169 | xegpu::getDistributedVectorType(storeOp.getTensorDescType()); |
1170 | if (failed(Result: storeNdDistributedValueTyOrFailure)) |
1171 | return rewriter.notifyMatchFailure( |
1172 | storeOp, "Failed to get distributed vector type for the store op" ); |
1173 | newStoreOperands.push_back(Elt: resolveDistributedTy( |
1174 | newWarpOp.getResult(newRetIndices[0]), |
1175 | storeNdDistributedValueTyOrFailure.value(), rewriter)); |
1176 | // For the tensor descriptor operand, the layout attribute is dropped after |
1177 | // distribution. Types needs to be resolved in this case also. |
1178 | xegpu::TensorDescType distributedTensorDescTy = |
1179 | storeOp.getTensorDescType().dropLayouts(); |
1180 | newStoreOperands.push_back( |
1181 | Elt: resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]), |
1182 | distributedTensorDescTy, rewriter)); |
1183 | |
1184 | rewriter.create<xegpu::StoreNdOp>( |
1185 | newWarpOp.getLoc(), TypeRange{}, newStoreOperands, |
1186 | removeTemporaryLayoutAttributes(storeOp->getAttrs())); |
1187 | rewriter.eraseOp(op: storeOp); |
1188 | return success(); |
1189 | } |
1190 | }; |
1191 | |
1192 | /// Distribute a load_nd op feeding into vector.yield op for the enclosing |
1193 | /// `gpu.warp_execute_on_lane_0` and put it after the warp op. |
1194 | /// The warp op will still contain the original op that will not be used by |
1195 | /// the yield op (and should be cleaned up later). The yield op will |
1196 | /// bypass the load's arguments. Only the loaded vector is distributed |
1197 | /// according to lane layout and, tensor descriptor types is not |
1198 | /// distributed. Appropriate cast ops are inserted if the distributed types does |
1199 | /// not match expected xegpu SIMT types. |
1200 | /// |
1201 | /// Example: |
1202 | /// |
1203 | /// ``` |
1204 | /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]> |
1205 | /// %r = gpu.warp_execute_on_lane_0(%laneid) -> |
1206 | /// (vector<4x1xf32>) { |
1207 | /// ... |
1208 | /// %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32, #layout0> |
1209 | /// -> |
1210 | /// vector<4x8xf32> |
1211 | /// gpu.yield %ld |
1212 | /// } |
1213 | /// ``` |
1214 | /// To |
1215 | /// ``` |
1216 | /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>, |
1217 | /// !xegpu.tensor_desc<4x8xf32, #layout0>) { |
1218 | /// ... |
1219 | /// %dead = xegpu.load_nd %arg0: !xegpu.tensor_desc<4x8xf32, #layout0> -> |
1220 | /// vector<4x8xf32> gpu.yield %dead, %arg0 |
1221 | /// } |
1222 | /// %0 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32, |
1223 | /// #layout0> -> !xegpu.tensor_desc<4x8xf32> |
1224 | /// %1 = xegpu.load_nd %0: !xegpu.tensor_desc<4x8xf32> -> vector<4xf32> |
1225 | /// %2 = vector.shape_cast %r#0: vector<4xf32> to vector<4x1xf32> |
1226 | /// |
1227 | /// ``` |
1228 | struct LoadNdDistribution final : public gpu::WarpDistributionPattern { |
1229 | using gpu::WarpDistributionPattern::WarpDistributionPattern; |
1230 | LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp, |
1231 | PatternRewriter &rewriter) const override { |
1232 | OpOperand *operand = |
1233 | getWarpResult(subgroupOp, llvm::IsaPred<xegpu::LoadNdOp>); |
1234 | if (!operand) |
1235 | return rewriter.notifyMatchFailure( |
1236 | subgroupOp, "warp result is not a xegpu::LoadNd op" ); |
1237 | |
1238 | auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>(); |
1239 | xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType(); |
1240 | xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr(); |
1241 | if (!layout) |
1242 | return rewriter.notifyMatchFailure( |
1243 | loadOp, "the source tensor descriptor lacks layout attribute" ); |
1244 | |
1245 | unsigned operandIdx = operand->getOperandNumber(); |
1246 | VectorType distributedTypeByWarpOp = |
1247 | cast<VectorType>(subgroupOp.getResult(operandIdx).getType()); |
1248 | |
1249 | SmallVector<size_t> newRetIndices; |
1250 | gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
1251 | rewriter, subgroupOp, |
1252 | /* new yielded values = */ loadOp.getTensorDesc(), |
1253 | /* new yielded types = */ tensorDescTy, newRetIndices); |
1254 | |
1255 | // Create a new load op outside the warp op with the distributed vector |
1256 | // type. |
1257 | rewriter.setInsertionPointAfter(newWarpOp); |
1258 | FailureOr<VectorType> loadNdDistValueTyOrFailure = |
1259 | xegpu::getDistributedVectorType(loadOp.getTensorDescType()); |
1260 | if (failed(Result: loadNdDistValueTyOrFailure)) |
1261 | return rewriter.notifyMatchFailure( |
1262 | loadOp, "Failed to get distributed vector type for the load op" ); |
1263 | xegpu::TensorDescType distributedTensorDescTy = |
1264 | loadOp.getTensorDescType().dropLayouts(); // Distributed tensor |
1265 | // descriptor type does not |
1266 | // contain layout info. |
1267 | auto newLoadOp = rewriter.create<xegpu::LoadNdOp>( |
1268 | newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(), |
1269 | resolveDistributedTy(newWarpOp->getResult(newRetIndices[0]), |
1270 | distributedTensorDescTy, rewriter), |
1271 | removeTemporaryLayoutAttributes(loadOp->getAttrs())); |
1272 | // Set the packed attribute if the layout requires it. |
1273 | newLoadOp.setPacked(hasPackedLayout(layout)); |
1274 | Value distributedVal = newWarpOp.getResult(operandIdx); |
1275 | // There can be a conflict between the vector type distributed by the |
1276 | // warp op and (xegpu-specific) distributed type supported by the load |
1277 | // op. Resolve these mismatches by inserting a cast. |
1278 | Value tyResolvedVal = resolveDistributedTy( |
1279 | newLoadOp.getResult(), distributedTypeByWarpOp, rewriter); |
1280 | rewriter.replaceAllUsesWith(from: distributedVal, to: tyResolvedVal); |
1281 | return success(); |
1282 | } |
1283 | }; |
1284 | |
1285 | /// Distribute a dpas op feeding into vector.yield op for the enclosing |
1286 | /// `gpu.warp_execute_on_lane_0` and put it after the warp op. |
1287 | /// The warp op will still contain the original op that will not be used by |
1288 | /// the yield op (and should be cleaned up later). The yield op will |
1289 | /// bypass the dpas's arguments. Appropriate cast ops are inserted if the |
1290 | /// distributed types does not match expected xegpu SIMT types. |
1291 | /// Example: |
1292 | /// ``` |
1293 | /// #lo_a = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]> |
1294 | /// #lo_b = #xegpu.layout<wi_layout = [1, 16], wi_data = [2, 1]> |
1295 | /// #lo_c = #xegpu.layout<wi_layout = [1, 16], wi_data = [1, 1]> |
1296 | /// %r = gpu.warp_execute_on_lane_0(%laneid) -> |
1297 | /// (vector<8x1xf32>) { |
1298 | /// ... |
1299 | /// %dpas = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16> -> |
1300 | /// vector<8x16xf32> |
1301 | /// gpu.yield %dpas |
1302 | /// } |
1303 | /// ``` |
1304 | /// To |
1305 | /// ``` |
1306 | /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<8x1xf32>, |
1307 | /// vector<8x1xf16>, vector<16x1xf16>) { |
1308 | /// ... |
1309 | /// %dead = xegpu.dpas %arg0, %arg1: vector<8x16xf16>, vector<16x16xf16> |
1310 | /// -> vector<8x16xf32> |
1311 | /// gpu.yield %dead, %arg0, %arg1 |
1312 | /// } |
1313 | /// %0 = vector.shape_cast %r#1: vector<8x1xf16> to vector<8xf16> |
1314 | /// %1 = vector.shape_cast %r#2: vector<16x1xf16> to vector<16xf16> |
1315 | /// %2 = xegpu.dpas %0, %1: vector<8xf16>, vector<16xf16> -> |
1316 | /// vector<8xf32> |
1317 | /// %dpas = vector.shape_cast %2: vector<8xf32> to vector<8x1xf32> |
1318 | /// ``` |
1319 | struct DpasDistribution final : public gpu::WarpDistributionPattern { |
1320 | using gpu::WarpDistributionPattern::WarpDistributionPattern; |
1321 | LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp, |
1322 | PatternRewriter &rewriter) const override { |
1323 | OpOperand *operand = |
1324 | getWarpResult(subgroupOp, llvm::IsaPred<xegpu::DpasOp>); |
1325 | if (!operand) |
1326 | return rewriter.notifyMatchFailure(subgroupOp, |
1327 | "warp result is not a xegpu::Dpas op" ); |
1328 | |
1329 | auto dpasOp = operand->get().getDefiningOp<xegpu::DpasOp>(); |
1330 | unsigned operandIdx = operand->getOperandNumber(); |
1331 | std::string layoutAName = xegpu::getLayoutName(dpasOp->getOpOperand(0)); |
1332 | std::string layoutBName = xegpu::getLayoutName(dpasOp->getOpOperand(1)); |
1333 | std::string layoutCName = xegpu::getLayoutName(dpasOp->getOpResult(0)); |
1334 | |
1335 | xegpu::LayoutAttr layoutA = |
1336 | dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutAName); |
1337 | xegpu::LayoutAttr layoutB = |
1338 | dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutBName); |
1339 | xegpu::LayoutAttr layoutOut = |
1340 | dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutCName); |
1341 | if (!layoutA || !layoutB || !layoutOut) |
1342 | return rewriter.notifyMatchFailure( |
1343 | dpasOp, |
1344 | "the xegpu::Dpas op lacks layout attribute for A, B or output" ); |
1345 | |
1346 | FailureOr<VectorType> distLhsTypeByWarpOpOrFailure = |
1347 | getDistVecTypeBasedOnLaneLayout(layoutA, dpasOp.getLhsType()); |
1348 | FailureOr<VectorType> distRhsTypeByWarpOpOrFailure = |
1349 | getDistVecTypeBasedOnLaneLayout(layoutB, dpasOp.getRhsType()); |
1350 | FailureOr<VectorType> distResultTypeByWarpOpOrFailure = |
1351 | getDistVecTypeBasedOnLaneLayout(layoutOut, dpasOp.getResultType()); |
1352 | if (failed(Result: distLhsTypeByWarpOpOrFailure) || |
1353 | failed(Result: distRhsTypeByWarpOpOrFailure) || |
1354 | failed(Result: distResultTypeByWarpOpOrFailure)) |
1355 | return rewriter.notifyMatchFailure( |
1356 | dpasOp, |
1357 | "Failed to distribute the A, B or output types in xegpu::Dpas op" ); |
1358 | |
1359 | llvm::SmallVector<Value, 3> newYieldValues{dpasOp.getLhs(), |
1360 | dpasOp.getRhs()}; |
1361 | llvm::SmallVector<Type, 3> newYieldTypes{ |
1362 | distLhsTypeByWarpOpOrFailure.value(), |
1363 | distRhsTypeByWarpOpOrFailure.value()}; |
1364 | // Dpas acc operand is optional. |
1365 | if (dpasOp.getAcc()) { |
1366 | newYieldValues.push_back(Elt: dpasOp.getAcc()); |
1367 | newYieldTypes.push_back(Elt: distResultTypeByWarpOpOrFailure.value()); |
1368 | } |
1369 | // Create a new warp op without the dpas. |
1370 | SmallVector<size_t> newRetIndices; |
1371 | gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
1372 | rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices); |
1373 | |
1374 | FailureOr<VectorType> expectedDistLhsTyOrFailure = |
1375 | xegpu::getDistributedVectorType(dpasOp.getLhsType(), layoutA); |
1376 | FailureOr<VectorType> expectedDistRhsTyOrFailure = |
1377 | xegpu::getDistributedVectorType(dpasOp.getRhsType(), layoutB); |
1378 | FailureOr<VectorType> expectedDistResultTyOrFailure = |
1379 | xegpu::getDistributedVectorType(dpasOp.getResultType(), layoutOut); |
1380 | if (failed(Result: expectedDistLhsTyOrFailure) || |
1381 | failed(Result: expectedDistRhsTyOrFailure) || |
1382 | failed(Result: expectedDistResultTyOrFailure)) |
1383 | return rewriter.notifyMatchFailure( |
1384 | dpasOp, |
1385 | "Failed to get distributed vector type for the dpas operands." ); |
1386 | // Create a new dpas op outside the warp op. |
1387 | rewriter.setInsertionPointAfter(newWarpOp); |
1388 | SmallVector<Value> newDpasOperands; |
1389 | SmallVector<VectorType> newDpasOperandExpectedTypes; |
1390 | |
1391 | // Resolve the distributed types with the original types. |
1392 | newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value()); |
1393 | newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value()); |
1394 | VectorType distributedResultTy = expectedDistResultTyOrFailure.value(); |
1395 | if (dpasOp.getAcc()) |
1396 | newDpasOperandExpectedTypes.push_back(distributedResultTy); |
1397 | |
1398 | for (unsigned i = 0; i < newRetIndices.size(); i++) { |
1399 | newDpasOperands.push_back( |
1400 | Elt: resolveDistributedTy(newWarpOp.getResult(newRetIndices[i]), |
1401 | newDpasOperandExpectedTypes[i], rewriter)); |
1402 | } |
1403 | Value newDpasOp = rewriter.create<xegpu::DpasOp>( |
1404 | newWarpOp->getLoc(), distributedResultTy, newDpasOperands, |
1405 | removeTemporaryLayoutAttributes(dpasOp->getAttrs())); |
1406 | Value distributedVal = newWarpOp.getResult(operandIdx); |
1407 | // Resolve the output type. |
1408 | newDpasOp = resolveDistributedTy( |
1409 | newDpasOp, distResultTypeByWarpOpOrFailure.value(), rewriter); |
1410 | rewriter.replaceAllUsesWith(from: distributedVal, to: newDpasOp); |
1411 | return success(); |
1412 | } |
1413 | }; |
1414 | |
1415 | /// Sink an update_nd_offset op feeding into yield op of an enclosing |
1416 | /// `gpu.warp_execute_on_lane_0` region. The warp op will still contain the |
1417 | /// original op that will not be used by the yield op (and should be cleaned |
1418 | /// up later). The yield op will bypass the updateOp's arguments. The tensor |
1419 | /// descriptor type is not distributed. Appropriate cast ops are inserted if |
1420 | /// the distributed types does not match expected xegpu SIMT types. |
1421 | /// Example: |
1422 | /// ``` |
1423 | /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]> |
1424 | /// %r = gpu.warp_execute_on_lane_0(%laneid) -> |
1425 | /// (!xegpu.tensor_desc<4x8xf32, #layout0>) { |
1426 | /// ... |
1427 | /// %update = xegpu.update_nd_offset %arg0, [%c32, %c16]: |
1428 | /// !xegpu.tensor_desc<4x8xf32, #layout0> |
1429 | /// gpu.yield %update |
1430 | /// } |
1431 | /// ... |
1432 | /// ``` |
1433 | /// To |
1434 | /// ``` |
1435 | /// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> ( |
1436 | /// !xegpu.tensor_desc<4x8xf32, #layout0>, |
1437 | /// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) { |
1438 | /// ... |
1439 | /// %dead = xegpu.update_nd_offset %arg0, [%c32, %c16]: |
1440 | /// !xegpu.tensor_desc<4x8xf32, #layout0> gpu.yield %dead, %arg0 |
1441 | /// gpu.yield %dead, %arg0, %c32, %c16 |
1442 | /// } |
1443 | /// %0 = xegpu.unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32, |
1444 | /// #layout0> -> !xegpu.tensor_desc<4x8xf32> |
1445 | /// %1 = xegpu.update_nd_offset %0, [%r#2, %r#3]: |
1446 | /// !xegpu.tensor_desc<4x8xf32> |
1447 | /// ... |
1448 | /// ``` |
1449 | struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern { |
1450 | using gpu::WarpDistributionPattern::WarpDistributionPattern; |
1451 | LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp, |
1452 | PatternRewriter &rewriter) const override { |
1453 | OpOperand *operand = |
1454 | getWarpResult(subgroupOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>); |
1455 | if (!operand) |
1456 | return rewriter.notifyMatchFailure( |
1457 | subgroupOp, "warp result is not a xegpu::UpdateNdOffset op" ); |
1458 | auto updateOp = operand->get().getDefiningOp<xegpu::UpdateNdOffsetOp>(); |
1459 | unsigned operandIdx = operand->getOperandNumber(); |
1460 | // new update op does not have layout attribute. |
1461 | xegpu::TensorDescType newTensorDescTy = |
1462 | updateOp.getTensorDescType().dropLayouts(); |
1463 | |
1464 | SmallVector<Value, 3> newYieldValues; |
1465 | SmallVector<Type, 3> newYieldTypes; |
1466 | for (Value operand : updateOp->getOperands()) { |
1467 | newYieldValues.push_back(operand); |
1468 | if (isa<xegpu::TensorDescType>(operand.getType())) { |
1469 | newYieldTypes.push_back(newTensorDescTy); |
1470 | } else { |
1471 | newYieldTypes.push_back(operand.getType()); |
1472 | } |
1473 | } |
1474 | SmallVector<size_t> newRetIndices; |
1475 | gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
1476 | rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices); |
1477 | rewriter.setInsertionPointAfter(newWarpOp); |
1478 | SmallVector<Value> newUpdateOperands; |
1479 | for (size_t i : newRetIndices) { |
1480 | // For the tensor descriptor operand, the layout attribute is dropped |
1481 | // after distribution. Types needs to be resolved in this case. |
1482 | if (isa<xegpu::TensorDescType>(newWarpOp.getResult(i).getType())) { |
1483 | newUpdateOperands.push_back(Elt: resolveDistributedTy( |
1484 | newWarpOp.getResult(i), newTensorDescTy, rewriter)); |
1485 | } else { |
1486 | newUpdateOperands.push_back(Elt: newWarpOp.getResult(i)); |
1487 | } |
1488 | } |
1489 | // Create a new update op outside the warp op. |
1490 | auto newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>( |
1491 | newWarpOp.getLoc(), newTensorDescTy, newUpdateOperands, |
1492 | removeTemporaryLayoutAttributes(updateOp->getAttrs())); |
1493 | Value distributedVal = newWarpOp.getResult(operandIdx); |
1494 | rewriter.replaceAllUsesWith(distributedVal, newUpdateOp); |
1495 | return success(); |
1496 | } |
1497 | }; |
1498 | |
1499 | /// Distribute a prefetch_nd op at the end of enclosing |
1500 | /// `gpu.warp_execute_on_lane_0`. In case arguments for the prefetch are passed |
1501 | /// through the warp op interface they would be propagated as returned values. |
1502 | /// Tensor descriptor shape is not distributed because it is a uniform value |
1503 | /// across all work items within the subgroup. Appropriate cast ops are inserted |
1504 | /// if the distributed types does not match expected xegpu SIMT types. |
1505 | /// |
1506 | /// Example: |
1507 | /// |
1508 | /// ``` |
1509 | /// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]> |
1510 | /// gpu.warp_execute_on_lane_0(%laneid) -> () { |
1511 | /// ... |
1512 | /// xegpu.prefetch_nd %arg0 : !xegpu.tensor_desc<4x8xf32, #layout0> |
1513 | /// } |
1514 | /// ``` |
1515 | /// To |
1516 | /// ``` |
1517 | /// %r:1 = gpu.warp_execute_on_lane_0(%laneid) -> ( |
1518 | /// !xegpu.tensor_desc<4x8xf32, #layout0>) { |
1519 | /// gpu.yield %arg0: !xegpu.tensor_desc<4x8xf32, #layout0> |
1520 | /// } |
1521 | /// %1 = unrealized_conversion_cast %r#0: !xegpu.tensor_desc<4x8xf32, |
1522 | /// #layout0> -> !xegpu.tensor_desc<4x8xf32> |
1523 | /// xegpu.prefetch_nd %1 : !xegpu.tensor_desc<4x8xf32> |
1524 | /// |
1525 | /// ``` |
1526 | struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern { |
1527 | using gpu::WarpDistributionPattern::WarpDistributionPattern; |
1528 | LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp, |
1529 | PatternRewriter &rewriter) const override { |
1530 | auto yield = cast<gpu::YieldOp>( |
1531 | subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
1532 | Operation *lastNode = yield->getPrevNode(); |
1533 | auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode); |
1534 | if (!prefetchOp) |
1535 | return failure(); |
1536 | xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr(); |
1537 | if (!layout) |
1538 | return rewriter.notifyMatchFailure( |
1539 | prefetchOp, "the source tensor descriptor lacks layout attribute" ); |
1540 | |
1541 | SmallVector<Value, 1> newYieldValues = {prefetchOp.getTensorDesc()}; |
1542 | SmallVector<Type, 1> newYieldTypes = {prefetchOp.getTensorDescType()}; |
1543 | SmallVector<size_t> newRetIndices; |
1544 | gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
1545 | rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices); |
1546 | // Create a new prefetch op outside the warp op with updated tensor |
1547 | // descriptor type. Source tensor descriptor require type resolution. |
1548 | xegpu::TensorDescType newTensorDescTy = |
1549 | prefetchOp.getTensorDescType().dropLayouts(); |
1550 | rewriter.setInsertionPointAfter(newWarpOp); |
1551 | SmallVector<Value> newPrefetchOperands = {resolveDistributedTy( |
1552 | newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)}; |
1553 | rewriter.create<xegpu::PrefetchNdOp>( |
1554 | newWarpOp.getLoc(), TypeRange{}, newPrefetchOperands, |
1555 | removeTemporaryLayoutAttributes(prefetchOp->getAttrs())); |
1556 | rewriter.eraseOp(op: prefetchOp); |
1557 | return success(); |
1558 | } |
1559 | }; |
1560 | |
1561 | } // namespace |
1562 | |
1563 | namespace { |
1564 | struct XeGPUSubgroupDistributePass final |
1565 | : public xegpu::impl::XeGPUSubgroupDistributeBase< |
1566 | XeGPUSubgroupDistributePass> { |
1567 | XeGPUSubgroupDistributePass() = default; |
1568 | XeGPUSubgroupDistributePass(const XeGPUSubgroupDistributePass &other) = |
1569 | default; |
1570 | XeGPUSubgroupDistributePass(xegpu::XeGPUSubgroupDistributeOptions options) |
1571 | : XeGPUSubgroupDistributeBase(options) {} |
1572 | void runOnOperation() override; |
1573 | }; |
1574 | } // namespace |
1575 | |
1576 | void xegpu::populateXeGPUSubgroupDistributePatterns( |
1577 | RewritePatternSet &patterns) { |
1578 | patterns.add<CreateNdDescDistribution, StoreNdDistribution, |
1579 | LoadNdDistribution, DpasDistribution, PrefetchNdDistribution, |
1580 | UpdateNdOffsetDistribution>(arg: patterns.getContext()); |
1581 | } |
1582 | |
1583 | void XeGPUSubgroupDistributePass::runOnOperation() { |
1584 | auto &analyis = getAnalysis<RunLayoutInfoPropagation>(); |
1585 | // Print the analysis result and exit. (for testing purposes) |
1586 | if (printOnly) { |
1587 | auto &os = llvm::outs(); |
1588 | analyis.printAnalysisResult(os); |
1589 | return; |
1590 | } |
1591 | auto getPropagatedLayout = [&](Value val) { |
1592 | return analyis.getLayoutInfo(val); |
1593 | }; |
1594 | |
1595 | // Assign xegpu::LayoutAttr to all ops and their users based on the layout |
1596 | // propagation analysis result. |
1597 | LayoutAttrAssignment layoutAssignment(getOperation(), getPropagatedLayout); |
1598 | if (failed(Result: layoutAssignment.run())) { |
1599 | signalPassFailure(); |
1600 | return; |
1601 | } |
1602 | |
1603 | // Move all operations of a GPU function inside gpu.warp_execute_on_lane_0 |
1604 | // operation. |
1605 | { |
1606 | RewritePatternSet patterns(&getContext()); |
1607 | patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext()); |
1608 | |
1609 | if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { |
1610 | signalPassFailure(); |
1611 | return; |
1612 | } |
1613 | // At this point, we have moved the entire function body inside the warpOp. |
1614 | // Now move any scalar uniform code outside of the warpOp (like GPU index |
1615 | // ops, scalar constants, etc.). This will simplify the later lowering and |
1616 | // avoid custom patterns for these ops. |
1617 | getOperation()->walk([&](Operation *op) { |
1618 | if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) { |
1619 | vector::moveScalarUniformCode(warpOp); |
1620 | } |
1621 | }); |
1622 | } |
1623 | // Finally, do the SIMD to SIMT distribution. |
1624 | RewritePatternSet patterns(&getContext()); |
1625 | xegpu::populateXeGPUSubgroupDistributePatterns(patterns); |
1626 | // TODO: distributionFn and shuffleFn are not used at this point. |
1627 | auto distributionFn = [](Value val) { |
1628 | VectorType vecType = dyn_cast<VectorType>(val.getType()); |
1629 | int64_t vecRank = vecType ? vecType.getRank() : 0; |
1630 | OpBuilder builder(val.getContext()); |
1631 | if (vecRank == 0) |
1632 | return AffineMap::get(context: val.getContext()); |
1633 | return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext()); |
1634 | }; |
1635 | auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx, |
1636 | int64_t warpSz) { return Value(); }; |
1637 | vector::populatePropagateWarpVectorDistributionPatterns( |
1638 | patterns, distributionFn, shuffleFn); |
1639 | if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { |
1640 | signalPassFailure(); |
1641 | return; |
1642 | } |
1643 | } |
1644 | |