1 | //===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "flang/Optimizer/CodeGen/FIROpPatterns.h" |
14 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
15 | #include "llvm/Support/Debug.h" |
16 | |
17 | static inline mlir::Type getLlvmPtrType(mlir::MLIRContext *context, |
18 | unsigned addressSpace = 0) { |
19 | return mlir::LLVM::LLVMPointerType::get(context, addressSpace); |
20 | } |
21 | |
22 | static unsigned getTypeDescFieldId(mlir::Type ty) { |
23 | auto isArray = mlir::isa<fir::SequenceType>(fir::dyn_cast_ptrOrBoxEleTy(ty)); |
24 | return isArray ? kOptTypePtrPosInBox : kDimsPosInBox; |
25 | } |
26 | |
27 | namespace fir { |
28 | |
29 | ConvertFIRToLLVMPattern::ConvertFIRToLLVMPattern( |
30 | llvm::StringRef rootOpName, mlir::MLIRContext *context, |
31 | const fir::LLVMTypeConverter &typeConverter, |
32 | const fir::FIRToLLVMPassOptions &options, mlir::PatternBenefit benefit) |
33 | : ConvertToLLVMPattern(rootOpName, context, typeConverter, benefit), |
34 | options(options) {} |
35 | |
36 | // Convert FIR type to LLVM without turning fir.box<T> into memory |
37 | // reference. |
38 | mlir::Type |
39 | ConvertFIRToLLVMPattern::convertObjectType(mlir::Type firType) const { |
40 | if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(firType)) |
41 | return lowerTy().convertBoxTypeAsStruct(boxTy); |
42 | return lowerTy().convertType(firType); |
43 | } |
44 | |
45 | mlir::LLVM::ConstantOp ConvertFIRToLLVMPattern::genI32Constant( |
46 | mlir::Location loc, mlir::ConversionPatternRewriter &rewriter, |
47 | int value) const { |
48 | mlir::Type i32Ty = rewriter.getI32Type(); |
49 | mlir::IntegerAttr attr = rewriter.getI32IntegerAttr(value); |
50 | return rewriter.create<mlir::LLVM::ConstantOp>(loc, i32Ty, attr); |
51 | } |
52 | |
53 | mlir::LLVM::ConstantOp ConvertFIRToLLVMPattern::genConstantOffset( |
54 | mlir::Location loc, mlir::ConversionPatternRewriter &rewriter, |
55 | int offset) const { |
56 | mlir::Type ity = lowerTy().offsetType(); |
57 | mlir::IntegerAttr cattr = rewriter.getI32IntegerAttr(offset); |
58 | return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, cattr); |
59 | } |
60 | |
61 | /// Perform an extension or truncation as needed on an integer value. Lowering |
62 | /// to the specific target may involve some sign-extending or truncation of |
63 | /// values, particularly to fit them from abstract box types to the |
64 | /// appropriate reified structures. |
65 | mlir::Value ConvertFIRToLLVMPattern::integerCast( |
66 | mlir::Location loc, mlir::ConversionPatternRewriter &rewriter, |
67 | mlir::Type ty, mlir::Value val, bool fold) const { |
68 | auto valTy = val.getType(); |
69 | // If the value was not yet lowered, lower its type so that it can |
70 | // be used in getPrimitiveTypeSizeInBits. |
71 | if (!mlir::isa<mlir::IntegerType>(valTy)) |
72 | valTy = convertType(valTy); |
73 | auto toSize = mlir::LLVM::getPrimitiveTypeSizeInBits(ty); |
74 | auto fromSize = mlir::LLVM::getPrimitiveTypeSizeInBits(valTy); |
75 | if (fold) { |
76 | if (toSize < fromSize) |
77 | return rewriter.createOrFold<mlir::LLVM::TruncOp>(loc, ty, val); |
78 | if (toSize > fromSize) |
79 | return rewriter.createOrFold<mlir::LLVM::SExtOp>(loc, ty, val); |
80 | } else { |
81 | if (toSize < fromSize) |
82 | return rewriter.create<mlir::LLVM::TruncOp>(loc, ty, val); |
83 | if (toSize > fromSize) |
84 | return rewriter.create<mlir::LLVM::SExtOp>(loc, ty, val); |
85 | } |
86 | return val; |
87 | } |
88 | |
89 | fir::ConvertFIRToLLVMPattern::TypePair |
90 | ConvertFIRToLLVMPattern::getBoxTypePair(mlir::Type firBoxTy) const { |
91 | mlir::Type llvmBoxTy = |
92 | lowerTy().convertBoxTypeAsStruct(mlir::cast<fir::BaseBoxType>(firBoxTy)); |
93 | return TypePair{firBoxTy, llvmBoxTy}; |
94 | } |
95 | |
96 | /// Construct code sequence to extract the specific value from a `fir.box`. |
97 | mlir::Value ConvertFIRToLLVMPattern::getValueFromBox( |
98 | mlir::Location loc, TypePair boxTy, mlir::Value box, mlir::Type resultTy, |
99 | mlir::ConversionPatternRewriter &rewriter, int boxValue) const { |
100 | if (mlir::isa<mlir::LLVM::LLVMPointerType>(box.getType())) { |
101 | auto pty = getLlvmPtrType(resultTy.getContext()); |
102 | auto p = rewriter.create<mlir::LLVM::GEPOp>( |
103 | loc, pty, boxTy.llvm, box, |
104 | llvm::ArrayRef<mlir::LLVM::GEPArg>{0, boxValue}); |
105 | auto fldTy = getBoxEleTy(boxTy.llvm, {boxValue}); |
106 | auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(loc, fldTy, p); |
107 | auto castOp = integerCast(loc, rewriter, resultTy, loadOp); |
108 | attachTBAATag(loadOp, boxTy.fir, nullptr, p); |
109 | return castOp; |
110 | } |
111 | return rewriter.create<mlir::LLVM::ExtractValueOp>(loc, box, boxValue); |
112 | } |
113 | |
114 | /// Method to construct code sequence to get the triple for dimension `dim` |
115 | /// from a box. |
116 | llvm::SmallVector<mlir::Value, 3> ConvertFIRToLLVMPattern::getDimsFromBox( |
117 | mlir::Location loc, llvm::ArrayRef<mlir::Type> retTys, TypePair boxTy, |
118 | mlir::Value box, mlir::Value dim, |
119 | mlir::ConversionPatternRewriter &rewriter) const { |
120 | mlir::Value l0 = |
121 | loadDimFieldFromBox(loc, boxTy, box, dim, 0, retTys[0], rewriter); |
122 | mlir::Value l1 = |
123 | loadDimFieldFromBox(loc, boxTy, box, dim, 1, retTys[1], rewriter); |
124 | mlir::Value l2 = |
125 | loadDimFieldFromBox(loc, boxTy, box, dim, 2, retTys[2], rewriter); |
126 | return {l0, l1, l2}; |
127 | } |
128 | |
129 | llvm::SmallVector<mlir::Value, 3> ConvertFIRToLLVMPattern::getDimsFromBox( |
130 | mlir::Location loc, llvm::ArrayRef<mlir::Type> retTys, TypePair boxTy, |
131 | mlir::Value box, int dim, mlir::ConversionPatternRewriter &rewriter) const { |
132 | mlir::Value l0 = |
133 | getDimFieldFromBox(loc, boxTy, box, dim, 0, retTys[0], rewriter); |
134 | mlir::Value l1 = |
135 | getDimFieldFromBox(loc, boxTy, box, dim, 1, retTys[1], rewriter); |
136 | mlir::Value l2 = |
137 | getDimFieldFromBox(loc, boxTy, box, dim, 2, retTys[2], rewriter); |
138 | return {l0, l1, l2}; |
139 | } |
140 | |
141 | mlir::Value ConvertFIRToLLVMPattern::loadDimFieldFromBox( |
142 | mlir::Location loc, TypePair boxTy, mlir::Value box, mlir::Value dim, |
143 | int off, mlir::Type ty, mlir::ConversionPatternRewriter &rewriter) const { |
144 | assert(mlir::isa<mlir::LLVM::LLVMPointerType>(box.getType()) && |
145 | "descriptor inquiry with runtime dim can only be done on descriptor " |
146 | "in memory" ); |
147 | mlir::LLVM::GEPOp p = genGEP(loc, boxTy.llvm, rewriter, box, 0, |
148 | static_cast<int>(kDimsPosInBox), dim, off); |
149 | auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(loc, ty, p); |
150 | attachTBAATag(loadOp, boxTy.fir, nullptr, p); |
151 | return loadOp; |
152 | } |
153 | |
154 | mlir::Value ConvertFIRToLLVMPattern::getDimFieldFromBox( |
155 | mlir::Location loc, TypePair boxTy, mlir::Value box, int dim, int off, |
156 | mlir::Type ty, mlir::ConversionPatternRewriter &rewriter) const { |
157 | if (mlir::isa<mlir::LLVM::LLVMPointerType>(box.getType())) { |
158 | mlir::LLVM::GEPOp p = genGEP(loc, boxTy.llvm, rewriter, box, 0, |
159 | static_cast<int>(kDimsPosInBox), dim, off); |
160 | auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(loc, ty, p); |
161 | attachTBAATag(loadOp, boxTy.fir, nullptr, p); |
162 | return loadOp; |
163 | } |
164 | return rewriter.create<mlir::LLVM::ExtractValueOp>( |
165 | loc, box, llvm::ArrayRef<std::int64_t>{kDimsPosInBox, dim, off}); |
166 | } |
167 | |
168 | mlir::Value ConvertFIRToLLVMPattern::getStrideFromBox( |
169 | mlir::Location loc, TypePair boxTy, mlir::Value box, unsigned dim, |
170 | mlir::ConversionPatternRewriter &rewriter) const { |
171 | auto idxTy = lowerTy().indexType(); |
172 | return getDimFieldFromBox(loc, boxTy, box, dim, kDimStridePos, idxTy, |
173 | rewriter); |
174 | } |
175 | |
176 | /// Read base address from a fir.box. Returned address has type ty. |
177 | mlir::Value ConvertFIRToLLVMPattern::getBaseAddrFromBox( |
178 | mlir::Location loc, TypePair boxTy, mlir::Value box, |
179 | mlir::ConversionPatternRewriter &rewriter) const { |
180 | mlir::Type resultTy = ::getLlvmPtrType(boxTy.llvm.getContext()); |
181 | return getValueFromBox(loc, boxTy, box, resultTy, rewriter, kAddrPosInBox); |
182 | } |
183 | |
184 | mlir::Value ConvertFIRToLLVMPattern::getElementSizeFromBox( |
185 | mlir::Location loc, mlir::Type resultTy, TypePair boxTy, mlir::Value box, |
186 | mlir::ConversionPatternRewriter &rewriter) const { |
187 | return getValueFromBox(loc, boxTy, box, resultTy, rewriter, kElemLenPosInBox); |
188 | } |
189 | |
190 | /// Read base address from a fir.box. Returned address has type ty. |
191 | mlir::Value ConvertFIRToLLVMPattern::getRankFromBox( |
192 | mlir::Location loc, TypePair boxTy, mlir::Value box, |
193 | mlir::ConversionPatternRewriter &rewriter) const { |
194 | mlir::Type resultTy = getBoxEleTy(boxTy.llvm, {kRankPosInBox}); |
195 | return getValueFromBox(loc, boxTy, box, resultTy, rewriter, kRankPosInBox); |
196 | } |
197 | |
198 | /// Read the extra field from a fir.box. |
199 | mlir::Value ConvertFIRToLLVMPattern::getExtraFromBox( |
200 | mlir::Location loc, TypePair boxTy, mlir::Value box, |
201 | mlir::ConversionPatternRewriter &rewriter) const { |
202 | mlir::Type resultTy = getBoxEleTy(boxTy.llvm, {kExtraPosInBox}); |
203 | return getValueFromBox(loc, boxTy, box, resultTy, rewriter, kExtraPosInBox); |
204 | } |
205 | |
206 | // Get the element type given an LLVM type that is of the form |
207 | // (array|struct|vector)+ and the provided indexes. |
208 | mlir::Type ConvertFIRToLLVMPattern::getBoxEleTy( |
209 | mlir::Type type, llvm::ArrayRef<std::int64_t> indexes) const { |
210 | for (unsigned i : indexes) { |
211 | if (auto t = mlir::dyn_cast<mlir::LLVM::LLVMStructType>(type)) { |
212 | assert(!t.isOpaque() && i < t.getBody().size()); |
213 | type = t.getBody()[i]; |
214 | } else if (auto t = mlir::dyn_cast<mlir::LLVM::LLVMArrayType>(type)) { |
215 | type = t.getElementType(); |
216 | } else if (auto t = mlir::dyn_cast<mlir::VectorType>(type)) { |
217 | type = t.getElementType(); |
218 | } else { |
219 | fir::emitFatalError(mlir::UnknownLoc::get(type.getContext()), |
220 | "request for invalid box element type" ); |
221 | } |
222 | } |
223 | return type; |
224 | } |
225 | |
226 | // Return LLVM type of the object described by a fir.box of \p boxType. |
227 | mlir::Type ConvertFIRToLLVMPattern::getLlvmObjectTypeFromBoxType( |
228 | mlir::Type boxType) const { |
229 | mlir::Type objectType = fir::dyn_cast_ptrOrBoxEleTy(boxType); |
230 | assert(objectType && "boxType must be a box type" ); |
231 | return this->convertType(objectType); |
232 | } |
233 | |
234 | /// Read the address of the type descriptor from a box. |
235 | mlir::Value ConvertFIRToLLVMPattern::loadTypeDescAddress( |
236 | mlir::Location loc, TypePair boxTy, mlir::Value box, |
237 | mlir::ConversionPatternRewriter &rewriter) const { |
238 | unsigned typeDescFieldId = getTypeDescFieldId(boxTy.fir); |
239 | mlir::Type tdescType = lowerTy().convertTypeDescType(rewriter.getContext()); |
240 | return getValueFromBox(loc, boxTy, box, tdescType, rewriter, typeDescFieldId); |
241 | } |
242 | |
243 | // Load the attribute from the \p box and perform a check against \p maskValue |
244 | // The final comparison is implemented as `(attribute & maskValue) != 0`. |
245 | mlir::Value ConvertFIRToLLVMPattern::genBoxAttributeCheck( |
246 | mlir::Location loc, TypePair boxTy, mlir::Value box, |
247 | mlir::ConversionPatternRewriter &rewriter, unsigned maskValue) const { |
248 | mlir::Type attrTy = rewriter.getI32Type(); |
249 | mlir::Value attribute = |
250 | getValueFromBox(loc, boxTy, box, attrTy, rewriter, kAttributePosInBox); |
251 | mlir::LLVM::ConstantOp attrMask = genConstantOffset(loc, rewriter, maskValue); |
252 | auto maskRes = |
253 | rewriter.create<mlir::LLVM::AndOp>(loc, attrTy, attribute, attrMask); |
254 | mlir::LLVM::ConstantOp c0 = genConstantOffset(loc, rewriter, 0); |
255 | return rewriter.create<mlir::LLVM::ICmpOp>(loc, mlir::LLVM::ICmpPredicate::ne, |
256 | maskRes, c0); |
257 | } |
258 | |
259 | mlir::Value ConvertFIRToLLVMPattern::computeBoxSize( |
260 | mlir::Location loc, TypePair boxTy, mlir::Value box, |
261 | mlir::ConversionPatternRewriter &rewriter) const { |
262 | auto firBoxType = mlir::dyn_cast<fir::BaseBoxType>(boxTy.fir); |
263 | assert(firBoxType && "must be a BaseBoxType" ); |
264 | const mlir::DataLayout &dl = lowerTy().getDataLayout(); |
265 | if (!firBoxType.isAssumedRank()) |
266 | return genConstantOffset(loc, rewriter, dl.getTypeSize(boxTy.llvm)); |
267 | fir::BaseBoxType firScalarBoxType = firBoxType.getBoxTypeWithNewShape(0); |
268 | mlir::Type llvmScalarBoxType = |
269 | lowerTy().convertBoxTypeAsStruct(firScalarBoxType); |
270 | llvm::TypeSize scalarBoxSizeCst = dl.getTypeSize(llvmScalarBoxType); |
271 | mlir::Value scalarBoxSize = |
272 | genConstantOffset(loc, rewriter, scalarBoxSizeCst); |
273 | mlir::Value rawRank = getRankFromBox(loc, boxTy, box, rewriter); |
274 | mlir::Value rank = |
275 | integerCast(loc, rewriter, scalarBoxSize.getType(), rawRank); |
276 | mlir::Type llvmDimsType = getBoxEleTy(boxTy.llvm, {kDimsPosInBox, 1}); |
277 | llvm::TypeSize sizePerDimCst = dl.getTypeSize(llvmDimsType); |
278 | assert((scalarBoxSizeCst + sizePerDimCst == |
279 | dl.getTypeSize(lowerTy().convertBoxTypeAsStruct( |
280 | firBoxType.getBoxTypeWithNewShape(1)))) && |
281 | "descriptor layout requires adding padding for dim field" ); |
282 | mlir::Value sizePerDim = genConstantOffset(loc, rewriter, sizePerDimCst); |
283 | mlir::Value dimsSize = rewriter.create<mlir::LLVM::MulOp>( |
284 | loc, sizePerDim.getType(), sizePerDim, rank); |
285 | mlir::Value size = rewriter.create<mlir::LLVM::AddOp>( |
286 | loc, scalarBoxSize.getType(), scalarBoxSize, dimsSize); |
287 | return size; |
288 | } |
289 | |
290 | // Find the Block in which the alloca should be inserted. |
291 | // The order to recursively find the proper block: |
292 | // 1. An OpenMP Op that will be outlined. |
293 | // 2. An OpenMP or OpenACC Op with one or more regions holding executable code. |
294 | // 3. A LLVMFuncOp |
295 | // 4. The first ancestor that is one of the above. |
296 | mlir::Block *ConvertFIRToLLVMPattern::getBlockForAllocaInsert( |
297 | mlir::Operation *op, mlir::Region *parentRegion) const { |
298 | if (auto iface = mlir::dyn_cast<mlir::omp::OutlineableOpenMPOpInterface>(op)) |
299 | return iface.getAllocaBlock(); |
300 | if (auto recipeIface = mlir::dyn_cast<mlir::accomp::RecipeInterface>(op)) |
301 | return recipeIface.getAllocaBlock(*parentRegion); |
302 | if (auto llvmFuncOp = mlir::dyn_cast<mlir::LLVM::LLVMFuncOp>(op)) |
303 | return &llvmFuncOp.front(); |
304 | |
305 | return getBlockForAllocaInsert(op->getParentOp(), parentRegion); |
306 | } |
307 | |
308 | // Generate an alloca of size 1 for an object of type \p llvmObjectTy in the |
309 | // allocation address space provided for the architecture in the DataLayout |
310 | // specification. If the address space is different from the devices |
311 | // program address space we perform a cast. In the case of most architectures |
312 | // the program and allocation address space will be the default of 0 and no |
313 | // cast will be emitted. |
314 | mlir::Value ConvertFIRToLLVMPattern::genAllocaAndAddrCastWithType( |
315 | mlir::Location loc, mlir::Type llvmObjectTy, unsigned alignment, |
316 | mlir::ConversionPatternRewriter &rewriter) const { |
317 | auto thisPt = rewriter.saveInsertionPoint(); |
318 | mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp(); |
319 | mlir::Region *parentRegion = rewriter.getInsertionBlock()->getParent(); |
320 | mlir::Block *insertBlock = getBlockForAllocaInsert(parentOp, parentRegion); |
321 | rewriter.setInsertionPointToStart(insertBlock); |
322 | auto size = genI32Constant(loc, rewriter, 1); |
323 | unsigned allocaAs = getAllocaAddressSpace(rewriter); |
324 | unsigned programAs = getProgramAddressSpace(rewriter); |
325 | |
326 | mlir::Value al = rewriter.create<mlir::LLVM::AllocaOp>( |
327 | loc, ::getLlvmPtrType(llvmObjectTy.getContext(), allocaAs), llvmObjectTy, |
328 | size, alignment); |
329 | |
330 | // if our allocation address space, is not the same as the program address |
331 | // space, then we must emit a cast to the program address space before use. |
332 | // An example case would be on AMDGPU, where the allocation address space is |
333 | // the numeric value 5 (private), and the program address space is 0 |
334 | // (generic). |
335 | if (allocaAs != programAs) { |
336 | al = rewriter.create<mlir::LLVM::AddrSpaceCastOp>( |
337 | loc, ::getLlvmPtrType(llvmObjectTy.getContext(), programAs), al); |
338 | } |
339 | |
340 | rewriter.restoreInsertionPoint(thisPt); |
341 | return al; |
342 | } |
343 | |
344 | unsigned ConvertFIRToLLVMPattern::getAllocaAddressSpace( |
345 | mlir::ConversionPatternRewriter &rewriter) const { |
346 | mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp(); |
347 | assert(parentOp != nullptr && |
348 | "expected insertion block to have parent operation" ); |
349 | if (auto module = parentOp->getParentOfType<mlir::ModuleOp>()) |
350 | if (mlir::Attribute addrSpace = |
351 | mlir::DataLayout(module).getAllocaMemorySpace()) |
352 | return llvm::cast<mlir::IntegerAttr>(addrSpace).getUInt(); |
353 | return defaultAddressSpace; |
354 | } |
355 | |
356 | unsigned ConvertFIRToLLVMPattern::getProgramAddressSpace( |
357 | mlir::ConversionPatternRewriter &rewriter) const { |
358 | mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp(); |
359 | assert(parentOp != nullptr && |
360 | "expected insertion block to have parent operation" ); |
361 | if (auto module = parentOp->getParentOfType<mlir::ModuleOp>()) |
362 | if (mlir::Attribute addrSpace = |
363 | mlir::DataLayout(module).getProgramMemorySpace()) |
364 | return llvm::cast<mlir::IntegerAttr>(addrSpace).getUInt(); |
365 | return defaultAddressSpace; |
366 | } |
367 | |
368 | } // namespace fir |
369 | |