| 1 | //===- XeGPUPropagateLayout.cpp - XeGPU Layout Propagation ------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" |
| 10 | #include "mlir/Analysis/DataFlow/SparseAnalysis.h" |
| 11 | #include "mlir/Analysis/DataFlow/Utils.h" |
| 12 | #include "mlir/Analysis/DataFlowFramework.h" |
| 13 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 14 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 15 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 16 | #include "mlir/Dialect/XeGPU/IR/XeGPU.h" |
| 17 | #include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h" |
| 18 | #include "mlir/Dialect/XeGPU/Transforms/Passes.h" |
| 19 | #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" |
| 20 | #include "mlir/IR/Attributes.h" |
| 21 | #include "mlir/IR/Builders.h" |
| 22 | #include "mlir/IR/BuiltinAttributes.h" |
| 23 | #include "mlir/IR/BuiltinTypes.h" |
| 24 | #include "mlir/IR/Operation.h" |
| 25 | #include "mlir/IR/Value.h" |
| 26 | #include "mlir/IR/Visitors.h" |
| 27 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
| 28 | #include "mlir/Interfaces/FunctionInterfaces.h" |
| 29 | #include "mlir/Support/LLVM.h" |
| 30 | #include "llvm/ADT/ArrayRef.h" |
| 31 | #include "llvm/ADT/STLExtras.h" |
| 32 | #include "llvm/ADT/SmallVector.h" |
| 33 | #include "llvm/ADT/TypeSwitch.h" |
| 34 | #include "llvm/Support/Casting.h" |
| 35 | #include "llvm/Support/Debug.h" |
| 36 | #include "llvm/Support/InterleavedRange.h" |
| 37 | #include "llvm/Support/LogicalResult.h" |
| 38 | #include "llvm/Support/raw_ostream.h" |
| 39 | |
| 40 | namespace mlir { |
| 41 | namespace xegpu { |
| 42 | #define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT |
| 43 | #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" |
| 44 | } // namespace xegpu |
| 45 | } // namespace mlir |
| 46 | |
| 47 | #define DEBUG_TYPE "xegpu-propagate-layout" |
| 48 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
| 49 | |
| 50 | using namespace mlir; |
| 51 | using namespace mlir::dataflow; |
| 52 | |
| 53 | namespace { |
| 54 | |
| 55 | //===----------------------------------------------------------------------===// |
| 56 | // Layout |
| 57 | //===----------------------------------------------------------------------===// |
| 58 | |
| 59 | /// Helper class to store the ND layout of lanes within a subgroup and data |
| 60 | /// owned by each lane. |
| 61 | struct Layout { |
| 62 | SmallVector<int64_t, 3> layout; |
| 63 | Layout() = default; |
| 64 | Layout(std::initializer_list<int64_t> list) : layout(list) {} |
| 65 | void print(llvm::raw_ostream &os) const; |
| 66 | size_t size() const { return layout.size(); } |
| 67 | }; |
| 68 | |
| 69 | void Layout::print(llvm::raw_ostream &os) const { |
| 70 | os << llvm::interleaved_array(R: layout); |
| 71 | } |
| 72 | |
| 73 | /// LaneLayout represents the logical layout of lanes within a subgroup when it |
| 74 | /// accesses some value. LaneData represents the logical layout of data owned by |
| 75 | /// each work item. |
| 76 | using LaneLayout = Layout; |
| 77 | using LaneData = Layout; |
| 78 | |
| 79 | //===----------------------------------------------------------------------===// |
| 80 | // LayoutInfo |
| 81 | //===----------------------------------------------------------------------===// |
| 82 | |
| 83 | /// Helper class for tracking the analysis state of an mlir value. For layout |
| 84 | /// propagation, the analysis state is simply the lane_layout and lane_data of |
| 85 | /// each value. Purpose of this analysis to propagate some unique layout for |
| 86 | /// each value in the program starting from a set of anchor operations (like |
| 87 | /// DPAS, StoreNd, etc.). |
| 88 | /// |
| 89 | /// Given this, LayoutInfo satisifies the following properties: |
| 90 | /// 1) A LayoutInfo value can be in one of two states - `assigned` or `not |
| 91 | /// assigned`. |
| 92 | /// 2) Two LayoutInfo values are equal if they are both assigned or |
| 93 | /// both not assigned. The concrete value of assigned state does not matter. |
| 94 | /// 3) The meet operator works as follows: |
| 95 | /// - If current state is assigned, return the current state. (already |
| 96 | /// a unique layout is assigned. don't change it) |
| 97 | /// - Otherwise, return the other state. |
| 98 | |
| 99 | struct LayoutInfo { |
| 100 | private: |
| 101 | LaneLayout laneLayout; |
| 102 | LaneData laneData; |
| 103 | xegpu::LayoutAttr layoutAttr; |
| 104 | |
| 105 | public: |
| 106 | LayoutInfo() = default; |
| 107 | LayoutInfo(const LaneLayout &layout, const LaneData &data) |
| 108 | : laneLayout(layout), laneData(data) {} |
| 109 | |
| 110 | // Two lattice values are equal if they have `some` layout. The actual |
| 111 | // content of the layout does not matter. |
| 112 | bool operator==(const LayoutInfo &other) const { |
| 113 | return this->isAssigned() == other.isAssigned(); |
| 114 | } |
| 115 | |
| 116 | static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs); |
| 117 | |
| 118 | static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs); |
| 119 | |
| 120 | void print(raw_ostream &os) const; |
| 121 | |
| 122 | bool isAssigned() const { |
| 123 | return laneLayout.size() > 0 && laneData.size() > 0; |
| 124 | } |
| 125 | |
| 126 | LayoutInfo getTransposedLayout(ArrayRef<int64_t> permutation) const; |
| 127 | |
| 128 | const LaneLayout &getLayout() const { return laneLayout; } |
| 129 | const LaneData &getData() const { return laneData; } |
| 130 | ArrayRef<int64_t> getLayoutAsArrayRef() const { return laneLayout.layout; } |
| 131 | ArrayRef<int64_t> getDataAsArrayRef() const { return laneData.layout; } |
| 132 | }; |
| 133 | |
| 134 | void LayoutInfo::print(raw_ostream &os) const { |
| 135 | if (isAssigned()) { |
| 136 | os << "lane_layout: " ; |
| 137 | laneLayout.print(os); |
| 138 | os << ", lane_data: " ; |
| 139 | laneData.print(os); |
| 140 | } else { |
| 141 | os << "Not assigned." ; |
| 142 | } |
| 143 | } |
| 144 | |
| 145 | LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) { |
| 146 | if (!lhs.isAssigned()) |
| 147 | return rhs; |
| 148 | return lhs; |
| 149 | } |
| 150 | |
| 151 | /// Since this is a backward analysis, join method is not used. |
| 152 | LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) { |
| 153 | llvm_unreachable("Join should not be triggered by layout propagation." ); |
| 154 | } |
| 155 | |
| 156 | /// Get the transposed layout according to the given permutation. |
| 157 | LayoutInfo |
| 158 | LayoutInfo::getTransposedLayout(ArrayRef<int64_t> permutation) const { |
| 159 | if (!isAssigned()) |
| 160 | return {}; |
| 161 | LaneLayout newLayout; |
| 162 | LaneData newData; |
| 163 | for (int64_t idx : permutation) { |
| 164 | newLayout.layout.push_back(Elt: laneLayout.layout[idx]); |
| 165 | newData.layout.push_back(Elt: laneData.layout[idx]); |
| 166 | } |
| 167 | return LayoutInfo(newLayout, newData); |
| 168 | } |
| 169 | |
| 170 | //===----------------------------------------------------------------------===// |
| 171 | // LayoutInfoLattice |
| 172 | //===----------------------------------------------------------------------===// |
| 173 | |
| 174 | /// Lattice holding the LayoutInfo for each value. |
| 175 | struct LayoutInfoLattice : public Lattice<LayoutInfo> { |
| 176 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LayoutInfoLattice) |
| 177 | using Lattice::Lattice; |
| 178 | }; |
| 179 | |
| 180 | /// Helper Functions to get default layouts. A `default layout` is a layout that |
| 181 | /// is assigned to a value when the layout is not fixed by some anchor operation |
| 182 | /// (like DPAS). |
| 183 | |
| 184 | /// Helper Function to get the default layout for uniform values like constants. |
| 185 | /// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1]. |
| 186 | /// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1]. |
| 187 | static LayoutInfo getDefaultSIMTLayoutInfo(unsigned rank) { |
| 188 | assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector." ); |
| 189 | if (rank == 1) |
| 190 | return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize}), |
| 191 | LaneData({1})); |
| 192 | return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}), |
| 193 | LaneData({1, 1})); |
| 194 | } |
| 195 | |
| 196 | /// Helper to get the default layout for a vector type. |
| 197 | static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) { |
| 198 | // Expecting a 1D or 2D vector. |
| 199 | assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) && |
| 200 | "Expected 1D or 2D vector." ); |
| 201 | // Expecting int or float element type. |
| 202 | assert(vectorTy.getElementType().isIntOrFloat() && |
| 203 | "Expected int or float element type." ); |
| 204 | // If the rank is 1, then return default layout for 1D vector. |
| 205 | if (vectorTy.getRank() == 1) |
| 206 | return getDefaultSIMTLayoutInfo(rank: 1); |
| 207 | // Packing factor is determined by the element type bitwidth. |
| 208 | int packingFactor = 1; |
| 209 | unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth(); |
| 210 | if (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault) |
| 211 | packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth; |
| 212 | return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}), |
| 213 | LaneData({1, packingFactor})); |
| 214 | } |
| 215 | |
| 216 | /// Helper to get the default layout for a vector type. |
| 217 | static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) { |
| 218 | // Expecting a 1D or 2D vector. |
| 219 | assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) && |
| 220 | "Expected 1D or 2D TensorDesc." ); |
| 221 | // Expecting int or float element type. |
| 222 | assert(tdescTy.getElementType().isIntOrFloat() && |
| 223 | "Expected int or float element type." ); |
| 224 | // If the rank is 1, then return default layout for 1D vector. |
| 225 | if (tdescTy.getRank() == 1) |
| 226 | return getDefaultSIMTLayoutInfo(rank: 1); |
| 227 | // Packing factor is determined by the element type bitwidth. |
| 228 | unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth(); |
| 229 | |
| 230 | if (tdescTy.isScattered()) { |
| 231 | int packingFactor = |
| 232 | bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter |
| 233 | ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth |
| 234 | : 1; |
| 235 | return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize, 1}), |
| 236 | LaneData({1, packingFactor})); |
| 237 | } |
| 238 | |
| 239 | int packingFactor = |
| 240 | (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault) |
| 241 | ? xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth |
| 242 | : 1; |
| 243 | return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}), |
| 244 | LaneData({1, packingFactor})); |
| 245 | } |
| 246 | |
| 247 | /// Helper Function to get the expected layouts for DPAS operands. `lane_data` |
| 248 | /// is set according to the following criteria: |
| 249 | /// * For A operand, the data must be packed in minimum |
| 250 | /// `packedSizeInBitsForDefault` |
| 251 | /// * For B operand, the data must be packed in minimum |
| 252 | /// `packedSizeInBitsForDpasB` |
| 253 | static LayoutInfo getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, |
| 254 | unsigned operandNum) { |
| 255 | Type elementTy = vectorTy.getElementType(); |
| 256 | assert(elementTy.isIntOrFloat() && |
| 257 | "Expected int or float type in DPAS operands" ); |
| 258 | LaneLayout layout({1, xegpu::targetinfo::subgroupSize}); |
| 259 | // For B operand, data must be packed in minimum `packedDpasBSizeInBits` and |
| 260 | // must have the VNNI format. |
| 261 | if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() < |
| 262 | xegpu::targetinfo::packedSizeInBitsForDpasB) { |
| 263 | LaneData data({xegpu::targetinfo::packedSizeInBitsForDpasB / |
| 264 | elementTy.getIntOrFloatBitWidth(), |
| 265 | 1}); |
| 266 | return LayoutInfo(layout, data); |
| 267 | } |
| 268 | // Otherwise, return the default layout for the vector type. |
| 269 | return getDefaultSIMTLayoutInfo(vectorTy); |
| 270 | } |
| 271 | |
| 272 | //===----------------------------------------------------------------------===// |
| 273 | // LayoutInfoPropagation |
| 274 | //===----------------------------------------------------------------------===// |
| 275 | |
| 276 | /// Backward data flow analysis to propagate the lane_layout and lane_data of |
| 277 | /// each value in the program. Currently, the layouts for operands DPAS, |
| 278 | /// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of |
| 279 | /// this analysis is to propagate those known layouts to all their producers and |
| 280 | /// (other) consumers. |
| 281 | class LayoutInfoPropagation |
| 282 | : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> { |
| 283 | private: |
| 284 | void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands, |
| 285 | ArrayRef<const LayoutInfoLattice *> results); |
| 286 | |
| 287 | void visitStoreNdOp(xegpu::StoreNdOp store, |
| 288 | ArrayRef<LayoutInfoLattice *> operands, |
| 289 | ArrayRef<const LayoutInfoLattice *> results); |
| 290 | |
| 291 | void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter, |
| 292 | ArrayRef<LayoutInfoLattice *> operands, |
| 293 | ArrayRef<const LayoutInfoLattice *> results); |
| 294 | |
| 295 | void visitLoadNdOp(xegpu::LoadNdOp load, |
| 296 | ArrayRef<LayoutInfoLattice *> operands, |
| 297 | ArrayRef<const LayoutInfoLattice *> results); |
| 298 | |
| 299 | void visitLoadGatherOp(xegpu::LoadGatherOp load, |
| 300 | ArrayRef<LayoutInfoLattice *> operands, |
| 301 | ArrayRef<const LayoutInfoLattice *> results); |
| 302 | |
| 303 | void visitTransposeOp(vector::TransposeOp transpose, |
| 304 | ArrayRef<LayoutInfoLattice *> operands, |
| 305 | ArrayRef<const LayoutInfoLattice *> results); |
| 306 | |
| 307 | void visitVectorBitcastOp(vector::BitCastOp bitcast, |
| 308 | ArrayRef<LayoutInfoLattice *> operands, |
| 309 | ArrayRef<const LayoutInfoLattice *> results); |
| 310 | |
| 311 | void visitCreateDescOp(xegpu::CreateDescOp createDesc, |
| 312 | ArrayRef<LayoutInfoLattice *> operands, |
| 313 | ArrayRef<const LayoutInfoLattice *> results); |
| 314 | |
| 315 | void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset, |
| 316 | ArrayRef<LayoutInfoLattice *> operands, |
| 317 | ArrayRef<const LayoutInfoLattice *> results); |
| 318 | |
| 319 | void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch, |
| 320 | ArrayRef<LayoutInfoLattice *> operands, |
| 321 | ArrayRef<const LayoutInfoLattice *> results); |
| 322 | |
| 323 | void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction, |
| 324 | ArrayRef<LayoutInfoLattice *> operands, |
| 325 | ArrayRef<const LayoutInfoLattice *> results); |
| 326 | |
| 327 | public: |
| 328 | LayoutInfoPropagation(DataFlowSolver &solver, |
| 329 | SymbolTableCollection &symbolTable) |
| 330 | : SparseBackwardDataFlowAnalysis(solver, symbolTable) {} |
| 331 | using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; |
| 332 | |
| 333 | LogicalResult |
| 334 | visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands, |
| 335 | ArrayRef<const LayoutInfoLattice *> results) override; |
| 336 | |
| 337 | void visitBranchOperand(OpOperand &operand) override {}; |
| 338 | |
| 339 | void visitCallOperand(OpOperand &operand) override {}; |
| 340 | |
| 341 | void visitExternalCall(CallOpInterface call, |
| 342 | ArrayRef<LayoutInfoLattice *> operands, |
| 343 | ArrayRef<const LayoutInfoLattice *> results) override { |
| 344 | }; |
| 345 | |
| 346 | void setToExitState(LayoutInfoLattice *lattice) override { |
| 347 | (void)lattice->meet(rhs: LayoutInfo()); |
| 348 | } |
| 349 | }; |
| 350 | } // namespace |
| 351 | |
| 352 | LogicalResult LayoutInfoPropagation::visitOperation( |
| 353 | Operation *op, ArrayRef<LayoutInfoLattice *> operands, |
| 354 | ArrayRef<const LayoutInfoLattice *> results) { |
| 355 | TypeSwitch<Operation *>(op) |
| 356 | .Case<xegpu::DpasOp>( |
| 357 | caseFn: [&](auto dpasOp) { visitDpasOp(dpas: dpasOp, operands, results); }) |
| 358 | .Case<xegpu::StoreNdOp>( |
| 359 | caseFn: [&](auto storeNdOp) { visitStoreNdOp(store: storeNdOp, operands, results); }) |
| 360 | .Case<xegpu::StoreScatterOp>(caseFn: [&](auto storeScatterOp) { |
| 361 | visitStoreScatterOp(storeScatter: storeScatterOp, operands, results); |
| 362 | }) |
| 363 | .Case<xegpu::LoadNdOp>( |
| 364 | caseFn: [&](auto loadNdOp) { visitLoadNdOp(load: loadNdOp, operands, results); }) |
| 365 | .Case<xegpu::LoadGatherOp>(caseFn: [&](auto loadGatherOp) { |
| 366 | visitLoadGatherOp(load: loadGatherOp, operands, results); |
| 367 | }) |
| 368 | .Case<xegpu::CreateDescOp>(caseFn: [&](auto createDescOp) { |
| 369 | visitCreateDescOp(createDesc: createDescOp, operands, results); |
| 370 | }) |
| 371 | .Case<xegpu::UpdateNdOffsetOp>(caseFn: [&](auto updateNdOffsetOp) { |
| 372 | visitUpdateNdOffsetOp(updateNdOffset: updateNdOffsetOp, operands, results); |
| 373 | }) |
| 374 | .Case<xegpu::PrefetchNdOp>(caseFn: [&](auto prefetchNdOp) { |
| 375 | visitPrefetchNdOp(prefetch: prefetchNdOp, operands, results); |
| 376 | }) |
| 377 | .Case<vector::TransposeOp>(caseFn: [&](auto transposeOp) { |
| 378 | visitTransposeOp(transpose: transposeOp, operands, results); |
| 379 | }) |
| 380 | .Case<vector::BitCastOp>(caseFn: [&](auto bitcastOp) { |
| 381 | visitVectorBitcastOp(bitcast: bitcastOp, operands, results); |
| 382 | }) |
| 383 | .Case<vector::MultiDimReductionOp>(caseFn: [&](auto reductionOp) { |
| 384 | visitVectorMultiReductionOp(reduction: reductionOp, operands, results); |
| 385 | }) |
| 386 | // All other ops. |
| 387 | .Default(defaultFn: [&](Operation *op) { |
| 388 | for (const LayoutInfoLattice *resultInfo : results) { |
| 389 | if (!resultInfo->getValue().isAssigned()) |
| 390 | continue; |
| 391 | for (auto [operandInfo, operand] : |
| 392 | llvm::zip(t&: operands, u: op->getOpOperands())) { |
| 393 | // If the operand type is not a vector or tensor descriptor, skip |
| 394 | // it. |
| 395 | if (!isa<xegpu::TensorDescType, VectorType>( |
| 396 | Val: operand.get().getType())) |
| 397 | continue; |
| 398 | // Propagate the result layout to the operand. |
| 399 | meet(lhs: operandInfo, rhs: *resultInfo); |
| 400 | } |
| 401 | } |
| 402 | }); |
| 403 | |
| 404 | return success(); |
| 405 | } |
| 406 | |
| 407 | void LayoutInfoPropagation::visitPrefetchNdOp( |
| 408 | xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands, |
| 409 | ArrayRef<const LayoutInfoLattice *> results) { |
| 410 | // Here we assign the default layout to the tensor descriptor operand of |
| 411 | // prefetch. |
| 412 | auto tdescTy = prefetch.getTensorDescType(); |
| 413 | auto prefetchLayout = getDefaultSIMTLayoutInfo(tdescTy); |
| 414 | // Propagate the layout to the source tensor descriptor. |
| 415 | propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: prefetchLayout)); |
| 416 | } |
| 417 | |
| 418 | void LayoutInfoPropagation::visitVectorMultiReductionOp( |
| 419 | vector::MultiDimReductionOp reduction, |
| 420 | ArrayRef<LayoutInfoLattice *> operands, |
| 421 | ArrayRef<const LayoutInfoLattice *> results) { |
| 422 | // The layout of the result must be present. |
| 423 | LayoutInfo resultLayout = results[0]->getValue(); |
| 424 | if (!resultLayout.isAssigned()) |
| 425 | return; |
| 426 | // We only consider 2D -> 1D reductions at this point. |
| 427 | VectorType resultTy = llvm::dyn_cast<VectorType>(Val: reduction.getDestType()); |
| 428 | if (!resultTy || resultTy.getRank() != 1) { |
| 429 | reduction.emitWarning(message: "Expecting output type to be 1D vector." ); |
| 430 | return; |
| 431 | } |
| 432 | // Given that the result is 1D, the layout of the operand should be 2D with |
| 433 | // default layout. |
| 434 | LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(rank: 2); |
| 435 | propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: operandLayout)); |
| 436 | // Accumulator should have the same layout as the result. |
| 437 | propagateIfChanged(state: operands[1], changed: operands[1]->meet(rhs: resultLayout)); |
| 438 | } |
| 439 | |
| 440 | /// Propagate the layout of the result tensor to the source tensor descriptor in |
| 441 | /// UpdateNdOffsetOp. |
| 442 | void LayoutInfoPropagation::visitUpdateNdOffsetOp( |
| 443 | xegpu::UpdateNdOffsetOp updateNdOffset, |
| 444 | ArrayRef<LayoutInfoLattice *> operands, |
| 445 | ArrayRef<const LayoutInfoLattice *> results) { |
| 446 | // The layout of the result must be present. |
| 447 | LayoutInfo resultLayout = results[0]->getValue(); |
| 448 | if (!resultLayout.isAssigned()) |
| 449 | return; |
| 450 | // Propagate the layout to the source operand. |
| 451 | propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: resultLayout)); |
| 452 | } |
| 453 | |
| 454 | /// Set the layouts for DPAS A, B, and C operands. |
| 455 | void LayoutInfoPropagation::visitDpasOp( |
| 456 | xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands, |
| 457 | ArrayRef<const LayoutInfoLattice *> results) { |
| 458 | VectorType aTy = dpas.getLhsType(); |
| 459 | VectorType bTy = dpas.getRhsType(); |
| 460 | propagateIfChanged( |
| 461 | state: operands[0], changed: operands[0]->meet(rhs: getSIMTLayoutInfoForDPASOperand(vectorTy: aTy, operandNum: 0))); |
| 462 | propagateIfChanged( |
| 463 | state: operands[1], changed: operands[1]->meet(rhs: getSIMTLayoutInfoForDPASOperand(vectorTy: bTy, operandNum: 1))); |
| 464 | if (operands.size() > 2) { |
| 465 | VectorType cTy = dpas.getAccType(); |
| 466 | propagateIfChanged( |
| 467 | state: operands[2], |
| 468 | changed: operands[2]->meet(rhs: getSIMTLayoutInfoForDPASOperand(vectorTy: cTy, operandNum: 2))); |
| 469 | } |
| 470 | } |
| 471 | |
| 472 | /// Set the layout for the value and tensor descriptor operands in StoreNdOp. |
| 473 | void LayoutInfoPropagation::visitStoreNdOp( |
| 474 | xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands, |
| 475 | ArrayRef<const LayoutInfoLattice *> results) { |
| 476 | LayoutInfo storeLayout = getDefaultSIMTLayoutInfo(vectorTy: store.getValueType()); |
| 477 | // Both operands should have the same layout |
| 478 | for (LayoutInfoLattice *operand : operands) |
| 479 | propagateIfChanged(state: operand, changed: operand->meet(rhs: storeLayout)); |
| 480 | } |
| 481 | |
| 482 | /// Propagate the layout of the value to the tensor descriptor operand in |
| 483 | /// LoadNdOp. |
| 484 | void LayoutInfoPropagation::visitLoadNdOp( |
| 485 | xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands, |
| 486 | ArrayRef<const LayoutInfoLattice *> results) { |
| 487 | LayoutInfo valueLayout = results[0]->getValue(); |
| 488 | // Need the layout of the value to propagate to the tensor descriptor. |
| 489 | if (!valueLayout.isAssigned()) |
| 490 | return; |
| 491 | LayoutInfo tensorDescLayout = valueLayout; |
| 492 | // LoadNdOp has the transpose effect. However, at the stage of this analysis |
| 493 | // this effect is not expected and should be abstracted away. Emit a |
| 494 | // warning. |
| 495 | if (auto transpose = load.getTranspose()) { |
| 496 | load.emitWarning(message: "Transpose effect is not expected for LoadNdOp at " |
| 497 | "LayoutInfoPropagation stage." ); |
| 498 | tensorDescLayout = valueLayout.getTransposedLayout(permutation: transpose.value()); |
| 499 | } |
| 500 | // Propagate the new layout to the tensor descriptor operand. |
| 501 | propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: tensorDescLayout)); |
| 502 | } |
| 503 | |
| 504 | /// For vector::TransposeOp, the layout of the result is transposed and |
| 505 | /// propagated to the operand. |
| 506 | void LayoutInfoPropagation::visitTransposeOp( |
| 507 | vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands, |
| 508 | ArrayRef<const LayoutInfoLattice *> results) { |
| 509 | // Need the layout of transpose result to propagate to the operands. |
| 510 | LayoutInfo resultLayout = results[0]->getValue(); |
| 511 | if (!resultLayout.isAssigned()) |
| 512 | return; |
| 513 | LayoutInfo newLayout = |
| 514 | resultLayout.getTransposedLayout(permutation: transpose.getPermutation()); |
| 515 | // Propagate the new layout to the vector operand. |
| 516 | propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: newLayout)); |
| 517 | } |
| 518 | |
| 519 | /// For vector::BitCastOp, the lane_data of the source layout is changed based |
| 520 | /// on the bit width of the source and result types. |
| 521 | void LayoutInfoPropagation::visitVectorBitcastOp( |
| 522 | vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands, |
| 523 | ArrayRef<const LayoutInfoLattice *> results) { |
| 524 | // Need the layout of bitcast result to propagate to the operands. |
| 525 | LayoutInfo resultLayout = results[0]->getValue(); |
| 526 | if (!resultLayout.isAssigned()) |
| 527 | return; |
| 528 | int inElemTyBitWidth = |
| 529 | bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth(); |
| 530 | int outElemTyBitWidth = |
| 531 | bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth(); |
| 532 | |
| 533 | // NOTE: We do not expect widening or narrowing bitcasts at this stage. Emit |
| 534 | // a warning and return. |
| 535 | if (inElemTyBitWidth != outElemTyBitWidth) { |
| 536 | bitcast.emitWarning(message: "Widening or narrowing bitcasts are not expected at " |
| 537 | "layout propagation stage." ); |
| 538 | return; |
| 539 | } |
| 540 | |
| 541 | propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: resultLayout)); |
| 542 | } |
| 543 | |
| 544 | /// Propagate the layout of the result to the tensor descriptor and mask |
| 545 | /// operands in LoadGatherOp. |
| 546 | void LayoutInfoPropagation::visitLoadGatherOp( |
| 547 | xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands, |
| 548 | ArrayRef<const LayoutInfoLattice *> results) { |
| 549 | // The layout is strictly determined by the tensor descriptor type. |
| 550 | LayoutInfo layout = getDefaultSIMTLayoutInfo(tdescTy: load.getTensorDescType()); |
| 551 | |
| 552 | // Mask operand should have 1D default layout. |
| 553 | LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(rank: 1); |
| 554 | |
| 555 | // Propagate the new layout to the tensor descriptor operand. |
| 556 | propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: layout)); |
| 557 | // Propagate the new layout to the mask operand. |
| 558 | propagateIfChanged(state: operands[1], changed: operands[1]->meet(rhs: maskLayout)); |
| 559 | } |
| 560 | |
| 561 | /// Propagate the layout of the descriptor to the vector offset operand in |
| 562 | /// CreateDescOp. |
| 563 | void LayoutInfoPropagation::visitCreateDescOp( |
| 564 | xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands, |
| 565 | ArrayRef<const LayoutInfoLattice *> results) { |
| 566 | LayoutInfo descLayout = results[0]->getValue(); |
| 567 | // Need the layout of the descriptor to propagate to the operands. |
| 568 | if (!descLayout.isAssigned()) |
| 569 | return; |
| 570 | // For offset operand propagate 1D default layout. |
| 571 | LayoutInfo layout = getDefaultSIMTLayoutInfo(rank: 1); |
| 572 | propagateIfChanged(state: operands[1], changed: operands[1]->meet(rhs: layout)); |
| 573 | } |
| 574 | |
| 575 | /// Set the layout for the value, tensor descriptor, and mask operands in the |
| 576 | /// StoreScatterOp. |
| 577 | void LayoutInfoPropagation::visitStoreScatterOp( |
| 578 | xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands, |
| 579 | ArrayRef<const LayoutInfoLattice *> results) { |
| 580 | // Currently, for 2D StoreScatterOp we expect that the height dimension of |
| 581 | // the tensor descriptor is equal to the subgroup size. This is ensured by |
| 582 | // the op verifier. |
| 583 | ArrayRef<int64_t> tdescShape = storeScatter.getTensorDescType().getShape(); |
| 584 | if (tdescShape.size() > 1) |
| 585 | assert( |
| 586 | tdescShape[0] == xegpu::targetinfo::subgroupSize && |
| 587 | "Expected the first dimension of 2D tensor descriptor to be equal to " |
| 588 | "subgroup size." ); |
| 589 | |
| 590 | LayoutInfo layout = |
| 591 | getDefaultSIMTLayoutInfo(tdescTy: storeScatter.getTensorDescType()); |
| 592 | |
| 593 | // Propagate the value layout. |
| 594 | propagateIfChanged(state: operands[0], changed: operands[0]->meet(rhs: layout)); |
| 595 | // Propagate the tensor descriptor layout. |
| 596 | propagateIfChanged(state: operands[1], changed: operands[1]->meet(rhs: layout)); |
| 597 | // Use default 1D layout for mask operand. |
| 598 | LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(rank: 1); |
| 599 | propagateIfChanged(state: operands[2], changed: operands[2]->meet(rhs: maskLayout)); |
| 600 | } |
| 601 | |
| 602 | namespace { |
| 603 | //===----------------------------------------------------------------------===// |
| 604 | // RunLayoutInfoPropagation |
| 605 | //===----------------------------------------------------------------------===// |
| 606 | |
| 607 | /// Driver class for running the LayoutInfoPropagation analysis. |
| 608 | class RunLayoutInfoPropagation { |
| 609 | public: |
| 610 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation) |
| 611 | |
| 612 | RunLayoutInfoPropagation(Operation *op) : target(op) { |
| 613 | SymbolTableCollection symbolTable; |
| 614 | loadBaselineAnalyses(solver); |
| 615 | solver.load<LayoutInfoPropagation>(args&: symbolTable); |
| 616 | (void)solver.initializeAndRun(top: op); |
| 617 | } |
| 618 | |
| 619 | LayoutInfo getLayoutInfo(Value val); |
| 620 | |
| 621 | void printAnalysisResult(llvm::raw_ostream &os); |
| 622 | |
| 623 | private: |
| 624 | DataFlowSolver solver; |
| 625 | const Operation *target; |
| 626 | }; |
| 627 | } // namespace |
| 628 | |
| 629 | LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) { |
| 630 | auto *state = solver.lookupState<LayoutInfoLattice>(anchor: val); |
| 631 | if (!state) |
| 632 | return {}; |
| 633 | return state->getValue(); |
| 634 | } |
| 635 | |
| 636 | // Print the analysis result for debugging purposes. |
| 637 | void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) { |
| 638 | auto printFunctionResult = [&](FunctionOpInterface funcOp) { |
| 639 | os << "function: " << funcOp.getName() << ":\n" ; |
| 640 | // Function arguments |
| 641 | for (BlockArgument arg : funcOp.getArguments()) { |
| 642 | LayoutInfo layout = getLayoutInfo(val: arg); |
| 643 | os << "argument: " << arg << "\n" ; |
| 644 | os << "layout : " ; |
| 645 | layout.print(os); |
| 646 | os << "\n" ; |
| 647 | } |
| 648 | // Function ops |
| 649 | funcOp.walk(callback: [&](Operation *op) { |
| 650 | // Skip ops that do not have results |
| 651 | if (op->getResults().empty()) |
| 652 | return; |
| 653 | os << "op : " ; |
| 654 | // For control-flow ops, print the op name only. |
| 655 | if (isa<BranchOpInterface>(Val: op) || isa<RegionBranchOpInterface>(Val: op)) |
| 656 | os << op->getName(); |
| 657 | else |
| 658 | op->print(os); |
| 659 | os << "\n" ; |
| 660 | // Print the layout for each result. |
| 661 | for (auto [i, r] : llvm::enumerate(First: op->getResults())) { |
| 662 | LayoutInfo layout = getLayoutInfo(val: r); |
| 663 | os << "layout for result #" << i << ": " ; |
| 664 | layout.print(os); |
| 665 | os << "\n" ; |
| 666 | } |
| 667 | }); |
| 668 | }; |
| 669 | |
| 670 | SmallVector<FunctionOpInterface> funcOps; |
| 671 | if (auto modOp = dyn_cast<ModuleOp>(Val: target)) { |
| 672 | for (auto funcOp : modOp.getOps<FunctionOpInterface>()) |
| 673 | funcOps.push_back(Elt: funcOp); |
| 674 | |
| 675 | // Collect all GpuFuncOps in the module. |
| 676 | for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) { |
| 677 | for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>()) |
| 678 | funcOps.push_back(Elt: gpuFuncOp); |
| 679 | } |
| 680 | } |
| 681 | // Print the analysis result for each function. |
| 682 | for (FunctionOpInterface funcOp : funcOps) |
| 683 | printFunctionResult(funcOp); |
| 684 | } |
| 685 | |
| 686 | using GetLayoutFnTy = function_ref<xegpu::LayoutAttr(Value)>; |
| 687 | /// Update an operation with the layout of its results. If the result type is a |
| 688 | /// vector type, a temporary layout attribute is added to the operation. If the |
| 689 | /// result type is a tensor descriptor type, the type is updated with the layout |
| 690 | /// attribute. The users of the result are also updated with the layout |
| 691 | /// attribute. |
| 692 | static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, |
| 693 | GetLayoutFnTy getLayoutOfValue) { |
| 694 | // Region ops (like scf.for) are already handled by the updateControlFlowOps. |
| 695 | if (mlir::isa<mlir::RegionBranchOpInterface>(Val: op)) |
| 696 | return success(); |
| 697 | |
| 698 | // Iterate over all the results. |
| 699 | for (OpResult result : op->getResults()) { |
| 700 | Type resultType = result.getType(); |
| 701 | // Layouts are needed only for vector and tensor descriptor types. |
| 702 | if (!isa<VectorType, xegpu::TensorDescType>(Val: resultType)) |
| 703 | continue; |
| 704 | // If the result has no layout but has users, emit a warning and continue. |
| 705 | xegpu::LayoutAttr layout = getLayoutOfValue(result); |
| 706 | if (!layout && result.getNumUses() > 0) { |
| 707 | op->emitWarning(message: "op has users but no layout assigned for its result" ); |
| 708 | continue; |
| 709 | } |
| 710 | // If the result is a tensor descriptor type, update the tensor desc type |
| 711 | // with layout. |
| 712 | if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(Val&: resultType)) { |
| 713 | auto typeWithLayout = xegpu::TensorDescType::get( |
| 714 | context: tensorDescTy.getContext(), shape: tensorDescTy.getShape(), |
| 715 | elementType: tensorDescTy.getElementType(), encoding: tensorDescTy.getEncoding(), layout); |
| 716 | result.setType(typeWithLayout); |
| 717 | continue; |
| 718 | } |
| 719 | // If the result is a vector type, add a temporary layout attribute to the |
| 720 | // op. |
| 721 | xegpu::setLayoutAttr(operandOrResult: result, layout); |
| 722 | } |
| 723 | return success(); |
| 724 | } |
| 725 | |
| 726 | /// Region ops like scf.for need special handling because they have blocks |
| 727 | /// inside. If the blocks have tensor descriptor type as block arguments, thier |
| 728 | /// types must be updated. Also region op can have results that may not have any |
| 729 | /// users (e.g. A and B tiles). They are not assigned a layout by layout |
| 730 | /// analysis because they have no users. However inside the region op |
| 731 | /// corresponding block arguments for these results do have layouts. Therefore, |
| 732 | /// in this case we still need to update the result types with the layout |
| 733 | /// attribute. This function function updates the internal block arguments and |
| 734 | /// the result types of the region op with the assigned layouts. |
| 735 | /// clang-format off |
| 736 | /// Example: scf.for ... iter_args(...) -> (out types) { |
| 737 | /// ^bb0(block types): |
| 738 | /// ... |
| 739 | /// scf.yield ... : (yield types) |
| 740 | /// } |
| 741 | /// clang-format on |
| 742 | /// In this example, at scf.yield, control-flow can transfer to two successor |
| 743 | /// regions. One is the ^bb0 (for loop body) and the other is the scf.for op |
| 744 | /// itself (yield the results). So we update both the block arguments of the |
| 745 | /// successor region (i.e. block types) and the result types of the scf.for op |
| 746 | /// (i.e. out types). Note that yield types are updated by respective producers |
| 747 | /// inside bb0. |
| 748 | static LogicalResult |
| 749 | updateControlFlowOps(mlir::OpBuilder &builder, |
| 750 | mlir::RegionBranchTerminatorOpInterface terminator, |
| 751 | GetLayoutFnTy getLayoutOfValue) { |
| 752 | // Only process if the terminator is inside a region branch op. |
| 753 | if (!mlir::isa<mlir::RegionBranchOpInterface>(Val: terminator->getParentOp())) |
| 754 | return success(); |
| 755 | |
| 756 | llvm::SmallVector<mlir::RegionSuccessor> successors; |
| 757 | llvm::SmallVector<mlir::Attribute> operands(terminator->getNumOperands(), |
| 758 | nullptr); |
| 759 | terminator.getSuccessorRegions(operands, regions&: successors); |
| 760 | |
| 761 | for (mlir::RegionSuccessor &successor : successors) { |
| 762 | mlir::OperandRange successorOperands = |
| 763 | terminator.getSuccessorOperands(point: successor); |
| 764 | mlir::ValueRange successorInputs = successor.getSuccessorInputs(); |
| 765 | for (auto [successorOperand, successorInput] : |
| 766 | llvm::zip(t&: successorOperands, u&: successorInputs)) { |
| 767 | Type inputType = successorInput.getType(); |
| 768 | // We only need to operate on tensor descriptor or vector types. |
| 769 | if (!isa<xegpu::TensorDescType, VectorType>(Val: inputType)) |
| 770 | continue; |
| 771 | xegpu::LayoutAttr successorInputLayout = getLayoutOfValue(successorInput); |
| 772 | xegpu::LayoutAttr successorOperandLayout = |
| 773 | getLayoutOfValue(successorOperand); |
| 774 | |
| 775 | // If either of the layouts is not assigned, we cannot proceed. |
| 776 | if (!successorOperandLayout) { |
| 777 | LLVM_DEBUG( |
| 778 | DBGS() |
| 779 | << "No layout assigned for forwarded operand in branch terminator: " |
| 780 | << successorOperand << "\n" ); |
| 781 | return failure(); |
| 782 | } |
| 783 | // We expect the layouts to match. |
| 784 | if (successorInputLayout && |
| 785 | successorInputLayout != successorOperandLayout) { |
| 786 | LLVM_DEBUG(DBGS() << "Conflicting layouts for region argument and " |
| 787 | "operand forwarded as the argument: " |
| 788 | << successorInputLayout << " vs " |
| 789 | << successorOperandLayout << "\n" ); |
| 790 | return failure(); |
| 791 | } |
| 792 | // Get tensor descriptor type with the layout. |
| 793 | if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(Val&: inputType)) { |
| 794 | auto newTdescTy = xegpu::TensorDescType::get( |
| 795 | context: tdescTy.getContext(), shape: tdescTy.getShape(), elementType: tdescTy.getElementType(), |
| 796 | encoding: tdescTy.getEncoding(), layout: successorOperandLayout); |
| 797 | successorInput.setType(newTdescTy); |
| 798 | continue; |
| 799 | } |
| 800 | // If the type is a vector type and this region argument is an OpResult, |
| 801 | // set the layout attribute on the OpResult. |
| 802 | if (auto result = dyn_cast<OpResult>(Val&: successorInput)) |
| 803 | xegpu::setLayoutAttr(operandOrResult: result, layout: successorOperandLayout); |
| 804 | } |
| 805 | } |
| 806 | return success(); |
| 807 | } |
| 808 | |
| 809 | /// Update the function arguments and results with the layouts. |
| 810 | static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder, |
| 811 | mlir::FunctionOpInterface funcOp, |
| 812 | GetLayoutFnTy getLayoutOfValue) { |
| 813 | SmallVector<Type> newArgTypes; |
| 814 | // Update the function arguments. |
| 815 | for (BlockArgument arg : funcOp.getArguments()) { |
| 816 | Type argType = arg.getType(); |
| 817 | newArgTypes.push_back(Elt: argType); |
| 818 | if (!isa<VectorType, xegpu::TensorDescType>(Val: argType)) |
| 819 | continue; |
| 820 | xegpu::LayoutAttr layout = getLayoutOfValue(arg); |
| 821 | if (!layout) { |
| 822 | LLVM_DEBUG(DBGS() << "Expecting layout for function argument: " << arg |
| 823 | << " but got none.\n" ); |
| 824 | return failure(); |
| 825 | } |
| 826 | if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(Val&: argType)) { |
| 827 | auto newTdescTy = xegpu::TensorDescType::get( |
| 828 | context: tensorDescTy.getContext(), shape: tensorDescTy.getShape(), |
| 829 | elementType: tensorDescTy.getElementType(), encoding: tensorDescTy.getEncoding(), layout); |
| 830 | arg.setType(newTdescTy); |
| 831 | newArgTypes.back() = newTdescTy; |
| 832 | } |
| 833 | } |
| 834 | // Update the function type with the new argument types. |
| 835 | // NOTE: We assume that function results are not expected to have layouts. |
| 836 | funcOp.setType(FunctionType::get(context: funcOp.getContext(), inputs: newArgTypes, |
| 837 | results: funcOp.getResultTypes())); |
| 838 | return success(); |
| 839 | } |
| 840 | |
| 841 | namespace { |
| 842 | struct XeGPUPropagateLayoutPass final |
| 843 | : public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> { |
| 844 | XeGPUPropagateLayoutPass() = default; |
| 845 | XeGPUPropagateLayoutPass(const XeGPUPropagateLayoutPass &other) = default; |
| 846 | XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions options) |
| 847 | : XeGPUPropagateLayoutBase(options) {} |
| 848 | void runOnOperation() override; |
| 849 | }; |
| 850 | |
| 851 | } // namespace |
| 852 | |
| 853 | void XeGPUPropagateLayoutPass::runOnOperation() { |
| 854 | auto &analysis = getAnalysis<RunLayoutInfoPropagation>(); |
| 855 | // Print the analysis result and exit. (for debugging purposes) |
| 856 | if (printOnly) { |
| 857 | auto &os = llvm::outs(); |
| 858 | analysis.printAnalysisResult(os); |
| 859 | return; |
| 860 | } |
| 861 | // Helper to convert LayoutInfo to xegpu::LayoutAttr. |
| 862 | auto getXeGPULayoutForValue = [&](Value val) -> xegpu::LayoutAttr { |
| 863 | LayoutInfo layout = analysis.getLayoutInfo(val); |
| 864 | if (!layout.isAssigned()) |
| 865 | return {}; |
| 866 | return xegpu::LayoutAttr::get( |
| 867 | context: val.getContext(), lane_layout: llvm::to_vector_of<int>(Range: layout.getLayoutAsArrayRef()), |
| 868 | lane_data: llvm::to_vector_of<int>(Range: layout.getDataAsArrayRef())); |
| 869 | }; |
| 870 | |
| 871 | mlir::OpBuilder builder(&getContext()); |
| 872 | Operation *op = getOperation(); |
| 873 | auto walkResult = op->walk(callback: [&](mlir::Block *block) -> WalkResult { |
| 874 | for (mlir::Operation &op : llvm::reverse(C&: block->getOperations())) { |
| 875 | LogicalResult r = success(); |
| 876 | TypeSwitch<Operation *>(&op) |
| 877 | .Case<mlir::RegionBranchTerminatorOpInterface>( |
| 878 | caseFn: [&](mlir::RegionBranchTerminatorOpInterface branchTermOp) { |
| 879 | r = updateControlFlowOps(builder, terminator: branchTermOp, |
| 880 | getLayoutOfValue: getXeGPULayoutForValue); |
| 881 | }) |
| 882 | .Case<mlir::FunctionOpInterface>( |
| 883 | caseFn: [&](mlir::FunctionOpInterface funcOp) { |
| 884 | r = updateFunctionOpInterface(builder, funcOp, |
| 885 | getLayoutOfValue: getXeGPULayoutForValue); |
| 886 | }) |
| 887 | .Default(defaultFn: [&](Operation *op) { |
| 888 | r = updateOp(builder, op, getLayoutOfValue: getXeGPULayoutForValue); |
| 889 | }); |
| 890 | if (failed(Result: r)) { |
| 891 | op.emitError(message: "Failed to update operation with the layout." ); |
| 892 | return WalkResult::interrupt(); |
| 893 | } |
| 894 | } |
| 895 | return WalkResult::advance(); |
| 896 | }); |
| 897 | if (walkResult.wasInterrupted()) { |
| 898 | signalPassFailure(); |
| 899 | return; |
| 900 | } |
| 901 | } |
| 902 | |