| 1 | //===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" |
| 10 | |
| 11 | #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" |
| 12 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
| 13 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
| 14 | #include "mlir/Conversion/LLVMCommon/VectorPattern.h" |
| 15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 16 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 17 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 18 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
| 19 | #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
| 20 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 21 | #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" |
| 22 | #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
| 23 | #include "mlir/IR/BuiltinTypes.h" |
| 24 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
| 25 | #include "mlir/IR/PatternMatch.h" |
| 26 | #include "mlir/IR/TypeUtilities.h" |
| 27 | #include "mlir/IR/Value.h" |
| 28 | #include "mlir/Pass/Pass.h" |
| 29 | #include "llvm/Support/Debug.h" |
| 30 | #include "llvm/Support/ErrorHandling.h" |
| 31 | #include "llvm/Support/raw_ostream.h" |
| 32 | #include <optional> |
| 33 | |
| 34 | #define DEBUG_TYPE "nvgpu-to-nvvm" |
| 35 | #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") |
| 36 | #define DBGSE() (llvm::dbgs()) |
| 37 | |
| 38 | namespace mlir { |
| 39 | #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS |
| 40 | #include "mlir/Conversion/Passes.h.inc" |
| 41 | } // namespace mlir |
| 42 | |
| 43 | using namespace mlir; |
| 44 | |
| 45 | /// Number of bits that needs to be excluded when building matrix descriptor for |
| 46 | /// wgmma operations. |
| 47 | constexpr int exclude4LSB = 4; |
| 48 | |
| 49 | /// GPU has 32 bit registers, this function truncates values when larger width |
| 50 | /// is not needed. |
| 51 | static Value truncToI32(ImplicitLocOpBuilder &b, Value value) { |
| 52 | Type type = value.getType(); |
| 53 | assert(llvm::isa<IntegerType>(type) && "expected an integer Value" ); |
| 54 | if (type.getIntOrFloatBitWidth() <= 32) |
| 55 | return value; |
| 56 | return b.create<LLVM::TruncOp>(b.getI32Type(), value); |
| 57 | } |
| 58 | |
| 59 | /// Returns the type for the intrinsic given the vectorResultType of the |
| 60 | /// `gpu.mma.sync` operation. |
| 61 | static Type inferIntrinsicResultType(Type vectorResultType) { |
| 62 | MLIRContext *ctx = vectorResultType.getContext(); |
| 63 | auto a = cast<LLVM::LLVMArrayType>(vectorResultType); |
| 64 | auto f16x2Ty = VectorType::get(2, Float16Type::get(ctx)); |
| 65 | auto i32Ty = IntegerType::get(ctx, 32); |
| 66 | auto i32x2Ty = VectorType::get(2, i32Ty); |
| 67 | Type f64Ty = Float64Type::get(ctx); |
| 68 | Type f64x2Ty = VectorType::get(2, f64Ty); |
| 69 | Type f32Ty = Float32Type::get(ctx); |
| 70 | Type f32x2Ty = VectorType::get(2, f32Ty); |
| 71 | if (a.getElementType() == f16x2Ty) { |
| 72 | return LLVM::LLVMStructType::getLiteral( |
| 73 | ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty)); |
| 74 | } |
| 75 | if (a.getElementType() == i32x2Ty) { |
| 76 | return LLVM::LLVMStructType::getLiteral( |
| 77 | ctx, |
| 78 | SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty)); |
| 79 | } |
| 80 | if (a.getElementType() == f64x2Ty) { |
| 81 | return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty}); |
| 82 | } |
| 83 | if (a.getElementType() == f32x2Ty) { |
| 84 | return LLVM::LLVMStructType::getLiteral( |
| 85 | ctx, |
| 86 | SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty)); |
| 87 | } |
| 88 | if (a.getElementType() == VectorType::get(1, f32Ty)) { |
| 89 | return LLVM::LLVMStructType::getLiteral( |
| 90 | ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty)); |
| 91 | } |
| 92 | return vectorResultType; |
| 93 | } |
| 94 | |
| 95 | /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is |
| 96 | /// always an LLVM struct) into a fragment that is compatible with the vector |
| 97 | /// type of this operation. This involves extracting elements from the struct |
| 98 | /// and inserting them into an LLVM array. These extra data-movement |
| 99 | /// operations should be canonicalized away by the LLVM backend. |
| 100 | static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, |
| 101 | Type resultType, Value intrinsicResult, |
| 102 | RewriterBase &rewriter) { |
| 103 | MLIRContext *ctx = rewriter.getContext(); |
| 104 | auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType); |
| 105 | auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType); |
| 106 | Type i32Ty = rewriter.getI32Type(); |
| 107 | Type f32Ty = rewriter.getF32Type(); |
| 108 | Type f64Ty = rewriter.getF64Type(); |
| 109 | Type f16x2Ty = VectorType::get(2, rewriter.getF16Type()); |
| 110 | Type i32x2Ty = VectorType::get(2, i32Ty); |
| 111 | Type f64x2Ty = VectorType::get(2, f64Ty); |
| 112 | Type f32x2Ty = VectorType::get(2, f32Ty); |
| 113 | Type f32x1Ty = VectorType::get(1, f32Ty); |
| 114 | |
| 115 | auto makeConst = [&](int32_t index) -> Value { |
| 116 | return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32), |
| 117 | rewriter.getI32IntegerAttr(index)); |
| 118 | }; |
| 119 | |
| 120 | if (arrayType) { |
| 121 | SmallVector<Value, 4> elements; |
| 122 | |
| 123 | // The intrinsic returns 32-bit wide elements in a form which can be |
| 124 | // directly bitcasted and inserted into the result vector. |
| 125 | if (arrayType.getElementType() == f16x2Ty || |
| 126 | arrayType.getElementType() == f32x1Ty) { |
| 127 | for (unsigned i = 0; i < structType.getBody().size(); i++) { |
| 128 | Value el = |
| 129 | rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i); |
| 130 | el = rewriter.createOrFold<LLVM::BitcastOp>( |
| 131 | loc, arrayType.getElementType(), el); |
| 132 | elements.push_back(Elt: el); |
| 133 | } |
| 134 | } |
| 135 | |
| 136 | // The intrinsic returns i32, f64, and f32 values as individual scalars, |
| 137 | // even when the result is notionally a 64-bit wide element (e.g. f32x2). We |
| 138 | // need to extract them from the struct and pack them into the 64-bit wide |
| 139 | // rows of the vector result. |
| 140 | if (arrayType.getElementType() == i32x2Ty || |
| 141 | arrayType.getElementType() == f64x2Ty || |
| 142 | arrayType.getElementType() == f32x2Ty) { |
| 143 | |
| 144 | for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) { |
| 145 | Value vec = |
| 146 | rewriter.create<LLVM::PoisonOp>(loc, arrayType.getElementType()); |
| 147 | Value x1 = |
| 148 | rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2); |
| 149 | Value x2 = rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, |
| 150 | i * 2 + 1); |
| 151 | vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec, |
| 152 | x1, makeConst(0)); |
| 153 | vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec, |
| 154 | x2, makeConst(1)); |
| 155 | elements.push_back(Elt: vec); |
| 156 | } |
| 157 | } |
| 158 | |
| 159 | // Create the final vectorized result. |
| 160 | Value result = rewriter.create<LLVM::PoisonOp>(loc, arrayType); |
| 161 | for (const auto &el : llvm::enumerate(First&: elements)) { |
| 162 | result = rewriter.create<LLVM::InsertValueOp>(loc, result, el.value(), |
| 163 | el.index()); |
| 164 | } |
| 165 | return result; |
| 166 | } |
| 167 | |
| 168 | return intrinsicResult; |
| 169 | } |
| 170 | |
| 171 | /// The `gpu.mma.sync` converter below expects matrix fragment operands to be |
| 172 | /// given as 2D `vectors` where the rows are 32b or 64b wide. The |
| 173 | /// `nvvm.mma.sync` op expects these argments to be a given in a long list of |
| 174 | /// scalars of certain types. This function helps unpack the `vector` arguments |
| 175 | /// and cast them to the types expected by `nvvm.mma.sync`. |
| 176 | static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b, |
| 177 | Value operand, |
| 178 | NVVM::MMATypes operandPtxType) { |
| 179 | SmallVector<Value> result; |
| 180 | Type i32Ty = b.getI32Type(); |
| 181 | Type f64Ty = b.getF64Type(); |
| 182 | Type f32Ty = b.getF32Type(); |
| 183 | Type i64Ty = b.getI64Type(); |
| 184 | Type i8x4Ty = VectorType::get(4, b.getI8Type()); |
| 185 | Type i4x8Ty = VectorType::get(8, b.getIntegerType(4)); |
| 186 | Type f32x1Ty = VectorType::get(1, f32Ty); |
| 187 | auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType()); |
| 188 | |
| 189 | for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) { |
| 190 | Value toUse = b.create<LLVM::ExtractValueOp>(operand, i); |
| 191 | |
| 192 | // For 4xi8 vectors, the intrinsic expects these to be provided as i32 |
| 193 | // scalar types. |
| 194 | if (arrayTy.getElementType() == i8x4Ty || |
| 195 | arrayTy.getElementType() == i4x8Ty || |
| 196 | (arrayTy.getElementType() == f32x1Ty && |
| 197 | operandPtxType == NVVM::MMATypes::tf32)) { |
| 198 | result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse)); |
| 199 | continue; |
| 200 | } |
| 201 | |
| 202 | // For some element types (i32, f32, f64), we need to unpack the inner |
| 203 | // vector/array type as well because the intrinsic expects individual |
| 204 | // scalars to be provided. |
| 205 | VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType()); |
| 206 | if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty || |
| 207 | innerArrayTy.getElementType() == f64Ty || |
| 208 | innerArrayTy.getElementType() == f32Ty)) { |
| 209 | for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements(); |
| 210 | idx < innerSize; idx++) { |
| 211 | result.push_back(b.create<LLVM::ExtractElementOp>( |
| 212 | toUse, |
| 213 | b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx)))); |
| 214 | } |
| 215 | continue; |
| 216 | } |
| 217 | result.push_back(Elt: toUse); |
| 218 | } |
| 219 | return result; |
| 220 | } |
| 221 | |
| 222 | /// Returns whether mbarrier object has shared memory address space. |
| 223 | static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType) { |
| 224 | return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace( |
| 225 | barrierType.getMemorySpace())); |
| 226 | } |
| 227 | |
| 228 | /// Returns the memory space attribute of the mbarrier object. |
| 229 | Attribute nvgpu::getMbarrierMemorySpace(MLIRContext *context, |
| 230 | nvgpu::MBarrierGroupType barrierType) { |
| 231 | Attribute memorySpace = {}; |
| 232 | if (isMbarrierShared(barrierType)) { |
| 233 | memorySpace = |
| 234 | IntegerAttr::get(IntegerType::get(context, 64), |
| 235 | nvgpu::NVGPUDialect::kSharedMemoryAddressSpace); |
| 236 | } |
| 237 | return memorySpace; |
| 238 | } |
| 239 | |
| 240 | /// Returns memref type of the mbarrier object. The type is defined in the |
| 241 | /// MBarrierGroupType. |
| 242 | MemRefType nvgpu::getMBarrierMemrefType(MLIRContext *context, |
| 243 | nvgpu::MBarrierGroupType barrierType) { |
| 244 | Attribute memorySpace = nvgpu::getMbarrierMemorySpace(context, barrierType: barrierType); |
| 245 | MemRefLayoutAttrInterface layout; |
| 246 | return MemRefType::get({barrierType.getNumBarriers()}, |
| 247 | IntegerType::get(context, 64), layout, memorySpace); |
| 248 | } |
| 249 | |
| 250 | namespace { |
| 251 | |
| 252 | struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> { |
| 253 | using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern; |
| 254 | |
| 255 | LogicalResult |
| 256 | matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor, |
| 257 | ConversionPatternRewriter &rewriter) const override { |
| 258 | MLIRContext *ctx = getContext(); |
| 259 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
| 260 | |
| 261 | // The result type of ldmatrix will always be a struct of 32bit integer |
| 262 | // registers if more than one 32bit value is returned. Otherwise, the result |
| 263 | // is a single i32. The result type of the GPU operation is always a vector |
| 264 | // of shape (NumRegisters, VectorRegister) where VectorRegister is the |
| 265 | // vector type of the result and always 32 bits long. We bitcast the result |
| 266 | // of the NVVM::LdMatrix to this vector type. |
| 267 | auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]); |
| 268 | if (!vectorResultType) { |
| 269 | return failure(); |
| 270 | } |
| 271 | Type innerVectorType = VectorType::get(vectorResultType.getDimSize(1), |
| 272 | vectorResultType.getElementType()); |
| 273 | |
| 274 | int64_t num32BitRegs = vectorResultType.getDimSize(0); |
| 275 | |
| 276 | Type ldMatrixResultType; |
| 277 | if (num32BitRegs > 1) { |
| 278 | ldMatrixResultType = LLVM::LLVMStructType::getLiteral( |
| 279 | ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type())); |
| 280 | } else { |
| 281 | ldMatrixResultType = rewriter.getI32Type(); |
| 282 | } |
| 283 | |
| 284 | auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType()); |
| 285 | Value srcPtr = |
| 286 | getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType, |
| 287 | adaptor.getSrcMemref(), adaptor.getIndices()); |
| 288 | Value ldMatrixResult = b.create<NVVM::LdMatrixOp>( |
| 289 | ldMatrixResultType, srcPtr, |
| 290 | /*num=*/op.getNumTiles(), |
| 291 | /*layout=*/op.getTranspose() ? NVVM::MMALayout::col |
| 292 | : NVVM::MMALayout::row); |
| 293 | |
| 294 | // The ldmatrix operation returns either a single i32 value or a struct of |
| 295 | // i32 values. Here we unpack those values and cast them back to their |
| 296 | // actual vector type (still of width 32b) and repack them into a result |
| 297 | // struct. |
| 298 | Type finalResultType = typeConverter->convertType(vectorResultType); |
| 299 | Value result = b.create<LLVM::PoisonOp>(finalResultType); |
| 300 | for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) { |
| 301 | Value i32Register = |
| 302 | num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i) |
| 303 | : ldMatrixResult; |
| 304 | Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register); |
| 305 | result = b.create<LLVM::InsertValueOp>(result, casted, i); |
| 306 | } |
| 307 | |
| 308 | rewriter.replaceOp(op, result); |
| 309 | return success(); |
| 310 | } |
| 311 | }; |
| 312 | |
| 313 | /// Convert the given type into the corresponding PTX type (NVVM::MMATypes |
| 314 | /// enum). |
| 315 | static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) { |
| 316 | Type elType = getElementTypeOrSelf(type: t); |
| 317 | if (elType.isInteger(8)) |
| 318 | return NVVM::MMATypes::s8; |
| 319 | if (elType.isInteger(4)) |
| 320 | return NVVM::MMATypes::s4; |
| 321 | if (elType.isF16()) |
| 322 | return NVVM::MMATypes::f16; |
| 323 | if (elType.isF64()) |
| 324 | return NVVM::MMATypes::f64; |
| 325 | if (elType.isF32()) |
| 326 | return NVVM::MMATypes::tf32; |
| 327 | return failure(); |
| 328 | } |
| 329 | |
| 330 | struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> { |
| 331 | using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern; |
| 332 | |
| 333 | LogicalResult |
| 334 | matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor, |
| 335 | ConversionPatternRewriter &rewriter) const override { |
| 336 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
| 337 | // Get the shapes of the MMAMatrix type being used. The shapes will |
| 338 | // choose which intrinsic this op will be lowered to. |
| 339 | VectorType aType = op.getMatrixA().getType(); |
| 340 | VectorType bType = op.getMatrixA().getType(); |
| 341 | VectorType cType = op.getMatrixC().getType(); |
| 342 | |
| 343 | std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray(); |
| 344 | |
| 345 | // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32). |
| 346 | bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName()); |
| 347 | if (aType.getElementType().isF32() && !tf32Enabled) |
| 348 | return failure(); |
| 349 | |
| 350 | FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType); |
| 351 | if (failed(ptxTypeA)) |
| 352 | return op->emitOpError("failed to deduce operand PTX types" ); |
| 353 | FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType); |
| 354 | if (failed(ptxTypeB)) |
| 355 | return op->emitOpError("failed to deduce operand PTX types" ); |
| 356 | std::optional<NVVM::MMATypes> ptxTypeC = |
| 357 | NVVM::MmaOp::inferOperandMMAType(cType.getElementType(), |
| 358 | /*isAccumulator=*/true); |
| 359 | if (!ptxTypeC) |
| 360 | return op->emitError( |
| 361 | "could not infer the PTX type for the accumulator/result" ); |
| 362 | |
| 363 | // TODO: add an attribute to the op to customize this behavior. |
| 364 | std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt); |
| 365 | if (isa<IntegerType>(aType.getElementType())) |
| 366 | overflow = NVVM::MMAIntOverflow::satfinite; |
| 367 | |
| 368 | SmallVector<Value> matA = |
| 369 | unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA); |
| 370 | SmallVector<Value> matB = |
| 371 | unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB); |
| 372 | SmallVector<Value> matC = |
| 373 | unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC); |
| 374 | |
| 375 | Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]); |
| 376 | Type intrinsicResTy = inferIntrinsicResultType( |
| 377 | typeConverter->convertType(op->getResultTypes()[0])); |
| 378 | Value intrinsicResult = b.create<NVVM::MmaOp>( |
| 379 | intrinsicResTy, matA, matB, matC, |
| 380 | /*shape=*/gemmShape, |
| 381 | /*b1Op=*/std::nullopt, |
| 382 | /*intOverflow=*/overflow, |
| 383 | /*multiplicandPtxTypes=*/ |
| 384 | std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB}, |
| 385 | /*multiplicandLayouts=*/ |
| 386 | std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row, |
| 387 | NVVM::MMALayout::col}); |
| 388 | rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, |
| 389 | desiredRetTy, intrinsicResult, |
| 390 | rewriter)); |
| 391 | return success(); |
| 392 | } |
| 393 | }; |
| 394 | |
| 395 | struct ConvertNVGPUToNVVMPass |
| 396 | : public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> { |
| 397 | using Base::Base; |
| 398 | |
| 399 | void getDependentDialects(DialectRegistry ®istry) const override { |
| 400 | registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect, |
| 401 | arith::ArithDialect>(); |
| 402 | } |
| 403 | |
| 404 | void runOnOperation() override { |
| 405 | LowerToLLVMOptions options(&getContext()); |
| 406 | RewritePatternSet patterns(&getContext()); |
| 407 | LLVMTypeConverter converter(&getContext(), options); |
| 408 | IRRewriter rewriter(&getContext()); |
| 409 | populateGpuMemorySpaceAttributeConversions( |
| 410 | typeConverter&: converter, mapping: [](gpu::AddressSpace space) -> unsigned { |
| 411 | switch (space) { |
| 412 | case gpu::AddressSpace::Global: |
| 413 | return static_cast<unsigned>( |
| 414 | NVVM::NVVMMemorySpace::kGlobalMemorySpace); |
| 415 | case gpu::AddressSpace::Workgroup: |
| 416 | return static_cast<unsigned>( |
| 417 | NVVM::NVVMMemorySpace::kSharedMemorySpace); |
| 418 | case gpu::AddressSpace::Private: |
| 419 | return 0; |
| 420 | } |
| 421 | llvm_unreachable("unknown address space enum value" ); |
| 422 | return 0; |
| 423 | }); |
| 424 | /// device-side async tokens cannot be materialized in nvvm. We just |
| 425 | /// convert them to a dummy i32 type in order to easily drop them during |
| 426 | /// conversion. |
| 427 | converter.addConversion(callback: [&](nvgpu::DeviceAsyncTokenType type) -> Type { |
| 428 | return converter.convertType(IntegerType::get(type.getContext(), 32)); |
| 429 | }); |
| 430 | converter.addConversion(callback: [&](nvgpu::WarpgroupAccumulatorType type) -> Type { |
| 431 | Type elemType = type.getFragmented().getElementType(); |
| 432 | int64_t sizeM = type.getFragmented().getDimSize(0); |
| 433 | int64_t sizeN = type.getFragmented().getDimSize(1); |
| 434 | |
| 435 | unsigned numMembers; |
| 436 | if (elemType.isF32() || elemType.isInteger(width: 32)) |
| 437 | numMembers = sizeN / 2; |
| 438 | else if (elemType.isF16()) |
| 439 | numMembers = sizeN / 4; |
| 440 | else |
| 441 | llvm_unreachable("unsupported type for warpgroup accumulator" ); |
| 442 | |
| 443 | SmallVector<Type> innerStructBody; |
| 444 | for (unsigned i = 0; i < numMembers; i++) |
| 445 | innerStructBody.push_back(Elt: elemType); |
| 446 | auto innerStructType = |
| 447 | LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody); |
| 448 | |
| 449 | SmallVector<Type> structBody; |
| 450 | for (int i = 0; i < sizeM; i += kWgmmaSizeM) |
| 451 | structBody.push_back(Elt: innerStructType); |
| 452 | |
| 453 | auto convertedType = |
| 454 | LLVM::LLVMStructType::getLiteral(type.getContext(), structBody); |
| 455 | return converter.convertType(convertedType); |
| 456 | }); |
| 457 | converter.addConversion(callback: [&](nvgpu::MBarrierTokenType type) -> Type { |
| 458 | return converter.convertType(IntegerType::get(type.getContext(), 64)); |
| 459 | }); |
| 460 | converter.addConversion( |
| 461 | callback: [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type { |
| 462 | return converter.convertType(IntegerType::get(type.getContext(), 64)); |
| 463 | }); |
| 464 | converter.addConversion(callback: [&](nvgpu::MBarrierGroupType type) -> Type { |
| 465 | return converter.convertType( |
| 466 | nvgpu::getMBarrierMemrefType(rewriter.getContext(), type)); |
| 467 | }); |
| 468 | converter.addConversion(callback: [&](nvgpu::TensorMapDescriptorType type) -> Type { |
| 469 | return LLVM::LLVMPointerType::get(type.getContext()); |
| 470 | }); |
| 471 | populateNVGPUToNVVMConversionPatterns(converter, patterns); |
| 472 | LLVMConversionTarget target(getContext()); |
| 473 | target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); |
| 474 | target.addLegalDialect<::mlir::arith::ArithDialect>(); |
| 475 | target.addLegalDialect<::mlir::memref::MemRefDialect>(); |
| 476 | target.addLegalDialect<::mlir::NVVM::NVVMDialect>(); |
| 477 | mlir::scf::populateSCFStructuralTypeConversionsAndLegality( |
| 478 | typeConverter: converter, patterns, target); |
| 479 | if (failed(applyPartialConversion(getOperation(), target, |
| 480 | std::move(patterns)))) |
| 481 | signalPassFailure(); |
| 482 | } |
| 483 | }; |
| 484 | |
| 485 | /// Returns the constraints for the sparse MMA inline assembly instruction. |
| 486 | static std::string buildMmaSparseAsmConstraintString(unsigned matASize, |
| 487 | unsigned matBSize, |
| 488 | unsigned matCSize) { |
| 489 | std::string str; |
| 490 | llvm::raw_string_ostream ss(str); |
| 491 | for (unsigned i = 0; i < matCSize; i++) |
| 492 | ss << "=r," ; |
| 493 | for (unsigned i = 0; i < matASize + matBSize + matCSize; i++) |
| 494 | ss << "r," ; |
| 495 | // The final operand is for the sparsity metadata. |
| 496 | // The sparsity selector appears as direct literal. |
| 497 | ss << "r" ; |
| 498 | return str; |
| 499 | } |
| 500 | |
| 501 | /// Returns the string for the `mma.sp.sync` instruction that corresponds to |
| 502 | /// the given parameters. Note that this function doesn't do any validation, |
| 503 | /// it's expected that the provided parameters correspond to a valid |
| 504 | /// instruction. |
| 505 | static std::string buildMmaSparseAsmString( |
| 506 | const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize, |
| 507 | unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB, |
| 508 | NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD, |
| 509 | std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) { |
| 510 | auto ptxTypeStr = [](NVVM::MMATypes ptxType) { |
| 511 | return NVVM::stringifyMMATypes(ptxType); |
| 512 | }; |
| 513 | |
| 514 | std::string asmStr; |
| 515 | llvm::raw_string_ostream ss(asmStr); |
| 516 | ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k" |
| 517 | << shape[2] << ".row.col." ; |
| 518 | |
| 519 | if (overflow) |
| 520 | ss << NVVM::stringifyMMAIntOverflow(*overflow) << "." ; |
| 521 | |
| 522 | ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "." |
| 523 | << ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " " ; |
| 524 | unsigned asmArgIdx = 0; |
| 525 | |
| 526 | // The operand string is structured into sections `{matC elements...}, |
| 527 | // {matA elements...}, {matB elements...}, {matC elements}`. |
| 528 | for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) { |
| 529 | ss << "{" ; |
| 530 | for (unsigned i = 0; i < arrSize; i++) |
| 531 | ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : "" ); |
| 532 | ss << "}," ; |
| 533 | } |
| 534 | ss << "$" << asmArgIdx++ << "," ; |
| 535 | assert(metaDataSelector <= 1); |
| 536 | ss << "0x" << metaDataSelector << ";" ; |
| 537 | return asmStr; |
| 538 | } |
| 539 | |
| 540 | /// Builds an inline assembly operation corresponding to the specified MMA |
| 541 | /// sparse sync operation. |
| 542 | static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm( |
| 543 | ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB, |
| 544 | NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD, |
| 545 | std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData, |
| 546 | ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData, |
| 547 | int64_t metadataSelector, const std::array<int64_t, 3> &shape, |
| 548 | Type intrinsicResultType) { |
| 549 | auto asmDialectAttr = |
| 550 | LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT); |
| 551 | |
| 552 | const unsigned matASize = unpackedAData.size(); |
| 553 | const unsigned matBSize = unpackedB.size(); |
| 554 | const unsigned matCSize = unpackedC.size(); |
| 555 | |
| 556 | std::string asmStr = buildMmaSparseAsmString( |
| 557 | shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC, |
| 558 | ptxTypeD, overflow, metadataSelector); |
| 559 | std::string constraintStr = |
| 560 | buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize); |
| 561 | |
| 562 | SmallVector<Value> asmVals; |
| 563 | asmVals.reserve(N: matASize + matBSize + matCSize + 1); |
| 564 | for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC}) |
| 565 | llvm::append_range(C&: asmVals, R&: args); |
| 566 | asmVals.push_back(Elt: indexData); |
| 567 | |
| 568 | return b.create<LLVM::InlineAsmOp>( |
| 569 | /*resultTypes=*/intrinsicResultType, |
| 570 | /*operands=*/asmVals, |
| 571 | /*asm_string=*/asmStr, |
| 572 | /*constraints=*/constraintStr, |
| 573 | /*has_side_effects=*/true, |
| 574 | /*is_align_stack=*/false, LLVM::TailCallKind::None, |
| 575 | /*asm_dialect=*/asmDialectAttr, |
| 576 | /*operand_attrs=*/ArrayAttr()); |
| 577 | } |
| 578 | |
| 579 | /// Lowers `nvgpu.mma.sp.sync` to inline assembly. |
| 580 | struct NVGPUMmaSparseSyncLowering |
| 581 | : public ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp> { |
| 582 | using ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp>::ConvertOpToLLVMPattern; |
| 583 | |
| 584 | LogicalResult |
| 585 | matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor, |
| 586 | ConversionPatternRewriter &rewriter) const override { |
| 587 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
| 588 | // Get the shapes of the MMAMatrix type being used. The shapes will |
| 589 | // choose which intrinsic this op will be lowered to. |
| 590 | VectorType aType = op.getMatrixA().getType(); |
| 591 | VectorType bType = op.getMatrixB().getType(); |
| 592 | VectorType cType = op.getMatrixC().getType(); |
| 593 | |
| 594 | FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType); |
| 595 | if (failed(ptxTypeA)) |
| 596 | return op->emitOpError("failed to deduce operand PTX types" ); |
| 597 | FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType); |
| 598 | if (failed(ptxTypeB)) |
| 599 | return op->emitOpError("failed to deduce operand PTX types" ); |
| 600 | std::optional<NVVM::MMATypes> ptxTypeC = |
| 601 | NVVM::MmaOp::inferOperandMMAType(cType.getElementType(), |
| 602 | /*isAccumulator=*/true); |
| 603 | if (!ptxTypeC) |
| 604 | return op->emitError( |
| 605 | "could not infer the PTX type for the accumulator/result" ); |
| 606 | |
| 607 | // Same as `mma.sync`, F32 works only with TensorFloat32 (TF32). |
| 608 | bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName()); |
| 609 | if (aType.getElementType().isF32() && !tf32Enabled) |
| 610 | return failure(); |
| 611 | |
| 612 | // TODO: add an attribute to the op to customize this behavior. |
| 613 | std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt); |
| 614 | if (isa<IntegerType>(aType.getElementType())) |
| 615 | overflow = NVVM::MMAIntOverflow::satfinite; |
| 616 | |
| 617 | SmallVector<Value> matA = |
| 618 | unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA); |
| 619 | SmallVector<Value> matB = |
| 620 | unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB); |
| 621 | SmallVector<Value> matC = |
| 622 | unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC); |
| 623 | |
| 624 | Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]); |
| 625 | Type intrinsicResTy = inferIntrinsicResultType( |
| 626 | typeConverter->convertType(op->getResultTypes()[0])); |
| 627 | |
| 628 | // Bitcast the sparse metadata from vector<2xf16> to an i32. |
| 629 | Value sparseMetadata = adaptor.getSparseMetadata(); |
| 630 | if (sparseMetadata.getType() != VectorType::get(2, rewriter.getI16Type())) |
| 631 | return op->emitOpError() << "Expected metadata type to be LLVM " |
| 632 | "VectorType of 2 i16 elements" ; |
| 633 | sparseMetadata = |
| 634 | b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata); |
| 635 | |
| 636 | FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm( |
| 637 | b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB, |
| 638 | matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(), |
| 639 | intrinsicResTy); |
| 640 | if (failed(intrinsicResult)) |
| 641 | return failure(); |
| 642 | |
| 643 | assert((*intrinsicResult).getNumResults() == 1 && |
| 644 | "expected inline asm op returns a single LLVM struct type" ); |
| 645 | rewriter.replaceOp( |
| 646 | op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy, |
| 647 | (*intrinsicResult)->getResult(0), rewriter)); |
| 648 | return success(); |
| 649 | } |
| 650 | }; |
| 651 | |
| 652 | struct NVGPUAsyncCopyLowering |
| 653 | : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> { |
| 654 | using ConvertOpToLLVMPattern< |
| 655 | nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern; |
| 656 | |
| 657 | LogicalResult |
| 658 | matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor, |
| 659 | ConversionPatternRewriter &rewriter) const override { |
| 660 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
| 661 | Location loc = op.getLoc(); |
| 662 | auto dstMemrefType = cast<MemRefType>(op.getDst().getType()); |
| 663 | Value dstPtr = |
| 664 | getStridedElementPtr(rewriter, b.getLoc(), dstMemrefType, |
| 665 | adaptor.getDst(), adaptor.getDstIndices()); |
| 666 | FailureOr<unsigned> dstAddressSpace = |
| 667 | getTypeConverter()->getMemRefAddressSpace(dstMemrefType); |
| 668 | if (failed(Result: dstAddressSpace)) |
| 669 | return rewriter.notifyMatchFailure( |
| 670 | arg&: loc, msg: "destination memref address space not convertible to integer" ); |
| 671 | |
| 672 | auto srcMemrefType = cast<MemRefType>(op.getSrc().getType()); |
| 673 | FailureOr<unsigned> srcAddressSpace = |
| 674 | getTypeConverter()->getMemRefAddressSpace(srcMemrefType); |
| 675 | if (failed(Result: srcAddressSpace)) |
| 676 | return rewriter.notifyMatchFailure( |
| 677 | arg&: loc, msg: "source memref address space not convertible to integer" ); |
| 678 | |
| 679 | Value scrPtr = |
| 680 | getStridedElementPtr(rewriter, loc, srcMemrefType, adaptor.getSrc(), |
| 681 | adaptor.getSrcIndices()); |
| 682 | // Intrinsics takes a global pointer so we need an address space cast. |
| 683 | auto srcPointerGlobalType = LLVM::LLVMPointerType::get( |
| 684 | op->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace); |
| 685 | scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr); |
| 686 | int64_t dstElements = adaptor.getDstElements().getZExtValue(); |
| 687 | int64_t sizeInBytes = |
| 688 | (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8; |
| 689 | // When the optional SrcElements argument is *not* present, the regular |
| 690 | // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global |
| 691 | // memory) to fill DstElements number of elements in the destination |
| 692 | // (shared memory). |
| 693 | Value srcBytes = adaptor.getSrcElements(); |
| 694 | if (srcBytes) { |
| 695 | // When the optional SrcElements argument is present, the source (global |
| 696 | // memory) of CpAsyncOp is read only for SrcElements number of elements. |
| 697 | // The rest of the DstElements in the destination (shared memory) are |
| 698 | // filled with zeros. |
| 699 | Value c3I32 = |
| 700 | b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3)); |
| 701 | Value bitwidth = b.create<LLVM::ConstantOp>( |
| 702 | b.getI32Type(), |
| 703 | b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth())); |
| 704 | Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes); |
| 705 | srcBytes = b.create<LLVM::LShrOp>( |
| 706 | b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32); |
| 707 | } |
| 708 | // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than |
| 709 | // 16 dst bytes. |
| 710 | NVVM::LoadCacheModifierKind cacheModifier = |
| 711 | (op.getBypassL1().value_or(false) && sizeInBytes == 16) |
| 712 | ? NVVM::LoadCacheModifierKind::CG |
| 713 | : NVVM::LoadCacheModifierKind::CA; |
| 714 | |
| 715 | b.create<NVVM::CpAsyncOp>( |
| 716 | dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), |
| 717 | NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier), |
| 718 | srcBytes); |
| 719 | |
| 720 | // Drop the result token. |
| 721 | Value zero = b.create<LLVM::ConstantOp>( |
| 722 | IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0)); |
| 723 | rewriter.replaceOp(op, zero); |
| 724 | return success(); |
| 725 | } |
| 726 | }; |
| 727 | |
| 728 | struct NVGPUAsyncCreateGroupLowering |
| 729 | : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> { |
| 730 | using ConvertOpToLLVMPattern< |
| 731 | nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern; |
| 732 | |
| 733 | LogicalResult |
| 734 | matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor, |
| 735 | ConversionPatternRewriter &rewriter) const override { |
| 736 | rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc()); |
| 737 | // Drop the result token. |
| 738 | Value zero = rewriter.create<LLVM::ConstantOp>( |
| 739 | op->getLoc(), IntegerType::get(op.getContext(), 32), |
| 740 | rewriter.getI32IntegerAttr(0)); |
| 741 | rewriter.replaceOp(op, zero); |
| 742 | return success(); |
| 743 | } |
| 744 | }; |
| 745 | |
| 746 | struct NVGPUAsyncWaitLowering |
| 747 | : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> { |
| 748 | using ConvertOpToLLVMPattern< |
| 749 | nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern; |
| 750 | |
| 751 | LogicalResult |
| 752 | matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor, |
| 753 | ConversionPatternRewriter &rewriter) const override { |
| 754 | // If numGroup is not present pick 0 as a conservative correct value. |
| 755 | int32_t numGroups = adaptor.getNumGroups().value_or(0); |
| 756 | rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups); |
| 757 | rewriter.eraseOp(op: op); |
| 758 | return success(); |
| 759 | } |
| 760 | }; |
| 761 | |
| 762 | /// Creates mbarrier object in shared memory |
| 763 | struct NVGPUMBarrierCreateLowering |
| 764 | : public ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp> { |
| 765 | using ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp>::ConvertOpToLLVMPattern; |
| 766 | |
| 767 | template <typename moduleT> |
| 768 | memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter, |
| 769 | Operation *funcOp, moduleT moduleOp, |
| 770 | MemRefType barrierType) const { |
| 771 | SymbolTable symbolTable(moduleOp); |
| 772 | OpBuilder::InsertionGuard guard(rewriter); |
| 773 | rewriter.setInsertionPoint(&moduleOp.front()); |
| 774 | auto global = rewriter.create<memref::GlobalOp>( |
| 775 | funcOp->getLoc(), "__mbarrier" , |
| 776 | /*sym_visibility=*/rewriter.getStringAttr("private" ), |
| 777 | /*type=*/barrierType, |
| 778 | /*initial_value=*/ElementsAttr(), |
| 779 | /*constant=*/false, |
| 780 | /*alignment=*/rewriter.getI64IntegerAttr(8)); |
| 781 | symbolTable.insert(symbol: global); |
| 782 | return global; |
| 783 | } |
| 784 | |
| 785 | LogicalResult |
| 786 | matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor, |
| 787 | ConversionPatternRewriter &rewriter) const override { |
| 788 | Operation *funcOp = op->getParentOp(); |
| 789 | MemRefType barrierType = nvgpu::getMBarrierMemrefType( |
| 790 | rewriter.getContext(), op.getBarriers().getType()); |
| 791 | |
| 792 | memref::GlobalOp global; |
| 793 | if (auto moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>()) |
| 794 | global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType); |
| 795 | else if (auto moduleOp = funcOp->getParentOfType<ModuleOp>()) |
| 796 | global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType); |
| 797 | |
| 798 | rewriter.setInsertionPoint(op); |
| 799 | rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType, |
| 800 | global.getName()); |
| 801 | return success(); |
| 802 | } |
| 803 | }; |
| 804 | |
| 805 | /// Base class for lowering mbarrier operations to nvvm intrinsics. |
| 806 | template <typename SourceOp> |
| 807 | struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> { |
| 808 | public: |
| 809 | using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern; |
| 810 | /// Returns the base pointer of the mbarrier object. |
| 811 | Value getMbarrierPtr(ImplicitLocOpBuilder &b, |
| 812 | nvgpu::MBarrierGroupType mbarType, Value memrefDesc, |
| 813 | Value mbarId, |
| 814 | ConversionPatternRewriter &rewriter) const { |
| 815 | MemRefType mbarrierMemrefType = |
| 816 | nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType); |
| 817 | return ConvertToLLVMPattern::getStridedElementPtr( |
| 818 | rewriter, b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}); |
| 819 | } |
| 820 | }; |
| 821 | |
| 822 | struct NVGPUMBarrierGetLowering |
| 823 | : public MBarrierBasePattern<nvgpu::MBarrierGetOp> { |
| 824 | using MBarrierBasePattern<nvgpu::MBarrierGetOp>::MBarrierBasePattern; |
| 825 | |
| 826 | LogicalResult |
| 827 | matchAndRewrite(nvgpu::MBarrierGetOp op, OpAdaptor adaptor, |
| 828 | ConversionPatternRewriter &rewriter) const override { |
| 829 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
| 830 | nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType(); |
| 831 | rewriter.setInsertionPoint(op); |
| 832 | Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(), |
| 833 | adaptor.getMbarId(), rewriter); |
| 834 | Type resType = op.getMbarrierPointer().getType(); |
| 835 | rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(op, resType, barrier); |
| 836 | return success(); |
| 837 | } |
| 838 | }; |
| 839 | |
| 840 | /// Lowers `nvgpu.mbarrier.init` to `nvvm.mbarrier.init` |
| 841 | struct NVGPUMBarrierInitLowering |
| 842 | : public MBarrierBasePattern<nvgpu::MBarrierInitOp> { |
| 843 | using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern; |
| 844 | |
| 845 | LogicalResult |
| 846 | matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor, |
| 847 | ConversionPatternRewriter &rewriter) const override { |
| 848 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
| 849 | nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType(); |
| 850 | rewriter.setInsertionPoint(op); |
| 851 | Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(), |
| 852 | adaptor.getMbarId(), rewriter); |
| 853 | Value count = truncToI32(b, adaptor.getCount()); |
| 854 | if (isMbarrierShared(mbarrierType)) { |
| 855 | rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>( |
| 856 | op, barrier, count, adaptor.getPredicate()); |
| 857 | } else { |
| 858 | rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count, |
| 859 | adaptor.getPredicate()); |
| 860 | } |
| 861 | return success(); |
| 862 | } |
| 863 | }; |
| 864 | |
| 865 | /// Lowers `nvgpu.mbarrier.arrive` to `nvvm.mbarrier.arrive` |
| 866 | struct NVGPUMBarrierArriveLowering |
| 867 | : public MBarrierBasePattern<nvgpu::MBarrierArriveOp> { |
| 868 | using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern; |
| 869 | LogicalResult |
| 870 | matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor, |
| 871 | ConversionPatternRewriter &rewriter) const override { |
| 872 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
| 873 | Value barrier = |
| 874 | getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), |
| 875 | adaptor.getMbarId(), rewriter); |
| 876 | Type tokenType = getTypeConverter()->convertType( |
| 877 | nvgpu::MBarrierTokenType::get(op->getContext())); |
| 878 | if (isMbarrierShared(op.getBarriers().getType())) { |
| 879 | rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveSharedOp>(op, tokenType, |
| 880 | barrier); |
| 881 | } else { |
| 882 | rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType, |
| 883 | barrier); |
| 884 | } |
| 885 | return success(); |
| 886 | } |
| 887 | }; |
| 888 | |
| 889 | /// Lowers `nvgpu.mbarrier.arrive.nocomplete` to |
| 890 | /// `nvvm.mbarrier.arrive.nocomplete` |
| 891 | struct NVGPUMBarrierArriveNoCompleteLowering |
| 892 | : public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> { |
| 893 | using MBarrierBasePattern< |
| 894 | nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern; |
| 895 | LogicalResult |
| 896 | matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor, |
| 897 | ConversionPatternRewriter &rewriter) const override { |
| 898 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
| 899 | Value barrier = |
| 900 | getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), |
| 901 | adaptor.getMbarId(), rewriter); |
| 902 | Type tokenType = getTypeConverter()->convertType( |
| 903 | nvgpu::MBarrierTokenType::get(op->getContext())); |
| 904 | Value count = truncToI32(b, adaptor.getCount()); |
| 905 | if (isMbarrierShared(op.getBarriers().getType())) { |
| 906 | rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>( |
| 907 | op, tokenType, barrier, count); |
| 908 | } else { |
| 909 | rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>( |
| 910 | op, tokenType, barrier, count); |
| 911 | } |
| 912 | return success(); |
| 913 | } |
| 914 | }; |
| 915 | |
| 916 | /// Lowers `nvgpu.mbarrier.test.wait` to `nvvm.mbarrier.test.wait` |
| 917 | struct NVGPUMBarrierTestWaitLowering |
| 918 | : public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> { |
| 919 | using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern; |
| 920 | LogicalResult |
| 921 | matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor, |
| 922 | ConversionPatternRewriter &rewriter) const override { |
| 923 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
| 924 | Value barrier = |
| 925 | getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), |
| 926 | adaptor.getMbarId(), rewriter); |
| 927 | Type retType = rewriter.getI1Type(); |
| 928 | if (isMbarrierShared(op.getBarriers().getType())) { |
| 929 | rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitSharedOp>( |
| 930 | op, retType, barrier, adaptor.getToken()); |
| 931 | } else { |
| 932 | rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>( |
| 933 | op, retType, barrier, adaptor.getToken()); |
| 934 | } |
| 935 | return success(); |
| 936 | } |
| 937 | }; |
| 938 | |
| 939 | struct NVGPUMBarrierArriveExpectTxLowering |
| 940 | : public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> { |
| 941 | using MBarrierBasePattern< |
| 942 | nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern; |
| 943 | LogicalResult |
| 944 | matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor, |
| 945 | ConversionPatternRewriter &rewriter) const override { |
| 946 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
| 947 | Value barrier = |
| 948 | getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), |
| 949 | adaptor.getMbarId(), rewriter); |
| 950 | Value txcount = truncToI32(b, adaptor.getTxcount()); |
| 951 | |
| 952 | if (isMbarrierShared(op.getBarriers().getType())) { |
| 953 | rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>( |
| 954 | op, barrier, txcount, adaptor.getPredicate()); |
| 955 | return success(); |
| 956 | } |
| 957 | |
| 958 | rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>( |
| 959 | op, barrier, txcount, adaptor.getPredicate()); |
| 960 | return success(); |
| 961 | } |
| 962 | }; |
| 963 | |
| 964 | struct NVGPUMBarrierTryWaitParityLowering |
| 965 | : public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> { |
| 966 | using MBarrierBasePattern< |
| 967 | nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern; |
| 968 | LogicalResult |
| 969 | matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor, |
| 970 | ConversionPatternRewriter &rewriter) const override { |
| 971 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
| 972 | Value barrier = |
| 973 | getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), |
| 974 | adaptor.getMbarId(), rewriter); |
| 975 | Value ticks = truncToI32(b, adaptor.getTicks()); |
| 976 | Value phase = |
| 977 | b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity()); |
| 978 | |
| 979 | if (isMbarrierShared(op.getBarriers().getType())) { |
| 980 | rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>( |
| 981 | op, barrier, phase, ticks); |
| 982 | return success(); |
| 983 | } |
| 984 | |
| 985 | rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier, |
| 986 | phase, ticks); |
| 987 | return success(); |
| 988 | } |
| 989 | }; |
| 990 | |
| 991 | struct NVGPUTmaAsyncLoadOpLowering |
| 992 | : public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> { |
| 993 | using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern; |
| 994 | LogicalResult |
| 995 | matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor, |
| 996 | ConversionPatternRewriter &rewriter) const override { |
| 997 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
| 998 | auto srcMemrefType = cast<MemRefType>(op.getDst().getType()); |
| 999 | Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType, |
| 1000 | adaptor.getDst(), {}); |
| 1001 | Value barrier = |
| 1002 | getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), |
| 1003 | adaptor.getMbarId(), rewriter); |
| 1004 | |
| 1005 | SmallVector<Value> coords = adaptor.getCoordinates(); |
| 1006 | for (auto [index, value] : llvm::enumerate(coords)) { |
| 1007 | coords[index] = truncToI32(b, value); |
| 1008 | } |
| 1009 | rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>( |
| 1010 | op, dest, adaptor.getTensorMapDescriptor(), coords, barrier, |
| 1011 | ValueRange{}, adaptor.getMulticastMask(), Value{}, |
| 1012 | adaptor.getPredicate()); |
| 1013 | return success(); |
| 1014 | } |
| 1015 | }; |
| 1016 | |
| 1017 | struct NVGPUTmaAsyncStoreOpLowering |
| 1018 | : public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> { |
| 1019 | using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern; |
| 1020 | LogicalResult |
| 1021 | matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor, |
| 1022 | ConversionPatternRewriter &rewriter) const override { |
| 1023 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
| 1024 | auto srcMemrefType = cast<MemRefType>(op.getSrc().getType()); |
| 1025 | Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType, |
| 1026 | adaptor.getSrc(), {}); |
| 1027 | SmallVector<Value> coords = adaptor.getCoordinates(); |
| 1028 | for (auto [index, value] : llvm::enumerate(coords)) { |
| 1029 | coords[index] = truncToI32(b, value); |
| 1030 | } |
| 1031 | |
| 1032 | rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>( |
| 1033 | op, adaptor.getTensorMapDescriptor(), dest, coords, |
| 1034 | adaptor.getPredicate()); |
| 1035 | return success(); |
| 1036 | } |
| 1037 | }; |
| 1038 | |
| 1039 | struct NVGPUGenerateWarpgroupDescriptorLowering |
| 1040 | : public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> { |
| 1041 | using ConvertOpToLLVMPattern< |
| 1042 | nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern; |
| 1043 | |
| 1044 | LogicalResult |
| 1045 | matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor, |
| 1046 | ConversionPatternRewriter &rewriter) const override { |
| 1047 | |
| 1048 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
| 1049 | |
| 1050 | nvgpu::TensorMapSwizzleKind swizzleKind = |
| 1051 | op.getTensorMap().getType().getSwizzle(); |
| 1052 | |
| 1053 | unsigned layout = |
| 1054 | (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128 |
| 1055 | : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64 |
| 1056 | : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32 |
| 1057 | : 1; |
| 1058 | unsigned swizzle = |
| 1059 | (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1 |
| 1060 | : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2 |
| 1061 | : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3 |
| 1062 | : 0; |
| 1063 | |
| 1064 | auto ti64 = b.getIntegerType(64); |
| 1065 | auto makeConst = [&](uint64_t index) -> Value { |
| 1066 | return b.create<LLVM::ConstantOp>(ti64, b.getI64IntegerAttr(index)); |
| 1067 | }; |
| 1068 | auto shiftLeft = [&](Value value, unsigned shift) -> Value { |
| 1069 | return b.create<LLVM::ShlOp>(ti64, value, makeConst(shift)); |
| 1070 | }; |
| 1071 | auto shiftRight = [&](Value value, unsigned shift) -> Value { |
| 1072 | return b.create<LLVM::LShrOp>(ti64, value, makeConst(shift)); |
| 1073 | }; |
| 1074 | auto insertBit = [&](Value desc, Value val, int startBit) { |
| 1075 | return b.create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit)); |
| 1076 | }; |
| 1077 | |
| 1078 | int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0); |
| 1079 | uint64_t strideDimVal = (layout << 3) >> exclude4LSB; |
| 1080 | uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB; |
| 1081 | uint64_t offsetVal = 0; |
| 1082 | |
| 1083 | Value strideDim = makeConst(strideDimVal); |
| 1084 | Value leadDim = makeConst(leadDimVal); |
| 1085 | |
| 1086 | Value baseAddr = getStridedElementPtr( |
| 1087 | rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()), |
| 1088 | adaptor.getTensor(), {}); |
| 1089 | Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr); |
| 1090 | // Just use 14 bits for base address |
| 1091 | Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50); |
| 1092 | |
| 1093 | int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32, |
| 1094 | startLeadBit = 16, startBaseAddrBit = 0; |
| 1095 | Value dsc = makeConst(0); |
| 1096 | // // [62,64) swizzle type |
| 1097 | dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit); |
| 1098 | // // [49,52) base_offset |
| 1099 | dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit); |
| 1100 | // // [32,46) stride |
| 1101 | dsc = insertBit(dsc, strideDim, startStrideBit); |
| 1102 | // // [16,30) leading dimension |
| 1103 | dsc = insertBit(dsc, leadDim, startLeadBit); |
| 1104 | // // [0,14) start_address |
| 1105 | dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit); |
| 1106 | |
| 1107 | LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: " |
| 1108 | << "leading_off:" << leadDimVal << "\t" |
| 1109 | << "stride_off :" << strideDimVal << "\t" |
| 1110 | << "base_offset:" << offsetVal << "\t" |
| 1111 | << "layout_type:" << swizzle << " (" |
| 1112 | << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind) |
| 1113 | << ")\n start_addr : " << baseAddr << "\n" ); |
| 1114 | |
| 1115 | rewriter.replaceOp(op, dsc); |
| 1116 | return success(); |
| 1117 | } |
| 1118 | }; |
| 1119 | |
| 1120 | static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) { |
| 1121 | return b.create<LLVM::ConstantOp>(b.getIntegerType(64), |
| 1122 | b.getI32IntegerAttr(index)); |
| 1123 | } |
| 1124 | |
| 1125 | /// Returns a Value that holds data type enum that is expected by CUDA driver. |
| 1126 | static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) { |
| 1127 | // Enum is from CUDA driver API |
| 1128 | // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html |
| 1129 | enum CUtensorMapDataTypeEnum { |
| 1130 | CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0, |
| 1131 | CU_TENSOR_MAP_DATA_TYPE_UINT16, |
| 1132 | CU_TENSOR_MAP_DATA_TYPE_UINT32, |
| 1133 | CU_TENSOR_MAP_DATA_TYPE_INT32, |
| 1134 | CU_TENSOR_MAP_DATA_TYPE_UINT64, |
| 1135 | CU_TENSOR_MAP_DATA_TYPE_INT64, |
| 1136 | CU_TENSOR_MAP_DATA_TYPE_FLOAT16, |
| 1137 | CU_TENSOR_MAP_DATA_TYPE_FLOAT32, |
| 1138 | CU_TENSOR_MAP_DATA_TYPE_FLOAT64, |
| 1139 | CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, |
| 1140 | CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ, |
| 1141 | CU_TENSOR_MAP_DATA_TYPE_TFLOAT32, |
| 1142 | CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ |
| 1143 | }; |
| 1144 | |
| 1145 | if (type.isUnsignedInteger(width: 8)) |
| 1146 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_UINT8); |
| 1147 | if (type.isUnsignedInteger(width: 16)) |
| 1148 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_UINT16); |
| 1149 | if (type.isUnsignedInteger(width: 32)) |
| 1150 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_UINT32); |
| 1151 | if (type.isUnsignedInteger(width: 64)) |
| 1152 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_UINT64); |
| 1153 | if (type.isSignlessInteger(width: 32)) |
| 1154 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_INT32); |
| 1155 | if (type.isSignlessInteger(width: 64)) |
| 1156 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_INT64); |
| 1157 | if (type.isF16()) |
| 1158 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_FLOAT16); |
| 1159 | if (type.isF32()) |
| 1160 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_FLOAT32); |
| 1161 | if (type.isF64()) |
| 1162 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_FLOAT64); |
| 1163 | if (type.isBF16()) |
| 1164 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_BFLOAT16); |
| 1165 | |
| 1166 | llvm_unreachable("Not supported data type" ); |
| 1167 | } |
| 1168 | |
| 1169 | struct NVGPUTmaCreateDescriptorOpLowering |
| 1170 | : public ConvertOpToLLVMPattern<nvgpu::TmaCreateDescriptorOp> { |
| 1171 | using ConvertOpToLLVMPattern< |
| 1172 | nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern; |
| 1173 | LogicalResult |
| 1174 | matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor, |
| 1175 | ConversionPatternRewriter &rewriter) const override { |
| 1176 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
| 1177 | auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext()); |
| 1178 | Type llvmInt64Type = IntegerType::get(op->getContext(), 64); |
| 1179 | |
| 1180 | Value tensorElementType = |
| 1181 | elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType()); |
| 1182 | auto promotedOperands = getTypeConverter()->promoteOperands( |
| 1183 | b.getLoc(), op->getOperands(), adaptor.getOperands(), b); |
| 1184 | |
| 1185 | Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type, |
| 1186 | makeI64Const(b, 5)); |
| 1187 | for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) { |
| 1188 | Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType, |
| 1189 | boxArrayPtr, makeI64Const(b, index)); |
| 1190 | b.create<LLVM::StoreOp>(value, gep); |
| 1191 | } |
| 1192 | |
| 1193 | nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType(); |
| 1194 | // Set Arguments for the function call |
| 1195 | SmallVector<Value> arguments; |
| 1196 | arguments.push_back(Elt: promotedOperands[0]); // rank |
| 1197 | arguments.push_back(Elt: promotedOperands[1]); // descriptor |
| 1198 | arguments.push_back(Elt: tensorElementType); // data type |
| 1199 | arguments.push_back( |
| 1200 | Elt: makeI64Const(b, index: (int)desc.getInterleave())); // interleave |
| 1201 | arguments.push_back(Elt: makeI64Const(b, index: (int)desc.getSwizzle())); // swizzle |
| 1202 | arguments.push_back(Elt: makeI64Const(b, index: (int)desc.getL2promo())); // l2promo |
| 1203 | arguments.push_back(Elt: makeI64Const(b, index: (int)desc.getOob())); // oob |
| 1204 | arguments.push_back(Elt: boxArrayPtr); // box dimensions |
| 1205 | |
| 1206 | // Set data types of the arguments |
| 1207 | SmallVector<Type> argTypes = { |
| 1208 | llvmInt64Type, /* int64_t tensorRank */ |
| 1209 | llvmPointerType, /* ptr */ |
| 1210 | llvmInt64Type, /* int64_t */ |
| 1211 | llvmInt64Type, /* int64_t */ |
| 1212 | llvmInt64Type, /* int64_t */ |
| 1213 | llvmInt64Type, /* int64_t */ |
| 1214 | llvmInt64Type, /* int64_t */ |
| 1215 | llvmPointerType /* ptr */ |
| 1216 | }; |
| 1217 | FunctionCallBuilder hostRegisterCallBuilder = { |
| 1218 | "mgpuTensorMapEncodeTiledMemref" , llvmPointerType, argTypes}; |
| 1219 | Value tensorMap = |
| 1220 | hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult(); |
| 1221 | |
| 1222 | rewriter.replaceOp(op, tensorMap); |
| 1223 | return success(); |
| 1224 | } |
| 1225 | }; |
| 1226 | |
| 1227 | struct NVGPUWarpgroupMmaOpLowering |
| 1228 | : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> { |
| 1229 | using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern; |
| 1230 | |
| 1231 | /// This is a helper class to generate required NVVM Ops for warp-group level |
| 1232 | /// matrix multiplication. |
| 1233 | /// When the given GEMM shape is larger than the shape of |
| 1234 | /// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp |
| 1235 | /// Op(s), group and execute them asynchronously. The class also handles |
| 1236 | /// waiting for completion and iterates through WarpgroupMatrixDescriptor to |
| 1237 | /// create descriptors for each instruction. |
| 1238 | /// |
| 1239 | /// For example this is the case when the shape of GEMM is 128x128x128 |
| 1240 | /// |
| 1241 | /// nvvm.wgmma.fence.aligned |
| 1242 | /// |
| 1243 | /// nvvm.wgmma.mma.async descA, descB |
| 1244 | /// iterate(descA, descB) |
| 1245 | /// nvvm.wgmma.mma.async descA, descB |
| 1246 | /// [6x times more] |
| 1247 | /// |
| 1248 | /// nvvm.wgmma.group.sync.aligned |
| 1249 | /// nvvm.wgmma.wait.group.sync [groupId] |
| 1250 | /// |
| 1251 | class WarpgroupGemm { |
| 1252 | nvgpu::WarpgroupMmaOp op; |
| 1253 | ImplicitLocOpBuilder b; |
| 1254 | OpAdaptor adaptor; |
| 1255 | |
| 1256 | // Entire shape of the given Op |
| 1257 | int64_t totalM, totalN, totalK; |
| 1258 | |
| 1259 | // Shape of one wgmma instruction |
| 1260 | int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0; |
| 1261 | |
| 1262 | // Iteration counts for GEMM |
| 1263 | int iterationM = 0, iterationN = 0, iterationK = 0; |
| 1264 | |
| 1265 | /// The function returns the shape of wgmma instruction that is defined in |
| 1266 | /// PTX programming guide. |
| 1267 | /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape |
| 1268 | void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) { |
| 1269 | wgmmaM = 64; |
| 1270 | wgmmaN = sizeN; |
| 1271 | if (inputElemType.isTF32()) { |
| 1272 | wgmmaK = 8; |
| 1273 | } else if (inputElemType.isF16() || inputElemType.isBF16()) { |
| 1274 | wgmmaK = 16; |
| 1275 | } else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) || |
| 1276 | inputElemType.isInteger(16)) { |
| 1277 | wgmmaK = 32; |
| 1278 | } else if (inputElemType.isInteger(width: 1)) { |
| 1279 | wgmmaK = 256; |
| 1280 | } else { |
| 1281 | llvm_unreachable("msg: not supported K shape" ); |
| 1282 | } |
| 1283 | LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM |
| 1284 | << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n" ); |
| 1285 | } |
| 1286 | |
| 1287 | /// Generates WGMMATypesAttr from MLIR Type |
| 1288 | NVVM::WGMMATypesAttr generateWgmmaType(Type type, |
| 1289 | bool useF32 = false) const { |
| 1290 | auto getWgmmaType = [=](Type elemType) { |
| 1291 | if (elemType.isF32() || elemType.isTF32()) |
| 1292 | return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32; |
| 1293 | if (elemType.isF16()) |
| 1294 | return NVVM::WGMMATypes::f16; |
| 1295 | if (elemType.isBF16()) |
| 1296 | return NVVM::WGMMATypes::bf16; |
| 1297 | if (isa<Float8E4M3FNType>(elemType)) |
| 1298 | return NVVM::WGMMATypes::e4m3; |
| 1299 | if (isa<Float8E5M2Type>(elemType)) |
| 1300 | return NVVM::WGMMATypes::e5m2; |
| 1301 | if (elemType.isInteger(1)) |
| 1302 | return NVVM::WGMMATypes::b1; |
| 1303 | if (elemType.isInteger(8)) |
| 1304 | return NVVM::WGMMATypes::s8; |
| 1305 | if (elemType.isUnsignedInteger(8)) |
| 1306 | return NVVM::WGMMATypes::u8; |
| 1307 | if (elemType.isInteger(32)) |
| 1308 | return NVVM::WGMMATypes::s32; |
| 1309 | llvm_unreachable("unsupported type" ); |
| 1310 | }; |
| 1311 | return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type)); |
| 1312 | } |
| 1313 | |
| 1314 | /// Generates layout attribute for the input matrix for wgmma instruction |
| 1315 | NVVM::MMALayoutAttr |
| 1316 | generateWgmmaLayout(std::optional<bool> transpose) const { |
| 1317 | if (transpose.value_or(false)) |
| 1318 | return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col); |
| 1319 | return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row); |
| 1320 | } |
| 1321 | |
| 1322 | /// Generates shape attribute for wgmma instruction |
| 1323 | NVVM::MMAShapeAttr generateWgmmaShape() const { |
| 1324 | return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK); |
| 1325 | } |
| 1326 | |
| 1327 | /// Generates scale attributes of output matrix for wgmma instruction |
| 1328 | NVVM::WGMMAScaleOutAttr generateScaleOut() const { |
| 1329 | return NVVM::WGMMAScaleOutAttr::get(op->getContext(), |
| 1330 | NVVM::WGMMAScaleOut::one); |
| 1331 | } |
| 1332 | /// Generates scale attributes of input matrix for wgmma instruction |
| 1333 | NVVM::WGMMAScaleInAttr generateScaleIn() const { |
| 1334 | return NVVM::WGMMAScaleInAttr::get(op->getContext(), |
| 1335 | NVVM::WGMMAScaleIn::one); |
| 1336 | } |
| 1337 | |
| 1338 | /// Basic function to generate Add |
| 1339 | Value makeAdd(Value lhs, Value rhs) { |
| 1340 | return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs); |
| 1341 | }; |
| 1342 | |
| 1343 | /// Moves the descriptor pointer of matrix-A for the next wgmma instruction. |
| 1344 | /// Currently, it only handles row-major. |
| 1345 | /// |
| 1346 | /// It moves the pointer like below for [128][64] size: |
| 1347 | /// +2 +4 +6 |
| 1348 | /// ↓ ↓ ↓ |
| 1349 | /// descA ---> +--+--+--+--+ |
| 1350 | /// |->|->|->|->| |
| 1351 | /// | | | | | |
| 1352 | /// | | | | | |
| 1353 | /// | | | | | |
| 1354 | /// descA+512---> +-----------+ |
| 1355 | /// | | | | | |
| 1356 | /// | | | | | |
| 1357 | /// | | | | | |
| 1358 | /// | | | | | |
| 1359 | /// +-----------+ |
| 1360 | /// |
| 1361 | Value iterateDescriptorA(Value desc, int i, int j, int k) { |
| 1362 | MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor(); |
| 1363 | Type elemA = matrixTypeA.getElementType(); |
| 1364 | int byte = elemA.getIntOrFloatBitWidth() / 8; |
| 1365 | int tileShapeA = matrixTypeA.getDimSize(1); |
| 1366 | int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte; |
| 1367 | incrementVal = incrementVal >> exclude4LSB; |
| 1368 | LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k |
| 1369 | << "] [wgmma descriptors] Descriptor A + " |
| 1370 | << incrementVal << " | \t " ); |
| 1371 | if (!incrementVal) |
| 1372 | return desc; |
| 1373 | return makeAdd(lhs: desc, rhs: makeI64Const(b, index: incrementVal)); |
| 1374 | } |
| 1375 | |
| 1376 | /// Moves the descriptor pointer of matrix-B for the next wgmma instruction. |
| 1377 | /// Currently, it only handles column-major. |
| 1378 | /// |
| 1379 | /// It moves the pointer like below for [128][64] size: |
| 1380 | /// descB ---> +--+--+--+--+--+--+--+--+ |
| 1381 | /// |↓ | | | | | | | | |
| 1382 | /// |↓ | | | | | | | | |
| 1383 | /// |↓ | | | | | | | | |
| 1384 | /// |↓ | | | | | | | | |
| 1385 | /// +--+--+--+--+--+--+--+--+ |
| 1386 | /// |
| 1387 | Value iterateDescriptorB(Value desc, int i, int j, int k) { |
| 1388 | MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor(); |
| 1389 | Type elemB = matrixTypeB.getElementType(); |
| 1390 | int byte = elemB.getIntOrFloatBitWidth() / 8; |
| 1391 | int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte; |
| 1392 | incrementVal = incrementVal >> exclude4LSB; |
| 1393 | LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n" ); |
| 1394 | if (!incrementVal) |
| 1395 | return desc; |
| 1396 | return makeAdd(lhs: desc, rhs: makeI64Const(b, index: incrementVal)); |
| 1397 | } |
| 1398 | |
| 1399 | /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix |
| 1400 | /// descriptors and arranges them based on induction variables: i, j, and k. |
| 1401 | Value generateWgmma(int i, int j, int k, Value matrixC) { |
| 1402 | LLVM_DEBUG(DBGS() << "\t wgmma." |
| 1403 | << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK |
| 1404 | << "(A[" << (iterationM * wgmmaM) << ":" |
| 1405 | << (iterationM * wgmmaM) + wgmmaM << "][" |
| 1406 | << (iterationK * wgmmaK) << ":" |
| 1407 | << (iterationK * wgmmaK + wgmmaK) << "] * " |
| 1408 | << " B[" << (iterationK * wgmmaK) << ":" |
| 1409 | << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" |
| 1410 | << wgmmaN << "])\n" ); |
| 1411 | |
| 1412 | Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k); |
| 1413 | Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k); |
| 1414 | |
| 1415 | Type elemA = op.getDescriptorA().getType().getTensor().getElementType(); |
| 1416 | NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA); |
| 1417 | |
| 1418 | Type elemB = op.getDescriptorB().getType().getTensor().getElementType(); |
| 1419 | NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB); |
| 1420 | |
| 1421 | Type elemD = op.getMatrixC().getType().getFragmented().getElementType(); |
| 1422 | NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true); |
| 1423 | |
| 1424 | NVVM::MMAShapeAttr shape = generateWgmmaShape(); |
| 1425 | NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut(); |
| 1426 | NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn(); |
| 1427 | NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA()); |
| 1428 | NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB()); |
| 1429 | |
| 1430 | auto overflow = NVVM::MMAIntOverflowAttr::get( |
| 1431 | op->getContext(), NVVM::MMAIntOverflow::wrapped); |
| 1432 | |
| 1433 | return b.create<NVVM::WgmmaMmaAsyncOp>( |
| 1434 | matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA, |
| 1435 | itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB, |
| 1436 | overflow); |
| 1437 | } |
| 1438 | |
| 1439 | /// Generates multiple wgmma instructions to complete the given GEMM shape |
| 1440 | Value generateWgmmaGroup() { |
| 1441 | Value wgmmaResult = |
| 1442 | b.create<LLVM::PoisonOp>(adaptor.getMatrixC().getType()); |
| 1443 | |
| 1444 | // Perform GEMM |
| 1445 | SmallVector<Value> wgmmaResults; |
| 1446 | for (int i = 0; i < iterationM; ++i) { |
| 1447 | Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i); |
| 1448 | for (int j = 0; j < iterationN; ++j) |
| 1449 | for (int k = 0; k < iterationK; ++k) |
| 1450 | matrixC = generateWgmma(i, j, k, matrixC); |
| 1451 | wgmmaResults.push_back(Elt: matrixC); |
| 1452 | } |
| 1453 | for (auto [idx, matrix] : llvm::enumerate(First&: wgmmaResults)) { |
| 1454 | wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(), |
| 1455 | wgmmaResult, matrix, idx); |
| 1456 | } |
| 1457 | return wgmmaResult; |
| 1458 | } |
| 1459 | |
| 1460 | public: |
| 1461 | WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b, |
| 1462 | OpAdaptor adaptor) |
| 1463 | : op(op), b(b), adaptor(adaptor) { |
| 1464 | // Find the entire GEMM Shape |
| 1465 | totalM = op.getDescriptorA().getType().getTensor().getDimSize(0); |
| 1466 | totalN = op.getDescriptorB().getType().getTensor().getDimSize(1); |
| 1467 | totalK = op.getDescriptorA().getType().getTensor().getDimSize(1); |
| 1468 | LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN |
| 1469 | << "] += A[" << totalM << "][" << totalK << "] * B[" |
| 1470 | << totalK << "][" << totalN << "] ---===\n" ); |
| 1471 | |
| 1472 | // Find the shape for one wgmma instruction |
| 1473 | findWgmmaShape( |
| 1474 | sizeM: totalM, sizeN: totalN, |
| 1475 | inputElemType: op.getDescriptorA().getType().getTensor().getElementType()); |
| 1476 | |
| 1477 | // Iterations counts to complete the given shape with wgmma shape |
| 1478 | iterationM = totalM / wgmmaM; |
| 1479 | iterationN = totalN / wgmmaN; |
| 1480 | iterationK = totalK / wgmmaK; |
| 1481 | } |
| 1482 | |
| 1483 | /// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It |
| 1484 | /// includes generating a fence Op (WgmmaFenceAlignedOp) before the |
| 1485 | /// instructions and group synchronization, as well as waiting |
| 1486 | /// (WgmmaGroupSyncAlignedOp) for group synchronization |
| 1487 | /// (WgmmaWaitGroupSyncOp) after the instructions. |
| 1488 | Value generateWarpgroupMma() { |
| 1489 | b.create<NVVM::WgmmaFenceAlignedOp>(); |
| 1490 | Value wgmmaResult = generateWgmmaGroup(); |
| 1491 | b.create<NVVM::WgmmaGroupSyncAlignedOp>(); |
| 1492 | b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup()); |
| 1493 | return wgmmaResult; |
| 1494 | } |
| 1495 | }; |
| 1496 | LogicalResult |
| 1497 | matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor, |
| 1498 | ConversionPatternRewriter &rewriter) const override { |
| 1499 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
| 1500 | |
| 1501 | // Step 1. Build a helper class |
| 1502 | WarpgroupGemm warpgroupGemm(op, b, adaptor); |
| 1503 | |
| 1504 | // Step 2. Get the entire GEMM Shape |
| 1505 | Value wgmmaResult = warpgroupGemm.generateWarpgroupMma(); |
| 1506 | |
| 1507 | // Step 3. Replace fragmented result struct with the op results |
| 1508 | rewriter.replaceOp(op, wgmmaResult); |
| 1509 | return success(); |
| 1510 | } |
| 1511 | }; |
| 1512 | |
| 1513 | struct NVGPUWarpgroupMmaStoreOpLowering |
| 1514 | : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> { |
| 1515 | using ConvertOpToLLVMPattern< |
| 1516 | nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern; |
| 1517 | |
| 1518 | /// This function stores a fragmented register matrix owned by a warp group |
| 1519 | /// (128 threads) into a memref. Each thread has 64 registers, each the size |
| 1520 | /// of a struct. |
| 1521 | /// Here is what each threads (T) holds, each `d` is struct value with a |
| 1522 | /// number. |
| 1523 | /// |
| 1524 | /// Threads in warp-group (128 threads) and what they owns in the matrixD: |
| 1525 | /// 0-31 Warp-0 -> MatrixD[0:15 ][0:N] |
| 1526 | /// 32-63 Warp-1 -> MatrixD[16:31][0:N] |
| 1527 | /// 64-95 Warp-2 -> MatrixD[32:47][0:N] |
| 1528 | /// 96-127 Warp-3 -> MatrixD[48:64][0:N] |
| 1529 | /// |
| 1530 | /// Matrix-D: |
| 1531 | /// +______________________________________________________________________+ |
| 1532 | /// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 | |
| 1533 | /// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY| |
| 1534 | /// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY| |
| 1535 | /// ..| .........|.........|.........|.........|........|...........|........| |
| 1536 | /// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW| |
| 1537 | /// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW| |
| 1538 | /// ..| .........|.........|.........|.........|........|...........|........| |
| 1539 | /// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........| |
| 1540 | /// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........| |
| 1541 | /// ..| .........|.........|.........|.........|........|...........|........| |
| 1542 | /// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........| |
| 1543 | /// ..| .........|.........|.........|.........|........|...........|........| |
| 1544 | /// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........| |
| 1545 | /// ..| .........|.........|.........|.........|........|...........|........| |
| 1546 | /// +______________________________________________________________________+ |
| 1547 | /// |
| 1548 | /// \param rewriter: The pattern rewriter. |
| 1549 | /// \param matrixD: Result of the warp-group MMA operation (fragmented |
| 1550 | /// matrix). It is holded by a thread and a struct with 64 elements. |
| 1551 | /// \param dstMemref: The memref where the registers will be stored. |
| 1552 | /// \param offset: the offset within the memref where the registers will be |
| 1553 | /// stored. |
| 1554 | void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD, |
| 1555 | TypedValue<MemRefType> dstMemref, |
| 1556 | int offset) const { |
| 1557 | Type i32 = b.getI32Type(); |
| 1558 | |
| 1559 | auto makeConst = [&](int32_t index) -> Value { |
| 1560 | return b.create<LLVM::ConstantOp>(i32, b.getI32IntegerAttr(index)); |
| 1561 | }; |
| 1562 | Value c1 = makeConst(1); |
| 1563 | Value c2 = makeConst(2); |
| 1564 | Value c4 = makeConst(4); |
| 1565 | Value c8 = makeConst(8); |
| 1566 | Value c16 = makeConst(16); |
| 1567 | Value warpSize = makeConst(kWarpSize); |
| 1568 | |
| 1569 | auto makeMul = [&](Value lhs, Value rhs) -> Value { |
| 1570 | return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs); |
| 1571 | }; |
| 1572 | auto makeAdd = [&](Value lhs, Value rhs) -> Value { |
| 1573 | return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs); |
| 1574 | }; |
| 1575 | |
| 1576 | auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y, |
| 1577 | TypedValue<::mlir::MemRefType> memref) { |
| 1578 | Type it = b.getIndexType(); |
| 1579 | Value idx = b.create<arith::IndexCastOp>(it, x); |
| 1580 | Value idy0 = b.create<arith::IndexCastOp>(it, y); |
| 1581 | Value idy1 = b.create<arith::IndexCastOp>(it, makeAdd(y, c1)); |
| 1582 | Value d0 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i); |
| 1583 | Value d1 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i + 1); |
| 1584 | b.create<memref::StoreOp>(d0, memref, ValueRange{idx, idy0}); |
| 1585 | b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1}); |
| 1586 | }; |
| 1587 | |
| 1588 | Value tidx = b.create<NVVM::ThreadIdXOp>(i32); |
| 1589 | Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize); |
| 1590 | Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize); |
| 1591 | Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4); |
| 1592 | Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4); |
| 1593 | |
| 1594 | Value tj = makeMul(lane4modId, c2); |
| 1595 | Value ti = makeAdd(lane4Id, makeMul(warpId, c16)); |
| 1596 | if (offset) |
| 1597 | ti = makeAdd(ti, makeConst(offset)); |
| 1598 | |
| 1599 | auto structType = cast<LLVM::LLVMStructType>(matrixD.getType()); |
| 1600 | |
| 1601 | // Number of 32-bit registers owns per thread |
| 1602 | constexpr unsigned numAdjacentRegisters = 2; |
| 1603 | // Number of 8x8 matrices one below another per warp |
| 1604 | constexpr unsigned numStackedMatrices = 2; |
| 1605 | |
| 1606 | size_t storeCount = (structType.getBody().size() / |
| 1607 | (numStackedMatrices * numAdjacentRegisters)); |
| 1608 | |
| 1609 | for (size_t i = 0; i < numStackedMatrices; ++i) { |
| 1610 | Value idx = makeAdd(ti, makeMul(makeConst(i), c8)); |
| 1611 | for (size_t j = 0; j < storeCount; ++j) { |
| 1612 | Value idy = makeAdd(tj, makeMul(makeConst(j), c8)); |
| 1613 | size_t structIndex = (i * numAdjacentRegisters) + |
| 1614 | (j * (numStackedMatrices * numAdjacentRegisters)); |
| 1615 | makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref); |
| 1616 | } |
| 1617 | } |
| 1618 | } |
| 1619 | |
| 1620 | LogicalResult |
| 1621 | matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor, |
| 1622 | ConversionPatternRewriter &rewriter) const override { |
| 1623 | int offset = 0; |
| 1624 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
| 1625 | Value matriDValue = adaptor.getMatrixD(); |
| 1626 | auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType()); |
| 1627 | for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) { |
| 1628 | auto structType = cast<LLVM::LLVMStructType>(matrixD); |
| 1629 | Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx); |
| 1630 | storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset); |
| 1631 | offset += structType.getBody().size(); |
| 1632 | } |
| 1633 | rewriter.eraseOp(op: op); |
| 1634 | return success(); |
| 1635 | } |
| 1636 | }; |
| 1637 | |
| 1638 | struct NVGPUWarpgroupMmaInitAccumulatorOpLowering |
| 1639 | : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> { |
| 1640 | using ConvertOpToLLVMPattern< |
| 1641 | nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern; |
| 1642 | LogicalResult |
| 1643 | matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor, |
| 1644 | ConversionPatternRewriter &rewriter) const override { |
| 1645 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
| 1646 | LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>( |
| 1647 | getTypeConverter()->convertType(op.getMatrixC().getType())); |
| 1648 | Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front()) |
| 1649 | .getBody() |
| 1650 | .front(); |
| 1651 | Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType)); |
| 1652 | Value packStruct = b.create<LLVM::PoisonOp>(packStructType); |
| 1653 | SmallVector<Value> innerStructs; |
| 1654 | // Unpack the structs and set all values to zero |
| 1655 | for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) { |
| 1656 | auto structType = cast<LLVM::LLVMStructType>(s); |
| 1657 | Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx); |
| 1658 | for (unsigned i = 0; i < structType.getBody().size(); ++i) { |
| 1659 | structValue = b.create<LLVM::InsertValueOp>( |
| 1660 | structType, structValue, zero, ArrayRef<int64_t>({i})); |
| 1661 | } |
| 1662 | innerStructs.push_back(structValue); |
| 1663 | } |
| 1664 | // Pack the inner structs into a single struct |
| 1665 | for (auto [idx, matrix] : llvm::enumerate(First&: innerStructs)) { |
| 1666 | packStruct = b.create<LLVM::InsertValueOp>(packStruct.getType(), |
| 1667 | packStruct, matrix, idx); |
| 1668 | } |
| 1669 | rewriter.replaceOp(op, packStruct); |
| 1670 | return success(); |
| 1671 | } |
| 1672 | }; |
| 1673 | |
| 1674 | struct NVGPUTmaFenceOpLowering |
| 1675 | : public ConvertOpToLLVMPattern<nvgpu::TmaFenceOp> { |
| 1676 | using ConvertOpToLLVMPattern<nvgpu::TmaFenceOp>::ConvertOpToLLVMPattern; |
| 1677 | LogicalResult |
| 1678 | matchAndRewrite(nvgpu::TmaFenceOp op, OpAdaptor adaptor, |
| 1679 | ConversionPatternRewriter &rewriter) const override { |
| 1680 | MLIRContext *ctx = op.getContext(); |
| 1681 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
| 1682 | auto i32Ty = b.getI32Type(); |
| 1683 | Value tensormapSize = |
| 1684 | b.create<LLVM::ConstantOp>(i32Ty, rewriter.getI32IntegerAttr(128)); |
| 1685 | |
| 1686 | auto memscope = |
| 1687 | NVVM::MemScopeKindAttr::get(ctx, ::mlir::NVVM::MemScopeKind::SYS); |
| 1688 | |
| 1689 | rewriter.replaceOpWithNewOp<NVVM::FenceProxyAcquireOp>( |
| 1690 | op, memscope, adaptor.getTensorMapDescriptor(), tensormapSize); |
| 1691 | |
| 1692 | return success(); |
| 1693 | } |
| 1694 | }; |
| 1695 | |
| 1696 | struct NVGPUTmaPrefetchOpLowering |
| 1697 | : public ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp> { |
| 1698 | using ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp>::ConvertOpToLLVMPattern; |
| 1699 | LogicalResult |
| 1700 | matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor, |
| 1701 | ConversionPatternRewriter &rewriter) const override { |
| 1702 | rewriter.replaceOpWithNewOp<NVVM::PrefetchTensorMapOp>( |
| 1703 | op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate()); |
| 1704 | return success(); |
| 1705 | } |
| 1706 | }; |
| 1707 | |
| 1708 | struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> { |
| 1709 | using ConvertOpToLLVMPattern<nvgpu::RcpOp>::ConvertOpToLLVMPattern; |
| 1710 | LogicalResult |
| 1711 | matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor, |
| 1712 | ConversionPatternRewriter &rewriter) const override { |
| 1713 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
| 1714 | auto i64Ty = b.getI64Type(); |
| 1715 | auto f32Ty = b.getF32Type(); |
| 1716 | VectorType inTy = op.getIn().getType(); |
| 1717 | // apply rcp.approx.ftz.f on each element in vector. |
| 1718 | auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) { |
| 1719 | Value ret1DVec = b.create<LLVM::PoisonOp>(llvm1DVectorTy); |
| 1720 | int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements(); |
| 1721 | for (int i = 0; i < numElems; i++) { |
| 1722 | Value idx = b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(i)); |
| 1723 | Value elem = b.create<LLVM::ExtractElementOp>(inVec, idx); |
| 1724 | Value dst = b.create<NVVM::RcpApproxFtzF32Op>(f32Ty, elem); |
| 1725 | ret1DVec = b.create<LLVM::InsertElementOp>(ret1DVec, dst, idx); |
| 1726 | } |
| 1727 | return ret1DVec; |
| 1728 | }; |
| 1729 | if (inTy.getRank() == 1) { |
| 1730 | rewriter.replaceOp(op, convert1DVec(inTy, adaptor.getIn())); |
| 1731 | return success(); |
| 1732 | } |
| 1733 | return LLVM::detail::handleMultidimensionalVectors( |
| 1734 | op: op.getOperation(), operands: adaptor.getOperands(), typeConverter: *(this->getTypeConverter()), |
| 1735 | createOperand: [&](Type llvm1DVectorTy, ValueRange operands) -> Value { |
| 1736 | OpAdaptor adaptor(operands); |
| 1737 | return convert1DVec(llvm1DVectorTy, adaptor.getIn()); |
| 1738 | }, |
| 1739 | rewriter); |
| 1740 | } |
| 1741 | }; |
| 1742 | } // namespace |
| 1743 | |
| 1744 | void mlir::populateNVGPUToNVVMConversionPatterns( |
| 1745 | const LLVMTypeConverter &converter, RewritePatternSet &patterns) { |
| 1746 | patterns.add< |
| 1747 | NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create |
| 1748 | NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init |
| 1749 | NVGPUMBarrierGetLowering, // nvgpu.mbarrier.get |
| 1750 | NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive |
| 1751 | NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete |
| 1752 | NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity |
| 1753 | NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity |
| 1754 | NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load |
| 1755 | NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store |
| 1756 | NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor |
| 1757 | NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor |
| 1758 | NVGPUTmaFenceOpLowering, // nvgpu.tma.fence.descriptor |
| 1759 | NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx |
| 1760 | NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor |
| 1761 | NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma |
| 1762 | NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store |
| 1763 | NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator |
| 1764 | MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering, |
| 1765 | NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering, |
| 1766 | NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(arg: converter); |
| 1767 | } |
| 1768 | |