1 | //===- SetRuntimeCallAttributes.cpp ---------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | //===----------------------------------------------------------------------===// |
10 | /// \file |
11 | /// SetRuntimeCallAttributesPass looks for fir.call operations |
12 | /// that are calling into Fortran runtime, and tries to set different |
13 | /// attributes on them to enable more optimizations in LLVM backend |
14 | /// (granted that they are preserved all the way to LLVM IR). |
15 | /// This pass is currently only attaching fir.call wide atttributes, |
16 | /// such as ones corresponding to llvm.memory, nosync, nocallbac, etc. |
17 | /// It is not designed to attach attributes to the arguments and the results |
18 | /// of a call. |
19 | //===----------------------------------------------------------------------===// |
20 | #include "flang/Common/static-multimap-view.h" |
21 | #include "flang/Optimizer/Builder/Runtime/RTBuilder.h" |
22 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
23 | #include "flang/Optimizer/Dialect/FIROpsSupport.h" |
24 | #include "flang/Optimizer/Support/InternalNames.h" |
25 | #include "flang/Optimizer/Transforms/Passes.h" |
26 | #include "flang/Runtime/io-api.h" |
27 | #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" |
28 | |
29 | namespace fir { |
30 | #define GEN_PASS_DEF_SETRUNTIMECALLATTRIBUTES |
31 | #include "flang/Optimizer/Transforms/Passes.h.inc" |
32 | } // namespace fir |
33 | |
34 | #define DEBUG_TYPE "set-runtime-call-attrs" |
35 | |
36 | using namespace Fortran::runtime; |
37 | using namespace Fortran::runtime::io; |
38 | |
39 | #define mkIOKey(X) FirmkKey(IONAME(X)) |
40 | #define mkRTKey(X) FirmkKey(RTNAME(X)) |
41 | |
42 | // Return LLVM dialect MemoryEffectsAttr for the given Fortran runtime call. |
43 | // This function is computing a generic value of this attribute |
44 | // by analyzing the arguments and their types. |
45 | // It tries to figure out if an "indirect" memory access is possible |
46 | // during this call. If it is not possible, then the memory effects |
47 | // are: |
48 | // * other = NoModRef |
49 | // * argMem = ModRef |
50 | // * inaccessibleMem = ModRef |
51 | // |
52 | // Otherwise, it returns an empty attribute meaning ModRef for all kinds |
53 | // of memory. |
54 | // |
55 | // The attribute deduction is conservative in a sense that it applies |
56 | // to most of the runtime calls, but it may still be incorrect for some |
57 | // runtime calls. |
58 | static mlir::LLVM::MemoryEffectsAttr getGenericMemoryAttr(fir::CallOp callOp) { |
59 | bool maybeIndirectAccess = false; |
60 | for (auto arg : callOp.getArgOperands()) { |
61 | mlir::Type argType = arg.getType(); |
62 | if (mlir::isa<fir::BaseBoxType>(argType)) { |
63 | // If it is a null/absent box, then this particular call |
64 | // cannot access memory indirectly through the box's |
65 | // base_addr. |
66 | auto def = arg.getDefiningOp(); |
67 | if (!mlir::isa_and_nonnull<fir::ZeroOp, fir::AbsentOp>(def)) { |
68 | maybeIndirectAccess = true; |
69 | break; |
70 | } |
71 | } |
72 | if (auto refType = mlir::dyn_cast<fir::ReferenceType>(argType)) { |
73 | if (!fir::isa_trivial(refType.getElementType())) { |
74 | maybeIndirectAccess = true; |
75 | break; |
76 | } |
77 | } |
78 | if (auto ptrType = mlir::dyn_cast<mlir::LLVM::LLVMPointerType>(argType)) { |
79 | maybeIndirectAccess = true; |
80 | break; |
81 | } |
82 | } |
83 | if (!maybeIndirectAccess) { |
84 | return mlir::LLVM::MemoryEffectsAttr::get( |
85 | callOp->getContext(), |
86 | {/*other=*/mlir::LLVM::ModRefInfo::NoModRef, |
87 | /*argMem=*/mlir::LLVM::ModRefInfo::ModRef, |
88 | /*inaccessibleMem=*/mlir::LLVM::ModRefInfo::ModRef}); |
89 | } |
90 | |
91 | return {}; |
92 | } |
93 | |
94 | namespace { |
95 | class SetRuntimeCallAttributesPass |
96 | : public fir::impl::SetRuntimeCallAttributesBase< |
97 | SetRuntimeCallAttributesPass> { |
98 | public: |
99 | void runOnOperation() override; |
100 | }; |
101 | |
102 | // A helper to match a type against a list of types. |
103 | template <typename T, typename... Ts> |
104 | constexpr bool IsAny = std::disjunction_v<std::is_same<T, Ts>...>; |
105 | } // end anonymous namespace |
106 | |
107 | // MemoryAttrDesc type provides get() method for computing |
108 | // mlir::LLVM::MemoryEffectsAttr for the given Fortran runtime call. |
109 | // If needed, add specializations for particular runtime calls. |
110 | namespace { |
111 | // Default implementation just uses getGenericMemoryAttr(). |
112 | // Note that it may be incorrect for some runtime calls. |
113 | template <typename KEY, typename Enable = void> |
114 | struct MemoryAttrDesc { |
115 | static mlir::LLVM::MemoryEffectsAttr get(fir::CallOp callOp) { |
116 | return getGenericMemoryAttr(callOp); |
117 | } |
118 | }; |
119 | } // end anonymous namespace |
120 | |
121 | // NosyncAttrDesc type provides get() method for computing |
122 | // LLVM nosync attribute for the given call. |
123 | namespace { |
124 | // Default implementation always returns LLVM nosync. |
125 | // This should be true for the majority of the Fortran runtime calls. |
126 | template <typename KEY, typename Enable = void> |
127 | struct NosyncAttrDesc { |
128 | static std::optional<mlir::NamedAttribute> get(fir::CallOp callOp) { |
129 | // TODO: replace llvm.nosync with an LLVM dialect callback. |
130 | return mlir::NamedAttribute("llvm.nosync" , |
131 | mlir::UnitAttr::get(callOp->getContext())); |
132 | } |
133 | }; |
134 | } // end anonymous namespace |
135 | |
136 | // NocallbackAttrDesc type provides get() method for computing |
137 | // LLVM nocallback attribute for the given call. |
138 | namespace { |
139 | // Default implementation always returns LLVM nocallback. |
140 | // It must be specialized for Fortran runtime functions that may call |
141 | // user functions during their execution (e.g. defined IO, assignment). |
142 | template <typename KEY, typename Enable = void> |
143 | struct NocallbackAttrDesc { |
144 | static std::optional<mlir::NamedAttribute> get(fir::CallOp callOp) { |
145 | // TODO: replace llvm.nocallback with an LLVM dialect callback. |
146 | return mlir::NamedAttribute("llvm.nocallback" , |
147 | mlir::UnitAttr::get(callOp->getContext())); |
148 | } |
149 | }; |
150 | |
151 | // Derived types IO may call back into a Fortran module. |
152 | // This specialization is conservative for Input/OutputDerivedType, |
153 | // and it might be improved by checking if the NonTbpDefinedIoTable |
154 | // pointer argument is null. |
155 | template <typename KEY> |
156 | struct NocallbackAttrDesc< |
157 | KEY, std::enable_if_t< |
158 | IsAny<KEY, mkIOKey(OutputDerivedType), mkIOKey(InputDerivedType), |
159 | mkIOKey(OutputNamelist), mkIOKey(InputNamelist)>>> { |
160 | static std::optional<mlir::NamedAttribute> get(fir::CallOp) { |
161 | return std::nullopt; |
162 | } |
163 | }; |
164 | } // end anonymous namespace |
165 | |
166 | namespace { |
167 | // RuntimeFunction provides different callbacks that compute values |
168 | // of fir.call attributes for a Fortran runtime function. |
169 | struct RuntimeFunction { |
170 | using MemoryAttrGeneratorTy = mlir::LLVM::MemoryEffectsAttr (*)(fir::CallOp); |
171 | using NamedAttrGeneratorTy = |
172 | std::optional<mlir::NamedAttribute> (*)(fir::CallOp); |
173 | using Key = std::string_view; |
174 | constexpr operator Key() const { return key; } |
175 | Key key; |
176 | MemoryAttrGeneratorTy memoryAttrGenerator; |
177 | NamedAttrGeneratorTy nosyncAttrGenerator; |
178 | NamedAttrGeneratorTy nocallbackAttrGenerator; |
179 | }; |
180 | |
181 | // Helper type to create a RuntimeFunction descriptor given |
182 | // the KEY and a function name. |
183 | template <typename KEY> |
184 | struct RuntimeFactory { |
185 | static constexpr RuntimeFunction create(const char name[]) { |
186 | // GCC 7 does not recognize this as a constant expression: |
187 | // ((const char *)RuntimeFunction<>::name) == nullptr |
188 | // This comparison comes from the basic_string_view(const char *) |
189 | // constructor. We have to use the other constructor |
190 | // that takes explicit length parameter. |
191 | return RuntimeFunction{ |
192 | std::string_view{name, std::char_traits<char>::length(name)}, |
193 | MemoryAttrDesc<KEY>::get, NosyncAttrDesc<KEY>::get, |
194 | NocallbackAttrDesc<KEY>::get}; |
195 | } |
196 | }; |
197 | } // end anonymous namespace |
198 | |
199 | #define KNOWN_IO_FUNC(X) RuntimeFactory<mkIOKey(X)>::create(mkIOKey(X)::name) |
200 | #define KNOWN_RUNTIME_FUNC(X) \ |
201 | RuntimeFactory<mkRTKey(X)>::create(mkRTKey(X)::name) |
202 | |
203 | // A table of RuntimeFunction descriptors for all recognized |
204 | // Fortran runtime functions. |
205 | static constexpr RuntimeFunction runtimeFuncsTable[] = { |
206 | #include "flang/Optimizer/Transforms/RuntimeFunctions.inc" |
207 | }; |
208 | |
209 | static constexpr Fortran::common::StaticMultimapView<RuntimeFunction> |
210 | runtimeFuncs(runtimeFuncsTable); |
211 | static_assert(runtimeFuncs.Verify() && "map must be sorted" ); |
212 | |
213 | // Set attributes for the given Fortran runtime call. |
214 | // The symbolTable is used to cache the name lookups in the module. |
215 | static void setRuntimeCallAttributes(fir::CallOp callOp, |
216 | mlir::SymbolTableCollection &symbolTable) { |
217 | auto iface = mlir::cast<mlir::CallOpInterface>(callOp.getOperation()); |
218 | auto funcOp = mlir::dyn_cast_or_null<mlir::func::FuncOp>( |
219 | iface.resolveCallableInTable(&symbolTable)); |
220 | |
221 | if (!funcOp || !funcOp->hasAttrOfType<mlir::UnitAttr>( |
222 | fir::FIROpsDialect::getFirRuntimeAttrName())) |
223 | return; |
224 | |
225 | llvm::StringRef name = funcOp.getName(); |
226 | if (auto range = runtimeFuncs.equal_range(name); |
227 | range.first != range.second) { |
228 | // There should not be duplicate entries. |
229 | assert(range.first + 1 == range.second); |
230 | const RuntimeFunction &desc = *range.first; |
231 | LLVM_DEBUG(llvm::dbgs() |
232 | << "Identified runtime function call: " << desc.key << '\n'); |
233 | if (mlir::LLVM::MemoryEffectsAttr memoryAttr = |
234 | desc.memoryAttrGenerator(callOp)) |
235 | callOp->setAttr(fir::FIROpsDialect::getFirCallMemoryAttrName(), |
236 | memoryAttr); |
237 | if (auto attr = desc.nosyncAttrGenerator(callOp)) |
238 | callOp->setAttr(attr->getName(), attr->getValue()); |
239 | if (auto attr = desc.nocallbackAttrGenerator(callOp)) |
240 | callOp->setAttr(attr->getName(), attr->getValue()); |
241 | LLVM_DEBUG(llvm::dbgs() << "Operation with attrs: " << callOp << '\n'); |
242 | } |
243 | } |
244 | |
245 | void SetRuntimeCallAttributesPass::runOnOperation() { |
246 | mlir::func::FuncOp funcOp = getOperation(); |
247 | // Exit early for declarations to skip the debug output for them. |
248 | if (funcOp.isDeclaration()) |
249 | return; |
250 | LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n" ); |
251 | LLVM_DEBUG(llvm::dbgs() << "Func-name:" << funcOp.getSymName() << "\n" ); |
252 | |
253 | mlir::SymbolTableCollection symbolTable; |
254 | funcOp.walk([&](fir::CallOp callOp) { |
255 | setRuntimeCallAttributes(callOp, symbolTable); |
256 | }); |
257 | LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n" ); |
258 | } |
259 | |