| 1 | //===-- CUFGPUToLLVMConversion.cpp ----------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "flang/Optimizer/Transforms/CUFGPUToLLVMConversion.h" |
| 10 | #include "flang/Optimizer/Builder/CUFCommon.h" |
| 11 | #include "flang/Optimizer/CodeGen/TypeConverter.h" |
| 12 | #include "flang/Optimizer/Dialect/CUF/CUFOps.h" |
| 13 | #include "flang/Optimizer/Support/DataLayout.h" |
| 14 | #include "flang/Runtime/CUDA/common.h" |
| 15 | #include "flang/Support/Fortran.h" |
| 16 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
| 17 | #include "mlir/Dialect/DLTI/DLTI.h" |
| 18 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 19 | #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
| 20 | #include "mlir/Pass/Pass.h" |
| 21 | #include "mlir/Transforms/DialectConversion.h" |
| 22 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 23 | #include "llvm/Support/FormatVariadic.h" |
| 24 | |
| 25 | namespace fir { |
| 26 | #define GEN_PASS_DEF_CUFGPUTOLLVMCONVERSION |
| 27 | #include "flang/Optimizer/Transforms/Passes.h.inc" |
| 28 | } // namespace fir |
| 29 | |
| 30 | using namespace fir; |
| 31 | using namespace mlir; |
| 32 | using namespace Fortran::runtime; |
| 33 | |
| 34 | namespace { |
| 35 | |
| 36 | static mlir::Value createKernelArgArray(mlir::Location loc, |
| 37 | mlir::ValueRange operands, |
| 38 | mlir::PatternRewriter &rewriter) { |
| 39 | |
| 40 | auto *ctx = rewriter.getContext(); |
| 41 | llvm::SmallVector<mlir::Type> structTypes(operands.size(), nullptr); |
| 42 | |
| 43 | for (auto [i, arg] : llvm::enumerate(operands)) |
| 44 | structTypes[i] = arg.getType(); |
| 45 | |
| 46 | auto structTy = mlir::LLVM::LLVMStructType::getLiteral(ctx, structTypes); |
| 47 | auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); |
| 48 | mlir::Type i32Ty = rewriter.getI32Type(); |
| 49 | auto zero = rewriter.create<mlir::LLVM::ConstantOp>( |
| 50 | loc, i32Ty, rewriter.getIntegerAttr(i32Ty, 0)); |
| 51 | auto one = rewriter.create<mlir::LLVM::ConstantOp>( |
| 52 | loc, i32Ty, rewriter.getIntegerAttr(i32Ty, 1)); |
| 53 | mlir::Value argStruct = |
| 54 | rewriter.create<mlir::LLVM::AllocaOp>(loc, ptrTy, structTy, one); |
| 55 | auto size = rewriter.create<mlir::LLVM::ConstantOp>( |
| 56 | loc, i32Ty, rewriter.getIntegerAttr(i32Ty, structTypes.size())); |
| 57 | mlir::Value argArray = |
| 58 | rewriter.create<mlir::LLVM::AllocaOp>(loc, ptrTy, ptrTy, size); |
| 59 | |
| 60 | for (auto [i, arg] : llvm::enumerate(operands)) { |
| 61 | auto indice = rewriter.create<mlir::LLVM::ConstantOp>( |
| 62 | loc, i32Ty, rewriter.getIntegerAttr(i32Ty, i)); |
| 63 | mlir::Value structMember = rewriter.create<LLVM::GEPOp>( |
| 64 | loc, ptrTy, structTy, argStruct, |
| 65 | mlir::ArrayRef<mlir::Value>({zero, indice})); |
| 66 | rewriter.create<LLVM::StoreOp>(loc, arg, structMember); |
| 67 | mlir::Value arrayMember = rewriter.create<LLVM::GEPOp>( |
| 68 | loc, ptrTy, ptrTy, argArray, mlir::ArrayRef<mlir::Value>({indice})); |
| 69 | rewriter.create<LLVM::StoreOp>(loc, structMember, arrayMember); |
| 70 | } |
| 71 | return argArray; |
| 72 | } |
| 73 | |
| 74 | struct GPULaunchKernelConversion |
| 75 | : public mlir::ConvertOpToLLVMPattern<mlir::gpu::LaunchFuncOp> { |
| 76 | explicit GPULaunchKernelConversion( |
| 77 | const fir::LLVMTypeConverter &typeConverter, mlir::PatternBenefit benefit) |
| 78 | : mlir::ConvertOpToLLVMPattern<mlir::gpu::LaunchFuncOp>(typeConverter, |
| 79 | benefit) {} |
| 80 | |
| 81 | using OpAdaptor = typename mlir::gpu::LaunchFuncOp::Adaptor; |
| 82 | |
| 83 | mlir::LogicalResult |
| 84 | matchAndRewrite(mlir::gpu::LaunchFuncOp op, OpAdaptor adaptor, |
| 85 | mlir::ConversionPatternRewriter &rewriter) const override { |
| 86 | // Only convert gpu.launch_func for CUDA Fortran. |
| 87 | if (!op.getOperation()->getAttrOfType<cuf::ProcAttributeAttr>( |
| 88 | cuf::getProcAttrName())) |
| 89 | return mlir::failure(); |
| 90 | |
| 91 | mlir::Location loc = op.getLoc(); |
| 92 | auto *ctx = rewriter.getContext(); |
| 93 | mlir::ModuleOp mod = op->getParentOfType<mlir::ModuleOp>(); |
| 94 | mlir::Value dynamicMemorySize = op.getDynamicSharedMemorySize(); |
| 95 | mlir::Type i32Ty = rewriter.getI32Type(); |
| 96 | if (!dynamicMemorySize) |
| 97 | dynamicMemorySize = rewriter.create<mlir::LLVM::ConstantOp>( |
| 98 | loc, i32Ty, rewriter.getIntegerAttr(i32Ty, 0)); |
| 99 | |
| 100 | mlir::Value kernelArgs = |
| 101 | createKernelArgArray(loc, adaptor.getKernelOperands(), rewriter); |
| 102 | |
| 103 | auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); |
| 104 | auto kernel = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>(op.getKernelName()); |
| 105 | mlir::Value kernelPtr; |
| 106 | if (!kernel) { |
| 107 | auto funcOp = mod.lookupSymbol<mlir::func::FuncOp>(op.getKernelName()); |
| 108 | if (!funcOp) |
| 109 | return mlir::failure(); |
| 110 | kernelPtr = |
| 111 | rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, funcOp.getName()); |
| 112 | } else { |
| 113 | kernelPtr = |
| 114 | rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, kernel.getName()); |
| 115 | } |
| 116 | |
| 117 | auto llvmIntPtrType = mlir::IntegerType::get( |
| 118 | ctx, this->getTypeConverter()->getPointerBitwidth(0)); |
| 119 | auto voidTy = mlir::LLVM::LLVMVoidType::get(ctx); |
| 120 | |
| 121 | mlir::Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, ptrTy); |
| 122 | |
| 123 | if (op.hasClusterSize()) { |
| 124 | auto funcOp = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>( |
| 125 | RTNAME_STRING(CUFLaunchClusterKernel)); |
| 126 | auto funcTy = mlir::LLVM::LLVMFunctionType::get( |
| 127 | voidTy, |
| 128 | {ptrTy, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, |
| 129 | llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, |
| 130 | llvmIntPtrType, llvmIntPtrType, ptrTy, i32Ty, ptrTy, ptrTy}, |
| 131 | /*isVarArg=*/false); |
| 132 | auto cufLaunchClusterKernel = mlir::SymbolRefAttr::get( |
| 133 | mod.getContext(), RTNAME_STRING(CUFLaunchClusterKernel)); |
| 134 | if (!funcOp) { |
| 135 | mlir::OpBuilder::InsertionGuard insertGuard(rewriter); |
| 136 | rewriter.setInsertionPointToStart(mod.getBody()); |
| 137 | auto launchKernelFuncOp = rewriter.create<mlir::LLVM::LLVMFuncOp>( |
| 138 | loc, RTNAME_STRING(CUFLaunchClusterKernel), funcTy); |
| 139 | launchKernelFuncOp.setVisibility( |
| 140 | mlir::SymbolTable::Visibility::Private); |
| 141 | } |
| 142 | |
| 143 | mlir::Value stream = nullPtr; |
| 144 | if (!adaptor.getAsyncDependencies().empty()) { |
| 145 | if (adaptor.getAsyncDependencies().size() != 1) |
| 146 | return rewriter.notifyMatchFailure( |
| 147 | op, "Can only convert with exactly one stream dependency." ); |
| 148 | stream = adaptor.getAsyncDependencies().front(); |
| 149 | } |
| 150 | |
| 151 | rewriter.create<mlir::LLVM::CallOp>( |
| 152 | loc, funcTy, cufLaunchClusterKernel, |
| 153 | mlir::ValueRange{kernelPtr, adaptor.getClusterSizeX(), |
| 154 | adaptor.getClusterSizeY(), adaptor.getClusterSizeZ(), |
| 155 | adaptor.getGridSizeX(), adaptor.getGridSizeY(), |
| 156 | adaptor.getGridSizeZ(), adaptor.getBlockSizeX(), |
| 157 | adaptor.getBlockSizeY(), adaptor.getBlockSizeZ(), |
| 158 | stream, dynamicMemorySize, kernelArgs, nullPtr}); |
| 159 | rewriter.eraseOp(op); |
| 160 | } else { |
| 161 | auto procAttr = |
| 162 | op->getAttrOfType<cuf::ProcAttributeAttr>(cuf::getProcAttrName()); |
| 163 | bool isGridGlobal = |
| 164 | procAttr && procAttr.getValue() == cuf::ProcAttribute::GridGlobal; |
| 165 | llvm::StringRef fctName = isGridGlobal |
| 166 | ? RTNAME_STRING(CUFLaunchCooperativeKernel) |
| 167 | : RTNAME_STRING(CUFLaunchKernel); |
| 168 | auto funcOp = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>(fctName); |
| 169 | auto funcTy = mlir::LLVM::LLVMFunctionType::get( |
| 170 | voidTy, |
| 171 | {ptrTy, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, |
| 172 | llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, ptrTy, i32Ty, ptrTy, |
| 173 | ptrTy}, |
| 174 | /*isVarArg=*/false); |
| 175 | auto cufLaunchKernel = |
| 176 | mlir::SymbolRefAttr::get(mod.getContext(), fctName); |
| 177 | if (!funcOp) { |
| 178 | mlir::OpBuilder::InsertionGuard insertGuard(rewriter); |
| 179 | rewriter.setInsertionPointToStart(mod.getBody()); |
| 180 | auto launchKernelFuncOp = |
| 181 | rewriter.create<mlir::LLVM::LLVMFuncOp>(loc, fctName, funcTy); |
| 182 | launchKernelFuncOp.setVisibility( |
| 183 | mlir::SymbolTable::Visibility::Private); |
| 184 | } |
| 185 | |
| 186 | mlir::Value stream = nullPtr; |
| 187 | if (!adaptor.getAsyncDependencies().empty()) { |
| 188 | if (adaptor.getAsyncDependencies().size() != 1) |
| 189 | return rewriter.notifyMatchFailure( |
| 190 | op, "Can only convert with exactly one stream dependency." ); |
| 191 | stream = adaptor.getAsyncDependencies().front(); |
| 192 | } |
| 193 | |
| 194 | rewriter.create<mlir::LLVM::CallOp>( |
| 195 | loc, funcTy, cufLaunchKernel, |
| 196 | mlir::ValueRange{kernelPtr, adaptor.getGridSizeX(), |
| 197 | adaptor.getGridSizeY(), adaptor.getGridSizeZ(), |
| 198 | adaptor.getBlockSizeX(), adaptor.getBlockSizeY(), |
| 199 | adaptor.getBlockSizeZ(), stream, dynamicMemorySize, |
| 200 | kernelArgs, nullPtr}); |
| 201 | rewriter.eraseOp(op); |
| 202 | } |
| 203 | |
| 204 | return mlir::success(); |
| 205 | } |
| 206 | }; |
| 207 | |
| 208 | static std::string getFuncName(cuf::SharedMemoryOp op) { |
| 209 | if (auto gpuFuncOp = op->getParentOfType<mlir::gpu::GPUFuncOp>()) |
| 210 | return gpuFuncOp.getName().str(); |
| 211 | if (auto funcOp = op->getParentOfType<mlir::func::FuncOp>()) |
| 212 | return funcOp.getName().str(); |
| 213 | if (auto llvmFuncOp = op->getParentOfType<mlir::LLVM::LLVMFuncOp>()) |
| 214 | return llvmFuncOp.getSymName().str(); |
| 215 | return "" ; |
| 216 | } |
| 217 | |
| 218 | static mlir::Value createAddressOfOp(mlir::ConversionPatternRewriter &rewriter, |
| 219 | mlir::Location loc, |
| 220 | gpu::GPUModuleOp gpuMod, |
| 221 | std::string &sharedGlobalName) { |
| 222 | auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get( |
| 223 | rewriter.getContext(), mlir::NVVM::NVVMMemorySpace::kSharedMemorySpace); |
| 224 | if (auto g = gpuMod.lookupSymbol<fir::GlobalOp>(sharedGlobalName)) |
| 225 | return rewriter.create<mlir::LLVM::AddressOfOp>(loc, llvmPtrTy, |
| 226 | g.getSymName()); |
| 227 | if (auto g = gpuMod.lookupSymbol<mlir::LLVM::GlobalOp>(sharedGlobalName)) |
| 228 | return rewriter.create<mlir::LLVM::AddressOfOp>(loc, llvmPtrTy, |
| 229 | g.getSymName()); |
| 230 | return {}; |
| 231 | } |
| 232 | |
| 233 | struct CUFSharedMemoryOpConversion |
| 234 | : public mlir::ConvertOpToLLVMPattern<cuf::SharedMemoryOp> { |
| 235 | explicit CUFSharedMemoryOpConversion( |
| 236 | const fir::LLVMTypeConverter &typeConverter, mlir::PatternBenefit benefit) |
| 237 | : mlir::ConvertOpToLLVMPattern<cuf::SharedMemoryOp>(typeConverter, |
| 238 | benefit) {} |
| 239 | using OpAdaptor = typename cuf::SharedMemoryOp::Adaptor; |
| 240 | |
| 241 | mlir::LogicalResult |
| 242 | matchAndRewrite(cuf::SharedMemoryOp op, OpAdaptor adaptor, |
| 243 | mlir::ConversionPatternRewriter &rewriter) const override { |
| 244 | mlir::Location loc = op->getLoc(); |
| 245 | if (!op.getOffset()) |
| 246 | mlir::emitError(loc, |
| 247 | "cuf.shared_memory must have an offset for code gen" ); |
| 248 | |
| 249 | auto gpuMod = op->getParentOfType<gpu::GPUModuleOp>(); |
| 250 | std::string sharedGlobalName = |
| 251 | (getFuncName(op) + llvm::Twine(cudaSharedMemSuffix)).str(); |
| 252 | mlir::Value sharedGlobalAddr = |
| 253 | createAddressOfOp(rewriter, loc, gpuMod, sharedGlobalName); |
| 254 | |
| 255 | if (!sharedGlobalAddr) |
| 256 | mlir::emitError(loc, "Could not find the shared global operation\n" ); |
| 257 | |
| 258 | auto castPtr = rewriter.create<mlir::LLVM::AddrSpaceCastOp>( |
| 259 | loc, mlir::LLVM::LLVMPointerType::get(rewriter.getContext()), |
| 260 | sharedGlobalAddr); |
| 261 | mlir::Type baseType = castPtr->getResultTypes().front(); |
| 262 | llvm::SmallVector<mlir::LLVM::GEPArg> gepArgs = {op.getOffset()}; |
| 263 | mlir::Value shmemPtr = rewriter.create<mlir::LLVM::GEPOp>( |
| 264 | loc, baseType, rewriter.getI8Type(), castPtr, gepArgs); |
| 265 | rewriter.replaceOp(op, {shmemPtr}); |
| 266 | return mlir::success(); |
| 267 | } |
| 268 | }; |
| 269 | |
| 270 | struct CUFStreamCastConversion |
| 271 | : public mlir::ConvertOpToLLVMPattern<cuf::StreamCastOp> { |
| 272 | explicit CUFStreamCastConversion(const fir::LLVMTypeConverter &typeConverter, |
| 273 | mlir::PatternBenefit benefit) |
| 274 | : mlir::ConvertOpToLLVMPattern<cuf::StreamCastOp>(typeConverter, |
| 275 | benefit) {} |
| 276 | using OpAdaptor = typename cuf::StreamCastOp::Adaptor; |
| 277 | |
| 278 | mlir::LogicalResult |
| 279 | matchAndRewrite(cuf::StreamCastOp op, OpAdaptor adaptor, |
| 280 | mlir::ConversionPatternRewriter &rewriter) const override { |
| 281 | rewriter.replaceOp(op, adaptor.getStream()); |
| 282 | return mlir::success(); |
| 283 | } |
| 284 | }; |
| 285 | |
| 286 | class CUFGPUToLLVMConversion |
| 287 | : public fir::impl::CUFGPUToLLVMConversionBase<CUFGPUToLLVMConversion> { |
| 288 | public: |
| 289 | void runOnOperation() override { |
| 290 | auto *ctx = &getContext(); |
| 291 | mlir::RewritePatternSet patterns(ctx); |
| 292 | mlir::ConversionTarget target(*ctx); |
| 293 | |
| 294 | mlir::Operation *op = getOperation(); |
| 295 | mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(op); |
| 296 | if (!module) |
| 297 | return signalPassFailure(); |
| 298 | |
| 299 | std::optional<mlir::DataLayout> dl = fir::support::getOrSetMLIRDataLayout( |
| 300 | module, /*allowDefaultLayout=*/false); |
| 301 | fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false, |
| 302 | /*forceUnifiedTBAATree=*/false, *dl); |
| 303 | cuf::populateCUFGPUToLLVMConversionPatterns(typeConverter, patterns); |
| 304 | |
| 305 | target.addDynamicallyLegalOp<mlir::gpu::LaunchFuncOp>( |
| 306 | [&](mlir::gpu::LaunchFuncOp op) { |
| 307 | if (op.getOperation()->getAttrOfType<cuf::ProcAttributeAttr>( |
| 308 | cuf::getProcAttrName())) |
| 309 | return false; |
| 310 | return true; |
| 311 | }); |
| 312 | |
| 313 | target.addIllegalOp<cuf::SharedMemoryOp>(); |
| 314 | target.addLegalDialect<mlir::LLVM::LLVMDialect>(); |
| 315 | if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, |
| 316 | std::move(patterns)))) { |
| 317 | mlir::emitError(mlir::UnknownLoc::get(ctx), |
| 318 | "error in CUF GPU op conversion\n" ); |
| 319 | signalPassFailure(); |
| 320 | } |
| 321 | } |
| 322 | }; |
| 323 | } // namespace |
| 324 | |
| 325 | void cuf::populateCUFGPUToLLVMConversionPatterns( |
| 326 | fir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns, |
| 327 | mlir::PatternBenefit benefit) { |
| 328 | converter.addConversion([&converter](mlir::gpu::AsyncTokenType) -> Type { |
| 329 | return mlir::LLVM::LLVMPointerType::get(&converter.getContext()); |
| 330 | }); |
| 331 | patterns.add<CUFSharedMemoryOpConversion, GPULaunchKernelConversion, |
| 332 | CUFStreamCastConversion>(converter, benefit); |
| 333 | } |
| 334 | |