1 | //===- TestOpenACCInterfaces.cpp ------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/OpenACC/OpenACC.h" |
10 | #include "mlir/IR/Builders.h" |
11 | #include "mlir/IR/BuiltinOps.h" |
12 | #include "mlir/Pass/Pass.h" |
13 | #include "mlir/Support/LLVM.h" |
14 | #include "flang/Optimizer/Support/DataLayout.h" |
15 | |
16 | using namespace mlir; |
17 | |
18 | namespace { |
19 | |
20 | struct TestFIROpenACCInterfaces |
21 | : public PassWrapper<TestFIROpenACCInterfaces, OperationPass<ModuleOp>> { |
22 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFIROpenACCInterfaces) |
23 | |
24 | StringRef getArgument() const final { return "test-fir-openacc-interfaces" ; } |
25 | StringRef getDescription() const final { |
26 | return "Test FIR implementation of the OpenACC interfaces." ; |
27 | } |
28 | void runOnOperation() override { |
29 | mlir::ModuleOp mod = getOperation(); |
30 | auto datalayout = |
31 | fir::support::getOrSetMLIRDataLayout(mod, /*allowDefaultLayout=*/true); |
32 | mlir::OpBuilder builder(mod); |
33 | getOperation().walk([&](Operation *op) { |
34 | if (isa<ACC_DATA_ENTRY_OPS>(op)) { |
35 | Value var = acc::getVar(op); |
36 | Type typeOfVar = var.getType(); |
37 | |
38 | // Attempt to determine if the variable is mappable-like or if |
39 | // the pointee itself is mappable-like. For example, if the variable is |
40 | // of type !fir.ref<!fir.box<>>, we want to print both the details about |
41 | // the !fir.ref since it is pointer-like, and about !fir.box since it |
42 | // is mappable. |
43 | auto mappableTy = dyn_cast_if_present<acc::MappableType>(typeOfVar); |
44 | if (!mappableTy) { |
45 | mappableTy = |
46 | dyn_cast_if_present<acc::MappableType>(acc::getVarType(op)); |
47 | } |
48 | |
49 | llvm::errs() << "Visiting: " << *op << "\n" ; |
50 | llvm::errs() << "\tVar: " << var << "\n" ; |
51 | |
52 | if (auto ptrTy = dyn_cast_if_present<acc::PointerLikeType>(typeOfVar)) { |
53 | llvm::errs() << "\tPointer-like: " << typeOfVar << "\n" ; |
54 | // If the pointee is not mappable, print details about it. Otherwise, |
55 | // we defer to the mappable printing below to print those details. |
56 | if (!mappableTy) { |
57 | acc::VariableTypeCategory typeCategory = |
58 | ptrTy.getPointeeTypeCategory( |
59 | cast<TypedValue<acc::PointerLikeType>>(var), |
60 | acc::getVarType(op)); |
61 | llvm::errs() << "\t\tType category: " << typeCategory << "\n" ; |
62 | } |
63 | } |
64 | |
65 | if (mappableTy) { |
66 | llvm::errs() << "\tMappable: " << mappableTy << "\n" ; |
67 | |
68 | acc::VariableTypeCategory typeCategory = |
69 | mappableTy.getTypeCategory(var); |
70 | llvm::errs() << "\t\tType category: " << typeCategory << "\n" ; |
71 | |
72 | if (datalayout.has_value()) { |
73 | auto size = mappableTy.getSizeInBytes( |
74 | acc::getVar(op), acc::getBounds(op), datalayout.value()); |
75 | if (size) { |
76 | llvm::errs() << "\t\tSize: " << size.value() << "\n" ; |
77 | } |
78 | auto offset = mappableTy.getOffsetInBytes( |
79 | acc::getVar(op), acc::getBounds(op), datalayout.value()); |
80 | if (offset) { |
81 | llvm::errs() << "\t\tOffset: " << offset.value() << "\n" ; |
82 | } |
83 | } |
84 | |
85 | builder.setInsertionPoint(op); |
86 | auto bounds = mappableTy.generateAccBounds(acc::getVar(op), builder); |
87 | if (!bounds.empty()) { |
88 | for (auto [idx, bound] : llvm::enumerate(bounds)) { |
89 | llvm::errs() << "\t\tBound[" << idx << "]: " << bound << "\n" ; |
90 | } |
91 | } |
92 | } |
93 | } |
94 | }); |
95 | } |
96 | }; |
97 | } // namespace |
98 | |
99 | //===----------------------------------------------------------------------===// |
100 | // Pass Registration |
101 | //===----------------------------------------------------------------------===// |
102 | |
103 | namespace fir { |
104 | namespace test { |
105 | void registerTestFIROpenACCInterfacesPass() { |
106 | PassRegistration<TestFIROpenACCInterfaces>(); |
107 | } |
108 | } // namespace test |
109 | } // namespace fir |
110 | |