1//===-- CUFDeviceGlobal.cpp -----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "flang/Optimizer/Transforms/CUFOpConversion.h"
10#include "flang/Optimizer/Builder/CUFCommon.h"
11#include "flang/Optimizer/Builder/Runtime/CUDA/Descriptor.h"
12#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
13#include "flang/Optimizer/CodeGen/TypeConverter.h"
14#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
15#include "flang/Optimizer/Dialect/FIRDialect.h"
16#include "flang/Optimizer/Dialect/FIROps.h"
17#include "flang/Optimizer/HLFIR/HLFIROps.h"
18#include "flang/Optimizer/Support/DataLayout.h"
19#include "flang/Runtime/CUDA/allocatable.h"
20#include "flang/Runtime/CUDA/common.h"
21#include "flang/Runtime/CUDA/descriptor.h"
22#include "flang/Runtime/CUDA/memory.h"
23#include "flang/Runtime/CUDA/pointer.h"
24#include "flang/Runtime/allocatable.h"
25#include "flang/Support/Fortran.h"
26#include "mlir/Conversion/LLVMCommon/Pattern.h"
27#include "mlir/Dialect/DLTI/DLTI.h"
28#include "mlir/Dialect/GPU/IR/GPUDialect.h"
29#include "mlir/IR/Matchers.h"
30#include "mlir/Pass/Pass.h"
31#include "mlir/Transforms/DialectConversion.h"
32#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33
34namespace fir {
35#define GEN_PASS_DEF_CUFOPCONVERSION
36#include "flang/Optimizer/Transforms/Passes.h.inc"
37} // namespace fir
38
39using namespace fir;
40using namespace mlir;
41using namespace Fortran::runtime;
42using namespace Fortran::runtime::cuda;
43
44namespace {
45
46static inline unsigned getMemType(cuf::DataAttribute attr) {
47 if (attr == cuf::DataAttribute::Device)
48 return kMemTypeDevice;
49 if (attr == cuf::DataAttribute::Managed)
50 return kMemTypeManaged;
51 if (attr == cuf::DataAttribute::Unified)
52 return kMemTypeUnified;
53 if (attr == cuf::DataAttribute::Pinned)
54 return kMemTypePinned;
55 llvm::report_fatal_error("unsupported memory type");
56}
57
58template <typename OpTy>
59static bool isPinned(OpTy op) {
60 if (op.getDataAttr() && *op.getDataAttr() == cuf::DataAttribute::Pinned)
61 return true;
62 return false;
63}
64
65template <typename OpTy>
66static bool hasDoubleDescriptors(OpTy op) {
67 if (auto declareOp =
68 mlir::dyn_cast_or_null<fir::DeclareOp>(op.getBox().getDefiningOp())) {
69 if (mlir::isa_and_nonnull<fir::AddrOfOp>(
70 declareOp.getMemref().getDefiningOp())) {
71 if (isPinned(declareOp))
72 return false;
73 return true;
74 }
75 } else if (auto declareOp = mlir::dyn_cast_or_null<hlfir::DeclareOp>(
76 op.getBox().getDefiningOp())) {
77 if (mlir::isa_and_nonnull<fir::AddrOfOp>(
78 declareOp.getMemref().getDefiningOp())) {
79 if (isPinned(declareOp))
80 return false;
81 return true;
82 }
83 }
84 return false;
85}
86
87static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
88 mlir::Location loc, mlir::Type toTy,
89 mlir::Value val) {
90 if (val.getType() != toTy)
91 return rewriter.create<fir::ConvertOp>(loc, toTy, val);
92 return val;
93}
94
95template <typename OpTy>
96static mlir::LogicalResult convertOpToCall(OpTy op,
97 mlir::PatternRewriter &rewriter,
98 mlir::func::FuncOp func) {
99 auto mod = op->template getParentOfType<mlir::ModuleOp>();
100 fir::FirOpBuilder builder(rewriter, mod);
101 mlir::Location loc = op.getLoc();
102 auto fTy = func.getFunctionType();
103
104 mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
105 mlir::Value sourceLine;
106 if constexpr (std::is_same_v<OpTy, cuf::AllocateOp>)
107 sourceLine = fir::factory::locationToLineNo(
108 builder, loc, op.getSource() ? fTy.getInput(7) : fTy.getInput(6));
109 else
110 sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
111
112 mlir::Value hasStat = op.getHasStat() ? builder.createBool(loc, true)
113 : builder.createBool(loc, false);
114
115 mlir::Value errmsg;
116 if (op.getErrmsg()) {
117 errmsg = op.getErrmsg();
118 } else {
119 mlir::Type boxNoneTy = fir::BoxType::get(builder.getNoneType());
120 errmsg = builder.create<fir::AbsentOp>(loc, boxNoneTy).getResult();
121 }
122 llvm::SmallVector<mlir::Value> args;
123 if constexpr (std::is_same_v<OpTy, cuf::AllocateOp>) {
124 mlir::Value pinned =
125 op.getPinned()
126 ? op.getPinned()
127 : builder.createNullConstant(
128 loc, fir::ReferenceType::get(
129 mlir::IntegerType::get(op.getContext(), 1)));
130 if (op.getSource()) {
131 mlir::Value stream =
132 op.getStream() ? op.getStream()
133 : builder.createNullConstant(loc, fTy.getInput(2));
134 args = fir::runtime::createArguments(
135 builder, loc, fTy, op.getBox(), op.getSource(), stream, pinned,
136 hasStat, errmsg, sourceFile, sourceLine);
137 } else {
138 mlir::Value stream =
139 op.getStream() ? op.getStream()
140 : builder.createNullConstant(loc, fTy.getInput(1));
141 args = fir::runtime::createArguments(builder, loc, fTy, op.getBox(),
142 stream, pinned, hasStat, errmsg,
143 sourceFile, sourceLine);
144 }
145 } else {
146 args =
147 fir::runtime::createArguments(builder, loc, fTy, op.getBox(), hasStat,
148 errmsg, sourceFile, sourceLine);
149 }
150 auto callOp = builder.create<fir::CallOp>(loc, func, args);
151 rewriter.replaceOp(op, callOp);
152 return mlir::success();
153}
154
155struct CUFAllocateOpConversion
156 : public mlir::OpRewritePattern<cuf::AllocateOp> {
157 using OpRewritePattern::OpRewritePattern;
158
159 mlir::LogicalResult
160 matchAndRewrite(cuf::AllocateOp op,
161 mlir::PatternRewriter &rewriter) const override {
162 auto mod = op->getParentOfType<mlir::ModuleOp>();
163 fir::FirOpBuilder builder(rewriter, mod);
164 mlir::Location loc = op.getLoc();
165
166 bool isPointer = false;
167
168 if (auto declareOp =
169 mlir::dyn_cast_or_null<fir::DeclareOp>(op.getBox().getDefiningOp()))
170 if (declareOp.getFortranAttrs() &&
171 bitEnumContainsAny(*declareOp.getFortranAttrs(),
172 fir::FortranVariableFlagsEnum::pointer))
173 isPointer = true;
174
175 if (hasDoubleDescriptors(op)) {
176 // Allocation for module variable are done with custom runtime entry point
177 // so the descriptors can be synchronized.
178 mlir::func::FuncOp func;
179 if (op.getSource()) {
180 func = isPointer ? fir::runtime::getRuntimeFunc<mkRTKey(
181 CUFPointerAllocateSourceSync)>(loc, builder)
182 : fir::runtime::getRuntimeFunc<mkRTKey(
183 CUFAllocatableAllocateSourceSync)>(loc, builder);
184 } else {
185 func =
186 isPointer
187 ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocateSync)>(
188 loc, builder)
189 : fir::runtime::getRuntimeFunc<mkRTKey(
190 CUFAllocatableAllocateSync)>(loc, builder);
191 }
192 return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
193 }
194
195 mlir::func::FuncOp func;
196 if (op.getSource()) {
197 func =
198 isPointer
199 ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocateSource)>(
200 loc, builder)
201 : fir::runtime::getRuntimeFunc<mkRTKey(
202 CUFAllocatableAllocateSource)>(loc, builder);
203 } else {
204 func =
205 isPointer
206 ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocate)>(
207 loc, builder)
208 : fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableAllocate)>(
209 loc, builder);
210 }
211
212 return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
213 }
214};
215
216struct CUFDeallocateOpConversion
217 : public mlir::OpRewritePattern<cuf::DeallocateOp> {
218 using OpRewritePattern::OpRewritePattern;
219
220 mlir::LogicalResult
221 matchAndRewrite(cuf::DeallocateOp op,
222 mlir::PatternRewriter &rewriter) const override {
223
224 auto mod = op->getParentOfType<mlir::ModuleOp>();
225 fir::FirOpBuilder builder(rewriter, mod);
226 mlir::Location loc = op.getLoc();
227
228 if (hasDoubleDescriptors(op)) {
229 // Deallocation for module variable are done with custom runtime entry
230 // point so the descriptors can be synchronized.
231 mlir::func::FuncOp func =
232 fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableDeallocate)>(
233 loc, builder);
234 return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func);
235 }
236
237 // Deallocation for local descriptor falls back on the standard runtime
238 // AllocatableDeallocate as the dedicated deallocator is set in the
239 // descriptor before the call.
240 mlir::func::FuncOp func =
241 fir::runtime::getRuntimeFunc<mkRTKey(AllocatableDeallocate)>(loc,
242 builder);
243 return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func);
244 }
245};
246
247static bool inDeviceContext(mlir::Operation *op) {
248 if (op->getParentOfType<cuf::KernelOp>())
249 return true;
250 if (auto funcOp = op->getParentOfType<mlir::gpu::GPUFuncOp>())
251 return true;
252 if (auto funcOp = op->getParentOfType<mlir::gpu::LaunchOp>())
253 return true;
254 if (auto funcOp = op->getParentOfType<mlir::func::FuncOp>()) {
255 if (auto cudaProcAttr =
256 funcOp.getOperation()->getAttrOfType<cuf::ProcAttributeAttr>(
257 cuf::getProcAttrName())) {
258 return cudaProcAttr.getValue() != cuf::ProcAttribute::Host &&
259 cudaProcAttr.getValue() != cuf::ProcAttribute::HostDevice;
260 }
261 }
262 return false;
263}
264
265static int computeWidth(mlir::Location loc, mlir::Type type,
266 fir::KindMapping &kindMap) {
267 auto eleTy = fir::unwrapSequenceType(type);
268 if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)})
269 return t.getWidth() / 8;
270 if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)})
271 return t.getWidth() / 8;
272 if (eleTy.isInteger(1))
273 return 1;
274 if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)})
275 return kindMap.getLogicalBitsize(t.getFKind()) / 8;
276 if (auto t{mlir::dyn_cast<mlir::ComplexType>(eleTy)}) {
277 int elemSize =
278 mlir::cast<mlir::FloatType>(t.getElementType()).getWidth() / 8;
279 return 2 * elemSize;
280 }
281 if (auto t{mlir::dyn_cast_or_null<fir::CharacterType>(eleTy)})
282 return kindMap.getCharacterBitsize(t.getFKind()) / 8;
283 mlir::emitError(loc, "unsupported type");
284 return 0;
285}
286
287struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
288 using OpRewritePattern::OpRewritePattern;
289
290 CUFAllocOpConversion(mlir::MLIRContext *context, mlir::DataLayout *dl,
291 const fir::LLVMTypeConverter *typeConverter)
292 : OpRewritePattern(context), dl{dl}, typeConverter{typeConverter} {}
293
294 mlir::LogicalResult
295 matchAndRewrite(cuf::AllocOp op,
296 mlir::PatternRewriter &rewriter) const override {
297
298 mlir::Location loc = op.getLoc();
299
300 if (inDeviceContext(op.getOperation())) {
301 // In device context just replace the cuf.alloc operation with a fir.alloc
302 // the cuf.free will be removed.
303 auto allocaOp = rewriter.create<fir::AllocaOp>(
304 loc, op.getInType(), op.getUniqName() ? *op.getUniqName() : "",
305 op.getBindcName() ? *op.getBindcName() : "", op.getTypeparams(),
306 op.getShape());
307 allocaOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
308 rewriter.replaceOp(op, allocaOp);
309 return mlir::success();
310 }
311
312 auto mod = op->getParentOfType<mlir::ModuleOp>();
313 fir::FirOpBuilder builder(rewriter, mod);
314 mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
315
316 if (!mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType())) {
317 // Convert scalar and known size array allocations.
318 mlir::Value bytes;
319 fir::KindMapping kindMap{fir::getKindMapping(mod)};
320 if (fir::isa_trivial(op.getInType())) {
321 int width = computeWidth(loc, op.getInType(), kindMap);
322 bytes =
323 builder.createIntegerConstant(loc, builder.getIndexType(), width);
324 } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
325 op.getInType())) {
326 std::size_t size = 0;
327 if (fir::isa_derived(seqTy.getEleTy())) {
328 mlir::Type structTy = typeConverter->convertType(seqTy.getEleTy());
329 size = dl->getTypeSizeInBits(structTy) / 8;
330 } else {
331 size = computeWidth(loc, seqTy.getEleTy(), kindMap);
332 }
333 mlir::Value width =
334 builder.createIntegerConstant(loc, builder.getIndexType(), size);
335 mlir::Value nbElem;
336 if (fir::sequenceWithNonConstantShape(seqTy)) {
337 assert(!op.getShape().empty() && "expect shape with dynamic arrays");
338 nbElem = builder.loadIfRef(loc, op.getShape()[0]);
339 for (unsigned i = 1; i < op.getShape().size(); ++i) {
340 nbElem = rewriter.create<mlir::arith::MulIOp>(
341 loc, nbElem, builder.loadIfRef(loc, op.getShape()[i]));
342 }
343 } else {
344 nbElem = builder.createIntegerConstant(loc, builder.getIndexType(),
345 seqTy.getConstantArraySize());
346 }
347 bytes = rewriter.create<mlir::arith::MulIOp>(loc, nbElem, width);
348 } else if (fir::isa_derived(op.getInType())) {
349 mlir::Type structTy = typeConverter->convertType(op.getInType());
350 std::size_t structSize = dl->getTypeSizeInBits(structTy) / 8;
351 bytes = builder.createIntegerConstant(loc, builder.getIndexType(),
352 structSize);
353 } else {
354 mlir::emitError(loc, "unsupported type in cuf.alloc\n");
355 }
356 mlir::func::FuncOp func =
357 fir::runtime::getRuntimeFunc<mkRTKey(CUFMemAlloc)>(loc, builder);
358 auto fTy = func.getFunctionType();
359 mlir::Value sourceLine =
360 fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
361 mlir::Value memTy = builder.createIntegerConstant(
362 loc, builder.getI32Type(), getMemType(op.getDataAttr()));
363 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
364 builder, loc, fTy, bytes, memTy, sourceFile, sourceLine)};
365 auto callOp = builder.create<fir::CallOp>(loc, func, args);
366 callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
367 auto convOp = builder.createConvert(loc, op.getResult().getType(),
368 callOp.getResult(0));
369 rewriter.replaceOp(op, convOp);
370 return mlir::success();
371 }
372
373 // Convert descriptor allocations to function call.
374 auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType());
375 mlir::func::FuncOp func =
376 fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocDescriptor)>(loc, builder);
377 auto fTy = func.getFunctionType();
378 mlir::Value sourceLine =
379 fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
380
381 mlir::Type structTy = typeConverter->convertBoxTypeAsStruct(boxTy);
382 std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8;
383 mlir::Value sizeInBytes =
384 builder.createIntegerConstant(loc, builder.getIndexType(), boxSize);
385
386 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
387 builder, loc, fTy, sizeInBytes, sourceFile, sourceLine)};
388 auto callOp = builder.create<fir::CallOp>(loc, func, args);
389 callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
390 auto convOp = builder.createConvert(loc, op.getResult().getType(),
391 callOp.getResult(0));
392 rewriter.replaceOp(op, convOp);
393 return mlir::success();
394 }
395
396private:
397 mlir::DataLayout *dl;
398 const fir::LLVMTypeConverter *typeConverter;
399};
400
401struct CUFDeviceAddressOpConversion
402 : public mlir::OpRewritePattern<cuf::DeviceAddressOp> {
403 using OpRewritePattern::OpRewritePattern;
404
405 CUFDeviceAddressOpConversion(mlir::MLIRContext *context,
406 const mlir::SymbolTable &symtab)
407 : OpRewritePattern(context), symTab{symtab} {}
408
409 mlir::LogicalResult
410 matchAndRewrite(cuf::DeviceAddressOp op,
411 mlir::PatternRewriter &rewriter) const override {
412 if (auto global = symTab.lookup<fir::GlobalOp>(
413 op.getHostSymbol().getRootReference().getValue())) {
414 auto mod = op->getParentOfType<mlir::ModuleOp>();
415 mlir::Location loc = op.getLoc();
416 auto hostAddr = rewriter.create<fir::AddrOfOp>(
417 loc, fir::ReferenceType::get(global.getType()), op.getHostSymbol());
418 fir::FirOpBuilder builder(rewriter, mod);
419 mlir::func::FuncOp callee =
420 fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc,
421 builder);
422 auto fTy = callee.getFunctionType();
423 mlir::Value conv =
424 createConvertOp(rewriter, loc, fTy.getInput(0), hostAddr);
425 mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
426 mlir::Value sourceLine =
427 fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
428 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
429 builder, loc, fTy, conv, sourceFile, sourceLine)};
430 auto call = rewriter.create<fir::CallOp>(loc, callee, args);
431 mlir::Value addr = createConvertOp(rewriter, loc, hostAddr.getType(),
432 call->getResult(0));
433 rewriter.replaceOp(op, addr.getDefiningOp());
434 return success();
435 }
436 return failure();
437 }
438
439private:
440 const mlir::SymbolTable &symTab;
441};
442
443struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
444 using OpRewritePattern::OpRewritePattern;
445
446 DeclareOpConversion(mlir::MLIRContext *context,
447 const mlir::SymbolTable &symtab)
448 : OpRewritePattern(context), symTab{symtab} {}
449
450 mlir::LogicalResult
451 matchAndRewrite(fir::DeclareOp op,
452 mlir::PatternRewriter &rewriter) const override {
453 if (auto addrOfOp = op.getMemref().getDefiningOp<fir::AddrOfOp>()) {
454 if (auto global = symTab.lookup<fir::GlobalOp>(
455 addrOfOp.getSymbol().getRootReference().getValue())) {
456 if (cuf::isRegisteredDeviceGlobal(global)) {
457 rewriter.setInsertionPointAfter(addrOfOp);
458 mlir::Value devAddr = rewriter.create<cuf::DeviceAddressOp>(
459 op.getLoc(), addrOfOp.getType(), addrOfOp.getSymbol());
460 rewriter.startOpModification(op);
461 op.getMemrefMutable().assign(devAddr);
462 rewriter.finalizeOpModification(op);
463 return success();
464 }
465 }
466 }
467 return failure();
468 }
469
470private:
471 const mlir::SymbolTable &symTab;
472};
473
474struct CUFFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
475 using OpRewritePattern::OpRewritePattern;
476
477 mlir::LogicalResult
478 matchAndRewrite(cuf::FreeOp op,
479 mlir::PatternRewriter &rewriter) const override {
480 if (inDeviceContext(op.getOperation())) {
481 rewriter.eraseOp(op);
482 return mlir::success();
483 }
484
485 if (!mlir::isa<fir::ReferenceType>(op.getDevptr().getType()))
486 return failure();
487
488 auto mod = op->getParentOfType<mlir::ModuleOp>();
489 fir::FirOpBuilder builder(rewriter, mod);
490 mlir::Location loc = op.getLoc();
491 mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
492
493 auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr().getType());
494 if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy())) {
495 mlir::func::FuncOp func =
496 fir::runtime::getRuntimeFunc<mkRTKey(CUFMemFree)>(loc, builder);
497 auto fTy = func.getFunctionType();
498 mlir::Value sourceLine =
499 fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
500 mlir::Value memTy = builder.createIntegerConstant(
501 loc, builder.getI32Type(), getMemType(op.getDataAttr()));
502 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
503 builder, loc, fTy, op.getDevptr(), memTy, sourceFile, sourceLine)};
504 builder.create<fir::CallOp>(loc, func, args);
505 rewriter.eraseOp(op);
506 return mlir::success();
507 }
508
509 // Convert cuf.free on descriptors.
510 mlir::func::FuncOp func =
511 fir::runtime::getRuntimeFunc<mkRTKey(CUFFreeDescriptor)>(loc, builder);
512 auto fTy = func.getFunctionType();
513 mlir::Value sourceLine =
514 fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
515 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
516 builder, loc, fTy, op.getDevptr(), sourceFile, sourceLine)};
517 auto callOp = builder.create<fir::CallOp>(loc, func, args);
518 callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
519 rewriter.eraseOp(op);
520 return mlir::success();
521 }
522};
523
524static bool isDstGlobal(cuf::DataTransferOp op) {
525 if (auto declareOp = op.getDst().getDefiningOp<fir::DeclareOp>())
526 if (declareOp.getMemref().getDefiningOp<fir::AddrOfOp>())
527 return true;
528 if (auto declareOp = op.getDst().getDefiningOp<hlfir::DeclareOp>())
529 if (declareOp.getMemref().getDefiningOp<fir::AddrOfOp>())
530 return true;
531 return false;
532}
533
534static mlir::Value getShapeFromDecl(mlir::Value src) {
535 if (auto declareOp = src.getDefiningOp<fir::DeclareOp>())
536 return declareOp.getShape();
537 if (auto declareOp = src.getDefiningOp<hlfir::DeclareOp>())
538 return declareOp.getShape();
539 return mlir::Value{};
540}
541
542static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
543 cuf::DataTransferOp op,
544 const mlir::SymbolTable &symtab,
545 mlir::Type dstEleTy = nullptr) {
546 auto mod = op->getParentOfType<mlir::ModuleOp>();
547 mlir::Location loc = op.getLoc();
548 fir::FirOpBuilder builder(rewriter, mod);
549 mlir::Value addr;
550 mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType());
551 if (fir::isa_trivial(srcTy) &&
552 mlir::matchPattern(op.getSrc().getDefiningOp(), mlir::m_Constant())) {
553 mlir::Value src = op.getSrc();
554 if (srcTy.isInteger(1)) {
555 // i1 is not a supported type in the descriptor and it is actually coming
556 // from a LOGICAL constant. Store it as a fir.logical.
557 srcTy = fir::LogicalType::get(rewriter.getContext(), 4);
558 src = createConvertOp(rewriter, loc, srcTy, src);
559 addr = builder.createTemporary(loc, srcTy);
560 builder.create<fir::StoreOp>(loc, src, addr);
561 } else {
562 if (dstEleTy && fir::isa_trivial(dstEleTy) && srcTy != dstEleTy) {
563 // Use dstEleTy and convert to avoid assign mismatch.
564 addr = builder.createTemporary(loc, dstEleTy);
565 auto conv = builder.create<fir::ConvertOp>(loc, dstEleTy, src);
566 builder.create<fir::StoreOp>(loc, conv, addr);
567 srcTy = dstEleTy;
568 } else {
569 // Put constant in memory if it is not.
570 addr = builder.createTemporary(loc, srcTy);
571 builder.create<fir::StoreOp>(loc, src, addr);
572 }
573 }
574 } else {
575 addr = op.getSrc();
576 }
577 llvm::SmallVector<mlir::Value> lenParams;
578 mlir::Type boxTy = fir::BoxType::get(srcTy);
579 mlir::Value box =
580 builder.createBox(loc, boxTy, addr, getShapeFromDecl(op.getSrc()),
581 /*slice=*/nullptr, lenParams,
582 /*tdesc=*/nullptr);
583 mlir::Value src = builder.createTemporary(loc, box.getType());
584 builder.create<fir::StoreOp>(loc, box, src);
585 return src;
586}
587
588static mlir::Value emboxDst(mlir::PatternRewriter &rewriter,
589 cuf::DataTransferOp op,
590 const mlir::SymbolTable &symtab) {
591 auto mod = op->getParentOfType<mlir::ModuleOp>();
592 mlir::Location loc = op.getLoc();
593 fir::FirOpBuilder builder(rewriter, mod);
594 mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType());
595 mlir::Value dstAddr = op.getDst();
596 mlir::Type dstBoxTy = fir::BoxType::get(dstTy);
597 llvm::SmallVector<mlir::Value> lenParams;
598 mlir::Value dstBox =
599 builder.createBox(loc, dstBoxTy, dstAddr, getShapeFromDecl(op.getDst()),
600 /*slice=*/nullptr, lenParams,
601 /*tdesc=*/nullptr);
602 mlir::Value dst = builder.createTemporary(loc, dstBox.getType());
603 builder.create<fir::StoreOp>(loc, dstBox, dst);
604 return dst;
605}
606
607struct CUFDataTransferOpConversion
608 : public mlir::OpRewritePattern<cuf::DataTransferOp> {
609 using OpRewritePattern::OpRewritePattern;
610
611 CUFDataTransferOpConversion(mlir::MLIRContext *context,
612 const mlir::SymbolTable &symtab,
613 mlir::DataLayout *dl,
614 const fir::LLVMTypeConverter *typeConverter)
615 : OpRewritePattern(context), symtab{symtab}, dl{dl},
616 typeConverter{typeConverter} {}
617
618 mlir::LogicalResult
619 matchAndRewrite(cuf::DataTransferOp op,
620 mlir::PatternRewriter &rewriter) const override {
621
622 mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType());
623 mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType());
624
625 mlir::Location loc = op.getLoc();
626 unsigned mode = 0;
627 if (op.getTransferKind() == cuf::DataTransferKind::HostDevice) {
628 mode = kHostToDevice;
629 } else if (op.getTransferKind() == cuf::DataTransferKind::DeviceHost) {
630 mode = kDeviceToHost;
631 } else if (op.getTransferKind() == cuf::DataTransferKind::DeviceDevice) {
632 mode = kDeviceToDevice;
633 } else {
634 mlir::emitError(loc, "unsupported transfer kind\n");
635 }
636
637 auto mod = op->getParentOfType<mlir::ModuleOp>();
638 fir::FirOpBuilder builder(rewriter, mod);
639 fir::KindMapping kindMap{fir::getKindMapping(mod)};
640 mlir::Value modeValue =
641 builder.createIntegerConstant(loc, builder.getI32Type(), mode);
642
643 // Convert data transfer without any descriptor.
644 if (!mlir::isa<fir::BaseBoxType>(srcTy) &&
645 !mlir::isa<fir::BaseBoxType>(dstTy)) {
646
647 if (fir::isa_trivial(srcTy) && !fir::isa_trivial(dstTy)) {
648 // Initialization of an array from a scalar value should be implemented
649 // via a kernel launch. Use the flan runtime via the Assign function
650 // until we have more infrastructure.
651 mlir::Value src = emboxSrc(rewriter, op, symtab);
652 mlir::Value dst = emboxDst(rewriter, op, symtab);
653 mlir::func::FuncOp func =
654 fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferCstDesc)>(
655 loc, builder);
656 auto fTy = func.getFunctionType();
657 mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
658 mlir::Value sourceLine =
659 fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
660 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
661 builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
662 builder.create<fir::CallOp>(loc, func, args);
663 rewriter.eraseOp(op);
664 return mlir::success();
665 }
666
667 mlir::Type i64Ty = builder.getI64Type();
668 mlir::Value nbElement;
669 if (op.getShape()) {
670 llvm::SmallVector<mlir::Value> extents;
671 if (auto shapeOp =
672 mlir::dyn_cast<fir::ShapeOp>(op.getShape().getDefiningOp())) {
673 extents = shapeOp.getExtents();
674 } else if (auto shapeShiftOp = mlir::dyn_cast<fir::ShapeShiftOp>(
675 op.getShape().getDefiningOp())) {
676 for (auto i : llvm::enumerate(shapeShiftOp.getPairs()))
677 if (i.index() & 1)
678 extents.push_back(i.value());
679 }
680
681 nbElement = rewriter.create<fir::ConvertOp>(loc, i64Ty, extents[0]);
682 for (unsigned i = 1; i < extents.size(); ++i) {
683 auto operand =
684 rewriter.create<fir::ConvertOp>(loc, i64Ty, extents[i]);
685 nbElement =
686 rewriter.create<mlir::arith::MulIOp>(loc, nbElement, operand);
687 }
688 } else {
689 if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(dstTy))
690 nbElement = builder.createIntegerConstant(
691 loc, i64Ty, seqTy.getConstantArraySize());
692 }
693 unsigned width = 0;
694 if (fir::isa_derived(fir::unwrapSequenceType(dstTy))) {
695 mlir::Type structTy =
696 typeConverter->convertType(fir::unwrapSequenceType(dstTy));
697 width = dl->getTypeSizeInBits(structTy) / 8;
698 } else {
699 width = computeWidth(loc, dstTy, kindMap);
700 }
701 mlir::Value widthValue = rewriter.create<mlir::arith::ConstantOp>(
702 loc, i64Ty, rewriter.getIntegerAttr(i64Ty, width));
703 mlir::Value bytes =
704 nbElement
705 ? rewriter.create<mlir::arith::MulIOp>(loc, nbElement, widthValue)
706 : widthValue;
707
708 mlir::func::FuncOp func =
709 fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrPtr)>(loc,
710 builder);
711 auto fTy = func.getFunctionType();
712 mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
713 mlir::Value sourceLine =
714 fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
715
716 mlir::Value dst = op.getDst();
717 mlir::Value src = op.getSrc();
718 // Materialize the src if constant.
719 if (matchPattern(src.getDefiningOp(), mlir::m_Constant())) {
720 mlir::Value temp = builder.createTemporary(loc, srcTy);
721 builder.create<fir::StoreOp>(loc, src, temp);
722 src = temp;
723 }
724 llvm::SmallVector<mlir::Value> args{
725 fir::runtime::createArguments(builder, loc, fTy, dst, src, bytes,
726 modeValue, sourceFile, sourceLine)};
727 builder.create<fir::CallOp>(loc, func, args);
728 rewriter.eraseOp(op);
729 return mlir::success();
730 }
731
732 auto materializeBoxIfNeeded = [&](mlir::Value val) -> mlir::Value {
733 if (mlir::isa<fir::EmboxOp, fir::ReboxOp>(val.getDefiningOp())) {
734 // Materialize the box to memory to be able to call the runtime.
735 mlir::Value box = builder.createTemporary(loc, val.getType());
736 builder.create<fir::StoreOp>(loc, val, box);
737 return box;
738 }
739 return val;
740 };
741
742 // Conversion of data transfer involving at least one descriptor.
743 if (auto dstBoxTy = mlir::dyn_cast<fir::BaseBoxType>(dstTy)) {
744 // Transfer to a descriptor.
745 mlir::func::FuncOp func =
746 isDstGlobal(op)
747 ? fir::runtime::getRuntimeFunc<mkRTKey(
748 CUFDataTransferGlobalDescDesc)>(loc, builder)
749 : fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescDesc)>(
750 loc, builder);
751 mlir::Value dst = op.getDst();
752 mlir::Value src = op.getSrc();
753 if (!mlir::isa<fir::BaseBoxType>(srcTy)) {
754 mlir::Type dstEleTy = fir::unwrapInnerType(dstBoxTy.getEleTy());
755 src = emboxSrc(rewriter, op, symtab, dstEleTy);
756 if (fir::isa_trivial(srcTy))
757 func = fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferCstDesc)>(
758 loc, builder);
759 }
760
761 src = materializeBoxIfNeeded(src);
762 dst = materializeBoxIfNeeded(dst);
763
764 auto fTy = func.getFunctionType();
765 mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
766 mlir::Value sourceLine =
767 fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
768 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
769 builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
770 builder.create<fir::CallOp>(loc, func, args);
771 rewriter.eraseOp(op);
772 } else {
773 // Transfer from a descriptor.
774 mlir::Value dst = emboxDst(rewriter, op, symtab);
775 mlir::Value src = materializeBoxIfNeeded(op.getSrc());
776
777 mlir::func::FuncOp func = fir::runtime::getRuntimeFunc<mkRTKey(
778 CUFDataTransferDescDescNoRealloc)>(loc, builder);
779
780 auto fTy = func.getFunctionType();
781 mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
782 mlir::Value sourceLine =
783 fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
784 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
785 builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
786 builder.create<fir::CallOp>(loc, func, args);
787 rewriter.eraseOp(op);
788 }
789 return mlir::success();
790 }
791
792private:
793 const mlir::SymbolTable &symtab;
794 mlir::DataLayout *dl;
795 const fir::LLVMTypeConverter *typeConverter;
796};
797
798struct CUFLaunchOpConversion
799 : public mlir::OpRewritePattern<cuf::KernelLaunchOp> {
800public:
801 using OpRewritePattern::OpRewritePattern;
802
803 CUFLaunchOpConversion(mlir::MLIRContext *context,
804 const mlir::SymbolTable &symTab)
805 : OpRewritePattern(context), symTab{symTab} {}
806
807 mlir::LogicalResult
808 matchAndRewrite(cuf::KernelLaunchOp op,
809 mlir::PatternRewriter &rewriter) const override {
810 mlir::Location loc = op.getLoc();
811 auto idxTy = mlir::IndexType::get(op.getContext());
812 mlir::Value zero = rewriter.create<mlir::arith::ConstantOp>(
813 loc, rewriter.getIntegerType(32), rewriter.getI32IntegerAttr(0));
814 auto gridSizeX =
815 rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getGridX());
816 auto gridSizeY =
817 rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getGridY());
818 auto gridSizeZ =
819 rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getGridZ());
820 auto blockSizeX =
821 rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getBlockX());
822 auto blockSizeY =
823 rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getBlockY());
824 auto blockSizeZ =
825 rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getBlockZ());
826 auto kernelName = mlir::SymbolRefAttr::get(
827 rewriter.getStringAttr(cudaDeviceModuleName),
828 {mlir::SymbolRefAttr::get(
829 rewriter.getContext(),
830 op.getCallee().getLeafReference().getValue())});
831 mlir::Value clusterDimX, clusterDimY, clusterDimZ;
832 cuf::ProcAttributeAttr procAttr;
833 if (auto funcOp = symTab.lookup<mlir::func::FuncOp>(
834 op.getCallee().getLeafReference())) {
835 if (auto clusterDimsAttr = funcOp->getAttrOfType<cuf::ClusterDimsAttr>(
836 cuf::getClusterDimsAttrName())) {
837 clusterDimX = rewriter.create<mlir::arith::ConstantIndexOp>(
838 loc, clusterDimsAttr.getX().getInt());
839 clusterDimY = rewriter.create<mlir::arith::ConstantIndexOp>(
840 loc, clusterDimsAttr.getY().getInt());
841 clusterDimZ = rewriter.create<mlir::arith::ConstantIndexOp>(
842 loc, clusterDimsAttr.getZ().getInt());
843 }
844 procAttr =
845 funcOp->getAttrOfType<cuf::ProcAttributeAttr>(cuf::getProcAttrName());
846 }
847 llvm::SmallVector<mlir::Value> args;
848 for (mlir::Value arg : op.getArgs()) {
849 // If the argument is a global descriptor, make sure we pass the device
850 // copy of this descriptor and not the host one.
851 if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(arg.getType()))) {
852 if (auto declareOp =
853 mlir::dyn_cast_or_null<fir::DeclareOp>(arg.getDefiningOp())) {
854 if (auto addrOfOp = mlir::dyn_cast_or_null<fir::AddrOfOp>(
855 declareOp.getMemref().getDefiningOp())) {
856 if (auto global = symTab.lookup<fir::GlobalOp>(
857 addrOfOp.getSymbol().getRootReference().getValue())) {
858 if (cuf::isRegisteredDeviceGlobal(global)) {
859 arg = rewriter
860 .create<cuf::DeviceAddressOp>(op.getLoc(),
861 addrOfOp.getType(),
862 addrOfOp.getSymbol())
863 .getResult();
864 }
865 }
866 }
867 }
868 }
869 args.push_back(arg);
870 }
871 mlir::Value dynamicShmemSize = op.getBytes() ? op.getBytes() : zero;
872 auto gpuLaunchOp = rewriter.create<mlir::gpu::LaunchFuncOp>(
873 loc, kernelName, mlir::gpu::KernelDim3{gridSizeX, gridSizeY, gridSizeZ},
874 mlir::gpu::KernelDim3{blockSizeX, blockSizeY, blockSizeZ},
875 dynamicShmemSize, args);
876 if (clusterDimX && clusterDimY && clusterDimZ) {
877 gpuLaunchOp.getClusterSizeXMutable().assign(clusterDimX);
878 gpuLaunchOp.getClusterSizeYMutable().assign(clusterDimY);
879 gpuLaunchOp.getClusterSizeZMutable().assign(clusterDimZ);
880 }
881 if (op.getStream()) {
882 mlir::OpBuilder::InsertionGuard guard(rewriter);
883 rewriter.setInsertionPoint(gpuLaunchOp);
884 mlir::Value stream =
885 rewriter.create<cuf::StreamCastOp>(loc, op.getStream());
886 gpuLaunchOp.getAsyncDependenciesMutable().append(stream);
887 }
888 if (procAttr)
889 gpuLaunchOp->setAttr(cuf::getProcAttrName(), procAttr);
890 else
891 // Set default global attribute of the original was not found.
892 gpuLaunchOp->setAttr(cuf::getProcAttrName(),
893 cuf::ProcAttributeAttr::get(
894 op.getContext(), cuf::ProcAttribute::Global));
895 rewriter.replaceOp(op, gpuLaunchOp);
896 return mlir::success();
897 }
898
899private:
900 const mlir::SymbolTable &symTab;
901};
902
903struct CUFSyncDescriptorOpConversion
904 : public mlir::OpRewritePattern<cuf::SyncDescriptorOp> {
905 using OpRewritePattern::OpRewritePattern;
906
907 mlir::LogicalResult
908 matchAndRewrite(cuf::SyncDescriptorOp op,
909 mlir::PatternRewriter &rewriter) const override {
910 auto mod = op->getParentOfType<mlir::ModuleOp>();
911 fir::FirOpBuilder builder(rewriter, mod);
912 mlir::Location loc = op.getLoc();
913
914 auto globalOp = mod.lookupSymbol<fir::GlobalOp>(op.getGlobalName());
915 if (!globalOp)
916 return mlir::failure();
917
918 auto hostAddr = builder.create<fir::AddrOfOp>(
919 loc, fir::ReferenceType::get(globalOp.getType()), op.getGlobalName());
920 fir::runtime::cuda::genSyncGlobalDescriptor(builder, loc, hostAddr);
921 op.erase();
922 return mlir::success();
923 }
924};
925
926class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
927public:
928 void runOnOperation() override {
929 auto *ctx = &getContext();
930 mlir::RewritePatternSet patterns(ctx);
931 mlir::ConversionTarget target(*ctx);
932
933 mlir::Operation *op = getOperation();
934 mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(op);
935 if (!module)
936 return signalPassFailure();
937 mlir::SymbolTable symtab(module);
938
939 std::optional<mlir::DataLayout> dl = fir::support::getOrSetMLIRDataLayout(
940 module, /*allowDefaultLayout=*/false);
941 fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false,
942 /*forceUnifiedTBAATree=*/false, *dl);
943 target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
944 mlir::gpu::GPUDialect>();
945 target.addLegalOp<cuf::StreamCastOp>();
946 cuf::populateCUFToFIRConversionPatterns(typeConverter, *dl, symtab,
947 patterns);
948 if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
949 std::move(patterns)))) {
950 mlir::emitError(mlir::UnknownLoc::get(ctx),
951 "error in CUF op conversion\n");
952 signalPassFailure();
953 }
954
955 target.addDynamicallyLegalOp<fir::DeclareOp>([&](fir::DeclareOp op) {
956 if (inDeviceContext(op))
957 return true;
958 if (auto addrOfOp = op.getMemref().getDefiningOp<fir::AddrOfOp>()) {
959 if (auto global = symtab.lookup<fir::GlobalOp>(
960 addrOfOp.getSymbol().getRootReference().getValue())) {
961 if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(global.getType())))
962 return true;
963 if (cuf::isRegisteredDeviceGlobal(global))
964 return false;
965 }
966 }
967 return true;
968 });
969
970 patterns.clear();
971 cuf::populateFIRCUFConversionPatterns(symtab, patterns);
972 if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
973 std::move(patterns)))) {
974 mlir::emitError(mlir::UnknownLoc::get(ctx),
975 "error in CUF op conversion\n");
976 signalPassFailure();
977 }
978 }
979};
980} // namespace
981
982void cuf::populateCUFToFIRConversionPatterns(
983 const fir::LLVMTypeConverter &converter, mlir::DataLayout &dl,
984 const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) {
985 patterns.insert<CUFAllocOpConversion>(patterns.getContext(), &dl, &converter);
986 patterns.insert<CUFAllocateOpConversion, CUFDeallocateOpConversion,
987 CUFFreeOpConversion, CUFSyncDescriptorOpConversion>(
988 patterns.getContext());
989 patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab,
990 &dl, &converter);
991 patterns.insert<CUFLaunchOpConversion, CUFDeviceAddressOpConversion>(
992 patterns.getContext(), symtab);
993}
994
995void cuf::populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab,
996 mlir::RewritePatternSet &patterns) {
997 patterns.insert<DeclareOpConversion, CUFDeviceAddressOpConversion>(
998 patterns.getContext(), symtab);
999}
1000

source code of flang/lib/Optimizer/Transforms/CUFOpConversion.cpp