1 | //===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "flang/Optimizer/CodeGen/CodeGen.h" |
14 | |
15 | #include "flang/Optimizer/CodeGen/CodeGenOpenMP.h" |
16 | #include "flang/Optimizer/CodeGen/FIROpPatterns.h" |
17 | #include "flang/Optimizer/CodeGen/LLVMInsertChainFolder.h" |
18 | #include "flang/Optimizer/CodeGen/TypeConverter.h" |
19 | #include "flang/Optimizer/Dialect/FIRAttr.h" |
20 | #include "flang/Optimizer/Dialect/FIRCG/CGOps.h" |
21 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
22 | #include "flang/Optimizer/Dialect/FIROps.h" |
23 | #include "flang/Optimizer/Dialect/FIRType.h" |
24 | #include "flang/Optimizer/Support/DataLayout.h" |
25 | #include "flang/Optimizer/Support/InternalNames.h" |
26 | #include "flang/Optimizer/Support/TypeCode.h" |
27 | #include "flang/Optimizer/Support/Utils.h" |
28 | #include "flang/Runtime/CUDA/descriptor.h" |
29 | #include "flang/Runtime/CUDA/memory.h" |
30 | #include "flang/Runtime/allocator-registry-consts.h" |
31 | #include "flang/Runtime/descriptor-consts.h" |
32 | #include "flang/Semantics/runtime-type-info.h" |
33 | #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" |
34 | #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" |
35 | #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" |
36 | #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" |
37 | #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" |
38 | #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" |
39 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
40 | #include "mlir/Conversion/MathToFuncs/MathToFuncs.h" |
41 | #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" |
42 | #include "mlir/Conversion/MathToLibm/MathToLibm.h" |
43 | #include "mlir/Conversion/MathToROCDL/MathToROCDL.h" |
44 | #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" |
45 | #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" |
46 | #include "mlir/Dialect/Arith/IR/Arith.h" |
47 | #include "mlir/Dialect/DLTI/DLTI.h" |
48 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
49 | #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" |
50 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
51 | #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
52 | #include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h" |
53 | #include "mlir/Dialect/OpenACC/OpenACC.h" |
54 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
55 | #include "mlir/IR/BuiltinTypes.h" |
56 | #include "mlir/IR/Matchers.h" |
57 | #include "mlir/Pass/Pass.h" |
58 | #include "mlir/Pass/PassManager.h" |
59 | #include "mlir/Target/LLVMIR/Import.h" |
60 | #include "mlir/Target/LLVMIR/ModuleTranslation.h" |
61 | #include "llvm/ADT/ArrayRef.h" |
62 | #include "llvm/ADT/TypeSwitch.h" |
63 | |
64 | namespace fir { |
65 | #define GEN_PASS_DEF_FIRTOLLVMLOWERING |
66 | #include "flang/Optimizer/CodeGen/CGPasses.h.inc" |
67 | } // namespace fir |
68 | |
69 | #define DEBUG_TYPE "flang-codegen" |
70 | |
71 | // TODO: This should really be recovered from the specified target. |
72 | static constexpr unsigned defaultAlign = 8; |
73 | |
74 | /// `fir.box` attribute values as defined for CFI_attribute_t in |
75 | /// flang/ISO_Fortran_binding.h. |
76 | static constexpr unsigned kAttrPointer = CFI_attribute_pointer; |
77 | static constexpr unsigned kAttrAllocatable = CFI_attribute_allocatable; |
78 | |
79 | static inline mlir::Type getLlvmPtrType(mlir::MLIRContext *context, |
80 | unsigned addressSpace = 0) { |
81 | return mlir::LLVM::LLVMPointerType::get(context, addressSpace); |
82 | } |
83 | |
84 | static inline mlir::Type getI8Type(mlir::MLIRContext *context) { |
85 | return mlir::IntegerType::get(context, 8); |
86 | } |
87 | |
88 | static mlir::LLVM::ConstantOp |
89 | genConstantIndex(mlir::Location loc, mlir::Type ity, |
90 | mlir::ConversionPatternRewriter &rewriter, |
91 | std::int64_t offset) { |
92 | auto cattr = rewriter.getI64IntegerAttr(offset); |
93 | return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, cattr); |
94 | } |
95 | |
96 | static mlir::Block *createBlock(mlir::ConversionPatternRewriter &rewriter, |
97 | mlir::Block *insertBefore) { |
98 | assert(insertBefore && "expected valid insertion block" ); |
99 | return rewriter.createBlock(insertBefore->getParent(), |
100 | mlir::Region::iterator(insertBefore)); |
101 | } |
102 | |
103 | /// Extract constant from a value that must be the result of one of the |
104 | /// ConstantOp operations. |
105 | static int64_t getConstantIntValue(mlir::Value val) { |
106 | if (auto constVal = fir::getIntIfConstant(val)) |
107 | return *constVal; |
108 | fir::emitFatalError(val.getLoc(), "must be a constant" ); |
109 | } |
110 | |
111 | static unsigned getTypeDescFieldId(mlir::Type ty) { |
112 | auto isArray = mlir::isa<fir::SequenceType>(fir::dyn_cast_ptrOrBoxEleTy(ty)); |
113 | return isArray ? kOptTypePtrPosInBox : kDimsPosInBox; |
114 | } |
115 | static unsigned getLenParamFieldId(mlir::Type ty) { |
116 | return getTypeDescFieldId(ty) + 1; |
117 | } |
118 | |
119 | static llvm::SmallVector<mlir::NamedAttribute> |
120 | addLLVMOpBundleAttrs(mlir::ConversionPatternRewriter &rewriter, |
121 | llvm::ArrayRef<mlir::NamedAttribute> attrs, |
122 | int32_t numCallOperands) { |
123 | llvm::SmallVector<mlir::NamedAttribute> newAttrs; |
124 | newAttrs.reserve(attrs.size() + 2); |
125 | |
126 | for (mlir::NamedAttribute attr : attrs) { |
127 | if (attr.getName() != "operandSegmentSizes" ) |
128 | newAttrs.push_back(attr); |
129 | } |
130 | |
131 | newAttrs.push_back(rewriter.getNamedAttr( |
132 | "operandSegmentSizes" , |
133 | rewriter.getDenseI32ArrayAttr({numCallOperands, 0}))); |
134 | newAttrs.push_back(rewriter.getNamedAttr("op_bundle_sizes" , |
135 | rewriter.getDenseI32ArrayAttr({}))); |
136 | return newAttrs; |
137 | } |
138 | |
139 | namespace { |
140 | /// Lower `fir.address_of` operation to `llvm.address_of` operation. |
141 | struct AddrOfOpConversion : public fir::FIROpConversion<fir::AddrOfOp> { |
142 | using FIROpConversion::FIROpConversion; |
143 | |
144 | llvm::LogicalResult |
145 | matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor, |
146 | mlir::ConversionPatternRewriter &rewriter) const override { |
147 | auto ty = convertType(addr.getType()); |
148 | rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>( |
149 | addr, ty, addr.getSymbol().getRootReference().getValue()); |
150 | return mlir::success(); |
151 | } |
152 | }; |
153 | } // namespace |
154 | |
155 | /// Lookup the function to compute the memory size of this parametric derived |
156 | /// type. The size of the object may depend on the LEN type parameters of the |
157 | /// derived type. |
158 | static mlir::LLVM::LLVMFuncOp |
159 | getDependentTypeMemSizeFn(fir::RecordType recTy, fir::AllocaOp op, |
160 | mlir::ConversionPatternRewriter &rewriter) { |
161 | auto module = op->getParentOfType<mlir::ModuleOp>(); |
162 | std::string name = recTy.getName().str() + "P.mem.size" ; |
163 | if (auto memSizeFunc = module.lookupSymbol<mlir::LLVM::LLVMFuncOp>(name)) |
164 | return memSizeFunc; |
165 | TODO(op.getLoc(), "did not find allocation function" ); |
166 | } |
167 | |
168 | // Compute the alloc scale size (constant factors encoded in the array type). |
169 | // We do this for arrays without a constant interior or arrays of character with |
170 | // dynamic length arrays, since those are the only ones that get decayed to a |
171 | // pointer to the element type. |
172 | template <typename OP> |
173 | static mlir::Value |
174 | genAllocationScaleSize(OP op, mlir::Type ity, |
175 | mlir::ConversionPatternRewriter &rewriter) { |
176 | mlir::Location loc = op.getLoc(); |
177 | mlir::Type dataTy = op.getInType(); |
178 | auto seqTy = mlir::dyn_cast<fir::SequenceType>(dataTy); |
179 | fir::SequenceType::Extent constSize = 1; |
180 | if (seqTy) { |
181 | int constRows = seqTy.getConstantRows(); |
182 | const fir::SequenceType::ShapeRef &shape = seqTy.getShape(); |
183 | if (constRows != static_cast<int>(shape.size())) { |
184 | for (auto extent : shape) { |
185 | if (constRows-- > 0) |
186 | continue; |
187 | if (extent != fir::SequenceType::getUnknownExtent()) |
188 | constSize *= extent; |
189 | } |
190 | } |
191 | } |
192 | |
193 | if (constSize != 1) { |
194 | mlir::Value constVal{ |
195 | genConstantIndex(loc, ity, rewriter, constSize).getResult()}; |
196 | return constVal; |
197 | } |
198 | return nullptr; |
199 | } |
200 | |
201 | namespace { |
202 | struct DeclareOpConversion : public fir::FIROpConversion<fir::cg::XDeclareOp> { |
203 | public: |
204 | using FIROpConversion::FIROpConversion; |
205 | llvm::LogicalResult |
206 | matchAndRewrite(fir::cg::XDeclareOp declareOp, OpAdaptor adaptor, |
207 | mlir::ConversionPatternRewriter &rewriter) const override { |
208 | auto memRef = adaptor.getOperands()[0]; |
209 | if (auto fusedLoc = mlir::dyn_cast<mlir::FusedLoc>(declareOp.getLoc())) { |
210 | if (auto varAttr = |
211 | mlir::dyn_cast_or_null<mlir::LLVM::DILocalVariableAttr>( |
212 | fusedLoc.getMetadata())) { |
213 | rewriter.create<mlir::LLVM::DbgDeclareOp>(memRef.getLoc(), memRef, |
214 | varAttr, nullptr); |
215 | } |
216 | } |
217 | rewriter.replaceOp(declareOp, memRef); |
218 | return mlir::success(); |
219 | } |
220 | }; |
221 | } // namespace |
222 | |
223 | namespace { |
224 | /// convert to LLVM IR dialect `alloca` |
225 | struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> { |
226 | using FIROpConversion::FIROpConversion; |
227 | |
228 | llvm::LogicalResult |
229 | matchAndRewrite(fir::AllocaOp alloc, OpAdaptor adaptor, |
230 | mlir::ConversionPatternRewriter &rewriter) const override { |
231 | mlir::ValueRange operands = adaptor.getOperands(); |
232 | auto loc = alloc.getLoc(); |
233 | mlir::Type ity = lowerTy().indexType(); |
234 | unsigned i = 0; |
235 | mlir::Value size = genConstantIndex(loc, ity, rewriter, 1).getResult(); |
236 | mlir::Type firObjType = fir::unwrapRefType(alloc.getType()); |
237 | mlir::Type llvmObjectType = convertObjectType(firObjType); |
238 | if (alloc.hasLenParams()) { |
239 | unsigned end = alloc.numLenParams(); |
240 | llvm::SmallVector<mlir::Value> lenParams; |
241 | for (; i < end; ++i) |
242 | lenParams.push_back(operands[i]); |
243 | mlir::Type scalarType = fir::unwrapSequenceType(alloc.getInType()); |
244 | if (auto chrTy = mlir::dyn_cast<fir::CharacterType>(scalarType)) { |
245 | fir::CharacterType rawCharTy = fir::CharacterType::getUnknownLen( |
246 | chrTy.getContext(), chrTy.getFKind()); |
247 | llvmObjectType = convertType(rawCharTy); |
248 | assert(end == 1); |
249 | size = integerCast(loc, rewriter, ity, lenParams[0], /*fold=*/true); |
250 | } else if (auto recTy = mlir::dyn_cast<fir::RecordType>(scalarType)) { |
251 | mlir::LLVM::LLVMFuncOp memSizeFn = |
252 | getDependentTypeMemSizeFn(recTy, alloc, rewriter); |
253 | if (!memSizeFn) |
254 | emitError(loc, "did not find allocation function" ); |
255 | mlir::NamedAttribute attr = rewriter.getNamedAttr( |
256 | "callee" , mlir::SymbolRefAttr::get(memSizeFn)); |
257 | auto call = rewriter.create<mlir::LLVM::CallOp>( |
258 | loc, ity, lenParams, |
259 | addLLVMOpBundleAttrs(rewriter, {attr}, lenParams.size())); |
260 | size = call.getResult(); |
261 | llvmObjectType = ::getI8Type(alloc.getContext()); |
262 | } else { |
263 | return emitError(loc, "unexpected type " ) |
264 | << scalarType << " with type parameters" ; |
265 | } |
266 | } |
267 | if (auto scaleSize = genAllocationScaleSize(alloc, ity, rewriter)) |
268 | size = |
269 | rewriter.createOrFold<mlir::LLVM::MulOp>(loc, ity, size, scaleSize); |
270 | if (alloc.hasShapeOperands()) { |
271 | unsigned end = operands.size(); |
272 | for (; i < end; ++i) |
273 | size = rewriter.createOrFold<mlir::LLVM::MulOp>( |
274 | loc, ity, size, |
275 | integerCast(loc, rewriter, ity, operands[i], /*fold=*/true)); |
276 | } |
277 | |
278 | unsigned allocaAs = getAllocaAddressSpace(rewriter); |
279 | unsigned programAs = getProgramAddressSpace(rewriter); |
280 | |
281 | if (mlir::isa<mlir::LLVM::ConstantOp>(size.getDefiningOp())) { |
282 | // Set the Block in which the llvm alloca should be inserted. |
283 | mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp(); |
284 | mlir::Region *parentRegion = rewriter.getInsertionBlock()->getParent(); |
285 | mlir::Block *insertBlock = |
286 | getBlockForAllocaInsert(parentOp, parentRegion); |
287 | |
288 | // The old size might have had multiple users, some at a broader scope |
289 | // than we can safely outline the alloca to. As it is only an |
290 | // llvm.constant operation, it is faster to clone it than to calculate the |
291 | // dominance to see if it really should be moved. |
292 | mlir::Operation *clonedSize = rewriter.clone(*size.getDefiningOp()); |
293 | size = clonedSize->getResult(0); |
294 | clonedSize->moveBefore(&insertBlock->front()); |
295 | rewriter.setInsertionPointAfter(size.getDefiningOp()); |
296 | } |
297 | |
298 | // NOTE: we used to pass alloc->getAttrs() in the builder for non opaque |
299 | // pointers! Only propagate pinned and bindc_name to help debugging, but |
300 | // this should have no functional purpose (and passing the operand segment |
301 | // attribute like before is certainly bad). |
302 | auto llvmAlloc = rewriter.create<mlir::LLVM::AllocaOp>( |
303 | loc, ::getLlvmPtrType(alloc.getContext(), allocaAs), llvmObjectType, |
304 | size); |
305 | if (alloc.getPinned()) |
306 | llvmAlloc->setDiscardableAttr(alloc.getPinnedAttrName(), |
307 | alloc.getPinnedAttr()); |
308 | if (alloc.getBindcName()) |
309 | llvmAlloc->setDiscardableAttr(alloc.getBindcNameAttrName(), |
310 | alloc.getBindcNameAttr()); |
311 | if (allocaAs == programAs) { |
312 | rewriter.replaceOp(alloc, llvmAlloc); |
313 | } else { |
314 | // if our allocation address space, is not the same as the program address |
315 | // space, then we must emit a cast to the program address space before |
316 | // use. An example case would be on AMDGPU, where the allocation address |
317 | // space is the numeric value 5 (private), and the program address space |
318 | // is 0 (generic). |
319 | rewriter.replaceOpWithNewOp<mlir::LLVM::AddrSpaceCastOp>( |
320 | alloc, ::getLlvmPtrType(alloc.getContext(), programAs), llvmAlloc); |
321 | } |
322 | |
323 | return mlir::success(); |
324 | } |
325 | }; |
326 | } // namespace |
327 | |
328 | namespace { |
329 | /// Lower `fir.box_addr` to the sequence of operations to extract the first |
330 | /// element of the box. |
331 | struct BoxAddrOpConversion : public fir::FIROpConversion<fir::BoxAddrOp> { |
332 | using FIROpConversion::FIROpConversion; |
333 | |
334 | llvm::LogicalResult |
335 | matchAndRewrite(fir::BoxAddrOp boxaddr, OpAdaptor adaptor, |
336 | mlir::ConversionPatternRewriter &rewriter) const override { |
337 | mlir::Value a = adaptor.getOperands()[0]; |
338 | auto loc = boxaddr.getLoc(); |
339 | if (auto argty = |
340 | mlir::dyn_cast<fir::BaseBoxType>(boxaddr.getVal().getType())) { |
341 | TypePair boxTyPair = getBoxTypePair(argty); |
342 | rewriter.replaceOp(boxaddr, |
343 | getBaseAddrFromBox(loc, boxTyPair, a, rewriter)); |
344 | } else { |
345 | rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>(boxaddr, a, 0); |
346 | } |
347 | return mlir::success(); |
348 | } |
349 | }; |
350 | |
351 | /// Convert `!fir.boxchar_len` to `!llvm.extractvalue` for the 2nd part of the |
352 | /// boxchar. |
353 | struct BoxCharLenOpConversion : public fir::FIROpConversion<fir::BoxCharLenOp> { |
354 | using FIROpConversion::FIROpConversion; |
355 | |
356 | llvm::LogicalResult |
357 | matchAndRewrite(fir::BoxCharLenOp boxCharLen, OpAdaptor adaptor, |
358 | mlir::ConversionPatternRewriter &rewriter) const override { |
359 | mlir::Value boxChar = adaptor.getOperands()[0]; |
360 | mlir::Location loc = boxChar.getLoc(); |
361 | mlir::Type returnValTy = boxCharLen.getResult().getType(); |
362 | |
363 | constexpr int boxcharLenIdx = 1; |
364 | auto len = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, boxChar, |
365 | boxcharLenIdx); |
366 | mlir::Value lenAfterCast = integerCast(loc, rewriter, returnValTy, len); |
367 | rewriter.replaceOp(boxCharLen, lenAfterCast); |
368 | |
369 | return mlir::success(); |
370 | } |
371 | }; |
372 | |
373 | /// Lower `fir.box_dims` to a sequence of operations to extract the requested |
374 | /// dimension information from the boxed value. |
375 | /// Result in a triple set of GEPs and loads. |
376 | struct BoxDimsOpConversion : public fir::FIROpConversion<fir::BoxDimsOp> { |
377 | using FIROpConversion::FIROpConversion; |
378 | |
379 | llvm::LogicalResult |
380 | matchAndRewrite(fir::BoxDimsOp boxdims, OpAdaptor adaptor, |
381 | mlir::ConversionPatternRewriter &rewriter) const override { |
382 | llvm::SmallVector<mlir::Type, 3> resultTypes = { |
383 | convertType(boxdims.getResult(0).getType()), |
384 | convertType(boxdims.getResult(1).getType()), |
385 | convertType(boxdims.getResult(2).getType()), |
386 | }; |
387 | TypePair boxTyPair = getBoxTypePair(boxdims.getVal().getType()); |
388 | auto results = getDimsFromBox(boxdims.getLoc(), resultTypes, boxTyPair, |
389 | adaptor.getOperands()[0], |
390 | adaptor.getOperands()[1], rewriter); |
391 | rewriter.replaceOp(boxdims, results); |
392 | return mlir::success(); |
393 | } |
394 | }; |
395 | |
396 | /// Lower `fir.box_elesize` to a sequence of operations ro extract the size of |
397 | /// an element in the boxed value. |
398 | struct BoxEleSizeOpConversion : public fir::FIROpConversion<fir::BoxEleSizeOp> { |
399 | using FIROpConversion::FIROpConversion; |
400 | |
401 | llvm::LogicalResult |
402 | matchAndRewrite(fir::BoxEleSizeOp boxelesz, OpAdaptor adaptor, |
403 | mlir::ConversionPatternRewriter &rewriter) const override { |
404 | mlir::Value box = adaptor.getOperands()[0]; |
405 | auto loc = boxelesz.getLoc(); |
406 | auto ty = convertType(boxelesz.getType()); |
407 | TypePair boxTyPair = getBoxTypePair(boxelesz.getVal().getType()); |
408 | auto elemSize = getElementSizeFromBox(loc, ty, boxTyPair, box, rewriter); |
409 | rewriter.replaceOp(boxelesz, elemSize); |
410 | return mlir::success(); |
411 | } |
412 | }; |
413 | |
414 | /// Lower `fir.box_isalloc` to a sequence of operations to determine if the |
415 | /// boxed value was from an ALLOCATABLE entity. |
416 | struct BoxIsAllocOpConversion : public fir::FIROpConversion<fir::BoxIsAllocOp> { |
417 | using FIROpConversion::FIROpConversion; |
418 | |
419 | llvm::LogicalResult |
420 | matchAndRewrite(fir::BoxIsAllocOp boxisalloc, OpAdaptor adaptor, |
421 | mlir::ConversionPatternRewriter &rewriter) const override { |
422 | mlir::Value box = adaptor.getOperands()[0]; |
423 | auto loc = boxisalloc.getLoc(); |
424 | TypePair boxTyPair = getBoxTypePair(boxisalloc.getVal().getType()); |
425 | mlir::Value check = |
426 | genBoxAttributeCheck(loc, boxTyPair, box, rewriter, kAttrAllocatable); |
427 | rewriter.replaceOp(boxisalloc, check); |
428 | return mlir::success(); |
429 | } |
430 | }; |
431 | |
432 | /// Lower `fir.box_isarray` to a sequence of operations to determine if the |
433 | /// boxed is an array. |
434 | struct BoxIsArrayOpConversion : public fir::FIROpConversion<fir::BoxIsArrayOp> { |
435 | using FIROpConversion::FIROpConversion; |
436 | |
437 | llvm::LogicalResult |
438 | matchAndRewrite(fir::BoxIsArrayOp boxisarray, OpAdaptor adaptor, |
439 | mlir::ConversionPatternRewriter &rewriter) const override { |
440 | mlir::Value a = adaptor.getOperands()[0]; |
441 | auto loc = boxisarray.getLoc(); |
442 | TypePair boxTyPair = getBoxTypePair(boxisarray.getVal().getType()); |
443 | mlir::Value rank = getRankFromBox(loc, boxTyPair, a, rewriter); |
444 | mlir::Value c0 = genConstantIndex(loc, rank.getType(), rewriter, 0); |
445 | rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>( |
446 | boxisarray, mlir::LLVM::ICmpPredicate::ne, rank, c0); |
447 | return mlir::success(); |
448 | } |
449 | }; |
450 | |
451 | /// Lower `fir.box_isptr` to a sequence of operations to determined if the |
452 | /// boxed value was from a POINTER entity. |
453 | struct BoxIsPtrOpConversion : public fir::FIROpConversion<fir::BoxIsPtrOp> { |
454 | using FIROpConversion::FIROpConversion; |
455 | |
456 | llvm::LogicalResult |
457 | matchAndRewrite(fir::BoxIsPtrOp boxisptr, OpAdaptor adaptor, |
458 | mlir::ConversionPatternRewriter &rewriter) const override { |
459 | mlir::Value box = adaptor.getOperands()[0]; |
460 | auto loc = boxisptr.getLoc(); |
461 | TypePair boxTyPair = getBoxTypePair(boxisptr.getVal().getType()); |
462 | mlir::Value check = |
463 | genBoxAttributeCheck(loc, boxTyPair, box, rewriter, kAttrPointer); |
464 | rewriter.replaceOp(boxisptr, check); |
465 | return mlir::success(); |
466 | } |
467 | }; |
468 | |
469 | /// Lower `fir.box_rank` to the sequence of operation to extract the rank from |
470 | /// the box. |
471 | struct BoxRankOpConversion : public fir::FIROpConversion<fir::BoxRankOp> { |
472 | using FIROpConversion::FIROpConversion; |
473 | |
474 | llvm::LogicalResult |
475 | matchAndRewrite(fir::BoxRankOp boxrank, OpAdaptor adaptor, |
476 | mlir::ConversionPatternRewriter &rewriter) const override { |
477 | mlir::Value a = adaptor.getOperands()[0]; |
478 | auto loc = boxrank.getLoc(); |
479 | mlir::Type ty = convertType(boxrank.getType()); |
480 | TypePair boxTyPair = |
481 | getBoxTypePair(fir::unwrapRefType(boxrank.getBox().getType())); |
482 | mlir::Value rank = getRankFromBox(loc, boxTyPair, a, rewriter); |
483 | mlir::Value result = integerCast(loc, rewriter, ty, rank); |
484 | rewriter.replaceOp(boxrank, result); |
485 | return mlir::success(); |
486 | } |
487 | }; |
488 | |
489 | /// Lower `fir.boxproc_host` operation. Extracts the host pointer from the |
490 | /// boxproc. |
491 | /// TODO: Part of supporting Fortran 2003 procedure pointers. |
492 | struct BoxProcHostOpConversion |
493 | : public fir::FIROpConversion<fir::BoxProcHostOp> { |
494 | using FIROpConversion::FIROpConversion; |
495 | |
496 | llvm::LogicalResult |
497 | matchAndRewrite(fir::BoxProcHostOp boxprochost, OpAdaptor adaptor, |
498 | mlir::ConversionPatternRewriter &rewriter) const override { |
499 | TODO(boxprochost.getLoc(), "fir.boxproc_host codegen" ); |
500 | return mlir::failure(); |
501 | } |
502 | }; |
503 | |
504 | /// Lower `fir.box_tdesc` to the sequence of operations to extract the type |
505 | /// descriptor from the box. |
506 | struct BoxTypeDescOpConversion |
507 | : public fir::FIROpConversion<fir::BoxTypeDescOp> { |
508 | using FIROpConversion::FIROpConversion; |
509 | |
510 | llvm::LogicalResult |
511 | matchAndRewrite(fir::BoxTypeDescOp boxtypedesc, OpAdaptor adaptor, |
512 | mlir::ConversionPatternRewriter &rewriter) const override { |
513 | mlir::Value box = adaptor.getOperands()[0]; |
514 | TypePair boxTyPair = getBoxTypePair(boxtypedesc.getBox().getType()); |
515 | auto typeDescAddr = |
516 | loadTypeDescAddress(boxtypedesc.getLoc(), boxTyPair, box, rewriter); |
517 | rewriter.replaceOp(boxtypedesc, typeDescAddr); |
518 | return mlir::success(); |
519 | } |
520 | }; |
521 | |
522 | /// Lower `fir.box_typecode` to a sequence of operations to extract the type |
523 | /// code in the boxed value. |
524 | struct BoxTypeCodeOpConversion |
525 | : public fir::FIROpConversion<fir::BoxTypeCodeOp> { |
526 | using FIROpConversion::FIROpConversion; |
527 | |
528 | llvm::LogicalResult |
529 | matchAndRewrite(fir::BoxTypeCodeOp op, OpAdaptor adaptor, |
530 | mlir::ConversionPatternRewriter &rewriter) const override { |
531 | mlir::Value box = adaptor.getOperands()[0]; |
532 | auto loc = box.getLoc(); |
533 | auto ty = convertType(op.getType()); |
534 | TypePair boxTyPair = getBoxTypePair(op.getBox().getType()); |
535 | auto typeCode = |
536 | getValueFromBox(loc, boxTyPair, box, ty, rewriter, kTypePosInBox); |
537 | rewriter.replaceOp(op, typeCode); |
538 | return mlir::success(); |
539 | } |
540 | }; |
541 | |
542 | /// Lower `fir.string_lit` to LLVM IR dialect operation. |
543 | struct StringLitOpConversion : public fir::FIROpConversion<fir::StringLitOp> { |
544 | using FIROpConversion::FIROpConversion; |
545 | |
546 | llvm::LogicalResult |
547 | matchAndRewrite(fir::StringLitOp constop, OpAdaptor adaptor, |
548 | mlir::ConversionPatternRewriter &rewriter) const override { |
549 | auto ty = convertType(constop.getType()); |
550 | auto attr = constop.getValue(); |
551 | if (mlir::isa<mlir::StringAttr>(attr)) { |
552 | rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(constop, ty, attr); |
553 | return mlir::success(); |
554 | } |
555 | |
556 | auto charTy = mlir::cast<fir::CharacterType>(constop.getType()); |
557 | unsigned bits = lowerTy().characterBitsize(charTy); |
558 | mlir::Type intTy = rewriter.getIntegerType(bits); |
559 | mlir::Location loc = constop.getLoc(); |
560 | mlir::Value cst = rewriter.create<mlir::LLVM::UndefOp>(loc, ty); |
561 | if (auto arr = mlir::dyn_cast<mlir::DenseElementsAttr>(attr)) { |
562 | cst = rewriter.create<mlir::LLVM::ConstantOp>(loc, ty, arr); |
563 | } else if (auto arr = mlir::dyn_cast<mlir::ArrayAttr>(attr)) { |
564 | for (auto a : llvm::enumerate(arr.getValue())) { |
565 | // convert each character to a precise bitsize |
566 | auto elemAttr = mlir::IntegerAttr::get( |
567 | intTy, |
568 | mlir::cast<mlir::IntegerAttr>(a.value()).getValue().zextOrTrunc( |
569 | bits)); |
570 | auto elemCst = |
571 | rewriter.create<mlir::LLVM::ConstantOp>(loc, intTy, elemAttr); |
572 | cst = rewriter.create<mlir::LLVM::InsertValueOp>(loc, cst, elemCst, |
573 | a.index()); |
574 | } |
575 | } else { |
576 | return mlir::failure(); |
577 | } |
578 | rewriter.replaceOp(constop, cst); |
579 | return mlir::success(); |
580 | } |
581 | }; |
582 | |
583 | /// `fir.call` -> `llvm.call` |
584 | struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> { |
585 | using FIROpConversion::FIROpConversion; |
586 | |
587 | llvm::LogicalResult |
588 | matchAndRewrite(fir::CallOp call, OpAdaptor adaptor, |
589 | mlir::ConversionPatternRewriter &rewriter) const override { |
590 | llvm::SmallVector<mlir::Type> resultTys; |
591 | mlir::Attribute memAttr = |
592 | call->getAttr(fir::FIROpsDialect::getFirCallMemoryAttrName()); |
593 | if (memAttr) |
594 | call->removeAttr(fir::FIROpsDialect::getFirCallMemoryAttrName()); |
595 | |
596 | for (auto r : call.getResults()) |
597 | resultTys.push_back(convertType(r.getType())); |
598 | // Convert arith::FastMathFlagsAttr to LLVM::FastMathFlagsAttr. |
599 | mlir::arith::AttrConvertFastMathToLLVM<fir::CallOp, mlir::LLVM::CallOp> |
600 | attrConvert(call); |
601 | auto llvmCall = rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>( |
602 | call, resultTys, adaptor.getOperands(), |
603 | addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(), |
604 | adaptor.getOperands().size())); |
605 | if (mlir::ArrayAttr argAttrsArray = call.getArgAttrsAttr()) { |
606 | // sret and byval type needs to be converted. |
607 | auto convertTypeAttr = [&](const mlir::NamedAttribute &attr) { |
608 | return mlir::TypeAttr::get(convertType( |
609 | llvm::cast<mlir::TypeAttr>(attr.getValue()).getValue())); |
610 | }; |
611 | llvm::SmallVector<mlir::Attribute> newArgAttrsArray; |
612 | for (auto argAttrs : argAttrsArray) { |
613 | llvm::SmallVector<mlir::NamedAttribute> convertedAttrs; |
614 | for (const mlir::NamedAttribute &attr : |
615 | llvm::cast<mlir::DictionaryAttr>(argAttrs)) { |
616 | if (attr.getName().getValue() == |
617 | mlir::LLVM::LLVMDialect::getByValAttrName()) { |
618 | convertedAttrs.push_back(rewriter.getNamedAttr( |
619 | mlir::LLVM::LLVMDialect::getByValAttrName(), |
620 | convertTypeAttr(attr))); |
621 | } else if (attr.getName().getValue() == |
622 | mlir::LLVM::LLVMDialect::getStructRetAttrName()) { |
623 | convertedAttrs.push_back(rewriter.getNamedAttr( |
624 | mlir::LLVM::LLVMDialect::getStructRetAttrName(), |
625 | convertTypeAttr(attr))); |
626 | } else { |
627 | convertedAttrs.push_back(attr); |
628 | } |
629 | } |
630 | newArgAttrsArray.emplace_back( |
631 | mlir::DictionaryAttr::get(rewriter.getContext(), convertedAttrs)); |
632 | } |
633 | llvmCall.setArgAttrsAttr(rewriter.getArrayAttr(newArgAttrsArray)); |
634 | } |
635 | if (mlir::ArrayAttr resAttrs = call.getResAttrsAttr()) |
636 | llvmCall.setResAttrsAttr(resAttrs); |
637 | |
638 | if (memAttr) |
639 | llvmCall.setMemoryEffectsAttr( |
640 | mlir::cast<mlir::LLVM::MemoryEffectsAttr>(memAttr)); |
641 | return mlir::success(); |
642 | } |
643 | }; |
644 | } // namespace |
645 | |
646 | static mlir::Type getComplexEleTy(mlir::Type complex) { |
647 | return mlir::cast<mlir::ComplexType>(complex).getElementType(); |
648 | } |
649 | |
650 | namespace { |
651 | /// Compare complex values |
652 | /// |
653 | /// Per 10.1, the only comparisons available are .EQ. (oeq) and .NE. (une). |
654 | /// |
655 | /// For completeness, all other comparison are done on the real component only. |
656 | struct CmpcOpConversion : public fir::FIROpConversion<fir::CmpcOp> { |
657 | using FIROpConversion::FIROpConversion; |
658 | |
659 | llvm::LogicalResult |
660 | matchAndRewrite(fir::CmpcOp cmp, OpAdaptor adaptor, |
661 | mlir::ConversionPatternRewriter &rewriter) const override { |
662 | mlir::ValueRange operands = adaptor.getOperands(); |
663 | mlir::Type resTy = convertType(cmp.getType()); |
664 | mlir::Location loc = cmp.getLoc(); |
665 | mlir::LLVM::FastmathFlags fmf = |
666 | mlir::arith::convertArithFastMathFlagsToLLVM(cmp.getFastmath()); |
667 | mlir::LLVM::FCmpPredicate pred = |
668 | static_cast<mlir::LLVM::FCmpPredicate>(cmp.getPredicate()); |
669 | auto rcp = rewriter.create<mlir::LLVM::FCmpOp>( |
670 | loc, resTy, pred, |
671 | rewriter.create<mlir::LLVM::ExtractValueOp>(loc, operands[0], 0), |
672 | rewriter.create<mlir::LLVM::ExtractValueOp>(loc, operands[1], 0), fmf); |
673 | auto icp = rewriter.create<mlir::LLVM::FCmpOp>( |
674 | loc, resTy, pred, |
675 | rewriter.create<mlir::LLVM::ExtractValueOp>(loc, operands[0], 1), |
676 | rewriter.create<mlir::LLVM::ExtractValueOp>(loc, operands[1], 1), fmf); |
677 | llvm::SmallVector<mlir::Value, 2> cp = {rcp, icp}; |
678 | switch (cmp.getPredicate()) { |
679 | case mlir::arith::CmpFPredicate::OEQ: // .EQ. |
680 | rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(cmp, resTy, cp); |
681 | break; |
682 | case mlir::arith::CmpFPredicate::UNE: // .NE. |
683 | rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(cmp, resTy, cp); |
684 | break; |
685 | default: |
686 | rewriter.replaceOp(cmp, rcp.getResult()); |
687 | break; |
688 | } |
689 | return mlir::success(); |
690 | } |
691 | }; |
692 | |
693 | /// fir.volatile_cast is only useful at the fir level. Once we lower to LLVM, |
694 | /// volatility is described by setting volatile attributes on the LLVM ops. |
695 | struct VolatileCastOpConversion |
696 | : public fir::FIROpConversion<fir::VolatileCastOp> { |
697 | using FIROpConversion::FIROpConversion; |
698 | |
699 | llvm::LogicalResult |
700 | matchAndRewrite(fir::VolatileCastOp volatileCast, OpAdaptor adaptor, |
701 | mlir::ConversionPatternRewriter &rewriter) const override { |
702 | rewriter.replaceOp(volatileCast, adaptor.getOperands()[0]); |
703 | return mlir::success(); |
704 | } |
705 | }; |
706 | |
707 | /// convert value of from-type to value of to-type |
708 | struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> { |
709 | using FIROpConversion::FIROpConversion; |
710 | |
711 | static bool isFloatingPointTy(mlir::Type ty) { |
712 | return mlir::isa<mlir::FloatType>(ty); |
713 | } |
714 | |
715 | llvm::LogicalResult |
716 | matchAndRewrite(fir::ConvertOp convert, OpAdaptor adaptor, |
717 | mlir::ConversionPatternRewriter &rewriter) const override { |
718 | auto fromFirTy = convert.getValue().getType(); |
719 | auto toFirTy = convert.getRes().getType(); |
720 | auto fromTy = convertType(fromFirTy); |
721 | auto toTy = convertType(toFirTy); |
722 | mlir::Value op0 = adaptor.getOperands()[0]; |
723 | |
724 | if (fromFirTy == toFirTy) { |
725 | rewriter.replaceOp(convert, op0); |
726 | return mlir::success(); |
727 | } |
728 | |
729 | auto loc = convert.getLoc(); |
730 | auto i1Type = mlir::IntegerType::get(convert.getContext(), 1); |
731 | |
732 | if (mlir::isa<fir::RecordType>(toFirTy)) { |
733 | // Convert to compatible BIND(C) record type. |
734 | // Double check that the record types are compatible (it should have |
735 | // already been checked by the verifier). |
736 | assert(mlir::cast<fir::RecordType>(fromFirTy).getTypeList() == |
737 | mlir::cast<fir::RecordType>(toFirTy).getTypeList() && |
738 | "incompatible record types" ); |
739 | |
740 | auto toStTy = mlir::cast<mlir::LLVM::LLVMStructType>(toTy); |
741 | mlir::Value val = rewriter.create<mlir::LLVM::UndefOp>(loc, toStTy); |
742 | auto indexTypeMap = toStTy.getSubelementIndexMap(); |
743 | assert(indexTypeMap.has_value() && "invalid record type" ); |
744 | |
745 | for (auto [attr, type] : indexTypeMap.value()) { |
746 | int64_t index = mlir::cast<mlir::IntegerAttr>(attr).getInt(); |
747 | auto extVal = |
748 | rewriter.create<mlir::LLVM::ExtractValueOp>(loc, op0, index); |
749 | val = |
750 | rewriter.create<mlir::LLVM::InsertValueOp>(loc, val, extVal, index); |
751 | } |
752 | |
753 | rewriter.replaceOp(convert, val); |
754 | return mlir::success(); |
755 | } |
756 | |
757 | if (mlir::isa<fir::LogicalType>(fromFirTy) || |
758 | mlir::isa<fir::LogicalType>(toFirTy)) { |
759 | // By specification fir::LogicalType value may be any number, |
760 | // where non-zero value represents .true. and zero value represents |
761 | // .false. |
762 | // |
763 | // integer<->logical conversion requires value normalization. |
764 | // Conversion from wide logical to narrow logical must set the result |
765 | // to non-zero iff the input is non-zero - the easiest way to implement |
766 | // it is to compare the input agains zero and set the result to |
767 | // the canonical 0/1. |
768 | // Conversion from narrow logical to wide logical may be implemented |
769 | // as a zero or sign extension of the input, but it may use value |
770 | // normalization as well. |
771 | if (!mlir::isa<mlir::IntegerType>(fromTy) || |
772 | !mlir::isa<mlir::IntegerType>(toTy)) |
773 | return mlir::emitError(loc) |
774 | << "unsupported types for logical conversion: " << fromTy |
775 | << " -> " << toTy; |
776 | |
777 | // Do folding for constant inputs. |
778 | if (auto constVal = fir::getIntIfConstant(op0)) { |
779 | mlir::Value normVal = |
780 | genConstantIndex(loc, toTy, rewriter, *constVal ? 1 : 0); |
781 | rewriter.replaceOp(convert, normVal); |
782 | return mlir::success(); |
783 | } |
784 | |
785 | // If the input is i1, then we can just zero extend it, and |
786 | // the result will be normalized. |
787 | if (fromTy == i1Type) { |
788 | rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(convert, toTy, op0); |
789 | return mlir::success(); |
790 | } |
791 | |
792 | // Compare the input with zero. |
793 | mlir::Value zero = genConstantIndex(loc, fromTy, rewriter, 0); |
794 | auto isTrue = rewriter.create<mlir::LLVM::ICmpOp>( |
795 | loc, mlir::LLVM::ICmpPredicate::ne, op0, zero); |
796 | |
797 | // Zero extend the i1 isTrue result to the required type (unless it is i1 |
798 | // itself). |
799 | if (toTy != i1Type) |
800 | rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(convert, toTy, isTrue); |
801 | else |
802 | rewriter.replaceOp(convert, isTrue.getResult()); |
803 | |
804 | return mlir::success(); |
805 | } |
806 | |
807 | if (fromTy == toTy) { |
808 | rewriter.replaceOp(convert, op0); |
809 | return mlir::success(); |
810 | } |
811 | auto convertFpToFp = [&](mlir::Value val, unsigned fromBits, |
812 | unsigned toBits, mlir::Type toTy) -> mlir::Value { |
813 | if (fromBits == toBits) { |
814 | // TODO: Converting between two floating-point representations with the |
815 | // same bitwidth is not allowed for now. |
816 | mlir::emitError(loc, |
817 | "cannot implicitly convert between two floating-point " |
818 | "representations of the same bitwidth" ); |
819 | return {}; |
820 | } |
821 | if (fromBits > toBits) |
822 | return rewriter.create<mlir::LLVM::FPTruncOp>(loc, toTy, val); |
823 | return rewriter.create<mlir::LLVM::FPExtOp>(loc, toTy, val); |
824 | }; |
825 | // Complex to complex conversion. |
826 | if (fir::isa_complex(fromFirTy) && fir::isa_complex(toFirTy)) { |
827 | // Special case: handle the conversion of a complex such that both the |
828 | // real and imaginary parts are converted together. |
829 | auto ty = convertType(getComplexEleTy(convert.getValue().getType())); |
830 | auto rp = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, op0, 0); |
831 | auto ip = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, op0, 1); |
832 | auto nt = convertType(getComplexEleTy(convert.getRes().getType())); |
833 | auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(ty); |
834 | auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(nt); |
835 | auto rc = convertFpToFp(rp, fromBits, toBits, nt); |
836 | auto ic = convertFpToFp(ip, fromBits, toBits, nt); |
837 | auto un = rewriter.create<mlir::LLVM::UndefOp>(loc, toTy); |
838 | auto i1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, un, rc, 0); |
839 | rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(convert, i1, ic, |
840 | 1); |
841 | return mlir::success(); |
842 | } |
843 | |
844 | // Floating point to floating point conversion. |
845 | if (isFloatingPointTy(fromTy)) { |
846 | if (isFloatingPointTy(toTy)) { |
847 | auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(fromTy); |
848 | auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(toTy); |
849 | auto v = convertFpToFp(op0, fromBits, toBits, toTy); |
850 | rewriter.replaceOp(convert, v); |
851 | return mlir::success(); |
852 | } |
853 | if (mlir::isa<mlir::IntegerType>(toTy)) { |
854 | // NOTE: We are checking the fir type here because toTy is an LLVM type |
855 | // which is signless, and we need to use the intrinsic that matches the |
856 | // sign of the output in fir. |
857 | if (toFirTy.isUnsignedInteger()) { |
858 | auto intrinsicName = |
859 | mlir::StringAttr::get(convert.getContext(), "llvm.fptoui.sat" ); |
860 | rewriter.replaceOpWithNewOp<mlir::LLVM::CallIntrinsicOp>( |
861 | convert, toTy, intrinsicName, op0); |
862 | } else { |
863 | auto intrinsicName = |
864 | mlir::StringAttr::get(convert.getContext(), "llvm.fptosi.sat" ); |
865 | rewriter.replaceOpWithNewOp<mlir::LLVM::CallIntrinsicOp>( |
866 | convert, toTy, intrinsicName, op0); |
867 | } |
868 | return mlir::success(); |
869 | } |
870 | } else if (mlir::isa<mlir::IntegerType>(fromTy)) { |
871 | // Integer to integer conversion. |
872 | if (mlir::isa<mlir::IntegerType>(toTy)) { |
873 | auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(fromTy); |
874 | auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(toTy); |
875 | assert(fromBits != toBits); |
876 | if (fromBits > toBits) { |
877 | rewriter.replaceOpWithNewOp<mlir::LLVM::TruncOp>(convert, toTy, op0); |
878 | return mlir::success(); |
879 | } |
880 | if (fromFirTy == i1Type || fromFirTy.isUnsignedInteger()) { |
881 | rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(convert, toTy, op0); |
882 | return mlir::success(); |
883 | } |
884 | rewriter.replaceOpWithNewOp<mlir::LLVM::SExtOp>(convert, toTy, op0); |
885 | return mlir::success(); |
886 | } |
887 | // Integer to floating point conversion. |
888 | if (isFloatingPointTy(toTy)) { |
889 | if (fromTy.isUnsignedInteger()) |
890 | rewriter.replaceOpWithNewOp<mlir::LLVM::UIToFPOp>(convert, toTy, op0); |
891 | else |
892 | rewriter.replaceOpWithNewOp<mlir::LLVM::SIToFPOp>(convert, toTy, op0); |
893 | return mlir::success(); |
894 | } |
895 | // Integer to pointer conversion. |
896 | if (mlir::isa<mlir::LLVM::LLVMPointerType>(toTy)) { |
897 | rewriter.replaceOpWithNewOp<mlir::LLVM::IntToPtrOp>(convert, toTy, op0); |
898 | return mlir::success(); |
899 | } |
900 | } else if (mlir::isa<mlir::LLVM::LLVMPointerType>(fromTy)) { |
901 | // Pointer to integer conversion. |
902 | if (mlir::isa<mlir::IntegerType>(toTy)) { |
903 | rewriter.replaceOpWithNewOp<mlir::LLVM::PtrToIntOp>(convert, toTy, op0); |
904 | return mlir::success(); |
905 | } |
906 | // Pointer to pointer conversion. |
907 | if (mlir::isa<mlir::LLVM::LLVMPointerType>(toTy)) { |
908 | rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(convert, toTy, op0); |
909 | return mlir::success(); |
910 | } |
911 | } |
912 | return emitError(loc) << "cannot convert " << fromTy << " to " << toTy; |
913 | } |
914 | }; |
915 | |
916 | /// `fir.type_info` operation has no specific CodeGen. The operation is |
917 | /// only used to carry information during FIR to FIR passes. It may be used |
918 | /// in the future to generate the runtime type info data structures instead |
919 | /// of generating them in lowering. |
920 | struct TypeInfoOpConversion : public fir::FIROpConversion<fir::TypeInfoOp> { |
921 | using FIROpConversion::FIROpConversion; |
922 | |
923 | llvm::LogicalResult |
924 | matchAndRewrite(fir::TypeInfoOp op, OpAdaptor, |
925 | mlir::ConversionPatternRewriter &rewriter) const override { |
926 | rewriter.eraseOp(op); |
927 | return mlir::success(); |
928 | } |
929 | }; |
930 | |
931 | /// `fir.dt_entry` operation has no specific CodeGen. The operation is only used |
932 | /// to carry information during FIR to FIR passes. |
933 | struct DTEntryOpConversion : public fir::FIROpConversion<fir::DTEntryOp> { |
934 | using FIROpConversion::FIROpConversion; |
935 | |
936 | llvm::LogicalResult |
937 | matchAndRewrite(fir::DTEntryOp op, OpAdaptor, |
938 | mlir::ConversionPatternRewriter &rewriter) const override { |
939 | rewriter.eraseOp(op); |
940 | return mlir::success(); |
941 | } |
942 | }; |
943 | |
944 | /// Lower `fir.global_len` operation. |
945 | struct GlobalLenOpConversion : public fir::FIROpConversion<fir::GlobalLenOp> { |
946 | using FIROpConversion::FIROpConversion; |
947 | |
948 | llvm::LogicalResult |
949 | matchAndRewrite(fir::GlobalLenOp globalLen, OpAdaptor adaptor, |
950 | mlir::ConversionPatternRewriter &rewriter) const override { |
951 | TODO(globalLen.getLoc(), "fir.global_len codegen" ); |
952 | return mlir::failure(); |
953 | } |
954 | }; |
955 | |
956 | /// Lower fir.len_param_index |
957 | struct LenParamIndexOpConversion |
958 | : public fir::FIROpConversion<fir::LenParamIndexOp> { |
959 | using FIROpConversion::FIROpConversion; |
960 | |
961 | // FIXME: this should be specialized by the runtime target |
962 | llvm::LogicalResult |
963 | matchAndRewrite(fir::LenParamIndexOp lenp, OpAdaptor, |
964 | mlir::ConversionPatternRewriter &rewriter) const override { |
965 | TODO(lenp.getLoc(), "fir.len_param_index codegen" ); |
966 | } |
967 | }; |
968 | |
969 | /// Convert `!fir.emboxchar<!fir.char<KIND, ?>, #n>` into a sequence of |
970 | /// instructions that generate `!llvm.struct<(ptr<ik>, i64)>`. The 1st element |
971 | /// in this struct is a pointer. Its type is determined from `KIND`. The 2nd |
972 | /// element is the length of the character buffer (`#n`). |
973 | struct EmboxCharOpConversion : public fir::FIROpConversion<fir::EmboxCharOp> { |
974 | using FIROpConversion::FIROpConversion; |
975 | |
976 | llvm::LogicalResult |
977 | matchAndRewrite(fir::EmboxCharOp emboxChar, OpAdaptor adaptor, |
978 | mlir::ConversionPatternRewriter &rewriter) const override { |
979 | mlir::ValueRange operands = adaptor.getOperands(); |
980 | |
981 | mlir::Value charBuffer = operands[0]; |
982 | mlir::Value charBufferLen = operands[1]; |
983 | |
984 | mlir::Location loc = emboxChar.getLoc(); |
985 | mlir::Type llvmStructTy = convertType(emboxChar.getType()); |
986 | auto llvmStruct = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmStructTy); |
987 | |
988 | mlir::Type lenTy = |
989 | mlir::cast<mlir::LLVM::LLVMStructType>(llvmStructTy).getBody()[1]; |
990 | mlir::Value lenAfterCast = integerCast(loc, rewriter, lenTy, charBufferLen); |
991 | |
992 | mlir::Type addrTy = |
993 | mlir::cast<mlir::LLVM::LLVMStructType>(llvmStructTy).getBody()[0]; |
994 | if (addrTy != charBuffer.getType()) |
995 | charBuffer = |
996 | rewriter.create<mlir::LLVM::BitcastOp>(loc, addrTy, charBuffer); |
997 | |
998 | auto insertBufferOp = rewriter.create<mlir::LLVM::InsertValueOp>( |
999 | loc, llvmStruct, charBuffer, 0); |
1000 | rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>( |
1001 | emboxChar, insertBufferOp, lenAfterCast, 1); |
1002 | |
1003 | return mlir::success(); |
1004 | } |
1005 | }; |
1006 | } // namespace |
1007 | |
1008 | template <typename ModuleOp> |
1009 | static mlir::SymbolRefAttr |
1010 | getMallocInModule(ModuleOp mod, fir::AllocMemOp op, |
1011 | mlir::ConversionPatternRewriter &rewriter, |
1012 | mlir::Type indexType) { |
1013 | static constexpr char mallocName[] = "malloc" ; |
1014 | if (auto mallocFunc = |
1015 | mod.template lookupSymbol<mlir::LLVM::LLVMFuncOp>(mallocName)) |
1016 | return mlir::SymbolRefAttr::get(mallocFunc); |
1017 | if (auto userMalloc = |
1018 | mod.template lookupSymbol<mlir::func::FuncOp>(mallocName)) |
1019 | return mlir::SymbolRefAttr::get(userMalloc); |
1020 | |
1021 | mlir::OpBuilder moduleBuilder(mod.getBodyRegion()); |
1022 | auto mallocDecl = moduleBuilder.create<mlir::LLVM::LLVMFuncOp>( |
1023 | op.getLoc(), mallocName, |
1024 | mlir::LLVM::LLVMFunctionType::get(getLlvmPtrType(op.getContext()), |
1025 | indexType, |
1026 | /*isVarArg=*/false)); |
1027 | return mlir::SymbolRefAttr::get(mallocDecl); |
1028 | } |
1029 | |
1030 | /// Return the LLVMFuncOp corresponding to the standard malloc call. |
1031 | static mlir::SymbolRefAttr getMalloc(fir::AllocMemOp op, |
1032 | mlir::ConversionPatternRewriter &rewriter, |
1033 | mlir::Type indexType) { |
1034 | if (auto mod = op->getParentOfType<mlir::gpu::GPUModuleOp>()) |
1035 | return getMallocInModule(mod, op, rewriter, indexType); |
1036 | auto mod = op->getParentOfType<mlir::ModuleOp>(); |
1037 | return getMallocInModule(mod, op, rewriter, indexType); |
1038 | } |
1039 | |
1040 | /// Helper function for generating the LLVM IR that computes the distance |
1041 | /// in bytes between adjacent elements pointed to by a pointer |
1042 | /// of type \p ptrTy. The result is returned as a value of \p idxTy integer |
1043 | /// type. |
1044 | static mlir::Value |
1045 | computeElementDistance(mlir::Location loc, mlir::Type llvmObjectType, |
1046 | mlir::Type idxTy, |
1047 | mlir::ConversionPatternRewriter &rewriter, |
1048 | const mlir::DataLayout &dataLayout) { |
1049 | llvm::TypeSize size = dataLayout.getTypeSize(llvmObjectType); |
1050 | unsigned short alignment = dataLayout.getTypeABIAlignment(llvmObjectType); |
1051 | std::int64_t distance = llvm::alignTo(size, alignment); |
1052 | return genConstantIndex(loc, idxTy, rewriter, distance); |
1053 | } |
1054 | |
1055 | /// Return value of the stride in bytes between adjacent elements |
1056 | /// of LLVM type \p llTy. The result is returned as a value of |
1057 | /// \p idxTy integer type. |
1058 | static mlir::Value |
1059 | genTypeStrideInBytes(mlir::Location loc, mlir::Type idxTy, |
1060 | mlir::ConversionPatternRewriter &rewriter, mlir::Type llTy, |
1061 | const mlir::DataLayout &dataLayout) { |
1062 | // Create a pointer type and use computeElementDistance(). |
1063 | return computeElementDistance(loc, llTy, idxTy, rewriter, dataLayout); |
1064 | } |
1065 | |
1066 | namespace { |
1067 | /// Lower a `fir.allocmem` instruction into `llvm.call @malloc` |
1068 | struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> { |
1069 | using FIROpConversion::FIROpConversion; |
1070 | |
1071 | llvm::LogicalResult |
1072 | matchAndRewrite(fir::AllocMemOp heap, OpAdaptor adaptor, |
1073 | mlir::ConversionPatternRewriter &rewriter) const override { |
1074 | mlir::Type heapTy = heap.getType(); |
1075 | mlir::Location loc = heap.getLoc(); |
1076 | auto ity = lowerTy().indexType(); |
1077 | mlir::Type dataTy = fir::unwrapRefType(heapTy); |
1078 | mlir::Type llvmObjectTy = convertObjectType(dataTy); |
1079 | if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy))) |
1080 | TODO(loc, "fir.allocmem codegen of derived type with length parameters" ); |
1081 | mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy); |
1082 | if (auto scaleSize = genAllocationScaleSize(heap, ity, rewriter)) |
1083 | size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize); |
1084 | for (mlir::Value opnd : adaptor.getOperands()) |
1085 | size = rewriter.create<mlir::LLVM::MulOp>( |
1086 | loc, ity, size, integerCast(loc, rewriter, ity, opnd)); |
1087 | auto mallocTyWidth = lowerTy().getIndexTypeBitwidth(); |
1088 | auto mallocTy = |
1089 | mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth); |
1090 | if (mallocTyWidth != ity.getIntOrFloatBitWidth()) |
1091 | size = integerCast(loc, rewriter, mallocTy, size); |
1092 | heap->setAttr("callee" , getMalloc(heap, rewriter, mallocTy)); |
1093 | rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>( |
1094 | heap, ::getLlvmPtrType(heap.getContext()), size, |
1095 | addLLVMOpBundleAttrs(rewriter, heap->getAttrs(), 1)); |
1096 | return mlir::success(); |
1097 | } |
1098 | |
1099 | /// Compute the allocation size in bytes of the element type of |
1100 | /// \p llTy pointer type. The result is returned as a value of \p idxTy |
1101 | /// integer type. |
1102 | mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy, |
1103 | mlir::ConversionPatternRewriter &rewriter, |
1104 | mlir::Type llTy) const { |
1105 | return computeElementDistance(loc, llTy, idxTy, rewriter, getDataLayout()); |
1106 | } |
1107 | }; |
1108 | } // namespace |
1109 | |
1110 | /// Return the LLVMFuncOp corresponding to the standard free call. |
1111 | template <typename ModuleOp> |
1112 | static mlir::SymbolRefAttr |
1113 | getFreeInModule(ModuleOp mod, fir::FreeMemOp op, |
1114 | mlir::ConversionPatternRewriter &rewriter) { |
1115 | static constexpr char freeName[] = "free" ; |
1116 | // Check if free already defined in the module. |
1117 | if (auto freeFunc = |
1118 | mod.template lookupSymbol<mlir::LLVM::LLVMFuncOp>(freeName)) |
1119 | return mlir::SymbolRefAttr::get(freeFunc); |
1120 | if (auto freeDefinedByUser = |
1121 | mod.template lookupSymbol<mlir::func::FuncOp>(freeName)) |
1122 | return mlir::SymbolRefAttr::get(freeDefinedByUser); |
1123 | // Create llvm declaration for free. |
1124 | mlir::OpBuilder moduleBuilder(mod.getBodyRegion()); |
1125 | auto voidType = mlir::LLVM::LLVMVoidType::get(op.getContext()); |
1126 | auto freeDecl = moduleBuilder.create<mlir::LLVM::LLVMFuncOp>( |
1127 | rewriter.getUnknownLoc(), freeName, |
1128 | mlir::LLVM::LLVMFunctionType::get(voidType, |
1129 | getLlvmPtrType(op.getContext()), |
1130 | /*isVarArg=*/false)); |
1131 | return mlir::SymbolRefAttr::get(freeDecl); |
1132 | } |
1133 | |
1134 | static mlir::SymbolRefAttr getFree(fir::FreeMemOp op, |
1135 | mlir::ConversionPatternRewriter &rewriter) { |
1136 | if (auto mod = op->getParentOfType<mlir::gpu::GPUModuleOp>()) |
1137 | return getFreeInModule(mod, op, rewriter); |
1138 | auto mod = op->getParentOfType<mlir::ModuleOp>(); |
1139 | return getFreeInModule(mod, op, rewriter); |
1140 | } |
1141 | |
1142 | static unsigned getDimension(mlir::LLVM::LLVMArrayType ty) { |
1143 | unsigned result = 1; |
1144 | for (auto eleTy = |
1145 | mlir::dyn_cast<mlir::LLVM::LLVMArrayType>(ty.getElementType()); |
1146 | eleTy; eleTy = mlir::dyn_cast<mlir::LLVM::LLVMArrayType>( |
1147 | eleTy.getElementType())) |
1148 | ++result; |
1149 | return result; |
1150 | } |
1151 | |
1152 | namespace { |
1153 | /// Lower a `fir.freemem` instruction into `llvm.call @free` |
1154 | struct FreeMemOpConversion : public fir::FIROpConversion<fir::FreeMemOp> { |
1155 | using FIROpConversion::FIROpConversion; |
1156 | |
1157 | llvm::LogicalResult |
1158 | matchAndRewrite(fir::FreeMemOp freemem, OpAdaptor adaptor, |
1159 | mlir::ConversionPatternRewriter &rewriter) const override { |
1160 | mlir::Location loc = freemem.getLoc(); |
1161 | freemem->setAttr("callee" , getFree(freemem, rewriter)); |
1162 | rewriter.create<mlir::LLVM::CallOp>( |
1163 | loc, mlir::TypeRange{}, mlir::ValueRange{adaptor.getHeapref()}, |
1164 | addLLVMOpBundleAttrs(rewriter, freemem->getAttrs(), 1)); |
1165 | rewriter.eraseOp(freemem); |
1166 | return mlir::success(); |
1167 | } |
1168 | }; |
1169 | } // namespace |
1170 | |
1171 | // Convert subcomponent array indices from column-major to row-major ordering. |
1172 | static llvm::SmallVector<mlir::Value> |
1173 | convertSubcomponentIndices(mlir::Location loc, mlir::Type eleTy, |
1174 | mlir::ValueRange indices, |
1175 | mlir::Type *retTy = nullptr) { |
1176 | llvm::SmallVector<mlir::Value> result; |
1177 | llvm::SmallVector<mlir::Value> arrayIndices; |
1178 | |
1179 | auto appendArrayIndices = [&] { |
1180 | if (arrayIndices.empty()) |
1181 | return; |
1182 | std::reverse(arrayIndices.begin(), arrayIndices.end()); |
1183 | result.append(arrayIndices.begin(), arrayIndices.end()); |
1184 | arrayIndices.clear(); |
1185 | }; |
1186 | |
1187 | for (mlir::Value index : indices) { |
1188 | // Component indices can be field index to select a component, or array |
1189 | // index, to select an element in an array component. |
1190 | if (auto structTy = mlir::dyn_cast<mlir::LLVM::LLVMStructType>(eleTy)) { |
1191 | std::int64_t cstIndex = getConstantIntValue(index); |
1192 | assert(cstIndex < (int64_t)structTy.getBody().size() && |
1193 | "out-of-bounds struct field index" ); |
1194 | eleTy = structTy.getBody()[cstIndex]; |
1195 | appendArrayIndices(); |
1196 | result.push_back(index); |
1197 | } else if (auto arrayTy = |
1198 | mlir::dyn_cast<mlir::LLVM::LLVMArrayType>(eleTy)) { |
1199 | eleTy = arrayTy.getElementType(); |
1200 | arrayIndices.push_back(index); |
1201 | } else |
1202 | fir::emitFatalError(loc, "Unexpected subcomponent type" ); |
1203 | } |
1204 | appendArrayIndices(); |
1205 | if (retTy) |
1206 | *retTy = eleTy; |
1207 | return result; |
1208 | } |
1209 | |
1210 | static mlir::Value genSourceFile(mlir::Location loc, mlir::ModuleOp mod, |
1211 | mlir::ConversionPatternRewriter &rewriter) { |
1212 | auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); |
1213 | if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc)) { |
1214 | auto fn = flc.getFilename().str() + '\0'; |
1215 | std::string globalName = fir::factory::uniqueCGIdent("cl" , fn); |
1216 | |
1217 | if (auto g = mod.lookupSymbol<fir::GlobalOp>(globalName)) { |
1218 | return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName()); |
1219 | } else if (auto g = mod.lookupSymbol<mlir::LLVM::GlobalOp>(globalName)) { |
1220 | return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, g.getName()); |
1221 | } |
1222 | |
1223 | auto crtInsPt = rewriter.saveInsertionPoint(); |
1224 | rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end()); |
1225 | auto arrayTy = mlir::LLVM::LLVMArrayType::get( |
1226 | mlir::IntegerType::get(rewriter.getContext(), 8), fn.size()); |
1227 | mlir::LLVM::GlobalOp globalOp = rewriter.create<mlir::LLVM::GlobalOp>( |
1228 | loc, arrayTy, /*constant=*/true, mlir::LLVM::Linkage::Linkonce, |
1229 | globalName, mlir::Attribute()); |
1230 | |
1231 | mlir::Region ®ion = globalOp.getInitializerRegion(); |
1232 | mlir::Block *block = rewriter.createBlock(®ion); |
1233 | rewriter.setInsertionPoint(block, block->begin()); |
1234 | mlir::Value constValue = rewriter.create<mlir::LLVM::ConstantOp>( |
1235 | loc, arrayTy, rewriter.getStringAttr(fn)); |
1236 | rewriter.create<mlir::LLVM::ReturnOp>(loc, constValue); |
1237 | rewriter.restoreInsertionPoint(crtInsPt); |
1238 | return rewriter.create<mlir::LLVM::AddressOfOp>(loc, ptrTy, |
1239 | globalOp.getName()); |
1240 | } |
1241 | return rewriter.create<mlir::LLVM::ZeroOp>(loc, ptrTy); |
1242 | } |
1243 | |
1244 | static mlir::Value genSourceLine(mlir::Location loc, |
1245 | mlir::ConversionPatternRewriter &rewriter) { |
1246 | if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(loc)) |
1247 | return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(), |
1248 | flc.getLine()); |
1249 | return rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0); |
1250 | } |
1251 | |
1252 | static mlir::Value |
1253 | genCUFAllocDescriptor(mlir::Location loc, |
1254 | mlir::ConversionPatternRewriter &rewriter, |
1255 | mlir::ModuleOp mod, fir::BaseBoxType boxTy, |
1256 | const fir::LLVMTypeConverter &typeConverter) { |
1257 | std::optional<mlir::DataLayout> dl = |
1258 | fir::support::getOrSetMLIRDataLayout(mod, /*allowDefaultLayout=*/true); |
1259 | if (!dl) |
1260 | mlir::emitError(mod.getLoc(), |
1261 | "module operation must carry a data layout attribute " |
1262 | "to generate llvm IR from FIR" ); |
1263 | |
1264 | mlir::Value sourceFile = genSourceFile(loc, mod, rewriter); |
1265 | mlir::Value sourceLine = genSourceLine(loc, rewriter); |
1266 | |
1267 | mlir::MLIRContext *ctx = mod.getContext(); |
1268 | |
1269 | mlir::LLVM::LLVMPointerType llvmPointerType = |
1270 | mlir::LLVM::LLVMPointerType::get(ctx); |
1271 | mlir::Type llvmInt32Type = mlir::IntegerType::get(ctx, 32); |
1272 | mlir::Type llvmIntPtrType = |
1273 | mlir::IntegerType::get(ctx, typeConverter.getPointerBitwidth(0)); |
1274 | auto fctTy = mlir::LLVM::LLVMFunctionType::get( |
1275 | llvmPointerType, {llvmIntPtrType, llvmPointerType, llvmInt32Type}); |
1276 | |
1277 | auto llvmFunc = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>( |
1278 | RTNAME_STRING(CUFAllocDescriptor)); |
1279 | auto funcFunc = |
1280 | mod.lookupSymbol<mlir::func::FuncOp>(RTNAME_STRING(CUFAllocDescriptor)); |
1281 | if (!llvmFunc && !funcFunc) |
1282 | mlir::OpBuilder::atBlockEnd(mod.getBody()) |
1283 | .create<mlir::LLVM::LLVMFuncOp>(loc, RTNAME_STRING(CUFAllocDescriptor), |
1284 | fctTy); |
1285 | |
1286 | mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy); |
1287 | std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8; |
1288 | mlir::Value sizeInBytes = |
1289 | genConstantIndex(loc, llvmIntPtrType, rewriter, boxSize); |
1290 | llvm::SmallVector args = {sizeInBytes, sourceFile, sourceLine}; |
1291 | return rewriter |
1292 | .create<mlir::LLVM::CallOp>(loc, fctTy, RTNAME_STRING(CUFAllocDescriptor), |
1293 | args) |
1294 | .getResult(); |
1295 | } |
1296 | |
1297 | /// Common base class for embox to descriptor conversion. |
1298 | template <typename OP> |
1299 | struct EmboxCommonConversion : public fir::FIROpConversion<OP> { |
1300 | using fir::FIROpConversion<OP>::FIROpConversion; |
1301 | using TypePair = typename fir::FIROpConversion<OP>::TypePair; |
1302 | |
1303 | static int getCFIAttr(fir::BaseBoxType boxTy) { |
1304 | auto eleTy = boxTy.getEleTy(); |
1305 | if (mlir::isa<fir::PointerType>(eleTy)) |
1306 | return CFI_attribute_pointer; |
1307 | if (mlir::isa<fir::HeapType>(eleTy)) |
1308 | return CFI_attribute_allocatable; |
1309 | return CFI_attribute_other; |
1310 | } |
1311 | |
1312 | mlir::Value getCharacterByteSize(mlir::Location loc, |
1313 | mlir::ConversionPatternRewriter &rewriter, |
1314 | fir::CharacterType charTy, |
1315 | mlir::ValueRange lenParams) const { |
1316 | auto i64Ty = mlir::IntegerType::get(rewriter.getContext(), 64); |
1317 | mlir::Value size = genTypeStrideInBytes( |
1318 | loc, i64Ty, rewriter, this->convertType(charTy), this->getDataLayout()); |
1319 | if (charTy.hasConstantLen()) |
1320 | return size; // Length accounted for in the genTypeStrideInBytes GEP. |
1321 | // Otherwise, multiply the single character size by the length. |
1322 | assert(!lenParams.empty()); |
1323 | auto len64 = fir::FIROpConversion<OP>::integerCast(loc, rewriter, i64Ty, |
1324 | lenParams.back()); |
1325 | return rewriter.create<mlir::LLVM::MulOp>(loc, i64Ty, size, len64); |
1326 | } |
1327 | |
1328 | // Get the element size and CFI type code of the boxed value. |
1329 | std::tuple<mlir::Value, mlir::Value> getSizeAndTypeCode( |
1330 | mlir::Location loc, mlir::ConversionPatternRewriter &rewriter, |
1331 | mlir::Type boxEleTy, mlir::ValueRange lenParams = {}) const { |
1332 | const mlir::DataLayout &dataLayout = this->getDataLayout(); |
1333 | auto i64Ty = mlir::IntegerType::get(rewriter.getContext(), 64); |
1334 | if (auto eleTy = fir::dyn_cast_ptrEleTy(boxEleTy)) |
1335 | boxEleTy = eleTy; |
1336 | if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(boxEleTy)) |
1337 | return getSizeAndTypeCode(loc, rewriter, seqTy.getEleTy(), lenParams); |
1338 | if (mlir::isa<mlir::NoneType>( |
1339 | boxEleTy)) // unlimited polymorphic or assumed type |
1340 | return {rewriter.create<mlir::LLVM::ConstantOp>(loc, i64Ty, 0), |
1341 | this->genConstantOffset(loc, rewriter, CFI_type_other)}; |
1342 | mlir::Value typeCodeVal = this->genConstantOffset( |
1343 | loc, rewriter, |
1344 | fir::getTypeCode(boxEleTy, this->lowerTy().getKindMap())); |
1345 | if (fir::isa_integer(boxEleTy) || |
1346 | mlir::dyn_cast<fir::LogicalType>(boxEleTy) || fir::isa_real(boxEleTy) || |
1347 | fir::isa_complex(boxEleTy)) |
1348 | return {genTypeStrideInBytes(loc, i64Ty, rewriter, |
1349 | this->convertType(boxEleTy), dataLayout), |
1350 | typeCodeVal}; |
1351 | if (auto charTy = mlir::dyn_cast<fir::CharacterType>(boxEleTy)) |
1352 | return {getCharacterByteSize(loc, rewriter, charTy, lenParams), |
1353 | typeCodeVal}; |
1354 | if (fir::isa_ref_type(boxEleTy)) { |
1355 | auto ptrTy = ::getLlvmPtrType(rewriter.getContext()); |
1356 | return {genTypeStrideInBytes(loc, i64Ty, rewriter, ptrTy, dataLayout), |
1357 | typeCodeVal}; |
1358 | } |
1359 | if (mlir::isa<fir::RecordType>(boxEleTy)) |
1360 | return {genTypeStrideInBytes(loc, i64Ty, rewriter, |
1361 | this->convertType(boxEleTy), dataLayout), |
1362 | typeCodeVal}; |
1363 | fir::emitFatalError(loc, "unhandled type in fir.box code generation" ); |
1364 | } |
1365 | |
1366 | /// Basic pattern to write a field in the descriptor |
1367 | mlir::Value insertField(mlir::ConversionPatternRewriter &rewriter, |
1368 | mlir::Location loc, mlir::Value dest, |
1369 | llvm::ArrayRef<std::int64_t> fldIndexes, |
1370 | mlir::Value value, bool bitcast = false) const { |
1371 | auto boxTy = dest.getType(); |
1372 | auto fldTy = this->getBoxEleTy(boxTy, fldIndexes); |
1373 | if (!bitcast) |
1374 | value = this->integerCast(loc, rewriter, fldTy, value); |
1375 | // bitcast are no-ops with LLVM opaque pointers. |
1376 | return rewriter.create<mlir::LLVM::InsertValueOp>(loc, dest, value, |
1377 | fldIndexes); |
1378 | } |
1379 | |
1380 | inline mlir::Value |
1381 | insertBaseAddress(mlir::ConversionPatternRewriter &rewriter, |
1382 | mlir::Location loc, mlir::Value dest, |
1383 | mlir::Value base) const { |
1384 | return insertField(rewriter, loc, dest, {kAddrPosInBox}, base, |
1385 | /*bitCast=*/true); |
1386 | } |
1387 | |
1388 | inline mlir::Value insertLowerBound(mlir::ConversionPatternRewriter &rewriter, |
1389 | mlir::Location loc, mlir::Value dest, |
1390 | unsigned dim, mlir::Value lb) const { |
1391 | return insertField(rewriter, loc, dest, |
1392 | {kDimsPosInBox, dim, kDimLowerBoundPos}, lb); |
1393 | } |
1394 | |
1395 | inline mlir::Value insertExtent(mlir::ConversionPatternRewriter &rewriter, |
1396 | mlir::Location loc, mlir::Value dest, |
1397 | unsigned dim, mlir::Value extent) const { |
1398 | return insertField(rewriter, loc, dest, {kDimsPosInBox, dim, kDimExtentPos}, |
1399 | extent); |
1400 | } |
1401 | |
1402 | inline mlir::Value insertStride(mlir::ConversionPatternRewriter &rewriter, |
1403 | mlir::Location loc, mlir::Value dest, |
1404 | unsigned dim, mlir::Value stride) const { |
1405 | return insertField(rewriter, loc, dest, {kDimsPosInBox, dim, kDimStridePos}, |
1406 | stride); |
1407 | } |
1408 | |
1409 | /// Get the address of the type descriptor global variable that was created by |
1410 | /// lowering for derived type \p recType. |
1411 | template <typename ModOpTy> |
1412 | mlir::Value |
1413 | getTypeDescriptor(ModOpTy mod, mlir::ConversionPatternRewriter &rewriter, |
1414 | mlir::Location loc, fir::RecordType recType) const { |
1415 | std::string name = |
1416 | this->options.typeDescriptorsRenamedForAssembly |
1417 | ? fir::NameUniquer::getTypeDescriptorAssemblyName(recType.getName()) |
1418 | : fir::NameUniquer::getTypeDescriptorName(recType.getName()); |
1419 | mlir::Type llvmPtrTy = ::getLlvmPtrType(mod.getContext()); |
1420 | if (auto global = mod.template lookupSymbol<fir::GlobalOp>(name)) { |
1421 | return rewriter.create<mlir::LLVM::AddressOfOp>(loc, llvmPtrTy, |
1422 | global.getSymName()); |
1423 | } |
1424 | if (auto global = mod.template lookupSymbol<mlir::LLVM::GlobalOp>(name)) { |
1425 | // The global may have already been translated to LLVM. |
1426 | return rewriter.create<mlir::LLVM::AddressOfOp>(loc, llvmPtrTy, |
1427 | global.getSymName()); |
1428 | } |
1429 | // Type info derived types do not have type descriptors since they are the |
1430 | // types defining type descriptors. |
1431 | if (!this->options.ignoreMissingTypeDescriptors && |
1432 | !fir::NameUniquer::belongsToModule( |
1433 | name, Fortran::semantics::typeInfoBuiltinModule)) |
1434 | fir::emitFatalError( |
1435 | loc, "runtime derived type info descriptor was not generated" ); |
1436 | return rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPtrTy); |
1437 | } |
1438 | |
1439 | template <typename ModOpTy> |
1440 | mlir::Value populateDescriptor(mlir::Location loc, ModOpTy mod, |
1441 | fir::BaseBoxType boxTy, mlir::Type inputType, |
1442 | mlir::ConversionPatternRewriter &rewriter, |
1443 | unsigned rank, mlir::Value eleSize, |
1444 | mlir::Value cfiTy, mlir::Value typeDesc, |
1445 | int allocatorIdx = kDefaultAllocator, |
1446 | mlir::Value = {}) const { |
1447 | auto llvmBoxTy = this->lowerTy().convertBoxTypeAsStruct(boxTy, rank); |
1448 | bool isUnlimitedPolymorphic = fir::isUnlimitedPolymorphicType(boxTy); |
1449 | bool useInputType = fir::isPolymorphicType(boxTy) || isUnlimitedPolymorphic; |
1450 | mlir::Value descriptor = |
1451 | rewriter.create<mlir::LLVM::UndefOp>(loc, llvmBoxTy); |
1452 | descriptor = |
1453 | insertField(rewriter, loc, descriptor, {kElemLenPosInBox}, eleSize); |
1454 | descriptor = insertField(rewriter, loc, descriptor, {kVersionPosInBox}, |
1455 | this->genI32Constant(loc, rewriter, CFI_VERSION)); |
1456 | descriptor = insertField(rewriter, loc, descriptor, {kRankPosInBox}, |
1457 | this->genI32Constant(loc, rewriter, rank)); |
1458 | descriptor = insertField(rewriter, loc, descriptor, {kTypePosInBox}, cfiTy); |
1459 | descriptor = |
1460 | insertField(rewriter, loc, descriptor, {kAttributePosInBox}, |
1461 | this->genI32Constant(loc, rewriter, getCFIAttr(boxTy))); |
1462 | |
1463 | const bool hasAddendum = fir::boxHasAddendum(boxTy); |
1464 | |
1465 | if (extraField) { |
1466 | // Make sure to set the addendum presence flag according to the |
1467 | // destination box. |
1468 | if (hasAddendum) { |
1469 | auto maskAttr = mlir::IntegerAttr::get( |
1470 | rewriter.getIntegerType(8, /*isSigned=*/false), |
1471 | llvm::APInt(8, (uint64_t)_CFI_ADDENDUM_FLAG, /*isSigned=*/false)); |
1472 | mlir::LLVM::ConstantOp mask = rewriter.create<mlir::LLVM::ConstantOp>( |
1473 | loc, rewriter.getI8Type(), maskAttr); |
1474 | extraField = rewriter.create<mlir::LLVM::OrOp>(loc, extraField, mask); |
1475 | } else { |
1476 | auto maskAttr = mlir::IntegerAttr::get( |
1477 | rewriter.getIntegerType(8, /*isSigned=*/false), |
1478 | llvm::APInt(8, (uint64_t)~_CFI_ADDENDUM_FLAG, /*isSigned=*/true)); |
1479 | mlir::LLVM::ConstantOp mask = rewriter.create<mlir::LLVM::ConstantOp>( |
1480 | loc, rewriter.getI8Type(), maskAttr); |
1481 | extraField = rewriter.create<mlir::LLVM::AndOp>(loc, extraField, mask); |
1482 | } |
1483 | // Extra field value is provided so just use it. |
1484 | descriptor = |
1485 | insertField(rewriter, loc, descriptor, {kExtraPosInBox}, extraField); |
1486 | } else { |
1487 | // Compute the value of the extra field based on allocator_idx and |
1488 | // addendum present. |
1489 | unsigned = allocatorIdx << _CFI_ALLOCATOR_IDX_SHIFT; |
1490 | if (hasAddendum) |
1491 | extra |= _CFI_ADDENDUM_FLAG; |
1492 | descriptor = insertField(rewriter, loc, descriptor, {kExtraPosInBox}, |
1493 | this->genI32Constant(loc, rewriter, extra)); |
1494 | } |
1495 | |
1496 | if (hasAddendum) { |
1497 | unsigned typeDescFieldId = getTypeDescFieldId(boxTy); |
1498 | if (!typeDesc) { |
1499 | if (useInputType) { |
1500 | mlir::Type innerType = fir::unwrapInnerType(inputType); |
1501 | if (innerType && mlir::isa<fir::RecordType>(innerType)) { |
1502 | auto recTy = mlir::dyn_cast<fir::RecordType>(innerType); |
1503 | typeDesc = getTypeDescriptor(mod, rewriter, loc, recTy); |
1504 | } else { |
1505 | // Unlimited polymorphic type descriptor with no record type. Set |
1506 | // type descriptor address to a clean state. |
1507 | typeDesc = rewriter.create<mlir::LLVM::ZeroOp>( |
1508 | loc, ::getLlvmPtrType(mod.getContext())); |
1509 | } |
1510 | } else { |
1511 | typeDesc = getTypeDescriptor(mod, rewriter, loc, |
1512 | fir::unwrapIfDerived(boxTy)); |
1513 | } |
1514 | } |
1515 | if (typeDesc) |
1516 | descriptor = |
1517 | insertField(rewriter, loc, descriptor, {typeDescFieldId}, typeDesc, |
1518 | /*bitCast=*/true); |
1519 | // Always initialize the length parameter field to zero to avoid issues |
1520 | // with uninitialized values in Fortran code trying to compare physical |
1521 | // representation of derived types with pointer/allocatable components. |
1522 | // This has been seen in hashing algorithms using TRANSFER. |
1523 | mlir::Value zero = |
1524 | genConstantIndex(loc, rewriter.getI64Type(), rewriter, 0); |
1525 | descriptor = insertField(rewriter, loc, descriptor, |
1526 | {getLenParamFieldId(boxTy), 0}, zero); |
1527 | } |
1528 | return descriptor; |
1529 | } |
1530 | |
1531 | // Template used for fir::EmboxOp and fir::cg::XEmboxOp |
1532 | template <typename BOX> |
1533 | std::tuple<fir::BaseBoxType, mlir::Value, mlir::Value> |
1534 | consDescriptorPrefix(BOX box, mlir::Type inputType, |
1535 | mlir::ConversionPatternRewriter &rewriter, unsigned rank, |
1536 | [[maybe_unused]] mlir::ValueRange substrParams, |
1537 | mlir::ValueRange lenParams, mlir::Value sourceBox = {}, |
1538 | mlir::Type sourceBoxType = {}) const { |
1539 | auto loc = box.getLoc(); |
1540 | auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(box.getType()); |
1541 | bool useInputType = fir::isPolymorphicType(boxTy) && |
1542 | !fir::isUnlimitedPolymorphicType(inputType); |
1543 | llvm::SmallVector<mlir::Value> typeparams = lenParams; |
1544 | if constexpr (!std::is_same_v<BOX, fir::EmboxOp>) { |
1545 | if (!box.getSubstr().empty() && fir::hasDynamicSize(boxTy.getEleTy())) |
1546 | typeparams.push_back(substrParams[1]); |
1547 | } |
1548 | |
1549 | int allocatorIdx = 0; |
1550 | if constexpr (std::is_same_v<BOX, fir::EmboxOp> || |
1551 | std::is_same_v<BOX, fir::cg::XEmboxOp>) { |
1552 | if (box.getAllocatorIdx()) |
1553 | allocatorIdx = *box.getAllocatorIdx(); |
1554 | } |
1555 | |
1556 | // Write each of the fields with the appropriate values. |
1557 | // When emboxing an element to a polymorphic descriptor, use the |
1558 | // input type since the destination descriptor type has not the exact |
1559 | // information. |
1560 | auto [eleSize, cfiTy] = getSizeAndTypeCode( |
1561 | loc, rewriter, useInputType ? inputType : boxTy.getEleTy(), typeparams); |
1562 | |
1563 | mlir::Value typeDesc; |
1564 | mlir::Value ; |
1565 | // When emboxing to a polymorphic box, get the type descriptor, type code |
1566 | // and element size from the source box if any. |
1567 | if (fir::isPolymorphicType(boxTy) && sourceBox) { |
1568 | TypePair sourceBoxTyPair = this->getBoxTypePair(sourceBoxType); |
1569 | typeDesc = |
1570 | this->loadTypeDescAddress(loc, sourceBoxTyPair, sourceBox, rewriter); |
1571 | mlir::Type idxTy = this->lowerTy().indexType(); |
1572 | eleSize = this->getElementSizeFromBox(loc, idxTy, sourceBoxTyPair, |
1573 | sourceBox, rewriter); |
1574 | cfiTy = this->getValueFromBox(loc, sourceBoxTyPair, sourceBox, |
1575 | cfiTy.getType(), rewriter, kTypePosInBox); |
1576 | extraField = |
1577 | this->getExtraFromBox(loc, sourceBoxTyPair, sourceBox, rewriter); |
1578 | } |
1579 | |
1580 | mlir::Value descriptor; |
1581 | if (auto gpuMod = box->template getParentOfType<mlir::gpu::GPUModuleOp>()) |
1582 | descriptor = populateDescriptor(loc, gpuMod, boxTy, inputType, rewriter, |
1583 | rank, eleSize, cfiTy, typeDesc, |
1584 | allocatorIdx, extraField); |
1585 | else if (auto mod = box->template getParentOfType<mlir::ModuleOp>()) |
1586 | descriptor = populateDescriptor(loc, mod, boxTy, inputType, rewriter, |
1587 | rank, eleSize, cfiTy, typeDesc, |
1588 | allocatorIdx, extraField); |
1589 | |
1590 | return {boxTy, descriptor, eleSize}; |
1591 | } |
1592 | |
1593 | std::tuple<fir::BaseBoxType, mlir::Value, mlir::Value> |
1594 | consDescriptorPrefix(fir::cg::XReboxOp box, mlir::Value loweredBox, |
1595 | mlir::ConversionPatternRewriter &rewriter, unsigned rank, |
1596 | mlir::ValueRange substrParams, |
1597 | mlir::ValueRange lenParams, |
1598 | mlir::Value typeDesc = {}) const { |
1599 | auto loc = box.getLoc(); |
1600 | auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(box.getType()); |
1601 | auto inputBoxTy = mlir::dyn_cast<fir::BaseBoxType>(box.getBox().getType()); |
1602 | auto inputBoxTyPair = this->getBoxTypePair(inputBoxTy); |
1603 | llvm::SmallVector<mlir::Value> typeparams = lenParams; |
1604 | if (!box.getSubstr().empty() && fir::hasDynamicSize(boxTy.getEleTy())) |
1605 | typeparams.push_back(substrParams[1]); |
1606 | |
1607 | auto [eleSize, cfiTy] = |
1608 | getSizeAndTypeCode(loc, rewriter, boxTy.getEleTy(), typeparams); |
1609 | |
1610 | // Reboxing to a polymorphic entity. eleSize and type code need to |
1611 | // be retrieved from the initial box and propagated to the new box. |
1612 | // If the initial box has an addendum, the type desc must be propagated as |
1613 | // well. |
1614 | if (fir::isPolymorphicType(boxTy)) { |
1615 | mlir::Type idxTy = this->lowerTy().indexType(); |
1616 | eleSize = this->getElementSizeFromBox(loc, idxTy, inputBoxTyPair, |
1617 | loweredBox, rewriter); |
1618 | cfiTy = this->getValueFromBox(loc, inputBoxTyPair, loweredBox, |
1619 | cfiTy.getType(), rewriter, kTypePosInBox); |
1620 | // TODO: For initial box that are unlimited polymorphic entities, this |
1621 | // code must be made conditional because unlimited polymorphic entities |
1622 | // with intrinsic type spec does not have addendum. |
1623 | if (fir::boxHasAddendum(inputBoxTy)) |
1624 | typeDesc = this->loadTypeDescAddress(loc, inputBoxTyPair, loweredBox, |
1625 | rewriter); |
1626 | } |
1627 | |
1628 | mlir::Value = |
1629 | this->getExtraFromBox(loc, inputBoxTyPair, loweredBox, rewriter); |
1630 | |
1631 | mlir::Value descriptor; |
1632 | if (auto gpuMod = box->template getParentOfType<mlir::gpu::GPUModuleOp>()) |
1633 | descriptor = |
1634 | populateDescriptor(loc, gpuMod, boxTy, box.getBox().getType(), |
1635 | rewriter, rank, eleSize, cfiTy, typeDesc, |
1636 | /*allocatorIdx=*/kDefaultAllocator, extraField); |
1637 | else if (auto mod = box->template getParentOfType<mlir::ModuleOp>()) |
1638 | descriptor = |
1639 | populateDescriptor(loc, mod, boxTy, box.getBox().getType(), rewriter, |
1640 | rank, eleSize, cfiTy, typeDesc, |
1641 | /*allocatorIdx=*/kDefaultAllocator, extraField); |
1642 | |
1643 | return {boxTy, descriptor, eleSize}; |
1644 | } |
1645 | |
1646 | // Compute the base address of a fir.box given the indices from the slice. |
1647 | // The indices from the "outer" dimensions (every dimension after the first |
1648 | // one (included) that is not a compile time constant) must have been |
1649 | // multiplied with the related extents and added together into \p outerOffset. |
1650 | mlir::Value |
1651 | genBoxOffsetGep(mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, |
1652 | mlir::Value base, mlir::Type llvmBaseObjectType, |
1653 | mlir::Value outerOffset, mlir::ValueRange cstInteriorIndices, |
1654 | mlir::ValueRange componentIndices, |
1655 | std::optional<mlir::Value> substringOffset) const { |
1656 | llvm::SmallVector<mlir::LLVM::GEPArg> gepArgs{outerOffset}; |
1657 | mlir::Type resultTy = llvmBaseObjectType; |
1658 | // Fortran is column major, llvm GEP is row major: reverse the indices here. |
1659 | for (mlir::Value interiorIndex : llvm::reverse(cstInteriorIndices)) { |
1660 | auto arrayTy = mlir::dyn_cast<mlir::LLVM::LLVMArrayType>(resultTy); |
1661 | if (!arrayTy) |
1662 | fir::emitFatalError( |
1663 | loc, |
1664 | "corrupted GEP generated being generated in fir.embox/fir.rebox" ); |
1665 | resultTy = arrayTy.getElementType(); |
1666 | gepArgs.push_back(interiorIndex); |
1667 | } |
1668 | llvm::SmallVector<mlir::Value> gepIndices = |
1669 | convertSubcomponentIndices(loc, resultTy, componentIndices, &resultTy); |
1670 | gepArgs.append(gepIndices.begin(), gepIndices.end()); |
1671 | if (substringOffset) { |
1672 | if (auto arrayTy = mlir::dyn_cast<mlir::LLVM::LLVMArrayType>(resultTy)) { |
1673 | gepArgs.push_back(*substringOffset); |
1674 | resultTy = arrayTy.getElementType(); |
1675 | } else { |
1676 | // If the CHARACTER length is dynamic, the whole base type should have |
1677 | // degenerated to an llvm.ptr<i[width]>, and there should not be any |
1678 | // cstInteriorIndices/componentIndices. The substring offset can be |
1679 | // added to the outterOffset since it applies on the same LLVM type. |
1680 | if (gepArgs.size() != 1) |
1681 | fir::emitFatalError(loc, |
1682 | "corrupted substring GEP in fir.embox/fir.rebox" ); |
1683 | mlir::Type outterOffsetTy = |
1684 | llvm::cast<mlir::Value>(gepArgs[0]).getType(); |
1685 | mlir::Value cast = |
1686 | this->integerCast(loc, rewriter, outterOffsetTy, *substringOffset); |
1687 | |
1688 | gepArgs[0] = rewriter.create<mlir::LLVM::AddOp>( |
1689 | loc, outterOffsetTy, llvm::cast<mlir::Value>(gepArgs[0]), cast); |
1690 | } |
1691 | } |
1692 | mlir::Type llvmPtrTy = ::getLlvmPtrType(resultTy.getContext()); |
1693 | return rewriter.create<mlir::LLVM::GEPOp>( |
1694 | loc, llvmPtrTy, llvmBaseObjectType, base, gepArgs); |
1695 | } |
1696 | |
1697 | template <typename BOX> |
1698 | void |
1699 | getSubcomponentIndices(BOX xbox, mlir::Value memref, |
1700 | mlir::ValueRange operands, |
1701 | mlir::SmallVectorImpl<mlir::Value> &indices) const { |
1702 | // For each field in the path add the offset to base via the args list. |
1703 | // In the most general case, some offsets must be computed since |
1704 | // they are not be known until runtime. |
1705 | if (fir::hasDynamicSize(fir::unwrapSequenceType( |
1706 | fir::unwrapPassByRefType(memref.getType())))) |
1707 | TODO(xbox.getLoc(), |
1708 | "fir.embox codegen dynamic size component in derived type" ); |
1709 | indices.append(operands.begin() + xbox.getSubcomponentOperandIndex(), |
1710 | operands.begin() + xbox.getSubcomponentOperandIndex() + |
1711 | xbox.getSubcomponent().size()); |
1712 | } |
1713 | |
1714 | static bool isInGlobalOp(mlir::ConversionPatternRewriter &rewriter) { |
1715 | auto *thisBlock = rewriter.getInsertionBlock(); |
1716 | return thisBlock && |
1717 | mlir::isa<mlir::LLVM::GlobalOp>(thisBlock->getParentOp()); |
1718 | } |
1719 | |
1720 | /// If the embox is not in a globalOp body, allocate storage for the box; |
1721 | /// store the value inside and return the generated alloca. Return the input |
1722 | /// value otherwise. |
1723 | mlir::Value |
1724 | placeInMemoryIfNotGlobalInit(mlir::ConversionPatternRewriter &rewriter, |
1725 | mlir::Location loc, mlir::Type boxTy, |
1726 | mlir::Value boxValue, |
1727 | bool needDeviceAllocation = false) const { |
1728 | if (isInGlobalOp(rewriter)) |
1729 | return boxValue; |
1730 | mlir::Type llvmBoxTy = boxValue.getType(); |
1731 | mlir::Value storage; |
1732 | if (needDeviceAllocation) { |
1733 | auto mod = boxValue.getDefiningOp()->getParentOfType<mlir::ModuleOp>(); |
1734 | auto baseBoxTy = mlir::dyn_cast<fir::BaseBoxType>(boxTy); |
1735 | storage = |
1736 | genCUFAllocDescriptor(loc, rewriter, mod, baseBoxTy, this->lowerTy()); |
1737 | } else { |
1738 | storage = this->genAllocaAndAddrCastWithType(loc, llvmBoxTy, defaultAlign, |
1739 | rewriter); |
1740 | } |
1741 | auto storeOp = rewriter.create<mlir::LLVM::StoreOp>(loc, boxValue, storage); |
1742 | this->attachTBAATag(storeOp, boxTy, boxTy, nullptr); |
1743 | return storage; |
1744 | } |
1745 | |
1746 | /// Compute the extent of a triplet slice (lb:ub:step). |
1747 | mlir::Value computeTripletExtent(mlir::ConversionPatternRewriter &rewriter, |
1748 | mlir::Location loc, mlir::Value lb, |
1749 | mlir::Value ub, mlir::Value step, |
1750 | mlir::Value zero, mlir::Type type) const { |
1751 | lb = this->integerCast(loc, rewriter, type, lb); |
1752 | ub = this->integerCast(loc, rewriter, type, ub); |
1753 | step = this->integerCast(loc, rewriter, type, step); |
1754 | zero = this->integerCast(loc, rewriter, type, zero); |
1755 | mlir::Value extent = rewriter.create<mlir::LLVM::SubOp>(loc, type, ub, lb); |
1756 | extent = rewriter.create<mlir::LLVM::AddOp>(loc, type, extent, step); |
1757 | extent = rewriter.create<mlir::LLVM::SDivOp>(loc, type, extent, step); |
1758 | // If the resulting extent is negative (`ub-lb` and `step` have different |
1759 | // signs), zero must be returned instead. |
1760 | auto cmp = rewriter.create<mlir::LLVM::ICmpOp>( |
1761 | loc, mlir::LLVM::ICmpPredicate::sgt, extent, zero); |
1762 | return rewriter.create<mlir::LLVM::SelectOp>(loc, cmp, extent, zero); |
1763 | } |
1764 | }; |
1765 | |
1766 | /// Create a generic box on a memory reference. This conversions lowers the |
1767 | /// abstract box to the appropriate, initialized descriptor. |
1768 | struct EmboxOpConversion : public EmboxCommonConversion<fir::EmboxOp> { |
1769 | using EmboxCommonConversion::EmboxCommonConversion; |
1770 | |
1771 | llvm::LogicalResult |
1772 | matchAndRewrite(fir::EmboxOp embox, OpAdaptor adaptor, |
1773 | mlir::ConversionPatternRewriter &rewriter) const override { |
1774 | mlir::ValueRange operands = adaptor.getOperands(); |
1775 | mlir::Value sourceBox; |
1776 | mlir::Type sourceBoxType; |
1777 | if (embox.getSourceBox()) { |
1778 | sourceBox = operands[embox.getSourceBoxOperandIndex()]; |
1779 | sourceBoxType = embox.getSourceBox().getType(); |
1780 | } |
1781 | assert(!embox.getShape() && "There should be no dims on this embox op" ); |
1782 | auto [boxTy, dest, eleSize] = consDescriptorPrefix( |
1783 | embox, fir::unwrapRefType(embox.getMemref().getType()), rewriter, |
1784 | /*rank=*/0, /*substrParams=*/mlir::ValueRange{}, |
1785 | adaptor.getTypeparams(), sourceBox, sourceBoxType); |
1786 | dest = insertBaseAddress(rewriter, embox.getLoc(), dest, operands[0]); |
1787 | if (fir::isDerivedTypeWithLenParams(boxTy)) { |
1788 | TODO(embox.getLoc(), |
1789 | "fir.embox codegen of derived with length parameters" ); |
1790 | return mlir::failure(); |
1791 | } |
1792 | auto result = |
1793 | placeInMemoryIfNotGlobalInit(rewriter, embox.getLoc(), boxTy, dest); |
1794 | rewriter.replaceOp(embox, result); |
1795 | return mlir::success(); |
1796 | } |
1797 | }; |
1798 | |
1799 | static bool isDeviceAllocation(mlir::Value val, mlir::Value adaptorVal) { |
1800 | if (auto loadOp = mlir::dyn_cast_or_null<fir::LoadOp>(val.getDefiningOp())) |
1801 | return isDeviceAllocation(loadOp.getMemref(), {}); |
1802 | if (auto boxAddrOp = |
1803 | mlir::dyn_cast_or_null<fir::BoxAddrOp>(val.getDefiningOp())) |
1804 | return isDeviceAllocation(boxAddrOp.getVal(), {}); |
1805 | if (auto convertOp = |
1806 | mlir::dyn_cast_or_null<fir::ConvertOp>(val.getDefiningOp())) |
1807 | return isDeviceAllocation(convertOp.getValue(), {}); |
1808 | if (!val.getDefiningOp() && adaptorVal) { |
1809 | if (auto blockArg = llvm::cast<mlir::BlockArgument>(adaptorVal)) { |
1810 | if (blockArg.getOwner() && blockArg.getOwner()->getParentOp() && |
1811 | blockArg.getOwner()->isEntryBlock()) { |
1812 | if (auto func = mlir::dyn_cast_or_null<mlir::FunctionOpInterface>( |
1813 | *blockArg.getOwner()->getParentOp())) { |
1814 | auto argAttrs = func.getArgAttrs(blockArg.getArgNumber()); |
1815 | for (auto attr : argAttrs) { |
1816 | if (attr.getName().getValue().ends_with(cuf::getDataAttrName())) { |
1817 | auto dataAttr = |
1818 | mlir::dyn_cast<cuf::DataAttributeAttr>(attr.getValue()); |
1819 | if (dataAttr.getValue() != cuf::DataAttribute::Pinned && |
1820 | dataAttr.getValue() != cuf::DataAttribute::Unified) |
1821 | return true; |
1822 | } |
1823 | } |
1824 | } |
1825 | } |
1826 | } |
1827 | } |
1828 | if (auto callOp = mlir::dyn_cast_or_null<fir::CallOp>(val.getDefiningOp())) |
1829 | if (callOp.getCallee() && |
1830 | (callOp.getCallee().value().getRootReference().getValue().starts_with( |
1831 | RTNAME_STRING(CUFMemAlloc)) || |
1832 | callOp.getCallee().value().getRootReference().getValue().starts_with( |
1833 | RTNAME_STRING(CUFAllocDescriptor)) || |
1834 | callOp.getCallee().value().getRootReference().getValue() == |
1835 | "__tgt_acc_get_deviceptr" )) |
1836 | return true; |
1837 | return false; |
1838 | } |
1839 | |
1840 | /// Create a generic box on a memory reference. |
1841 | struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> { |
1842 | using EmboxCommonConversion::EmboxCommonConversion; |
1843 | |
1844 | llvm::LogicalResult |
1845 | matchAndRewrite(fir::cg::XEmboxOp xbox, OpAdaptor adaptor, |
1846 | mlir::ConversionPatternRewriter &rewriter) const override { |
1847 | mlir::ValueRange operands = adaptor.getOperands(); |
1848 | mlir::Value sourceBox; |
1849 | mlir::Type sourceBoxType; |
1850 | if (xbox.getSourceBox()) { |
1851 | sourceBox = operands[xbox.getSourceBoxOperandIndex()]; |
1852 | sourceBoxType = xbox.getSourceBox().getType(); |
1853 | } |
1854 | auto [boxTy, dest, resultEleSize] = consDescriptorPrefix( |
1855 | xbox, fir::unwrapRefType(xbox.getMemref().getType()), rewriter, |
1856 | xbox.getOutRank(), adaptor.getSubstr(), adaptor.getLenParams(), |
1857 | sourceBox, sourceBoxType); |
1858 | // Generate the triples in the dims field of the descriptor |
1859 | auto i64Ty = mlir::IntegerType::get(xbox.getContext(), 64); |
1860 | assert(!xbox.getShape().empty() && "must have a shape" ); |
1861 | unsigned shapeOffset = xbox.getShapeOperandIndex(); |
1862 | bool hasShift = !xbox.getShift().empty(); |
1863 | unsigned shiftOffset = xbox.getShiftOperandIndex(); |
1864 | bool hasSlice = !xbox.getSlice().empty(); |
1865 | unsigned sliceOffset = xbox.getSliceOperandIndex(); |
1866 | mlir::Location loc = xbox.getLoc(); |
1867 | mlir::Value zero = genConstantIndex(loc, i64Ty, rewriter, 0); |
1868 | mlir::Value one = genConstantIndex(loc, i64Ty, rewriter, 1); |
1869 | mlir::Value prevPtrOff = one; |
1870 | mlir::Type eleTy = boxTy.getEleTy(); |
1871 | const unsigned rank = xbox.getRank(); |
1872 | llvm::SmallVector<mlir::Value> cstInteriorIndices; |
1873 | unsigned constRows = 0; |
1874 | mlir::Value ptrOffset = zero; |
1875 | mlir::Type memEleTy = fir::dyn_cast_ptrEleTy(xbox.getMemref().getType()); |
1876 | assert(mlir::isa<fir::SequenceType>(memEleTy)); |
1877 | auto seqTy = mlir::cast<fir::SequenceType>(memEleTy); |
1878 | mlir::Type seqEleTy = seqTy.getEleTy(); |
1879 | // Adjust the element scaling factor if the element is a dependent type. |
1880 | if (fir::hasDynamicSize(seqEleTy)) { |
1881 | if (auto charTy = mlir::dyn_cast<fir::CharacterType>(seqEleTy)) { |
1882 | // The GEP pointer type decays to llvm.ptr<i[width]>. |
1883 | // The scaling factor is the runtime value of the length. |
1884 | assert(!adaptor.getLenParams().empty()); |
1885 | prevPtrOff = FIROpConversion::integerCast( |
1886 | loc, rewriter, i64Ty, adaptor.getLenParams().back()); |
1887 | } else if (mlir::isa<fir::RecordType>(seqEleTy)) { |
1888 | // prevPtrOff = ; |
1889 | TODO(loc, "generate call to calculate size of PDT" ); |
1890 | } else { |
1891 | fir::emitFatalError(loc, "unexpected dynamic type" ); |
1892 | } |
1893 | } else { |
1894 | constRows = seqTy.getConstantRows(); |
1895 | } |
1896 | |
1897 | const auto hasSubcomp = !xbox.getSubcomponent().empty(); |
1898 | const bool hasSubstr = !xbox.getSubstr().empty(); |
1899 | // Initial element stride that will be use to compute the step in |
1900 | // each dimension. Initially, this is the size of the input element. |
1901 | // Note that when there are no components/substring, the resultEleSize |
1902 | // that was previously computed matches the input element size. |
1903 | mlir::Value prevDimByteStride = resultEleSize; |
1904 | if (hasSubcomp) { |
1905 | // We have a subcomponent. The step value needs to be the number of |
1906 | // bytes per element (which is a derived type). |
1907 | prevDimByteStride = genTypeStrideInBytes( |
1908 | loc, i64Ty, rewriter, convertType(seqEleTy), getDataLayout()); |
1909 | } else if (hasSubstr) { |
1910 | // We have a substring. The step value needs to be the number of bytes |
1911 | // per CHARACTER element. |
1912 | auto charTy = mlir::cast<fir::CharacterType>(seqEleTy); |
1913 | if (fir::hasDynamicSize(charTy)) { |
1914 | prevDimByteStride = |
1915 | getCharacterByteSize(loc, rewriter, charTy, adaptor.getLenParams()); |
1916 | } else { |
1917 | prevDimByteStride = genConstantIndex( |
1918 | loc, i64Ty, rewriter, |
1919 | charTy.getLen() * lowerTy().characterBitsize(charTy) / 8); |
1920 | } |
1921 | } |
1922 | |
1923 | // Process the array subspace arguments (shape, shift, etc.), if any, |
1924 | // translating everything to values in the descriptor wherever the entity |
1925 | // has a dynamic array dimension. |
1926 | for (unsigned di = 0, descIdx = 0; di < rank; ++di) { |
1927 | mlir::Value extent = |
1928 | integerCast(loc, rewriter, i64Ty, operands[shapeOffset]); |
1929 | mlir::Value outerExtent = extent; |
1930 | bool skipNext = false; |
1931 | if (hasSlice) { |
1932 | mlir::Value off = |
1933 | integerCast(loc, rewriter, i64Ty, operands[sliceOffset]); |
1934 | mlir::Value adj = one; |
1935 | if (hasShift) |
1936 | adj = integerCast(loc, rewriter, i64Ty, operands[shiftOffset]); |
1937 | auto ao = rewriter.create<mlir::LLVM::SubOp>(loc, i64Ty, off, adj); |
1938 | if (constRows > 0) { |
1939 | cstInteriorIndices.push_back(ao); |
1940 | } else { |
1941 | auto dimOff = |
1942 | rewriter.create<mlir::LLVM::MulOp>(loc, i64Ty, ao, prevPtrOff); |
1943 | ptrOffset = |
1944 | rewriter.create<mlir::LLVM::AddOp>(loc, i64Ty, dimOff, ptrOffset); |
1945 | } |
1946 | if (mlir::isa_and_nonnull<fir::UndefOp>( |
1947 | xbox.getSlice()[3 * di + 1].getDefiningOp())) { |
1948 | // This dimension contains a scalar expression in the array slice op. |
1949 | // The dimension is loop invariant, will be dropped, and will not |
1950 | // appear in the descriptor. |
1951 | skipNext = true; |
1952 | } |
1953 | } |
1954 | if (!skipNext) { |
1955 | // store extent |
1956 | if (hasSlice) |
1957 | extent = computeTripletExtent(rewriter, loc, operands[sliceOffset], |
1958 | operands[sliceOffset + 1], |
1959 | operands[sliceOffset + 2], zero, i64Ty); |
1960 | // Lower bound is normalized to 0 for BIND(C) interoperability. |
1961 | mlir::Value lb = zero; |
1962 | const bool isaPointerOrAllocatable = |
1963 | mlir::isa<fir::PointerType, fir::HeapType>(eleTy); |
1964 | // Lower bound is defaults to 1 for POINTER, ALLOCATABLE, and |
1965 | // denormalized descriptors. |
1966 | if (isaPointerOrAllocatable || !normalizedLowerBound(xbox)) |
1967 | lb = one; |
1968 | // If there is a shifted origin, and no fir.slice, and this is not |
1969 | // a normalized descriptor then use the value from the shift op as |
1970 | // the lower bound. |
1971 | if (hasShift && !(hasSlice || hasSubcomp || hasSubstr) && |
1972 | (isaPointerOrAllocatable || !normalizedLowerBound(xbox))) { |
1973 | lb = integerCast(loc, rewriter, i64Ty, operands[shiftOffset]); |
1974 | auto extentIsEmpty = rewriter.create<mlir::LLVM::ICmpOp>( |
1975 | loc, mlir::LLVM::ICmpPredicate::eq, extent, zero); |
1976 | lb = rewriter.create<mlir::LLVM::SelectOp>(loc, extentIsEmpty, one, |
1977 | lb); |
1978 | } |
1979 | dest = insertLowerBound(rewriter, loc, dest, descIdx, lb); |
1980 | |
1981 | dest = insertExtent(rewriter, loc, dest, descIdx, extent); |
1982 | |
1983 | // store step (scaled by shaped extent) |
1984 | mlir::Value step = prevDimByteStride; |
1985 | if (hasSlice) { |
1986 | mlir::Value sliceStep = |
1987 | integerCast(loc, rewriter, i64Ty, operands[sliceOffset + 2]); |
1988 | step = |
1989 | rewriter.create<mlir::LLVM::MulOp>(loc, i64Ty, step, sliceStep); |
1990 | } |
1991 | dest = insertStride(rewriter, loc, dest, descIdx, step); |
1992 | ++descIdx; |
1993 | } |
1994 | |
1995 | // compute the stride and offset for the next natural dimension |
1996 | prevDimByteStride = rewriter.create<mlir::LLVM::MulOp>( |
1997 | loc, i64Ty, prevDimByteStride, outerExtent); |
1998 | if (constRows == 0) |
1999 | prevPtrOff = rewriter.create<mlir::LLVM::MulOp>(loc, i64Ty, prevPtrOff, |
2000 | outerExtent); |
2001 | else |
2002 | --constRows; |
2003 | |
2004 | // increment iterators |
2005 | ++shapeOffset; |
2006 | if (hasShift) |
2007 | ++shiftOffset; |
2008 | if (hasSlice) |
2009 | sliceOffset += 3; |
2010 | } |
2011 | mlir::Value base = adaptor.getMemref(); |
2012 | if (hasSlice || hasSubcomp || hasSubstr) { |
2013 | // Shift the base address. |
2014 | llvm::SmallVector<mlir::Value> fieldIndices; |
2015 | std::optional<mlir::Value> substringOffset; |
2016 | if (hasSubcomp) |
2017 | getSubcomponentIndices(xbox, xbox.getMemref(), operands, fieldIndices); |
2018 | if (hasSubstr) |
2019 | substringOffset = operands[xbox.getSubstrOperandIndex()]; |
2020 | mlir::Type llvmBaseType = |
2021 | convertType(fir::unwrapRefType(xbox.getMemref().getType())); |
2022 | base = genBoxOffsetGep(rewriter, loc, base, llvmBaseType, ptrOffset, |
2023 | cstInteriorIndices, fieldIndices, substringOffset); |
2024 | } |
2025 | dest = insertBaseAddress(rewriter, loc, dest, base); |
2026 | if (fir::isDerivedTypeWithLenParams(boxTy)) |
2027 | TODO(loc, "fir.embox codegen of derived with length parameters" ); |
2028 | mlir::Value result = placeInMemoryIfNotGlobalInit( |
2029 | rewriter, loc, boxTy, dest, |
2030 | isDeviceAllocation(xbox.getMemref(), adaptor.getMemref())); |
2031 | rewriter.replaceOp(xbox, result); |
2032 | return mlir::success(); |
2033 | } |
2034 | |
2035 | /// Return true if `xbox` has a normalized lower bounds attribute. A box value |
2036 | /// that is neither a POINTER nor an ALLOCATABLE should be normalized to a |
2037 | /// zero origin lower bound for interoperability with BIND(C). |
2038 | inline static bool normalizedLowerBound(fir::cg::XEmboxOp xbox) { |
2039 | return xbox->hasAttr(fir::getNormalizedLowerBoundAttrName()); |
2040 | } |
2041 | }; |
2042 | |
2043 | /// Create a new box given a box reference. |
2044 | struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> { |
2045 | using EmboxCommonConversion::EmboxCommonConversion; |
2046 | |
2047 | llvm::LogicalResult |
2048 | matchAndRewrite(fir::cg::XReboxOp rebox, OpAdaptor adaptor, |
2049 | mlir::ConversionPatternRewriter &rewriter) const override { |
2050 | mlir::Location loc = rebox.getLoc(); |
2051 | mlir::Type idxTy = lowerTy().indexType(); |
2052 | mlir::Value loweredBox = adaptor.getOperands()[0]; |
2053 | mlir::ValueRange operands = adaptor.getOperands(); |
2054 | |
2055 | // Inside a fir.global, the input box was produced as an llvm.struct<> |
2056 | // because objects cannot be handled in memory inside a fir.global body that |
2057 | // must be constant foldable. However, the type translation are not |
2058 | // contextual, so the fir.box<T> type of the operation that produced the |
2059 | // fir.box was translated to an llvm.ptr<llvm.struct<>> and the MLIR pass |
2060 | // manager inserted a builtin.unrealized_conversion_cast that was inserted |
2061 | // and needs to be removed here. |
2062 | if (isInGlobalOp(rewriter)) |
2063 | if (auto unrealizedCast = |
2064 | loweredBox.getDefiningOp<mlir::UnrealizedConversionCastOp>()) |
2065 | loweredBox = unrealizedCast.getInputs()[0]; |
2066 | |
2067 | TypePair inputBoxTyPair = getBoxTypePair(rebox.getBox().getType()); |
2068 | |
2069 | // Create new descriptor and fill its non-shape related data. |
2070 | llvm::SmallVector<mlir::Value, 2> lenParams; |
2071 | mlir::Type inputEleTy = getInputEleTy(rebox); |
2072 | if (auto charTy = mlir::dyn_cast<fir::CharacterType>(inputEleTy)) { |
2073 | if (charTy.hasConstantLen()) { |
2074 | mlir::Value len = |
2075 | genConstantIndex(loc, idxTy, rewriter, charTy.getLen()); |
2076 | lenParams.emplace_back(len); |
2077 | } else { |
2078 | mlir::Value len = getElementSizeFromBox(loc, idxTy, inputBoxTyPair, |
2079 | loweredBox, rewriter); |
2080 | if (charTy.getFKind() != 1) { |
2081 | assert(!isInGlobalOp(rewriter) && |
2082 | "character target in global op must have constant length" ); |
2083 | mlir::Value width = |
2084 | genConstantIndex(loc, idxTy, rewriter, charTy.getFKind()); |
2085 | len = rewriter.create<mlir::LLVM::SDivOp>(loc, idxTy, len, width); |
2086 | } |
2087 | lenParams.emplace_back(len); |
2088 | } |
2089 | } else if (auto recTy = mlir::dyn_cast<fir::RecordType>(inputEleTy)) { |
2090 | if (recTy.getNumLenParams() != 0) |
2091 | TODO(loc, "reboxing descriptor of derived type with length parameters" ); |
2092 | } |
2093 | |
2094 | // Rebox on polymorphic entities needs to carry over the dynamic type. |
2095 | mlir::Value typeDescAddr; |
2096 | if (mlir::isa<fir::ClassType>(inputBoxTyPair.fir) && |
2097 | mlir::isa<fir::ClassType>(rebox.getType())) |
2098 | typeDescAddr = |
2099 | loadTypeDescAddress(loc, inputBoxTyPair, loweredBox, rewriter); |
2100 | |
2101 | auto [boxTy, dest, eleSize] = |
2102 | consDescriptorPrefix(rebox, loweredBox, rewriter, rebox.getOutRank(), |
2103 | adaptor.getSubstr(), lenParams, typeDescAddr); |
2104 | |
2105 | // Read input extents, strides, and base address |
2106 | llvm::SmallVector<mlir::Value> inputExtents; |
2107 | llvm::SmallVector<mlir::Value> inputStrides; |
2108 | const unsigned inputRank = rebox.getRank(); |
2109 | for (unsigned dim = 0; dim < inputRank; ++dim) { |
2110 | llvm::SmallVector<mlir::Value, 3> dimInfo = |
2111 | getDimsFromBox(loc, {idxTy, idxTy, idxTy}, inputBoxTyPair, loweredBox, |
2112 | dim, rewriter); |
2113 | inputExtents.emplace_back(dimInfo[1]); |
2114 | inputStrides.emplace_back(dimInfo[2]); |
2115 | } |
2116 | |
2117 | mlir::Value baseAddr = |
2118 | getBaseAddrFromBox(loc, inputBoxTyPair, loweredBox, rewriter); |
2119 | |
2120 | if (!rebox.getSlice().empty() || !rebox.getSubcomponent().empty()) |
2121 | return sliceBox(rebox, adaptor, boxTy, dest, baseAddr, inputExtents, |
2122 | inputStrides, operands, rewriter); |
2123 | return reshapeBox(rebox, adaptor, boxTy, dest, baseAddr, inputExtents, |
2124 | inputStrides, operands, rewriter); |
2125 | } |
2126 | |
2127 | private: |
2128 | /// Write resulting shape and base address in descriptor, and replace rebox |
2129 | /// op. |
2130 | llvm::LogicalResult |
2131 | finalizeRebox(fir::cg::XReboxOp rebox, OpAdaptor adaptor, |
2132 | mlir::Type destBoxTy, mlir::Value dest, mlir::Value base, |
2133 | mlir::ValueRange lbounds, mlir::ValueRange extents, |
2134 | mlir::ValueRange strides, |
2135 | mlir::ConversionPatternRewriter &rewriter) const { |
2136 | mlir::Location loc = rebox.getLoc(); |
2137 | mlir::Value zero = |
2138 | genConstantIndex(loc, lowerTy().indexType(), rewriter, 0); |
2139 | mlir::Value one = genConstantIndex(loc, lowerTy().indexType(), rewriter, 1); |
2140 | for (auto iter : llvm::enumerate(llvm::zip(extents, strides))) { |
2141 | mlir::Value extent = std::get<0>(iter.value()); |
2142 | unsigned dim = iter.index(); |
2143 | mlir::Value lb = one; |
2144 | if (!lbounds.empty()) { |
2145 | lb = integerCast(loc, rewriter, lowerTy().indexType(), lbounds[dim]); |
2146 | auto extentIsEmpty = rewriter.create<mlir::LLVM::ICmpOp>( |
2147 | loc, mlir::LLVM::ICmpPredicate::eq, extent, zero); |
2148 | lb = rewriter.create<mlir::LLVM::SelectOp>(loc, extentIsEmpty, one, lb); |
2149 | }; |
2150 | dest = insertLowerBound(rewriter, loc, dest, dim, lb); |
2151 | dest = insertExtent(rewriter, loc, dest, dim, extent); |
2152 | dest = insertStride(rewriter, loc, dest, dim, std::get<1>(iter.value())); |
2153 | } |
2154 | dest = insertBaseAddress(rewriter, loc, dest, base); |
2155 | mlir::Value result = placeInMemoryIfNotGlobalInit( |
2156 | rewriter, rebox.getLoc(), destBoxTy, dest, |
2157 | isDeviceAllocation(rebox.getBox(), adaptor.getBox())); |
2158 | rewriter.replaceOp(rebox, result); |
2159 | return mlir::success(); |
2160 | } |
2161 | |
2162 | // Apply slice given the base address, extents and strides of the input box. |
2163 | llvm::LogicalResult |
2164 | sliceBox(fir::cg::XReboxOp rebox, OpAdaptor adaptor, mlir::Type destBoxTy, |
2165 | mlir::Value dest, mlir::Value base, mlir::ValueRange inputExtents, |
2166 | mlir::ValueRange inputStrides, mlir::ValueRange operands, |
2167 | mlir::ConversionPatternRewriter &rewriter) const { |
2168 | mlir::Location loc = rebox.getLoc(); |
2169 | mlir::Type byteTy = ::getI8Type(rebox.getContext()); |
2170 | mlir::Type idxTy = lowerTy().indexType(); |
2171 | mlir::Value zero = genConstantIndex(loc, idxTy, rewriter, 0); |
2172 | // Apply subcomponent and substring shift on base address. |
2173 | if (!rebox.getSubcomponent().empty() || !rebox.getSubstr().empty()) { |
2174 | // Cast to inputEleTy* so that a GEP can be used. |
2175 | mlir::Type inputEleTy = getInputEleTy(rebox); |
2176 | mlir::Type llvmBaseObjectType = convertType(inputEleTy); |
2177 | llvm::SmallVector<mlir::Value> fieldIndices; |
2178 | std::optional<mlir::Value> substringOffset; |
2179 | if (!rebox.getSubcomponent().empty()) |
2180 | getSubcomponentIndices(rebox, rebox.getBox(), operands, fieldIndices); |
2181 | if (!rebox.getSubstr().empty()) |
2182 | substringOffset = operands[rebox.getSubstrOperandIndex()]; |
2183 | base = genBoxOffsetGep(rewriter, loc, base, llvmBaseObjectType, zero, |
2184 | /*cstInteriorIndices=*/std::nullopt, fieldIndices, |
2185 | substringOffset); |
2186 | } |
2187 | |
2188 | if (rebox.getSlice().empty()) |
2189 | // The array section is of the form array[%component][substring], keep |
2190 | // the input array extents and strides. |
2191 | return finalizeRebox(rebox, adaptor, destBoxTy, dest, base, |
2192 | /*lbounds*/ std::nullopt, inputExtents, inputStrides, |
2193 | rewriter); |
2194 | |
2195 | // The slice is of the form array(i:j:k)[%component]. Compute new extents |
2196 | // and strides. |
2197 | llvm::SmallVector<mlir::Value> slicedExtents; |
2198 | llvm::SmallVector<mlir::Value> slicedStrides; |
2199 | mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1); |
2200 | const bool sliceHasOrigins = !rebox.getShift().empty(); |
2201 | unsigned sliceOps = rebox.getSliceOperandIndex(); |
2202 | unsigned shiftOps = rebox.getShiftOperandIndex(); |
2203 | auto strideOps = inputStrides.begin(); |
2204 | const unsigned inputRank = inputStrides.size(); |
2205 | for (unsigned i = 0; i < inputRank; |
2206 | ++i, ++strideOps, ++shiftOps, sliceOps += 3) { |
2207 | mlir::Value sliceLb = |
2208 | integerCast(loc, rewriter, idxTy, operands[sliceOps]); |
2209 | mlir::Value inputStride = *strideOps; // already idxTy |
2210 | // Apply origin shift: base += (lb-shift)*input_stride |
2211 | mlir::Value sliceOrigin = |
2212 | sliceHasOrigins |
2213 | ? integerCast(loc, rewriter, idxTy, operands[shiftOps]) |
2214 | : one; |
2215 | mlir::Value diff = |
2216 | rewriter.create<mlir::LLVM::SubOp>(loc, idxTy, sliceLb, sliceOrigin); |
2217 | mlir::Value offset = |
2218 | rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, diff, inputStride); |
2219 | // Strides from the fir.box are in bytes. |
2220 | base = genGEP(loc, byteTy, rewriter, base, offset); |
2221 | // Apply upper bound and step if this is a triplet. Otherwise, the |
2222 | // dimension is dropped and no extents/strides are computed. |
2223 | mlir::Value upper = operands[sliceOps + 1]; |
2224 | const bool isTripletSlice = |
2225 | !mlir::isa_and_nonnull<mlir::LLVM::UndefOp>(upper.getDefiningOp()); |
2226 | if (isTripletSlice) { |
2227 | mlir::Value step = |
2228 | integerCast(loc, rewriter, idxTy, operands[sliceOps + 2]); |
2229 | // extent = ub-lb+step/step |
2230 | mlir::Value sliceUb = integerCast(loc, rewriter, idxTy, upper); |
2231 | mlir::Value extent = computeTripletExtent(rewriter, loc, sliceLb, |
2232 | sliceUb, step, zero, idxTy); |
2233 | slicedExtents.emplace_back(extent); |
2234 | // stride = step*input_stride |
2235 | mlir::Value stride = |
2236 | rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, step, inputStride); |
2237 | slicedStrides.emplace_back(stride); |
2238 | } |
2239 | } |
2240 | return finalizeRebox(rebox, adaptor, destBoxTy, dest, base, |
2241 | /*lbounds*/ std::nullopt, slicedExtents, slicedStrides, |
2242 | rewriter); |
2243 | } |
2244 | |
2245 | /// Apply a new shape to the data described by a box given the base address, |
2246 | /// extents and strides of the box. |
2247 | llvm::LogicalResult |
2248 | reshapeBox(fir::cg::XReboxOp rebox, OpAdaptor adaptor, mlir::Type destBoxTy, |
2249 | mlir::Value dest, mlir::Value base, mlir::ValueRange inputExtents, |
2250 | mlir::ValueRange inputStrides, mlir::ValueRange operands, |
2251 | mlir::ConversionPatternRewriter &rewriter) const { |
2252 | mlir::ValueRange reboxShifts{ |
2253 | operands.begin() + rebox.getShiftOperandIndex(), |
2254 | operands.begin() + rebox.getShiftOperandIndex() + |
2255 | rebox.getShift().size()}; |
2256 | if (rebox.getShape().empty()) { |
2257 | // Only setting new lower bounds. |
2258 | return finalizeRebox(rebox, adaptor, destBoxTy, dest, base, reboxShifts, |
2259 | inputExtents, inputStrides, rewriter); |
2260 | } |
2261 | |
2262 | mlir::Location loc = rebox.getLoc(); |
2263 | |
2264 | llvm::SmallVector<mlir::Value> newStrides; |
2265 | llvm::SmallVector<mlir::Value> newExtents; |
2266 | mlir::Type idxTy = lowerTy().indexType(); |
2267 | // First stride from input box is kept. The rest is assumed contiguous |
2268 | // (it is not possible to reshape otherwise). If the input is scalar, |
2269 | // which may be OK if all new extents are ones, the stride does not |
2270 | // matter, use one. |
2271 | mlir::Value stride = inputStrides.empty() |
2272 | ? genConstantIndex(loc, idxTy, rewriter, 1) |
2273 | : inputStrides[0]; |
2274 | for (unsigned i = 0; i < rebox.getShape().size(); ++i) { |
2275 | mlir::Value rawExtent = operands[rebox.getShapeOperandIndex() + i]; |
2276 | mlir::Value extent = integerCast(loc, rewriter, idxTy, rawExtent); |
2277 | newExtents.emplace_back(extent); |
2278 | newStrides.emplace_back(stride); |
2279 | // nextStride = extent * stride; |
2280 | stride = rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, extent, stride); |
2281 | } |
2282 | return finalizeRebox(rebox, adaptor, destBoxTy, dest, base, reboxShifts, |
2283 | newExtents, newStrides, rewriter); |
2284 | } |
2285 | |
2286 | /// Return scalar element type of the input box. |
2287 | static mlir::Type getInputEleTy(fir::cg::XReboxOp rebox) { |
2288 | auto ty = fir::dyn_cast_ptrOrBoxEleTy(rebox.getBox().getType()); |
2289 | if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) |
2290 | return seqTy.getEleTy(); |
2291 | return ty; |
2292 | } |
2293 | }; |
2294 | |
2295 | /// Lower `fir.emboxproc` operation. Creates a procedure box. |
2296 | /// TODO: Part of supporting Fortran 2003 procedure pointers. |
2297 | struct EmboxProcOpConversion : public fir::FIROpConversion<fir::EmboxProcOp> { |
2298 | using FIROpConversion::FIROpConversion; |
2299 | |
2300 | llvm::LogicalResult |
2301 | matchAndRewrite(fir::EmboxProcOp emboxproc, OpAdaptor adaptor, |
2302 | mlir::ConversionPatternRewriter &rewriter) const override { |
2303 | TODO(emboxproc.getLoc(), "fir.emboxproc codegen" ); |
2304 | return mlir::failure(); |
2305 | } |
2306 | }; |
2307 | |
2308 | // Code shared between insert_value and extract_value Ops. |
2309 | struct ValueOpCommon { |
2310 | // Translate the arguments pertaining to any multidimensional array to |
2311 | // row-major order for LLVM-IR. |
2312 | static void toRowMajor(llvm::SmallVectorImpl<int64_t> &indices, |
2313 | mlir::Type ty) { |
2314 | assert(ty && "type is null" ); |
2315 | const auto end = indices.size(); |
2316 | for (std::remove_const_t<decltype(end)> i = 0; i < end; ++i) { |
2317 | if (auto seq = mlir::dyn_cast<mlir::LLVM::LLVMArrayType>(ty)) { |
2318 | const auto dim = getDimension(seq); |
2319 | if (dim > 1) { |
2320 | auto ub = std::min(i + dim, end); |
2321 | std::reverse(indices.begin() + i, indices.begin() + ub); |
2322 | i += dim - 1; |
2323 | } |
2324 | ty = getArrayElementType(seq); |
2325 | } else if (auto st = mlir::dyn_cast<mlir::LLVM::LLVMStructType>(ty)) { |
2326 | ty = st.getBody()[indices[i]]; |
2327 | } else { |
2328 | llvm_unreachable("index into invalid type" ); |
2329 | } |
2330 | } |
2331 | } |
2332 | |
2333 | static llvm::SmallVector<int64_t> |
2334 | collectIndices(mlir::ConversionPatternRewriter &rewriter, |
2335 | mlir::ArrayAttr arrAttr) { |
2336 | llvm::SmallVector<int64_t> indices; |
2337 | for (auto i = arrAttr.begin(), e = arrAttr.end(); i != e; ++i) { |
2338 | if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(*i)) { |
2339 | indices.push_back(Elt: intAttr.getInt()); |
2340 | } else { |
2341 | auto fieldName = mlir::cast<mlir::StringAttr>(*i).getValue(); |
2342 | ++i; |
2343 | auto ty = mlir::cast<mlir::TypeAttr>(*i).getValue(); |
2344 | auto index = mlir::cast<fir::RecordType>(ty).getFieldIndex(fieldName); |
2345 | indices.push_back(Elt: index); |
2346 | } |
2347 | } |
2348 | return indices; |
2349 | } |
2350 | |
2351 | private: |
2352 | static mlir::Type getArrayElementType(mlir::LLVM::LLVMArrayType ty) { |
2353 | auto eleTy = ty.getElementType(); |
2354 | while (auto arrTy = mlir::dyn_cast<mlir::LLVM::LLVMArrayType>(eleTy)) |
2355 | eleTy = arrTy.getElementType(); |
2356 | return eleTy; |
2357 | } |
2358 | }; |
2359 | |
2360 | namespace { |
2361 | /// Extract a subobject value from an ssa-value of aggregate type |
2362 | struct |
2363 | : public fir::FIROpAndTypeConversion<fir::ExtractValueOp>, |
2364 | public ValueOpCommon { |
2365 | using FIROpAndTypeConversion::FIROpAndTypeConversion; |
2366 | |
2367 | llvm::LogicalResult |
2368 | (fir::ExtractValueOp , mlir::Type ty, OpAdaptor adaptor, |
2369 | mlir::ConversionPatternRewriter &rewriter) const override { |
2370 | mlir::ValueRange operands = adaptor.getOperands(); |
2371 | auto indices = collectIndices(rewriter, extractVal.getCoor()); |
2372 | toRowMajor(indices, operands[0].getType()); |
2373 | rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>( |
2374 | extractVal, operands[0], indices); |
2375 | return mlir::success(); |
2376 | } |
2377 | }; |
2378 | |
2379 | /// InsertValue is the generalized instruction for the composition of new |
2380 | /// aggregate type values. |
2381 | struct InsertValueOpConversion |
2382 | : public mlir::OpConversionPattern<fir::InsertValueOp>, |
2383 | public ValueOpCommon { |
2384 | using OpConversionPattern::OpConversionPattern; |
2385 | |
2386 | llvm::LogicalResult |
2387 | matchAndRewrite(fir::InsertValueOp insertVal, OpAdaptor adaptor, |
2388 | mlir::ConversionPatternRewriter &rewriter) const override { |
2389 | mlir::ValueRange operands = adaptor.getOperands(); |
2390 | auto indices = collectIndices(rewriter, insertVal.getCoor()); |
2391 | toRowMajor(indices, operands[0].getType()); |
2392 | rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>( |
2393 | insertVal, operands[0], operands[1], indices); |
2394 | return mlir::success(); |
2395 | } |
2396 | }; |
2397 | |
2398 | /// InsertOnRange inserts a value into a sequence over a range of offsets. |
2399 | struct InsertOnRangeOpConversion |
2400 | : public fir::FIROpAndTypeConversion<fir::InsertOnRangeOp> { |
2401 | using FIROpAndTypeConversion::FIROpAndTypeConversion; |
2402 | |
2403 | // Increments an array of subscripts in a row major fasion. |
2404 | void incrementSubscripts(llvm::ArrayRef<int64_t> dims, |
2405 | llvm::SmallVectorImpl<int64_t> &subscripts) const { |
2406 | for (size_t i = dims.size(); i > 0; --i) { |
2407 | if (++subscripts[i - 1] < dims[i - 1]) { |
2408 | return; |
2409 | } |
2410 | subscripts[i - 1] = 0; |
2411 | } |
2412 | } |
2413 | |
2414 | llvm::LogicalResult |
2415 | doRewrite(fir::InsertOnRangeOp range, mlir::Type ty, OpAdaptor adaptor, |
2416 | mlir::ConversionPatternRewriter &rewriter) const override { |
2417 | |
2418 | auto arrayType = adaptor.getSeq().getType(); |
2419 | |
2420 | // Iteratively extract the array dimensions from the type. |
2421 | llvm::SmallVector<std::int64_t> dims; |
2422 | mlir::Type type = arrayType; |
2423 | while (auto t = mlir::dyn_cast<mlir::LLVM::LLVMArrayType>(type)) { |
2424 | dims.push_back(Elt: t.getNumElements()); |
2425 | type = t.getElementType(); |
2426 | } |
2427 | |
2428 | // Avoid generating long insert chain that are very slow to fold back |
2429 | // (which is required in globals when later generating LLVM IR). Attempt to |
2430 | // fold the inserted element value to an attribute and build an ArrayAttr |
2431 | // for the resulting array. |
2432 | if (range.isFullRange()) { |
2433 | llvm::FailureOr<mlir::Attribute> cst = |
2434 | fir::tryFoldingLLVMInsertChain(adaptor.getVal(), rewriter); |
2435 | if (llvm::succeeded(cst)) { |
2436 | mlir::Attribute dimVal = *cst; |
2437 | for (auto dim : llvm::reverse(C&: dims)) { |
2438 | // Use std::vector in case the number of elements is big. |
2439 | std::vector<mlir::Attribute> elements(dim, dimVal); |
2440 | dimVal = mlir::ArrayAttr::get(range.getContext(), elements); |
2441 | } |
2442 | // Replace insert chain with constant. |
2443 | rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(range, arrayType, |
2444 | dimVal); |
2445 | return mlir::success(); |
2446 | } |
2447 | } |
2448 | |
2449 | // The inserted value cannot be folded to an attribute, turn the |
2450 | // insert_range into an llvm.insertvalue chain. |
2451 | llvm::SmallVector<std::int64_t> lBounds; |
2452 | llvm::SmallVector<std::int64_t> uBounds; |
2453 | |
2454 | // Unzip the upper and lower bound and convert to a row major format. |
2455 | mlir::DenseIntElementsAttr coor = range.getCoor(); |
2456 | auto reversedCoor = llvm::reverse(coor.getValues<int64_t>()); |
2457 | for (auto i = reversedCoor.begin(), e = reversedCoor.end(); i != e; ++i) { |
2458 | uBounds.push_back(Elt: *i++); |
2459 | lBounds.push_back(Elt: *i); |
2460 | } |
2461 | |
2462 | auto &subscripts = lBounds; |
2463 | auto loc = range.getLoc(); |
2464 | mlir::Value lastOp = adaptor.getSeq(); |
2465 | mlir::Value insertVal = adaptor.getVal(); |
2466 | |
2467 | while (subscripts != uBounds) { |
2468 | lastOp = rewriter.create<mlir::LLVM::InsertValueOp>( |
2469 | loc, lastOp, insertVal, subscripts); |
2470 | |
2471 | incrementSubscripts(dims, subscripts); |
2472 | } |
2473 | |
2474 | rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>( |
2475 | range, lastOp, insertVal, subscripts); |
2476 | |
2477 | return mlir::success(); |
2478 | } |
2479 | }; |
2480 | } // namespace |
2481 | |
2482 | namespace { |
2483 | /// XArrayCoor is the address arithmetic on a dynamically shaped, sliced, |
2484 | /// shifted etc. array. |
2485 | /// (See the static restriction on coordinate_of.) array_coor determines the |
2486 | /// coordinate (location) of a specific element. |
2487 | struct XArrayCoorOpConversion |
2488 | : public fir::FIROpAndTypeConversion<fir::cg::XArrayCoorOp> { |
2489 | using FIROpAndTypeConversion::FIROpAndTypeConversion; |
2490 | |
2491 | llvm::LogicalResult |
2492 | doRewrite(fir::cg::XArrayCoorOp coor, mlir::Type llvmPtrTy, OpAdaptor adaptor, |
2493 | mlir::ConversionPatternRewriter &rewriter) const override { |
2494 | auto loc = coor.getLoc(); |
2495 | mlir::ValueRange operands = adaptor.getOperands(); |
2496 | unsigned rank = coor.getRank(); |
2497 | assert(coor.getIndices().size() == rank); |
2498 | assert(coor.getShape().empty() || coor.getShape().size() == rank); |
2499 | assert(coor.getShift().empty() || coor.getShift().size() == rank); |
2500 | assert(coor.getSlice().empty() || coor.getSlice().size() == 3 * rank); |
2501 | mlir::Type idxTy = lowerTy().indexType(); |
2502 | unsigned indexOffset = coor.getIndicesOperandIndex(); |
2503 | unsigned shapeOffset = coor.getShapeOperandIndex(); |
2504 | unsigned shiftOffset = coor.getShiftOperandIndex(); |
2505 | unsigned sliceOffset = coor.getSliceOperandIndex(); |
2506 | auto sliceOps = coor.getSlice().begin(); |
2507 | mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1); |
2508 | mlir::Value prevExt = one; |
2509 | mlir::Value offset = genConstantIndex(loc, idxTy, rewriter, 0); |
2510 | const bool isShifted = !coor.getShift().empty(); |
2511 | const bool isSliced = !coor.getSlice().empty(); |
2512 | const bool baseIsBoxed = |
2513 | mlir::isa<fir::BaseBoxType>(coor.getMemref().getType()); |
2514 | TypePair baseBoxTyPair = |
2515 | baseIsBoxed ? getBoxTypePair(coor.getMemref().getType()) : TypePair{}; |
2516 | mlir::LLVM::IntegerOverflowFlags nsw = |
2517 | mlir::LLVM::IntegerOverflowFlags::nsw; |
2518 | |
2519 | // For each dimension of the array, generate the offset calculation. |
2520 | for (unsigned i = 0; i < rank; ++i, ++indexOffset, ++shapeOffset, |
2521 | ++shiftOffset, sliceOffset += 3, sliceOps += 3) { |
2522 | mlir::Value index = |
2523 | integerCast(loc, rewriter, idxTy, operands[indexOffset]); |
2524 | mlir::Value lb = |
2525 | isShifted ? integerCast(loc, rewriter, idxTy, operands[shiftOffset]) |
2526 | : one; |
2527 | mlir::Value step = one; |
2528 | bool normalSlice = isSliced; |
2529 | // Compute zero based index in dimension i of the element, applying |
2530 | // potential triplets and lower bounds. |
2531 | if (isSliced) { |
2532 | mlir::Value originalUb = *(sliceOps + 1); |
2533 | normalSlice = |
2534 | !mlir::isa_and_nonnull<fir::UndefOp>(originalUb.getDefiningOp()); |
2535 | if (normalSlice) |
2536 | step = integerCast(loc, rewriter, idxTy, operands[sliceOffset + 2]); |
2537 | } |
2538 | auto idx = rewriter.create<mlir::LLVM::SubOp>(loc, idxTy, index, lb, nsw); |
2539 | mlir::Value diff = |
2540 | rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, idx, step, nsw); |
2541 | if (normalSlice) { |
2542 | mlir::Value sliceLb = |
2543 | integerCast(loc, rewriter, idxTy, operands[sliceOffset]); |
2544 | auto adj = |
2545 | rewriter.create<mlir::LLVM::SubOp>(loc, idxTy, sliceLb, lb, nsw); |
2546 | diff = rewriter.create<mlir::LLVM::AddOp>(loc, idxTy, diff, adj, nsw); |
2547 | } |
2548 | // Update the offset given the stride and the zero based index `diff` |
2549 | // that was just computed. |
2550 | if (baseIsBoxed) { |
2551 | // Use stride in bytes from the descriptor. |
2552 | mlir::Value stride = |
2553 | getStrideFromBox(loc, baseBoxTyPair, operands[0], i, rewriter); |
2554 | auto sc = |
2555 | rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, diff, stride, nsw); |
2556 | offset = |
2557 | rewriter.create<mlir::LLVM::AddOp>(loc, idxTy, sc, offset, nsw); |
2558 | } else { |
2559 | // Use stride computed at last iteration. |
2560 | auto sc = |
2561 | rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, diff, prevExt, nsw); |
2562 | offset = |
2563 | rewriter.create<mlir::LLVM::AddOp>(loc, idxTy, sc, offset, nsw); |
2564 | // Compute next stride assuming contiguity of the base array |
2565 | // (in element number). |
2566 | auto nextExt = integerCast(loc, rewriter, idxTy, operands[shapeOffset]); |
2567 | prevExt = rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, prevExt, |
2568 | nextExt, nsw); |
2569 | } |
2570 | } |
2571 | |
2572 | // Add computed offset to the base address. |
2573 | if (baseIsBoxed) { |
2574 | // Working with byte offsets. The base address is read from the fir.box. |
2575 | // and used in i8* GEP to do the pointer arithmetic. |
2576 | mlir::Type byteTy = ::getI8Type(coor.getContext()); |
2577 | mlir::Value base = |
2578 | getBaseAddrFromBox(loc, baseBoxTyPair, operands[0], rewriter); |
2579 | llvm::SmallVector<mlir::LLVM::GEPArg> args{offset}; |
2580 | auto addr = rewriter.create<mlir::LLVM::GEPOp>(loc, llvmPtrTy, byteTy, |
2581 | base, args); |
2582 | if (coor.getSubcomponent().empty()) { |
2583 | rewriter.replaceOp(coor, addr); |
2584 | return mlir::success(); |
2585 | } |
2586 | // Cast the element address from void* to the derived type so that the |
2587 | // derived type members can be addresses via a GEP using the index of |
2588 | // components. |
2589 | mlir::Type elementType = |
2590 | getLlvmObjectTypeFromBoxType(coor.getMemref().getType()); |
2591 | while (auto arrayTy = |
2592 | mlir::dyn_cast<mlir::LLVM::LLVMArrayType>(elementType)) |
2593 | elementType = arrayTy.getElementType(); |
2594 | args.clear(); |
2595 | args.push_back(0); |
2596 | if (!coor.getLenParams().empty()) { |
2597 | // If type parameters are present, then we don't want to use a GEPOp |
2598 | // as below, as the LLVM struct type cannot be statically defined. |
2599 | TODO(loc, "derived type with type parameters" ); |
2600 | } |
2601 | llvm::SmallVector<mlir::Value> indices = convertSubcomponentIndices( |
2602 | loc, elementType, |
2603 | operands.slice(coor.getSubcomponentOperandIndex(), |
2604 | coor.getSubcomponent().size())); |
2605 | args.append(indices.begin(), indices.end()); |
2606 | rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(coor, llvmPtrTy, |
2607 | elementType, addr, args); |
2608 | return mlir::success(); |
2609 | } |
2610 | |
2611 | // The array was not boxed, so it must be contiguous. offset is therefore an |
2612 | // element offset and the base type is kept in the GEP unless the element |
2613 | // type size is itself dynamic. |
2614 | mlir::Type objectTy = fir::unwrapRefType(coor.getMemref().getType()); |
2615 | mlir::Type eleType = fir::unwrapSequenceType(objectTy); |
2616 | mlir::Type gepObjectType = convertType(eleType); |
2617 | llvm::SmallVector<mlir::LLVM::GEPArg> args; |
2618 | if (coor.getSubcomponent().empty()) { |
2619 | // No subcomponent. |
2620 | if (!coor.getLenParams().empty()) { |
2621 | // Type parameters. Adjust element size explicitly. |
2622 | auto eleTy = fir::dyn_cast_ptrEleTy(coor.getType()); |
2623 | assert(eleTy && "result must be a reference-like type" ); |
2624 | if (fir::characterWithDynamicLen(eleTy)) { |
2625 | assert(coor.getLenParams().size() == 1); |
2626 | auto length = integerCast(loc, rewriter, idxTy, |
2627 | operands[coor.getLenParamsOperandIndex()]); |
2628 | offset = rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, offset, |
2629 | length, nsw); |
2630 | } else { |
2631 | TODO(loc, "compute size of derived type with type parameters" ); |
2632 | } |
2633 | } |
2634 | args.push_back(offset); |
2635 | } else { |
2636 | // There are subcomponents. |
2637 | args.push_back(offset); |
2638 | llvm::SmallVector<mlir::Value> indices = convertSubcomponentIndices( |
2639 | loc, gepObjectType, |
2640 | operands.slice(coor.getSubcomponentOperandIndex(), |
2641 | coor.getSubcomponent().size())); |
2642 | args.append(indices.begin(), indices.end()); |
2643 | } |
2644 | rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>( |
2645 | coor, llvmPtrTy, gepObjectType, adaptor.getMemref(), args); |
2646 | return mlir::success(); |
2647 | } |
2648 | }; |
2649 | } // namespace |
2650 | |
2651 | /// Convert to (memory) reference to a reference to a subobject. |
2652 | /// The coordinate_of op is a Swiss army knife operation that can be used on |
2653 | /// (memory) references to records, arrays, complex, etc. as well as boxes. |
2654 | /// With unboxed arrays, there is the restriction that the array have a static |
2655 | /// shape in all but the last column. |
2656 | struct CoordinateOpConversion |
2657 | : public fir::FIROpAndTypeConversion<fir::CoordinateOp> { |
2658 | using FIROpAndTypeConversion::FIROpAndTypeConversion; |
2659 | |
2660 | llvm::LogicalResult |
2661 | doRewrite(fir::CoordinateOp coor, mlir::Type ty, OpAdaptor adaptor, |
2662 | mlir::ConversionPatternRewriter &rewriter) const override { |
2663 | mlir::ValueRange operands = adaptor.getOperands(); |
2664 | |
2665 | mlir::Location loc = coor.getLoc(); |
2666 | mlir::Value base = operands[0]; |
2667 | mlir::Type baseObjectTy = coor.getBaseType(); |
2668 | mlir::Type objectTy = fir::dyn_cast_ptrOrBoxEleTy(baseObjectTy); |
2669 | assert(objectTy && "fir.coordinate_of expects a reference type" ); |
2670 | mlir::Type llvmObjectTy = convertType(objectTy); |
2671 | |
2672 | // Complex type - basically, extract the real or imaginary part |
2673 | // FIXME: double check why this is done before the fir.box case below. |
2674 | if (fir::isa_complex(objectTy)) { |
2675 | mlir::Value gep = |
2676 | genGEP(loc, llvmObjectTy, rewriter, base, 0, operands[1]); |
2677 | rewriter.replaceOp(coor, gep); |
2678 | return mlir::success(); |
2679 | } |
2680 | |
2681 | // Boxed type - get the base pointer from the box |
2682 | if (mlir::dyn_cast<fir::BaseBoxType>(baseObjectTy)) |
2683 | return doRewriteBox(coor, operands, loc, rewriter); |
2684 | |
2685 | // Reference, pointer or a heap type |
2686 | if (mlir::isa<fir::ReferenceType, fir::PointerType, fir::HeapType>( |
2687 | baseObjectTy)) |
2688 | return doRewriteRefOrPtr(coor, llvmObjectTy, operands, loc, rewriter); |
2689 | |
2690 | return rewriter.notifyMatchFailure( |
2691 | coor, "fir.coordinate_of base operand has unsupported type" ); |
2692 | } |
2693 | |
2694 | static unsigned getFieldNumber(fir::RecordType ty, mlir::Value op) { |
2695 | return fir::hasDynamicSize(ty) |
2696 | ? op.getDefiningOp() |
2697 | ->getAttrOfType<mlir::IntegerAttr>("field" ) |
2698 | .getInt() |
2699 | : getConstantIntValue(op); |
2700 | } |
2701 | |
2702 | static bool hasSubDimensions(mlir::Type type) { |
2703 | return mlir::isa<fir::SequenceType, fir::RecordType, mlir::TupleType>(type); |
2704 | } |
2705 | |
2706 | // Helper structure to analyze the CoordinateOp path and decide if and how |
2707 | // the GEP should be generated for it. |
2708 | struct ShapeAnalysis { |
2709 | bool hasKnownShape; |
2710 | bool columnIsDeferred; |
2711 | }; |
2712 | |
2713 | /// Walk the abstract memory layout and determine if the path traverses any |
2714 | /// array types with unknown shape. Return true iff all the array types have a |
2715 | /// constant shape along the path. |
2716 | /// TODO: move the verification logic into the verifier. |
2717 | static std::optional<ShapeAnalysis> |
2718 | arraysHaveKnownShape(mlir::Type type, fir::CoordinateOp coor) { |
2719 | fir::CoordinateIndicesAdaptor indices = coor.getIndices(); |
2720 | auto begin = indices.begin(); |
2721 | bool hasKnownShape = true; |
2722 | bool columnIsDeferred = false; |
2723 | for (auto it = begin, end = indices.end(); it != end;) { |
2724 | if (auto arrTy = mlir::dyn_cast<fir::SequenceType>(type)) { |
2725 | bool addressingStart = (it == begin); |
2726 | unsigned arrayDim = arrTy.getDimension(); |
2727 | for (auto dimExtent : llvm::enumerate(arrTy.getShape())) { |
2728 | if (dimExtent.value() == fir::SequenceType::getUnknownExtent()) { |
2729 | hasKnownShape = false; |
2730 | if (addressingStart && dimExtent.index() + 1 == arrayDim) { |
2731 | // If this point was reached, the raws of the first array have |
2732 | // constant extents. |
2733 | columnIsDeferred = true; |
2734 | } else { |
2735 | // One of the array dimension that is not the column of the first |
2736 | // array has dynamic extent. It will not possible to do |
2737 | // code generation for the CoordinateOp if the base is not a |
2738 | // fir.box containing the value of that extent. |
2739 | return ShapeAnalysis{false, false}; |
2740 | } |
2741 | } |
2742 | // There may be less operands than the array size if the |
2743 | // fir.coordinate_of result is not an element but a sub-array. |
2744 | if (it != end) |
2745 | ++it; |
2746 | } |
2747 | type = arrTy.getEleTy(); |
2748 | continue; |
2749 | } |
2750 | if (auto strTy = mlir::dyn_cast<fir::RecordType>(type)) { |
2751 | auto intAttr = llvm::dyn_cast<mlir::IntegerAttr>(*it); |
2752 | if (!intAttr) { |
2753 | mlir::emitError(coor.getLoc(), |
2754 | "expected field name in fir.coordinate_of" ); |
2755 | return std::nullopt; |
2756 | } |
2757 | type = strTy.getType(intAttr.getInt()); |
2758 | } else if (auto strTy = mlir::dyn_cast<mlir::TupleType>(type)) { |
2759 | auto value = llvm::dyn_cast<mlir::Value>(*it); |
2760 | if (!value) { |
2761 | mlir::emitError( |
2762 | coor.getLoc(), |
2763 | "expected constant value to address tuple in fir.coordinate_of" ); |
2764 | return std::nullopt; |
2765 | } |
2766 | type = strTy.getType(getConstantIntValue(value)); |
2767 | } else if (auto charType = mlir::dyn_cast<fir::CharacterType>(type)) { |
2768 | // Addressing character in string. Fortran strings degenerate to arrays |
2769 | // in LLVM, so they are handled like arrays of characters here. |
2770 | if (charType.getLen() == fir::CharacterType::unknownLen()) |
2771 | return ShapeAnalysis{.hasKnownShape: false, .columnIsDeferred: true}; |
2772 | type = fir::CharacterType::getSingleton(charType.getContext(), |
2773 | charType.getFKind()); |
2774 | } |
2775 | ++it; |
2776 | } |
2777 | return ShapeAnalysis{.hasKnownShape: hasKnownShape, .columnIsDeferred: columnIsDeferred}; |
2778 | } |
2779 | |
2780 | private: |
2781 | llvm::LogicalResult |
2782 | doRewriteBox(fir::CoordinateOp coor, mlir::ValueRange operands, |
2783 | mlir::Location loc, |
2784 | mlir::ConversionPatternRewriter &rewriter) const { |
2785 | mlir::Type boxObjTy = coor.getBaseType(); |
2786 | assert(mlir::dyn_cast<fir::BaseBoxType>(boxObjTy) && |
2787 | "This is not a `fir.box`" ); |
2788 | TypePair boxTyPair = getBoxTypePair(boxObjTy); |
2789 | |
2790 | mlir::Value boxBaseAddr = operands[0]; |
2791 | |
2792 | // 1. SPECIAL CASE (uses `fir.len_param_index`): |
2793 | // %box = ... : !fir.box<!fir.type<derived{len1:i32}>> |
2794 | // %lenp = fir.len_param_index len1, !fir.type<derived{len1:i32}> |
2795 | // %addr = coordinate_of %box, %lenp |
2796 | if (coor.getNumOperands() == 2) { |
2797 | mlir::Operation *coordinateDef = |
2798 | (*coor.getCoor().begin()).getDefiningOp(); |
2799 | if (mlir::isa_and_nonnull<fir::LenParamIndexOp>(coordinateDef)) |
2800 | TODO(loc, |
2801 | "fir.coordinate_of - fir.len_param_index is not supported yet" ); |
2802 | } |
2803 | |
2804 | // 2. GENERAL CASE: |
2805 | // 2.1. (`fir.array`) |
2806 | // %box = ... : !fix.box<!fir.array<?xU>> |
2807 | // %idx = ... : index |
2808 | // %resultAddr = coordinate_of %box, %idx : !fir.ref<U> |
2809 | // 2.2 (`fir.derived`) |
2810 | // %box = ... : !fix.box<!fir.type<derived_type{field_1:i32}>> |
2811 | // %idx = ... : i32 |
2812 | // %resultAddr = coordinate_of %box, %idx : !fir.ref<i32> |
2813 | // 2.3 (`fir.derived` inside `fir.array`) |
2814 | // %box = ... : !fir.box<!fir.array<10 x !fir.type<derived_1{field_1:f32, |
2815 | // field_2:f32}>>> %idx1 = ... : index %idx2 = ... : i32 %resultAddr = |
2816 | // coordinate_of %box, %idx1, %idx2 : !fir.ref<f32> |
2817 | // 2.4. TODO: Either document or disable any other case that the following |
2818 | // implementation might convert. |
2819 | mlir::Value resultAddr = |
2820 | getBaseAddrFromBox(loc, boxTyPair, boxBaseAddr, rewriter); |
2821 | // Component Type |
2822 | auto cpnTy = fir::dyn_cast_ptrOrBoxEleTy(boxObjTy); |
2823 | mlir::Type llvmPtrTy = ::getLlvmPtrType(coor.getContext()); |
2824 | mlir::Type byteTy = ::getI8Type(coor.getContext()); |
2825 | mlir::LLVM::IntegerOverflowFlags nsw = |
2826 | mlir::LLVM::IntegerOverflowFlags::nsw; |
2827 | |
2828 | int nextIndexValue = 1; |
2829 | fir::CoordinateIndicesAdaptor indices = coor.getIndices(); |
2830 | for (auto it = indices.begin(), end = indices.end(); it != end;) { |
2831 | if (auto arrTy = mlir::dyn_cast<fir::SequenceType>(cpnTy)) { |
2832 | if (it != indices.begin()) |
2833 | TODO(loc, "fir.array nested inside other array and/or derived type" ); |
2834 | // Applies byte strides from the box. Ignore lower bound from box |
2835 | // since fir.coordinate_of indexes are zero based. Lowering takes care |
2836 | // of lower bound aspects. This both accounts for dynamically sized |
2837 | // types and non contiguous arrays. |
2838 | auto idxTy = lowerTy().indexType(); |
2839 | mlir::Value off = genConstantIndex(loc, idxTy, rewriter, 0); |
2840 | unsigned arrayDim = arrTy.getDimension(); |
2841 | for (unsigned dim = 0; dim < arrayDim && it != end; ++dim, ++it) { |
2842 | mlir::Value stride = |
2843 | getStrideFromBox(loc, boxTyPair, operands[0], dim, rewriter); |
2844 | auto sc = rewriter.create<mlir::LLVM::MulOp>( |
2845 | loc, idxTy, operands[nextIndexValue + dim], stride, nsw); |
2846 | off = rewriter.create<mlir::LLVM::AddOp>(loc, idxTy, sc, off, nsw); |
2847 | } |
2848 | nextIndexValue += arrayDim; |
2849 | resultAddr = rewriter.create<mlir::LLVM::GEPOp>( |
2850 | loc, llvmPtrTy, byteTy, resultAddr, |
2851 | llvm::ArrayRef<mlir::LLVM::GEPArg>{off}); |
2852 | cpnTy = arrTy.getEleTy(); |
2853 | } else if (auto recTy = mlir::dyn_cast<fir::RecordType>(cpnTy)) { |
2854 | auto intAttr = llvm::dyn_cast<mlir::IntegerAttr>(*it); |
2855 | if (!intAttr) |
2856 | return mlir::emitError(loc, |
2857 | "expected field name in fir.coordinate_of" ); |
2858 | int fieldIndex = intAttr.getInt(); |
2859 | ++it; |
2860 | cpnTy = recTy.getType(fieldIndex); |
2861 | auto llvmRecTy = lowerTy().convertType(recTy); |
2862 | resultAddr = rewriter.create<mlir::LLVM::GEPOp>( |
2863 | loc, llvmPtrTy, llvmRecTy, resultAddr, |
2864 | llvm::ArrayRef<mlir::LLVM::GEPArg>{0, fieldIndex}); |
2865 | } else { |
2866 | fir::emitFatalError(loc, "unexpected type in coordinate_of" ); |
2867 | } |
2868 | } |
2869 | |
2870 | rewriter.replaceOp(coor, resultAddr); |
2871 | return mlir::success(); |
2872 | } |
2873 | |
2874 | llvm::LogicalResult |
2875 | doRewriteRefOrPtr(fir::CoordinateOp coor, mlir::Type llvmObjectTy, |
2876 | mlir::ValueRange operands, mlir::Location loc, |
2877 | mlir::ConversionPatternRewriter &rewriter) const { |
2878 | mlir::Type baseObjectTy = coor.getBaseType(); |
2879 | |
2880 | // Component Type |
2881 | mlir::Type cpnTy = fir::dyn_cast_ptrOrBoxEleTy(baseObjectTy); |
2882 | |
2883 | const std::optional<ShapeAnalysis> shapeAnalysis = |
2884 | arraysHaveKnownShape(cpnTy, coor); |
2885 | if (!shapeAnalysis) |
2886 | return mlir::failure(); |
2887 | |
2888 | if (fir::hasDynamicSize(fir::unwrapSequenceType(cpnTy))) |
2889 | return mlir::emitError( |
2890 | loc, "fir.coordinate_of with a dynamic element size is unsupported" ); |
2891 | |
2892 | if (shapeAnalysis->hasKnownShape || shapeAnalysis->columnIsDeferred) { |
2893 | llvm::SmallVector<mlir::LLVM::GEPArg> offs; |
2894 | if (shapeAnalysis->hasKnownShape) { |
2895 | offs.push_back(0); |
2896 | } |
2897 | // Else, only the column is `?` and we can simply place the column value |
2898 | // in the 0-th GEP position. |
2899 | |
2900 | std::optional<int> dims; |
2901 | llvm::SmallVector<mlir::Value> arrIdx; |
2902 | int nextIndexValue = 1; |
2903 | for (auto index : coor.getIndices()) { |
2904 | if (auto intAttr = llvm::dyn_cast<mlir::IntegerAttr>(index)) { |
2905 | // Addressing derived type component. |
2906 | auto recordType = llvm::dyn_cast<fir::RecordType>(cpnTy); |
2907 | if (!recordType) |
2908 | return mlir::emitError( |
2909 | loc, |
2910 | "fir.coordinate base type is not consistent with operands" ); |
2911 | int fieldId = intAttr.getInt(); |
2912 | cpnTy = recordType.getType(fieldId); |
2913 | offs.push_back(fieldId); |
2914 | continue; |
2915 | } |
2916 | // Value index (addressing array, tuple, or complex part). |
2917 | mlir::Value indexValue = operands[nextIndexValue++]; |
2918 | if (auto tupTy = mlir::dyn_cast<mlir::TupleType>(cpnTy)) { |
2919 | cpnTy = tupTy.getType(getConstantIntValue(indexValue)); |
2920 | offs.push_back(indexValue); |
2921 | } else { |
2922 | if (!dims) { |
2923 | if (auto arrayType = llvm::dyn_cast<fir::SequenceType>(cpnTy)) { |
2924 | // Starting addressing array or array component. |
2925 | dims = arrayType.getDimension(); |
2926 | cpnTy = arrayType.getElementType(); |
2927 | } |
2928 | } |
2929 | if (dims) { |
2930 | arrIdx.push_back(indexValue); |
2931 | if (--(*dims) == 0) { |
2932 | // Append array range in reverse (FIR arrays are column-major). |
2933 | offs.append(arrIdx.rbegin(), arrIdx.rend()); |
2934 | arrIdx.clear(); |
2935 | dims.reset(); |
2936 | } |
2937 | } else { |
2938 | offs.push_back(indexValue); |
2939 | } |
2940 | } |
2941 | } |
2942 | // It is possible the fir.coordinate_of result is a sub-array, in which |
2943 | // case there may be some "unfinished" array indices to reverse and push. |
2944 | if (!arrIdx.empty()) |
2945 | offs.append(arrIdx.rbegin(), arrIdx.rend()); |
2946 | |
2947 | mlir::Value base = operands[0]; |
2948 | mlir::Value retval = genGEP(loc, llvmObjectTy, rewriter, base, offs); |
2949 | rewriter.replaceOp(coor, retval); |
2950 | return mlir::success(); |
2951 | } |
2952 | |
2953 | return mlir::emitError( |
2954 | loc, "fir.coordinate_of base operand has unsupported type" ); |
2955 | } |
2956 | }; |
2957 | |
2958 | /// Convert `fir.field_index`. The conversion depends on whether the size of |
2959 | /// the record is static or dynamic. |
2960 | struct FieldIndexOpConversion : public fir::FIROpConversion<fir::FieldIndexOp> { |
2961 | using FIROpConversion::FIROpConversion; |
2962 | |
2963 | // NB: most field references should be resolved by this point |
2964 | llvm::LogicalResult |
2965 | matchAndRewrite(fir::FieldIndexOp field, OpAdaptor adaptor, |
2966 | mlir::ConversionPatternRewriter &rewriter) const override { |
2967 | auto recTy = mlir::cast<fir::RecordType>(field.getOnType()); |
2968 | unsigned index = recTy.getFieldIndex(field.getFieldId()); |
2969 | |
2970 | if (!fir::hasDynamicSize(recTy)) { |
2971 | // Derived type has compile-time constant layout. Return index of the |
2972 | // component type in the parent type (to be used in GEP). |
2973 | rewriter.replaceOp(field, mlir::ValueRange{genConstantOffset( |
2974 | field.getLoc(), rewriter, index)}); |
2975 | return mlir::success(); |
2976 | } |
2977 | |
2978 | // Derived type has compile-time constant layout. Call the compiler |
2979 | // generated function to determine the byte offset of the field at runtime. |
2980 | // This returns a non-constant. |
2981 | mlir::FlatSymbolRefAttr symAttr = mlir::SymbolRefAttr::get( |
2982 | field.getContext(), getOffsetMethodName(recTy, field.getFieldId())); |
2983 | mlir::NamedAttribute callAttr = rewriter.getNamedAttr("callee" , symAttr); |
2984 | mlir::NamedAttribute fieldAttr = rewriter.getNamedAttr( |
2985 | "field" , mlir::IntegerAttr::get(lowerTy().indexType(), index)); |
2986 | rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>( |
2987 | field, lowerTy().offsetType(), adaptor.getOperands(), |
2988 | addLLVMOpBundleAttrs(rewriter, {callAttr, fieldAttr}, |
2989 | adaptor.getOperands().size())); |
2990 | return mlir::success(); |
2991 | } |
2992 | |
2993 | // Re-Construct the name of the compiler generated method that calculates the |
2994 | // offset |
2995 | inline static std::string getOffsetMethodName(fir::RecordType recTy, |
2996 | llvm::StringRef field) { |
2997 | return recTy.getName().str() + "P." + field.str() + ".offset" ; |
2998 | } |
2999 | }; |
3000 | |
3001 | /// Convert `fir.end` |
3002 | struct FirEndOpConversion : public fir::FIROpConversion<fir::FirEndOp> { |
3003 | using FIROpConversion::FIROpConversion; |
3004 | |
3005 | llvm::LogicalResult |
3006 | matchAndRewrite(fir::FirEndOp firEnd, OpAdaptor, |
3007 | mlir::ConversionPatternRewriter &rewriter) const override { |
3008 | TODO(firEnd.getLoc(), "fir.end codegen" ); |
3009 | return mlir::failure(); |
3010 | } |
3011 | }; |
3012 | |
3013 | /// Lower `fir.type_desc` to a global addr. |
3014 | struct TypeDescOpConversion : public fir::FIROpConversion<fir::TypeDescOp> { |
3015 | using FIROpConversion::FIROpConversion; |
3016 | |
3017 | llvm::LogicalResult |
3018 | matchAndRewrite(fir::TypeDescOp typeDescOp, OpAdaptor adaptor, |
3019 | mlir::ConversionPatternRewriter &rewriter) const override { |
3020 | mlir::Type inTy = typeDescOp.getInType(); |
3021 | assert(mlir::isa<fir::RecordType>(inTy) && "expecting fir.type" ); |
3022 | auto recordType = mlir::dyn_cast<fir::RecordType>(inTy); |
3023 | auto module = typeDescOp.getOperation()->getParentOfType<mlir::ModuleOp>(); |
3024 | std::string typeDescName = |
3025 | this->options.typeDescriptorsRenamedForAssembly |
3026 | ? fir::NameUniquer::getTypeDescriptorAssemblyName( |
3027 | recordType.getName()) |
3028 | : fir::NameUniquer::getTypeDescriptorName(recordType.getName()); |
3029 | auto llvmPtrTy = ::getLlvmPtrType(typeDescOp.getContext()); |
3030 | if (auto global = module.lookupSymbol<mlir::LLVM::GlobalOp>(typeDescName)) { |
3031 | rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>( |
3032 | typeDescOp, llvmPtrTy, global.getSymName()); |
3033 | return mlir::success(); |
3034 | } else if (auto global = module.lookupSymbol<fir::GlobalOp>(typeDescName)) { |
3035 | rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>( |
3036 | typeDescOp, llvmPtrTy, global.getSymName()); |
3037 | return mlir::success(); |
3038 | } |
3039 | return mlir::failure(); |
3040 | } |
3041 | }; |
3042 | |
3043 | /// Lower `fir.has_value` operation to `llvm.return` operation. |
3044 | struct HasValueOpConversion |
3045 | : public mlir::OpConversionPattern<fir::HasValueOp> { |
3046 | using OpConversionPattern::OpConversionPattern; |
3047 | |
3048 | llvm::LogicalResult |
3049 | matchAndRewrite(fir::HasValueOp op, OpAdaptor adaptor, |
3050 | mlir::ConversionPatternRewriter &rewriter) const override { |
3051 | rewriter.replaceOpWithNewOp<mlir::LLVM::ReturnOp>(op, |
3052 | adaptor.getOperands()); |
3053 | return mlir::success(); |
3054 | } |
3055 | }; |
3056 | |
3057 | #ifndef NDEBUG |
3058 | // Check if attr's type is compatible with ty. |
3059 | // |
3060 | // This is done by comparing attr's element type, converted to LLVM type, |
3061 | // with ty's element type. |
3062 | // |
3063 | // Only integer and floating point (including complex) attributes are |
3064 | // supported. Also, attr is expected to have a TensorType and ty is expected |
3065 | // to be of LLVMArrayType. If any of the previous conditions is false, then |
3066 | // the specified attr and ty are not supported by this function and are |
3067 | // assumed to be compatible. |
3068 | static inline bool attributeTypeIsCompatible(mlir::MLIRContext *ctx, |
3069 | mlir::Attribute attr, |
3070 | mlir::Type ty) { |
3071 | // Get attr's LLVM element type. |
3072 | if (!attr) |
3073 | return true; |
3074 | auto intOrFpEleAttr = mlir::dyn_cast<mlir::DenseIntOrFPElementsAttr>(attr); |
3075 | if (!intOrFpEleAttr) |
3076 | return true; |
3077 | auto tensorTy = mlir::dyn_cast<mlir::TensorType>(intOrFpEleAttr.getType()); |
3078 | if (!tensorTy) |
3079 | return true; |
3080 | mlir::Type attrEleTy = |
3081 | mlir::LLVMTypeConverter(ctx).convertType(tensorTy.getElementType()); |
3082 | |
3083 | // Get ty's element type. |
3084 | auto arrTy = mlir::dyn_cast<mlir::LLVM::LLVMArrayType>(ty); |
3085 | if (!arrTy) |
3086 | return true; |
3087 | mlir::Type eleTy = arrTy.getElementType(); |
3088 | while ((arrTy = mlir::dyn_cast<mlir::LLVM::LLVMArrayType>(eleTy))) |
3089 | eleTy = arrTy.getElementType(); |
3090 | |
3091 | return attrEleTy == eleTy; |
3092 | } |
3093 | #endif |
3094 | |
3095 | /// Lower `fir.global` operation to `llvm.global` operation. |
3096 | /// `fir.insert_on_range` operations are replaced with constant dense attribute |
3097 | /// if they are applied on the full range. |
3098 | struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> { |
3099 | using FIROpConversion::FIROpConversion; |
3100 | |
3101 | llvm::LogicalResult |
3102 | matchAndRewrite(fir::GlobalOp global, OpAdaptor adaptor, |
3103 | mlir::ConversionPatternRewriter &rewriter) const override { |
3104 | |
3105 | llvm::SmallVector<mlir::Attribute> dbgExprs; |
3106 | |
3107 | if (auto fusedLoc = mlir::dyn_cast<mlir::FusedLoc>(global.getLoc())) { |
3108 | if (auto gvExprAttr = mlir::dyn_cast_if_present<mlir::ArrayAttr>( |
3109 | fusedLoc.getMetadata())) { |
3110 | for (auto attr : gvExprAttr.getAsRange<mlir::Attribute>()) |
3111 | if (auto dbgAttr = |
3112 | mlir::dyn_cast<mlir::LLVM::DIGlobalVariableExpressionAttr>( |
3113 | attr)) |
3114 | dbgExprs.push_back(dbgAttr); |
3115 | } |
3116 | } |
3117 | |
3118 | auto tyAttr = convertType(global.getType()); |
3119 | if (auto boxType = mlir::dyn_cast<fir::BaseBoxType>(global.getType())) |
3120 | tyAttr = this->lowerTy().convertBoxTypeAsStruct(boxType); |
3121 | auto loc = global.getLoc(); |
3122 | mlir::Attribute initAttr = global.getInitVal().value_or(mlir::Attribute()); |
3123 | assert(attributeTypeIsCompatible(global.getContext(), initAttr, tyAttr)); |
3124 | auto linkage = convertLinkage(global.getLinkName()); |
3125 | auto isConst = global.getConstant().has_value(); |
3126 | mlir::SymbolRefAttr comdat; |
3127 | llvm::ArrayRef<mlir::NamedAttribute> attrs; |
3128 | auto g = rewriter.create<mlir::LLVM::GlobalOp>( |
3129 | loc, tyAttr, isConst, linkage, global.getSymName(), initAttr, 0, 0, |
3130 | false, false, comdat, attrs, dbgExprs); |
3131 | |
3132 | if (global.getAlignment() && *global.getAlignment() > 0) |
3133 | g.setAlignment(*global.getAlignment()); |
3134 | |
3135 | auto module = global->getParentOfType<mlir::ModuleOp>(); |
3136 | auto gpuMod = global->getParentOfType<mlir::gpu::GPUModuleOp>(); |
3137 | // Add comdat if necessary |
3138 | if (fir::getTargetTriple(module).supportsCOMDAT() && |
3139 | (linkage == mlir::LLVM::Linkage::Linkonce || |
3140 | linkage == mlir::LLVM::Linkage::LinkonceODR) && |
3141 | !gpuMod) { |
3142 | addComdat(g, rewriter, module); |
3143 | } |
3144 | |
3145 | // Apply all non-Fir::GlobalOp attributes to the LLVM::GlobalOp, preserving |
3146 | // them; whilst taking care not to apply attributes that are lowered in |
3147 | // other ways. |
3148 | llvm::SmallDenseSet<llvm::StringRef> elidedAttrsSet( |
3149 | global.getAttributeNames().begin(), global.getAttributeNames().end()); |
3150 | for (auto &attr : global->getAttrs()) |
3151 | if (!elidedAttrsSet.contains(attr.getName().strref())) |
3152 | g->setAttr(attr.getName(), attr.getValue()); |
3153 | |
3154 | auto &gr = g.getInitializerRegion(); |
3155 | rewriter.inlineRegionBefore(global.getRegion(), gr, gr.end()); |
3156 | if (!gr.empty()) { |
3157 | // Replace insert_on_range with a constant dense attribute if the |
3158 | // initialization is on the full range. |
3159 | auto insertOnRangeOps = gr.front().getOps<fir::InsertOnRangeOp>(); |
3160 | for (auto insertOp : insertOnRangeOps) { |
3161 | if (insertOp.isFullRange()) { |
3162 | auto seqTyAttr = convertType(insertOp.getType()); |
3163 | auto *op = insertOp.getVal().getDefiningOp(); |
3164 | auto constant = mlir::dyn_cast<mlir::arith::ConstantOp>(op); |
3165 | if (!constant) { |
3166 | auto convertOp = mlir::dyn_cast<fir::ConvertOp>(op); |
3167 | if (!convertOp) |
3168 | continue; |
3169 | constant = mlir::cast<mlir::arith::ConstantOp>( |
3170 | convertOp.getValue().getDefiningOp()); |
3171 | } |
3172 | mlir::Type vecType = mlir::VectorType::get( |
3173 | insertOp.getType().getShape(), constant.getType()); |
3174 | auto denseAttr = mlir::DenseElementsAttr::get( |
3175 | mlir::cast<mlir::ShapedType>(vecType), constant.getValue()); |
3176 | rewriter.setInsertionPointAfter(insertOp); |
3177 | rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>( |
3178 | insertOp, seqTyAttr, denseAttr); |
3179 | } |
3180 | } |
3181 | } |
3182 | |
3183 | if (global.getDataAttr() && |
3184 | *global.getDataAttr() == cuf::DataAttribute::Shared) |
3185 | g.setAddrSpace(mlir::NVVM::NVVMMemorySpace::kSharedMemorySpace); |
3186 | |
3187 | rewriter.eraseOp(global); |
3188 | return mlir::success(); |
3189 | } |
3190 | |
3191 | // TODO: String comparisons should be avoided. Replace linkName with an |
3192 | // enumeration. |
3193 | mlir::LLVM::Linkage |
3194 | convertLinkage(std::optional<llvm::StringRef> optLinkage) const { |
3195 | if (optLinkage) { |
3196 | auto name = *optLinkage; |
3197 | if (name == "internal" ) |
3198 | return mlir::LLVM::Linkage::Internal; |
3199 | if (name == "linkonce" ) |
3200 | return mlir::LLVM::Linkage::Linkonce; |
3201 | if (name == "linkonce_odr" ) |
3202 | return mlir::LLVM::Linkage::LinkonceODR; |
3203 | if (name == "common" ) |
3204 | return mlir::LLVM::Linkage::Common; |
3205 | if (name == "weak" ) |
3206 | return mlir::LLVM::Linkage::Weak; |
3207 | } |
3208 | return mlir::LLVM::Linkage::External; |
3209 | } |
3210 | |
3211 | private: |
3212 | static void addComdat(mlir::LLVM::GlobalOp &global, |
3213 | mlir::ConversionPatternRewriter &rewriter, |
3214 | mlir::ModuleOp module) { |
3215 | const char *comdatName = "__llvm_comdat" ; |
3216 | mlir::LLVM::ComdatOp comdatOp = |
3217 | module.lookupSymbol<mlir::LLVM::ComdatOp>(comdatName); |
3218 | if (!comdatOp) { |
3219 | comdatOp = |
3220 | rewriter.create<mlir::LLVM::ComdatOp>(module.getLoc(), comdatName); |
3221 | } |
3222 | if (auto select = comdatOp.lookupSymbol<mlir::LLVM::ComdatSelectorOp>( |
3223 | global.getSymName())) |
3224 | return; |
3225 | mlir::OpBuilder::InsertionGuard guard(rewriter); |
3226 | rewriter.setInsertionPointToEnd(&comdatOp.getBody().back()); |
3227 | auto selectorOp = rewriter.create<mlir::LLVM::ComdatSelectorOp>( |
3228 | comdatOp.getLoc(), global.getSymName(), |
3229 | mlir::LLVM::comdat::Comdat::Any); |
3230 | global.setComdatAttr(mlir::SymbolRefAttr::get( |
3231 | rewriter.getContext(), comdatName, |
3232 | mlir::FlatSymbolRefAttr::get(selectorOp.getSymNameAttr()))); |
3233 | } |
3234 | }; |
3235 | |
3236 | /// `fir.load` --> `llvm.load` |
3237 | struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> { |
3238 | using FIROpConversion::FIROpConversion; |
3239 | |
3240 | llvm::LogicalResult |
3241 | matchAndRewrite(fir::LoadOp load, OpAdaptor adaptor, |
3242 | mlir::ConversionPatternRewriter &rewriter) const override { |
3243 | |
3244 | mlir::Type llvmLoadTy = convertObjectType(load.getType()); |
3245 | const bool isVolatile = fir::isa_volatile_type(load.getMemref().getType()); |
3246 | if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(load.getType())) { |
3247 | // fir.box is a special case because it is considered an ssa value in |
3248 | // fir, but it is lowered as a pointer to a descriptor. So |
3249 | // fir.ref<fir.box> and fir.box end up being the same llvm types and |
3250 | // loading a fir.ref<fir.box> is implemented as taking a snapshot of the |
3251 | // descriptor value into a new descriptor temp. |
3252 | auto inputBoxStorage = adaptor.getOperands()[0]; |
3253 | mlir::Value newBoxStorage; |
3254 | mlir::Location loc = load.getLoc(); |
3255 | if (auto callOp = mlir::dyn_cast_or_null<mlir::LLVM::CallOp>( |
3256 | inputBoxStorage.getDefiningOp())) { |
3257 | if (callOp.getCallee() && |
3258 | ((*callOp.getCallee()) |
3259 | .starts_with(RTNAME_STRING(CUFAllocDescriptor)) || |
3260 | (*callOp.getCallee()).starts_with("__tgt_acc_get_deviceptr" ))) { |
3261 | // CUDA Fortran local descriptor are allocated in managed memory. So |
3262 | // new storage must be allocated the same way. |
3263 | auto mod = load->getParentOfType<mlir::ModuleOp>(); |
3264 | newBoxStorage = |
3265 | genCUFAllocDescriptor(loc, rewriter, mod, boxTy, lowerTy()); |
3266 | } |
3267 | } |
3268 | if (!newBoxStorage) |
3269 | newBoxStorage = genAllocaAndAddrCastWithType(loc, llvmLoadTy, |
3270 | defaultAlign, rewriter); |
3271 | |
3272 | TypePair boxTypePair{boxTy, llvmLoadTy}; |
3273 | mlir::Value boxSize = |
3274 | computeBoxSize(loc, boxTypePair, inputBoxStorage, rewriter); |
3275 | auto memcpy = rewriter.create<mlir::LLVM::MemcpyOp>( |
3276 | loc, newBoxStorage, inputBoxStorage, boxSize, isVolatile); |
3277 | |
3278 | if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa()) |
3279 | memcpy.setTBAATags(*optionalTag); |
3280 | else |
3281 | attachTBAATag(memcpy, boxTy, boxTy, nullptr); |
3282 | rewriter.replaceOp(load, newBoxStorage); |
3283 | } else { |
3284 | mlir::LLVM::LoadOp loadOp = rewriter.create<mlir::LLVM::LoadOp>( |
3285 | load.getLoc(), llvmLoadTy, adaptor.getOperands(), load->getAttrs()); |
3286 | loadOp.setVolatile_(isVolatile); |
3287 | if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa()) |
3288 | loadOp.setTBAATags(*optionalTag); |
3289 | else |
3290 | attachTBAATag(loadOp, load.getType(), load.getType(), nullptr); |
3291 | rewriter.replaceOp(load, loadOp.getResult()); |
3292 | } |
3293 | return mlir::success(); |
3294 | } |
3295 | }; |
3296 | |
3297 | /// Lower `fir.no_reassoc` to LLVM IR dialect. |
3298 | /// TODO: how do we want to enforce this in LLVM-IR? Can we manipulate the fast |
3299 | /// math flags? |
3300 | struct NoReassocOpConversion : public fir::FIROpConversion<fir::NoReassocOp> { |
3301 | using FIROpConversion::FIROpConversion; |
3302 | |
3303 | llvm::LogicalResult |
3304 | matchAndRewrite(fir::NoReassocOp noreassoc, OpAdaptor adaptor, |
3305 | mlir::ConversionPatternRewriter &rewriter) const override { |
3306 | rewriter.replaceOp(noreassoc, adaptor.getOperands()[0]); |
3307 | return mlir::success(); |
3308 | } |
3309 | }; |
3310 | |
3311 | static void genCondBrOp(mlir::Location loc, mlir::Value cmp, mlir::Block *dest, |
3312 | std::optional<mlir::ValueRange> destOps, |
3313 | mlir::ConversionPatternRewriter &rewriter, |
3314 | mlir::Block *newBlock) { |
3315 | if (destOps) |
3316 | rewriter.create<mlir::LLVM::CondBrOp>(loc, cmp, dest, *destOps, newBlock, |
3317 | mlir::ValueRange()); |
3318 | else |
3319 | rewriter.create<mlir::LLVM::CondBrOp>(loc, cmp, dest, newBlock); |
3320 | } |
3321 | |
3322 | template <typename A, typename B> |
3323 | static void genBrOp(A caseOp, mlir::Block *dest, std::optional<B> destOps, |
3324 | mlir::ConversionPatternRewriter &rewriter) { |
3325 | if (destOps) |
3326 | rewriter.replaceOpWithNewOp<mlir::LLVM::BrOp>(caseOp, *destOps, dest); |
3327 | else |
3328 | rewriter.replaceOpWithNewOp<mlir::LLVM::BrOp>(caseOp, std::nullopt, dest); |
3329 | } |
3330 | |
3331 | static void genCaseLadderStep(mlir::Location loc, mlir::Value cmp, |
3332 | mlir::Block *dest, |
3333 | std::optional<mlir::ValueRange> destOps, |
3334 | mlir::ConversionPatternRewriter &rewriter) { |
3335 | auto *thisBlock = rewriter.getInsertionBlock(); |
3336 | auto *newBlock = createBlock(rewriter, dest); |
3337 | rewriter.setInsertionPointToEnd(thisBlock); |
3338 | genCondBrOp(loc, cmp, dest, destOps, rewriter, newBlock); |
3339 | rewriter.setInsertionPointToEnd(newBlock); |
3340 | } |
3341 | |
3342 | /// Conversion of `fir.select_case` |
3343 | /// |
3344 | /// The `fir.select_case` operation is converted to a if-then-else ladder. |
3345 | /// Depending on the case condition type, one or several comparison and |
3346 | /// conditional branching can be generated. |
3347 | /// |
3348 | /// A point value case such as `case(4)`, a lower bound case such as |
3349 | /// `case(5:)` or an upper bound case such as `case(:3)` are converted to a |
3350 | /// simple comparison between the selector value and the constant value in the |
3351 | /// case. The block associated with the case condition is then executed if |
3352 | /// the comparison succeed otherwise it branch to the next block with the |
3353 | /// comparison for the next case conditon. |
3354 | /// |
3355 | /// A closed interval case condition such as `case(7:10)` is converted with a |
3356 | /// first comparison and conditional branching for the lower bound. If |
3357 | /// successful, it branch to a second block with the comparison for the |
3358 | /// upper bound in the same case condition. |
3359 | /// |
3360 | /// TODO: lowering of CHARACTER type cases is not handled yet. |
3361 | struct SelectCaseOpConversion : public fir::FIROpConversion<fir::SelectCaseOp> { |
3362 | using FIROpConversion::FIROpConversion; |
3363 | |
3364 | llvm::LogicalResult |
3365 | matchAndRewrite(fir::SelectCaseOp caseOp, OpAdaptor adaptor, |
3366 | mlir::ConversionPatternRewriter &rewriter) const override { |
3367 | unsigned conds = caseOp.getNumConditions(); |
3368 | llvm::ArrayRef<mlir::Attribute> cases = caseOp.getCases().getValue(); |
3369 | // Type can be CHARACTER, INTEGER, or LOGICAL (C1145) |
3370 | auto ty = caseOp.getSelector().getType(); |
3371 | if (mlir::isa<fir::CharacterType>(ty)) { |
3372 | TODO(caseOp.getLoc(), "fir.select_case codegen with character type" ); |
3373 | return mlir::failure(); |
3374 | } |
3375 | mlir::Value selector = caseOp.getSelector(adaptor.getOperands()); |
3376 | auto loc = caseOp.getLoc(); |
3377 | for (unsigned t = 0; t != conds; ++t) { |
3378 | mlir::Block *dest = caseOp.getSuccessor(t); |
3379 | std::optional<mlir::ValueRange> destOps = |
3380 | caseOp.getSuccessorOperands(adaptor.getOperands(), t); |
3381 | std::optional<mlir::ValueRange> cmpOps = |
3382 | *caseOp.getCompareOperands(adaptor.getOperands(), t); |
3383 | mlir::Attribute attr = cases[t]; |
3384 | assert(mlir::isa<mlir::UnitAttr>(attr) || cmpOps.has_value()); |
3385 | if (mlir::isa<fir::PointIntervalAttr>(attr)) { |
3386 | auto cmp = rewriter.create<mlir::LLVM::ICmpOp>( |
3387 | loc, mlir::LLVM::ICmpPredicate::eq, selector, cmpOps->front()); |
3388 | genCaseLadderStep(loc, cmp, dest, destOps, rewriter); |
3389 | continue; |
3390 | } |
3391 | if (mlir::isa<fir::LowerBoundAttr>(attr)) { |
3392 | auto cmp = rewriter.create<mlir::LLVM::ICmpOp>( |
3393 | loc, mlir::LLVM::ICmpPredicate::sle, cmpOps->front(), selector); |
3394 | genCaseLadderStep(loc, cmp, dest, destOps, rewriter); |
3395 | continue; |
3396 | } |
3397 | if (mlir::isa<fir::UpperBoundAttr>(attr)) { |
3398 | auto cmp = rewriter.create<mlir::LLVM::ICmpOp>( |
3399 | loc, mlir::LLVM::ICmpPredicate::sle, selector, cmpOps->front()); |
3400 | genCaseLadderStep(loc, cmp, dest, destOps, rewriter); |
3401 | continue; |
3402 | } |
3403 | if (mlir::isa<fir::ClosedIntervalAttr>(attr)) { |
3404 | mlir::Value caseArg0 = *cmpOps->begin(); |
3405 | auto cmp0 = rewriter.create<mlir::LLVM::ICmpOp>( |
3406 | loc, mlir::LLVM::ICmpPredicate::sle, caseArg0, selector); |
3407 | auto *thisBlock = rewriter.getInsertionBlock(); |
3408 | auto *newBlock1 = createBlock(rewriter, dest); |
3409 | auto *newBlock2 = createBlock(rewriter, dest); |
3410 | rewriter.setInsertionPointToEnd(thisBlock); |
3411 | rewriter.create<mlir::LLVM::CondBrOp>(loc, cmp0, newBlock1, newBlock2); |
3412 | rewriter.setInsertionPointToEnd(newBlock1); |
3413 | mlir::Value caseArg1 = *(cmpOps->begin() + 1); |
3414 | auto cmp1 = rewriter.create<mlir::LLVM::ICmpOp>( |
3415 | loc, mlir::LLVM::ICmpPredicate::sle, selector, caseArg1); |
3416 | genCondBrOp(loc, cmp1, dest, destOps, rewriter, newBlock2); |
3417 | rewriter.setInsertionPointToEnd(newBlock2); |
3418 | continue; |
3419 | } |
3420 | assert(mlir::isa<mlir::UnitAttr>(attr)); |
3421 | assert((t + 1 == conds) && "unit must be last" ); |
3422 | genBrOp(caseOp, dest, destOps, rewriter); |
3423 | } |
3424 | return mlir::success(); |
3425 | } |
3426 | }; |
3427 | |
3428 | /// Helper function for converting select ops. This function converts the |
3429 | /// signature of the given block. If the new block signature is different from |
3430 | /// `expectedTypes`, returns "failure". |
3431 | static llvm::FailureOr<mlir::Block *> |
3432 | getConvertedBlock(mlir::ConversionPatternRewriter &rewriter, |
3433 | const mlir::TypeConverter *converter, |
3434 | mlir::Operation *branchOp, mlir::Block *block, |
3435 | mlir::TypeRange expectedTypes) { |
3436 | assert(converter && "expected non-null type converter" ); |
3437 | assert(!block->isEntryBlock() && "entry blocks have no predecessors" ); |
3438 | |
3439 | // There is nothing to do if the types already match. |
3440 | if (block->getArgumentTypes() == expectedTypes) |
3441 | return block; |
3442 | |
3443 | // Compute the new block argument types and convert the block. |
3444 | std::optional<mlir::TypeConverter::SignatureConversion> conversion = |
3445 | converter->convertBlockSignature(block); |
3446 | if (!conversion) |
3447 | return rewriter.notifyMatchFailure(branchOp, |
3448 | "could not compute block signature" ); |
3449 | if (expectedTypes != conversion->getConvertedTypes()) |
3450 | return rewriter.notifyMatchFailure( |
3451 | branchOp, |
3452 | "mismatch between adaptor operand types and computed block signature" ); |
3453 | return rewriter.applySignatureConversion(block, *conversion, converter); |
3454 | } |
3455 | |
3456 | template <typename OP> |
3457 | static llvm::LogicalResult |
3458 | selectMatchAndRewrite(const fir::LLVMTypeConverter &lowering, OP select, |
3459 | typename OP::Adaptor adaptor, |
3460 | mlir::ConversionPatternRewriter &rewriter, |
3461 | const mlir::TypeConverter *converter) { |
3462 | unsigned conds = select.getNumConditions(); |
3463 | auto cases = select.getCases().getValue(); |
3464 | mlir::Value selector = adaptor.getSelector(); |
3465 | auto loc = select.getLoc(); |
3466 | assert(conds > 0 && "select must have cases" ); |
3467 | |
3468 | llvm::SmallVector<mlir::Block *> destinations; |
3469 | llvm::SmallVector<mlir::ValueRange> destinationsOperands; |
3470 | mlir::Block *defaultDestination; |
3471 | mlir::ValueRange defaultOperands; |
3472 | llvm::SmallVector<int32_t> caseValues; |
3473 | |
3474 | for (unsigned t = 0; t != conds; ++t) { |
3475 | mlir::Block *dest = select.getSuccessor(t); |
3476 | auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t); |
3477 | const mlir::Attribute &attr = cases[t]; |
3478 | if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) { |
3479 | destinationsOperands.push_back(destOps ? *destOps : mlir::ValueRange{}); |
3480 | auto convertedBlock = |
3481 | getConvertedBlock(rewriter, converter, select, dest, |
3482 | mlir::TypeRange(destinationsOperands.back())); |
3483 | if (mlir::failed(convertedBlock)) |
3484 | return mlir::failure(); |
3485 | destinations.push_back(*convertedBlock); |
3486 | caseValues.push_back(Elt: intAttr.getInt()); |
3487 | continue; |
3488 | } |
3489 | assert(mlir::dyn_cast_or_null<mlir::UnitAttr>(attr)); |
3490 | assert((t + 1 == conds) && "unit must be last" ); |
3491 | defaultOperands = destOps ? *destOps : mlir::ValueRange{}; |
3492 | auto convertedBlock = getConvertedBlock(rewriter, converter, select, dest, |
3493 | mlir::TypeRange(defaultOperands)); |
3494 | if (mlir::failed(convertedBlock)) |
3495 | return mlir::failure(); |
3496 | defaultDestination = *convertedBlock; |
3497 | } |
3498 | |
3499 | // LLVM::SwitchOp takes a i32 type for the selector. |
3500 | if (select.getSelector().getType() != rewriter.getI32Type()) |
3501 | selector = rewriter.create<mlir::LLVM::TruncOp>(loc, rewriter.getI32Type(), |
3502 | selector); |
3503 | |
3504 | rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>( |
3505 | select, selector, |
3506 | /*defaultDestination=*/defaultDestination, |
3507 | /*defaultOperands=*/defaultOperands, |
3508 | /*caseValues=*/caseValues, |
3509 | /*caseDestinations=*/destinations, |
3510 | /*caseOperands=*/destinationsOperands, |
3511 | /*branchWeights=*/llvm::ArrayRef<std::int32_t>()); |
3512 | return mlir::success(); |
3513 | } |
3514 | |
3515 | /// conversion of fir::SelectOp to an if-then-else ladder |
3516 | struct SelectOpConversion : public fir::FIROpConversion<fir::SelectOp> { |
3517 | using FIROpConversion::FIROpConversion; |
3518 | |
3519 | llvm::LogicalResult |
3520 | matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor, |
3521 | mlir::ConversionPatternRewriter &rewriter) const override { |
3522 | return selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor, |
3523 | rewriter, getTypeConverter()); |
3524 | } |
3525 | }; |
3526 | |
3527 | /// conversion of fir::SelectRankOp to an if-then-else ladder |
3528 | struct SelectRankOpConversion : public fir::FIROpConversion<fir::SelectRankOp> { |
3529 | using FIROpConversion::FIROpConversion; |
3530 | |
3531 | llvm::LogicalResult |
3532 | matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor, |
3533 | mlir::ConversionPatternRewriter &rewriter) const override { |
3534 | return selectMatchAndRewrite<fir::SelectRankOp>( |
3535 | lowerTy(), op, adaptor, rewriter, getTypeConverter()); |
3536 | } |
3537 | }; |
3538 | |
3539 | /// Lower `fir.select_type` to LLVM IR dialect. |
3540 | struct SelectTypeOpConversion : public fir::FIROpConversion<fir::SelectTypeOp> { |
3541 | using FIROpConversion::FIROpConversion; |
3542 | |
3543 | llvm::LogicalResult |
3544 | matchAndRewrite(fir::SelectTypeOp select, OpAdaptor adaptor, |
3545 | mlir::ConversionPatternRewriter &rewriter) const override { |
3546 | mlir::emitError(select.getLoc(), |
3547 | "fir.select_type should have already been converted" ); |
3548 | return mlir::failure(); |
3549 | } |
3550 | }; |
3551 | |
3552 | /// `fir.store` --> `llvm.store` |
3553 | struct StoreOpConversion : public fir::FIROpConversion<fir::StoreOp> { |
3554 | using FIROpConversion::FIROpConversion; |
3555 | |
3556 | llvm::LogicalResult |
3557 | matchAndRewrite(fir::StoreOp store, OpAdaptor adaptor, |
3558 | mlir::ConversionPatternRewriter &rewriter) const override { |
3559 | mlir::Location loc = store.getLoc(); |
3560 | mlir::Type storeTy = store.getValue().getType(); |
3561 | mlir::Value llvmValue = adaptor.getValue(); |
3562 | mlir::Value llvmMemref = adaptor.getMemref(); |
3563 | mlir::LLVM::AliasAnalysisOpInterface newOp; |
3564 | const bool isVolatile = |
3565 | fir::isa_volatile_type(store.getMemref().getType()) || |
3566 | fir::isa_volatile_type(store.getValue().getType()); |
3567 | if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(storeTy)) { |
3568 | mlir::Type llvmBoxTy = lowerTy().convertBoxTypeAsStruct(boxTy); |
3569 | // Always use memcpy because LLVM is not as effective at optimizing |
3570 | // aggregate loads/stores as it is optimizing memcpy. |
3571 | TypePair boxTypePair{boxTy, llvmBoxTy}; |
3572 | mlir::Value boxSize = |
3573 | computeBoxSize(loc, boxTypePair, llvmValue, rewriter); |
3574 | newOp = rewriter.create<mlir::LLVM::MemcpyOp>(loc, llvmMemref, llvmValue, |
3575 | boxSize, isVolatile); |
3576 | } else { |
3577 | mlir::LLVM::StoreOp storeOp = |
3578 | rewriter.create<mlir::LLVM::StoreOp>(loc, llvmValue, llvmMemref); |
3579 | |
3580 | if (isVolatile) |
3581 | storeOp.setVolatile_(true); |
3582 | |
3583 | if (store.getNontemporal()) |
3584 | storeOp.setNontemporal(true); |
3585 | |
3586 | newOp = storeOp; |
3587 | } |
3588 | if (std::optional<mlir::ArrayAttr> optionalTag = store.getTbaa()) |
3589 | newOp.setTBAATags(*optionalTag); |
3590 | else |
3591 | attachTBAATag(newOp, storeTy, storeTy, nullptr); |
3592 | rewriter.eraseOp(store); |
3593 | return mlir::success(); |
3594 | } |
3595 | }; |
3596 | |
3597 | /// `fir.copy` --> `llvm.memcpy` or `llvm.memmove` |
3598 | struct CopyOpConversion : public fir::FIROpConversion<fir::CopyOp> { |
3599 | using FIROpConversion::FIROpConversion; |
3600 | |
3601 | llvm::LogicalResult |
3602 | matchAndRewrite(fir::CopyOp copy, OpAdaptor adaptor, |
3603 | mlir::ConversionPatternRewriter &rewriter) const override { |
3604 | mlir::Location loc = copy.getLoc(); |
3605 | const bool isVolatile = |
3606 | fir::isa_volatile_type(copy.getSource().getType()) || |
3607 | fir::isa_volatile_type(copy.getDestination().getType()); |
3608 | mlir::Value llvmSource = adaptor.getSource(); |
3609 | mlir::Value llvmDestination = adaptor.getDestination(); |
3610 | mlir::Type i64Ty = mlir::IntegerType::get(rewriter.getContext(), 64); |
3611 | mlir::Type copyTy = fir::unwrapRefType(copy.getSource().getType()); |
3612 | mlir::Value copySize = genTypeStrideInBytes( |
3613 | loc, i64Ty, rewriter, convertType(copyTy), getDataLayout()); |
3614 | |
3615 | mlir::LLVM::AliasAnalysisOpInterface newOp; |
3616 | if (copy.getNoOverlap()) |
3617 | newOp = rewriter.create<mlir::LLVM::MemcpyOp>( |
3618 | loc, llvmDestination, llvmSource, copySize, isVolatile); |
3619 | else |
3620 | newOp = rewriter.create<mlir::LLVM::MemmoveOp>( |
3621 | loc, llvmDestination, llvmSource, copySize, isVolatile); |
3622 | |
3623 | // TODO: propagate TBAA once FirAliasTagOpInterface added to CopyOp. |
3624 | attachTBAATag(newOp, copyTy, copyTy, nullptr); |
3625 | rewriter.eraseOp(copy); |
3626 | return mlir::success(); |
3627 | } |
3628 | }; |
3629 | |
3630 | namespace { |
3631 | |
3632 | /// Convert `fir.unboxchar` into two `llvm.extractvalue` instructions. One for |
3633 | /// the character buffer and one for the buffer length. |
3634 | struct UnboxCharOpConversion : public fir::FIROpConversion<fir::UnboxCharOp> { |
3635 | using FIROpConversion::FIROpConversion; |
3636 | |
3637 | llvm::LogicalResult |
3638 | matchAndRewrite(fir::UnboxCharOp unboxchar, OpAdaptor adaptor, |
3639 | mlir::ConversionPatternRewriter &rewriter) const override { |
3640 | mlir::Type lenTy = convertType(unboxchar.getType(1)); |
3641 | mlir::Value tuple = adaptor.getOperands()[0]; |
3642 | |
3643 | mlir::Location loc = unboxchar.getLoc(); |
3644 | mlir::Value ptrToBuffer = |
3645 | rewriter.create<mlir::LLVM::ExtractValueOp>(loc, tuple, 0); |
3646 | |
3647 | auto len = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, tuple, 1); |
3648 | mlir::Value lenAfterCast = integerCast(loc, rewriter, lenTy, len); |
3649 | |
3650 | rewriter.replaceOp(unboxchar, |
3651 | llvm::ArrayRef<mlir::Value>{ptrToBuffer, lenAfterCast}); |
3652 | return mlir::success(); |
3653 | } |
3654 | }; |
3655 | |
3656 | /// Lower `fir.unboxproc` operation. Unbox a procedure box value, yielding its |
3657 | /// components. |
3658 | /// TODO: Part of supporting Fortran 2003 procedure pointers. |
3659 | struct UnboxProcOpConversion : public fir::FIROpConversion<fir::UnboxProcOp> { |
3660 | using FIROpConversion::FIROpConversion; |
3661 | |
3662 | llvm::LogicalResult |
3663 | matchAndRewrite(fir::UnboxProcOp unboxproc, OpAdaptor adaptor, |
3664 | mlir::ConversionPatternRewriter &rewriter) const override { |
3665 | TODO(unboxproc.getLoc(), "fir.unboxproc codegen" ); |
3666 | return mlir::failure(); |
3667 | } |
3668 | }; |
3669 | |
3670 | /// convert to LLVM IR dialect `undef` |
3671 | struct UndefOpConversion : public fir::FIROpConversion<fir::UndefOp> { |
3672 | using FIROpConversion::FIROpConversion; |
3673 | |
3674 | llvm::LogicalResult |
3675 | matchAndRewrite(fir::UndefOp undef, OpAdaptor, |
3676 | mlir::ConversionPatternRewriter &rewriter) const override { |
3677 | if (mlir::isa<fir::DummyScopeType>(undef.getType())) { |
3678 | // Dummy scoping is used for Fortran analyses like AA. Once it gets to |
3679 | // pre-codegen rewrite it is erased and a fir.undef is created to |
3680 | // feed to the fir declare operation. Thus, during codegen, we can |
3681 | // simply erase is as it is no longer used. |
3682 | rewriter.eraseOp(undef); |
3683 | return mlir::success(); |
3684 | } |
3685 | rewriter.replaceOpWithNewOp<mlir::LLVM::UndefOp>( |
3686 | undef, convertType(undef.getType())); |
3687 | return mlir::success(); |
3688 | } |
3689 | }; |
3690 | |
3691 | struct ZeroOpConversion : public fir::FIROpConversion<fir::ZeroOp> { |
3692 | using FIROpConversion::FIROpConversion; |
3693 | |
3694 | llvm::LogicalResult |
3695 | matchAndRewrite(fir::ZeroOp zero, OpAdaptor, |
3696 | mlir::ConversionPatternRewriter &rewriter) const override { |
3697 | mlir::Type ty = convertType(zero.getType()); |
3698 | rewriter.replaceOpWithNewOp<mlir::LLVM::ZeroOp>(zero, ty); |
3699 | return mlir::success(); |
3700 | } |
3701 | }; |
3702 | |
3703 | /// `fir.unreachable` --> `llvm.unreachable` |
3704 | struct UnreachableOpConversion |
3705 | : public fir::FIROpConversion<fir::UnreachableOp> { |
3706 | using FIROpConversion::FIROpConversion; |
3707 | |
3708 | llvm::LogicalResult |
3709 | matchAndRewrite(fir::UnreachableOp unreach, OpAdaptor adaptor, |
3710 | mlir::ConversionPatternRewriter &rewriter) const override { |
3711 | rewriter.replaceOpWithNewOp<mlir::LLVM::UnreachableOp>(unreach); |
3712 | return mlir::success(); |
3713 | } |
3714 | }; |
3715 | |
3716 | /// `fir.is_present` --> |
3717 | /// ``` |
3718 | /// %0 = llvm.mlir.constant(0 : i64) |
3719 | /// %1 = llvm.ptrtoint %0 |
3720 | /// %2 = llvm.icmp "ne" %1, %0 : i64 |
3721 | /// ``` |
3722 | struct IsPresentOpConversion : public fir::FIROpConversion<fir::IsPresentOp> { |
3723 | using FIROpConversion::FIROpConversion; |
3724 | |
3725 | llvm::LogicalResult |
3726 | matchAndRewrite(fir::IsPresentOp isPresent, OpAdaptor adaptor, |
3727 | mlir::ConversionPatternRewriter &rewriter) const override { |
3728 | mlir::Type idxTy = lowerTy().indexType(); |
3729 | mlir::Location loc = isPresent.getLoc(); |
3730 | auto ptr = adaptor.getOperands()[0]; |
3731 | |
3732 | if (mlir::isa<fir::BoxCharType>(isPresent.getVal().getType())) { |
3733 | [[maybe_unused]] auto structTy = |
3734 | mlir::cast<mlir::LLVM::LLVMStructType>(ptr.getType()); |
3735 | assert(!structTy.isOpaque() && !structTy.getBody().empty()); |
3736 | |
3737 | ptr = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, ptr, 0); |
3738 | } |
3739 | mlir::LLVM::ConstantOp c0 = |
3740 | genConstantIndex(isPresent.getLoc(), idxTy, rewriter, 0); |
3741 | auto addr = rewriter.create<mlir::LLVM::PtrToIntOp>(loc, idxTy, ptr); |
3742 | rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>( |
3743 | isPresent, mlir::LLVM::ICmpPredicate::ne, addr, c0); |
3744 | |
3745 | return mlir::success(); |
3746 | } |
3747 | }; |
3748 | |
3749 | /// Create value signaling an absent optional argument in a call, e.g. |
3750 | /// `fir.absent !fir.ref<i64>` --> `llvm.mlir.zero : !llvm.ptr<i64>` |
3751 | struct AbsentOpConversion : public fir::FIROpConversion<fir::AbsentOp> { |
3752 | using FIROpConversion::FIROpConversion; |
3753 | |
3754 | llvm::LogicalResult |
3755 | matchAndRewrite(fir::AbsentOp absent, OpAdaptor, |
3756 | mlir::ConversionPatternRewriter &rewriter) const override { |
3757 | mlir::Type ty = convertType(absent.getType()); |
3758 | rewriter.replaceOpWithNewOp<mlir::LLVM::ZeroOp>(absent, ty); |
3759 | return mlir::success(); |
3760 | } |
3761 | }; |
3762 | |
3763 | // |
3764 | // Primitive operations on Complex types |
3765 | // |
3766 | |
3767 | template <typename OPTY> |
3768 | static inline mlir::LLVM::FastmathFlagsAttr getLLVMFMFAttr(OPTY op) { |
3769 | return mlir::LLVM::FastmathFlagsAttr::get( |
3770 | op.getContext(), |
3771 | mlir::arith::convertArithFastMathFlagsToLLVM(op.getFastmath())); |
3772 | } |
3773 | |
3774 | /// Generate inline code for complex addition/subtraction |
3775 | template <typename LLVMOP, typename OPTY> |
3776 | static mlir::LLVM::InsertValueOp |
3777 | complexSum(OPTY sumop, mlir::ValueRange opnds, |
3778 | mlir::ConversionPatternRewriter &rewriter, |
3779 | const fir::LLVMTypeConverter &lowering) { |
3780 | mlir::LLVM::FastmathFlagsAttr fmf = getLLVMFMFAttr(sumop); |
3781 | mlir::Value a = opnds[0]; |
3782 | mlir::Value b = opnds[1]; |
3783 | auto loc = sumop.getLoc(); |
3784 | mlir::Type eleTy = lowering.convertType(getComplexEleTy(sumop.getType())); |
3785 | mlir::Type ty = lowering.convertType(sumop.getType()); |
3786 | auto x0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 0); |
3787 | auto y0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 1); |
3788 | auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0); |
3789 | auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1); |
3790 | auto rx = rewriter.create<LLVMOP>(loc, eleTy, x0, x1, fmf); |
3791 | auto ry = rewriter.create<LLVMOP>(loc, eleTy, y0, y1, fmf); |
3792 | auto r0 = rewriter.create<mlir::LLVM::UndefOp>(loc, ty); |
3793 | auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r0, rx, 0); |
3794 | return rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ry, 1); |
3795 | } |
3796 | } // namespace |
3797 | |
3798 | namespace { |
3799 | struct AddcOpConversion : public fir::FIROpConversion<fir::AddcOp> { |
3800 | using FIROpConversion::FIROpConversion; |
3801 | |
3802 | llvm::LogicalResult |
3803 | matchAndRewrite(fir::AddcOp addc, OpAdaptor adaptor, |
3804 | mlir::ConversionPatternRewriter &rewriter) const override { |
3805 | // given: (x + iy) + (x' + iy') |
3806 | // result: (x + x') + i(y + y') |
3807 | auto r = complexSum<mlir::LLVM::FAddOp>(addc, adaptor.getOperands(), |
3808 | rewriter, lowerTy()); |
3809 | rewriter.replaceOp(addc, r.getResult()); |
3810 | return mlir::success(); |
3811 | } |
3812 | }; |
3813 | |
3814 | struct SubcOpConversion : public fir::FIROpConversion<fir::SubcOp> { |
3815 | using FIROpConversion::FIROpConversion; |
3816 | |
3817 | llvm::LogicalResult |
3818 | matchAndRewrite(fir::SubcOp subc, OpAdaptor adaptor, |
3819 | mlir::ConversionPatternRewriter &rewriter) const override { |
3820 | // given: (x + iy) - (x' + iy') |
3821 | // result: (x - x') + i(y - y') |
3822 | auto r = complexSum<mlir::LLVM::FSubOp>(subc, adaptor.getOperands(), |
3823 | rewriter, lowerTy()); |
3824 | rewriter.replaceOp(subc, r.getResult()); |
3825 | return mlir::success(); |
3826 | } |
3827 | }; |
3828 | |
3829 | /// Inlined complex multiply |
3830 | struct MulcOpConversion : public fir::FIROpConversion<fir::MulcOp> { |
3831 | using FIROpConversion::FIROpConversion; |
3832 | |
3833 | llvm::LogicalResult |
3834 | matchAndRewrite(fir::MulcOp mulc, OpAdaptor adaptor, |
3835 | mlir::ConversionPatternRewriter &rewriter) const override { |
3836 | // TODO: Can we use a call to __muldc3 ? |
3837 | // given: (x + iy) * (x' + iy') |
3838 | // result: (xx'-yy')+i(xy'+yx') |
3839 | mlir::LLVM::FastmathFlagsAttr fmf = getLLVMFMFAttr(mulc); |
3840 | mlir::Value a = adaptor.getOperands()[0]; |
3841 | mlir::Value b = adaptor.getOperands()[1]; |
3842 | auto loc = mulc.getLoc(); |
3843 | mlir::Type eleTy = convertType(getComplexEleTy(mulc.getType())); |
3844 | mlir::Type ty = convertType(mulc.getType()); |
3845 | auto x0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 0); |
3846 | auto y0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 1); |
3847 | auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0); |
3848 | auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1); |
3849 | auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1, fmf); |
3850 | auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1, fmf); |
3851 | auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1, fmf); |
3852 | auto ri = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xy, yx, fmf); |
3853 | auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1, fmf); |
3854 | auto rr = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, xx, yy, fmf); |
3855 | auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty); |
3856 | auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ra, rr, 0); |
3857 | auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ri, 1); |
3858 | rewriter.replaceOp(mulc, r0.getResult()); |
3859 | return mlir::success(); |
3860 | } |
3861 | }; |
3862 | |
3863 | /// Inlined complex division |
3864 | struct DivcOpConversion : public fir::FIROpConversion<fir::DivcOp> { |
3865 | using FIROpConversion::FIROpConversion; |
3866 | |
3867 | llvm::LogicalResult |
3868 | matchAndRewrite(fir::DivcOp divc, OpAdaptor adaptor, |
3869 | mlir::ConversionPatternRewriter &rewriter) const override { |
3870 | // TODO: Can we use a call to __divdc3 instead? |
3871 | // Just generate inline code for now. |
3872 | // given: (x + iy) / (x' + iy') |
3873 | // result: ((xx'+yy')/d) + i((yx'-xy')/d) where d = x'x' + y'y' |
3874 | mlir::LLVM::FastmathFlagsAttr fmf = getLLVMFMFAttr(divc); |
3875 | mlir::Value a = adaptor.getOperands()[0]; |
3876 | mlir::Value b = adaptor.getOperands()[1]; |
3877 | auto loc = divc.getLoc(); |
3878 | mlir::Type eleTy = convertType(getComplexEleTy(divc.getType())); |
3879 | mlir::Type ty = convertType(divc.getType()); |
3880 | auto x0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 0); |
3881 | auto y0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 1); |
3882 | auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0); |
3883 | auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1); |
3884 | auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1, fmf); |
3885 | auto x1x1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x1, x1, fmf); |
3886 | auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1, fmf); |
3887 | auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1, fmf); |
3888 | auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1, fmf); |
3889 | auto y1y1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y1, y1, fmf); |
3890 | auto d = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, x1x1, y1y1, fmf); |
3891 | auto rrn = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xx, yy, fmf); |
3892 | auto rin = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, yx, xy, fmf); |
3893 | auto rr = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rrn, d, fmf); |
3894 | auto ri = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rin, d, fmf); |
3895 | auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty); |
3896 | auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ra, rr, 0); |
3897 | auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ri, 1); |
3898 | rewriter.replaceOp(divc, r0.getResult()); |
3899 | return mlir::success(); |
3900 | } |
3901 | }; |
3902 | |
3903 | /// Inlined complex negation |
3904 | struct NegcOpConversion : public fir::FIROpConversion<fir::NegcOp> { |
3905 | using FIROpConversion::FIROpConversion; |
3906 | |
3907 | llvm::LogicalResult |
3908 | matchAndRewrite(fir::NegcOp neg, OpAdaptor adaptor, |
3909 | mlir::ConversionPatternRewriter &rewriter) const override { |
3910 | // given: -(x + iy) |
3911 | // result: -x - iy |
3912 | auto eleTy = convertType(getComplexEleTy(neg.getType())); |
3913 | auto loc = neg.getLoc(); |
3914 | mlir::Value o0 = adaptor.getOperands()[0]; |
3915 | auto rp = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, o0, 0); |
3916 | auto ip = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, o0, 1); |
3917 | auto nrp = rewriter.create<mlir::LLVM::FNegOp>(loc, eleTy, rp); |
3918 | auto nip = rewriter.create<mlir::LLVM::FNegOp>(loc, eleTy, ip); |
3919 | auto r = rewriter.create<mlir::LLVM::InsertValueOp>(loc, o0, nrp, 0); |
3920 | rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(neg, r, nip, 1); |
3921 | return mlir::success(); |
3922 | } |
3923 | }; |
3924 | |
3925 | struct BoxOffsetOpConversion : public fir::FIROpConversion<fir::BoxOffsetOp> { |
3926 | using FIROpConversion::FIROpConversion; |
3927 | |
3928 | llvm::LogicalResult |
3929 | matchAndRewrite(fir::BoxOffsetOp boxOffset, OpAdaptor adaptor, |
3930 | mlir::ConversionPatternRewriter &rewriter) const override { |
3931 | |
3932 | mlir::Type pty = ::getLlvmPtrType(boxOffset.getContext()); |
3933 | mlir::Type boxRefType = fir::unwrapRefType(boxOffset.getBoxRef().getType()); |
3934 | |
3935 | assert((mlir::isa<fir::BaseBoxType>(boxRefType) || |
3936 | mlir::isa<fir::BoxCharType>(boxRefType)) && |
3937 | "boxRef should be a reference to either fir.box or fir.boxchar" ); |
3938 | |
3939 | mlir::Type llvmBoxTy; |
3940 | int fieldId; |
3941 | if (auto boxType = mlir::dyn_cast_or_null<fir::BaseBoxType>(boxRefType)) { |
3942 | llvmBoxTy = lowerTy().convertBoxTypeAsStruct( |
3943 | mlir::cast<fir::BaseBoxType>(boxType)); |
3944 | fieldId = boxOffset.getField() == fir::BoxFieldAttr::derived_type |
3945 | ? getTypeDescFieldId(boxType) |
3946 | : kAddrPosInBox; |
3947 | } else { |
3948 | auto boxCharType = mlir::cast<fir::BoxCharType>(boxRefType); |
3949 | llvmBoxTy = lowerTy().convertType(boxCharType); |
3950 | fieldId = kAddrPosInBox; |
3951 | } |
3952 | rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>( |
3953 | boxOffset, pty, llvmBoxTy, adaptor.getBoxRef(), |
3954 | llvm::ArrayRef<mlir::LLVM::GEPArg>{0, fieldId}); |
3955 | return mlir::success(); |
3956 | } |
3957 | }; |
3958 | |
3959 | /// Conversion pattern for operation that must be dead. The information in these |
3960 | /// operations is used by other operation. At this point they should not have |
3961 | /// anymore uses. |
3962 | /// These operations are normally dead after the pre-codegen pass. |
3963 | template <typename FromOp> |
3964 | struct MustBeDeadConversion : public fir::FIROpConversion<FromOp> { |
3965 | explicit MustBeDeadConversion(const fir::LLVMTypeConverter &lowering, |
3966 | const fir::FIRToLLVMPassOptions &options) |
3967 | : fir::FIROpConversion<FromOp>(lowering, options) {} |
3968 | using OpAdaptor = typename FromOp::Adaptor; |
3969 | |
3970 | llvm::LogicalResult |
3971 | matchAndRewrite(FromOp op, OpAdaptor adaptor, |
3972 | mlir::ConversionPatternRewriter &rewriter) const final { |
3973 | if (!op->getUses().empty()) |
3974 | return rewriter.notifyMatchFailure(op, "op must be dead" ); |
3975 | rewriter.eraseOp(op); |
3976 | return mlir::success(); |
3977 | } |
3978 | }; |
3979 | |
3980 | struct ShapeOpConversion : public MustBeDeadConversion<fir::ShapeOp> { |
3981 | using MustBeDeadConversion::MustBeDeadConversion; |
3982 | }; |
3983 | |
3984 | struct ShapeShiftOpConversion : public MustBeDeadConversion<fir::ShapeShiftOp> { |
3985 | using MustBeDeadConversion::MustBeDeadConversion; |
3986 | }; |
3987 | |
3988 | struct ShiftOpConversion : public MustBeDeadConversion<fir::ShiftOp> { |
3989 | using MustBeDeadConversion::MustBeDeadConversion; |
3990 | }; |
3991 | |
3992 | struct SliceOpConversion : public MustBeDeadConversion<fir::SliceOp> { |
3993 | using MustBeDeadConversion::MustBeDeadConversion; |
3994 | }; |
3995 | |
3996 | } // namespace |
3997 | |
3998 | namespace { |
3999 | class RenameMSVCLibmCallees |
4000 | : public mlir::OpRewritePattern<mlir::LLVM::CallOp> { |
4001 | public: |
4002 | using OpRewritePattern::OpRewritePattern; |
4003 | |
4004 | llvm::LogicalResult |
4005 | matchAndRewrite(mlir::LLVM::CallOp op, |
4006 | mlir::PatternRewriter &rewriter) const override { |
4007 | rewriter.startOpModification(op); |
4008 | auto callee = op.getCallee(); |
4009 | if (callee) |
4010 | if (*callee == "hypotf" ) |
4011 | op.setCalleeAttr(mlir::SymbolRefAttr::get(op.getContext(), "_hypotf" )); |
4012 | |
4013 | rewriter.finalizeOpModification(op); |
4014 | return mlir::success(); |
4015 | } |
4016 | }; |
4017 | |
4018 | class RenameMSVCLibmFuncs |
4019 | : public mlir::OpRewritePattern<mlir::LLVM::LLVMFuncOp> { |
4020 | public: |
4021 | using OpRewritePattern::OpRewritePattern; |
4022 | |
4023 | llvm::LogicalResult |
4024 | matchAndRewrite(mlir::LLVM::LLVMFuncOp op, |
4025 | mlir::PatternRewriter &rewriter) const override { |
4026 | rewriter.startOpModification(op); |
4027 | if (op.getSymName() == "hypotf" ) |
4028 | op.setSymNameAttr(rewriter.getStringAttr("_hypotf" )); |
4029 | rewriter.finalizeOpModification(op); |
4030 | return mlir::success(); |
4031 | } |
4032 | }; |
4033 | } // namespace |
4034 | |
4035 | namespace { |
4036 | /// Convert FIR dialect to LLVM dialect |
4037 | /// |
4038 | /// This pass lowers all FIR dialect operations to LLVM IR dialect. An |
4039 | /// MLIR pass is used to lower residual Std dialect to LLVM IR dialect. |
4040 | class FIRToLLVMLowering |
4041 | : public fir::impl::FIRToLLVMLoweringBase<FIRToLLVMLowering> { |
4042 | public: |
4043 | FIRToLLVMLowering() = default; |
4044 | FIRToLLVMLowering(fir::FIRToLLVMPassOptions options) : options{options} {} |
4045 | mlir::ModuleOp getModule() { return getOperation(); } |
4046 | |
4047 | void runOnOperation() override final { |
4048 | auto mod = getModule(); |
4049 | if (!forcedTargetTriple.empty()) |
4050 | fir::setTargetTriple(mod, forcedTargetTriple); |
4051 | |
4052 | if (!forcedDataLayout.empty()) { |
4053 | llvm::DataLayout dl(forcedDataLayout); |
4054 | fir::support::setMLIRDataLayout(mod, dl); |
4055 | } |
4056 | |
4057 | if (!forcedTargetCPU.empty()) |
4058 | fir::setTargetCPU(mod, forcedTargetCPU); |
4059 | |
4060 | if (!forcedTuneCPU.empty()) |
4061 | fir::setTuneCPU(mod, forcedTuneCPU); |
4062 | |
4063 | if (!forcedTargetFeatures.empty()) |
4064 | fir::setTargetFeatures(mod, forcedTargetFeatures); |
4065 | |
4066 | if (typeDescriptorsRenamedForAssembly) |
4067 | options.typeDescriptorsRenamedForAssembly = |
4068 | typeDescriptorsRenamedForAssembly; |
4069 | |
4070 | // Run dynamic pass pipeline for converting Math dialect |
4071 | // operations into other dialects (llvm, func, etc.). |
4072 | // Some conversions of Math operations cannot be done |
4073 | // by just using conversion patterns. This is true for |
4074 | // conversions that affect the ModuleOp, e.g. create new |
4075 | // function operations in it. We have to run such conversions |
4076 | // as passes here. |
4077 | mlir::OpPassManager mathConvertionPM("builtin.module" ); |
4078 | |
4079 | bool isAMDGCN = fir::getTargetTriple(mod).isAMDGCN(); |
4080 | // If compiling for AMD target some math operations must be lowered to AMD |
4081 | // GPU library calls, the rest can be converted to LLVM intrinsics, which |
4082 | // is handled in the mathToLLVM conversion. The lowering to libm calls is |
4083 | // not needed since all math operations are handled this way. |
4084 | if (isAMDGCN) |
4085 | mathConvertionPM.addPass(mlir::createConvertMathToROCDL()); |
4086 | |
4087 | // Convert math::FPowI operations to inline implementation |
4088 | // only if the exponent's width is greater than 32, otherwise, |
4089 | // it will be lowered to LLVM intrinsic operation by a later conversion. |
4090 | mlir::ConvertMathToFuncsOptions mathToFuncsOptions{}; |
4091 | mathToFuncsOptions.minWidthOfFPowIExponent = 33; |
4092 | mathConvertionPM.addPass( |
4093 | mlir::createConvertMathToFuncs(mathToFuncsOptions)); |
4094 | mathConvertionPM.addPass(mlir::createConvertComplexToStandardPass()); |
4095 | // Convert Math dialect operations into LLVM dialect operations. |
4096 | // There is no way to prefer MathToLLVM patterns over MathToLibm |
4097 | // patterns (applied below), so we have to run MathToLLVM conversion here. |
4098 | mathConvertionPM.addNestedPass<mlir::func::FuncOp>( |
4099 | mlir::createConvertMathToLLVMPass()); |
4100 | if (mlir::failed(runPipeline(mathConvertionPM, mod))) |
4101 | return signalPassFailure(); |
4102 | |
4103 | std::optional<mlir::DataLayout> dl = |
4104 | fir::support::getOrSetMLIRDataLayout(mod, /*allowDefaultLayout=*/true); |
4105 | if (!dl) { |
4106 | mlir::emitError(mod.getLoc(), |
4107 | "module operation must carry a data layout attribute " |
4108 | "to generate llvm IR from FIR" ); |
4109 | signalPassFailure(); |
4110 | return; |
4111 | } |
4112 | |
4113 | auto *context = getModule().getContext(); |
4114 | fir::LLVMTypeConverter typeConverter{getModule(), |
4115 | options.applyTBAA || applyTBAA, |
4116 | options.forceUnifiedTBAATree, *dl}; |
4117 | mlir::RewritePatternSet pattern(context); |
4118 | fir::populateFIRToLLVMConversionPatterns(typeConverter, pattern, options); |
4119 | mlir::populateFuncToLLVMConversionPatterns(typeConverter, pattern); |
4120 | mlir::populateOpenMPToLLVMConversionPatterns(typeConverter, pattern); |
4121 | mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, pattern); |
4122 | mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, |
4123 | pattern); |
4124 | mlir::cf::populateAssertToLLVMConversionPattern(typeConverter, pattern); |
4125 | // Math operations that have not been converted yet must be converted |
4126 | // to Libm. |
4127 | if (!isAMDGCN) |
4128 | mlir::populateMathToLibmConversionPatterns(pattern); |
4129 | mlir::populateComplexToLLVMConversionPatterns(typeConverter, pattern); |
4130 | mlir::populateVectorToLLVMConversionPatterns(typeConverter, pattern); |
4131 | |
4132 | // Flang specific overloads for OpenMP operations, to allow for special |
4133 | // handling of things like Box types. |
4134 | fir::populateOpenMPFIRToLLVMConversionPatterns(typeConverter, pattern); |
4135 | |
4136 | mlir::ConversionTarget target{*context}; |
4137 | target.addLegalDialect<mlir::LLVM::LLVMDialect>(); |
4138 | // The OpenMP dialect is legal for Operations without regions, for those |
4139 | // which contains regions it is legal if the region contains only the |
4140 | // LLVM dialect. Add OpenMP dialect as a legal dialect for conversion and |
4141 | // legalize conversion of OpenMP operations without regions. |
4142 | mlir::configureOpenMPToLLVMConversionLegality(target, typeConverter); |
4143 | target.addLegalDialect<mlir::omp::OpenMPDialect>(); |
4144 | target.addLegalDialect<mlir::acc::OpenACCDialect>(); |
4145 | target.addLegalDialect<mlir::gpu::GPUDialect>(); |
4146 | |
4147 | // required NOPs for applying a full conversion |
4148 | target.addLegalOp<mlir::ModuleOp>(); |
4149 | |
4150 | // If we're on Windows, we might need to rename some libm calls. |
4151 | bool isMSVC = fir::getTargetTriple(mod).isOSMSVCRT(); |
4152 | if (isMSVC) { |
4153 | pattern.insert<RenameMSVCLibmCallees, RenameMSVCLibmFuncs>(context); |
4154 | |
4155 | target.addDynamicallyLegalOp<mlir::LLVM::CallOp>( |
4156 | [](mlir::LLVM::CallOp op) { |
4157 | auto callee = op.getCallee(); |
4158 | if (!callee) |
4159 | return true; |
4160 | return *callee != "hypotf" ; |
4161 | }); |
4162 | target.addDynamicallyLegalOp<mlir::LLVM::LLVMFuncOp>( |
4163 | [](mlir::LLVM::LLVMFuncOp op) { |
4164 | return op.getSymName() != "hypotf" ; |
4165 | }); |
4166 | } |
4167 | |
4168 | // apply the patterns |
4169 | if (mlir::failed(mlir::applyFullConversion(getModule(), target, |
4170 | std::move(pattern)))) { |
4171 | signalPassFailure(); |
4172 | } |
4173 | |
4174 | // Run pass to add comdats to functions that have weak linkage on relevant |
4175 | // platforms |
4176 | if (fir::getTargetTriple(mod).supportsCOMDAT()) { |
4177 | mlir::OpPassManager comdatPM("builtin.module" ); |
4178 | comdatPM.addPass(mlir::LLVM::createLLVMAddComdats()); |
4179 | if (mlir::failed(runPipeline(comdatPM, mod))) |
4180 | return signalPassFailure(); |
4181 | } |
4182 | } |
4183 | |
4184 | private: |
4185 | fir::FIRToLLVMPassOptions options; |
4186 | }; |
4187 | |
4188 | /// Lower from LLVM IR dialect to proper LLVM-IR and dump the module |
4189 | struct LLVMIRLoweringPass |
4190 | : public mlir::PassWrapper<LLVMIRLoweringPass, |
4191 | mlir::OperationPass<mlir::ModuleOp>> { |
4192 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LLVMIRLoweringPass) |
4193 | |
4194 | LLVMIRLoweringPass(llvm::raw_ostream &output, fir::LLVMIRLoweringPrinter p) |
4195 | : output{output}, printer{p} {} |
4196 | |
4197 | mlir::ModuleOp getModule() { return getOperation(); } |
4198 | |
4199 | void runOnOperation() override final { |
4200 | auto *ctx = getModule().getContext(); |
4201 | auto optName = getModule().getName(); |
4202 | llvm::LLVMContext llvmCtx; |
4203 | if (auto llvmModule = mlir::translateModuleToLLVMIR( |
4204 | getModule(), llvmCtx, optName ? *optName : "FIRModule" )) { |
4205 | printer(*llvmModule, output); |
4206 | return; |
4207 | } |
4208 | |
4209 | mlir::emitError(mlir::UnknownLoc::get(ctx), "could not emit LLVM-IR\n" ); |
4210 | signalPassFailure(); |
4211 | } |
4212 | |
4213 | private: |
4214 | llvm::raw_ostream &output; |
4215 | fir::LLVMIRLoweringPrinter printer; |
4216 | }; |
4217 | |
4218 | } // namespace |
4219 | |
4220 | std::unique_ptr<mlir::Pass> fir::createFIRToLLVMPass() { |
4221 | return std::make_unique<FIRToLLVMLowering>(); |
4222 | } |
4223 | |
4224 | std::unique_ptr<mlir::Pass> |
4225 | fir::createFIRToLLVMPass(fir::FIRToLLVMPassOptions options) { |
4226 | return std::make_unique<FIRToLLVMLowering>(options); |
4227 | } |
4228 | |
4229 | std::unique_ptr<mlir::Pass> |
4230 | fir::createLLVMDialectToLLVMPass(llvm::raw_ostream &output, |
4231 | fir::LLVMIRLoweringPrinter printer) { |
4232 | return std::make_unique<LLVMIRLoweringPass>(output, printer); |
4233 | } |
4234 | |
4235 | void fir::populateFIRToLLVMConversionPatterns( |
4236 | const fir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns, |
4237 | fir::FIRToLLVMPassOptions &options) { |
4238 | patterns.insert< |
4239 | AbsentOpConversion, AddcOpConversion, AddrOfOpConversion, |
4240 | AllocaOpConversion, AllocMemOpConversion, BoxAddrOpConversion, |
4241 | BoxCharLenOpConversion, BoxDimsOpConversion, BoxEleSizeOpConversion, |
4242 | BoxIsAllocOpConversion, BoxIsArrayOpConversion, BoxIsPtrOpConversion, |
4243 | BoxOffsetOpConversion, BoxProcHostOpConversion, BoxRankOpConversion, |
4244 | BoxTypeCodeOpConversion, BoxTypeDescOpConversion, CallOpConversion, |
4245 | CmpcOpConversion, VolatileCastOpConversion, ConvertOpConversion, |
4246 | CoordinateOpConversion, CopyOpConversion, DTEntryOpConversion, |
4247 | DeclareOpConversion, DivcOpConversion, EmboxOpConversion, |
4248 | EmboxCharOpConversion, EmboxProcOpConversion, ExtractValueOpConversion, |
4249 | FieldIndexOpConversion, FirEndOpConversion, FreeMemOpConversion, |
4250 | GlobalLenOpConversion, GlobalOpConversion, InsertOnRangeOpConversion, |
4251 | IsPresentOpConversion, LenParamIndexOpConversion, LoadOpConversion, |
4252 | MulcOpConversion, NegcOpConversion, NoReassocOpConversion, |
4253 | SelectCaseOpConversion, SelectOpConversion, SelectRankOpConversion, |
4254 | SelectTypeOpConversion, ShapeOpConversion, ShapeShiftOpConversion, |
4255 | ShiftOpConversion, SliceOpConversion, StoreOpConversion, |
4256 | StringLitOpConversion, SubcOpConversion, TypeDescOpConversion, |
4257 | TypeInfoOpConversion, UnboxCharOpConversion, UnboxProcOpConversion, |
4258 | UndefOpConversion, UnreachableOpConversion, XArrayCoorOpConversion, |
4259 | XEmboxOpConversion, XReboxOpConversion, ZeroOpConversion>(converter, |
4260 | options); |
4261 | |
4262 | // Patterns that are populated without a type converter do not trigger |
4263 | // target materializations for the operands of the root op. |
4264 | patterns.insert<HasValueOpConversion, InsertValueOpConversion>( |
4265 | patterns.getContext()); |
4266 | } |
4267 | |