| 1 | //===---- XeGPUUtils.cpp - MLIR Utilities for XeGPUOps ------------------===// |
| 2 | // |
| 3 | // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements utility methods for working with the XeGPU dialect. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" |
| 14 | #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
| 15 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 16 | #include "mlir/Dialect/XeGPU/IR/XeGPU.h" |
| 17 | #include "mlir/IR/Builders.h" |
| 18 | #include "mlir/IR/Operation.h" |
| 19 | #include "mlir/IR/ValueRange.h" |
| 20 | #include "mlir/Interfaces/LoopLikeInterface.h" |
| 21 | #include "mlir/Transforms/DialectConversion.h" |
| 22 | #include "llvm/Support/Debug.h" |
| 23 | #include "llvm/Support/FormatVariadic.h" |
| 24 | #include <cstdint> |
| 25 | #include <numeric> |
| 26 | |
| 27 | using namespace mlir; |
| 28 | |
| 29 | /// convert ArrayRef<ValueRange> into SmallVector<Value> |
| 30 | static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) { |
| 31 | SmallVector<Value> result; |
| 32 | for (const auto &vals : values) |
| 33 | llvm::append_range(C&: result, R: vals); |
| 34 | return result; |
| 35 | } |
| 36 | |
| 37 | FailureOr<VectorType> |
| 38 | mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) { |
| 39 | auto layout = llvm::dyn_cast_if_present<LayoutAttr>(tdescTy.getLayout()); |
| 40 | // It only works for subgroup level layout, which only has lane_layout |
| 41 | // and lane_data, and is to distribute a SIMD code into SIMT code. |
| 42 | if (!layout || !layout.isSgLayout()) |
| 43 | return failure(); |
| 44 | |
| 45 | SmallVector<int64_t> laneData(layout.getLaneData().asArrayRef()); |
| 46 | SmallVector<int64_t> laneLayout(layout.getLaneLayout().asArrayRef()); |
| 47 | auto tdescShape = tdescTy.getShape(); |
| 48 | auto elementType = tdescTy.getElementType(); |
| 49 | |
| 50 | // compute sgSize by multiply elements of laneLayout |
| 51 | // e.g. for 2D layout, sgSize = laneLayout[0] * laneLayout[1] |
| 52 | // e.g. for 1D layout, sgSize = laneLayout[0] |
| 53 | auto sgSize = std::accumulate(first: laneLayout.begin(), last: laneLayout.end(), init: 1, |
| 54 | binary_op: std::multiplies<int64_t>()); |
| 55 | |
| 56 | // Case 1: regular loads/stores |
| 57 | auto scatterAttr = tdescTy.getEncodingAsScatterTensorDescAttr(); |
| 58 | if (scatterAttr) { |
| 59 | auto chunkSize = scatterAttr.getChunkSize().getInt(); |
| 60 | // Verify if the first dimension of the tensor descriptor shape is |
| 61 | // distributable. |
| 62 | assert(tdescShape[0] == laneLayout[0] && |
| 63 | "tensor descriptor shape is not distributable" ); |
| 64 | return VectorType::get({chunkSize}, elementType); |
| 65 | } |
| 66 | |
| 67 | // Case 2: block loads/stores |
| 68 | // Check if the tensor descriptor shape is distributable. |
| 69 | int64_t tensorSize = 1; |
| 70 | for (auto [tdescDim, laneDim, laneDataDim] : |
| 71 | llvm::zip_equal(tdescShape, laneLayout, laneData)) { |
| 72 | assert((tdescDim % (laneDim * laneDataDim) == 0) && |
| 73 | "tensor descriptor shape is not distributable" ); |
| 74 | tensorSize *= tdescDim; |
| 75 | } |
| 76 | // tensorSize must be adjusted for array_length. |
| 77 | tensorSize *= tdescTy.getArrayLength(); |
| 78 | |
| 79 | return VectorType::get({tensorSize / sgSize}, elementType); |
| 80 | } |
| 81 | |
| 82 | FailureOr<VectorType> |
| 83 | mlir::xegpu::getDistributedVectorType(VectorType originalType, |
| 84 | xegpu::LayoutAttr layout) { |
| 85 | int64_t rank = originalType.getRank(); |
| 86 | // Distributed vector type is only supported for 1D, 2D and 3D vectors. |
| 87 | if (rank < 1 || rank > 3) |
| 88 | return failure(); |
| 89 | ArrayRef<int64_t> shape = originalType.getShape(); |
| 90 | // arrayLength is 1 for 1D and 2D vectors, and equal to the first dimension |
| 91 | // of the 3D vector. |
| 92 | int arrayLength = 1; |
| 93 | if (rank == 3) { |
| 94 | arrayLength = shape[0]; |
| 95 | shape = shape.drop_front(); |
| 96 | } |
| 97 | auto helperTdescTy = xegpu::TensorDescType::get( |
| 98 | shape, originalType.getElementType(), arrayLength, |
| 99 | /*boundary_check=*/true, |
| 100 | /*memory_space=*/xegpu::MemorySpace::Global, layout); |
| 101 | return xegpu::getDistributedVectorType(helperTdescTy); |
| 102 | } |
| 103 | |
| 104 | std::string xegpu::getLayoutName(const OpOperand &operand) { |
| 105 | const StringRef prefix("layout_operand_" ); |
| 106 | unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber(); |
| 107 | return llvm::formatv(Fmt: "{0}{1}" , Vals: prefix, Vals&: idx).str(); |
| 108 | } |
| 109 | |
| 110 | std::string xegpu::getLayoutName(const OpResult result) { |
| 111 | const StringRef prefix = "layout_result_" ; |
| 112 | return llvm::formatv(Fmt: "{0}{1}" , Vals: prefix, Vals: result.getResultNumber()).str(); |
| 113 | } |
| 114 | |
| 115 | xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) { |
| 116 | if (!value) |
| 117 | return nullptr; |
| 118 | |
| 119 | if (auto tdescTy = |
| 120 | dyn_cast_if_present<xegpu::TensorDescType>(value.getType())) |
| 121 | return tdescTy.getLayoutAttr(); |
| 122 | |
| 123 | if (auto result = dyn_cast<OpResult>(Val: value)) { |
| 124 | Operation *defOp = result.getDefiningOp(); |
| 125 | assert(defOp && "result must have a defining op" ); |
| 126 | |
| 127 | // for LoadNdOp, the layout is stored in the tensor descriptor |
| 128 | if (auto loadNd = dyn_cast<xegpu::LoadNdOp>(defOp)) |
| 129 | return getLayoutAttr(loadNd.getTensorDesc()); |
| 130 | |
| 131 | std::string layoutName = getLayoutName(result); |
| 132 | if (defOp->hasAttr(name: layoutName)) |
| 133 | return defOp->getAttrOfType<xegpu::LayoutAttr>(layoutName); |
| 134 | } |
| 135 | |
| 136 | if (auto arg = dyn_cast<BlockArgument>(Val: value)) { |
| 137 | auto parentOp = arg.getOwner()->getParentOp(); |
| 138 | if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) { |
| 139 | OpOperand *tiedInit = loop.getTiedLoopInit(arg); |
| 140 | return getLayoutAttr(tiedInit->get()); |
| 141 | } |
| 142 | } |
| 143 | |
| 144 | return nullptr; |
| 145 | } |
| 146 | |
| 147 | xegpu::LayoutAttr xegpu::getLayoutAttr(const OpOperand &opr) { |
| 148 | Operation *op = opr.getOwner(); |
| 149 | std::string layoutName = xegpu::getLayoutName(operand: opr); |
| 150 | if (op->hasAttr(name: layoutName)) |
| 151 | return op->getAttrOfType<xegpu::LayoutAttr>(layoutName); |
| 152 | return getLayoutAttr(opr.get()); |
| 153 | } |
| 154 | |
| 155 | template <typename T, typename> |
| 156 | void xegpu::setLayoutAttr(const T &operandOrResult, const LayoutAttr layout) { |
| 157 | Operation *owner = operandOrResult.getOwner(); |
| 158 | std::string name = xegpu::getLayoutName(operandOrResult); |
| 159 | if (layout && !owner->hasAttrOfType<LayoutAttr>(name)) |
| 160 | owner->setAttr(name, layout); |
| 161 | } |
| 162 | |
| 163 | // Explicit instantiation for OpResult |
| 164 | template void |
| 165 | xegpu::setLayoutAttr<mlir::OpResult>(const mlir::OpResult &result, |
| 166 | const mlir::xegpu::LayoutAttr layout); |
| 167 | |
| 168 | // Explicit instantiation for OpOperand |
| 169 | template void |
| 170 | xegpu::setLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand, |
| 171 | const mlir::xegpu::LayoutAttr layout); |
| 172 | |
| 173 | void xegpu::setLayoutAttrs(Operation *op, |
| 174 | function_ref<LayoutAttr(Value)> getLayoutImpl) { |
| 175 | op->walk(callback: [&](Operation *nestOp) { |
| 176 | for (OpOperand &opr : nestOp->getOpOperands()) { |
| 177 | auto layout = getLayoutImpl(opr.get()); |
| 178 | setLayoutAttr(opr, layout); |
| 179 | } |
| 180 | for (OpResult result : nestOp->getOpResults()) { |
| 181 | auto layout = getLayoutImpl(result); |
| 182 | setLayoutAttr(result, layout); |
| 183 | } |
| 184 | }); |
| 185 | } |
| 186 | |
| 187 | SmallVector<Value> |
| 188 | xegpu::(OpBuilder &builder, Location loc, |
| 189 | Value value, ArrayRef<int64_t> shape) { |
| 190 | auto vecTy = dyn_cast<VectorType>(value.getType()); |
| 191 | if (!vecTy) |
| 192 | return {value}; |
| 193 | |
| 194 | ArrayRef<int64_t> srcShape = vecTy.getShape(); |
| 195 | if (!computeShapeRatio(shape: srcShape, subShape: shape)) |
| 196 | return {value}; |
| 197 | |
| 198 | SmallVector<Value> result; |
| 199 | for (SmallVector<int64_t> offsets : StaticTileOffsetRange(srcShape, shape)) { |
| 200 | SmallVector<int64_t> staticStrides(offsets.size(), 1); |
| 201 | result.push_back(builder.create<vector::ExtractStridedSliceOp>( |
| 202 | loc, value, offsets, shape, staticStrides)); |
| 203 | } |
| 204 | |
| 205 | return result; |
| 206 | } |
| 207 | |
| 208 | Value xegpu::createVectorWithShapeFromValues(OpBuilder &builder, Location loc, |
| 209 | ValueRange values, |
| 210 | ArrayRef<int64_t> shape) { |
| 211 | VectorType inputTy = dyn_cast<VectorType>(values[0].getType()); |
| 212 | assert(llvm::all_of(values.getTypes(), |
| 213 | [&](Type type) { return type == inputTy; }) && |
| 214 | "values must be of the same VectorType" ); |
| 215 | |
| 216 | Type elemTy = inputTy.getElementType(); |
| 217 | ArrayRef<int64_t> tileShape = inputTy.getShape(); |
| 218 | |
| 219 | VectorType resultTy = VectorType::get(shape, elemTy); |
| 220 | auto zeroAttr = builder.getZeroAttr(elemTy); |
| 221 | Value result = builder.create<arith::ConstantOp>( |
| 222 | loc, resultTy, DenseElementsAttr::get(resultTy, zeroAttr)); |
| 223 | |
| 224 | for (auto [src, offsets] : |
| 225 | llvm::zip_equal(values, StaticTileOffsetRange(shape, tileShape))) { |
| 226 | SmallVector<int64_t> staticStrides(offsets.size(), 1); |
| 227 | result = builder.create<vector::InsertStridedSliceOp>( |
| 228 | loc, src, result, offsets, staticStrides); |
| 229 | } |
| 230 | return result; |
| 231 | } |
| 232 | |
| 233 | void xegpu::doSCFStructuralTypeConversionWithTensorType( |
| 234 | Operation *op, TypeConverter converter) { |
| 235 | MLIRContext *context = op->getContext(); |
| 236 | |
| 237 | auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs, |
| 238 | Location loc) -> Value { |
| 239 | return builder.create<UnrealizedConversionCastOp>(loc, type, inputs) |
| 240 | .getResult(0); |
| 241 | }; |
| 242 | |
| 243 | { // convert VectorType to RankedTensorType for SCF Structural ops |
| 244 | TypeConverter converter; |
| 245 | converter.addConversion(callback: [](Type type) -> Type { return type; }); |
| 246 | converter.addConversion(callback: [](VectorType type) -> Type { |
| 247 | return RankedTensorType::get(type.getShape(), type.getElementType()); |
| 248 | }); |
| 249 | converter.addSourceMaterialization(callback&: materializeCast); |
| 250 | converter.addTargetMaterialization(callback&: materializeCast); |
| 251 | |
| 252 | mlir::ConversionTarget target(*context); |
| 253 | target.addLegalOp<UnrealizedConversionCastOp>(); |
| 254 | |
| 255 | mlir::RewritePatternSet patterns(context); |
| 256 | scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter: converter, patterns, |
| 257 | target); |
| 258 | (void)mlir::applyPartialConversion(op, target, std::move(patterns)); |
| 259 | } |
| 260 | |
| 261 | { // propagate the layout attribute to RankedTensorType by checking |
| 262 | // BuiltInUnrealizedCastOps |
| 263 | // for VectorType to RankedTensorType cast. |
| 264 | op->walk(callback: [](UnrealizedConversionCastOp castOp) { |
| 265 | if (castOp.getNumOperands() != 1 || castOp.getNumResults() != 1) |
| 266 | return WalkResult::skip(); |
| 267 | |
| 268 | Value input = castOp.getInputs()[0]; |
| 269 | Value result = castOp.getResults()[0]; |
| 270 | auto inputTy = dyn_cast<VectorType>(input.getType()); |
| 271 | auto resultTy = dyn_cast<RankedTensorType>(result.getType()); |
| 272 | |
| 273 | // Only look at ops casting from VectorType to RankedTensorType |
| 274 | if (!isa<VectorType>(inputTy) || !isa<RankedTensorType>(resultTy)) |
| 275 | return WalkResult::skip(); |
| 276 | |
| 277 | xegpu::LayoutAttr layout = xegpu::getLayoutAttr(input); |
| 278 | if (!layout) |
| 279 | return WalkResult::skip(); |
| 280 | |
| 281 | RankedTensorType newTy = resultTy.cloneWithEncoding(layout); |
| 282 | result.setType(newTy); |
| 283 | |
| 284 | // update the arguments if user is a LoopLike op. |
| 285 | for (OpOperand &use : result.getUses()) { |
| 286 | if (auto loop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) { |
| 287 | BlockArgument arg = loop.getTiedLoopRegionIterArg(&use); |
| 288 | arg.setType(newTy); |
| 289 | } |
| 290 | // whileOp has two regions, the BlockArgument of the after region |
| 291 | // is not exposed by LoopLikeOpInterface |
| 292 | if (auto whileOp = dyn_cast<scf::WhileOp>(use.getOwner())) { |
| 293 | unsigned idx = use.getOperandNumber(); |
| 294 | BlockArgument arg = whileOp.getAfterArguments()[idx]; |
| 295 | arg.setType(newTy); |
| 296 | } |
| 297 | } |
| 298 | return WalkResult::advance(); |
| 299 | }); |
| 300 | |
| 301 | // using yieldOp as anchor to update the result type of its ParentOp |
| 302 | op->walk(callback: [](scf::YieldOp yieldOp) { |
| 303 | Operation *parentOp = yieldOp->getParentOp(); |
| 304 | for (OpResult r : parentOp->getOpResults()) { |
| 305 | unsigned idx = r.getResultNumber(); |
| 306 | Type resultTy = r.getType(); |
| 307 | Type yieldTy = yieldOp.getResults()[idx].getType(); |
| 308 | if (isa<RankedTensorType>(resultTy) && yieldTy != resultTy) |
| 309 | r.setType(yieldTy); |
| 310 | } |
| 311 | }); |
| 312 | } |
| 313 | |
| 314 | { // perform the conversion from RankedTensorType to VectorType based on the |
| 315 | // LayoutAttr |
| 316 | |
| 317 | // Handle the UnrealizedConversionCastOp introduced by the first step. |
| 318 | // For vector->RankedTensorType, it will simply forward the inputs. |
| 319 | // For RankedTensorType->vector, it will update the inputs with the |
| 320 | // one from the adaptor. |
| 321 | class UnrealizedConversionCastOpPattern |
| 322 | : public OpConversionPattern<mlir::UnrealizedConversionCastOp> { |
| 323 | using OpConversionPattern< |
| 324 | mlir::UnrealizedConversionCastOp>::OpConversionPattern; |
| 325 | |
| 326 | mlir::LogicalResult |
| 327 | matchAndRewrite(mlir::UnrealizedConversionCastOp op, |
| 328 | OneToNOpAdaptor adaptor, |
| 329 | ConversionPatternRewriter &rewriter) const override { |
| 330 | auto inputs = op.getOperands(); |
| 331 | auto outputs = op.getOutputs(); |
| 332 | |
| 333 | if (inputs.size() != 1 || outputs.size() != 1) |
| 334 | return failure(); |
| 335 | |
| 336 | auto inputTy = inputs[0].getType(); |
| 337 | auto outputTy = outputs[0].getType(); |
| 338 | |
| 339 | if (isa<VectorType>(inputTy) && isa<RankedTensorType>(outputTy)) { |
| 340 | rewriter.replaceOpWithMultiple(op, adaptor.getInputs()); |
| 341 | return success(); |
| 342 | } |
| 343 | |
| 344 | if (isa<RankedTensorType>(inputTy) && isa<VectorType>(outputTy)) { |
| 345 | SmallVector<Value> values = flattenValues(adaptor.getInputs()); |
| 346 | auto newOp = rewriter.create<UnrealizedConversionCastOp>( |
| 347 | op.getLoc(), outputTy, values); |
| 348 | rewriter.replaceOp(op, newOp); |
| 349 | return success(); |
| 350 | } |
| 351 | return failure(); |
| 352 | } |
| 353 | }; |
| 354 | |
| 355 | converter.addSourceMaterialization(callback&: materializeCast); |
| 356 | converter.addTargetMaterialization(callback: [&](OpBuilder &builder, TypeRange type, |
| 357 | ValueRange inputs, Location loc) { |
| 358 | return builder.create<UnrealizedConversionCastOp>(loc, type, inputs) |
| 359 | .getResults(); |
| 360 | }); |
| 361 | |
| 362 | mlir::ConversionTarget target(*context); |
| 363 | target.addDynamicallyLegalOp<UnrealizedConversionCastOp>( |
| 364 | [](UnrealizedConversionCastOp op) { |
| 365 | auto isTensorTy = [](Type type) { |
| 366 | return isa<RankedTensorType>(type); |
| 367 | }; |
| 368 | return llvm::none_of(op->getOperandTypes(), isTensorTy) && |
| 369 | llvm::none_of(op->getResultTypes(), isTensorTy); |
| 370 | }); |
| 371 | mlir::RewritePatternSet patterns(context); |
| 372 | patterns.insert<UnrealizedConversionCastOpPattern>(arg&: context); |
| 373 | scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter: converter, patterns, |
| 374 | target); |
| 375 | (void)mlir::applyPartialConversion(op, target, std::move(patterns)); |
| 376 | } |
| 377 | } |
| 378 | |