1 | //===-- FIROpenACCTypeInterfaces.cpp --------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Implementation of external dialect interfaces for FIR. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "flang/Optimizer/OpenACC/FIROpenACCTypeInterfaces.h" |
14 | #include "flang/Optimizer/Builder/BoxValue.h" |
15 | #include "flang/Optimizer/Builder/DirectivesCommon.h" |
16 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
17 | #include "flang/Optimizer/Builder/HLFIRTools.h" |
18 | #include "flang/Optimizer/Dialect/FIRCG/CGOps.h" |
19 | #include "flang/Optimizer/Dialect/FIROps.h" |
20 | #include "flang/Optimizer/Dialect/FIROpsSupport.h" |
21 | #include "flang/Optimizer/Dialect/FIRType.h" |
22 | #include "flang/Optimizer/Dialect/Support/FIRContext.h" |
23 | #include "flang/Optimizer/Dialect/Support/KindMapping.h" |
24 | #include "mlir/Dialect/Arith/IR/Arith.h" |
25 | #include "mlir/Dialect/OpenACC/OpenACC.h" |
26 | #include "mlir/IR/BuiltinOps.h" |
27 | #include "mlir/Support/LLVM.h" |
28 | #include "llvm/ADT/TypeSwitch.h" |
29 | |
30 | namespace fir::acc { |
31 | |
32 | static mlir::TypedValue<mlir::acc::PointerLikeType> |
33 | getPtrFromVar(mlir::Value var) { |
34 | if (auto ptr = |
35 | mlir::dyn_cast<mlir::TypedValue<mlir::acc::PointerLikeType>>(var)) |
36 | return ptr; |
37 | |
38 | if (auto load = mlir::dyn_cast_if_present<fir::LoadOp>(var.getDefiningOp())) { |
39 | // All FIR reference types implement the PointerLikeType interface. |
40 | return mlir::cast<mlir::TypedValue<mlir::acc::PointerLikeType>>( |
41 | load.getMemref()); |
42 | } |
43 | |
44 | return {}; |
45 | } |
46 | |
47 | template <> |
48 | mlir::TypedValue<mlir::acc::PointerLikeType> |
49 | OpenACCMappableModel<fir::SequenceType>::getVarPtr(mlir::Type type, |
50 | mlir::Value var) const { |
51 | return getPtrFromVar(var); |
52 | } |
53 | |
54 | template <> |
55 | mlir::TypedValue<mlir::acc::PointerLikeType> |
56 | OpenACCMappableModel<fir::BaseBoxType>::getVarPtr(mlir::Type type, |
57 | mlir::Value var) const { |
58 | return getPtrFromVar(var); |
59 | } |
60 | |
61 | template <> |
62 | std::optional<llvm::TypeSize> |
63 | OpenACCMappableModel<fir::SequenceType>::getSizeInBytes( |
64 | mlir::Type type, mlir::Value var, mlir::ValueRange accBounds, |
65 | const mlir::DataLayout &dataLayout) const { |
66 | // TODO: Bounds operation affect the total size - add support to take them |
67 | // into account. |
68 | if (!accBounds.empty()) |
69 | return {}; |
70 | |
71 | // Dynamic extents or unknown ranks generally do not have compile-time |
72 | // computable dimensions. |
73 | auto seqType = mlir::cast<fir::SequenceType>(type); |
74 | if (seqType.hasDynamicExtents() || seqType.hasUnknownShape()) |
75 | return {}; |
76 | |
77 | // Attempt to find an operation that a lookup for KindMapping can be done |
78 | // from. |
79 | mlir::Operation *kindMapSrcOp = var.getDefiningOp(); |
80 | if (!kindMapSrcOp) { |
81 | kindMapSrcOp = var.getParentRegion()->getParentOp(); |
82 | if (!kindMapSrcOp) |
83 | return {}; |
84 | } |
85 | auto kindMap = fir::getKindMapping(kindMapSrcOp); |
86 | |
87 | auto sizeAndAlignment = |
88 | fir::getTypeSizeAndAlignment(var.getLoc(), type, dataLayout, kindMap); |
89 | if (!sizeAndAlignment.has_value()) |
90 | return {}; |
91 | |
92 | return {llvm::TypeSize::getFixed(sizeAndAlignment->first)}; |
93 | } |
94 | |
95 | template <> |
96 | std::optional<llvm::TypeSize> |
97 | OpenACCMappableModel<fir::BaseBoxType>::getSizeInBytes( |
98 | mlir::Type type, mlir::Value var, mlir::ValueRange accBounds, |
99 | const mlir::DataLayout &dataLayout) const { |
100 | // If we have a box value instead of box reference, the intent is to |
101 | // get the size of the data not the box itself. |
102 | if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(var.getType())) { |
103 | if (auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>( |
104 | fir::unwrapRefType(boxTy.getEleTy()))) { |
105 | return mappableTy.getSizeInBytes(var, accBounds, dataLayout); |
106 | } |
107 | } |
108 | // Size for boxes is not computable until it gets materialized. |
109 | return {}; |
110 | } |
111 | |
112 | template <> |
113 | std::optional<int64_t> |
114 | OpenACCMappableModel<fir::SequenceType>::getOffsetInBytes( |
115 | mlir::Type type, mlir::Value var, mlir::ValueRange accBounds, |
116 | const mlir::DataLayout &dataLayout) const { |
117 | // TODO: Bounds operation affect the offset- add support to take them |
118 | // into account. |
119 | if (!accBounds.empty()) |
120 | return {}; |
121 | |
122 | // Dynamic extents (aka descriptor-based arrays) - may have a offset. |
123 | // For example, a negative stride may mean a negative offset to compute the |
124 | // start of array. |
125 | auto seqType = mlir::cast<fir::SequenceType>(type); |
126 | if (seqType.hasDynamicExtents() || seqType.hasUnknownShape()) |
127 | return {}; |
128 | |
129 | // We have non-dynamic extents - but if for some reason the size is not |
130 | // computable - assume offset is not either. Otherwise, it is an offset of |
131 | // zero. |
132 | if (getSizeInBytes(type, var, accBounds, dataLayout).has_value()) { |
133 | return {0}; |
134 | } |
135 | return {}; |
136 | } |
137 | |
138 | template <> |
139 | std::optional<int64_t> OpenACCMappableModel<fir::BaseBoxType>::getOffsetInBytes( |
140 | mlir::Type type, mlir::Value var, mlir::ValueRange accBounds, |
141 | const mlir::DataLayout &dataLayout) const { |
142 | // If we have a box value instead of box reference, the intent is to |
143 | // get the offset of the data not the offset of the box itself. |
144 | if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(var.getType())) { |
145 | if (auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>( |
146 | fir::unwrapRefType(boxTy.getEleTy()))) { |
147 | return mappableTy.getOffsetInBytes(var, accBounds, dataLayout); |
148 | } |
149 | } |
150 | // Until boxes get materialized, the offset is not evident because it is |
151 | // relative to the pointer being held. |
152 | return {}; |
153 | } |
154 | |
155 | template <> |
156 | llvm::SmallVector<mlir::Value> |
157 | OpenACCMappableModel<fir::SequenceType>::generateAccBounds( |
158 | mlir::Type type, mlir::Value var, mlir::OpBuilder &builder) const { |
159 | assert((mlir::isa<mlir::acc::PointerLikeType>(var.getType()) || |
160 | mlir::isa<mlir::acc::MappableType>(var.getType())) && |
161 | "must be pointer-like or mappable" ); |
162 | |
163 | fir::FirOpBuilder firBuilder(builder, var.getDefiningOp()); |
164 | auto seqType = mlir::cast<fir::SequenceType>(type); |
165 | mlir::Location loc = var.getLoc(); |
166 | |
167 | mlir::Value varPtr = |
168 | mlir::isa<mlir::acc::PointerLikeType>(var.getType()) |
169 | ? var |
170 | : mlir::cast<mlir::acc::MappableType>(var.getType()).getVarPtr(var); |
171 | |
172 | if (seqType.hasDynamicExtents() || seqType.hasUnknownShape()) { |
173 | if (auto boxAddr = |
174 | mlir::dyn_cast_if_present<fir::BoxAddrOp>(varPtr.getDefiningOp())) { |
175 | mlir::Value box = boxAddr.getVal(); |
176 | auto res = |
177 | hlfir::translateToExtendedValue(loc, firBuilder, hlfir::Entity(box)); |
178 | fir::ExtendedValue exv = res.first; |
179 | mlir::Value boxRef = box; |
180 | if (auto boxPtr = getPtrFromVar(box)) { |
181 | boxRef = boxPtr; |
182 | } |
183 | // TODO: Handle Fortran optional. |
184 | const mlir::Value isPresent; |
185 | fir::factory::AddrAndBoundsInfo info(box, boxRef, isPresent, |
186 | box.getType()); |
187 | return fir::factory::genBoundsOpsFromBox<mlir::acc::DataBoundsOp, |
188 | mlir::acc::DataBoundsType>( |
189 | firBuilder, loc, exv, info); |
190 | } |
191 | |
192 | if (mlir::isa<hlfir::DeclareOp, fir::DeclareOp>(varPtr.getDefiningOp())) { |
193 | mlir::Value zero = |
194 | firBuilder.createIntegerConstant(loc, builder.getIndexType(), 0); |
195 | mlir::Value one = |
196 | firBuilder.createIntegerConstant(loc, builder.getIndexType(), 1); |
197 | |
198 | mlir::Value shape; |
199 | if (auto declareOp = |
200 | mlir::dyn_cast_if_present<fir::DeclareOp>(varPtr.getDefiningOp())) |
201 | shape = declareOp.getShape(); |
202 | else if (auto declareOp = mlir::dyn_cast_if_present<hlfir::DeclareOp>( |
203 | varPtr.getDefiningOp())) |
204 | shape = declareOp.getShape(); |
205 | |
206 | const bool strideIncludeLowerExtent = true; |
207 | |
208 | llvm::SmallVector<mlir::Value> accBounds; |
209 | if (auto shapeOp = |
210 | mlir::dyn_cast_if_present<fir::ShapeOp>(shape.getDefiningOp())) { |
211 | mlir::Value cummulativeExtent = one; |
212 | for (auto extent : shapeOp.getExtents()) { |
213 | mlir::Value upperbound = |
214 | builder.create<mlir::arith::SubIOp>(loc, extent, one); |
215 | mlir::Value stride = one; |
216 | if (strideIncludeLowerExtent) { |
217 | stride = cummulativeExtent; |
218 | cummulativeExtent = builder.create<mlir::arith::MulIOp>( |
219 | loc, cummulativeExtent, extent); |
220 | } |
221 | auto accBound = builder.create<mlir::acc::DataBoundsOp>( |
222 | loc, mlir::acc::DataBoundsType::get(builder.getContext()), |
223 | /*lowerbound=*/zero, /*upperbound=*/upperbound, |
224 | /*extent=*/extent, /*stride=*/stride, /*strideInBytes=*/false, |
225 | /*startIdx=*/one); |
226 | accBounds.push_back(accBound); |
227 | } |
228 | } else if (auto shapeShiftOp = |
229 | mlir::dyn_cast_if_present<fir::ShapeShiftOp>( |
230 | shape.getDefiningOp())) { |
231 | mlir::Value lowerbound; |
232 | mlir::Value cummulativeExtent = one; |
233 | for (auto [idx, val] : llvm::enumerate(shapeShiftOp.getPairs())) { |
234 | if (idx % 2 == 0) { |
235 | lowerbound = val; |
236 | } else { |
237 | mlir::Value extent = val; |
238 | mlir::Value upperbound = |
239 | builder.create<mlir::arith::SubIOp>(loc, extent, one); |
240 | upperbound = builder.create<mlir::arith::AddIOp>(loc, lowerbound, |
241 | upperbound); |
242 | mlir::Value stride = one; |
243 | if (strideIncludeLowerExtent) { |
244 | stride = cummulativeExtent; |
245 | cummulativeExtent = builder.create<mlir::arith::MulIOp>( |
246 | loc, cummulativeExtent, extent); |
247 | } |
248 | auto accBound = builder.create<mlir::acc::DataBoundsOp>( |
249 | loc, mlir::acc::DataBoundsType::get(builder.getContext()), |
250 | /*lowerbound=*/zero, /*upperbound=*/upperbound, |
251 | /*extent=*/extent, /*stride=*/stride, /*strideInBytes=*/false, |
252 | /*startIdx=*/lowerbound); |
253 | accBounds.push_back(accBound); |
254 | } |
255 | } |
256 | } |
257 | |
258 | if (!accBounds.empty()) |
259 | return accBounds; |
260 | } |
261 | |
262 | assert(false && "array with unknown dimension expected to have descriptor" ); |
263 | return {}; |
264 | } |
265 | |
266 | // TODO: Detect assumed-size case. |
267 | const bool isAssumedSize = false; |
268 | auto valToCheck = varPtr; |
269 | if (auto boxAddr = |
270 | mlir::dyn_cast_if_present<fir::BoxAddrOp>(varPtr.getDefiningOp())) { |
271 | valToCheck = boxAddr.getVal(); |
272 | } |
273 | auto res = hlfir::translateToExtendedValue(loc, firBuilder, |
274 | hlfir::Entity(valToCheck)); |
275 | fir::ExtendedValue exv = res.first; |
276 | return fir::factory::genBaseBoundsOps<mlir::acc::DataBoundsOp, |
277 | mlir::acc::DataBoundsType>( |
278 | firBuilder, loc, exv, |
279 | /*isAssumedSize=*/isAssumedSize); |
280 | } |
281 | |
282 | template <> |
283 | llvm::SmallVector<mlir::Value> |
284 | OpenACCMappableModel<fir::BaseBoxType>::generateAccBounds( |
285 | mlir::Type type, mlir::Value var, mlir::OpBuilder &builder) const { |
286 | // If we have a box value instead of box reference, the intent is to |
287 | // get the bounds of the data not the bounds of the box itself. |
288 | if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(var.getType())) { |
289 | if (auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>( |
290 | fir::unwrapRefType(boxTy.getEleTy()))) { |
291 | mlir::Value data = builder.create<fir::BoxAddrOp>(var.getLoc(), var); |
292 | return mappableTy.generateAccBounds(data, builder); |
293 | } |
294 | } |
295 | // Box references are not arrays - thus generating acc.bounds does not make |
296 | // sense. |
297 | return {}; |
298 | } |
299 | |
300 | static bool isScalarLike(mlir::Type type) { |
301 | return fir::isa_trivial(type) || fir::isa_ref_type(type); |
302 | } |
303 | |
304 | static bool isArrayLike(mlir::Type type) { |
305 | return mlir::isa<fir::SequenceType>(type); |
306 | } |
307 | |
308 | static bool isCompositeLike(mlir::Type type) { |
309 | return mlir::isa<fir::RecordType, fir::ClassType, mlir::TupleType>(type); |
310 | } |
311 | |
312 | template <> |
313 | mlir::acc::VariableTypeCategory |
314 | OpenACCMappableModel<fir::SequenceType>::getTypeCategory( |
315 | mlir::Type type, mlir::Value var) const { |
316 | return mlir::acc::VariableTypeCategory::array; |
317 | } |
318 | |
319 | template <> |
320 | mlir::acc::VariableTypeCategory |
321 | OpenACCMappableModel<fir::BaseBoxType>::getTypeCategory(mlir::Type type, |
322 | mlir::Value var) const { |
323 | |
324 | mlir::Type eleTy = fir::dyn_cast_ptrOrBoxEleTy(type); |
325 | |
326 | // If the type enclosed by the box is a mappable type, then have it |
327 | // provide the type category. |
328 | if (auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>(eleTy)) |
329 | return mappableTy.getTypeCategory(var); |
330 | |
331 | // For all arrays, despite whether they are allocatable, pointer, assumed, |
332 | // etc, we'd like to categorize them as "array". |
333 | if (isArrayLike(eleTy)) |
334 | return mlir::acc::VariableTypeCategory::array; |
335 | |
336 | // We got here because we don't have an array nor a mappable type. At this |
337 | // point, we know we have a type that fits the "aggregate" definition since it |
338 | // is a type with a descriptor. Try to refine it by checking if it matches the |
339 | // "composite" definition. |
340 | if (isCompositeLike(eleTy)) |
341 | return mlir::acc::VariableTypeCategory::composite; |
342 | |
343 | // Even if we have a scalar type - simply because it is wrapped in a box |
344 | // we want to categorize it as "nonscalar". Anything else would've been |
345 | // non-scalar anyway. |
346 | return mlir::acc::VariableTypeCategory::nonscalar; |
347 | } |
348 | |
349 | static mlir::TypedValue<mlir::acc::PointerLikeType> |
350 | getBaseRef(mlir::TypedValue<mlir::acc::PointerLikeType> varPtr) { |
351 | // If there is no defining op - the unwrapped reference is the base one. |
352 | mlir::Operation *op = varPtr.getDefiningOp(); |
353 | if (!op) |
354 | return varPtr; |
355 | |
356 | // Look to find if this value originates from an interior pointer |
357 | // calculation op. |
358 | mlir::Value baseRef = |
359 | llvm::TypeSwitch<mlir::Operation *, mlir::Value>(op) |
360 | .Case<hlfir::DesignateOp>([&](auto op) { |
361 | // Get the base object. |
362 | return op.getMemref(); |
363 | }) |
364 | .Case<fir::ArrayCoorOp, fir::cg::XArrayCoorOp>([&](auto op) { |
365 | // Get the base array on which the coordinate is being applied. |
366 | return op.getMemref(); |
367 | }) |
368 | .Case<fir::CoordinateOp>([&](auto op) { |
369 | // For coordinate operation which is applied on derived type |
370 | // object, get the base object. |
371 | return op.getRef(); |
372 | }) |
373 | .Default([&](mlir::Operation *) { return varPtr; }); |
374 | |
375 | return mlir::cast<mlir::TypedValue<mlir::acc::PointerLikeType>>(baseRef); |
376 | } |
377 | |
378 | static mlir::acc::VariableTypeCategory |
379 | categorizePointee(mlir::Type pointer, |
380 | mlir::TypedValue<mlir::acc::PointerLikeType> varPtr, |
381 | mlir::Type varType) { |
382 | // FIR uses operations to compute interior pointers. |
383 | // So for example, an array element or composite field access to a float |
384 | // value would both be represented as !fir.ref<f32>. We do not want to treat |
385 | // such a reference as a scalar. Thus unwrap interior pointer calculations. |
386 | auto baseRef = getBaseRef(varPtr); |
387 | mlir::Type eleTy = baseRef.getType().getElementType(); |
388 | |
389 | if (auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>(eleTy)) |
390 | return mappableTy.getTypeCategory(varPtr); |
391 | |
392 | if (isScalarLike(eleTy)) |
393 | return mlir::acc::VariableTypeCategory::scalar; |
394 | if (isArrayLike(eleTy)) |
395 | return mlir::acc::VariableTypeCategory::array; |
396 | if (isCompositeLike(eleTy)) |
397 | return mlir::acc::VariableTypeCategory::composite; |
398 | if (mlir::isa<fir::CharacterType, mlir::FunctionType>(eleTy)) |
399 | return mlir::acc::VariableTypeCategory::nonscalar; |
400 | // "pointers" - in the sense of raw address point-of-view, are considered |
401 | // scalars. However |
402 | if (mlir::isa<fir::LLVMPointerType>(eleTy)) |
403 | return mlir::acc::VariableTypeCategory::scalar; |
404 | |
405 | // Without further checking, this type cannot be categorized. |
406 | return mlir::acc::VariableTypeCategory::uncategorized; |
407 | } |
408 | |
409 | template <> |
410 | mlir::acc::VariableTypeCategory |
411 | OpenACCPointerLikeModel<fir::ReferenceType>::getPointeeTypeCategory( |
412 | mlir::Type pointer, mlir::TypedValue<mlir::acc::PointerLikeType> varPtr, |
413 | mlir::Type varType) const { |
414 | return categorizePointee(pointer, varPtr, varType); |
415 | } |
416 | |
417 | template <> |
418 | mlir::acc::VariableTypeCategory |
419 | OpenACCPointerLikeModel<fir::PointerType>::getPointeeTypeCategory( |
420 | mlir::Type pointer, mlir::TypedValue<mlir::acc::PointerLikeType> varPtr, |
421 | mlir::Type varType) const { |
422 | return categorizePointee(pointer, varPtr, varType); |
423 | } |
424 | |
425 | template <> |
426 | mlir::acc::VariableTypeCategory |
427 | OpenACCPointerLikeModel<fir::HeapType>::getPointeeTypeCategory( |
428 | mlir::Type pointer, mlir::TypedValue<mlir::acc::PointerLikeType> varPtr, |
429 | mlir::Type varType) const { |
430 | return categorizePointee(pointer, varPtr, varType); |
431 | } |
432 | |
433 | template <> |
434 | mlir::acc::VariableTypeCategory |
435 | OpenACCPointerLikeModel<fir::LLVMPointerType>::getPointeeTypeCategory( |
436 | mlir::Type pointer, mlir::TypedValue<mlir::acc::PointerLikeType> varPtr, |
437 | mlir::Type varType) const { |
438 | return categorizePointee(pointer, varPtr, varType); |
439 | } |
440 | |
441 | } // namespace fir::acc |
442 | |