1//===-- CUFDeviceGlobal.cpp -----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "flang/Optimizer/Transforms/CUFOpConversion.h"
10#include "flang/Optimizer/Builder/CUFCommon.h"
11#include "flang/Optimizer/Builder/Runtime/CUDA/Descriptor.h"
12#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
13#include "flang/Optimizer/CodeGen/TypeConverter.h"
14#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
15#include "flang/Optimizer/Dialect/FIRDialect.h"
16#include "flang/Optimizer/Dialect/FIROps.h"
17#include "flang/Optimizer/HLFIR/HLFIROps.h"
18#include "flang/Optimizer/Support/DataLayout.h"
19#include "flang/Runtime/CUDA/allocatable.h"
20#include "flang/Runtime/CUDA/common.h"
21#include "flang/Runtime/CUDA/descriptor.h"
22#include "flang/Runtime/CUDA/memory.h"
23#include "flang/Runtime/CUDA/pointer.h"
24#include "flang/Runtime/allocatable.h"
25#include "flang/Runtime/allocator-registry-consts.h"
26#include "flang/Support/Fortran.h"
27#include "mlir/Conversion/LLVMCommon/Pattern.h"
28#include "mlir/Dialect/DLTI/DLTI.h"
29#include "mlir/Dialect/GPU/IR/GPUDialect.h"
30#include "mlir/IR/Matchers.h"
31#include "mlir/Pass/Pass.h"
32#include "mlir/Transforms/DialectConversion.h"
33#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
34
35namespace fir {
36#define GEN_PASS_DEF_CUFOPCONVERSION
37#include "flang/Optimizer/Transforms/Passes.h.inc"
38} // namespace fir
39
40using namespace fir;
41using namespace mlir;
42using namespace Fortran::runtime;
43using namespace Fortran::runtime::cuda;
44
45namespace {
46
47static inline unsigned getMemType(cuf::DataAttribute attr) {
48 if (attr == cuf::DataAttribute::Device)
49 return kMemTypeDevice;
50 if (attr == cuf::DataAttribute::Managed)
51 return kMemTypeManaged;
52 if (attr == cuf::DataAttribute::Unified)
53 return kMemTypeUnified;
54 if (attr == cuf::DataAttribute::Pinned)
55 return kMemTypePinned;
56 llvm::report_fatal_error("unsupported memory type");
57}
58
59template <typename OpTy>
60static bool isPinned(OpTy op) {
61 if (op.getDataAttr() && *op.getDataAttr() == cuf::DataAttribute::Pinned)
62 return true;
63 return false;
64}
65
66template <typename OpTy>
67static bool hasDoubleDescriptors(OpTy op) {
68 if (auto declareOp =
69 mlir::dyn_cast_or_null<fir::DeclareOp>(op.getBox().getDefiningOp())) {
70 if (mlir::isa_and_nonnull<fir::AddrOfOp>(
71 declareOp.getMemref().getDefiningOp())) {
72 if (isPinned(declareOp))
73 return false;
74 return true;
75 }
76 } else if (auto declareOp = mlir::dyn_cast_or_null<hlfir::DeclareOp>(
77 op.getBox().getDefiningOp())) {
78 if (mlir::isa_and_nonnull<fir::AddrOfOp>(
79 declareOp.getMemref().getDefiningOp())) {
80 if (isPinned(declareOp))
81 return false;
82 return true;
83 }
84 }
85 return false;
86}
87
88static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
89 mlir::Location loc, mlir::Type toTy,
90 mlir::Value val) {
91 if (val.getType() != toTy)
92 return rewriter.create<fir::ConvertOp>(loc, toTy, val);
93 return val;
94}
95
96template <typename OpTy>
97static mlir::LogicalResult convertOpToCall(OpTy op,
98 mlir::PatternRewriter &rewriter,
99 mlir::func::FuncOp func) {
100 auto mod = op->template getParentOfType<mlir::ModuleOp>();
101 fir::FirOpBuilder builder(rewriter, mod);
102 mlir::Location loc = op.getLoc();
103 auto fTy = func.getFunctionType();
104
105 mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
106 mlir::Value sourceLine;
107 if constexpr (std::is_same_v<OpTy, cuf::AllocateOp>)
108 sourceLine = fir::factory::locationToLineNo(
109 builder, loc, op.getSource() ? fTy.getInput(7) : fTy.getInput(6));
110 else
111 sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
112
113 mlir::Value hasStat = op.getHasStat() ? builder.createBool(loc, true)
114 : builder.createBool(loc, false);
115
116 mlir::Value errmsg;
117 if (op.getErrmsg()) {
118 errmsg = op.getErrmsg();
119 } else {
120 mlir::Type boxNoneTy = fir::BoxType::get(builder.getNoneType());
121 errmsg = builder.create<fir::AbsentOp>(loc, boxNoneTy).getResult();
122 }
123 llvm::SmallVector<mlir::Value> args;
124 if constexpr (std::is_same_v<OpTy, cuf::AllocateOp>) {
125 mlir::Value pinned =
126 op.getPinned()
127 ? op.getPinned()
128 : builder.createNullConstant(
129 loc, fir::ReferenceType::get(
130 mlir::IntegerType::get(op.getContext(), 1)));
131 if (op.getSource()) {
132 mlir::Value stream =
133 op.getStream() ? op.getStream()
134 : builder.createNullConstant(loc, fTy.getInput(2));
135 args = fir::runtime::createArguments(
136 builder, loc, fTy, op.getBox(), op.getSource(), stream, pinned,
137 hasStat, errmsg, sourceFile, sourceLine);
138 } else {
139 mlir::Value stream =
140 op.getStream() ? op.getStream()
141 : builder.createNullConstant(loc, fTy.getInput(1));
142 args = fir::runtime::createArguments(builder, loc, fTy, op.getBox(),
143 stream, pinned, hasStat, errmsg,
144 sourceFile, sourceLine);
145 }
146 } else {
147 args =
148 fir::runtime::createArguments(builder, loc, fTy, op.getBox(), hasStat,
149 errmsg, sourceFile, sourceLine);
150 }
151 auto callOp = builder.create<fir::CallOp>(loc, func, args);
152 rewriter.replaceOp(op, callOp);
153 return mlir::success();
154}
155
156struct CUFAllocateOpConversion
157 : public mlir::OpRewritePattern<cuf::AllocateOp> {
158 using OpRewritePattern::OpRewritePattern;
159
160 mlir::LogicalResult
161 matchAndRewrite(cuf::AllocateOp op,
162 mlir::PatternRewriter &rewriter) const override {
163 auto mod = op->getParentOfType<mlir::ModuleOp>();
164 fir::FirOpBuilder builder(rewriter, mod);
165 mlir::Location loc = op.getLoc();
166
167 bool isPointer = false;
168
169 if (auto declareOp =
170 mlir::dyn_cast_or_null<fir::DeclareOp>(op.getBox().getDefiningOp()))
171 if (declareOp.getFortranAttrs() &&
172 bitEnumContainsAny(*declareOp.getFortranAttrs(),
173 fir::FortranVariableFlagsEnum::pointer))
174 isPointer = true;
175
176 if (hasDoubleDescriptors(op)) {
177 // Allocation for module variable are done with custom runtime entry point
178 // so the descriptors can be synchronized.
179 mlir::func::FuncOp func;
180 if (op.getSource()) {
181 func = isPointer ? fir::runtime::getRuntimeFunc<mkRTKey(
182 CUFPointerAllocateSourceSync)>(loc, builder)
183 : fir::runtime::getRuntimeFunc<mkRTKey(
184 CUFAllocatableAllocateSourceSync)>(loc, builder);
185 } else {
186 func =
187 isPointer
188 ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocateSync)>(
189 loc, builder)
190 : fir::runtime::getRuntimeFunc<mkRTKey(
191 CUFAllocatableAllocateSync)>(loc, builder);
192 }
193 return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
194 }
195
196 mlir::func::FuncOp func;
197 if (op.getSource()) {
198 func =
199 isPointer
200 ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocateSource)>(
201 loc, builder)
202 : fir::runtime::getRuntimeFunc<mkRTKey(
203 CUFAllocatableAllocateSource)>(loc, builder);
204 } else {
205 func =
206 isPointer
207 ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocate)>(
208 loc, builder)
209 : fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableAllocate)>(
210 loc, builder);
211 }
212
213 return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
214 }
215};
216
217struct CUFDeallocateOpConversion
218 : public mlir::OpRewritePattern<cuf::DeallocateOp> {
219 using OpRewritePattern::OpRewritePattern;
220
221 mlir::LogicalResult
222 matchAndRewrite(cuf::DeallocateOp op,
223 mlir::PatternRewriter &rewriter) const override {
224
225 auto mod = op->getParentOfType<mlir::ModuleOp>();
226 fir::FirOpBuilder builder(rewriter, mod);
227 mlir::Location loc = op.getLoc();
228
229 if (hasDoubleDescriptors(op)) {
230 // Deallocation for module variable are done with custom runtime entry
231 // point so the descriptors can be synchronized.
232 mlir::func::FuncOp func =
233 fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableDeallocate)>(
234 loc, builder);
235 return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func);
236 }
237
238 // Deallocation for local descriptor falls back on the standard runtime
239 // AllocatableDeallocate as the dedicated deallocator is set in the
240 // descriptor before the call.
241 mlir::func::FuncOp func =
242 fir::runtime::getRuntimeFunc<mkRTKey(AllocatableDeallocate)>(loc,
243 builder);
244 return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func);
245 }
246};
247
248static bool inDeviceContext(mlir::Operation *op) {
249 if (op->getParentOfType<cuf::KernelOp>())
250 return true;
251 if (auto funcOp = op->getParentOfType<mlir::gpu::GPUFuncOp>())
252 return true;
253 if (auto funcOp = op->getParentOfType<mlir::gpu::LaunchOp>())
254 return true;
255 if (auto funcOp = op->getParentOfType<mlir::func::FuncOp>()) {
256 if (auto cudaProcAttr =
257 funcOp.getOperation()->getAttrOfType<cuf::ProcAttributeAttr>(
258 cuf::getProcAttrName())) {
259 return cudaProcAttr.getValue() != cuf::ProcAttribute::Host &&
260 cudaProcAttr.getValue() != cuf::ProcAttribute::HostDevice;
261 }
262 }
263 return false;
264}
265
266static int computeWidth(mlir::Location loc, mlir::Type type,
267 fir::KindMapping &kindMap) {
268 auto eleTy = fir::unwrapSequenceType(type);
269 if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)})
270 return t.getWidth() / 8;
271 if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)})
272 return t.getWidth() / 8;
273 if (eleTy.isInteger(1))
274 return 1;
275 if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)})
276 return kindMap.getLogicalBitsize(t.getFKind()) / 8;
277 if (auto t{mlir::dyn_cast<mlir::ComplexType>(eleTy)}) {
278 int elemSize =
279 mlir::cast<mlir::FloatType>(t.getElementType()).getWidth() / 8;
280 return 2 * elemSize;
281 }
282 if (auto t{mlir::dyn_cast_or_null<fir::CharacterType>(eleTy)})
283 return kindMap.getCharacterBitsize(t.getFKind()) / 8;
284 mlir::emitError(loc, "unsupported type");
285 return 0;
286}
287
288struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
289 using OpRewritePattern::OpRewritePattern;
290
291 CUFAllocOpConversion(mlir::MLIRContext *context, mlir::DataLayout *dl,
292 const fir::LLVMTypeConverter *typeConverter)
293 : OpRewritePattern(context), dl{dl}, typeConverter{typeConverter} {}
294
295 mlir::LogicalResult
296 matchAndRewrite(cuf::AllocOp op,
297 mlir::PatternRewriter &rewriter) const override {
298
299 mlir::Location loc = op.getLoc();
300
301 if (inDeviceContext(op.getOperation())) {
302 // In device context just replace the cuf.alloc operation with a fir.alloc
303 // the cuf.free will be removed.
304 auto allocaOp = rewriter.create<fir::AllocaOp>(
305 loc, op.getInType(), op.getUniqName() ? *op.getUniqName() : "",
306 op.getBindcName() ? *op.getBindcName() : "", op.getTypeparams(),
307 op.getShape());
308 allocaOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
309 rewriter.replaceOp(op, allocaOp);
310 return mlir::success();
311 }
312
313 auto mod = op->getParentOfType<mlir::ModuleOp>();
314 fir::FirOpBuilder builder(rewriter, mod);
315 mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
316
317 if (!mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType())) {
318 // Convert scalar and known size array allocations.
319 mlir::Value bytes;
320 fir::KindMapping kindMap{fir::getKindMapping(mod)};
321 if (fir::isa_trivial(op.getInType())) {
322 int width = computeWidth(loc, op.getInType(), kindMap);
323 bytes =
324 builder.createIntegerConstant(loc, builder.getIndexType(), width);
325 } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
326 op.getInType())) {
327 std::size_t size = 0;
328 if (fir::isa_derived(seqTy.getEleTy())) {
329 mlir::Type structTy = typeConverter->convertType(seqTy.getEleTy());
330 size = dl->getTypeSizeInBits(structTy) / 8;
331 } else {
332 size = computeWidth(loc, seqTy.getEleTy(), kindMap);
333 }
334 mlir::Value width =
335 builder.createIntegerConstant(loc, builder.getIndexType(), size);
336 mlir::Value nbElem;
337 if (fir::sequenceWithNonConstantShape(seqTy)) {
338 assert(!op.getShape().empty() && "expect shape with dynamic arrays");
339 nbElem = builder.loadIfRef(loc, op.getShape()[0]);
340 for (unsigned i = 1; i < op.getShape().size(); ++i) {
341 nbElem = rewriter.create<mlir::arith::MulIOp>(
342 loc, nbElem, builder.loadIfRef(loc, op.getShape()[i]));
343 }
344 } else {
345 nbElem = builder.createIntegerConstant(loc, builder.getIndexType(),
346 seqTy.getConstantArraySize());
347 }
348 bytes = rewriter.create<mlir::arith::MulIOp>(loc, nbElem, width);
349 } else if (fir::isa_derived(op.getInType())) {
350 mlir::Type structTy = typeConverter->convertType(op.getInType());
351 std::size_t structSize = dl->getTypeSizeInBits(structTy) / 8;
352 bytes = builder.createIntegerConstant(loc, builder.getIndexType(),
353 structSize);
354 } else {
355 mlir::emitError(loc, "unsupported type in cuf.alloc\n");
356 }
357 mlir::func::FuncOp func =
358 fir::runtime::getRuntimeFunc<mkRTKey(CUFMemAlloc)>(loc, builder);
359 auto fTy = func.getFunctionType();
360 mlir::Value sourceLine =
361 fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
362 mlir::Value memTy = builder.createIntegerConstant(
363 loc, builder.getI32Type(), getMemType(op.getDataAttr()));
364 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
365 builder, loc, fTy, bytes, memTy, sourceFile, sourceLine)};
366 auto callOp = builder.create<fir::CallOp>(loc, func, args);
367 callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
368 auto convOp = builder.createConvert(loc, op.getResult().getType(),
369 callOp.getResult(0));
370 rewriter.replaceOp(op, convOp);
371 return mlir::success();
372 }
373
374 // Convert descriptor allocations to function call.
375 auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType());
376 mlir::func::FuncOp func =
377 fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocDescriptor)>(loc, builder);
378 auto fTy = func.getFunctionType();
379 mlir::Value sourceLine =
380 fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
381
382 mlir::Type structTy = typeConverter->convertBoxTypeAsStruct(boxTy);
383 std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8;
384 mlir::Value sizeInBytes =
385 builder.createIntegerConstant(loc, builder.getIndexType(), boxSize);
386
387 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
388 builder, loc, fTy, sizeInBytes, sourceFile, sourceLine)};
389 auto callOp = builder.create<fir::CallOp>(loc, func, args);
390 callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
391 auto convOp = builder.createConvert(loc, op.getResult().getType(),
392 callOp.getResult(0));
393 rewriter.replaceOp(op, convOp);
394 return mlir::success();
395 }
396
397private:
398 mlir::DataLayout *dl;
399 const fir::LLVMTypeConverter *typeConverter;
400};
401
402struct CUFDeviceAddressOpConversion
403 : public mlir::OpRewritePattern<cuf::DeviceAddressOp> {
404 using OpRewritePattern::OpRewritePattern;
405
406 CUFDeviceAddressOpConversion(mlir::MLIRContext *context,
407 const mlir::SymbolTable &symtab)
408 : OpRewritePattern(context), symTab{symtab} {}
409
410 mlir::LogicalResult
411 matchAndRewrite(cuf::DeviceAddressOp op,
412 mlir::PatternRewriter &rewriter) const override {
413 if (auto global = symTab.lookup<fir::GlobalOp>(
414 op.getHostSymbol().getRootReference().getValue())) {
415 auto mod = op->getParentOfType<mlir::ModuleOp>();
416 mlir::Location loc = op.getLoc();
417 auto hostAddr = rewriter.create<fir::AddrOfOp>(
418 loc, fir::ReferenceType::get(global.getType()), op.getHostSymbol());
419 fir::FirOpBuilder builder(rewriter, mod);
420 mlir::func::FuncOp callee =
421 fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc,
422 builder);
423 auto fTy = callee.getFunctionType();
424 mlir::Value conv =
425 createConvertOp(rewriter, loc, fTy.getInput(0), hostAddr);
426 mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
427 mlir::Value sourceLine =
428 fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
429 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
430 builder, loc, fTy, conv, sourceFile, sourceLine)};
431 auto call = rewriter.create<fir::CallOp>(loc, callee, args);
432 mlir::Value addr = createConvertOp(rewriter, loc, hostAddr.getType(),
433 call->getResult(0));
434 rewriter.replaceOp(op, addr.getDefiningOp());
435 return success();
436 }
437 return failure();
438 }
439
440private:
441 const mlir::SymbolTable &symTab;
442};
443
444struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
445 using OpRewritePattern::OpRewritePattern;
446
447 DeclareOpConversion(mlir::MLIRContext *context,
448 const mlir::SymbolTable &symtab)
449 : OpRewritePattern(context), symTab{symtab} {}
450
451 mlir::LogicalResult
452 matchAndRewrite(fir::DeclareOp op,
453 mlir::PatternRewriter &rewriter) const override {
454 if (auto addrOfOp = op.getMemref().getDefiningOp<fir::AddrOfOp>()) {
455 if (auto global = symTab.lookup<fir::GlobalOp>(
456 addrOfOp.getSymbol().getRootReference().getValue())) {
457 if (cuf::isRegisteredDeviceGlobal(global)) {
458 rewriter.setInsertionPointAfter(addrOfOp);
459 mlir::Value devAddr = rewriter.create<cuf::DeviceAddressOp>(
460 op.getLoc(), addrOfOp.getType(), addrOfOp.getSymbol());
461 rewriter.startOpModification(op);
462 op.getMemrefMutable().assign(devAddr);
463 rewriter.finalizeOpModification(op);
464 return success();
465 }
466 }
467 }
468 return failure();
469 }
470
471private:
472 const mlir::SymbolTable &symTab;
473};
474
475struct CUFFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
476 using OpRewritePattern::OpRewritePattern;
477
478 mlir::LogicalResult
479 matchAndRewrite(cuf::FreeOp op,
480 mlir::PatternRewriter &rewriter) const override {
481 if (inDeviceContext(op.getOperation())) {
482 rewriter.eraseOp(op);
483 return mlir::success();
484 }
485
486 if (!mlir::isa<fir::ReferenceType>(op.getDevptr().getType()))
487 return failure();
488
489 auto mod = op->getParentOfType<mlir::ModuleOp>();
490 fir::FirOpBuilder builder(rewriter, mod);
491 mlir::Location loc = op.getLoc();
492 mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
493
494 auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr().getType());
495 if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy())) {
496 mlir::func::FuncOp func =
497 fir::runtime::getRuntimeFunc<mkRTKey(CUFMemFree)>(loc, builder);
498 auto fTy = func.getFunctionType();
499 mlir::Value sourceLine =
500 fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
501 mlir::Value memTy = builder.createIntegerConstant(
502 loc, builder.getI32Type(), getMemType(op.getDataAttr()));
503 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
504 builder, loc, fTy, op.getDevptr(), memTy, sourceFile, sourceLine)};
505 builder.create<fir::CallOp>(loc, func, args);
506 rewriter.eraseOp(op);
507 return mlir::success();
508 }
509
510 // Convert cuf.free on descriptors.
511 mlir::func::FuncOp func =
512 fir::runtime::getRuntimeFunc<mkRTKey(CUFFreeDescriptor)>(loc, builder);
513 auto fTy = func.getFunctionType();
514 mlir::Value sourceLine =
515 fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
516 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
517 builder, loc, fTy, op.getDevptr(), sourceFile, sourceLine)};
518 auto callOp = builder.create<fir::CallOp>(loc, func, args);
519 callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
520 rewriter.eraseOp(op);
521 return mlir::success();
522 }
523};
524
525static bool isDstGlobal(cuf::DataTransferOp op) {
526 if (auto declareOp = op.getDst().getDefiningOp<fir::DeclareOp>())
527 if (declareOp.getMemref().getDefiningOp<fir::AddrOfOp>())
528 return true;
529 if (auto declareOp = op.getDst().getDefiningOp<hlfir::DeclareOp>())
530 if (declareOp.getMemref().getDefiningOp<fir::AddrOfOp>())
531 return true;
532 return false;
533}
534
535static mlir::Value getShapeFromDecl(mlir::Value src) {
536 if (auto declareOp = src.getDefiningOp<fir::DeclareOp>())
537 return declareOp.getShape();
538 if (auto declareOp = src.getDefiningOp<hlfir::DeclareOp>())
539 return declareOp.getShape();
540 return mlir::Value{};
541}
542
543static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
544 cuf::DataTransferOp op,
545 const mlir::SymbolTable &symtab,
546 mlir::Type dstEleTy = nullptr) {
547 auto mod = op->getParentOfType<mlir::ModuleOp>();
548 mlir::Location loc = op.getLoc();
549 fir::FirOpBuilder builder(rewriter, mod);
550 mlir::Value addr;
551 mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType());
552 if (fir::isa_trivial(srcTy) &&
553 mlir::matchPattern(op.getSrc().getDefiningOp(), mlir::m_Constant())) {
554 mlir::Value src = op.getSrc();
555 if (srcTy.isInteger(1)) {
556 // i1 is not a supported type in the descriptor and it is actually coming
557 // from a LOGICAL constant. Store it as a fir.logical.
558 srcTy = fir::LogicalType::get(rewriter.getContext(), 4);
559 src = createConvertOp(rewriter, loc, srcTy, src);
560 addr = builder.createTemporary(loc, srcTy);
561 builder.create<fir::StoreOp>(loc, src, addr);
562 } else {
563 if (dstEleTy && fir::isa_trivial(dstEleTy) && srcTy != dstEleTy) {
564 // Use dstEleTy and convert to avoid assign mismatch.
565 addr = builder.createTemporary(loc, dstEleTy);
566 auto conv = builder.create<fir::ConvertOp>(loc, dstEleTy, src);
567 builder.create<fir::StoreOp>(loc, conv, addr);
568 srcTy = dstEleTy;
569 } else {
570 // Put constant in memory if it is not.
571 addr = builder.createTemporary(loc, srcTy);
572 builder.create<fir::StoreOp>(loc, src, addr);
573 }
574 }
575 } else {
576 addr = op.getSrc();
577 }
578 llvm::SmallVector<mlir::Value> lenParams;
579 mlir::Type boxTy = fir::BoxType::get(srcTy);
580 mlir::Value box =
581 builder.createBox(loc, boxTy, addr, getShapeFromDecl(op.getSrc()),
582 /*slice=*/nullptr, lenParams,
583 /*tdesc=*/nullptr);
584 mlir::Value src = builder.createTemporary(loc, box.getType());
585 builder.create<fir::StoreOp>(loc, box, src);
586 return src;
587}
588
589static mlir::Value emboxDst(mlir::PatternRewriter &rewriter,
590 cuf::DataTransferOp op,
591 const mlir::SymbolTable &symtab) {
592 auto mod = op->getParentOfType<mlir::ModuleOp>();
593 mlir::Location loc = op.getLoc();
594 fir::FirOpBuilder builder(rewriter, mod);
595 mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType());
596 mlir::Value dstAddr = op.getDst();
597 mlir::Type dstBoxTy = fir::BoxType::get(dstTy);
598 llvm::SmallVector<mlir::Value> lenParams;
599 mlir::Value dstBox =
600 builder.createBox(loc, dstBoxTy, dstAddr, getShapeFromDecl(op.getDst()),
601 /*slice=*/nullptr, lenParams,
602 /*tdesc=*/nullptr);
603 mlir::Value dst = builder.createTemporary(loc, dstBox.getType());
604 builder.create<fir::StoreOp>(loc, dstBox, dst);
605 return dst;
606}
607
608struct CUFDataTransferOpConversion
609 : public mlir::OpRewritePattern<cuf::DataTransferOp> {
610 using OpRewritePattern::OpRewritePattern;
611
612 CUFDataTransferOpConversion(mlir::MLIRContext *context,
613 const mlir::SymbolTable &symtab,
614 mlir::DataLayout *dl,
615 const fir::LLVMTypeConverter *typeConverter)
616 : OpRewritePattern(context), symtab{symtab}, dl{dl},
617 typeConverter{typeConverter} {}
618
619 mlir::LogicalResult
620 matchAndRewrite(cuf::DataTransferOp op,
621 mlir::PatternRewriter &rewriter) const override {
622
623 mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType());
624 mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType());
625
626 mlir::Location loc = op.getLoc();
627 unsigned mode = 0;
628 if (op.getTransferKind() == cuf::DataTransferKind::HostDevice) {
629 mode = kHostToDevice;
630 } else if (op.getTransferKind() == cuf::DataTransferKind::DeviceHost) {
631 mode = kDeviceToHost;
632 } else if (op.getTransferKind() == cuf::DataTransferKind::DeviceDevice) {
633 mode = kDeviceToDevice;
634 } else {
635 mlir::emitError(loc, "unsupported transfer kind\n");
636 }
637
638 auto mod = op->getParentOfType<mlir::ModuleOp>();
639 fir::FirOpBuilder builder(rewriter, mod);
640 fir::KindMapping kindMap{fir::getKindMapping(mod)};
641 mlir::Value modeValue =
642 builder.createIntegerConstant(loc, builder.getI32Type(), mode);
643
644 // Convert data transfer without any descriptor.
645 if (!mlir::isa<fir::BaseBoxType>(srcTy) &&
646 !mlir::isa<fir::BaseBoxType>(dstTy)) {
647
648 if (fir::isa_trivial(srcTy) && !fir::isa_trivial(dstTy)) {
649 // Initialization of an array from a scalar value should be implemented
650 // via a kernel launch. Use the flan runtime via the Assign function
651 // until we have more infrastructure.
652 mlir::Value src = emboxSrc(rewriter, op, symtab);
653 mlir::Value dst = emboxDst(rewriter, op, symtab);
654 mlir::func::FuncOp func =
655 fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferCstDesc)>(
656 loc, builder);
657 auto fTy = func.getFunctionType();
658 mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
659 mlir::Value sourceLine =
660 fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
661 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
662 builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
663 builder.create<fir::CallOp>(loc, func, args);
664 rewriter.eraseOp(op);
665 return mlir::success();
666 }
667
668 mlir::Type i64Ty = builder.getI64Type();
669 mlir::Value nbElement;
670 if (op.getShape()) {
671 llvm::SmallVector<mlir::Value> extents;
672 if (auto shapeOp =
673 mlir::dyn_cast<fir::ShapeOp>(op.getShape().getDefiningOp())) {
674 extents = shapeOp.getExtents();
675 } else if (auto shapeShiftOp = mlir::dyn_cast<fir::ShapeShiftOp>(
676 op.getShape().getDefiningOp())) {
677 for (auto i : llvm::enumerate(shapeShiftOp.getPairs()))
678 if (i.index() & 1)
679 extents.push_back(i.value());
680 }
681
682 nbElement = rewriter.create<fir::ConvertOp>(loc, i64Ty, extents[0]);
683 for (unsigned i = 1; i < extents.size(); ++i) {
684 auto operand =
685 rewriter.create<fir::ConvertOp>(loc, i64Ty, extents[i]);
686 nbElement =
687 rewriter.create<mlir::arith::MulIOp>(loc, nbElement, operand);
688 }
689 } else {
690 if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(dstTy))
691 nbElement = builder.createIntegerConstant(
692 loc, i64Ty, seqTy.getConstantArraySize());
693 }
694 unsigned width = 0;
695 if (fir::isa_derived(fir::unwrapSequenceType(dstTy))) {
696 mlir::Type structTy =
697 typeConverter->convertType(fir::unwrapSequenceType(dstTy));
698 width = dl->getTypeSizeInBits(structTy) / 8;
699 } else {
700 width = computeWidth(loc, dstTy, kindMap);
701 }
702 mlir::Value widthValue = rewriter.create<mlir::arith::ConstantOp>(
703 loc, i64Ty, rewriter.getIntegerAttr(i64Ty, width));
704 mlir::Value bytes =
705 nbElement
706 ? rewriter.create<mlir::arith::MulIOp>(loc, nbElement, widthValue)
707 : widthValue;
708
709 mlir::func::FuncOp func =
710 fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrPtr)>(loc,
711 builder);
712 auto fTy = func.getFunctionType();
713 mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
714 mlir::Value sourceLine =
715 fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
716
717 mlir::Value dst = op.getDst();
718 mlir::Value src = op.getSrc();
719 // Materialize the src if constant.
720 if (matchPattern(src.getDefiningOp(), mlir::m_Constant())) {
721 mlir::Value temp = builder.createTemporary(loc, srcTy);
722 builder.create<fir::StoreOp>(loc, src, temp);
723 src = temp;
724 }
725 llvm::SmallVector<mlir::Value> args{
726 fir::runtime::createArguments(builder, loc, fTy, dst, src, bytes,
727 modeValue, sourceFile, sourceLine)};
728 builder.create<fir::CallOp>(loc, func, args);
729 rewriter.eraseOp(op);
730 return mlir::success();
731 }
732
733 auto materializeBoxIfNeeded = [&](mlir::Value val) -> mlir::Value {
734 if (mlir::isa<fir::EmboxOp, fir::ReboxOp>(val.getDefiningOp())) {
735 // Materialize the box to memory to be able to call the runtime.
736 mlir::Value box = builder.createTemporary(loc, val.getType());
737 builder.create<fir::StoreOp>(loc, val, box);
738 return box;
739 }
740 return val;
741 };
742
743 // Conversion of data transfer involving at least one descriptor.
744 if (auto dstBoxTy = mlir::dyn_cast<fir::BaseBoxType>(dstTy)) {
745 // Transfer to a descriptor.
746 mlir::func::FuncOp func =
747 isDstGlobal(op)
748 ? fir::runtime::getRuntimeFunc<mkRTKey(
749 CUFDataTransferGlobalDescDesc)>(loc, builder)
750 : fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescDesc)>(
751 loc, builder);
752 mlir::Value dst = op.getDst();
753 mlir::Value src = op.getSrc();
754 if (!mlir::isa<fir::BaseBoxType>(srcTy)) {
755 mlir::Type dstEleTy = fir::unwrapInnerType(dstBoxTy.getEleTy());
756 src = emboxSrc(rewriter, op, symtab, dstEleTy);
757 if (fir::isa_trivial(srcTy))
758 func = fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferCstDesc)>(
759 loc, builder);
760 }
761
762 src = materializeBoxIfNeeded(src);
763 dst = materializeBoxIfNeeded(dst);
764
765 auto fTy = func.getFunctionType();
766 mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
767 mlir::Value sourceLine =
768 fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
769 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
770 builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
771 builder.create<fir::CallOp>(loc, func, args);
772 rewriter.eraseOp(op);
773 } else {
774 // Transfer from a descriptor.
775 mlir::Value dst = emboxDst(rewriter, op, symtab);
776 mlir::Value src = materializeBoxIfNeeded(op.getSrc());
777
778 mlir::func::FuncOp func = fir::runtime::getRuntimeFunc<mkRTKey(
779 CUFDataTransferDescDescNoRealloc)>(loc, builder);
780
781 auto fTy = func.getFunctionType();
782 mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
783 mlir::Value sourceLine =
784 fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
785 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
786 builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
787 builder.create<fir::CallOp>(loc, func, args);
788 rewriter.eraseOp(op);
789 }
790 return mlir::success();
791 }
792
793private:
794 const mlir::SymbolTable &symtab;
795 mlir::DataLayout *dl;
796 const fir::LLVMTypeConverter *typeConverter;
797};
798
799struct CUFLaunchOpConversion
800 : public mlir::OpRewritePattern<cuf::KernelLaunchOp> {
801public:
802 using OpRewritePattern::OpRewritePattern;
803
804 CUFLaunchOpConversion(mlir::MLIRContext *context,
805 const mlir::SymbolTable &symTab)
806 : OpRewritePattern(context), symTab{symTab} {}
807
808 mlir::LogicalResult
809 matchAndRewrite(cuf::KernelLaunchOp op,
810 mlir::PatternRewriter &rewriter) const override {
811 mlir::Location loc = op.getLoc();
812 auto idxTy = mlir::IndexType::get(op.getContext());
813 mlir::Value zero = rewriter.create<mlir::arith::ConstantOp>(
814 loc, rewriter.getIntegerType(32), rewriter.getI32IntegerAttr(0));
815 auto gridSizeX =
816 rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getGridX());
817 auto gridSizeY =
818 rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getGridY());
819 auto gridSizeZ =
820 rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getGridZ());
821 auto blockSizeX =
822 rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getBlockX());
823 auto blockSizeY =
824 rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getBlockY());
825 auto blockSizeZ =
826 rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getBlockZ());
827 auto kernelName = mlir::SymbolRefAttr::get(
828 rewriter.getStringAttr(cudaDeviceModuleName),
829 {mlir::SymbolRefAttr::get(
830 rewriter.getContext(),
831 op.getCallee().getLeafReference().getValue())});
832 mlir::Value clusterDimX, clusterDimY, clusterDimZ;
833 cuf::ProcAttributeAttr procAttr;
834 if (auto funcOp = symTab.lookup<mlir::func::FuncOp>(
835 op.getCallee().getLeafReference())) {
836 if (auto clusterDimsAttr = funcOp->getAttrOfType<cuf::ClusterDimsAttr>(
837 cuf::getClusterDimsAttrName())) {
838 clusterDimX = rewriter.create<mlir::arith::ConstantIndexOp>(
839 loc, clusterDimsAttr.getX().getInt());
840 clusterDimY = rewriter.create<mlir::arith::ConstantIndexOp>(
841 loc, clusterDimsAttr.getY().getInt());
842 clusterDimZ = rewriter.create<mlir::arith::ConstantIndexOp>(
843 loc, clusterDimsAttr.getZ().getInt());
844 }
845 procAttr =
846 funcOp->getAttrOfType<cuf::ProcAttributeAttr>(cuf::getProcAttrName());
847 }
848 llvm::SmallVector<mlir::Value> args;
849 for (mlir::Value arg : op.getArgs()) {
850 // If the argument is a global descriptor, make sure we pass the device
851 // copy of this descriptor and not the host one.
852 if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(arg.getType()))) {
853 if (auto declareOp =
854 mlir::dyn_cast_or_null<fir::DeclareOp>(arg.getDefiningOp())) {
855 if (auto addrOfOp = mlir::dyn_cast_or_null<fir::AddrOfOp>(
856 declareOp.getMemref().getDefiningOp())) {
857 if (auto global = symTab.lookup<fir::GlobalOp>(
858 addrOfOp.getSymbol().getRootReference().getValue())) {
859 if (cuf::isRegisteredDeviceGlobal(global)) {
860 arg = rewriter
861 .create<cuf::DeviceAddressOp>(op.getLoc(),
862 addrOfOp.getType(),
863 addrOfOp.getSymbol())
864 .getResult();
865 }
866 }
867 }
868 }
869 }
870 args.push_back(arg);
871 }
872 mlir::Value dynamicShmemSize = op.getBytes() ? op.getBytes() : zero;
873 auto gpuLaunchOp = rewriter.create<mlir::gpu::LaunchFuncOp>(
874 loc, kernelName, mlir::gpu::KernelDim3{gridSizeX, gridSizeY, gridSizeZ},
875 mlir::gpu::KernelDim3{blockSizeX, blockSizeY, blockSizeZ},
876 dynamicShmemSize, args);
877 if (clusterDimX && clusterDimY && clusterDimZ) {
878 gpuLaunchOp.getClusterSizeXMutable().assign(clusterDimX);
879 gpuLaunchOp.getClusterSizeYMutable().assign(clusterDimY);
880 gpuLaunchOp.getClusterSizeZMutable().assign(clusterDimZ);
881 }
882 if (op.getStream()) {
883 mlir::OpBuilder::InsertionGuard guard(rewriter);
884 rewriter.setInsertionPoint(gpuLaunchOp);
885 mlir::Value stream =
886 rewriter.create<cuf::StreamCastOp>(loc, op.getStream());
887 gpuLaunchOp.getAsyncDependenciesMutable().append(stream);
888 }
889 if (procAttr)
890 gpuLaunchOp->setAttr(cuf::getProcAttrName(), procAttr);
891 else
892 // Set default global attribute of the original was not found.
893 gpuLaunchOp->setAttr(cuf::getProcAttrName(),
894 cuf::ProcAttributeAttr::get(
895 op.getContext(), cuf::ProcAttribute::Global));
896 rewriter.replaceOp(op, gpuLaunchOp);
897 return mlir::success();
898 }
899
900private:
901 const mlir::SymbolTable &symTab;
902};
903
904struct CUFSyncDescriptorOpConversion
905 : public mlir::OpRewritePattern<cuf::SyncDescriptorOp> {
906 using OpRewritePattern::OpRewritePattern;
907
908 mlir::LogicalResult
909 matchAndRewrite(cuf::SyncDescriptorOp op,
910 mlir::PatternRewriter &rewriter) const override {
911 auto mod = op->getParentOfType<mlir::ModuleOp>();
912 fir::FirOpBuilder builder(rewriter, mod);
913 mlir::Location loc = op.getLoc();
914
915 auto globalOp = mod.lookupSymbol<fir::GlobalOp>(op.getGlobalName());
916 if (!globalOp)
917 return mlir::failure();
918
919 auto hostAddr = builder.create<fir::AddrOfOp>(
920 loc, fir::ReferenceType::get(globalOp.getType()), op.getGlobalName());
921 fir::runtime::cuda::genSyncGlobalDescriptor(builder, loc, hostAddr);
922 op.erase();
923 return mlir::success();
924 }
925};
926
927struct CUFSetAllocatorIndexOpConversion
928 : public mlir::OpRewritePattern<cuf::SetAllocatorIndexOp> {
929 using OpRewritePattern::OpRewritePattern;
930
931 mlir::LogicalResult
932 matchAndRewrite(cuf::SetAllocatorIndexOp op,
933 mlir::PatternRewriter &rewriter) const override {
934 auto mod = op->getParentOfType<mlir::ModuleOp>();
935 fir::FirOpBuilder builder(rewriter, mod);
936 mlir::Location loc = op.getLoc();
937 int idx = kDefaultAllocator;
938 if (op.getDataAttr() == cuf::DataAttribute::Device) {
939 idx = kDeviceAllocatorPos;
940 } else if (op.getDataAttr() == cuf::DataAttribute::Managed) {
941 idx = kManagedAllocatorPos;
942 } else if (op.getDataAttr() == cuf::DataAttribute::Unified) {
943 idx = kUnifiedAllocatorPos;
944 } else if (op.getDataAttr() == cuf::DataAttribute::Pinned) {
945 idx = kPinnedAllocatorPos;
946 }
947 mlir::Value index =
948 builder.createIntegerConstant(loc, builder.getI32Type(), idx);
949 fir::runtime::cuda::genSetAllocatorIndex(builder, loc, op.getBox(), index);
950 op.erase();
951 return mlir::success();
952 }
953};
954
955class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
956public:
957 void runOnOperation() override {
958 auto *ctx = &getContext();
959 mlir::RewritePatternSet patterns(ctx);
960 mlir::ConversionTarget target(*ctx);
961
962 mlir::Operation *op = getOperation();
963 mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(op);
964 if (!module)
965 return signalPassFailure();
966 mlir::SymbolTable symtab(module);
967
968 std::optional<mlir::DataLayout> dl = fir::support::getOrSetMLIRDataLayout(
969 module, /*allowDefaultLayout=*/false);
970 fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false,
971 /*forceUnifiedTBAATree=*/false, *dl);
972 target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
973 mlir::gpu::GPUDialect>();
974 target.addLegalOp<cuf::StreamCastOp>();
975 cuf::populateCUFToFIRConversionPatterns(typeConverter, *dl, symtab,
976 patterns);
977 if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
978 std::move(patterns)))) {
979 mlir::emitError(mlir::UnknownLoc::get(ctx),
980 "error in CUF op conversion\n");
981 signalPassFailure();
982 }
983
984 target.addDynamicallyLegalOp<fir::DeclareOp>([&](fir::DeclareOp op) {
985 if (inDeviceContext(op))
986 return true;
987 if (auto addrOfOp = op.getMemref().getDefiningOp<fir::AddrOfOp>()) {
988 if (auto global = symtab.lookup<fir::GlobalOp>(
989 addrOfOp.getSymbol().getRootReference().getValue())) {
990 if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(global.getType())))
991 return true;
992 if (cuf::isRegisteredDeviceGlobal(global))
993 return false;
994 }
995 }
996 return true;
997 });
998
999 patterns.clear();
1000 cuf::populateFIRCUFConversionPatterns(symtab, patterns);
1001 if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
1002 std::move(patterns)))) {
1003 mlir::emitError(mlir::UnknownLoc::get(ctx),
1004 "error in CUF op conversion\n");
1005 signalPassFailure();
1006 }
1007 }
1008};
1009} // namespace
1010
1011void cuf::populateCUFToFIRConversionPatterns(
1012 const fir::LLVMTypeConverter &converter, mlir::DataLayout &dl,
1013 const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) {
1014 patterns.insert<CUFAllocOpConversion>(patterns.getContext(), &dl, &converter);
1015 patterns.insert<CUFAllocateOpConversion, CUFDeallocateOpConversion,
1016 CUFFreeOpConversion, CUFSyncDescriptorOpConversion,
1017 CUFSetAllocatorIndexOpConversion>(patterns.getContext());
1018 patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab,
1019 &dl, &converter);
1020 patterns.insert<CUFLaunchOpConversion, CUFDeviceAddressOpConversion>(
1021 patterns.getContext(), symtab);
1022}
1023
1024void cuf::populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab,
1025 mlir::RewritePatternSet &patterns) {
1026 patterns.insert<DeclareOpConversion, CUFDeviceAddressOpConversion>(
1027 patterns.getContext(), symtab);
1028}
1029

source code of flang/lib/Optimizer/Transforms/CUFOpConversion.cpp