1 | //===-- FIROps.cpp --------------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "flang/Optimizer/Dialect/FIROps.h" |
14 | #include "flang/Optimizer/Dialect/FIRAttr.h" |
15 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
16 | #include "flang/Optimizer/Dialect/FIROpsSupport.h" |
17 | #include "flang/Optimizer/Dialect/FIRType.h" |
18 | #include "flang/Optimizer/Dialect/Support/FIRContext.h" |
19 | #include "flang/Optimizer/Dialect/Support/KindMapping.h" |
20 | #include "flang/Optimizer/Support/Utils.h" |
21 | #include "mlir/Dialect/CommonFolders.h" |
22 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
23 | #include "mlir/Dialect/OpenACC/OpenACC.h" |
24 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
25 | #include "mlir/IR/Attributes.h" |
26 | #include "mlir/IR/BuiltinAttributes.h" |
27 | #include "mlir/IR/BuiltinOps.h" |
28 | #include "mlir/IR/Diagnostics.h" |
29 | #include "mlir/IR/Matchers.h" |
30 | #include "mlir/IR/OpDefinition.h" |
31 | #include "mlir/IR/PatternMatch.h" |
32 | #include "mlir/IR/TypeRange.h" |
33 | #include "llvm/ADT/STLExtras.h" |
34 | #include "llvm/ADT/SmallVector.h" |
35 | #include "llvm/ADT/TypeSwitch.h" |
36 | #include "llvm/Support/CommandLine.h" |
37 | |
38 | namespace { |
39 | #include "flang/Optimizer/Dialect/CanonicalizationPatterns.inc" |
40 | } // namespace |
41 | |
42 | static llvm::cl::opt<bool> clUseStrictVolatileVerification( |
43 | "strict-fir-volatile-verifier" , llvm::cl::init(false), |
44 | llvm::cl::desc( |
45 | "use stricter verifier for FIR operations with volatile types" )); |
46 | |
47 | bool fir::useStrictVolatileVerification() { |
48 | return clUseStrictVolatileVerification; |
49 | } |
50 | |
51 | static void propagateAttributes(mlir::Operation *fromOp, |
52 | mlir::Operation *toOp) { |
53 | if (!fromOp || !toOp) |
54 | return; |
55 | |
56 | for (mlir::NamedAttribute attr : fromOp->getAttrs()) { |
57 | if (attr.getName().getValue().starts_with( |
58 | mlir::acc::OpenACCDialect::getDialectNamespace())) |
59 | toOp->setAttr(attr.getName(), attr.getValue()); |
60 | } |
61 | } |
62 | |
63 | /// Return true if a sequence type is of some incomplete size or a record type |
64 | /// is malformed or contains an incomplete sequence type. An incomplete sequence |
65 | /// type is one with more unknown extents in the type than have been provided |
66 | /// via `dynamicExtents`. Sequence types with an unknown rank are incomplete by |
67 | /// definition. |
68 | static bool verifyInType(mlir::Type inType, |
69 | llvm::SmallVectorImpl<llvm::StringRef> &visited, |
70 | unsigned dynamicExtents = 0) { |
71 | if (auto st = mlir::dyn_cast<fir::SequenceType>(inType)) { |
72 | auto shape = st.getShape(); |
73 | if (shape.size() == 0) |
74 | return true; |
75 | for (std::size_t i = 0, end = shape.size(); i < end; ++i) { |
76 | if (shape[i] != fir::SequenceType::getUnknownExtent()) |
77 | continue; |
78 | if (dynamicExtents-- == 0) |
79 | return true; |
80 | } |
81 | } else if (auto rt = mlir::dyn_cast<fir::RecordType>(inType)) { |
82 | // don't recurse if we're already visiting this one |
83 | if (llvm::is_contained(visited, rt.getName())) |
84 | return false; |
85 | // keep track of record types currently being visited |
86 | visited.push_back(Elt: rt.getName()); |
87 | for (auto &field : rt.getTypeList()) |
88 | if (verifyInType(field.second, visited)) |
89 | return true; |
90 | visited.pop_back(); |
91 | } |
92 | return false; |
93 | } |
94 | |
95 | static bool verifyTypeParamCount(mlir::Type inType, unsigned numParams) { |
96 | auto ty = fir::unwrapSequenceType(inType); |
97 | if (numParams > 0) { |
98 | if (auto recTy = mlir::dyn_cast<fir::RecordType>(ty)) |
99 | return numParams != recTy.getNumLenParams(); |
100 | if (auto chrTy = mlir::dyn_cast<fir::CharacterType>(ty)) |
101 | return !(numParams == 1 && chrTy.hasDynamicLen()); |
102 | return true; |
103 | } |
104 | if (auto chrTy = mlir::dyn_cast<fir::CharacterType>(ty)) |
105 | return !chrTy.hasConstantLen(); |
106 | return false; |
107 | } |
108 | |
109 | /// Parser shared by Alloca and Allocmem |
110 | /// |
111 | /// operation ::= %res = (`fir.alloca` | `fir.allocmem`) $in_type |
112 | /// ( `(` $typeparams `)` )? ( `,` $shape )? |
113 | /// attr-dict-without-keyword |
114 | template <typename FN> |
115 | static mlir::ParseResult parseAllocatableOp(FN wrapResultType, |
116 | mlir::OpAsmParser &parser, |
117 | mlir::OperationState &result) { |
118 | mlir::Type intype; |
119 | if (parser.parseType(result&: intype)) |
120 | return mlir::failure(); |
121 | auto &builder = parser.getBuilder(); |
122 | result.addAttribute("in_type" , mlir::TypeAttr::get(intype)); |
123 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands; |
124 | llvm::SmallVector<mlir::Type> typeVec; |
125 | bool hasOperands = false; |
126 | std::int32_t typeparamsSize = 0; |
127 | if (!parser.parseOptionalLParen()) { |
128 | // parse the LEN params of the derived type. (<params> : <types>) |
129 | if (parser.parseOperandList(result&: operands, delimiter: mlir::OpAsmParser::Delimiter::None) || |
130 | parser.parseColonTypeList(result&: typeVec) || parser.parseRParen()) |
131 | return mlir::failure(); |
132 | typeparamsSize = operands.size(); |
133 | hasOperands = true; |
134 | } |
135 | std::int32_t shapeSize = 0; |
136 | if (!parser.parseOptionalComma()) { |
137 | // parse size to scale by, vector of n dimensions of type index |
138 | if (parser.parseOperandList(result&: operands, delimiter: mlir::OpAsmParser::Delimiter::None)) |
139 | return mlir::failure(); |
140 | shapeSize = operands.size() - typeparamsSize; |
141 | auto idxTy = builder.getIndexType(); |
142 | for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i) |
143 | typeVec.push_back(Elt: idxTy); |
144 | hasOperands = true; |
145 | } |
146 | if (hasOperands && |
147 | parser.resolveOperands(operands, types&: typeVec, loc: parser.getNameLoc(), |
148 | result&: result.operands)) |
149 | return mlir::failure(); |
150 | mlir::Type restype = wrapResultType(intype); |
151 | if (!restype) { |
152 | parser.emitError(loc: parser.getNameLoc(), message: "invalid allocate type: " ) << intype; |
153 | return mlir::failure(); |
154 | } |
155 | result.addAttribute("operandSegmentSizes" , builder.getDenseI32ArrayAttr( |
156 | {typeparamsSize, shapeSize})); |
157 | if (parser.parseOptionalAttrDict(result&: result.attributes) || |
158 | parser.addTypeToList(type: restype, result&: result.types)) |
159 | return mlir::failure(); |
160 | return mlir::success(); |
161 | } |
162 | |
163 | template <typename OP> |
164 | static void printAllocatableOp(mlir::OpAsmPrinter &p, OP &op) { |
165 | p << ' ' << op.getInType(); |
166 | if (!op.getTypeparams().empty()) { |
167 | p << '(' << op.getTypeparams() << " : " << op.getTypeparams().getTypes() |
168 | << ')'; |
169 | } |
170 | // print the shape of the allocation (if any); all must be index type |
171 | for (auto sh : op.getShape()) { |
172 | p << ", " ; |
173 | p.printOperand(sh); |
174 | } |
175 | p.printOptionalAttrDict(attrs: op->getAttrs(), elidedAttrs: {"in_type" , "operandSegmentSizes" }); |
176 | } |
177 | |
178 | //===----------------------------------------------------------------------===// |
179 | // AllocaOp |
180 | //===----------------------------------------------------------------------===// |
181 | |
182 | /// Create a legal memory reference as return type |
183 | static mlir::Type wrapAllocaResultType(mlir::Type intype) { |
184 | // FIR semantics: memory references to memory references are disallowed |
185 | if (mlir::isa<fir::ReferenceType>(intype)) |
186 | return {}; |
187 | return fir::ReferenceType::get(intype); |
188 | } |
189 | |
190 | mlir::Type fir::AllocaOp::getAllocatedType() { |
191 | return mlir::cast<fir::ReferenceType>(getType()).getEleTy(); |
192 | } |
193 | |
194 | mlir::Type fir::AllocaOp::getRefTy(mlir::Type ty) { |
195 | return fir::ReferenceType::get(ty); |
196 | } |
197 | |
198 | void fir::AllocaOp::build(mlir::OpBuilder &builder, |
199 | mlir::OperationState &result, mlir::Type inType, |
200 | llvm::StringRef uniqName, mlir::ValueRange typeparams, |
201 | mlir::ValueRange shape, |
202 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
203 | auto nameAttr = builder.getStringAttr(uniqName); |
204 | build(builder, result, wrapAllocaResultType(inType), inType, nameAttr, {}, |
205 | /*pinned=*/false, typeparams, shape); |
206 | result.addAttributes(attributes); |
207 | } |
208 | |
209 | void fir::AllocaOp::build(mlir::OpBuilder &builder, |
210 | mlir::OperationState &result, mlir::Type inType, |
211 | llvm::StringRef uniqName, bool pinned, |
212 | mlir::ValueRange typeparams, mlir::ValueRange shape, |
213 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
214 | auto nameAttr = builder.getStringAttr(uniqName); |
215 | build(builder, result, wrapAllocaResultType(inType), inType, nameAttr, {}, |
216 | pinned, typeparams, shape); |
217 | result.addAttributes(attributes); |
218 | } |
219 | |
220 | void fir::AllocaOp::build(mlir::OpBuilder &builder, |
221 | mlir::OperationState &result, mlir::Type inType, |
222 | llvm::StringRef uniqName, llvm::StringRef bindcName, |
223 | mlir::ValueRange typeparams, mlir::ValueRange shape, |
224 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
225 | auto nameAttr = |
226 | uniqName.empty() ? mlir::StringAttr{} : builder.getStringAttr(uniqName); |
227 | auto bindcAttr = |
228 | bindcName.empty() ? mlir::StringAttr{} : builder.getStringAttr(bindcName); |
229 | build(builder, result, wrapAllocaResultType(inType), inType, nameAttr, |
230 | bindcAttr, /*pinned=*/false, typeparams, shape); |
231 | result.addAttributes(attributes); |
232 | } |
233 | |
234 | void fir::AllocaOp::build(mlir::OpBuilder &builder, |
235 | mlir::OperationState &result, mlir::Type inType, |
236 | llvm::StringRef uniqName, llvm::StringRef bindcName, |
237 | bool pinned, mlir::ValueRange typeparams, |
238 | mlir::ValueRange shape, |
239 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
240 | auto nameAttr = |
241 | uniqName.empty() ? mlir::StringAttr{} : builder.getStringAttr(uniqName); |
242 | auto bindcAttr = |
243 | bindcName.empty() ? mlir::StringAttr{} : builder.getStringAttr(bindcName); |
244 | build(builder, result, wrapAllocaResultType(inType), inType, nameAttr, |
245 | bindcAttr, pinned, typeparams, shape); |
246 | result.addAttributes(attributes); |
247 | } |
248 | |
249 | void fir::AllocaOp::build(mlir::OpBuilder &builder, |
250 | mlir::OperationState &result, mlir::Type inType, |
251 | mlir::ValueRange typeparams, mlir::ValueRange shape, |
252 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
253 | build(builder, result, wrapAllocaResultType(inType), inType, {}, {}, |
254 | /*pinned=*/false, typeparams, shape); |
255 | result.addAttributes(attributes); |
256 | } |
257 | |
258 | void fir::AllocaOp::build(mlir::OpBuilder &builder, |
259 | mlir::OperationState &result, mlir::Type inType, |
260 | bool pinned, mlir::ValueRange typeparams, |
261 | mlir::ValueRange shape, |
262 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
263 | build(builder, result, wrapAllocaResultType(inType), inType, {}, {}, pinned, |
264 | typeparams, shape); |
265 | result.addAttributes(attributes); |
266 | } |
267 | |
268 | mlir::ParseResult fir::AllocaOp::parse(mlir::OpAsmParser &parser, |
269 | mlir::OperationState &result) { |
270 | return parseAllocatableOp(wrapAllocaResultType, parser, result); |
271 | } |
272 | |
273 | void fir::AllocaOp::print(mlir::OpAsmPrinter &p) { |
274 | printAllocatableOp(p, *this); |
275 | } |
276 | |
277 | llvm::LogicalResult fir::AllocaOp::verify() { |
278 | llvm::SmallVector<llvm::StringRef> visited; |
279 | if (verifyInType(getInType(), visited, numShapeOperands())) |
280 | return emitOpError("invalid type for allocation" ); |
281 | if (verifyTypeParamCount(getInType(), numLenParams())) |
282 | return emitOpError("LEN params do not correspond to type" ); |
283 | mlir::Type outType = getType(); |
284 | if (!mlir::isa<fir::ReferenceType>(outType)) |
285 | return emitOpError("must be a !fir.ref type" ); |
286 | return mlir::success(); |
287 | } |
288 | |
289 | bool fir::AllocaOp::ownsNestedAlloca(mlir::Operation *op) { |
290 | return op->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>() || |
291 | op->hasTrait<mlir::OpTrait::AutomaticAllocationScope>() || |
292 | mlir::isa<mlir::LoopLikeOpInterface>(*op); |
293 | } |
294 | |
295 | mlir::Region *fir::AllocaOp::getOwnerRegion() { |
296 | mlir::Operation *currentOp = getOperation(); |
297 | while (mlir::Operation *parentOp = currentOp->getParentOp()) { |
298 | // If the operation was not registered, inquiries about its traits will be |
299 | // incorrect and it is not possible to reason about the operation. This |
300 | // should not happen in a normal Fortran compilation flow, but be foolproof. |
301 | if (!parentOp->isRegistered()) |
302 | return nullptr; |
303 | if (fir::AllocaOp::ownsNestedAlloca(parentOp)) |
304 | return currentOp->getParentRegion(); |
305 | currentOp = parentOp; |
306 | } |
307 | return nullptr; |
308 | } |
309 | |
310 | //===----------------------------------------------------------------------===// |
311 | // AllocMemOp |
312 | //===----------------------------------------------------------------------===// |
313 | |
314 | /// Create a legal heap reference as return type |
315 | static mlir::Type wrapAllocMemResultType(mlir::Type intype) { |
316 | // Fortran semantics: C852 an entity cannot be both ALLOCATABLE and POINTER |
317 | // 8.5.3 note 1 prohibits ALLOCATABLE procedures as well |
318 | // FIR semantics: one may not allocate a memory reference value |
319 | if (mlir::isa<fir::ReferenceType, fir::HeapType, fir::PointerType, |
320 | mlir::FunctionType>(intype)) |
321 | return {}; |
322 | return fir::HeapType::get(intype); |
323 | } |
324 | |
325 | mlir::Type fir::AllocMemOp::getAllocatedType() { |
326 | return mlir::cast<fir::HeapType>(getType()).getEleTy(); |
327 | } |
328 | |
329 | mlir::Type fir::AllocMemOp::getRefTy(mlir::Type ty) { |
330 | return fir::HeapType::get(ty); |
331 | } |
332 | |
333 | void fir::AllocMemOp::build(mlir::OpBuilder &builder, |
334 | mlir::OperationState &result, mlir::Type inType, |
335 | llvm::StringRef uniqName, |
336 | mlir::ValueRange typeparams, mlir::ValueRange shape, |
337 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
338 | auto nameAttr = builder.getStringAttr(uniqName); |
339 | build(builder, result, wrapAllocMemResultType(inType), inType, nameAttr, {}, |
340 | typeparams, shape); |
341 | result.addAttributes(attributes); |
342 | } |
343 | |
344 | void fir::AllocMemOp::build(mlir::OpBuilder &builder, |
345 | mlir::OperationState &result, mlir::Type inType, |
346 | llvm::StringRef uniqName, llvm::StringRef bindcName, |
347 | mlir::ValueRange typeparams, mlir::ValueRange shape, |
348 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
349 | auto nameAttr = builder.getStringAttr(uniqName); |
350 | auto bindcAttr = builder.getStringAttr(bindcName); |
351 | build(builder, result, wrapAllocMemResultType(inType), inType, nameAttr, |
352 | bindcAttr, typeparams, shape); |
353 | result.addAttributes(attributes); |
354 | } |
355 | |
356 | void fir::AllocMemOp::build(mlir::OpBuilder &builder, |
357 | mlir::OperationState &result, mlir::Type inType, |
358 | mlir::ValueRange typeparams, mlir::ValueRange shape, |
359 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
360 | build(builder, result, wrapAllocMemResultType(inType), inType, {}, {}, |
361 | typeparams, shape); |
362 | result.addAttributes(attributes); |
363 | } |
364 | |
365 | mlir::ParseResult fir::AllocMemOp::parse(mlir::OpAsmParser &parser, |
366 | mlir::OperationState &result) { |
367 | return parseAllocatableOp(wrapAllocMemResultType, parser, result); |
368 | } |
369 | |
370 | void fir::AllocMemOp::print(mlir::OpAsmPrinter &p) { |
371 | printAllocatableOp(p, *this); |
372 | } |
373 | |
374 | llvm::LogicalResult fir::AllocMemOp::verify() { |
375 | llvm::SmallVector<llvm::StringRef> visited; |
376 | if (verifyInType(getInType(), visited, numShapeOperands())) |
377 | return emitOpError("invalid type for allocation" ); |
378 | if (verifyTypeParamCount(getInType(), numLenParams())) |
379 | return emitOpError("LEN params do not correspond to type" ); |
380 | mlir::Type outType = getType(); |
381 | if (!mlir::dyn_cast<fir::HeapType>(outType)) |
382 | return emitOpError("must be a !fir.heap type" ); |
383 | if (fir::isa_unknown_size_box(fir::dyn_cast_ptrEleTy(outType))) |
384 | return emitOpError("cannot allocate !fir.box of unknown rank or type" ); |
385 | return mlir::success(); |
386 | } |
387 | |
388 | //===----------------------------------------------------------------------===// |
389 | // ArrayCoorOp |
390 | //===----------------------------------------------------------------------===// |
391 | |
392 | // CHARACTERs and derived types with LEN PARAMETERs are dependent types that |
393 | // require runtime values to fully define the type of an object. |
394 | static bool validTypeParams(mlir::Type dynTy, mlir::ValueRange typeParams, |
395 | bool allowParamsForBox = false) { |
396 | dynTy = fir::unwrapAllRefAndSeqType(dynTy); |
397 | if (mlir::isa<fir::BaseBoxType>(dynTy)) { |
398 | // A box value will contain type parameter values itself. |
399 | if (!allowParamsForBox) |
400 | return typeParams.size() == 0; |
401 | |
402 | // A boxed value may have no length parameters, when the lengths |
403 | // are assumed. If dynamic lengths are used, then proceed |
404 | // to the verification below. |
405 | if (typeParams.size() == 0) |
406 | return true; |
407 | |
408 | dynTy = fir::getFortranElementType(dynTy); |
409 | } |
410 | // Derived type must have all type parameters satisfied. |
411 | if (auto recTy = mlir::dyn_cast<fir::RecordType>(dynTy)) |
412 | return typeParams.size() == recTy.getNumLenParams(); |
413 | // Characters with non-constant LEN must have a type parameter value. |
414 | if (auto charTy = mlir::dyn_cast<fir::CharacterType>(dynTy)) |
415 | if (charTy.hasDynamicLen()) |
416 | return typeParams.size() == 1; |
417 | // Otherwise, any type parameters are invalid. |
418 | return typeParams.size() == 0; |
419 | } |
420 | |
421 | llvm::LogicalResult fir::ArrayCoorOp::verify() { |
422 | auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(getMemref().getType()); |
423 | auto arrTy = mlir::dyn_cast<fir::SequenceType>(eleTy); |
424 | if (!arrTy) |
425 | return emitOpError("must be a reference to an array" ); |
426 | auto arrDim = arrTy.getDimension(); |
427 | |
428 | if (auto shapeOp = getShape()) { |
429 | auto shapeTy = shapeOp.getType(); |
430 | unsigned shapeTyRank = 0; |
431 | if (auto s = mlir::dyn_cast<fir::ShapeType>(shapeTy)) { |
432 | shapeTyRank = s.getRank(); |
433 | } else if (auto ss = mlir::dyn_cast<fir::ShapeShiftType>(shapeTy)) { |
434 | shapeTyRank = ss.getRank(); |
435 | } else { |
436 | auto s = mlir::cast<fir::ShiftType>(shapeTy); |
437 | shapeTyRank = s.getRank(); |
438 | // TODO: it looks like PreCGRewrite and CodeGen can support |
439 | // fir.shift with plain array reference, so we may consider |
440 | // removing this check. |
441 | if (!mlir::isa<fir::BaseBoxType>(getMemref().getType())) |
442 | return emitOpError("shift can only be provided with fir.box memref" ); |
443 | } |
444 | if (arrDim && arrDim != shapeTyRank) |
445 | return emitOpError("rank of dimension mismatched" ); |
446 | // TODO: support slicing with changing the number of dimensions, |
447 | // e.g. when array_coor represents an element access to array(:,1,:) |
448 | // slice: the shape is 3D and the number of indices is 2 in this case. |
449 | if (shapeTyRank != getIndices().size()) |
450 | return emitOpError("number of indices do not match dim rank" ); |
451 | } |
452 | |
453 | if (auto sliceOp = getSlice()) { |
454 | if (auto sl = mlir::dyn_cast_or_null<fir::SliceOp>(sliceOp.getDefiningOp())) |
455 | if (!sl.getSubstr().empty()) |
456 | return emitOpError("array_coor cannot take a slice with substring" ); |
457 | if (auto sliceTy = mlir::dyn_cast<fir::SliceType>(sliceOp.getType())) |
458 | if (sliceTy.getRank() != arrDim) |
459 | return emitOpError("rank of dimension in slice mismatched" ); |
460 | } |
461 | if (!validTypeParams(getMemref().getType(), getTypeparams())) |
462 | return emitOpError("invalid type parameters" ); |
463 | |
464 | return mlir::success(); |
465 | } |
466 | |
467 | // Pull in fir.embox and fir.rebox into fir.array_coor when possible. |
468 | struct SimplifyArrayCoorOp : public mlir::OpRewritePattern<fir::ArrayCoorOp> { |
469 | using mlir::OpRewritePattern<fir::ArrayCoorOp>::OpRewritePattern; |
470 | llvm::LogicalResult |
471 | matchAndRewrite(fir::ArrayCoorOp op, |
472 | mlir::PatternRewriter &rewriter) const override { |
473 | mlir::Value memref = op.getMemref(); |
474 | if (!mlir::isa<fir::BaseBoxType>(memref.getType())) |
475 | return mlir::failure(); |
476 | |
477 | mlir::Value boxedMemref, boxedShape, boxedSlice; |
478 | if (auto emboxOp = |
479 | mlir::dyn_cast_or_null<fir::EmboxOp>(memref.getDefiningOp())) { |
480 | boxedMemref = emboxOp.getMemref(); |
481 | boxedShape = emboxOp.getShape(); |
482 | boxedSlice = emboxOp.getSlice(); |
483 | // If any of operands, that are not currently supported for migration |
484 | // to ArrayCoorOp, is present, don't rewrite. |
485 | if (!emboxOp.getTypeparams().empty() || emboxOp.getSourceBox() || |
486 | emboxOp.getAccessMap()) |
487 | return mlir::failure(); |
488 | } else if (auto reboxOp = mlir::dyn_cast_or_null<fir::ReboxOp>( |
489 | memref.getDefiningOp())) { |
490 | boxedMemref = reboxOp.getBox(); |
491 | boxedShape = reboxOp.getShape(); |
492 | // Avoid pulling in rebox that performs reshaping. |
493 | // There is no way to represent box reshaping with array_coor. |
494 | if (boxedShape && !mlir::isa<fir::ShiftType>(boxedShape.getType())) |
495 | return mlir::failure(); |
496 | boxedSlice = reboxOp.getSlice(); |
497 | } else { |
498 | return mlir::failure(); |
499 | } |
500 | |
501 | bool boxedShapeIsShift = |
502 | boxedShape && mlir::isa<fir::ShiftType>(boxedShape.getType()); |
503 | bool boxedShapeIsShape = |
504 | boxedShape && mlir::isa<fir::ShapeType>(boxedShape.getType()); |
505 | bool boxedShapeIsShapeShift = |
506 | boxedShape && mlir::isa<fir::ShapeShiftType>(boxedShape.getType()); |
507 | |
508 | // Slices changing the number of dimensions are not supported |
509 | // for array_coor yet. |
510 | unsigned origBoxRank; |
511 | if (mlir::isa<fir::BaseBoxType>(boxedMemref.getType())) |
512 | origBoxRank = fir::getBoxRank(boxedMemref.getType()); |
513 | else if (auto arrTy = mlir::dyn_cast<fir::SequenceType>( |
514 | fir::unwrapRefType(boxedMemref.getType()))) |
515 | origBoxRank = arrTy.getDimension(); |
516 | else |
517 | return mlir::failure(); |
518 | |
519 | if (fir::getBoxRank(memref.getType()) != origBoxRank) |
520 | return mlir::failure(); |
521 | |
522 | // Slices with substring are not supported by array_coor. |
523 | if (boxedSlice) |
524 | if (auto sliceOp = |
525 | mlir::dyn_cast_or_null<fir::SliceOp>(boxedSlice.getDefiningOp())) |
526 | if (!sliceOp.getSubstr().empty()) |
527 | return mlir::failure(); |
528 | |
529 | // If embox/rebox and array_coor have conflicting shapes or slices, |
530 | // do nothing. |
531 | if (op.getShape() && boxedShape && boxedShape != op.getShape()) |
532 | return mlir::failure(); |
533 | if (op.getSlice() && boxedSlice && boxedSlice != op.getSlice()) |
534 | return mlir::failure(); |
535 | |
536 | std::optional<IndicesVectorTy> shiftedIndices; |
537 | // The embox/rebox and array_coor either have compatible |
538 | // shape/slice at this point or shape/slice is null |
539 | // in one of them but not in the other. |
540 | // The compatibility means they are equal or both null. |
541 | if (!op.getShape()) { |
542 | if (boxedShape) { |
543 | if (op.getSlice()) { |
544 | if (!boxedSlice) { |
545 | if (boxedShapeIsShift) { |
546 | // %0 = fir.rebox %arg(%shift) |
547 | // %1 = fir.array_coor %0 [%slice] %idx |
548 | // Both the slice indices and %idx are 1-based, so the rebox |
549 | // may be pulled in as: |
550 | // %1 = fir.array_coor %arg [%slice] %idx |
551 | boxedShape = nullptr; |
552 | } else if (boxedShapeIsShape) { |
553 | // %0 = fir.embox %arg(%shape) |
554 | // %1 = fir.array_coor %0 [%slice] %idx |
555 | // Pull in as: |
556 | // %1 = fir.array_coor %arg(%shape) [%slice] %idx |
557 | } else if (boxedShapeIsShapeShift) { |
558 | // %0 = fir.embox %arg(%shapeshift) |
559 | // %1 = fir.array_coor %0 [%slice] %idx |
560 | // Pull in as: |
561 | // %shape = fir.shape <extents from the %shapeshift> |
562 | // %1 = fir.array_coor %arg(%shape) [%slice] %idx |
563 | boxedShape = getShapeFromShapeShift(v: boxedShape, rewriter); |
564 | if (!boxedShape) |
565 | return mlir::failure(); |
566 | } else { |
567 | return mlir::failure(); |
568 | } |
569 | } else { |
570 | if (boxedShapeIsShift) { |
571 | // %0 = fir.rebox %arg(%shift) [%slice] |
572 | // %1 = fir.array_coor %0 [%slice] %idx |
573 | // This FIR may only be valid if the shape specifies |
574 | // that all lower bounds are 1s and the slice's start indices |
575 | // and strides are all 1s. |
576 | // We could pull in the rebox as: |
577 | // %1 = fir.array_coor %arg [%slice] %idx |
578 | // Do not do anything for the time being. |
579 | return mlir::failure(); |
580 | } else if (boxedShapeIsShape) { |
581 | // %0 = fir.embox %arg(%shape) [%slice] |
582 | // %1 = fir.array_coor %0 [%slice] %idx |
583 | // This FIR may only be valid if the slice's start indices |
584 | // and strides are all 1s. |
585 | // We could pull in the embox as: |
586 | // %1 = fir.array_coor %arg(%shape) [%slice] %idx |
587 | return mlir::failure(); |
588 | } else if (boxedShapeIsShapeShift) { |
589 | // %0 = fir.embox %arg(%shapeshift) [%slice] |
590 | // %1 = fir.array_coor %0 [%slice] %idx |
591 | // This FIR may only be valid if the shape specifies |
592 | // that all lower bounds are 1s and the slice's start indices |
593 | // and strides are all 1s. |
594 | // We could pull in the embox as: |
595 | // %shape = fir.shape <extents from the %shapeshift> |
596 | // %1 = fir.array_coor %arg(%shape) [%slice] %idx |
597 | return mlir::failure(); |
598 | } else { |
599 | return mlir::failure(); |
600 | } |
601 | } |
602 | } else { // !op.getSlice() |
603 | if (!boxedSlice) { |
604 | if (boxedShapeIsShift) { |
605 | // %0 = fir.rebox %arg(%shift) |
606 | // %1 = fir.array_coor %0 %idx |
607 | // Pull in as: |
608 | // %1 = fir.array_coor %arg %idx |
609 | boxedShape = nullptr; |
610 | } else if (boxedShapeIsShape) { |
611 | // %0 = fir.embox %arg(%shape) |
612 | // %1 = fir.array_coor %0 %idx |
613 | // Pull in as: |
614 | // %1 = fir.array_coor %arg(%shape) %idx |
615 | } else if (boxedShapeIsShapeShift) { |
616 | // %0 = fir.embox %arg(%shapeshift) |
617 | // %1 = fir.array_coor %0 %idx |
618 | // Pull in as: |
619 | // %shape = fir.shape <extents from the %shapeshift> |
620 | // %1 = fir.array_coor %arg(%shape) %idx |
621 | boxedShape = getShapeFromShapeShift(v: boxedShape, rewriter); |
622 | if (!boxedShape) |
623 | return mlir::failure(); |
624 | } else { |
625 | return mlir::failure(); |
626 | } |
627 | } else { |
628 | if (boxedShapeIsShift) { |
629 | // %0 = fir.embox %arg(%shift) [%slice] |
630 | // %1 = fir.array_coor %0 %idx |
631 | // Pull in as: |
632 | // %tmp = arith.addi %idx, %shift.origin |
633 | // %idx_shifted = arith.subi %tmp, 1 |
634 | // %1 = fir.array_coor %arg(%shift) %[slice] %idx_shifted |
635 | shiftedIndices = |
636 | getShiftedIndices(v: boxedShape, indices: op.getIndices(), rewriter); |
637 | if (!shiftedIndices) |
638 | return mlir::failure(); |
639 | } else if (boxedShapeIsShape) { |
640 | // %0 = fir.embox %arg(%shape) [%slice] |
641 | // %1 = fir.array_coor %0 %idx |
642 | // Pull in as: |
643 | // %1 = fir.array_coor %arg(%shape) %[slice] %idx |
644 | } else if (boxedShapeIsShapeShift) { |
645 | // %0 = fir.embox %arg(%shapeshift) [%slice] |
646 | // %1 = fir.array_coor %0 %idx |
647 | // Pull in as: |
648 | // %tmp = arith.addi %idx, %shapeshift.lb |
649 | // %idx_shifted = arith.subi %tmp, 1 |
650 | // %1 = fir.array_coor %arg(%shapeshift) %[slice] %idx_shifted |
651 | shiftedIndices = |
652 | getShiftedIndices(v: boxedShape, indices: op.getIndices(), rewriter); |
653 | if (!shiftedIndices) |
654 | return mlir::failure(); |
655 | } else { |
656 | return mlir::failure(); |
657 | } |
658 | } |
659 | } |
660 | } else { // !boxedShape |
661 | if (op.getSlice()) { |
662 | if (!boxedSlice) { |
663 | // %0 = fir.rebox %arg |
664 | // %1 = fir.array_coor %0 [%slice] %idx |
665 | // Pull in as: |
666 | // %1 = fir.array_coor %arg [%slice] %idx |
667 | } else { |
668 | // %0 = fir.rebox %arg [%slice] |
669 | // %1 = fir.array_coor %0 [%slice] %idx |
670 | // This is a valid FIR iff the slice's lower bounds |
671 | // and strides are all 1s. |
672 | // Pull in as: |
673 | // %1 = fir.array_coor %arg [%slice] %idx |
674 | } |
675 | } else { // !op.getSlice() |
676 | if (!boxedSlice) { |
677 | // %0 = fir.rebox %arg |
678 | // %1 = fir.array_coor %0 %idx |
679 | // Pull in as: |
680 | // %1 = fir.array_coor %arg %idx |
681 | } else { |
682 | // %0 = fir.rebox %arg [%slice] |
683 | // %1 = fir.array_coor %0 %idx |
684 | // Pull in as: |
685 | // %1 = fir.array_coor %arg [%slice] %idx |
686 | } |
687 | } |
688 | } |
689 | } else { // op.getShape() |
690 | if (boxedShape) { |
691 | // Check if pulling in non-default shape is correct. |
692 | if (op.getSlice()) { |
693 | if (!boxedSlice) { |
694 | // %0 = fir.embox %arg(%shape) |
695 | // %1 = fir.array_coor %0(%shape) [%slice] %idx |
696 | // Pull in as: |
697 | // %1 = fir.array_coor %arg(%shape) [%slice] %idx |
698 | } else { |
699 | // %0 = fir.embox %arg(%shape) [%slice] |
700 | // %1 = fir.array_coor %0(%shape) [%slice] %idx |
701 | // Pull in as: |
702 | // %1 = fir.array_coor %arg(%shape) [%slice] %idx |
703 | } |
704 | } else { // !op.getSlice() |
705 | if (!boxedSlice) { |
706 | // %0 = fir.embox %arg(%shape) |
707 | // %1 = fir.array_coor %0(%shape) %idx |
708 | // Pull in as: |
709 | // %1 = fir.array_coor %arg(%shape) %idx |
710 | } else { |
711 | // %0 = fir.embox %arg(%shape) [%slice] |
712 | // %1 = fir.array_coor %0(%shape) %idx |
713 | // Pull in as: |
714 | // %1 = fir.array_coor %arg(%shape) [%slice] %idx |
715 | } |
716 | } |
717 | } else { // !boxedShape |
718 | if (op.getSlice()) { |
719 | if (!boxedSlice) { |
720 | // %0 = fir.rebox %arg |
721 | // %1 = fir.array_coor %0(%shape) [%slice] %idx |
722 | // Pull in as: |
723 | // %1 = fir.array_coor %arg(%shape) [%slice] %idx |
724 | } else { |
725 | // %0 = fir.rebox %arg [%slice] |
726 | // %1 = fir.array_coor %0(%shape) [%slice] %idx |
727 | return mlir::failure(); |
728 | } |
729 | } else { // !op.getSlice() |
730 | if (!boxedSlice) { |
731 | // %0 = fir.rebox %arg |
732 | // %1 = fir.array_coor %0(%shape) %idx |
733 | // Pull in as: |
734 | // %1 = fir.array_coor %arg(%shape) %idx |
735 | } else { |
736 | // %0 = fir.rebox %arg [%slice] |
737 | // %1 = fir.array_coor %0(%shape) %idx |
738 | // Cannot pull in without adjusting the slice indices. |
739 | return mlir::failure(); |
740 | } |
741 | } |
742 | } |
743 | } |
744 | |
745 | // TODO: temporarily avoid producing array_coor with the shape shift |
746 | // and plain array reference (it seems to be a limitation of |
747 | // ArrayCoorOp verifier). |
748 | if (!mlir::isa<fir::BaseBoxType>(boxedMemref.getType())) { |
749 | if (boxedShape) { |
750 | if (mlir::isa<fir::ShiftType>(boxedShape.getType())) |
751 | return mlir::failure(); |
752 | } else if (op.getShape() && |
753 | mlir::isa<fir::ShiftType>(op.getShape().getType())) { |
754 | return mlir::failure(); |
755 | } |
756 | } |
757 | |
758 | rewriter.modifyOpInPlace(op, [&]() { |
759 | op.getMemrefMutable().assign(boxedMemref); |
760 | if (boxedShape) |
761 | op.getShapeMutable().assign(boxedShape); |
762 | if (boxedSlice) |
763 | op.getSliceMutable().assign(boxedSlice); |
764 | if (shiftedIndices) |
765 | op.getIndicesMutable().assign(*shiftedIndices); |
766 | }); |
767 | return mlir::success(); |
768 | } |
769 | |
770 | private: |
771 | using IndicesVectorTy = std::vector<mlir::Value>; |
772 | |
773 | // If v is a shape_shift operation: |
774 | // fir.shape_shift %l1, %e1, %l2, %e2, ... |
775 | // create: |
776 | // fir.shape %e1, %e2, ... |
777 | static mlir::Value getShapeFromShapeShift(mlir::Value v, |
778 | mlir::PatternRewriter &rewriter) { |
779 | auto shapeShiftOp = |
780 | mlir::dyn_cast_or_null<fir::ShapeShiftOp>(v.getDefiningOp()); |
781 | if (!shapeShiftOp) |
782 | return nullptr; |
783 | mlir::OpBuilder::InsertionGuard guard(rewriter); |
784 | rewriter.setInsertionPoint(shapeShiftOp); |
785 | return rewriter.create<fir::ShapeOp>(shapeShiftOp.getLoc(), |
786 | shapeShiftOp.getExtents()); |
787 | } |
788 | |
789 | static std::optional<IndicesVectorTy> |
790 | getShiftedIndices(mlir::Value v, mlir::ValueRange indices, |
791 | mlir::PatternRewriter &rewriter) { |
792 | auto insertAdjustments = [&](mlir::Operation *op, mlir::ValueRange lbs) { |
793 | // Compute the shifted indices using the extended type. |
794 | // Note that this can probably result in less efficient |
795 | // MLIR and further LLVM IR due to the extra conversions. |
796 | mlir::OpBuilder::InsertPoint savedIP = rewriter.saveInsertionPoint(); |
797 | rewriter.setInsertionPoint(op); |
798 | mlir::Location loc = op->getLoc(); |
799 | mlir::Type idxTy = rewriter.getIndexType(); |
800 | mlir::Value one = rewriter.create<mlir::arith::ConstantOp>( |
801 | loc, idxTy, rewriter.getIndexAttr(1)); |
802 | rewriter.restoreInsertionPoint(ip: savedIP); |
803 | auto nsw = mlir::arith::IntegerOverflowFlags::nsw; |
804 | |
805 | IndicesVectorTy shiftedIndices; |
806 | for (auto [lb, idx] : llvm::zip(t&: lbs, u&: indices)) { |
807 | mlir::Value extLb = rewriter.create<fir::ConvertOp>(loc, idxTy, lb); |
808 | mlir::Value extIdx = rewriter.create<fir::ConvertOp>(loc, idxTy, idx); |
809 | mlir::Value add = |
810 | rewriter.create<mlir::arith::AddIOp>(loc, extIdx, extLb, nsw); |
811 | mlir::Value sub = |
812 | rewriter.create<mlir::arith::SubIOp>(loc, add, one, nsw); |
813 | shiftedIndices.push_back(x: sub); |
814 | } |
815 | |
816 | return shiftedIndices; |
817 | }; |
818 | |
819 | if (auto shiftOp = |
820 | mlir::dyn_cast_or_null<fir::ShiftOp>(v.getDefiningOp())) { |
821 | return insertAdjustments(shiftOp.getOperation(), shiftOp.getOrigins()); |
822 | } else if (auto shapeShiftOp = mlir::dyn_cast_or_null<fir::ShapeShiftOp>( |
823 | v.getDefiningOp())) { |
824 | return insertAdjustments(shapeShiftOp.getOperation(), |
825 | shapeShiftOp.getOrigins()); |
826 | } |
827 | |
828 | return std::nullopt; |
829 | } |
830 | }; |
831 | |
832 | void fir::ArrayCoorOp::getCanonicalizationPatterns( |
833 | mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { |
834 | // TODO: !fir.shape<1> operand may be removed from array_coor always. |
835 | patterns.add<SimplifyArrayCoorOp>(context); |
836 | } |
837 | |
838 | //===----------------------------------------------------------------------===// |
839 | // ArrayLoadOp |
840 | //===----------------------------------------------------------------------===// |
841 | |
842 | static mlir::Type adjustedElementType(mlir::Type t) { |
843 | if (auto ty = mlir::dyn_cast<fir::ReferenceType>(t)) { |
844 | auto eleTy = ty.getEleTy(); |
845 | if (fir::isa_char(eleTy)) |
846 | return eleTy; |
847 | if (fir::isa_derived(eleTy)) |
848 | return eleTy; |
849 | if (mlir::isa<fir::SequenceType>(eleTy)) |
850 | return eleTy; |
851 | } |
852 | return t; |
853 | } |
854 | |
855 | std::vector<mlir::Value> fir::ArrayLoadOp::getExtents() { |
856 | if (auto sh = getShape()) |
857 | if (auto *op = sh.getDefiningOp()) { |
858 | if (auto shOp = mlir::dyn_cast<fir::ShapeOp>(op)) { |
859 | auto extents = shOp.getExtents(); |
860 | return {extents.begin(), extents.end()}; |
861 | } |
862 | return mlir::cast<fir::ShapeShiftOp>(op).getExtents(); |
863 | } |
864 | return {}; |
865 | } |
866 | |
867 | void fir::ArrayLoadOp::getEffects( |
868 | llvm::SmallVectorImpl< |
869 | mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>> |
870 | &effects) { |
871 | effects.emplace_back(mlir::MemoryEffects::Read::get(), &getMemrefMutable(), |
872 | mlir::SideEffects::DefaultResource::get()); |
873 | addVolatileMemoryEffects({getMemref().getType()}, effects); |
874 | } |
875 | |
876 | llvm::LogicalResult fir::ArrayLoadOp::verify() { |
877 | auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(getMemref().getType()); |
878 | auto arrTy = mlir::dyn_cast<fir::SequenceType>(eleTy); |
879 | if (!arrTy) |
880 | return emitOpError("must be a reference to an array" ); |
881 | auto arrDim = arrTy.getDimension(); |
882 | |
883 | if (auto shapeOp = getShape()) { |
884 | auto shapeTy = shapeOp.getType(); |
885 | unsigned shapeTyRank = 0u; |
886 | if (auto s = mlir::dyn_cast<fir::ShapeType>(shapeTy)) { |
887 | shapeTyRank = s.getRank(); |
888 | } else if (auto ss = mlir::dyn_cast<fir::ShapeShiftType>(shapeTy)) { |
889 | shapeTyRank = ss.getRank(); |
890 | } else { |
891 | auto s = mlir::cast<fir::ShiftType>(shapeTy); |
892 | shapeTyRank = s.getRank(); |
893 | if (!mlir::isa<fir::BaseBoxType>(getMemref().getType())) |
894 | return emitOpError("shift can only be provided with fir.box memref" ); |
895 | } |
896 | if (arrDim && arrDim != shapeTyRank) |
897 | return emitOpError("rank of dimension mismatched" ); |
898 | } |
899 | |
900 | if (auto sliceOp = getSlice()) { |
901 | if (auto sl = mlir::dyn_cast_or_null<fir::SliceOp>(sliceOp.getDefiningOp())) |
902 | if (!sl.getSubstr().empty()) |
903 | return emitOpError("array_load cannot take a slice with substring" ); |
904 | if (auto sliceTy = mlir::dyn_cast<fir::SliceType>(sliceOp.getType())) |
905 | if (sliceTy.getRank() != arrDim) |
906 | return emitOpError("rank of dimension in slice mismatched" ); |
907 | } |
908 | |
909 | if (!validTypeParams(getMemref().getType(), getTypeparams())) |
910 | return emitOpError("invalid type parameters" ); |
911 | |
912 | return mlir::success(); |
913 | } |
914 | |
915 | //===----------------------------------------------------------------------===// |
916 | // ArrayMergeStoreOp |
917 | //===----------------------------------------------------------------------===// |
918 | |
919 | llvm::LogicalResult fir::ArrayMergeStoreOp::verify() { |
920 | if (!mlir::isa<fir::ArrayLoadOp>(getOriginal().getDefiningOp())) |
921 | return emitOpError("operand #0 must be result of a fir.array_load op" ); |
922 | if (auto sl = getSlice()) { |
923 | if (auto sliceOp = |
924 | mlir::dyn_cast_or_null<fir::SliceOp>(sl.getDefiningOp())) { |
925 | if (!sliceOp.getSubstr().empty()) |
926 | return emitOpError( |
927 | "array_merge_store cannot take a slice with substring" ); |
928 | if (!sliceOp.getFields().empty()) { |
929 | // This is an intra-object merge, where the slice is projecting the |
930 | // subfields that are to be overwritten by the merge operation. |
931 | auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(getMemref().getType()); |
932 | if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy)) { |
933 | auto projTy = |
934 | fir::applyPathToType(seqTy.getEleTy(), sliceOp.getFields()); |
935 | if (fir::unwrapSequenceType(getOriginal().getType()) != projTy) |
936 | return emitOpError( |
937 | "type of origin does not match sliced memref type" ); |
938 | if (fir::unwrapSequenceType(getSequence().getType()) != projTy) |
939 | return emitOpError( |
940 | "type of sequence does not match sliced memref type" ); |
941 | return mlir::success(); |
942 | } |
943 | return emitOpError("referenced type is not an array" ); |
944 | } |
945 | } |
946 | return mlir::success(); |
947 | } |
948 | auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(getMemref().getType()); |
949 | if (getOriginal().getType() != eleTy) |
950 | return emitOpError("type of origin does not match memref element type" ); |
951 | if (getSequence().getType() != eleTy) |
952 | return emitOpError("type of sequence does not match memref element type" ); |
953 | if (!validTypeParams(getMemref().getType(), getTypeparams())) |
954 | return emitOpError("invalid type parameters" ); |
955 | return mlir::success(); |
956 | } |
957 | |
958 | void fir::ArrayMergeStoreOp::getEffects( |
959 | llvm::SmallVectorImpl< |
960 | mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>> |
961 | &effects) { |
962 | effects.emplace_back(mlir::MemoryEffects::Write::get(), &getMemrefMutable(), |
963 | mlir::SideEffects::DefaultResource::get()); |
964 | addVolatileMemoryEffects({getMemref().getType()}, effects); |
965 | } |
966 | |
967 | //===----------------------------------------------------------------------===// |
968 | // ArrayFetchOp |
969 | //===----------------------------------------------------------------------===// |
970 | |
971 | // Template function used for both array_fetch and array_update verification. |
972 | template <typename A> |
973 | mlir::Type validArraySubobject(A op) { |
974 | auto ty = op.getSequence().getType(); |
975 | return fir::applyPathToType(ty, op.getIndices()); |
976 | } |
977 | |
978 | llvm::LogicalResult fir::ArrayFetchOp::verify() { |
979 | auto arrTy = mlir::cast<fir::SequenceType>(getSequence().getType()); |
980 | auto indSize = getIndices().size(); |
981 | if (indSize < arrTy.getDimension()) |
982 | return emitOpError("number of indices != dimension of array" ); |
983 | if (indSize == arrTy.getDimension() && |
984 | ::adjustedElementType(getElement().getType()) != arrTy.getEleTy()) |
985 | return emitOpError("return type does not match array" ); |
986 | auto ty = validArraySubobject(*this); |
987 | if (!ty || ty != ::adjustedElementType(getType())) |
988 | return emitOpError("return type and/or indices do not type check" ); |
989 | if (!mlir::isa<fir::ArrayLoadOp>(getSequence().getDefiningOp())) |
990 | return emitOpError("argument #0 must be result of fir.array_load" ); |
991 | if (!validTypeParams(arrTy, getTypeparams())) |
992 | return emitOpError("invalid type parameters" ); |
993 | return mlir::success(); |
994 | } |
995 | |
996 | //===----------------------------------------------------------------------===// |
997 | // ArrayAccessOp |
998 | //===----------------------------------------------------------------------===// |
999 | |
1000 | llvm::LogicalResult fir::ArrayAccessOp::verify() { |
1001 | auto arrTy = mlir::cast<fir::SequenceType>(getSequence().getType()); |
1002 | std::size_t indSize = getIndices().size(); |
1003 | if (indSize < arrTy.getDimension()) |
1004 | return emitOpError("number of indices != dimension of array" ); |
1005 | if (indSize == arrTy.getDimension() && |
1006 | getElement().getType() != fir::ReferenceType::get(arrTy.getEleTy())) |
1007 | return emitOpError("return type does not match array" ); |
1008 | mlir::Type ty = validArraySubobject(*this); |
1009 | if (!ty || fir::ReferenceType::get(ty) != getType()) |
1010 | return emitOpError("return type and/or indices do not type check" ); |
1011 | if (!validTypeParams(arrTy, getTypeparams())) |
1012 | return emitOpError("invalid type parameters" ); |
1013 | return mlir::success(); |
1014 | } |
1015 | |
1016 | //===----------------------------------------------------------------------===// |
1017 | // ArrayUpdateOp |
1018 | //===----------------------------------------------------------------------===// |
1019 | |
1020 | llvm::LogicalResult fir::ArrayUpdateOp::verify() { |
1021 | if (fir::isa_ref_type(getMerge().getType())) |
1022 | return emitOpError("does not support reference type for merge" ); |
1023 | auto arrTy = mlir::cast<fir::SequenceType>(getSequence().getType()); |
1024 | auto indSize = getIndices().size(); |
1025 | if (indSize < arrTy.getDimension()) |
1026 | return emitOpError("number of indices != dimension of array" ); |
1027 | if (indSize == arrTy.getDimension() && |
1028 | ::adjustedElementType(getMerge().getType()) != arrTy.getEleTy()) |
1029 | return emitOpError("merged value does not have element type" ); |
1030 | auto ty = validArraySubobject(*this); |
1031 | if (!ty || ty != ::adjustedElementType(getMerge().getType())) |
1032 | return emitOpError("merged value and/or indices do not type check" ); |
1033 | if (!validTypeParams(arrTy, getTypeparams())) |
1034 | return emitOpError("invalid type parameters" ); |
1035 | return mlir::success(); |
1036 | } |
1037 | |
1038 | //===----------------------------------------------------------------------===// |
1039 | // ArrayModifyOp |
1040 | //===----------------------------------------------------------------------===// |
1041 | |
1042 | llvm::LogicalResult fir::ArrayModifyOp::verify() { |
1043 | auto arrTy = mlir::cast<fir::SequenceType>(getSequence().getType()); |
1044 | auto indSize = getIndices().size(); |
1045 | if (indSize < arrTy.getDimension()) |
1046 | return emitOpError("number of indices must match array dimension" ); |
1047 | return mlir::success(); |
1048 | } |
1049 | |
1050 | //===----------------------------------------------------------------------===// |
1051 | // BoxAddrOp |
1052 | //===----------------------------------------------------------------------===// |
1053 | |
1054 | void fir::BoxAddrOp::build(mlir::OpBuilder &builder, |
1055 | mlir::OperationState &result, mlir::Value val) { |
1056 | mlir::Type type = |
1057 | llvm::TypeSwitch<mlir::Type, mlir::Type>(val.getType()) |
1058 | .Case<fir::BaseBoxType>([&](fir::BaseBoxType ty) -> mlir::Type { |
1059 | mlir::Type eleTy = ty.getEleTy(); |
1060 | if (fir::isa_ref_type(eleTy)) |
1061 | return eleTy; |
1062 | return fir::ReferenceType::get(eleTy); |
1063 | }) |
1064 | .Case<fir::BoxCharType>([&](fir::BoxCharType ty) -> mlir::Type { |
1065 | return fir::ReferenceType::get(ty.getEleTy()); |
1066 | }) |
1067 | .Case<fir::BoxProcType>( |
1068 | [&](fir::BoxProcType ty) { return ty.getEleTy(); }) |
1069 | .Default([&](const auto &) { return mlir::Type{}; }); |
1070 | assert(type && "bad val type" ); |
1071 | build(builder, result, type, val); |
1072 | } |
1073 | |
1074 | mlir::OpFoldResult fir::BoxAddrOp::fold(FoldAdaptor adaptor) { |
1075 | if (auto *v = getVal().getDefiningOp()) { |
1076 | if (auto box = mlir::dyn_cast<fir::EmboxOp>(v)) { |
1077 | // Fold only if not sliced |
1078 | if (!box.getSlice() && box.getMemref().getType() == getType()) { |
1079 | propagateAttributes(getOperation(), box.getMemref().getDefiningOp()); |
1080 | return box.getMemref(); |
1081 | } |
1082 | } |
1083 | if (auto box = mlir::dyn_cast<fir::EmboxCharOp>(v)) |
1084 | if (box.getMemref().getType() == getType()) |
1085 | return box.getMemref(); |
1086 | } |
1087 | return {}; |
1088 | } |
1089 | |
1090 | //===----------------------------------------------------------------------===// |
1091 | // BoxCharLenOp |
1092 | //===----------------------------------------------------------------------===// |
1093 | |
1094 | mlir::OpFoldResult fir::BoxCharLenOp::fold(FoldAdaptor adaptor) { |
1095 | if (auto v = getVal().getDefiningOp()) { |
1096 | if (auto box = mlir::dyn_cast<fir::EmboxCharOp>(v)) |
1097 | return box.getLen(); |
1098 | } |
1099 | return {}; |
1100 | } |
1101 | |
1102 | //===----------------------------------------------------------------------===// |
1103 | // BoxDimsOp |
1104 | //===----------------------------------------------------------------------===// |
1105 | |
1106 | /// Get the result types packed in a tuple tuple |
1107 | mlir::Type fir::BoxDimsOp::getTupleType() { |
1108 | // note: triple, but 4 is nearest power of 2 |
1109 | llvm::SmallVector<mlir::Type> triple{ |
1110 | getResult(0).getType(), getResult(1).getType(), getResult(2).getType()}; |
1111 | return mlir::TupleType::get(getContext(), triple); |
1112 | } |
1113 | |
1114 | //===----------------------------------------------------------------------===// |
1115 | // BoxRankOp |
1116 | //===----------------------------------------------------------------------===// |
1117 | |
1118 | void fir::BoxRankOp::getEffects( |
1119 | llvm::SmallVectorImpl< |
1120 | mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>> |
1121 | &effects) { |
1122 | mlir::OpOperand &inputBox = getBoxMutable(); |
1123 | if (fir::isBoxAddress(inputBox.get().getType())) |
1124 | effects.emplace_back(mlir::MemoryEffects::Read::get(), &inputBox, |
1125 | mlir::SideEffects::DefaultResource::get()); |
1126 | } |
1127 | |
1128 | //===----------------------------------------------------------------------===// |
1129 | // CallOp |
1130 | //===----------------------------------------------------------------------===// |
1131 | |
1132 | mlir::FunctionType fir::CallOp::getFunctionType() { |
1133 | return mlir::FunctionType::get(getContext(), getOperandTypes(), |
1134 | getResultTypes()); |
1135 | } |
1136 | |
1137 | void fir::CallOp::print(mlir::OpAsmPrinter &p) { |
1138 | bool isDirect = getCallee().has_value(); |
1139 | p << ' '; |
1140 | if (isDirect) |
1141 | p << *getCallee(); |
1142 | else |
1143 | p << getOperand(0); |
1144 | p << '(' << (*this)->getOperands().drop_front(isDirect ? 0 : 1) << ')'; |
1145 | |
1146 | // Print `proc_attrs<...>`, if present. |
1147 | fir::FortranProcedureFlagsEnumAttr procAttrs = getProcedureAttrsAttr(); |
1148 | if (procAttrs && |
1149 | procAttrs.getValue() != fir::FortranProcedureFlagsEnum::none) { |
1150 | p << ' ' << fir::FortranProcedureFlagsEnumAttr::getMnemonic(); |
1151 | p.printStrippedAttrOrType(procAttrs); |
1152 | } |
1153 | |
1154 | // Print 'fastmath<...>' (if it has non-default value) before |
1155 | // any other attributes. |
1156 | mlir::arith::FastMathFlagsAttr fmfAttr = getFastmathAttr(); |
1157 | if (fmfAttr.getValue() != mlir::arith::FastMathFlags::none) { |
1158 | p << ' ' << mlir::arith::FastMathFlagsAttr::getMnemonic(); |
1159 | p.printStrippedAttrOrType(fmfAttr); |
1160 | } |
1161 | |
1162 | p.printOptionalAttrDict((*this)->getAttrs(), |
1163 | {fir::CallOp::getCalleeAttrNameStr(), |
1164 | getFastmathAttrName(), getProcedureAttrsAttrName(), |
1165 | getArgAttrsAttrName(), getResAttrsAttrName()}); |
1166 | p << " : " ; |
1167 | mlir::call_interface_impl::printFunctionSignature( |
1168 | p, getArgs().drop_front(isDirect ? 0 : 1).getTypes(), getArgAttrsAttr(), |
1169 | /*isVariadic=*/false, getResultTypes(), getResAttrsAttr()); |
1170 | } |
1171 | |
1172 | mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser, |
1173 | mlir::OperationState &result) { |
1174 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands; |
1175 | if (parser.parseOperandList(operands)) |
1176 | return mlir::failure(); |
1177 | |
1178 | mlir::NamedAttrList attrs; |
1179 | mlir::SymbolRefAttr funcAttr; |
1180 | bool isDirect = operands.empty(); |
1181 | if (isDirect) |
1182 | if (parser.parseAttribute(funcAttr, fir::CallOp::getCalleeAttrNameStr(), |
1183 | attrs)) |
1184 | return mlir::failure(); |
1185 | |
1186 | if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren)) |
1187 | return mlir::failure(); |
1188 | |
1189 | // Parse `proc_attrs<...>`, if present. |
1190 | fir::FortranProcedureFlagsEnumAttr procAttr; |
1191 | if (mlir::succeeded(parser.parseOptionalKeyword( |
1192 | fir::FortranProcedureFlagsEnumAttr::getMnemonic()))) |
1193 | if (parser.parseCustomAttributeWithFallback( |
1194 | procAttr, mlir::Type{}, getProcedureAttrsAttrName(result.name), |
1195 | attrs)) |
1196 | return mlir::failure(); |
1197 | |
1198 | // Parse 'fastmath<...>', if present. |
1199 | mlir::arith::FastMathFlagsAttr fmfAttr; |
1200 | llvm::StringRef fmfAttrName = getFastmathAttrName(result.name); |
1201 | if (mlir::succeeded(parser.parseOptionalKeyword(fmfAttrName))) |
1202 | if (parser.parseCustomAttributeWithFallback(fmfAttr, mlir::Type{}, |
1203 | fmfAttrName, attrs)) |
1204 | return mlir::failure(); |
1205 | |
1206 | if (parser.parseOptionalAttrDict(attrs) || parser.parseColon()) |
1207 | return mlir::failure(); |
1208 | llvm::SmallVector<mlir::Type> argTypes; |
1209 | llvm::SmallVector<mlir::Type> resTypes; |
1210 | llvm::SmallVector<mlir::DictionaryAttr> argAttrs; |
1211 | llvm::SmallVector<mlir::DictionaryAttr> resultAttrs; |
1212 | if (mlir::call_interface_impl::parseFunctionSignature( |
1213 | parser, argTypes, argAttrs, resTypes, resultAttrs)) |
1214 | return parser.emitError(parser.getNameLoc(), "expected function type" ); |
1215 | mlir::FunctionType funcType = |
1216 | mlir::FunctionType::get(parser.getContext(), argTypes, resTypes); |
1217 | if (isDirect) { |
1218 | if (parser.resolveOperands(operands, funcType.getInputs(), |
1219 | parser.getNameLoc(), result.operands)) |
1220 | return mlir::failure(); |
1221 | } else { |
1222 | auto funcArgs = |
1223 | llvm::ArrayRef<mlir::OpAsmParser::UnresolvedOperand>(operands) |
1224 | .drop_front(); |
1225 | if (parser.resolveOperand(operands[0], funcType, result.operands) || |
1226 | parser.resolveOperands(funcArgs, funcType.getInputs(), |
1227 | parser.getNameLoc(), result.operands)) |
1228 | return mlir::failure(); |
1229 | } |
1230 | result.attributes = attrs; |
1231 | mlir::call_interface_impl::addArgAndResultAttrs( |
1232 | parser.getBuilder(), result, argAttrs, resultAttrs, |
1233 | getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); |
1234 | result.addTypes(funcType.getResults()); |
1235 | return mlir::success(); |
1236 | } |
1237 | |
1238 | void fir::CallOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
1239 | mlir::func::FuncOp callee, mlir::ValueRange operands) { |
1240 | result.addOperands(operands); |
1241 | result.addAttribute(getCalleeAttrNameStr(), mlir::SymbolRefAttr::get(callee)); |
1242 | result.addTypes(callee.getFunctionType().getResults()); |
1243 | } |
1244 | |
1245 | void fir::CallOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
1246 | mlir::SymbolRefAttr callee, |
1247 | llvm::ArrayRef<mlir::Type> results, |
1248 | mlir::ValueRange operands) { |
1249 | result.addOperands(operands); |
1250 | if (callee) |
1251 | result.addAttribute(getCalleeAttrNameStr(), callee); |
1252 | result.addTypes(results); |
1253 | } |
1254 | |
1255 | //===----------------------------------------------------------------------===// |
1256 | // CharConvertOp |
1257 | //===----------------------------------------------------------------------===// |
1258 | |
1259 | llvm::LogicalResult fir::CharConvertOp::verify() { |
1260 | auto unwrap = [&](mlir::Type t) { |
1261 | t = fir::unwrapSequenceType(fir::dyn_cast_ptrEleTy(t)); |
1262 | return mlir::dyn_cast<fir::CharacterType>(t); |
1263 | }; |
1264 | auto inTy = unwrap(getFrom().getType()); |
1265 | auto outTy = unwrap(getTo().getType()); |
1266 | if (!(inTy && outTy)) |
1267 | return emitOpError("not a reference to a character" ); |
1268 | if (inTy.getFKind() == outTy.getFKind()) |
1269 | return emitOpError("buffers must have different KIND values" ); |
1270 | return mlir::success(); |
1271 | } |
1272 | |
1273 | //===----------------------------------------------------------------------===// |
1274 | // CmpOp |
1275 | //===----------------------------------------------------------------------===// |
1276 | |
1277 | template <typename OPTY> |
1278 | static void printCmpOp(mlir::OpAsmPrinter &p, OPTY op) { |
1279 | p << ' '; |
1280 | auto predSym = mlir::arith::symbolizeCmpFPredicate( |
1281 | op->template getAttrOfType<mlir::IntegerAttr>( |
1282 | OPTY::getPredicateAttrName()) |
1283 | .getInt()); |
1284 | assert(predSym.has_value() && "invalid symbol value for predicate" ); |
1285 | p << '"' << mlir::arith::stringifyCmpFPredicate(predSym.value()) << '"' |
1286 | << ", " ; |
1287 | p.printOperand(op.getLhs()); |
1288 | p << ", " ; |
1289 | p.printOperand(op.getRhs()); |
1290 | p.printOptionalAttrDict(attrs: op->getAttrs(), |
1291 | /*elidedAttrs=*/{OPTY::getPredicateAttrName()}); |
1292 | p << " : " << op.getLhs().getType(); |
1293 | } |
1294 | |
1295 | template <typename OPTY> |
1296 | static mlir::ParseResult parseCmpOp(mlir::OpAsmParser &parser, |
1297 | mlir::OperationState &result) { |
1298 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> ops; |
1299 | mlir::NamedAttrList attrs; |
1300 | mlir::Attribute predicateNameAttr; |
1301 | mlir::Type type; |
1302 | if (parser.parseAttribute(predicateNameAttr, OPTY::getPredicateAttrName(), |
1303 | attrs) || |
1304 | parser.parseComma() || parser.parseOperandList(result&: ops, requiredOperandCount: 2) || |
1305 | parser.parseOptionalAttrDict(result&: attrs) || parser.parseColonType(result&: type) || |
1306 | parser.resolveOperands(operands&: ops, type, result&: result.operands)) |
1307 | return mlir::failure(); |
1308 | |
1309 | if (!mlir::isa<mlir::StringAttr>(Val: predicateNameAttr)) |
1310 | return parser.emitError(loc: parser.getNameLoc(), |
1311 | message: "expected string comparison predicate attribute" ); |
1312 | |
1313 | // Rewrite string attribute to an enum value. |
1314 | llvm::StringRef predicateName = |
1315 | mlir::cast<mlir::StringAttr>(predicateNameAttr).getValue(); |
1316 | auto predicate = fir::CmpcOp::getPredicateByName(predicateName); |
1317 | auto builder = parser.getBuilder(); |
1318 | mlir::Type i1Type = builder.getI1Type(); |
1319 | attrs.set(OPTY::getPredicateAttrName(), |
1320 | builder.getI64IntegerAttr(static_cast<std::int64_t>(predicate))); |
1321 | result.attributes = attrs; |
1322 | result.addTypes(newTypes: {i1Type}); |
1323 | return mlir::success(); |
1324 | } |
1325 | |
1326 | //===----------------------------------------------------------------------===// |
1327 | // CmpcOp |
1328 | //===----------------------------------------------------------------------===// |
1329 | |
1330 | void fir::buildCmpCOp(mlir::OpBuilder &builder, mlir::OperationState &result, |
1331 | mlir::arith::CmpFPredicate predicate, mlir::Value lhs, |
1332 | mlir::Value rhs) { |
1333 | result.addOperands({lhs, rhs}); |
1334 | result.types.push_back(builder.getI1Type()); |
1335 | result.addAttribute( |
1336 | fir::CmpcOp::getPredicateAttrName(), |
1337 | builder.getI64IntegerAttr(static_cast<std::int64_t>(predicate))); |
1338 | } |
1339 | |
1340 | mlir::arith::CmpFPredicate |
1341 | fir::CmpcOp::getPredicateByName(llvm::StringRef name) { |
1342 | auto pred = mlir::arith::symbolizeCmpFPredicate(name); |
1343 | assert(pred.has_value() && "invalid predicate name" ); |
1344 | return pred.value(); |
1345 | } |
1346 | |
1347 | void fir::CmpcOp::print(mlir::OpAsmPrinter &p) { printCmpOp(p, *this); } |
1348 | |
1349 | mlir::ParseResult fir::CmpcOp::parse(mlir::OpAsmParser &parser, |
1350 | mlir::OperationState &result) { |
1351 | return parseCmpOp<fir::CmpcOp>(parser, result); |
1352 | } |
1353 | |
1354 | //===----------------------------------------------------------------------===// |
1355 | // VolatileCastOp |
1356 | //===----------------------------------------------------------------------===// |
1357 | |
1358 | static bool typesMatchExceptForVolatility(mlir::Type fromType, |
1359 | mlir::Type toType) { |
1360 | // If we can change only the volatility and get identical types, then we |
1361 | // match. |
1362 | if (fir::updateTypeWithVolatility(fromType, fir::isa_volatile_type(toType)) == |
1363 | toType) |
1364 | return true; |
1365 | |
1366 | // Otherwise, recurse on the element types if the base classes are the same. |
1367 | const bool match = |
1368 | llvm::TypeSwitch<mlir::Type, bool>(fromType) |
1369 | .Case<fir::BoxType, fir::ReferenceType, fir::ClassType>( |
1370 | [&](auto type) { |
1371 | using TYPE = decltype(type); |
1372 | // If we are not the same base class, then we don't match. |
1373 | auto castedToType = mlir::dyn_cast<TYPE>(toType); |
1374 | if (!castedToType) |
1375 | return false; |
1376 | // If we are the same base class, we match if the element types |
1377 | // match. |
1378 | return typesMatchExceptForVolatility(type.getEleTy(), |
1379 | castedToType.getEleTy()); |
1380 | }) |
1381 | .Default([](mlir::Type) { return false; }); |
1382 | |
1383 | return match; |
1384 | } |
1385 | |
1386 | llvm::LogicalResult fir::VolatileCastOp::verify() { |
1387 | mlir::Type fromType = getValue().getType(); |
1388 | mlir::Type toType = getType(); |
1389 | if (!typesMatchExceptForVolatility(fromType, toType)) |
1390 | return emitOpError("types must be identical except for volatility " ) |
1391 | << fromType << " / " << toType; |
1392 | return mlir::success(); |
1393 | } |
1394 | |
1395 | mlir::OpFoldResult fir::VolatileCastOp::fold(FoldAdaptor adaptor) { |
1396 | if (getValue().getType() == getType()) |
1397 | return getValue(); |
1398 | return {}; |
1399 | } |
1400 | |
1401 | //===----------------------------------------------------------------------===// |
1402 | // ConvertOp |
1403 | //===----------------------------------------------------------------------===// |
1404 | |
1405 | void fir::ConvertOp::getCanonicalizationPatterns( |
1406 | mlir::RewritePatternSet &results, mlir::MLIRContext *context) { |
1407 | results.insert<ConvertConvertOptPattern, ConvertAscendingIndexOptPattern, |
1408 | ConvertDescendingIndexOptPattern, RedundantConvertOptPattern, |
1409 | CombineConvertOptPattern, CombineConvertTruncOptPattern, |
1410 | ForwardConstantConvertPattern, ChainedPointerConvertsPattern>( |
1411 | context); |
1412 | } |
1413 | |
1414 | mlir::OpFoldResult fir::ConvertOp::fold(FoldAdaptor adaptor) { |
1415 | if (getValue().getType() == getType()) |
1416 | return getValue(); |
1417 | if (matchPattern(getValue(), mlir::m_Op<fir::ConvertOp>())) { |
1418 | auto inner = mlir::cast<fir::ConvertOp>(getValue().getDefiningOp()); |
1419 | // (convert (convert 'a : logical -> i1) : i1 -> logical) ==> forward 'a |
1420 | if (auto toTy = mlir::dyn_cast<fir::LogicalType>(getType())) |
1421 | if (auto fromTy = |
1422 | mlir::dyn_cast<fir::LogicalType>(inner.getValue().getType())) |
1423 | if (mlir::isa<mlir::IntegerType>(inner.getType()) && (toTy == fromTy)) |
1424 | return inner.getValue(); |
1425 | // (convert (convert 'a : i1 -> logical) : logical -> i1) ==> forward 'a |
1426 | if (auto toTy = mlir::dyn_cast<mlir::IntegerType>(getType())) |
1427 | if (auto fromTy = |
1428 | mlir::dyn_cast<mlir::IntegerType>(inner.getValue().getType())) |
1429 | if (mlir::isa<fir::LogicalType>(inner.getType()) && (toTy == fromTy) && |
1430 | (fromTy.getWidth() == 1)) |
1431 | return inner.getValue(); |
1432 | } |
1433 | return {}; |
1434 | } |
1435 | |
1436 | bool fir::ConvertOp::isInteger(mlir::Type ty) { |
1437 | return mlir::isa<mlir::IntegerType, mlir::IndexType, fir::IntegerType>(ty); |
1438 | } |
1439 | |
1440 | bool fir::ConvertOp::isIntegerCompatible(mlir::Type ty) { |
1441 | return isInteger(ty) || mlir::isa<fir::LogicalType>(ty); |
1442 | } |
1443 | |
1444 | bool fir::ConvertOp::isFloatCompatible(mlir::Type ty) { |
1445 | return mlir::isa<mlir::FloatType>(ty); |
1446 | } |
1447 | |
1448 | bool fir::ConvertOp::isPointerCompatible(mlir::Type ty) { |
1449 | return mlir::isa<fir::ReferenceType, fir::PointerType, fir::HeapType, |
1450 | fir::LLVMPointerType, mlir::MemRefType, mlir::FunctionType, |
1451 | fir::TypeDescType, mlir::LLVM::LLVMPointerType>(ty); |
1452 | } |
1453 | |
1454 | static std::optional<mlir::Type> getVectorElementType(mlir::Type ty) { |
1455 | mlir::Type elemTy; |
1456 | if (mlir::isa<fir::VectorType>(ty)) |
1457 | elemTy = mlir::dyn_cast<fir::VectorType>(ty).getElementType(); |
1458 | else if (mlir::isa<mlir::VectorType>(Val: ty)) |
1459 | elemTy = mlir::dyn_cast<mlir::VectorType>(ty).getElementType(); |
1460 | else |
1461 | return std::nullopt; |
1462 | |
1463 | // e.g. fir.vector<4:ui32> => mlir.vector<4xi32> |
1464 | // e.g. mlir.vector<4xui32> => mlir.vector<4xi32> |
1465 | if (elemTy.isUnsignedInteger()) { |
1466 | elemTy = mlir::IntegerType::get( |
1467 | ty.getContext(), mlir::dyn_cast<mlir::IntegerType>(elemTy).getWidth()); |
1468 | } |
1469 | return elemTy; |
1470 | } |
1471 | |
1472 | static std::optional<uint64_t> getVectorLen(mlir::Type ty) { |
1473 | if (mlir::isa<fir::VectorType>(ty)) |
1474 | return mlir::dyn_cast<fir::VectorType>(ty).getLen(); |
1475 | else if (mlir::isa<mlir::VectorType>(Val: ty)) { |
1476 | // fir.vector only supports 1-D vector |
1477 | if (!(mlir::dyn_cast<mlir::VectorType>(ty).isScalable())) |
1478 | return mlir::dyn_cast<mlir::VectorType>(ty).getShape()[0]; |
1479 | } |
1480 | |
1481 | return std::nullopt; |
1482 | } |
1483 | |
1484 | bool fir::ConvertOp::areVectorsCompatible(mlir::Type inTy, mlir::Type outTy) { |
1485 | if (!(mlir::isa<fir::VectorType>(inTy) && |
1486 | mlir::isa<mlir::VectorType>(outTy)) && |
1487 | !(mlir::isa<mlir::VectorType>(inTy) && mlir::isa<fir::VectorType>(outTy))) |
1488 | return false; |
1489 | |
1490 | // Only support integer, unsigned and real vector |
1491 | // Both vectors must have the same element type |
1492 | std::optional<mlir::Type> inElemTy = getVectorElementType(inTy); |
1493 | std::optional<mlir::Type> outElemTy = getVectorElementType(outTy); |
1494 | if (!inElemTy.has_value() || !outElemTy.has_value() || |
1495 | inElemTy.value() != outElemTy.value()) |
1496 | return false; |
1497 | |
1498 | // Both vectors must have the same number of elements |
1499 | std::optional<uint64_t> inLen = getVectorLen(inTy); |
1500 | std::optional<uint64_t> outLen = getVectorLen(outTy); |
1501 | if (!inLen.has_value() || !outLen.has_value() || |
1502 | inLen.value() != outLen.value()) |
1503 | return false; |
1504 | |
1505 | return true; |
1506 | } |
1507 | |
1508 | static bool areRecordsCompatible(mlir::Type inTy, mlir::Type outTy) { |
1509 | // Both records must have the same field types. |
1510 | // Trust frontend semantics for in-depth checks, such as if both records |
1511 | // have the BIND(C) attribute. |
1512 | auto inRecTy = mlir::dyn_cast<fir::RecordType>(inTy); |
1513 | auto outRecTy = mlir::dyn_cast<fir::RecordType>(outTy); |
1514 | return inRecTy && outRecTy && inRecTy.getTypeList() == outRecTy.getTypeList(); |
1515 | } |
1516 | |
1517 | bool fir::ConvertOp::canBeConverted(mlir::Type inType, mlir::Type outType) { |
1518 | if (inType == outType) |
1519 | return true; |
1520 | return (isPointerCompatible(inType) && isPointerCompatible(outType)) || |
1521 | (isIntegerCompatible(inType) && isIntegerCompatible(outType)) || |
1522 | (isInteger(inType) && isFloatCompatible(outType)) || |
1523 | (isFloatCompatible(inType) && isInteger(outType)) || |
1524 | (isFloatCompatible(inType) && isFloatCompatible(outType)) || |
1525 | (isIntegerCompatible(inType) && isPointerCompatible(outType)) || |
1526 | (isPointerCompatible(inType) && isIntegerCompatible(outType)) || |
1527 | (mlir::isa<fir::BoxType>(inType) && |
1528 | mlir::isa<fir::BoxType>(outType)) || |
1529 | (mlir::isa<fir::BoxProcType>(inType) && |
1530 | mlir::isa<fir::BoxProcType>(outType)) || |
1531 | (fir::isa_complex(inType) && fir::isa_complex(outType)) || |
1532 | (fir::isBoxedRecordType(inType) && fir::isPolymorphicType(outType)) || |
1533 | (fir::isPolymorphicType(inType) && fir::isPolymorphicType(outType)) || |
1534 | (fir::isPolymorphicType(inType) && mlir::isa<BoxType>(outType)) || |
1535 | areVectorsCompatible(inType, outType) || |
1536 | areRecordsCompatible(inType, outType); |
1537 | } |
1538 | |
1539 | // In general, ptrtoint-like conversions are allowed to lose volatility |
1540 | // information because they are either: |
1541 | // |
1542 | // 1. passing an entity to an external function and there's nothing we can do |
1543 | // about volatility after that happens, or |
1544 | // 2. for code generation, at which point we represent volatility with |
1545 | // attributes on the LLVM instructions and intrinsics. |
1546 | // |
1547 | // For all other cases, volatility ought to match exactly. |
1548 | static mlir::LogicalResult verifyVolatility(mlir::Type inType, |
1549 | mlir::Type outType) { |
1550 | const bool toLLVMPointer = mlir::isa<mlir::LLVM::LLVMPointerType>(outType); |
1551 | const bool toInteger = fir::isa_integer(outType); |
1552 | |
1553 | // When converting references to classes or allocatables into boxes for |
1554 | // runtime arguments, we cast away all the volatility information and pass a |
1555 | // box<none>. This is allowed. |
1556 | const bool isBoxNoneLike = [&]() { |
1557 | if (fir::isBoxNone(outType)) |
1558 | return true; |
1559 | if (auto referenceType = mlir::dyn_cast<fir::ReferenceType>(outType)) { |
1560 | if (fir::isBoxNone(referenceType.getElementType())) { |
1561 | return true; |
1562 | } |
1563 | } |
1564 | return false; |
1565 | }(); |
1566 | |
1567 | const bool isPtrToIntLike = toLLVMPointer || toInteger || isBoxNoneLike; |
1568 | if (isPtrToIntLike) { |
1569 | return mlir::success(); |
1570 | } |
1571 | |
1572 | // In all other cases, we need to check for an exact volatility match. |
1573 | return mlir::success(fir::isa_volatile_type(inType) == |
1574 | fir::isa_volatile_type(outType)); |
1575 | } |
1576 | |
1577 | llvm::LogicalResult fir::ConvertOp::verify() { |
1578 | mlir::Type inType = getValue().getType(); |
1579 | mlir::Type outType = getType(); |
1580 | if (fir::useStrictVolatileVerification()) { |
1581 | if (failed(verifyVolatility(inType, outType))) { |
1582 | return emitOpError("this conversion does not preserve volatility: " ) |
1583 | << inType << " / " << outType; |
1584 | } |
1585 | } |
1586 | if (canBeConverted(inType, outType)) |
1587 | return mlir::success(); |
1588 | return emitOpError("invalid type conversion" ) |
1589 | << getValue().getType() << " / " << getType(); |
1590 | } |
1591 | |
1592 | //===----------------------------------------------------------------------===// |
1593 | // CoordinateOp |
1594 | //===----------------------------------------------------------------------===// |
1595 | |
1596 | void fir::CoordinateOp::build(mlir::OpBuilder &builder, |
1597 | mlir::OperationState &result, |
1598 | mlir::Type resultType, mlir::Value ref, |
1599 | mlir::ValueRange coor) { |
1600 | llvm::SmallVector<int32_t> fieldIndices; |
1601 | llvm::SmallVector<mlir::Value> dynamicIndices; |
1602 | bool anyField = false; |
1603 | for (mlir::Value index : coor) { |
1604 | if (auto field = index.getDefiningOp<fir::FieldIndexOp>()) { |
1605 | auto recTy = mlir::cast<fir::RecordType>(field.getOnType()); |
1606 | fieldIndices.push_back(recTy.getFieldIndex(field.getFieldId())); |
1607 | anyField = true; |
1608 | } else { |
1609 | fieldIndices.push_back(fir::CoordinateOp::kDynamicIndex); |
1610 | dynamicIndices.push_back(index); |
1611 | } |
1612 | } |
1613 | auto typeAttr = mlir::TypeAttr::get(ref.getType()); |
1614 | if (anyField) { |
1615 | build(builder, result, resultType, ref, dynamicIndices, typeAttr, |
1616 | builder.getDenseI32ArrayAttr(fieldIndices)); |
1617 | } else { |
1618 | build(builder, result, resultType, ref, dynamicIndices, typeAttr, nullptr); |
1619 | } |
1620 | } |
1621 | |
1622 | void fir::CoordinateOp::build(mlir::OpBuilder &builder, |
1623 | mlir::OperationState &result, |
1624 | mlir::Type resultType, mlir::Value ref, |
1625 | llvm::ArrayRef<fir::IntOrValue> coor) { |
1626 | llvm::SmallVector<int32_t> fieldIndices; |
1627 | llvm::SmallVector<mlir::Value> dynamicIndices; |
1628 | bool anyField = false; |
1629 | for (fir::IntOrValue index : coor) { |
1630 | llvm::TypeSwitch<fir::IntOrValue>(index) |
1631 | .Case<mlir::IntegerAttr>([&](mlir::IntegerAttr intAttr) { |
1632 | fieldIndices.push_back(intAttr.getInt()); |
1633 | anyField = true; |
1634 | }) |
1635 | .Case<mlir::Value>([&](mlir::Value value) { |
1636 | dynamicIndices.push_back(value); |
1637 | fieldIndices.push_back(fir::CoordinateOp::kDynamicIndex); |
1638 | }); |
1639 | } |
1640 | auto typeAttr = mlir::TypeAttr::get(ref.getType()); |
1641 | if (anyField) { |
1642 | build(builder, result, resultType, ref, dynamicIndices, typeAttr, |
1643 | builder.getDenseI32ArrayAttr(fieldIndices)); |
1644 | } else { |
1645 | build(builder, result, resultType, ref, dynamicIndices, typeAttr, nullptr); |
1646 | } |
1647 | } |
1648 | |
1649 | void fir::CoordinateOp::print(mlir::OpAsmPrinter &p) { |
1650 | p << ' ' << getRef(); |
1651 | if (!getFieldIndicesAttr()) { |
1652 | p << ", " << getCoor(); |
1653 | } else { |
1654 | mlir::Type eleTy = fir::getFortranElementType(getRef().getType()); |
1655 | for (auto index : getIndices()) { |
1656 | p << ", " ; |
1657 | llvm::TypeSwitch<fir::IntOrValue>(index) |
1658 | .Case<mlir::IntegerAttr>([&](mlir::IntegerAttr intAttr) { |
1659 | if (auto recordType = llvm::dyn_cast<fir::RecordType>(eleTy)) { |
1660 | int fieldId = intAttr.getInt(); |
1661 | if (fieldId < static_cast<int>(recordType.getNumFields())) { |
1662 | auto nameAndType = recordType.getTypeList()[fieldId]; |
1663 | p << std::get<std::string>(nameAndType); |
1664 | eleTy = fir::getFortranElementType( |
1665 | std::get<mlir::Type>(nameAndType)); |
1666 | return; |
1667 | } |
1668 | } |
1669 | // Invalid index, still print it so that invalid IR can be |
1670 | // investigated. |
1671 | p << intAttr; |
1672 | }) |
1673 | .Case<mlir::Value>([&](mlir::Value value) { p << value; }); |
1674 | } |
1675 | } |
1676 | p.printOptionalAttrDict( |
1677 | (*this)->getAttrs(), |
1678 | /*elideAttrs=*/{getBaseTypeAttrName(), getFieldIndicesAttrName()}); |
1679 | p << " : " ; |
1680 | p.printFunctionalType(getOperandTypes(), (*this)->getResultTypes()); |
1681 | } |
1682 | |
1683 | mlir::ParseResult fir::CoordinateOp::parse(mlir::OpAsmParser &parser, |
1684 | mlir::OperationState &result) { |
1685 | mlir::OpAsmParser::UnresolvedOperand memref; |
1686 | if (parser.parseOperand(memref) || parser.parseComma()) |
1687 | return mlir::failure(); |
1688 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> coorOperands; |
1689 | llvm::SmallVector<std::pair<llvm::StringRef, int>> fieldNames; |
1690 | llvm::SmallVector<int32_t> fieldIndices; |
1691 | while (true) { |
1692 | llvm::StringRef fieldName; |
1693 | if (mlir::succeeded(parser.parseOptionalKeyword(&fieldName))) { |
1694 | fieldNames.push_back({fieldName, static_cast<int>(fieldIndices.size())}); |
1695 | // Actual value will be computed later when base type has been parsed. |
1696 | fieldIndices.push_back(0); |
1697 | } else { |
1698 | mlir::OpAsmParser::UnresolvedOperand index; |
1699 | if (parser.parseOperand(index)) |
1700 | return mlir::failure(); |
1701 | fieldIndices.push_back(fir::CoordinateOp::kDynamicIndex); |
1702 | coorOperands.push_back(index); |
1703 | } |
1704 | if (mlir::failed(parser.parseOptionalComma())) |
1705 | break; |
1706 | } |
1707 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> allOperands; |
1708 | allOperands.push_back(memref); |
1709 | allOperands.append(coorOperands.begin(), coorOperands.end()); |
1710 | mlir::FunctionType funcTy; |
1711 | auto loc = parser.getCurrentLocation(); |
1712 | if (parser.parseOptionalAttrDict(result.attributes) || |
1713 | parser.parseColonType(funcTy) || |
1714 | parser.resolveOperands(allOperands, funcTy.getInputs(), loc, |
1715 | result.operands) || |
1716 | parser.addTypesToList(funcTy.getResults(), result.types)) |
1717 | return mlir::failure(); |
1718 | result.addAttribute(getBaseTypeAttrName(result.name), |
1719 | mlir::TypeAttr::get(funcTy.getInput(0))); |
1720 | if (!fieldNames.empty()) { |
1721 | mlir::Type eleTy = fir::getFortranElementType(funcTy.getInput(0)); |
1722 | for (auto [fieldName, operandPosition] : fieldNames) { |
1723 | auto recTy = llvm::dyn_cast<fir::RecordType>(eleTy); |
1724 | if (!recTy) |
1725 | return parser.emitError( |
1726 | loc, "base must be a derived type when field name appears" ); |
1727 | unsigned fieldNum = recTy.getFieldIndex(fieldName); |
1728 | if (fieldNum > recTy.getNumFields()) |
1729 | return parser.emitError(loc) |
1730 | << "field '" << fieldName |
1731 | << "' is not a component or subcomponent of the base type" ; |
1732 | fieldIndices[operandPosition] = fieldNum; |
1733 | eleTy = fir::getFortranElementType( |
1734 | std::get<mlir::Type>(recTy.getTypeList()[fieldNum])); |
1735 | } |
1736 | result.addAttribute(getFieldIndicesAttrName(result.name), |
1737 | parser.getBuilder().getDenseI32ArrayAttr(fieldIndices)); |
1738 | } |
1739 | return mlir::success(); |
1740 | } |
1741 | |
1742 | llvm::LogicalResult fir::CoordinateOp::verify() { |
1743 | const mlir::Type refTy = getRef().getType(); |
1744 | if (fir::isa_ref_type(refTy)) { |
1745 | auto eleTy = fir::dyn_cast_ptrEleTy(refTy); |
1746 | if (auto arrTy = mlir::dyn_cast<fir::SequenceType>(eleTy)) { |
1747 | if (arrTy.hasUnknownShape()) |
1748 | return emitOpError("cannot find coordinate in unknown shape" ); |
1749 | if (arrTy.getConstantRows() < arrTy.getDimension() - 1) |
1750 | return emitOpError("cannot find coordinate with unknown extents" ); |
1751 | } |
1752 | if (!(fir::isa_aggregate(eleTy) || fir::isa_complex(eleTy) || |
1753 | fir::isa_char_string(eleTy))) |
1754 | return emitOpError("cannot apply to this element type" ); |
1755 | } |
1756 | auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(refTy); |
1757 | unsigned dimension = 0; |
1758 | const unsigned numCoors = getCoor().size(); |
1759 | for (auto coorOperand : llvm::enumerate(getCoor())) { |
1760 | auto co = coorOperand.value(); |
1761 | if (dimension == 0 && mlir::isa<fir::SequenceType>(eleTy)) { |
1762 | dimension = mlir::cast<fir::SequenceType>(eleTy).getDimension(); |
1763 | if (dimension == 0) |
1764 | return emitOpError("cannot apply to array of unknown rank" ); |
1765 | } |
1766 | if (auto *defOp = co.getDefiningOp()) { |
1767 | if (auto index = mlir::dyn_cast<fir::LenParamIndexOp>(defOp)) { |
1768 | // Recovering a LEN type parameter only makes sense from a boxed |
1769 | // value. For a bare reference, the LEN type parameters must be |
1770 | // passed as additional arguments to `index`. |
1771 | if (mlir::isa<fir::BoxType>(refTy)) { |
1772 | if (coorOperand.index() != numCoors - 1) |
1773 | return emitOpError("len_param_index must be last argument" ); |
1774 | if (getNumOperands() != 2) |
1775 | return emitOpError("too many operands for len_param_index case" ); |
1776 | } |
1777 | if (eleTy != index.getOnType()) |
1778 | emitOpError( |
1779 | "len_param_index type not compatible with reference type" ); |
1780 | return mlir::success(); |
1781 | } else if (auto index = mlir::dyn_cast<fir::FieldIndexOp>(defOp)) { |
1782 | if (eleTy != index.getOnType()) |
1783 | emitOpError("field_index type not compatible with reference type" ); |
1784 | if (auto recTy = mlir::dyn_cast<fir::RecordType>(eleTy)) { |
1785 | eleTy = recTy.getType(index.getFieldName()); |
1786 | continue; |
1787 | } |
1788 | return emitOpError("field_index not applied to !fir.type" ); |
1789 | } |
1790 | } |
1791 | if (dimension) { |
1792 | if (--dimension == 0) |
1793 | eleTy = mlir::cast<fir::SequenceType>(eleTy).getElementType(); |
1794 | } else { |
1795 | if (auto t = mlir::dyn_cast<mlir::TupleType>(eleTy)) { |
1796 | // FIXME: Generally, we don't know which field of the tuple is being |
1797 | // referred to unless the operand is a constant. Just assume everything |
1798 | // is good in the tuple case for now. |
1799 | return mlir::success(); |
1800 | } else if (auto t = mlir::dyn_cast<fir::RecordType>(eleTy)) { |
1801 | // FIXME: This is the same as the tuple case. |
1802 | return mlir::success(); |
1803 | } else if (auto t = mlir::dyn_cast<mlir::ComplexType>(eleTy)) { |
1804 | eleTy = t.getElementType(); |
1805 | } else if (auto t = mlir::dyn_cast<fir::CharacterType>(eleTy)) { |
1806 | if (t.getLen() == fir::CharacterType::singleton()) |
1807 | return emitOpError("cannot apply to character singleton" ); |
1808 | eleTy = fir::CharacterType::getSingleton(t.getContext(), t.getFKind()); |
1809 | if (fir::unwrapRefType(getType()) != eleTy) |
1810 | return emitOpError("character type mismatch" ); |
1811 | } else { |
1812 | return emitOpError("invalid parameters (too many)" ); |
1813 | } |
1814 | } |
1815 | } |
1816 | return mlir::success(); |
1817 | } |
1818 | |
1819 | fir::CoordinateIndicesAdaptor fir::CoordinateOp::getIndices() { |
1820 | return CoordinateIndicesAdaptor(getFieldIndicesAttr(), getCoor()); |
1821 | } |
1822 | |
1823 | //===----------------------------------------------------------------------===// |
1824 | // DispatchOp |
1825 | //===----------------------------------------------------------------------===// |
1826 | |
1827 | llvm::LogicalResult fir::DispatchOp::verify() { |
1828 | // Check that pass_arg_pos is in range of actual operands. pass_arg_pos is |
1829 | // unsigned so check for less than zero is not needed. |
1830 | if (getPassArgPos() && *getPassArgPos() > (getArgOperands().size() - 1)) |
1831 | return emitOpError( |
1832 | "pass_arg_pos must be smaller than the number of operands" ); |
1833 | |
1834 | // Operand pointed by pass_arg_pos must have polymorphic type. |
1835 | if (getPassArgPos() && |
1836 | !fir::isPolymorphicType(getArgOperands()[*getPassArgPos()].getType())) |
1837 | return emitOpError("pass_arg_pos must be a polymorphic operand" ); |
1838 | return mlir::success(); |
1839 | } |
1840 | |
1841 | mlir::FunctionType fir::DispatchOp::getFunctionType() { |
1842 | return mlir::FunctionType::get(getContext(), getOperandTypes(), |
1843 | getResultTypes()); |
1844 | } |
1845 | |
1846 | //===----------------------------------------------------------------------===// |
1847 | // TypeInfoOp |
1848 | //===----------------------------------------------------------------------===// |
1849 | |
1850 | void fir::TypeInfoOp::build(mlir::OpBuilder &builder, |
1851 | mlir::OperationState &result, fir::RecordType type, |
1852 | fir::RecordType parentType, |
1853 | llvm::ArrayRef<mlir::NamedAttribute> attrs) { |
1854 | result.addRegion(); |
1855 | result.addRegion(); |
1856 | result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), |
1857 | builder.getStringAttr(type.getName())); |
1858 | result.addAttribute(getTypeAttrName(result.name), mlir::TypeAttr::get(type)); |
1859 | if (parentType) |
1860 | result.addAttribute(getParentTypeAttrName(result.name), |
1861 | mlir::TypeAttr::get(parentType)); |
1862 | result.addAttributes(attrs); |
1863 | } |
1864 | |
1865 | llvm::LogicalResult fir::TypeInfoOp::verify() { |
1866 | if (!getDispatchTable().empty()) |
1867 | for (auto &op : getDispatchTable().front().without_terminator()) |
1868 | if (!mlir::isa<fir::DTEntryOp>(op)) |
1869 | return op.emitOpError("dispatch table must contain dt_entry" ); |
1870 | |
1871 | if (!mlir::isa<fir::RecordType>(getType())) |
1872 | return emitOpError("type must be a fir.type" ); |
1873 | |
1874 | if (getParentType() && !mlir::isa<fir::RecordType>(*getParentType())) |
1875 | return emitOpError("parent_type must be a fir.type" ); |
1876 | return mlir::success(); |
1877 | } |
1878 | |
1879 | //===----------------------------------------------------------------------===// |
1880 | // EmboxOp |
1881 | //===----------------------------------------------------------------------===// |
1882 | |
1883 | // Conversions from reference types to box types must preserve volatility. |
1884 | static llvm::LogicalResult |
1885 | verifyEmboxOpVolatilityInvariants(mlir::Type memrefType, |
1886 | mlir::Type resultType) { |
1887 | |
1888 | if (!fir::useStrictVolatileVerification()) |
1889 | return mlir::success(); |
1890 | |
1891 | mlir::Type boxElementType = |
1892 | llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType) |
1893 | .Case<fir::BoxType, fir::ClassType>( |
1894 | [&](auto type) { return type.getEleTy(); }) |
1895 | .Default([&](mlir::Type type) { return type; }); |
1896 | |
1897 | // If the embox is simply wrapping a non-volatile type into a volatile box, |
1898 | // we're not losing any volatility information. |
1899 | if (boxElementType == memrefType) { |
1900 | return mlir::success(); |
1901 | } |
1902 | |
1903 | // Otherwise, the volatility of the input and result must match. |
1904 | const bool volatilityMatches = |
1905 | fir::isa_volatile_type(memrefType) == fir::isa_volatile_type(resultType); |
1906 | |
1907 | return mlir::success(IsSuccess: volatilityMatches); |
1908 | } |
1909 | |
1910 | llvm::LogicalResult fir::EmboxOp::verify() { |
1911 | auto eleTy = fir::dyn_cast_ptrEleTy(getMemref().getType()); |
1912 | bool isArray = false; |
1913 | if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy)) { |
1914 | eleTy = seqTy.getEleTy(); |
1915 | isArray = true; |
1916 | } |
1917 | if (hasLenParams()) { |
1918 | auto lenPs = numLenParams(); |
1919 | if (auto rt = mlir::dyn_cast<fir::RecordType>(eleTy)) { |
1920 | if (lenPs != rt.getNumLenParams()) |
1921 | return emitOpError("number of LEN params does not correspond" |
1922 | " to the !fir.type type" ); |
1923 | } else if (auto strTy = mlir::dyn_cast<fir::CharacterType>(eleTy)) { |
1924 | if (strTy.getLen() != fir::CharacterType::unknownLen()) |
1925 | return emitOpError("CHARACTER already has static LEN" ); |
1926 | } else { |
1927 | return emitOpError("LEN parameters require CHARACTER or derived type" ); |
1928 | } |
1929 | for (auto lp : getTypeparams()) |
1930 | if (!fir::isa_integer(lp.getType())) |
1931 | return emitOpError("LEN parameters must be integral type" ); |
1932 | } |
1933 | if (getShape() && !isArray) |
1934 | return emitOpError("shape must not be provided for a scalar" ); |
1935 | if (getSlice() && !isArray) |
1936 | return emitOpError("slice must not be provided for a scalar" ); |
1937 | if (getSourceBox() && !mlir::isa<fir::ClassType>(getResult().getType())) |
1938 | return emitOpError("source_box must be used with fir.class result type" ); |
1939 | if (failed(verifyEmboxOpVolatilityInvariants(getMemref().getType(), |
1940 | getResult().getType()))) |
1941 | return emitOpError( |
1942 | "cannot convert between volatile and non-volatile types:" ) |
1943 | << " " << getMemref().getType() << " " << getResult().getType(); |
1944 | return mlir::success(); |
1945 | } |
1946 | |
1947 | //===----------------------------------------------------------------------===// |
1948 | // EmboxCharOp |
1949 | //===----------------------------------------------------------------------===// |
1950 | |
1951 | llvm::LogicalResult fir::EmboxCharOp::verify() { |
1952 | auto eleTy = fir::dyn_cast_ptrEleTy(getMemref().getType()); |
1953 | if (!mlir::dyn_cast_or_null<fir::CharacterType>(eleTy)) |
1954 | return mlir::failure(); |
1955 | return mlir::success(); |
1956 | } |
1957 | |
1958 | //===----------------------------------------------------------------------===// |
1959 | // EmboxProcOp |
1960 | //===----------------------------------------------------------------------===// |
1961 | |
1962 | llvm::LogicalResult fir::EmboxProcOp::verify() { |
1963 | // host bindings (optional) must be a reference to a tuple |
1964 | if (auto h = getHost()) { |
1965 | if (auto r = mlir::dyn_cast<fir::ReferenceType>(h.getType())) |
1966 | if (mlir::isa<mlir::TupleType>(r.getEleTy())) |
1967 | return mlir::success(); |
1968 | return mlir::failure(); |
1969 | } |
1970 | return mlir::success(); |
1971 | } |
1972 | |
1973 | //===----------------------------------------------------------------------===// |
1974 | // TypeDescOp |
1975 | //===----------------------------------------------------------------------===// |
1976 | |
1977 | void fir::TypeDescOp::build(mlir::OpBuilder &, mlir::OperationState &result, |
1978 | mlir::TypeAttr inty) { |
1979 | result.addAttribute("in_type" , inty); |
1980 | result.addTypes(TypeDescType::get(inty.getValue())); |
1981 | } |
1982 | |
1983 | mlir::ParseResult fir::TypeDescOp::parse(mlir::OpAsmParser &parser, |
1984 | mlir::OperationState &result) { |
1985 | mlir::Type intype; |
1986 | if (parser.parseType(intype)) |
1987 | return mlir::failure(); |
1988 | result.addAttribute("in_type" , mlir::TypeAttr::get(intype)); |
1989 | mlir::Type restype = fir::TypeDescType::get(intype); |
1990 | if (parser.addTypeToList(restype, result.types)) |
1991 | return mlir::failure(); |
1992 | return mlir::success(); |
1993 | } |
1994 | |
1995 | void fir::TypeDescOp::print(mlir::OpAsmPrinter &p) { |
1996 | p << ' ' << getOperation()->getAttr("in_type" ); |
1997 | p.printOptionalAttrDict(getOperation()->getAttrs(), {"in_type" }); |
1998 | } |
1999 | |
2000 | llvm::LogicalResult fir::TypeDescOp::verify() { |
2001 | mlir::Type resultTy = getType(); |
2002 | if (auto tdesc = mlir::dyn_cast<fir::TypeDescType>(resultTy)) { |
2003 | if (tdesc.getOfTy() != getInType()) |
2004 | return emitOpError("wrapped type mismatched" ); |
2005 | return mlir::success(); |
2006 | } |
2007 | return emitOpError("must be !fir.tdesc type" ); |
2008 | } |
2009 | |
2010 | //===----------------------------------------------------------------------===// |
2011 | // GlobalOp |
2012 | //===----------------------------------------------------------------------===// |
2013 | |
2014 | mlir::Type fir::GlobalOp::resultType() { |
2015 | return wrapAllocaResultType(getType()); |
2016 | } |
2017 | |
2018 | mlir::ParseResult fir::GlobalOp::parse(mlir::OpAsmParser &parser, |
2019 | mlir::OperationState &result) { |
2020 | // Parse the optional linkage |
2021 | llvm::StringRef linkage; |
2022 | auto &builder = parser.getBuilder(); |
2023 | if (mlir::succeeded(parser.parseOptionalKeyword(&linkage))) { |
2024 | if (fir::GlobalOp::verifyValidLinkage(linkage)) |
2025 | return mlir::failure(); |
2026 | mlir::StringAttr linkAttr = builder.getStringAttr(linkage); |
2027 | result.addAttribute(fir::GlobalOp::getLinkNameAttrName(result.name), |
2028 | linkAttr); |
2029 | } |
2030 | |
2031 | // Parse the name as a symbol reference attribute. |
2032 | mlir::SymbolRefAttr nameAttr; |
2033 | if (parser.parseAttribute(nameAttr, |
2034 | fir::GlobalOp::getSymrefAttrName(result.name), |
2035 | result.attributes)) |
2036 | return mlir::failure(); |
2037 | result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), |
2038 | nameAttr.getRootReference()); |
2039 | |
2040 | bool simpleInitializer = false; |
2041 | if (mlir::succeeded(parser.parseOptionalLParen())) { |
2042 | mlir::Attribute attr; |
2043 | if (parser.parseAttribute(attr, getInitValAttrName(result.name), |
2044 | result.attributes) || |
2045 | parser.parseRParen()) |
2046 | return mlir::failure(); |
2047 | simpleInitializer = true; |
2048 | } |
2049 | |
2050 | if (parser.parseOptionalAttrDict(result.attributes)) |
2051 | return mlir::failure(); |
2052 | |
2053 | if (succeeded( |
2054 | parser.parseOptionalKeyword(getConstantAttrName(result.name)))) { |
2055 | // if "constant" keyword then mark this as a constant, not a variable |
2056 | result.addAttribute(getConstantAttrName(result.name), |
2057 | builder.getUnitAttr()); |
2058 | } |
2059 | |
2060 | if (succeeded(parser.parseOptionalKeyword(getTargetAttrName(result.name)))) |
2061 | result.addAttribute(getTargetAttrName(result.name), builder.getUnitAttr()); |
2062 | |
2063 | mlir::Type globalType; |
2064 | if (parser.parseColonType(globalType)) |
2065 | return mlir::failure(); |
2066 | |
2067 | result.addAttribute(fir::GlobalOp::getTypeAttrName(result.name), |
2068 | mlir::TypeAttr::get(globalType)); |
2069 | |
2070 | if (simpleInitializer) { |
2071 | result.addRegion(); |
2072 | } else { |
2073 | // Parse the optional initializer body. |
2074 | auto parseResult = |
2075 | parser.parseOptionalRegion(*result.addRegion(), /*arguments=*/{}); |
2076 | if (parseResult.has_value() && mlir::failed(*parseResult)) |
2077 | return mlir::failure(); |
2078 | } |
2079 | return mlir::success(); |
2080 | } |
2081 | |
2082 | void fir::GlobalOp::print(mlir::OpAsmPrinter &p) { |
2083 | if (getLinkName()) |
2084 | p << ' ' << *getLinkName(); |
2085 | p << ' '; |
2086 | p.printAttributeWithoutType(getSymrefAttr()); |
2087 | if (auto val = getValueOrNull()) |
2088 | p << '(' << val << ')'; |
2089 | // Print all other attributes that are not pretty printed here. |
2090 | p.printOptionalAttrDict((*this)->getAttrs(), /*elideAttrs=*/{ |
2091 | getSymNameAttrName(), getSymrefAttrName(), |
2092 | getTypeAttrName(), getConstantAttrName(), |
2093 | getTargetAttrName(), getLinkNameAttrName(), |
2094 | getInitValAttrName()}); |
2095 | if (getOperation()->getAttr(getConstantAttrName())) |
2096 | p << " " << getConstantAttrName().strref(); |
2097 | if (getOperation()->getAttr(getTargetAttrName())) |
2098 | p << " " << getTargetAttrName().strref(); |
2099 | p << " : " ; |
2100 | p.printType(getType()); |
2101 | if (hasInitializationBody()) { |
2102 | p << ' '; |
2103 | p.printRegion(getOperation()->getRegion(0), |
2104 | /*printEntryBlockArgs=*/false, |
2105 | /*printBlockTerminators=*/true); |
2106 | } |
2107 | } |
2108 | |
2109 | void fir::GlobalOp::appendInitialValue(mlir::Operation *op) { |
2110 | getBlock().getOperations().push_back(op); |
2111 | } |
2112 | |
2113 | void fir::GlobalOp::build(mlir::OpBuilder &builder, |
2114 | mlir::OperationState &result, llvm::StringRef name, |
2115 | bool isConstant, bool isTarget, mlir::Type type, |
2116 | mlir::Attribute initialVal, mlir::StringAttr linkage, |
2117 | llvm::ArrayRef<mlir::NamedAttribute> attrs) { |
2118 | result.addRegion(); |
2119 | result.addAttribute(getTypeAttrName(result.name), mlir::TypeAttr::get(type)); |
2120 | result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), |
2121 | builder.getStringAttr(name)); |
2122 | result.addAttribute(getSymrefAttrName(result.name), |
2123 | mlir::SymbolRefAttr::get(builder.getContext(), name)); |
2124 | if (isConstant) |
2125 | result.addAttribute(getConstantAttrName(result.name), |
2126 | builder.getUnitAttr()); |
2127 | if (isTarget) |
2128 | result.addAttribute(getTargetAttrName(result.name), builder.getUnitAttr()); |
2129 | if (initialVal) |
2130 | result.addAttribute(getInitValAttrName(result.name), initialVal); |
2131 | if (linkage) |
2132 | result.addAttribute(getLinkNameAttrName(result.name), linkage); |
2133 | result.attributes.append(attrs.begin(), attrs.end()); |
2134 | } |
2135 | |
2136 | void fir::GlobalOp::build(mlir::OpBuilder &builder, |
2137 | mlir::OperationState &result, llvm::StringRef name, |
2138 | mlir::Type type, mlir::Attribute initialVal, |
2139 | mlir::StringAttr linkage, |
2140 | llvm::ArrayRef<mlir::NamedAttribute> attrs) { |
2141 | build(builder, result, name, /*isConstant=*/false, /*isTarget=*/false, type, |
2142 | {}, linkage, attrs); |
2143 | } |
2144 | |
2145 | void fir::GlobalOp::build(mlir::OpBuilder &builder, |
2146 | mlir::OperationState &result, llvm::StringRef name, |
2147 | bool isConstant, bool isTarget, mlir::Type type, |
2148 | mlir::StringAttr linkage, |
2149 | llvm::ArrayRef<mlir::NamedAttribute> attrs) { |
2150 | build(builder, result, name, isConstant, isTarget, type, {}, linkage, attrs); |
2151 | } |
2152 | |
2153 | void fir::GlobalOp::build(mlir::OpBuilder &builder, |
2154 | mlir::OperationState &result, llvm::StringRef name, |
2155 | mlir::Type type, mlir::StringAttr linkage, |
2156 | llvm::ArrayRef<mlir::NamedAttribute> attrs) { |
2157 | build(builder, result, name, /*isConstant=*/false, /*isTarget=*/false, type, |
2158 | {}, linkage, attrs); |
2159 | } |
2160 | |
2161 | void fir::GlobalOp::build(mlir::OpBuilder &builder, |
2162 | mlir::OperationState &result, llvm::StringRef name, |
2163 | bool isConstant, bool isTarget, mlir::Type type, |
2164 | llvm::ArrayRef<mlir::NamedAttribute> attrs) { |
2165 | build(builder, result, name, isConstant, isTarget, type, mlir::StringAttr{}, |
2166 | attrs); |
2167 | } |
2168 | |
2169 | void fir::GlobalOp::build(mlir::OpBuilder &builder, |
2170 | mlir::OperationState &result, llvm::StringRef name, |
2171 | mlir::Type type, |
2172 | llvm::ArrayRef<mlir::NamedAttribute> attrs) { |
2173 | build(builder, result, name, /*isConstant=*/false, /*isTarget=*/false, type, |
2174 | attrs); |
2175 | } |
2176 | |
2177 | mlir::ParseResult fir::GlobalOp::verifyValidLinkage(llvm::StringRef linkage) { |
2178 | // Supporting only a subset of the LLVM linkage types for now |
2179 | static const char *validNames[] = {"common" , "internal" , "linkonce" , |
2180 | "linkonce_odr" , "weak" }; |
2181 | return mlir::success(llvm::is_contained(validNames, linkage)); |
2182 | } |
2183 | |
2184 | //===----------------------------------------------------------------------===// |
2185 | // GlobalLenOp |
2186 | //===----------------------------------------------------------------------===// |
2187 | |
2188 | mlir::ParseResult fir::GlobalLenOp::parse(mlir::OpAsmParser &parser, |
2189 | mlir::OperationState &result) { |
2190 | llvm::StringRef fieldName; |
2191 | if (failed(parser.parseOptionalKeyword(&fieldName))) { |
2192 | mlir::StringAttr fieldAttr; |
2193 | if (parser.parseAttribute(fieldAttr, |
2194 | fir::GlobalLenOp::getLenParamAttrName(), |
2195 | result.attributes)) |
2196 | return mlir::failure(); |
2197 | } else { |
2198 | result.addAttribute(fir::GlobalLenOp::getLenParamAttrName(), |
2199 | parser.getBuilder().getStringAttr(fieldName)); |
2200 | } |
2201 | mlir::IntegerAttr constant; |
2202 | if (parser.parseComma() || |
2203 | parser.parseAttribute(constant, fir::GlobalLenOp::getIntAttrName(), |
2204 | result.attributes)) |
2205 | return mlir::failure(); |
2206 | return mlir::success(); |
2207 | } |
2208 | |
2209 | void fir::GlobalLenOp::print(mlir::OpAsmPrinter &p) { |
2210 | p << ' ' << getOperation()->getAttr(fir::GlobalLenOp::getLenParamAttrName()) |
2211 | << ", " << getOperation()->getAttr(fir::GlobalLenOp::getIntAttrName()); |
2212 | } |
2213 | |
2214 | //===----------------------------------------------------------------------===// |
2215 | // FieldIndexOp |
2216 | //===----------------------------------------------------------------------===// |
2217 | |
2218 | template <typename TY> |
2219 | mlir::ParseResult parseFieldLikeOp(mlir::OpAsmParser &parser, |
2220 | mlir::OperationState &result) { |
2221 | llvm::StringRef fieldName; |
2222 | auto &builder = parser.getBuilder(); |
2223 | mlir::Type recty; |
2224 | if (parser.parseOptionalKeyword(keyword: &fieldName) || parser.parseComma() || |
2225 | parser.parseType(result&: recty)) |
2226 | return mlir::failure(); |
2227 | result.addAttribute(fir::FieldIndexOp::getFieldAttrName(), |
2228 | builder.getStringAttr(fieldName)); |
2229 | if (!mlir::dyn_cast<fir::RecordType>(recty)) |
2230 | return mlir::failure(); |
2231 | result.addAttribute(fir::FieldIndexOp::getTypeAttrName(), |
2232 | mlir::TypeAttr::get(recty)); |
2233 | if (!parser.parseOptionalLParen()) { |
2234 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands; |
2235 | llvm::SmallVector<mlir::Type> types; |
2236 | auto loc = parser.getNameLoc(); |
2237 | if (parser.parseOperandList(result&: operands, delimiter: mlir::OpAsmParser::Delimiter::None) || |
2238 | parser.parseColonTypeList(result&: types) || parser.parseRParen() || |
2239 | parser.resolveOperands(operands, types, loc, result&: result.operands)) |
2240 | return mlir::failure(); |
2241 | } |
2242 | mlir::Type fieldType = TY::get(builder.getContext()); |
2243 | if (parser.addTypeToList(type: fieldType, result&: result.types)) |
2244 | return mlir::failure(); |
2245 | return mlir::success(); |
2246 | } |
2247 | |
2248 | mlir::ParseResult fir::FieldIndexOp::parse(mlir::OpAsmParser &parser, |
2249 | mlir::OperationState &result) { |
2250 | return parseFieldLikeOp<fir::FieldType>(parser, result); |
2251 | } |
2252 | |
2253 | template <typename OP> |
2254 | void printFieldLikeOp(mlir::OpAsmPrinter &p, OP &op) { |
2255 | p << ' ' |
2256 | << op.getOperation() |
2257 | ->template getAttrOfType<mlir::StringAttr>( |
2258 | fir::FieldIndexOp::getFieldAttrName()) |
2259 | .getValue() |
2260 | << ", " << op.getOperation()->getAttr(fir::FieldIndexOp::getTypeAttrName()); |
2261 | if (op.getNumOperands()) { |
2262 | p << '('; |
2263 | p.printOperands(op.getTypeparams()); |
2264 | auto sep = ") : " ; |
2265 | for (auto op : op.getTypeparams()) { |
2266 | p << sep; |
2267 | if (op) |
2268 | p.printType(type: op.getType()); |
2269 | else |
2270 | p << "()" ; |
2271 | sep = ", " ; |
2272 | } |
2273 | } |
2274 | } |
2275 | |
2276 | void fir::FieldIndexOp::print(mlir::OpAsmPrinter &p) { |
2277 | printFieldLikeOp(p, *this); |
2278 | } |
2279 | |
2280 | void fir::FieldIndexOp::build(mlir::OpBuilder &builder, |
2281 | mlir::OperationState &result, |
2282 | llvm::StringRef fieldName, mlir::Type recTy, |
2283 | mlir::ValueRange operands) { |
2284 | result.addAttribute(getFieldAttrName(), builder.getStringAttr(fieldName)); |
2285 | result.addAttribute(getTypeAttrName(), mlir::TypeAttr::get(recTy)); |
2286 | result.addOperands(operands); |
2287 | } |
2288 | |
2289 | llvm::SmallVector<mlir::Attribute> fir::FieldIndexOp::getAttributes() { |
2290 | llvm::SmallVector<mlir::Attribute> attrs; |
2291 | attrs.push_back(getFieldIdAttr()); |
2292 | attrs.push_back(getOnTypeAttr()); |
2293 | return attrs; |
2294 | } |
2295 | |
2296 | //===----------------------------------------------------------------------===// |
2297 | // InsertOnRangeOp |
2298 | //===----------------------------------------------------------------------===// |
2299 | |
2300 | static mlir::ParseResult |
2301 | parseCustomRangeSubscript(mlir::OpAsmParser &parser, |
2302 | mlir::DenseIntElementsAttr &coord) { |
2303 | llvm::SmallVector<std::int64_t> lbounds; |
2304 | llvm::SmallVector<std::int64_t> ubounds; |
2305 | if (parser.parseKeyword(keyword: "from" ) || |
2306 | parser.parseCommaSeparatedList( |
2307 | delimiter: mlir::AsmParser::Delimiter::Paren, |
2308 | parseElementFn: [&] { return parser.parseInteger(result&: lbounds.emplace_back(Args: 0)); }) || |
2309 | parser.parseKeyword(keyword: "to" ) || |
2310 | parser.parseCommaSeparatedList(delimiter: mlir::AsmParser::Delimiter::Paren, parseElementFn: [&] { |
2311 | return parser.parseInteger(result&: ubounds.emplace_back(Args: 0)); |
2312 | })) |
2313 | return mlir::failure(); |
2314 | llvm::SmallVector<std::int64_t> zippedBounds; |
2315 | for (auto zip : llvm::zip(t&: lbounds, u&: ubounds)) { |
2316 | zippedBounds.push_back(Elt: std::get<0>(t&: zip)); |
2317 | zippedBounds.push_back(Elt: std::get<1>(t&: zip)); |
2318 | } |
2319 | coord = mlir::Builder(parser.getContext()).getIndexTensorAttr(values: zippedBounds); |
2320 | return mlir::success(); |
2321 | } |
2322 | |
2323 | static void printCustomRangeSubscript(mlir::OpAsmPrinter &printer, |
2324 | fir::InsertOnRangeOp op, |
2325 | mlir::DenseIntElementsAttr coord) { |
2326 | printer << "from (" ; |
2327 | auto enumerate = llvm::enumerate(coord.getValues<std::int64_t>()); |
2328 | // Even entries are the lower bounds. |
2329 | llvm::interleaveComma( |
2330 | make_filter_range( |
2331 | enumerate, |
2332 | [](auto indexed_value) { return indexed_value.index() % 2 == 0; }), |
2333 | printer, [&](auto indexed_value) { printer << indexed_value.value(); }); |
2334 | printer << ") to (" ; |
2335 | // Odd entries are the upper bounds. |
2336 | llvm::interleaveComma( |
2337 | make_filter_range( |
2338 | enumerate, |
2339 | [](auto indexed_value) { return indexed_value.index() % 2 != 0; }), |
2340 | printer, [&](auto indexed_value) { printer << indexed_value.value(); }); |
2341 | printer << ")" ; |
2342 | } |
2343 | |
2344 | /// Range bounds must be nonnegative, and the range must not be empty. |
2345 | llvm::LogicalResult fir::InsertOnRangeOp::verify() { |
2346 | if (fir::hasDynamicSize(getSeq().getType())) |
2347 | return emitOpError("must have constant shape and size" ); |
2348 | mlir::DenseIntElementsAttr coorAttr = getCoor(); |
2349 | if (coorAttr.size() < 2 || coorAttr.size() % 2 != 0) |
2350 | return emitOpError("has uneven number of values in ranges" ); |
2351 | bool rangeIsKnownToBeNonempty = false; |
2352 | for (auto i = coorAttr.getValues<std::int64_t>().end(), |
2353 | b = coorAttr.getValues<std::int64_t>().begin(); |
2354 | i != b;) { |
2355 | int64_t ub = (*--i); |
2356 | int64_t lb = (*--i); |
2357 | if (lb < 0 || ub < 0) |
2358 | return emitOpError("negative range bound" ); |
2359 | if (rangeIsKnownToBeNonempty) |
2360 | continue; |
2361 | if (lb > ub) |
2362 | return emitOpError("empty range" ); |
2363 | rangeIsKnownToBeNonempty = lb < ub; |
2364 | } |
2365 | return mlir::success(); |
2366 | } |
2367 | |
2368 | bool fir::InsertOnRangeOp::isFullRange() { |
2369 | auto extents = getType().getShape(); |
2370 | mlir::DenseIntElementsAttr indexes = getCoor(); |
2371 | if (indexes.size() / 2 != static_cast<int64_t>(extents.size())) |
2372 | return false; |
2373 | auto cur_index = indexes.value_begin<int64_t>(); |
2374 | for (unsigned i = 0; i < indexes.size(); i += 2) { |
2375 | if (*(cur_index++) != 0) |
2376 | return false; |
2377 | if (*(cur_index++) != extents[i / 2] - 1) |
2378 | return false; |
2379 | } |
2380 | return true; |
2381 | } |
2382 | |
2383 | //===----------------------------------------------------------------------===// |
2384 | // InsertValueOp |
2385 | //===----------------------------------------------------------------------===// |
2386 | |
2387 | static bool checkIsIntegerConstant(mlir::Attribute attr, std::int64_t conVal) { |
2388 | if (auto iattr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) |
2389 | return iattr.getInt() == conVal; |
2390 | return false; |
2391 | } |
2392 | |
2393 | static bool isZero(mlir::Attribute a) { return checkIsIntegerConstant(attr: a, conVal: 0); } |
2394 | static bool isOne(mlir::Attribute a) { return checkIsIntegerConstant(attr: a, conVal: 1); } |
2395 | |
2396 | // Undo some complex patterns created in the front-end and turn them back into |
2397 | // complex ops. |
2398 | template <typename FltOp, typename CpxOp> |
2399 | struct UndoComplexPattern : public mlir::RewritePattern { |
2400 | UndoComplexPattern(mlir::MLIRContext *ctx) |
2401 | : mlir::RewritePattern("fir.insert_value" , 2, ctx) {} |
2402 | |
2403 | llvm::LogicalResult |
2404 | matchAndRewrite(mlir::Operation *op, |
2405 | mlir::PatternRewriter &rewriter) const override { |
2406 | auto insval = mlir::dyn_cast_or_null<fir::InsertValueOp>(op); |
2407 | if (!insval || !mlir::isa<mlir::ComplexType>(insval.getType())) |
2408 | return mlir::failure(); |
2409 | auto insval2 = mlir::dyn_cast_or_null<fir::InsertValueOp>( |
2410 | insval.getAdt().getDefiningOp()); |
2411 | if (!insval2) |
2412 | return mlir::failure(); |
2413 | auto binf = mlir::dyn_cast_or_null<FltOp>(insval.getVal().getDefiningOp()); |
2414 | auto binf2 = |
2415 | mlir::dyn_cast_or_null<FltOp>(insval2.getVal().getDefiningOp()); |
2416 | if (!binf || !binf2 || insval.getCoor().size() != 1 || |
2417 | !isOne(insval.getCoor()[0]) || insval2.getCoor().size() != 1 || |
2418 | !isZero(insval2.getCoor()[0])) |
2419 | return mlir::failure(); |
2420 | auto eai = mlir::dyn_cast_or_null<fir::ExtractValueOp>( |
2421 | binf.getLhs().getDefiningOp()); |
2422 | auto ebi = mlir::dyn_cast_or_null<fir::ExtractValueOp>( |
2423 | binf.getRhs().getDefiningOp()); |
2424 | auto ear = mlir::dyn_cast_or_null<fir::ExtractValueOp>( |
2425 | binf2.getLhs().getDefiningOp()); |
2426 | auto ebr = mlir::dyn_cast_or_null<fir::ExtractValueOp>( |
2427 | binf2.getRhs().getDefiningOp()); |
2428 | if (!eai || !ebi || !ear || !ebr || ear.getAdt() != eai.getAdt() || |
2429 | ebr.getAdt() != ebi.getAdt() || eai.getCoor().size() != 1 || |
2430 | !isOne(eai.getCoor()[0]) || ebi.getCoor().size() != 1 || |
2431 | !isOne(ebi.getCoor()[0]) || ear.getCoor().size() != 1 || |
2432 | !isZero(ear.getCoor()[0]) || ebr.getCoor().size() != 1 || |
2433 | !isZero(ebr.getCoor()[0])) |
2434 | return mlir::failure(); |
2435 | rewriter.replaceOpWithNewOp<CpxOp>(op, ear.getAdt(), ebr.getAdt()); |
2436 | return mlir::success(); |
2437 | } |
2438 | }; |
2439 | |
2440 | void fir::InsertValueOp::getCanonicalizationPatterns( |
2441 | mlir::RewritePatternSet &results, mlir::MLIRContext *context) { |
2442 | results.insert<UndoComplexPattern<mlir::arith::AddFOp, fir::AddcOp>, |
2443 | UndoComplexPattern<mlir::arith::SubFOp, fir::SubcOp>>(context); |
2444 | } |
2445 | |
2446 | //===----------------------------------------------------------------------===// |
2447 | // IterWhileOp |
2448 | //===----------------------------------------------------------------------===// |
2449 | |
2450 | void fir::IterWhileOp::build(mlir::OpBuilder &builder, |
2451 | mlir::OperationState &result, mlir::Value lb, |
2452 | mlir::Value ub, mlir::Value step, |
2453 | mlir::Value iterate, bool finalCountValue, |
2454 | mlir::ValueRange iterArgs, |
2455 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
2456 | result.addOperands({lb, ub, step, iterate}); |
2457 | if (finalCountValue) { |
2458 | result.addTypes(builder.getIndexType()); |
2459 | result.addAttribute(getFinalValueAttrNameStr(), builder.getUnitAttr()); |
2460 | } |
2461 | result.addTypes(iterate.getType()); |
2462 | result.addOperands(iterArgs); |
2463 | for (auto v : iterArgs) |
2464 | result.addTypes(v.getType()); |
2465 | mlir::Region *bodyRegion = result.addRegion(); |
2466 | bodyRegion->push_back(new mlir::Block{}); |
2467 | bodyRegion->front().addArgument(builder.getIndexType(), result.location); |
2468 | bodyRegion->front().addArgument(iterate.getType(), result.location); |
2469 | bodyRegion->front().addArguments( |
2470 | iterArgs.getTypes(), |
2471 | llvm::SmallVector<mlir::Location>(iterArgs.size(), result.location)); |
2472 | result.addAttributes(attributes); |
2473 | } |
2474 | |
2475 | mlir::ParseResult fir::IterWhileOp::parse(mlir::OpAsmParser &parser, |
2476 | mlir::OperationState &result) { |
2477 | auto &builder = parser.getBuilder(); |
2478 | mlir::OpAsmParser::Argument inductionVariable, iterateVar; |
2479 | mlir::OpAsmParser::UnresolvedOperand lb, ub, step, iterateInput; |
2480 | if (parser.parseLParen() || parser.parseArgument(inductionVariable) || |
2481 | parser.parseEqual()) |
2482 | return mlir::failure(); |
2483 | |
2484 | // Parse loop bounds. |
2485 | auto indexType = builder.getIndexType(); |
2486 | auto i1Type = builder.getIntegerType(1); |
2487 | if (parser.parseOperand(lb) || |
2488 | parser.resolveOperand(lb, indexType, result.operands) || |
2489 | parser.parseKeyword("to" ) || parser.parseOperand(ub) || |
2490 | parser.resolveOperand(ub, indexType, result.operands) || |
2491 | parser.parseKeyword("step" ) || parser.parseOperand(step) || |
2492 | parser.parseRParen() || |
2493 | parser.resolveOperand(step, indexType, result.operands) || |
2494 | parser.parseKeyword("and" ) || parser.parseLParen() || |
2495 | parser.parseArgument(iterateVar) || parser.parseEqual() || |
2496 | parser.parseOperand(iterateInput) || parser.parseRParen() || |
2497 | parser.resolveOperand(iterateInput, i1Type, result.operands)) |
2498 | return mlir::failure(); |
2499 | |
2500 | // Parse the initial iteration arguments. |
2501 | auto prependCount = false; |
2502 | |
2503 | // Induction variable. |
2504 | llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs; |
2505 | regionArgs.push_back(inductionVariable); |
2506 | regionArgs.push_back(iterateVar); |
2507 | |
2508 | if (succeeded(parser.parseOptionalKeyword("iter_args" ))) { |
2509 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands; |
2510 | llvm::SmallVector<mlir::Type> regionTypes; |
2511 | // Parse assignment list and results type list. |
2512 | if (parser.parseAssignmentList(regionArgs, operands) || |
2513 | parser.parseArrowTypeList(regionTypes)) |
2514 | return mlir::failure(); |
2515 | if (regionTypes.size() == operands.size() + 2) |
2516 | prependCount = true; |
2517 | llvm::ArrayRef<mlir::Type> resTypes = regionTypes; |
2518 | resTypes = prependCount ? resTypes.drop_front(2) : resTypes; |
2519 | // Resolve input operands. |
2520 | for (auto operandType : llvm::zip(operands, resTypes)) |
2521 | if (parser.resolveOperand(std::get<0>(operandType), |
2522 | std::get<1>(operandType), result.operands)) |
2523 | return mlir::failure(); |
2524 | if (prependCount) { |
2525 | result.addTypes(regionTypes); |
2526 | } else { |
2527 | result.addTypes(i1Type); |
2528 | result.addTypes(resTypes); |
2529 | } |
2530 | } else if (succeeded(parser.parseOptionalArrow())) { |
2531 | llvm::SmallVector<mlir::Type> typeList; |
2532 | if (parser.parseLParen() || parser.parseTypeList(typeList) || |
2533 | parser.parseRParen()) |
2534 | return mlir::failure(); |
2535 | // Type list must be "(index, i1)". |
2536 | if (typeList.size() != 2 || !mlir::isa<mlir::IndexType>(typeList[0]) || |
2537 | !typeList[1].isSignlessInteger(1)) |
2538 | return mlir::failure(); |
2539 | result.addTypes(typeList); |
2540 | prependCount = true; |
2541 | } else { |
2542 | result.addTypes(i1Type); |
2543 | } |
2544 | |
2545 | if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) |
2546 | return mlir::failure(); |
2547 | |
2548 | llvm::SmallVector<mlir::Type> argTypes; |
2549 | // Induction variable (hidden) |
2550 | if (prependCount) |
2551 | result.addAttribute(IterWhileOp::getFinalValueAttrNameStr(), |
2552 | builder.getUnitAttr()); |
2553 | else |
2554 | argTypes.push_back(indexType); |
2555 | // Loop carried variables (including iterate) |
2556 | argTypes.append(result.types.begin(), result.types.end()); |
2557 | // Parse the body region. |
2558 | auto *body = result.addRegion(); |
2559 | if (regionArgs.size() != argTypes.size()) |
2560 | return parser.emitError( |
2561 | parser.getNameLoc(), |
2562 | "mismatch in number of loop-carried values and defined values" ); |
2563 | |
2564 | for (size_t i = 0, e = regionArgs.size(); i != e; ++i) |
2565 | regionArgs[i].type = argTypes[i]; |
2566 | |
2567 | if (parser.parseRegion(*body, regionArgs)) |
2568 | return mlir::failure(); |
2569 | |
2570 | fir::IterWhileOp::ensureTerminator(*body, builder, result.location); |
2571 | return mlir::success(); |
2572 | } |
2573 | |
2574 | llvm::LogicalResult fir::IterWhileOp::verify() { |
2575 | // Check that the body defines as single block argument for the induction |
2576 | // variable. |
2577 | auto *body = getBody(); |
2578 | if (!body->getArgument(1).getType().isInteger(1)) |
2579 | return emitOpError( |
2580 | "expected body second argument to be an index argument for " |
2581 | "the induction variable" ); |
2582 | if (!body->getArgument(0).getType().isIndex()) |
2583 | return emitOpError( |
2584 | "expected body first argument to be an index argument for " |
2585 | "the induction variable" ); |
2586 | |
2587 | auto opNumResults = getNumResults(); |
2588 | if (getFinalValue()) { |
2589 | // Result type must be "(index, i1, ...)". |
2590 | if (!mlir::isa<mlir::IndexType>(getResult(0).getType())) |
2591 | return emitOpError("result #0 expected to be index" ); |
2592 | if (!getResult(1).getType().isSignlessInteger(1)) |
2593 | return emitOpError("result #1 expected to be i1" ); |
2594 | opNumResults--; |
2595 | } else { |
2596 | // iterate_while always returns the early exit induction value. |
2597 | // Result type must be "(i1, ...)" |
2598 | if (!getResult(0).getType().isSignlessInteger(1)) |
2599 | return emitOpError("result #0 expected to be i1" ); |
2600 | } |
2601 | if (opNumResults == 0) |
2602 | return mlir::failure(); |
2603 | if (getNumIterOperands() != opNumResults) |
2604 | return emitOpError( |
2605 | "mismatch in number of loop-carried values and defined values" ); |
2606 | if (getNumRegionIterArgs() != opNumResults) |
2607 | return emitOpError( |
2608 | "mismatch in number of basic block args and defined values" ); |
2609 | auto iterOperands = getIterOperands(); |
2610 | auto iterArgs = getRegionIterArgs(); |
2611 | auto opResults = getFinalValue() ? getResults().drop_front() : getResults(); |
2612 | unsigned i = 0u; |
2613 | for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { |
2614 | if (std::get<0>(e).getType() != std::get<2>(e).getType()) |
2615 | return emitOpError() << "types mismatch between " << i |
2616 | << "th iter operand and defined value" ; |
2617 | if (std::get<1>(e).getType() != std::get<2>(e).getType()) |
2618 | return emitOpError() << "types mismatch between " << i |
2619 | << "th iter region arg and defined value" ; |
2620 | |
2621 | i++; |
2622 | } |
2623 | return mlir::success(); |
2624 | } |
2625 | |
2626 | void fir::IterWhileOp::print(mlir::OpAsmPrinter &p) { |
2627 | p << " (" << getInductionVar() << " = " << getLowerBound() << " to " |
2628 | << getUpperBound() << " step " << getStep() << ") and (" ; |
2629 | assert(hasIterOperands()); |
2630 | auto regionArgs = getRegionIterArgs(); |
2631 | auto operands = getIterOperands(); |
2632 | p << regionArgs.front() << " = " << *operands.begin() << ")" ; |
2633 | if (regionArgs.size() > 1) { |
2634 | p << " iter_args(" ; |
2635 | llvm::interleaveComma( |
2636 | llvm::zip(regionArgs.drop_front(), operands.drop_front()), p, |
2637 | [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); |
2638 | p << ") -> (" ; |
2639 | llvm::interleaveComma( |
2640 | llvm::drop_begin(getResultTypes(), getFinalValue() ? 0 : 1), p); |
2641 | p << ")" ; |
2642 | } else if (getFinalValue()) { |
2643 | p << " -> (" << getResultTypes() << ')'; |
2644 | } |
2645 | p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), |
2646 | {getFinalValueAttrNameStr()}); |
2647 | p << ' '; |
2648 | p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, |
2649 | /*printBlockTerminators=*/true); |
2650 | } |
2651 | |
2652 | llvm::SmallVector<mlir::Region *> fir::IterWhileOp::getLoopRegions() { |
2653 | return {&getRegion()}; |
2654 | } |
2655 | |
2656 | mlir::BlockArgument fir::IterWhileOp::iterArgToBlockArg(mlir::Value iterArg) { |
2657 | for (auto i : llvm::enumerate(getInitArgs())) |
2658 | if (iterArg == i.value()) |
2659 | return getRegion().front().getArgument(i.index() + 1); |
2660 | return {}; |
2661 | } |
2662 | |
2663 | void fir::IterWhileOp::resultToSourceOps( |
2664 | llvm::SmallVectorImpl<mlir::Value> &results, unsigned resultNum) { |
2665 | auto oper = getFinalValue() ? resultNum + 1 : resultNum; |
2666 | auto *term = getRegion().front().getTerminator(); |
2667 | if (oper < term->getNumOperands()) |
2668 | results.push_back(term->getOperand(oper)); |
2669 | } |
2670 | |
2671 | mlir::Value fir::IterWhileOp::blockArgToSourceOp(unsigned blockArgNum) { |
2672 | if (blockArgNum > 0 && blockArgNum <= getInitArgs().size()) |
2673 | return getInitArgs()[blockArgNum - 1]; |
2674 | return {}; |
2675 | } |
2676 | |
2677 | std::optional<llvm::MutableArrayRef<mlir::OpOperand>> |
2678 | fir::IterWhileOp::getYieldedValuesMutable() { |
2679 | auto *term = getRegion().front().getTerminator(); |
2680 | return getFinalValue() ? term->getOpOperands().drop_front() |
2681 | : term->getOpOperands(); |
2682 | } |
2683 | |
2684 | //===----------------------------------------------------------------------===// |
2685 | // LenParamIndexOp |
2686 | //===----------------------------------------------------------------------===// |
2687 | |
2688 | mlir::ParseResult fir::LenParamIndexOp::parse(mlir::OpAsmParser &parser, |
2689 | mlir::OperationState &result) { |
2690 | return parseFieldLikeOp<fir::LenType>(parser, result); |
2691 | } |
2692 | |
2693 | void fir::LenParamIndexOp::print(mlir::OpAsmPrinter &p) { |
2694 | printFieldLikeOp(p, *this); |
2695 | } |
2696 | |
2697 | void fir::LenParamIndexOp::build(mlir::OpBuilder &builder, |
2698 | mlir::OperationState &result, |
2699 | llvm::StringRef fieldName, mlir::Type recTy, |
2700 | mlir::ValueRange operands) { |
2701 | result.addAttribute(getFieldAttrName(), builder.getStringAttr(fieldName)); |
2702 | result.addAttribute(getTypeAttrName(), mlir::TypeAttr::get(recTy)); |
2703 | result.addOperands(operands); |
2704 | } |
2705 | |
2706 | llvm::SmallVector<mlir::Attribute> fir::LenParamIndexOp::getAttributes() { |
2707 | llvm::SmallVector<mlir::Attribute> attrs; |
2708 | attrs.push_back(getFieldIdAttr()); |
2709 | attrs.push_back(getOnTypeAttr()); |
2710 | return attrs; |
2711 | } |
2712 | |
2713 | //===----------------------------------------------------------------------===// |
2714 | // LoadOp |
2715 | //===----------------------------------------------------------------------===// |
2716 | |
2717 | void fir::LoadOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
2718 | mlir::Value refVal) { |
2719 | if (!refVal) { |
2720 | mlir::emitError(result.location, "LoadOp has null argument" ); |
2721 | return; |
2722 | } |
2723 | auto eleTy = fir::dyn_cast_ptrEleTy(refVal.getType()); |
2724 | if (!eleTy) { |
2725 | mlir::emitError(result.location, "not a memory reference type" ); |
2726 | return; |
2727 | } |
2728 | build(builder, result, eleTy, refVal); |
2729 | } |
2730 | |
2731 | void fir::LoadOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
2732 | mlir::Type resTy, mlir::Value refVal) { |
2733 | |
2734 | if (!refVal) { |
2735 | mlir::emitError(result.location, "LoadOp has null argument" ); |
2736 | return; |
2737 | } |
2738 | result.addOperands(refVal); |
2739 | result.addTypes(resTy); |
2740 | } |
2741 | |
2742 | mlir::ParseResult fir::LoadOp::getElementOf(mlir::Type &ele, mlir::Type ref) { |
2743 | if ((ele = fir::dyn_cast_ptrEleTy(ref))) |
2744 | return mlir::success(); |
2745 | return mlir::failure(); |
2746 | } |
2747 | |
2748 | mlir::ParseResult fir::LoadOp::parse(mlir::OpAsmParser &parser, |
2749 | mlir::OperationState &result) { |
2750 | mlir::Type type; |
2751 | mlir::OpAsmParser::UnresolvedOperand oper; |
2752 | if (parser.parseOperand(oper) || |
2753 | parser.parseOptionalAttrDict(result.attributes) || |
2754 | parser.parseColonType(type) || |
2755 | parser.resolveOperand(oper, type, result.operands)) |
2756 | return mlir::failure(); |
2757 | mlir::Type eleTy; |
2758 | if (fir::LoadOp::getElementOf(eleTy, type) || |
2759 | parser.addTypeToList(eleTy, result.types)) |
2760 | return mlir::failure(); |
2761 | return mlir::success(); |
2762 | } |
2763 | |
2764 | void fir::LoadOp::print(mlir::OpAsmPrinter &p) { |
2765 | p << ' '; |
2766 | p.printOperand(getMemref()); |
2767 | p.printOptionalAttrDict(getOperation()->getAttrs(), {}); |
2768 | p << " : " << getMemref().getType(); |
2769 | } |
2770 | |
2771 | void fir::LoadOp::getEffects( |
2772 | llvm::SmallVectorImpl< |
2773 | mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>> |
2774 | &effects) { |
2775 | effects.emplace_back(mlir::MemoryEffects::Read::get(), &getMemrefMutable(), |
2776 | mlir::SideEffects::DefaultResource::get()); |
2777 | addVolatileMemoryEffects({getMemref().getType()}, effects); |
2778 | } |
2779 | |
2780 | //===----------------------------------------------------------------------===// |
2781 | // DoLoopOp |
2782 | //===----------------------------------------------------------------------===// |
2783 | |
2784 | void fir::DoLoopOp::build(mlir::OpBuilder &builder, |
2785 | mlir::OperationState &result, mlir::Value lb, |
2786 | mlir::Value ub, mlir::Value step, bool unordered, |
2787 | bool finalCountValue, mlir::ValueRange iterArgs, |
2788 | mlir::ValueRange reduceOperands, |
2789 | llvm::ArrayRef<mlir::Attribute> reduceAttrs, |
2790 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
2791 | result.addOperands({lb, ub, step}); |
2792 | result.addOperands(reduceOperands); |
2793 | result.addOperands(iterArgs); |
2794 | result.addAttribute(getOperandSegmentSizeAttr(), |
2795 | builder.getDenseI32ArrayAttr( |
2796 | {1, 1, 1, static_cast<int32_t>(reduceOperands.size()), |
2797 | static_cast<int32_t>(iterArgs.size())})); |
2798 | if (finalCountValue) { |
2799 | result.addTypes(builder.getIndexType()); |
2800 | result.addAttribute(getFinalValueAttrName(result.name), |
2801 | builder.getUnitAttr()); |
2802 | } |
2803 | for (auto v : iterArgs) |
2804 | result.addTypes(v.getType()); |
2805 | mlir::Region *bodyRegion = result.addRegion(); |
2806 | bodyRegion->push_back(new mlir::Block{}); |
2807 | if (iterArgs.empty() && !finalCountValue) |
2808 | fir::DoLoopOp::ensureTerminator(*bodyRegion, builder, result.location); |
2809 | bodyRegion->front().addArgument(builder.getIndexType(), result.location); |
2810 | bodyRegion->front().addArguments( |
2811 | iterArgs.getTypes(), |
2812 | llvm::SmallVector<mlir::Location>(iterArgs.size(), result.location)); |
2813 | if (unordered) |
2814 | result.addAttribute(getUnorderedAttrName(result.name), |
2815 | builder.getUnitAttr()); |
2816 | if (!reduceAttrs.empty()) |
2817 | result.addAttribute(getReduceAttrsAttrName(result.name), |
2818 | builder.getArrayAttr(reduceAttrs)); |
2819 | result.addAttributes(attributes); |
2820 | } |
2821 | |
2822 | mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser, |
2823 | mlir::OperationState &result) { |
2824 | auto &builder = parser.getBuilder(); |
2825 | mlir::OpAsmParser::Argument inductionVariable; |
2826 | mlir::OpAsmParser::UnresolvedOperand lb, ub, step; |
2827 | // Parse the induction variable followed by '='. |
2828 | if (parser.parseArgument(inductionVariable) || parser.parseEqual()) |
2829 | return mlir::failure(); |
2830 | |
2831 | // Parse loop bounds. |
2832 | auto indexType = builder.getIndexType(); |
2833 | if (parser.parseOperand(lb) || |
2834 | parser.resolveOperand(lb, indexType, result.operands) || |
2835 | parser.parseKeyword("to" ) || parser.parseOperand(ub) || |
2836 | parser.resolveOperand(ub, indexType, result.operands) || |
2837 | parser.parseKeyword("step" ) || parser.parseOperand(step) || |
2838 | parser.resolveOperand(step, indexType, result.operands)) |
2839 | return mlir::failure(); |
2840 | |
2841 | if (mlir::succeeded(parser.parseOptionalKeyword("unordered" ))) |
2842 | result.addAttribute("unordered" , builder.getUnitAttr()); |
2843 | |
2844 | // Parse the reduction arguments. |
2845 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> reduceOperands; |
2846 | llvm::SmallVector<mlir::Type> reduceArgTypes; |
2847 | if (succeeded(parser.parseOptionalKeyword("reduce" ))) { |
2848 | // Parse reduction attributes and variables. |
2849 | llvm::SmallVector<ReduceAttr> attributes; |
2850 | if (failed(parser.parseCommaSeparatedList( |
2851 | mlir::AsmParser::Delimiter::Paren, [&]() { |
2852 | if (parser.parseAttribute(attributes.emplace_back()) || |
2853 | parser.parseArrow() || |
2854 | parser.parseOperand(reduceOperands.emplace_back()) || |
2855 | parser.parseColonType(reduceArgTypes.emplace_back())) |
2856 | return mlir::failure(); |
2857 | return mlir::success(); |
2858 | }))) |
2859 | return mlir::failure(); |
2860 | // Resolve input operands. |
2861 | for (auto operand_type : llvm::zip(reduceOperands, reduceArgTypes)) |
2862 | if (parser.resolveOperand(std::get<0>(operand_type), |
2863 | std::get<1>(operand_type), result.operands)) |
2864 | return mlir::failure(); |
2865 | llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(), |
2866 | attributes.end()); |
2867 | result.addAttribute(getReduceAttrsAttrName(result.name), |
2868 | builder.getArrayAttr(arrayAttr)); |
2869 | } |
2870 | |
2871 | // Parse the optional initial iteration arguments. |
2872 | llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs; |
2873 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands; |
2874 | llvm::SmallVector<mlir::Type> argTypes; |
2875 | bool prependCount = false; |
2876 | regionArgs.push_back(inductionVariable); |
2877 | |
2878 | if (succeeded(parser.parseOptionalKeyword("iter_args" ))) { |
2879 | // Parse assignment list and results type list. |
2880 | if (parser.parseAssignmentList(regionArgs, iterOperands) || |
2881 | parser.parseArrowTypeList(result.types)) |
2882 | return mlir::failure(); |
2883 | if (result.types.size() == iterOperands.size() + 1) |
2884 | prependCount = true; |
2885 | // Resolve input operands. |
2886 | llvm::ArrayRef<mlir::Type> resTypes = result.types; |
2887 | for (auto operand_type : llvm::zip( |
2888 | iterOperands, prependCount ? resTypes.drop_front() : resTypes)) |
2889 | if (parser.resolveOperand(std::get<0>(operand_type), |
2890 | std::get<1>(operand_type), result.operands)) |
2891 | return mlir::failure(); |
2892 | } else if (succeeded(parser.parseOptionalArrow())) { |
2893 | if (parser.parseKeyword("index" )) |
2894 | return mlir::failure(); |
2895 | result.types.push_back(indexType); |
2896 | prependCount = true; |
2897 | } |
2898 | |
2899 | // Set the operandSegmentSizes attribute |
2900 | result.addAttribute(getOperandSegmentSizeAttr(), |
2901 | builder.getDenseI32ArrayAttr( |
2902 | {1, 1, 1, static_cast<int32_t>(reduceOperands.size()), |
2903 | static_cast<int32_t>(iterOperands.size())})); |
2904 | |
2905 | if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) |
2906 | return mlir::failure(); |
2907 | |
2908 | // Induction variable. |
2909 | if (prependCount) |
2910 | result.addAttribute(DoLoopOp::getFinalValueAttrName(result.name), |
2911 | builder.getUnitAttr()); |
2912 | else |
2913 | argTypes.push_back(indexType); |
2914 | // Loop carried variables |
2915 | argTypes.append(result.types.begin(), result.types.end()); |
2916 | // Parse the body region. |
2917 | auto *body = result.addRegion(); |
2918 | if (regionArgs.size() != argTypes.size()) |
2919 | return parser.emitError( |
2920 | parser.getNameLoc(), |
2921 | "mismatch in number of loop-carried values and defined values" ); |
2922 | for (size_t i = 0, e = regionArgs.size(); i != e; ++i) |
2923 | regionArgs[i].type = argTypes[i]; |
2924 | |
2925 | if (parser.parseRegion(*body, regionArgs)) |
2926 | return mlir::failure(); |
2927 | |
2928 | DoLoopOp::ensureTerminator(*body, builder, result.location); |
2929 | |
2930 | return mlir::success(); |
2931 | } |
2932 | |
2933 | fir::DoLoopOp fir::getForInductionVarOwner(mlir::Value val) { |
2934 | auto ivArg = mlir::dyn_cast<mlir::BlockArgument>(val); |
2935 | if (!ivArg) |
2936 | return {}; |
2937 | assert(ivArg.getOwner() && "unlinked block argument" ); |
2938 | auto *containingInst = ivArg.getOwner()->getParentOp(); |
2939 | return mlir::dyn_cast_or_null<fir::DoLoopOp>(containingInst); |
2940 | } |
2941 | |
2942 | // Lifted from loop.loop |
2943 | llvm::LogicalResult fir::DoLoopOp::verify() { |
2944 | // Check that the body defines as single block argument for the induction |
2945 | // variable. |
2946 | auto *body = getBody(); |
2947 | if (!body->getArgument(0).getType().isIndex()) |
2948 | return emitOpError( |
2949 | "expected body first argument to be an index argument for " |
2950 | "the induction variable" ); |
2951 | |
2952 | auto opNumResults = getNumResults(); |
2953 | if (opNumResults == 0) |
2954 | return mlir::success(); |
2955 | |
2956 | if (getFinalValue()) { |
2957 | if (getUnordered()) |
2958 | return emitOpError("unordered loop has no final value" ); |
2959 | opNumResults--; |
2960 | } |
2961 | if (getNumIterOperands() != opNumResults) |
2962 | return emitOpError( |
2963 | "mismatch in number of loop-carried values and defined values" ); |
2964 | if (getNumRegionIterArgs() != opNumResults) |
2965 | return emitOpError( |
2966 | "mismatch in number of basic block args and defined values" ); |
2967 | auto iterOperands = getIterOperands(); |
2968 | auto iterArgs = getRegionIterArgs(); |
2969 | auto opResults = getFinalValue() ? getResults().drop_front() : getResults(); |
2970 | unsigned i = 0u; |
2971 | for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { |
2972 | if (std::get<0>(e).getType() != std::get<2>(e).getType()) |
2973 | return emitOpError() << "types mismatch between " << i |
2974 | << "th iter operand and defined value" ; |
2975 | if (std::get<1>(e).getType() != std::get<2>(e).getType()) |
2976 | return emitOpError() << "types mismatch between " << i |
2977 | << "th iter region arg and defined value" ; |
2978 | |
2979 | i++; |
2980 | } |
2981 | auto reduceAttrs = getReduceAttrsAttr(); |
2982 | if (getNumReduceOperands() != (reduceAttrs ? reduceAttrs.size() : 0)) |
2983 | return emitOpError( |
2984 | "mismatch in number of reduction variables and reduction attributes" ); |
2985 | return mlir::success(); |
2986 | } |
2987 | |
2988 | void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) { |
2989 | bool printBlockTerminators = false; |
2990 | p << ' ' << getInductionVar() << " = " << getLowerBound() << " to " |
2991 | << getUpperBound() << " step " << getStep(); |
2992 | if (getUnordered()) |
2993 | p << " unordered" ; |
2994 | if (hasReduceOperands()) { |
2995 | p << " reduce(" ; |
2996 | auto attrs = getReduceAttrsAttr(); |
2997 | auto operands = getReduceOperands(); |
2998 | llvm::interleaveComma(llvm::zip(attrs, operands), p, [&](auto it) { |
2999 | p << std::get<0>(it) << " -> " << std::get<1>(it) << " : " |
3000 | << std::get<1>(it).getType(); |
3001 | }); |
3002 | p << ')'; |
3003 | printBlockTerminators = true; |
3004 | } |
3005 | if (hasIterOperands()) { |
3006 | p << " iter_args(" ; |
3007 | auto regionArgs = getRegionIterArgs(); |
3008 | auto operands = getIterOperands(); |
3009 | llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) { |
3010 | p << std::get<0>(it) << " = " << std::get<1>(it); |
3011 | }); |
3012 | p << ") -> (" << getResultTypes() << ')'; |
3013 | printBlockTerminators = true; |
3014 | } else if (getFinalValue()) { |
3015 | p << " -> " << getResultTypes(); |
3016 | printBlockTerminators = true; |
3017 | } |
3018 | p.printOptionalAttrDictWithKeyword( |
3019 | (*this)->getAttrs(), |
3020 | {"unordered" , "finalValue" , "reduceAttrs" , "operandSegmentSizes" }); |
3021 | p << ' '; |
3022 | p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, |
3023 | printBlockTerminators); |
3024 | } |
3025 | |
3026 | llvm::SmallVector<mlir::Region *> fir::DoLoopOp::getLoopRegions() { |
3027 | return {&getRegion()}; |
3028 | } |
3029 | |
3030 | /// Translate a value passed as an iter_arg to the corresponding block |
3031 | /// argument in the body of the loop. |
3032 | mlir::BlockArgument fir::DoLoopOp::iterArgToBlockArg(mlir::Value iterArg) { |
3033 | for (auto i : llvm::enumerate(getInitArgs())) |
3034 | if (iterArg == i.value()) |
3035 | return getRegion().front().getArgument(i.index() + 1); |
3036 | return {}; |
3037 | } |
3038 | |
3039 | /// Translate the result vector (by index number) to the corresponding value |
3040 | /// to the `fir.result` Op. |
3041 | void fir::DoLoopOp::resultToSourceOps( |
3042 | llvm::SmallVectorImpl<mlir::Value> &results, unsigned resultNum) { |
3043 | auto oper = getFinalValue() ? resultNum + 1 : resultNum; |
3044 | auto *term = getRegion().front().getTerminator(); |
3045 | if (oper < term->getNumOperands()) |
3046 | results.push_back(term->getOperand(oper)); |
3047 | } |
3048 | |
3049 | /// Translate the block argument (by index number) to the corresponding value |
3050 | /// passed as an iter_arg to the parent DoLoopOp. |
3051 | mlir::Value fir::DoLoopOp::blockArgToSourceOp(unsigned blockArgNum) { |
3052 | if (blockArgNum > 0 && blockArgNum <= getInitArgs().size()) |
3053 | return getInitArgs()[blockArgNum - 1]; |
3054 | return {}; |
3055 | } |
3056 | |
3057 | std::optional<llvm::MutableArrayRef<mlir::OpOperand>> |
3058 | fir::DoLoopOp::getYieldedValuesMutable() { |
3059 | auto *term = getRegion().front().getTerminator(); |
3060 | return getFinalValue() ? term->getOpOperands().drop_front() |
3061 | : term->getOpOperands(); |
3062 | } |
3063 | |
3064 | //===----------------------------------------------------------------------===// |
3065 | // DTEntryOp |
3066 | //===----------------------------------------------------------------------===// |
3067 | |
3068 | mlir::ParseResult fir::DTEntryOp::parse(mlir::OpAsmParser &parser, |
3069 | mlir::OperationState &result) { |
3070 | llvm::StringRef methodName; |
3071 | // allow `methodName` or `"methodName"` |
3072 | if (failed(parser.parseOptionalKeyword(&methodName))) { |
3073 | mlir::StringAttr methodAttr; |
3074 | if (parser.parseAttribute(methodAttr, getMethodAttrName(result.name), |
3075 | result.attributes)) |
3076 | return mlir::failure(); |
3077 | } else { |
3078 | result.addAttribute(getMethodAttrName(result.name), |
3079 | parser.getBuilder().getStringAttr(methodName)); |
3080 | } |
3081 | mlir::SymbolRefAttr calleeAttr; |
3082 | if (parser.parseComma() || |
3083 | parser.parseAttribute(calleeAttr, fir::DTEntryOp::getProcAttrNameStr(), |
3084 | result.attributes)) |
3085 | return mlir::failure(); |
3086 | return mlir::success(); |
3087 | } |
3088 | |
3089 | void fir::DTEntryOp::print(mlir::OpAsmPrinter &p) { |
3090 | p << ' ' << getMethodAttr() << ", " << getProcAttr(); |
3091 | } |
3092 | |
3093 | //===----------------------------------------------------------------------===// |
3094 | // ReboxOp |
3095 | //===----------------------------------------------------------------------===// |
3096 | |
3097 | /// Get the scalar type related to a fir.box type. |
3098 | /// Example: return f32 for !fir.box<!fir.heap<!fir.array<?x?xf32>>. |
3099 | static mlir::Type getBoxScalarEleTy(mlir::Type boxTy) { |
3100 | auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(boxTy); |
3101 | if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy)) |
3102 | return seqTy.getEleTy(); |
3103 | return eleTy; |
3104 | } |
3105 | |
3106 | /// Test if \p t1 and \p t2 are compatible character types (if they can |
3107 | /// represent the same type at runtime). |
3108 | static bool areCompatibleCharacterTypes(mlir::Type t1, mlir::Type t2) { |
3109 | auto c1 = mlir::dyn_cast<fir::CharacterType>(t1); |
3110 | auto c2 = mlir::dyn_cast<fir::CharacterType>(t2); |
3111 | if (!c1 || !c2) |
3112 | return false; |
3113 | if (c1.hasDynamicLen() || c2.hasDynamicLen()) |
3114 | return true; |
3115 | return c1.getLen() == c2.getLen(); |
3116 | } |
3117 | |
3118 | llvm::LogicalResult fir::ReboxOp::verify() { |
3119 | auto inputBoxTy = getBox().getType(); |
3120 | if (fir::isa_unknown_size_box(inputBoxTy)) |
3121 | return emitOpError("box operand must not have unknown rank or type" ); |
3122 | auto outBoxTy = getType(); |
3123 | if (fir::isa_unknown_size_box(outBoxTy)) |
3124 | return emitOpError("result type must not have unknown rank or type" ); |
3125 | auto inputRank = fir::getBoxRank(inputBoxTy); |
3126 | auto inputEleTy = getBoxScalarEleTy(inputBoxTy); |
3127 | auto outRank = fir::getBoxRank(outBoxTy); |
3128 | auto outEleTy = getBoxScalarEleTy(outBoxTy); |
3129 | |
3130 | if (auto sliceVal = getSlice()) { |
3131 | // Slicing case |
3132 | if (mlir::cast<fir::SliceType>(sliceVal.getType()).getRank() != inputRank) |
3133 | return emitOpError("slice operand rank must match box operand rank" ); |
3134 | if (auto shapeVal = getShape()) { |
3135 | if (auto shiftTy = mlir::dyn_cast<fir::ShiftType>(shapeVal.getType())) { |
3136 | if (shiftTy.getRank() != inputRank) |
3137 | return emitOpError("shape operand and input box ranks must match " |
3138 | "when there is a slice" ); |
3139 | } else { |
3140 | return emitOpError("shape operand must absent or be a fir.shift " |
3141 | "when there is a slice" ); |
3142 | } |
3143 | } |
3144 | if (auto sliceOp = sliceVal.getDefiningOp()) { |
3145 | auto slicedRank = mlir::cast<fir::SliceOp>(sliceOp).getOutRank(); |
3146 | if (slicedRank != outRank) |
3147 | return emitOpError("result type rank and rank after applying slice " |
3148 | "operand must match" ); |
3149 | } |
3150 | } else { |
3151 | // Reshaping case |
3152 | unsigned shapeRank = inputRank; |
3153 | if (auto shapeVal = getShape()) { |
3154 | auto ty = shapeVal.getType(); |
3155 | if (auto shapeTy = mlir::dyn_cast<fir::ShapeType>(ty)) { |
3156 | shapeRank = shapeTy.getRank(); |
3157 | } else if (auto shapeShiftTy = mlir::dyn_cast<fir::ShapeShiftType>(ty)) { |
3158 | shapeRank = shapeShiftTy.getRank(); |
3159 | } else { |
3160 | auto shiftTy = mlir::cast<fir::ShiftType>(ty); |
3161 | shapeRank = shiftTy.getRank(); |
3162 | if (shapeRank != inputRank) |
3163 | return emitOpError("shape operand and input box ranks must match " |
3164 | "when the shape is a fir.shift" ); |
3165 | } |
3166 | } |
3167 | if (shapeRank != outRank) |
3168 | return emitOpError("result type and shape operand ranks must match" ); |
3169 | } |
3170 | |
3171 | if (inputEleTy != outEleTy) { |
3172 | // TODO: check that outBoxTy is a parent type of inputBoxTy for derived |
3173 | // types. |
3174 | // Character input and output types with constant length may be different if |
3175 | // there is a substring in the slice, otherwise, they must match. If any of |
3176 | // the types is a character with dynamic length, the other type can be any |
3177 | // character type. |
3178 | const bool typeCanMismatch = |
3179 | mlir::isa<fir::RecordType>(inputEleTy) || |
3180 | mlir::isa<mlir::NoneType>(outEleTy) || |
3181 | (mlir::isa<mlir::NoneType>(inputEleTy) && |
3182 | mlir::isa<fir::RecordType>(outEleTy)) || |
3183 | (getSlice() && mlir::isa<fir::CharacterType>(inputEleTy)) || |
3184 | (getSlice() && fir::isa_complex(inputEleTy) && |
3185 | mlir::isa<mlir::FloatType>(outEleTy)) || |
3186 | areCompatibleCharacterTypes(inputEleTy, outEleTy); |
3187 | if (!typeCanMismatch) |
3188 | return emitOpError( |
3189 | "op input and output element types must match for intrinsic types" ); |
3190 | } |
3191 | return mlir::success(); |
3192 | } |
3193 | |
3194 | //===----------------------------------------------------------------------===// |
3195 | // ReboxAssumedRankOp |
3196 | //===----------------------------------------------------------------------===// |
3197 | |
3198 | static bool areCompatibleAssumedRankElementType(mlir::Type inputEleTy, |
3199 | mlir::Type outEleTy) { |
3200 | if (inputEleTy == outEleTy) |
3201 | return true; |
3202 | // Output is unlimited polymorphic -> output dynamic type is the same as input |
3203 | // type. |
3204 | if (mlir::isa<mlir::NoneType>(Val: outEleTy)) |
3205 | return true; |
3206 | // Output/Input are derived types. Assuming input extends output type, output |
3207 | // dynamic type is the output static type, unless output is polymorphic. |
3208 | if (mlir::isa<fir::RecordType>(inputEleTy) && |
3209 | mlir::isa<fir::RecordType>(outEleTy)) |
3210 | return true; |
3211 | if (areCompatibleCharacterTypes(t1: inputEleTy, t2: outEleTy)) |
3212 | return true; |
3213 | return false; |
3214 | } |
3215 | |
3216 | llvm::LogicalResult fir::ReboxAssumedRankOp::verify() { |
3217 | mlir::Type inputType = getBox().getType(); |
3218 | if (!mlir::isa<fir::BaseBoxType>(inputType) && !fir::isBoxAddress(inputType)) |
3219 | return emitOpError("input must be a box or box address" ); |
3220 | mlir::Type inputEleTy = |
3221 | mlir::cast<fir::BaseBoxType>(fir::unwrapRefType(inputType)) |
3222 | .unwrapInnerType(); |
3223 | mlir::Type outEleTy = |
3224 | mlir::cast<fir::BaseBoxType>(getType()).unwrapInnerType(); |
3225 | if (!areCompatibleAssumedRankElementType(inputEleTy, outEleTy)) |
3226 | return emitOpError("input and output element types are incompatible" ); |
3227 | return mlir::success(); |
3228 | } |
3229 | |
3230 | void fir::ReboxAssumedRankOp::getEffects( |
3231 | llvm::SmallVectorImpl< |
3232 | mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>> |
3233 | &effects) { |
3234 | mlir::OpOperand &inputBox = getBoxMutable(); |
3235 | if (fir::isBoxAddress(inputBox.get().getType())) |
3236 | effects.emplace_back(mlir::MemoryEffects::Read::get(), &inputBox, |
3237 | mlir::SideEffects::DefaultResource::get()); |
3238 | } |
3239 | |
3240 | //===----------------------------------------------------------------------===// |
3241 | // ResultOp |
3242 | //===----------------------------------------------------------------------===// |
3243 | |
3244 | llvm::LogicalResult fir::ResultOp::verify() { |
3245 | auto *parentOp = (*this)->getParentOp(); |
3246 | auto results = parentOp->getResults(); |
3247 | auto operands = (*this)->getOperands(); |
3248 | |
3249 | if (parentOp->getNumResults() != getNumOperands()) |
3250 | return emitOpError() << "parent of result must have same arity" ; |
3251 | for (auto e : llvm::zip(results, operands)) |
3252 | if (std::get<0>(e).getType() != std::get<1>(e).getType()) |
3253 | return emitOpError() << "types mismatch between result op and its parent" ; |
3254 | return mlir::success(); |
3255 | } |
3256 | |
3257 | //===----------------------------------------------------------------------===// |
3258 | // SaveResultOp |
3259 | //===----------------------------------------------------------------------===// |
3260 | |
3261 | llvm::LogicalResult fir::SaveResultOp::verify() { |
3262 | auto resultType = getValue().getType(); |
3263 | if (resultType != fir::dyn_cast_ptrEleTy(getMemref().getType())) |
3264 | return emitOpError("value type must match memory reference type" ); |
3265 | if (fir::isa_unknown_size_box(resultType)) |
3266 | return emitOpError("cannot save !fir.box of unknown rank or type" ); |
3267 | |
3268 | if (mlir::isa<fir::BoxType>(resultType)) { |
3269 | if (getShape() || !getTypeparams().empty()) |
3270 | return emitOpError( |
3271 | "must not have shape or length operands if the value is a fir.box" ); |
3272 | return mlir::success(); |
3273 | } |
3274 | |
3275 | // fir.record or fir.array case. |
3276 | unsigned shapeTyRank = 0; |
3277 | if (auto shapeVal = getShape()) { |
3278 | auto shapeTy = shapeVal.getType(); |
3279 | if (auto s = mlir::dyn_cast<fir::ShapeType>(shapeTy)) |
3280 | shapeTyRank = s.getRank(); |
3281 | else |
3282 | shapeTyRank = mlir::cast<fir::ShapeShiftType>(shapeTy).getRank(); |
3283 | } |
3284 | |
3285 | auto eleTy = resultType; |
3286 | if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(resultType)) { |
3287 | if (seqTy.getDimension() != shapeTyRank) |
3288 | emitOpError("shape operand must be provided and have the value rank " |
3289 | "when the value is a fir.array" ); |
3290 | eleTy = seqTy.getEleTy(); |
3291 | } else { |
3292 | if (shapeTyRank != 0) |
3293 | emitOpError( |
3294 | "shape operand should only be provided if the value is a fir.array" ); |
3295 | } |
3296 | |
3297 | if (auto recTy = mlir::dyn_cast<fir::RecordType>(eleTy)) { |
3298 | if (recTy.getNumLenParams() != getTypeparams().size()) |
3299 | emitOpError("length parameters number must match with the value type " |
3300 | "length parameters" ); |
3301 | } else if (auto charTy = mlir::dyn_cast<fir::CharacterType>(eleTy)) { |
3302 | if (getTypeparams().size() > 1) |
3303 | emitOpError("no more than one length parameter must be provided for " |
3304 | "character value" ); |
3305 | } else { |
3306 | if (!getTypeparams().empty()) |
3307 | emitOpError("length parameters must not be provided for this value type" ); |
3308 | } |
3309 | |
3310 | return mlir::success(); |
3311 | } |
3312 | |
3313 | //===----------------------------------------------------------------------===// |
3314 | // IntegralSwitchTerminator |
3315 | //===----------------------------------------------------------------------===// |
3316 | static constexpr llvm::StringRef getCompareOffsetAttr() { |
3317 | return "compare_operand_offsets" ; |
3318 | } |
3319 | |
3320 | static constexpr llvm::StringRef getTargetOffsetAttr() { |
3321 | return "target_operand_offsets" ; |
3322 | } |
3323 | |
3324 | template <typename OpT> |
3325 | static llvm::LogicalResult verifyIntegralSwitchTerminator(OpT op) { |
3326 | if (!mlir::isa<mlir::IntegerType, mlir::IndexType, fir::IntegerType>( |
3327 | op.getSelector().getType())) |
3328 | return op.emitOpError("must be an integer" ); |
3329 | auto cases = |
3330 | op->template getAttrOfType<mlir::ArrayAttr>(op.getCasesAttr()).getValue(); |
3331 | auto count = op.getNumDest(); |
3332 | if (count == 0) |
3333 | return op.emitOpError("must have at least one successor" ); |
3334 | if (op.getNumConditions() != count) |
3335 | return op.emitOpError("number of cases and targets don't match" ); |
3336 | if (op.targetOffsetSize() != count) |
3337 | return op.emitOpError("incorrect number of successor operand groups" ); |
3338 | for (decltype(count) i = 0; i != count; ++i) { |
3339 | if (!mlir::isa<mlir::IntegerAttr, mlir::UnitAttr>(cases[i])) |
3340 | return op.emitOpError("invalid case alternative" ); |
3341 | } |
3342 | return mlir::success(); |
3343 | } |
3344 | |
3345 | static mlir::ParseResult parseIntegralSwitchTerminator( |
3346 | mlir::OpAsmParser &parser, mlir::OperationState &result, |
3347 | llvm::StringRef casesAttr, llvm::StringRef operandSegmentAttr) { |
3348 | mlir::OpAsmParser::UnresolvedOperand selector; |
3349 | mlir::Type type; |
3350 | if (fir::parseSelector(parser, result, selector, type)) |
3351 | return mlir::failure(); |
3352 | |
3353 | llvm::SmallVector<mlir::Attribute> ivalues; |
3354 | llvm::SmallVector<mlir::Block *> dests; |
3355 | llvm::SmallVector<llvm::SmallVector<mlir::Value>> destArgs; |
3356 | while (true) { |
3357 | mlir::Attribute ivalue; // Integer or Unit |
3358 | mlir::Block *dest; |
3359 | llvm::SmallVector<mlir::Value> destArg; |
3360 | mlir::NamedAttrList temp; |
3361 | if (parser.parseAttribute(result&: ivalue, attrName: "i" , attrs&: temp) || parser.parseComma() || |
3362 | parser.parseSuccessorAndUseList(dest, operands&: destArg)) |
3363 | return mlir::failure(); |
3364 | ivalues.push_back(Elt: ivalue); |
3365 | dests.push_back(Elt: dest); |
3366 | destArgs.push_back(Elt: destArg); |
3367 | if (!parser.parseOptionalRSquare()) |
3368 | break; |
3369 | if (parser.parseComma()) |
3370 | return mlir::failure(); |
3371 | } |
3372 | auto &bld = parser.getBuilder(); |
3373 | result.addAttribute(casesAttr, bld.getArrayAttr(ivalues)); |
3374 | llvm::SmallVector<int32_t> argOffs; |
3375 | int32_t sumArgs = 0; |
3376 | const auto count = dests.size(); |
3377 | for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { |
3378 | result.addSuccessors(successor: dests[i]); |
3379 | result.addOperands(newOperands: destArgs[i]); |
3380 | auto argSize = destArgs[i].size(); |
3381 | argOffs.push_back(Elt: argSize); |
3382 | sumArgs += argSize; |
3383 | } |
3384 | result.addAttribute(operandSegmentAttr, |
3385 | bld.getDenseI32ArrayAttr({1, 0, sumArgs})); |
3386 | result.addAttribute(getTargetOffsetAttr(), bld.getDenseI32ArrayAttr(argOffs)); |
3387 | return mlir::success(); |
3388 | } |
3389 | |
3390 | template <typename OpT> |
3391 | static void printIntegralSwitchTerminator(OpT op, mlir::OpAsmPrinter &p) { |
3392 | p << ' '; |
3393 | p.printOperand(op.getSelector()); |
3394 | p << " : " << op.getSelector().getType() << " [" ; |
3395 | auto cases = |
3396 | op->template getAttrOfType<mlir::ArrayAttr>(op.getCasesAttr()).getValue(); |
3397 | auto count = op.getNumConditions(); |
3398 | for (decltype(count) i = 0; i != count; ++i) { |
3399 | if (i) |
3400 | p << ", " ; |
3401 | auto &attr = cases[i]; |
3402 | if (auto intAttr = mlir::dyn_cast_or_null<mlir::IntegerAttr>(attr)) |
3403 | p << intAttr.getValue(); |
3404 | else |
3405 | p.printAttribute(attr); |
3406 | p << ", " ; |
3407 | op.printSuccessorAtIndex(p, i); |
3408 | } |
3409 | p << ']'; |
3410 | p.printOptionalAttrDict( |
3411 | attrs: op->getAttrs(), elidedAttrs: {op.getCasesAttr(), getCompareOffsetAttr(), |
3412 | getTargetOffsetAttr(), op.getOperandSegmentSizeAttr()}); |
3413 | } |
3414 | |
3415 | //===----------------------------------------------------------------------===// |
3416 | // SelectOp |
3417 | //===----------------------------------------------------------------------===// |
3418 | |
3419 | llvm::LogicalResult fir::SelectOp::verify() { |
3420 | return verifyIntegralSwitchTerminator(*this); |
3421 | } |
3422 | |
3423 | mlir::ParseResult fir::SelectOp::parse(mlir::OpAsmParser &parser, |
3424 | mlir::OperationState &result) { |
3425 | return parseIntegralSwitchTerminator(parser, result, getCasesAttr(), |
3426 | getOperandSegmentSizeAttr()); |
3427 | } |
3428 | |
3429 | void fir::SelectOp::print(mlir::OpAsmPrinter &p) { |
3430 | printIntegralSwitchTerminator(*this, p); |
3431 | } |
3432 | |
3433 | template <typename A, typename... AdditionalArgs> |
3434 | static A getSubOperands(unsigned pos, A allArgs, mlir::DenseI32ArrayAttr ranges, |
3435 | AdditionalArgs &&...additionalArgs) { |
3436 | unsigned start = 0; |
3437 | for (unsigned i = 0; i < pos; ++i) |
3438 | start += ranges[i]; |
3439 | return allArgs.slice(start, ranges[pos], |
3440 | std::forward<AdditionalArgs>(additionalArgs)...); |
3441 | } |
3442 | |
3443 | static mlir::MutableOperandRange |
3444 | getMutableSuccessorOperands(unsigned pos, mlir::MutableOperandRange operands, |
3445 | llvm::StringRef offsetAttr) { |
3446 | mlir::Operation *owner = operands.getOwner(); |
3447 | mlir::NamedAttribute targetOffsetAttr = |
3448 | *owner->getAttrDictionary().getNamed(offsetAttr); |
3449 | return getSubOperands( |
3450 | pos, operands, |
3451 | mlir::cast<mlir::DenseI32ArrayAttr>(targetOffsetAttr.getValue()), |
3452 | mlir::MutableOperandRange::OperandSegment(pos, targetOffsetAttr)); |
3453 | } |
3454 | |
3455 | std::optional<mlir::OperandRange> fir::SelectOp::getCompareOperands(unsigned) { |
3456 | return {}; |
3457 | } |
3458 | |
3459 | std::optional<llvm::ArrayRef<mlir::Value>> |
3460 | fir::SelectOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { |
3461 | return {}; |
3462 | } |
3463 | |
3464 | mlir::SuccessorOperands fir::SelectOp::getSuccessorOperands(unsigned oper) { |
3465 | return mlir::SuccessorOperands(::getMutableSuccessorOperands( |
3466 | oper, getTargetArgsMutable(), getTargetOffsetAttr())); |
3467 | } |
3468 | |
3469 | std::optional<llvm::ArrayRef<mlir::Value>> |
3470 | fir::SelectOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, |
3471 | unsigned oper) { |
3472 | auto a = |
3473 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); |
3474 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
3475 | getOperandSegmentSizeAttr()); |
3476 | return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; |
3477 | } |
3478 | |
3479 | std::optional<mlir::ValueRange> |
3480 | fir::SelectOp::getSuccessorOperands(mlir::ValueRange operands, unsigned oper) { |
3481 | auto a = |
3482 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); |
3483 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
3484 | getOperandSegmentSizeAttr()); |
3485 | return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; |
3486 | } |
3487 | |
3488 | unsigned fir::SelectOp::targetOffsetSize() { |
3489 | return (*this) |
3490 | ->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()) |
3491 | .size(); |
3492 | } |
3493 | |
3494 | //===----------------------------------------------------------------------===// |
3495 | // SelectCaseOp |
3496 | //===----------------------------------------------------------------------===// |
3497 | |
3498 | std::optional<mlir::OperandRange> |
3499 | fir::SelectCaseOp::getCompareOperands(unsigned cond) { |
3500 | auto a = |
3501 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getCompareOffsetAttr()); |
3502 | return {getSubOperands(cond, getCompareArgs(), a)}; |
3503 | } |
3504 | |
3505 | std::optional<llvm::ArrayRef<mlir::Value>> |
3506 | fir::SelectCaseOp::getCompareOperands(llvm::ArrayRef<mlir::Value> operands, |
3507 | unsigned cond) { |
3508 | auto a = |
3509 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getCompareOffsetAttr()); |
3510 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
3511 | getOperandSegmentSizeAttr()); |
3512 | return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; |
3513 | } |
3514 | |
3515 | std::optional<mlir::ValueRange> |
3516 | fir::SelectCaseOp::getCompareOperands(mlir::ValueRange operands, |
3517 | unsigned cond) { |
3518 | auto a = |
3519 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getCompareOffsetAttr()); |
3520 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
3521 | getOperandSegmentSizeAttr()); |
3522 | return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; |
3523 | } |
3524 | |
3525 | mlir::SuccessorOperands fir::SelectCaseOp::getSuccessorOperands(unsigned oper) { |
3526 | return mlir::SuccessorOperands(::getMutableSuccessorOperands( |
3527 | oper, getTargetArgsMutable(), getTargetOffsetAttr())); |
3528 | } |
3529 | |
3530 | std::optional<llvm::ArrayRef<mlir::Value>> |
3531 | fir::SelectCaseOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, |
3532 | unsigned oper) { |
3533 | auto a = |
3534 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); |
3535 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
3536 | getOperandSegmentSizeAttr()); |
3537 | return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; |
3538 | } |
3539 | |
3540 | std::optional<mlir::ValueRange> |
3541 | fir::SelectCaseOp::getSuccessorOperands(mlir::ValueRange operands, |
3542 | unsigned oper) { |
3543 | auto a = |
3544 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); |
3545 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
3546 | getOperandSegmentSizeAttr()); |
3547 | return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; |
3548 | } |
3549 | |
3550 | // parser for fir.select_case Op |
3551 | mlir::ParseResult fir::SelectCaseOp::parse(mlir::OpAsmParser &parser, |
3552 | mlir::OperationState &result) { |
3553 | mlir::OpAsmParser::UnresolvedOperand selector; |
3554 | mlir::Type type; |
3555 | if (fir::parseSelector(parser, result, selector, type)) |
3556 | return mlir::failure(); |
3557 | |
3558 | llvm::SmallVector<mlir::Attribute> attrs; |
3559 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> opers; |
3560 | llvm::SmallVector<mlir::Block *> dests; |
3561 | llvm::SmallVector<llvm::SmallVector<mlir::Value>> destArgs; |
3562 | llvm::SmallVector<std::int32_t> argOffs; |
3563 | std::int32_t offSize = 0; |
3564 | while (true) { |
3565 | mlir::Attribute attr; |
3566 | mlir::Block *dest; |
3567 | llvm::SmallVector<mlir::Value> destArg; |
3568 | mlir::NamedAttrList temp; |
3569 | if (parser.parseAttribute(attr, "a" , temp) || isValidCaseAttr(attr) || |
3570 | parser.parseComma()) |
3571 | return mlir::failure(); |
3572 | attrs.push_back(attr); |
3573 | if (mlir::dyn_cast_or_null<mlir::UnitAttr>(attr)) { |
3574 | argOffs.push_back(0); |
3575 | } else if (mlir::dyn_cast_or_null<fir::ClosedIntervalAttr>(attr)) { |
3576 | mlir::OpAsmParser::UnresolvedOperand oper1; |
3577 | mlir::OpAsmParser::UnresolvedOperand oper2; |
3578 | if (parser.parseOperand(oper1) || parser.parseComma() || |
3579 | parser.parseOperand(oper2) || parser.parseComma()) |
3580 | return mlir::failure(); |
3581 | opers.push_back(oper1); |
3582 | opers.push_back(oper2); |
3583 | argOffs.push_back(2); |
3584 | offSize += 2; |
3585 | } else { |
3586 | mlir::OpAsmParser::UnresolvedOperand oper; |
3587 | if (parser.parseOperand(oper) || parser.parseComma()) |
3588 | return mlir::failure(); |
3589 | opers.push_back(oper); |
3590 | argOffs.push_back(1); |
3591 | ++offSize; |
3592 | } |
3593 | if (parser.parseSuccessorAndUseList(dest, destArg)) |
3594 | return mlir::failure(); |
3595 | dests.push_back(dest); |
3596 | destArgs.push_back(destArg); |
3597 | if (mlir::succeeded(parser.parseOptionalRSquare())) |
3598 | break; |
3599 | if (parser.parseComma()) |
3600 | return mlir::failure(); |
3601 | } |
3602 | result.addAttribute(fir::SelectCaseOp::getCasesAttr(), |
3603 | parser.getBuilder().getArrayAttr(attrs)); |
3604 | if (parser.resolveOperands(opers, type, result.operands)) |
3605 | return mlir::failure(); |
3606 | llvm::SmallVector<int32_t> targOffs; |
3607 | int32_t toffSize = 0; |
3608 | const auto count = dests.size(); |
3609 | for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { |
3610 | result.addSuccessors(dests[i]); |
3611 | result.addOperands(destArgs[i]); |
3612 | auto argSize = destArgs[i].size(); |
3613 | targOffs.push_back(argSize); |
3614 | toffSize += argSize; |
3615 | } |
3616 | auto &bld = parser.getBuilder(); |
3617 | result.addAttribute(fir::SelectCaseOp::getOperandSegmentSizeAttr(), |
3618 | bld.getDenseI32ArrayAttr({1, offSize, toffSize})); |
3619 | result.addAttribute(getCompareOffsetAttr(), |
3620 | bld.getDenseI32ArrayAttr(argOffs)); |
3621 | result.addAttribute(getTargetOffsetAttr(), |
3622 | bld.getDenseI32ArrayAttr(targOffs)); |
3623 | return mlir::success(); |
3624 | } |
3625 | |
3626 | void fir::SelectCaseOp::print(mlir::OpAsmPrinter &p) { |
3627 | p << ' '; |
3628 | p.printOperand(getSelector()); |
3629 | p << " : " << getSelector().getType() << " [" ; |
3630 | auto cases = |
3631 | getOperation()->getAttrOfType<mlir::ArrayAttr>(getCasesAttr()).getValue(); |
3632 | auto count = getNumConditions(); |
3633 | for (decltype(count) i = 0; i != count; ++i) { |
3634 | if (i) |
3635 | p << ", " ; |
3636 | p << cases[i] << ", " ; |
3637 | if (!mlir::isa<mlir::UnitAttr>(cases[i])) { |
3638 | auto caseArgs = *getCompareOperands(i); |
3639 | p.printOperand(*caseArgs.begin()); |
3640 | p << ", " ; |
3641 | if (mlir::isa<fir::ClosedIntervalAttr>(cases[i])) { |
3642 | p.printOperand(*(++caseArgs.begin())); |
3643 | p << ", " ; |
3644 | } |
3645 | } |
3646 | printSuccessorAtIndex(p, i); |
3647 | } |
3648 | p << ']'; |
3649 | p.printOptionalAttrDict(getOperation()->getAttrs(), |
3650 | {getCasesAttr(), getCompareOffsetAttr(), |
3651 | getTargetOffsetAttr(), getOperandSegmentSizeAttr()}); |
3652 | } |
3653 | |
3654 | unsigned fir::SelectCaseOp::compareOffsetSize() { |
3655 | return (*this) |
3656 | ->getAttrOfType<mlir::DenseI32ArrayAttr>(getCompareOffsetAttr()) |
3657 | .size(); |
3658 | } |
3659 | |
3660 | unsigned fir::SelectCaseOp::targetOffsetSize() { |
3661 | return (*this) |
3662 | ->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()) |
3663 | .size(); |
3664 | } |
3665 | |
3666 | void fir::SelectCaseOp::build(mlir::OpBuilder &builder, |
3667 | mlir::OperationState &result, |
3668 | mlir::Value selector, |
3669 | llvm::ArrayRef<mlir::Attribute> compareAttrs, |
3670 | llvm::ArrayRef<mlir::ValueRange> cmpOperands, |
3671 | llvm::ArrayRef<mlir::Block *> destinations, |
3672 | llvm::ArrayRef<mlir::ValueRange> destOperands, |
3673 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
3674 | result.addOperands(selector); |
3675 | result.addAttribute(getCasesAttr(), builder.getArrayAttr(compareAttrs)); |
3676 | llvm::SmallVector<int32_t> operOffs; |
3677 | int32_t operSize = 0; |
3678 | for (auto attr : compareAttrs) { |
3679 | if (mlir::isa<fir::ClosedIntervalAttr>(attr)) { |
3680 | operOffs.push_back(2); |
3681 | operSize += 2; |
3682 | } else if (mlir::isa<mlir::UnitAttr>(attr)) { |
3683 | operOffs.push_back(0); |
3684 | } else { |
3685 | operOffs.push_back(1); |
3686 | ++operSize; |
3687 | } |
3688 | } |
3689 | for (auto ops : cmpOperands) |
3690 | result.addOperands(ops); |
3691 | result.addAttribute(getCompareOffsetAttr(), |
3692 | builder.getDenseI32ArrayAttr(operOffs)); |
3693 | const auto count = destinations.size(); |
3694 | for (auto d : destinations) |
3695 | result.addSuccessors(d); |
3696 | const auto opCount = destOperands.size(); |
3697 | llvm::SmallVector<std::int32_t> argOffs; |
3698 | std::int32_t sumArgs = 0; |
3699 | for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { |
3700 | if (i < opCount) { |
3701 | result.addOperands(destOperands[i]); |
3702 | const auto argSz = destOperands[i].size(); |
3703 | argOffs.push_back(argSz); |
3704 | sumArgs += argSz; |
3705 | } else { |
3706 | argOffs.push_back(0); |
3707 | } |
3708 | } |
3709 | result.addAttribute(getOperandSegmentSizeAttr(), |
3710 | builder.getDenseI32ArrayAttr({1, operSize, sumArgs})); |
3711 | result.addAttribute(getTargetOffsetAttr(), |
3712 | builder.getDenseI32ArrayAttr(argOffs)); |
3713 | result.addAttributes(attributes); |
3714 | } |
3715 | |
3716 | /// This builder has a slightly simplified interface in that the list of |
3717 | /// operands need not be partitioned by the builder. Instead the operands are |
3718 | /// partitioned here, before being passed to the default builder. This |
3719 | /// partitioning is unchecked, so can go awry on bad input. |
3720 | void fir::SelectCaseOp::build(mlir::OpBuilder &builder, |
3721 | mlir::OperationState &result, |
3722 | mlir::Value selector, |
3723 | llvm::ArrayRef<mlir::Attribute> compareAttrs, |
3724 | llvm::ArrayRef<mlir::Value> cmpOpList, |
3725 | llvm::ArrayRef<mlir::Block *> destinations, |
3726 | llvm::ArrayRef<mlir::ValueRange> destOperands, |
3727 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
3728 | llvm::SmallVector<mlir::ValueRange> cmpOpers; |
3729 | auto iter = cmpOpList.begin(); |
3730 | for (auto &attr : compareAttrs) { |
3731 | if (mlir::isa<fir::ClosedIntervalAttr>(attr)) { |
3732 | cmpOpers.push_back(mlir::ValueRange({iter, iter + 2})); |
3733 | iter += 2; |
3734 | } else if (mlir::isa<mlir::UnitAttr>(attr)) { |
3735 | cmpOpers.push_back(mlir::ValueRange{}); |
3736 | } else { |
3737 | cmpOpers.push_back(mlir::ValueRange({iter, iter + 1})); |
3738 | ++iter; |
3739 | } |
3740 | } |
3741 | build(builder, result, selector, compareAttrs, cmpOpers, destinations, |
3742 | destOperands, attributes); |
3743 | } |
3744 | |
3745 | llvm::LogicalResult fir::SelectCaseOp::verify() { |
3746 | if (!mlir::isa<mlir::IntegerType, mlir::IndexType, fir::IntegerType, |
3747 | fir::LogicalType, fir::CharacterType>(getSelector().getType())) |
3748 | return emitOpError("must be an integer, character, or logical" ); |
3749 | auto cases = |
3750 | getOperation()->getAttrOfType<mlir::ArrayAttr>(getCasesAttr()).getValue(); |
3751 | auto count = getNumDest(); |
3752 | if (count == 0) |
3753 | return emitOpError("must have at least one successor" ); |
3754 | if (getNumConditions() != count) |
3755 | return emitOpError("number of conditions and successors don't match" ); |
3756 | if (compareOffsetSize() != count) |
3757 | return emitOpError("incorrect number of compare operand groups" ); |
3758 | if (targetOffsetSize() != count) |
3759 | return emitOpError("incorrect number of successor operand groups" ); |
3760 | for (decltype(count) i = 0; i != count; ++i) { |
3761 | auto &attr = cases[i]; |
3762 | if (!(mlir::isa<fir::PointIntervalAttr>(attr) || |
3763 | mlir::isa<fir::LowerBoundAttr>(attr) || |
3764 | mlir::isa<fir::UpperBoundAttr>(attr) || |
3765 | mlir::isa<fir::ClosedIntervalAttr>(attr) || |
3766 | mlir::isa<mlir::UnitAttr>(attr))) |
3767 | return emitOpError("incorrect select case attribute type" ); |
3768 | } |
3769 | return mlir::success(); |
3770 | } |
3771 | |
3772 | //===----------------------------------------------------------------------===// |
3773 | // SelectRankOp |
3774 | //===----------------------------------------------------------------------===// |
3775 | |
3776 | llvm::LogicalResult fir::SelectRankOp::verify() { |
3777 | return verifyIntegralSwitchTerminator(*this); |
3778 | } |
3779 | |
3780 | mlir::ParseResult fir::SelectRankOp::parse(mlir::OpAsmParser &parser, |
3781 | mlir::OperationState &result) { |
3782 | return parseIntegralSwitchTerminator(parser, result, getCasesAttr(), |
3783 | getOperandSegmentSizeAttr()); |
3784 | } |
3785 | |
3786 | void fir::SelectRankOp::print(mlir::OpAsmPrinter &p) { |
3787 | printIntegralSwitchTerminator(*this, p); |
3788 | } |
3789 | |
3790 | std::optional<mlir::OperandRange> |
3791 | fir::SelectRankOp::getCompareOperands(unsigned) { |
3792 | return {}; |
3793 | } |
3794 | |
3795 | std::optional<llvm::ArrayRef<mlir::Value>> |
3796 | fir::SelectRankOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { |
3797 | return {}; |
3798 | } |
3799 | |
3800 | mlir::SuccessorOperands fir::SelectRankOp::getSuccessorOperands(unsigned oper) { |
3801 | return mlir::SuccessorOperands(::getMutableSuccessorOperands( |
3802 | oper, getTargetArgsMutable(), getTargetOffsetAttr())); |
3803 | } |
3804 | |
3805 | std::optional<llvm::ArrayRef<mlir::Value>> |
3806 | fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, |
3807 | unsigned oper) { |
3808 | auto a = |
3809 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); |
3810 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
3811 | getOperandSegmentSizeAttr()); |
3812 | return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; |
3813 | } |
3814 | |
3815 | std::optional<mlir::ValueRange> |
3816 | fir::SelectRankOp::getSuccessorOperands(mlir::ValueRange operands, |
3817 | unsigned oper) { |
3818 | auto a = |
3819 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); |
3820 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
3821 | getOperandSegmentSizeAttr()); |
3822 | return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; |
3823 | } |
3824 | |
3825 | unsigned fir::SelectRankOp::targetOffsetSize() { |
3826 | return (*this) |
3827 | ->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()) |
3828 | .size(); |
3829 | } |
3830 | |
3831 | //===----------------------------------------------------------------------===// |
3832 | // SelectTypeOp |
3833 | //===----------------------------------------------------------------------===// |
3834 | |
3835 | std::optional<mlir::OperandRange> |
3836 | fir::SelectTypeOp::getCompareOperands(unsigned) { |
3837 | return {}; |
3838 | } |
3839 | |
3840 | std::optional<llvm::ArrayRef<mlir::Value>> |
3841 | fir::SelectTypeOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { |
3842 | return {}; |
3843 | } |
3844 | |
3845 | mlir::SuccessorOperands fir::SelectTypeOp::getSuccessorOperands(unsigned oper) { |
3846 | return mlir::SuccessorOperands(::getMutableSuccessorOperands( |
3847 | oper, getTargetArgsMutable(), getTargetOffsetAttr())); |
3848 | } |
3849 | |
3850 | std::optional<llvm::ArrayRef<mlir::Value>> |
3851 | fir::SelectTypeOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, |
3852 | unsigned oper) { |
3853 | auto a = |
3854 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); |
3855 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
3856 | getOperandSegmentSizeAttr()); |
3857 | return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; |
3858 | } |
3859 | |
3860 | std::optional<mlir::ValueRange> |
3861 | fir::SelectTypeOp::getSuccessorOperands(mlir::ValueRange operands, |
3862 | unsigned oper) { |
3863 | auto a = |
3864 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); |
3865 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
3866 | getOperandSegmentSizeAttr()); |
3867 | return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; |
3868 | } |
3869 | |
3870 | mlir::ParseResult fir::SelectTypeOp::parse(mlir::OpAsmParser &parser, |
3871 | mlir::OperationState &result) { |
3872 | mlir::OpAsmParser::UnresolvedOperand selector; |
3873 | mlir::Type type; |
3874 | if (fir::parseSelector(parser, result, selector, type)) |
3875 | return mlir::failure(); |
3876 | |
3877 | llvm::SmallVector<mlir::Attribute> attrs; |
3878 | llvm::SmallVector<mlir::Block *> dests; |
3879 | llvm::SmallVector<llvm::SmallVector<mlir::Value>> destArgs; |
3880 | while (true) { |
3881 | mlir::Attribute attr; |
3882 | mlir::Block *dest; |
3883 | llvm::SmallVector<mlir::Value> destArg; |
3884 | mlir::NamedAttrList temp; |
3885 | if (parser.parseAttribute(attr, "a" , temp) || parser.parseComma() || |
3886 | parser.parseSuccessorAndUseList(dest, destArg)) |
3887 | return mlir::failure(); |
3888 | attrs.push_back(attr); |
3889 | dests.push_back(dest); |
3890 | destArgs.push_back(destArg); |
3891 | if (mlir::succeeded(parser.parseOptionalRSquare())) |
3892 | break; |
3893 | if (parser.parseComma()) |
3894 | return mlir::failure(); |
3895 | } |
3896 | auto &bld = parser.getBuilder(); |
3897 | result.addAttribute(fir::SelectTypeOp::getCasesAttr(), |
3898 | bld.getArrayAttr(attrs)); |
3899 | llvm::SmallVector<int32_t> argOffs; |
3900 | int32_t offSize = 0; |
3901 | const auto count = dests.size(); |
3902 | for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { |
3903 | result.addSuccessors(dests[i]); |
3904 | result.addOperands(destArgs[i]); |
3905 | auto argSize = destArgs[i].size(); |
3906 | argOffs.push_back(argSize); |
3907 | offSize += argSize; |
3908 | } |
3909 | result.addAttribute(fir::SelectTypeOp::getOperandSegmentSizeAttr(), |
3910 | bld.getDenseI32ArrayAttr({1, 0, offSize})); |
3911 | result.addAttribute(getTargetOffsetAttr(), bld.getDenseI32ArrayAttr(argOffs)); |
3912 | return mlir::success(); |
3913 | } |
3914 | |
3915 | unsigned fir::SelectTypeOp::targetOffsetSize() { |
3916 | return (*this) |
3917 | ->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()) |
3918 | .size(); |
3919 | } |
3920 | |
3921 | void fir::SelectTypeOp::print(mlir::OpAsmPrinter &p) { |
3922 | p << ' '; |
3923 | p.printOperand(getSelector()); |
3924 | p << " : " << getSelector().getType() << " [" ; |
3925 | auto cases = |
3926 | getOperation()->getAttrOfType<mlir::ArrayAttr>(getCasesAttr()).getValue(); |
3927 | auto count = getNumConditions(); |
3928 | for (decltype(count) i = 0; i != count; ++i) { |
3929 | if (i) |
3930 | p << ", " ; |
3931 | p << cases[i] << ", " ; |
3932 | printSuccessorAtIndex(p, i); |
3933 | } |
3934 | p << ']'; |
3935 | p.printOptionalAttrDict(getOperation()->getAttrs(), |
3936 | {getCasesAttr(), getCompareOffsetAttr(), |
3937 | getTargetOffsetAttr(), |
3938 | fir::SelectTypeOp::getOperandSegmentSizeAttr()}); |
3939 | } |
3940 | |
3941 | llvm::LogicalResult fir::SelectTypeOp::verify() { |
3942 | if (!mlir::isa<fir::BaseBoxType>(getSelector().getType())) |
3943 | return emitOpError("must be a fir.class or fir.box type" ); |
3944 | if (auto boxType = mlir::dyn_cast<fir::BoxType>(getSelector().getType())) |
3945 | if (!mlir::isa<mlir::NoneType>(boxType.getEleTy())) |
3946 | return emitOpError("selector must be polymorphic" ); |
3947 | auto typeGuardAttr = getCases(); |
3948 | for (unsigned idx = 0; idx < typeGuardAttr.size(); ++idx) |
3949 | if (mlir::isa<mlir::UnitAttr>(typeGuardAttr[idx]) && |
3950 | idx != typeGuardAttr.size() - 1) |
3951 | return emitOpError("default must be the last attribute" ); |
3952 | auto count = getNumDest(); |
3953 | if (count == 0) |
3954 | return emitOpError("must have at least one successor" ); |
3955 | if (getNumConditions() != count) |
3956 | return emitOpError("number of conditions and successors don't match" ); |
3957 | if (targetOffsetSize() != count) |
3958 | return emitOpError("incorrect number of successor operand groups" ); |
3959 | for (unsigned i = 0; i != count; ++i) { |
3960 | if (!mlir::isa<fir::ExactTypeAttr, fir::SubclassAttr, mlir::UnitAttr>( |
3961 | typeGuardAttr[i])) |
3962 | return emitOpError("invalid type-case alternative" ); |
3963 | } |
3964 | return mlir::success(); |
3965 | } |
3966 | |
3967 | void fir::SelectTypeOp::build(mlir::OpBuilder &builder, |
3968 | mlir::OperationState &result, |
3969 | mlir::Value selector, |
3970 | llvm::ArrayRef<mlir::Attribute> typeOperands, |
3971 | llvm::ArrayRef<mlir::Block *> destinations, |
3972 | llvm::ArrayRef<mlir::ValueRange> destOperands, |
3973 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
3974 | result.addOperands(selector); |
3975 | result.addAttribute(getCasesAttr(), builder.getArrayAttr(typeOperands)); |
3976 | const auto count = destinations.size(); |
3977 | for (mlir::Block *dest : destinations) |
3978 | result.addSuccessors(dest); |
3979 | const auto opCount = destOperands.size(); |
3980 | llvm::SmallVector<int32_t> argOffs; |
3981 | int32_t sumArgs = 0; |
3982 | for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { |
3983 | if (i < opCount) { |
3984 | result.addOperands(destOperands[i]); |
3985 | const auto argSz = destOperands[i].size(); |
3986 | argOffs.push_back(argSz); |
3987 | sumArgs += argSz; |
3988 | } else { |
3989 | argOffs.push_back(0); |
3990 | } |
3991 | } |
3992 | result.addAttribute(getOperandSegmentSizeAttr(), |
3993 | builder.getDenseI32ArrayAttr({1, 0, sumArgs})); |
3994 | result.addAttribute(getTargetOffsetAttr(), |
3995 | builder.getDenseI32ArrayAttr(argOffs)); |
3996 | result.addAttributes(attributes); |
3997 | } |
3998 | |
3999 | //===----------------------------------------------------------------------===// |
4000 | // ShapeOp |
4001 | //===----------------------------------------------------------------------===// |
4002 | |
4003 | llvm::LogicalResult fir::ShapeOp::verify() { |
4004 | auto size = getExtents().size(); |
4005 | auto shapeTy = mlir::dyn_cast<fir::ShapeType>(getType()); |
4006 | assert(shapeTy && "must be a shape type" ); |
4007 | if (shapeTy.getRank() != size) |
4008 | return emitOpError("shape type rank mismatch" ); |
4009 | return mlir::success(); |
4010 | } |
4011 | |
4012 | void fir::ShapeOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
4013 | mlir::ValueRange extents) { |
4014 | auto type = fir::ShapeType::get(builder.getContext(), extents.size()); |
4015 | build(builder, result, type, extents); |
4016 | } |
4017 | |
4018 | //===----------------------------------------------------------------------===// |
4019 | // ShapeShiftOp |
4020 | //===----------------------------------------------------------------------===// |
4021 | |
4022 | llvm::LogicalResult fir::ShapeShiftOp::verify() { |
4023 | auto size = getPairs().size(); |
4024 | if (size < 2 || size > 16 * 2) |
4025 | return emitOpError("incorrect number of args" ); |
4026 | if (size % 2 != 0) |
4027 | return emitOpError("requires a multiple of 2 args" ); |
4028 | auto shapeTy = mlir::dyn_cast<fir::ShapeShiftType>(getType()); |
4029 | assert(shapeTy && "must be a shape shift type" ); |
4030 | if (shapeTy.getRank() * 2 != size) |
4031 | return emitOpError("shape type rank mismatch" ); |
4032 | return mlir::success(); |
4033 | } |
4034 | |
4035 | //===----------------------------------------------------------------------===// |
4036 | // ShiftOp |
4037 | //===----------------------------------------------------------------------===// |
4038 | |
4039 | llvm::LogicalResult fir::ShiftOp::verify() { |
4040 | auto size = getOrigins().size(); |
4041 | auto shiftTy = mlir::dyn_cast<fir::ShiftType>(getType()); |
4042 | assert(shiftTy && "must be a shift type" ); |
4043 | if (shiftTy.getRank() != size) |
4044 | return emitOpError("shift type rank mismatch" ); |
4045 | return mlir::success(); |
4046 | } |
4047 | |
4048 | //===----------------------------------------------------------------------===// |
4049 | // SliceOp |
4050 | //===----------------------------------------------------------------------===// |
4051 | |
4052 | void fir::SliceOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
4053 | mlir::ValueRange trips, mlir::ValueRange path, |
4054 | mlir::ValueRange substr) { |
4055 | const auto rank = trips.size() / 3; |
4056 | auto sliceTy = fir::SliceType::get(builder.getContext(), rank); |
4057 | build(builder, result, sliceTy, trips, path, substr); |
4058 | } |
4059 | |
4060 | /// Return the output rank of a slice op. The output rank must be between 1 and |
4061 | /// the rank of the array being sliced (inclusive). |
4062 | unsigned fir::SliceOp::getOutputRank(mlir::ValueRange triples) { |
4063 | unsigned rank = 0; |
4064 | if (!triples.empty()) { |
4065 | for (unsigned i = 1, end = triples.size(); i < end; i += 3) { |
4066 | auto *op = triples[i].getDefiningOp(); |
4067 | if (!mlir::isa_and_nonnull<fir::UndefOp>(op)) |
4068 | ++rank; |
4069 | } |
4070 | assert(rank > 0); |
4071 | } |
4072 | return rank; |
4073 | } |
4074 | |
4075 | llvm::LogicalResult fir::SliceOp::verify() { |
4076 | auto size = getTriples().size(); |
4077 | if (size < 3 || size > 16 * 3) |
4078 | return emitOpError("incorrect number of args for triple" ); |
4079 | if (size % 3 != 0) |
4080 | return emitOpError("requires a multiple of 3 args" ); |
4081 | auto sliceTy = mlir::dyn_cast<fir::SliceType>(getType()); |
4082 | assert(sliceTy && "must be a slice type" ); |
4083 | if (sliceTy.getRank() * 3 != size) |
4084 | return emitOpError("slice type rank mismatch" ); |
4085 | return mlir::success(); |
4086 | } |
4087 | |
4088 | //===----------------------------------------------------------------------===// |
4089 | // StoreOp |
4090 | //===----------------------------------------------------------------------===// |
4091 | |
4092 | mlir::Type fir::StoreOp::elementType(mlir::Type refType) { |
4093 | return fir::dyn_cast_ptrEleTy(refType); |
4094 | } |
4095 | |
4096 | mlir::ParseResult fir::StoreOp::parse(mlir::OpAsmParser &parser, |
4097 | mlir::OperationState &result) { |
4098 | mlir::Type type; |
4099 | mlir::OpAsmParser::UnresolvedOperand oper; |
4100 | mlir::OpAsmParser::UnresolvedOperand store; |
4101 | if (parser.parseOperand(oper) || parser.parseKeyword("to" ) || |
4102 | parser.parseOperand(store) || |
4103 | parser.parseOptionalAttrDict(result.attributes) || |
4104 | parser.parseColonType(type) || |
4105 | parser.resolveOperand(oper, fir::StoreOp::elementType(type), |
4106 | result.operands) || |
4107 | parser.resolveOperand(store, type, result.operands)) |
4108 | return mlir::failure(); |
4109 | return mlir::success(); |
4110 | } |
4111 | |
4112 | void fir::StoreOp::print(mlir::OpAsmPrinter &p) { |
4113 | p << ' '; |
4114 | p.printOperand(getValue()); |
4115 | p << " to " ; |
4116 | p.printOperand(getMemref()); |
4117 | p.printOptionalAttrDict(getOperation()->getAttrs(), {}); |
4118 | p << " : " << getMemref().getType(); |
4119 | } |
4120 | |
4121 | llvm::LogicalResult fir::StoreOp::verify() { |
4122 | if (getValue().getType() != fir::dyn_cast_ptrEleTy(getMemref().getType())) |
4123 | return emitOpError("store value type must match memory reference type" ); |
4124 | return mlir::success(); |
4125 | } |
4126 | |
4127 | void fir::StoreOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
4128 | mlir::Value value, mlir::Value memref) { |
4129 | build(builder, result, value, memref, {}); |
4130 | } |
4131 | |
4132 | void fir::StoreOp::getEffects( |
4133 | llvm::SmallVectorImpl< |
4134 | mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>> |
4135 | &effects) { |
4136 | effects.emplace_back(mlir::MemoryEffects::Write::get(), &getMemrefMutable(), |
4137 | mlir::SideEffects::DefaultResource::get()); |
4138 | addVolatileMemoryEffects({getMemref().getType()}, effects); |
4139 | } |
4140 | |
4141 | //===----------------------------------------------------------------------===// |
4142 | // CopyOp |
4143 | //===----------------------------------------------------------------------===// |
4144 | |
4145 | void fir::CopyOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
4146 | mlir::Value source, mlir::Value destination, |
4147 | bool noOverlap) { |
4148 | mlir::UnitAttr noOverlapAttr = |
4149 | noOverlap ? builder.getUnitAttr() : mlir::UnitAttr{}; |
4150 | build(builder, result, source, destination, noOverlapAttr); |
4151 | } |
4152 | |
4153 | llvm::LogicalResult fir::CopyOp::verify() { |
4154 | mlir::Type sourceType = fir::unwrapRefType(getSource().getType()); |
4155 | mlir::Type destinationType = fir::unwrapRefType(getDestination().getType()); |
4156 | if (sourceType != destinationType) |
4157 | return emitOpError("source and destination must have the same value type" ); |
4158 | return mlir::success(); |
4159 | } |
4160 | |
4161 | void fir::CopyOp::getEffects( |
4162 | llvm::SmallVectorImpl< |
4163 | mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>> |
4164 | &effects) { |
4165 | effects.emplace_back(mlir::MemoryEffects::Read::get(), &getSourceMutable(), |
4166 | mlir::SideEffects::DefaultResource::get()); |
4167 | effects.emplace_back(mlir::MemoryEffects::Write::get(), |
4168 | &getDestinationMutable(), |
4169 | mlir::SideEffects::DefaultResource::get()); |
4170 | addVolatileMemoryEffects({getDestination().getType(), getSource().getType()}, |
4171 | effects); |
4172 | } |
4173 | |
4174 | //===----------------------------------------------------------------------===// |
4175 | // StringLitOp |
4176 | //===----------------------------------------------------------------------===// |
4177 | |
4178 | inline fir::CharacterType::KindTy stringLitOpGetKind(fir::StringLitOp op) { |
4179 | auto eleTy = mlir::cast<fir::SequenceType>(op.getType()).getElementType(); |
4180 | return mlir::cast<fir::CharacterType>(eleTy).getFKind(); |
4181 | } |
4182 | |
4183 | bool fir::StringLitOp::isWideValue() { return stringLitOpGetKind(*this) != 1; } |
4184 | |
4185 | static mlir::NamedAttribute |
4186 | mkNamedIntegerAttr(mlir::OpBuilder &builder, llvm::StringRef name, int64_t v) { |
4187 | assert(v > 0); |
4188 | return builder.getNamedAttr( |
4189 | name, builder.getIntegerAttr(builder.getIntegerType(64), v)); |
4190 | } |
4191 | |
4192 | void fir::StringLitOp::build(mlir::OpBuilder &builder, |
4193 | mlir::OperationState &result, |
4194 | fir::CharacterType inType, llvm::StringRef val, |
4195 | std::optional<int64_t> len) { |
4196 | auto valAttr = builder.getNamedAttr(value(), builder.getStringAttr(val)); |
4197 | int64_t length = len ? *len : inType.getLen(); |
4198 | auto lenAttr = mkNamedIntegerAttr(builder, size(), length); |
4199 | result.addAttributes({valAttr, lenAttr}); |
4200 | result.addTypes(inType); |
4201 | } |
4202 | |
4203 | template <typename C> |
4204 | static mlir::ArrayAttr convertToArrayAttr(mlir::OpBuilder &builder, |
4205 | llvm::ArrayRef<C> xlist) { |
4206 | llvm::SmallVector<mlir::Attribute> attrs; |
4207 | auto ty = builder.getIntegerType(8 * sizeof(C)); |
4208 | for (auto ch : xlist) |
4209 | attrs.push_back(Elt: builder.getIntegerAttr(ty, ch)); |
4210 | return builder.getArrayAttr(attrs); |
4211 | } |
4212 | |
4213 | void fir::StringLitOp::build(mlir::OpBuilder &builder, |
4214 | mlir::OperationState &result, |
4215 | fir::CharacterType inType, |
4216 | llvm::ArrayRef<char> vlist, |
4217 | std::optional<std::int64_t> len) { |
4218 | auto valAttr = |
4219 | builder.getNamedAttr(xlist(), convertToArrayAttr(builder, vlist)); |
4220 | std::int64_t length = len ? *len : inType.getLen(); |
4221 | auto lenAttr = mkNamedIntegerAttr(builder, size(), length); |
4222 | result.addAttributes({valAttr, lenAttr}); |
4223 | result.addTypes(inType); |
4224 | } |
4225 | |
4226 | void fir::StringLitOp::build(mlir::OpBuilder &builder, |
4227 | mlir::OperationState &result, |
4228 | fir::CharacterType inType, |
4229 | llvm::ArrayRef<char16_t> vlist, |
4230 | std::optional<std::int64_t> len) { |
4231 | auto valAttr = |
4232 | builder.getNamedAttr(xlist(), convertToArrayAttr(builder, vlist)); |
4233 | std::int64_t length = len ? *len : inType.getLen(); |
4234 | auto lenAttr = mkNamedIntegerAttr(builder, size(), length); |
4235 | result.addAttributes({valAttr, lenAttr}); |
4236 | result.addTypes(inType); |
4237 | } |
4238 | |
4239 | void fir::StringLitOp::build(mlir::OpBuilder &builder, |
4240 | mlir::OperationState &result, |
4241 | fir::CharacterType inType, |
4242 | llvm::ArrayRef<char32_t> vlist, |
4243 | std::optional<std::int64_t> len) { |
4244 | auto valAttr = |
4245 | builder.getNamedAttr(xlist(), convertToArrayAttr(builder, vlist)); |
4246 | std::int64_t length = len ? *len : inType.getLen(); |
4247 | auto lenAttr = mkNamedIntegerAttr(builder, size(), length); |
4248 | result.addAttributes({valAttr, lenAttr}); |
4249 | result.addTypes(inType); |
4250 | } |
4251 | |
4252 | mlir::ParseResult fir::StringLitOp::parse(mlir::OpAsmParser &parser, |
4253 | mlir::OperationState &result) { |
4254 | auto &builder = parser.getBuilder(); |
4255 | mlir::Attribute val; |
4256 | mlir::NamedAttrList attrs; |
4257 | llvm::SMLoc trailingTypeLoc; |
4258 | if (parser.parseAttribute(val, "fake" , attrs)) |
4259 | return mlir::failure(); |
4260 | if (auto v = mlir::dyn_cast<mlir::StringAttr>(val)) |
4261 | result.attributes.push_back( |
4262 | builder.getNamedAttr(fir::StringLitOp::value(), v)); |
4263 | else if (auto v = mlir::dyn_cast<mlir::DenseElementsAttr>(val)) |
4264 | result.attributes.push_back( |
4265 | builder.getNamedAttr(fir::StringLitOp::xlist(), v)); |
4266 | else if (auto v = mlir::dyn_cast<mlir::ArrayAttr>(val)) |
4267 | result.attributes.push_back( |
4268 | builder.getNamedAttr(fir::StringLitOp::xlist(), v)); |
4269 | else |
4270 | return parser.emitError(parser.getCurrentLocation(), |
4271 | "found an invalid constant" ); |
4272 | mlir::IntegerAttr sz; |
4273 | mlir::Type type; |
4274 | if (parser.parseLParen() || |
4275 | parser.parseAttribute(sz, fir::StringLitOp::size(), result.attributes) || |
4276 | parser.parseRParen() || parser.getCurrentLocation(&trailingTypeLoc) || |
4277 | parser.parseColonType(type)) |
4278 | return mlir::failure(); |
4279 | auto charTy = mlir::dyn_cast<fir::CharacterType>(type); |
4280 | if (!charTy) |
4281 | return parser.emitError(trailingTypeLoc, "must have character type" ); |
4282 | type = fir::CharacterType::get(builder.getContext(), charTy.getFKind(), |
4283 | sz.getInt()); |
4284 | if (!type || parser.addTypesToList(type, result.types)) |
4285 | return mlir::failure(); |
4286 | return mlir::success(); |
4287 | } |
4288 | |
4289 | void fir::StringLitOp::print(mlir::OpAsmPrinter &p) { |
4290 | p << ' ' << getValue() << '('; |
4291 | p << mlir::cast<mlir::IntegerAttr>(getSize()).getValue() << ") : " ; |
4292 | p.printType(getType()); |
4293 | } |
4294 | |
4295 | llvm::LogicalResult fir::StringLitOp::verify() { |
4296 | if (mlir::cast<mlir::IntegerAttr>(getSize()).getValue().isNegative()) |
4297 | return emitOpError("size must be non-negative" ); |
4298 | if (auto xl = getOperation()->getAttr(fir::StringLitOp::xlist())) { |
4299 | if (auto xList = mlir::dyn_cast<mlir::ArrayAttr>(xl)) { |
4300 | for (auto a : xList) |
4301 | if (!mlir::isa<mlir::IntegerAttr>(a)) |
4302 | return emitOpError("values in initializer must be integers" ); |
4303 | } else if (mlir::isa<mlir::DenseElementsAttr>(xl)) { |
4304 | // do nothing |
4305 | } else { |
4306 | return emitOpError("has unexpected attribute" ); |
4307 | } |
4308 | } |
4309 | return mlir::success(); |
4310 | } |
4311 | |
4312 | //===----------------------------------------------------------------------===// |
4313 | // UnboxProcOp |
4314 | //===----------------------------------------------------------------------===// |
4315 | |
4316 | llvm::LogicalResult fir::UnboxProcOp::verify() { |
4317 | if (auto eleTy = fir::dyn_cast_ptrEleTy(getRefTuple().getType())) |
4318 | if (mlir::isa<mlir::TupleType>(eleTy)) |
4319 | return mlir::success(); |
4320 | return emitOpError("second output argument has bad type" ); |
4321 | } |
4322 | |
4323 | //===----------------------------------------------------------------------===// |
4324 | // IfOp |
4325 | //===----------------------------------------------------------------------===// |
4326 | |
4327 | void fir::IfOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
4328 | mlir::Value cond, bool withElseRegion) { |
4329 | build(builder, result, std::nullopt, cond, withElseRegion); |
4330 | } |
4331 | |
4332 | void fir::IfOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
4333 | mlir::TypeRange resultTypes, mlir::Value cond, |
4334 | bool withElseRegion) { |
4335 | result.addOperands(cond); |
4336 | result.addTypes(resultTypes); |
4337 | |
4338 | mlir::Region *thenRegion = result.addRegion(); |
4339 | thenRegion->push_back(new mlir::Block()); |
4340 | if (resultTypes.empty()) |
4341 | IfOp::ensureTerminator(*thenRegion, builder, result.location); |
4342 | |
4343 | mlir::Region *elseRegion = result.addRegion(); |
4344 | if (withElseRegion) { |
4345 | elseRegion->push_back(new mlir::Block()); |
4346 | if (resultTypes.empty()) |
4347 | IfOp::ensureTerminator(*elseRegion, builder, result.location); |
4348 | } |
4349 | } |
4350 | |
4351 | // These 3 functions copied from scf.if implementation. |
4352 | |
4353 | /// Given the region at `index`, or the parent operation if `index` is None, |
4354 | /// return the successor regions. These are the regions that may be selected |
4355 | /// during the flow of control. |
4356 | void fir::IfOp::getSuccessorRegions( |
4357 | mlir::RegionBranchPoint point, |
4358 | llvm::SmallVectorImpl<mlir::RegionSuccessor> ®ions) { |
4359 | // The `then` and the `else` region branch back to the parent operation. |
4360 | if (!point.isParent()) { |
4361 | regions.push_back(mlir::RegionSuccessor(getResults())); |
4362 | return; |
4363 | } |
4364 | |
4365 | // Don't consider the else region if it is empty. |
4366 | regions.push_back(mlir::RegionSuccessor(&getThenRegion())); |
4367 | |
4368 | // Don't consider the else region if it is empty. |
4369 | mlir::Region *elseRegion = &this->getElseRegion(); |
4370 | if (elseRegion->empty()) |
4371 | regions.push_back(mlir::RegionSuccessor()); |
4372 | else |
4373 | regions.push_back(mlir::RegionSuccessor(elseRegion)); |
4374 | } |
4375 | |
4376 | void fir::IfOp::getEntrySuccessorRegions( |
4377 | llvm::ArrayRef<mlir::Attribute> operands, |
4378 | llvm::SmallVectorImpl<mlir::RegionSuccessor> ®ions) { |
4379 | FoldAdaptor adaptor(operands); |
4380 | auto boolAttr = |
4381 | mlir::dyn_cast_or_null<mlir::BoolAttr>(adaptor.getCondition()); |
4382 | if (!boolAttr || boolAttr.getValue()) |
4383 | regions.emplace_back(&getThenRegion()); |
4384 | |
4385 | // If the else region is empty, execution continues after the parent op. |
4386 | if (!boolAttr || !boolAttr.getValue()) { |
4387 | if (!getElseRegion().empty()) |
4388 | regions.emplace_back(&getElseRegion()); |
4389 | else |
4390 | regions.emplace_back(getResults()); |
4391 | } |
4392 | } |
4393 | |
4394 | void fir::IfOp::getRegionInvocationBounds( |
4395 | llvm::ArrayRef<mlir::Attribute> operands, |
4396 | llvm::SmallVectorImpl<mlir::InvocationBounds> &invocationBounds) { |
4397 | if (auto cond = mlir::dyn_cast_or_null<mlir::BoolAttr>(operands[0])) { |
4398 | // If the condition is known, then one region is known to be executed once |
4399 | // and the other zero times. |
4400 | invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0); |
4401 | invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1); |
4402 | } else { |
4403 | // Non-constant condition. Each region may be executed 0 or 1 times. |
4404 | invocationBounds.assign(2, {0, 1}); |
4405 | } |
4406 | } |
4407 | |
4408 | mlir::ParseResult fir::IfOp::parse(mlir::OpAsmParser &parser, |
4409 | mlir::OperationState &result) { |
4410 | result.regions.reserve(2); |
4411 | mlir::Region *thenRegion = result.addRegion(); |
4412 | mlir::Region *elseRegion = result.addRegion(); |
4413 | |
4414 | auto &builder = parser.getBuilder(); |
4415 | mlir::OpAsmParser::UnresolvedOperand cond; |
4416 | mlir::Type i1Type = builder.getIntegerType(1); |
4417 | if (parser.parseOperand(cond) || |
4418 | parser.resolveOperand(cond, i1Type, result.operands)) |
4419 | return mlir::failure(); |
4420 | |
4421 | if (parser.parseOptionalArrowTypeList(result.types)) |
4422 | return mlir::failure(); |
4423 | |
4424 | if (parser.parseRegion(*thenRegion, {}, {})) |
4425 | return mlir::failure(); |
4426 | fir::IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), |
4427 | result.location); |
4428 | |
4429 | if (mlir::succeeded(parser.parseOptionalKeyword("else" ))) { |
4430 | if (parser.parseRegion(*elseRegion, {}, {})) |
4431 | return mlir::failure(); |
4432 | fir::IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), |
4433 | result.location); |
4434 | } |
4435 | |
4436 | // Parse the optional attribute list. |
4437 | if (parser.parseOptionalAttrDict(result.attributes)) |
4438 | return mlir::failure(); |
4439 | return mlir::success(); |
4440 | } |
4441 | |
4442 | llvm::LogicalResult fir::IfOp::verify() { |
4443 | if (getNumResults() != 0 && getElseRegion().empty()) |
4444 | return emitOpError("must have an else block if defining values" ); |
4445 | |
4446 | return mlir::success(); |
4447 | } |
4448 | |
4449 | void fir::IfOp::print(mlir::OpAsmPrinter &p) { |
4450 | bool printBlockTerminators = false; |
4451 | p << ' ' << getCondition(); |
4452 | if (!getResults().empty()) { |
4453 | p << " -> (" << getResultTypes() << ')'; |
4454 | printBlockTerminators = true; |
4455 | } |
4456 | p << ' '; |
4457 | p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false, |
4458 | printBlockTerminators); |
4459 | |
4460 | // Print the 'else' regions if it exists and has a block. |
4461 | auto &otherReg = getElseRegion(); |
4462 | if (!otherReg.empty()) { |
4463 | p << " else " ; |
4464 | p.printRegion(otherReg, /*printEntryBlockArgs=*/false, |
4465 | printBlockTerminators); |
4466 | } |
4467 | p.printOptionalAttrDict((*this)->getAttrs()); |
4468 | } |
4469 | |
4470 | void fir::IfOp::resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> &results, |
4471 | unsigned resultNum) { |
4472 | auto *term = getThenRegion().front().getTerminator(); |
4473 | if (resultNum < term->getNumOperands()) |
4474 | results.push_back(term->getOperand(resultNum)); |
4475 | term = getElseRegion().front().getTerminator(); |
4476 | if (resultNum < term->getNumOperands()) |
4477 | results.push_back(term->getOperand(resultNum)); |
4478 | } |
4479 | |
4480 | //===----------------------------------------------------------------------===// |
4481 | // BoxOffsetOp |
4482 | //===----------------------------------------------------------------------===// |
4483 | |
4484 | llvm::LogicalResult fir::BoxOffsetOp::verify() { |
4485 | auto boxType = mlir::dyn_cast_or_null<fir::BaseBoxType>( |
4486 | fir::dyn_cast_ptrEleTy(getBoxRef().getType())); |
4487 | mlir::Type boxCharType; |
4488 | if (!boxType) { |
4489 | boxCharType = mlir::dyn_cast_or_null<fir::BoxCharType>( |
4490 | fir::dyn_cast_ptrEleTy(getBoxRef().getType())); |
4491 | if (!boxCharType) |
4492 | return emitOpError("box_ref operand must have !fir.ref<!fir.box<T>> or " |
4493 | "!fir.ref<!fir.boxchar<k>> type" ); |
4494 | if (getField() == fir::BoxFieldAttr::derived_type) |
4495 | return emitOpError("cannot address derived_type field of a fir.boxchar" ); |
4496 | } |
4497 | if (getField() != fir::BoxFieldAttr::base_addr && |
4498 | getField() != fir::BoxFieldAttr::derived_type) |
4499 | return emitOpError("cannot address provided field" ); |
4500 | if (getField() == fir::BoxFieldAttr::derived_type) { |
4501 | if (!fir::boxHasAddendum(boxType)) |
4502 | return emitOpError("can only address derived_type field of derived type " |
4503 | "or unlimited polymorphic fir.box" ); |
4504 | } |
4505 | return mlir::success(); |
4506 | } |
4507 | |
4508 | void fir::BoxOffsetOp::build(mlir::OpBuilder &builder, |
4509 | mlir::OperationState &result, mlir::Value boxRef, |
4510 | fir::BoxFieldAttr field) { |
4511 | mlir::Type valueType = |
4512 | fir::unwrapPassByRefType(fir::unwrapRefType(boxRef.getType())); |
4513 | mlir::Type resultType = valueType; |
4514 | if (field == fir::BoxFieldAttr::base_addr) |
4515 | resultType = fir::LLVMPointerType::get(fir::ReferenceType::get(valueType)); |
4516 | else if (field == fir::BoxFieldAttr::derived_type) |
4517 | resultType = fir::LLVMPointerType::get( |
4518 | fir::TypeDescType::get(fir::unwrapSequenceType(valueType))); |
4519 | build(builder, result, {resultType}, boxRef, field); |
4520 | } |
4521 | |
4522 | //===----------------------------------------------------------------------===// |
4523 | |
4524 | mlir::ParseResult fir::isValidCaseAttr(mlir::Attribute attr) { |
4525 | if (mlir::isa<mlir::UnitAttr, fir::ClosedIntervalAttr, fir::PointIntervalAttr, |
4526 | fir::LowerBoundAttr, fir::UpperBoundAttr>(attr)) |
4527 | return mlir::success(); |
4528 | return mlir::failure(); |
4529 | } |
4530 | |
4531 | unsigned fir::getCaseArgumentOffset(llvm::ArrayRef<mlir::Attribute> cases, |
4532 | unsigned dest) { |
4533 | unsigned o = 0; |
4534 | for (unsigned i = 0; i < dest; ++i) { |
4535 | auto &attr = cases[i]; |
4536 | if (!mlir::dyn_cast_or_null<mlir::UnitAttr>(attr)) { |
4537 | ++o; |
4538 | if (mlir::dyn_cast_or_null<fir::ClosedIntervalAttr>(attr)) |
4539 | ++o; |
4540 | } |
4541 | } |
4542 | return o; |
4543 | } |
4544 | |
4545 | mlir::ParseResult |
4546 | fir::parseSelector(mlir::OpAsmParser &parser, mlir::OperationState &result, |
4547 | mlir::OpAsmParser::UnresolvedOperand &selector, |
4548 | mlir::Type &type) { |
4549 | if (parser.parseOperand(selector) || parser.parseColonType(type) || |
4550 | parser.resolveOperand(selector, type, result.operands) || |
4551 | parser.parseLSquare()) |
4552 | return mlir::failure(); |
4553 | return mlir::success(); |
4554 | } |
4555 | |
4556 | mlir::func::FuncOp fir::createFuncOp(mlir::Location loc, mlir::ModuleOp module, |
4557 | llvm::StringRef name, |
4558 | mlir::FunctionType type, |
4559 | llvm::ArrayRef<mlir::NamedAttribute> attrs, |
4560 | const mlir::SymbolTable *symbolTable) { |
4561 | if (symbolTable) |
4562 | if (auto f = symbolTable->lookup<mlir::func::FuncOp>(name)) { |
4563 | #ifdef EXPENSIVE_CHECKS |
4564 | assert(f == module.lookupSymbol<mlir::func::FuncOp>(name) && |
4565 | "symbolTable and module out of sync" ); |
4566 | #endif |
4567 | return f; |
4568 | } |
4569 | if (auto f = module.lookupSymbol<mlir::func::FuncOp>(name)) |
4570 | return f; |
4571 | mlir::OpBuilder modBuilder(module.getBodyRegion()); |
4572 | modBuilder.setInsertionPointToEnd(module.getBody()); |
4573 | auto result = modBuilder.create<mlir::func::FuncOp>(loc, name, type, attrs); |
4574 | result.setVisibility(mlir::SymbolTable::Visibility::Private); |
4575 | return result; |
4576 | } |
4577 | |
4578 | fir::GlobalOp fir::createGlobalOp(mlir::Location loc, mlir::ModuleOp module, |
4579 | llvm::StringRef name, mlir::Type type, |
4580 | llvm::ArrayRef<mlir::NamedAttribute> attrs, |
4581 | const mlir::SymbolTable *symbolTable) { |
4582 | if (symbolTable) |
4583 | if (auto g = symbolTable->lookup<fir::GlobalOp>(name)) { |
4584 | #ifdef EXPENSIVE_CHECKS |
4585 | assert(g == module.lookupSymbol<fir::GlobalOp>(name) && |
4586 | "symbolTable and module out of sync" ); |
4587 | #endif |
4588 | return g; |
4589 | } |
4590 | if (auto g = module.lookupSymbol<fir::GlobalOp>(name)) |
4591 | return g; |
4592 | mlir::OpBuilder modBuilder(module.getBodyRegion()); |
4593 | auto result = modBuilder.create<fir::GlobalOp>(loc, name, type, attrs); |
4594 | result.setVisibility(mlir::SymbolTable::Visibility::Private); |
4595 | return result; |
4596 | } |
4597 | |
4598 | bool fir::hasHostAssociationArgument(mlir::func::FuncOp func) { |
4599 | if (auto allArgAttrs = func.getAllArgAttrs()) |
4600 | for (auto attr : allArgAttrs) |
4601 | if (auto dict = mlir::dyn_cast_or_null<mlir::DictionaryAttr>(attr)) |
4602 | if (dict.get(fir::getHostAssocAttrName())) |
4603 | return true; |
4604 | return false; |
4605 | } |
4606 | |
4607 | // Test if value's definition has the specified set of |
4608 | // attributeNames. The value's definition is one of the operations |
4609 | // that are able to carry the Fortran variable attributes, e.g. |
4610 | // fir.alloca or fir.allocmem. Function arguments may also represent |
4611 | // value definitions and carry relevant attributes. |
4612 | // |
4613 | // If it is not possible to reach the limited set of definition |
4614 | // entities from the given value, then the function will return |
4615 | // std::nullopt. Otherwise, the definition is known and the return |
4616 | // value is computed as: |
4617 | // * if checkAny is true, then the function will return true |
4618 | // iff any of the attributeNames attributes is set on the definition. |
4619 | // * if checkAny is false, then the function will return true |
4620 | // iff all of the attributeNames attributes are set on the definition. |
4621 | static std::optional<bool> |
4622 | valueCheckFirAttributes(mlir::Value value, |
4623 | llvm::ArrayRef<llvm::StringRef> attributeNames, |
4624 | bool checkAny) { |
4625 | auto testAttributeSets = [&](llvm::ArrayRef<mlir::NamedAttribute> setAttrs, |
4626 | llvm::ArrayRef<llvm::StringRef> checkAttrs) { |
4627 | if (checkAny) { |
4628 | // Return true iff any of checkAttrs attributes is present |
4629 | // in setAttrs set. |
4630 | for (llvm::StringRef checkAttrName : checkAttrs) |
4631 | if (llvm::any_of(Range&: setAttrs, P: [&](mlir::NamedAttribute setAttr) { |
4632 | return setAttr.getName() == checkAttrName; |
4633 | })) |
4634 | return true; |
4635 | |
4636 | return false; |
4637 | } |
4638 | |
4639 | // Return true iff all attributes from checkAttrs are present |
4640 | // in setAttrs set. |
4641 | for (mlir::StringRef checkAttrName : checkAttrs) |
4642 | if (llvm::none_of(Range&: setAttrs, P: [&](mlir::NamedAttribute setAttr) { |
4643 | return setAttr.getName() == checkAttrName; |
4644 | })) |
4645 | return false; |
4646 | |
4647 | return true; |
4648 | }; |
4649 | // If this is a fir.box that was loaded, the fir attributes will be on the |
4650 | // related fir.ref<fir.box> creation. |
4651 | if (mlir::isa<fir::BoxType>(value.getType())) |
4652 | if (auto definingOp = value.getDefiningOp()) |
4653 | if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(definingOp)) |
4654 | value = loadOp.getMemref(); |
4655 | // If this is a function argument, look in the argument attributes. |
4656 | if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(Val&: value)) { |
4657 | if (blockArg.getOwner() && blockArg.getOwner()->isEntryBlock()) |
4658 | if (auto funcOp = mlir::dyn_cast<mlir::func::FuncOp>( |
4659 | blockArg.getOwner()->getParentOp())) |
4660 | return testAttributeSets( |
4661 | mlir::cast<mlir::FunctionOpInterface>(*funcOp).getArgAttrs( |
4662 | blockArg.getArgNumber()), |
4663 | attributeNames); |
4664 | |
4665 | // If it is not a function argument, the attributes are unknown. |
4666 | return std::nullopt; |
4667 | } |
4668 | |
4669 | if (auto definingOp = value.getDefiningOp()) { |
4670 | // If this is an allocated value, look at the allocation attributes. |
4671 | if (mlir::isa<fir::AllocMemOp>(definingOp) || |
4672 | mlir::isa<fir::AllocaOp>(definingOp)) |
4673 | return testAttributeSets(definingOp->getAttrs(), attributeNames); |
4674 | // If this is an imported global, look at AddrOfOp and GlobalOp attributes. |
4675 | // Both operations are looked at because use/host associated variable (the |
4676 | // AddrOfOp) can have ASYNCHRONOUS/VOLATILE attributes even if the ultimate |
4677 | // entity (the globalOp) does not have them. |
4678 | if (auto addressOfOp = mlir::dyn_cast<fir::AddrOfOp>(definingOp)) { |
4679 | if (testAttributeSets(addressOfOp->getAttrs(), attributeNames)) |
4680 | return true; |
4681 | if (auto module = definingOp->getParentOfType<mlir::ModuleOp>()) |
4682 | if (auto globalOp = |
4683 | module.lookupSymbol<fir::GlobalOp>(addressOfOp.getSymbol())) |
4684 | return testAttributeSets(globalOp->getAttrs(), attributeNames); |
4685 | } |
4686 | } |
4687 | // TODO: Construct associated entities attributes. Decide where the fir |
4688 | // attributes must be placed/looked for in this case. |
4689 | return std::nullopt; |
4690 | } |
4691 | |
4692 | bool fir::valueMayHaveFirAttributes( |
4693 | mlir::Value value, llvm::ArrayRef<llvm::StringRef> attributeNames) { |
4694 | std::optional<bool> mayHaveAttr = |
4695 | valueCheckFirAttributes(value, attributeNames, /*checkAny=*/true); |
4696 | return mayHaveAttr.value_or(true); |
4697 | } |
4698 | |
4699 | bool fir::valueHasFirAttribute(mlir::Value value, |
4700 | llvm::StringRef attributeName) { |
4701 | std::optional<bool> mayHaveAttr = |
4702 | valueCheckFirAttributes(value, {attributeName}, /*checkAny=*/false); |
4703 | return mayHaveAttr.value_or(false); |
4704 | } |
4705 | |
4706 | bool fir::anyFuncArgsHaveAttr(mlir::func::FuncOp func, llvm::StringRef attr) { |
4707 | for (unsigned i = 0, end = func.getNumArguments(); i < end; ++i) |
4708 | if (func.getArgAttr(i, attr)) |
4709 | return true; |
4710 | return false; |
4711 | } |
4712 | |
4713 | std::optional<std::int64_t> fir::getIntIfConstant(mlir::Value value) { |
4714 | if (auto *definingOp = value.getDefiningOp()) { |
4715 | if (auto cst = mlir::dyn_cast<mlir::arith::ConstantOp>(definingOp)) |
4716 | if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(cst.getValue())) |
4717 | return intAttr.getInt(); |
4718 | if (auto llConstOp = mlir::dyn_cast<mlir::LLVM::ConstantOp>(definingOp)) |
4719 | if (auto attr = mlir::dyn_cast<mlir::IntegerAttr>(llConstOp.getValue())) |
4720 | return attr.getValue().getSExtValue(); |
4721 | } |
4722 | return {}; |
4723 | } |
4724 | |
4725 | bool fir::isDummyArgument(mlir::Value v) { |
4726 | auto blockArg{mlir::dyn_cast<mlir::BlockArgument>(v)}; |
4727 | if (!blockArg) { |
4728 | auto defOp = v.getDefiningOp(); |
4729 | if (defOp) { |
4730 | if (auto declareOp = mlir::dyn_cast<fir::DeclareOp>(defOp)) |
4731 | if (declareOp.getDummyScope()) |
4732 | return true; |
4733 | } |
4734 | return false; |
4735 | } |
4736 | |
4737 | auto *owner{blockArg.getOwner()}; |
4738 | return owner->isEntryBlock() && |
4739 | mlir::isa<mlir::FunctionOpInterface>(owner->getParentOp()); |
4740 | } |
4741 | |
4742 | mlir::Type fir::applyPathToType(mlir::Type eleTy, mlir::ValueRange path) { |
4743 | for (auto i = path.begin(), end = path.end(); eleTy && i < end;) { |
4744 | eleTy = llvm::TypeSwitch<mlir::Type, mlir::Type>(eleTy) |
4745 | .Case<fir::RecordType>([&](fir::RecordType ty) { |
4746 | if (auto *op = (*i++).getDefiningOp()) { |
4747 | if (auto off = mlir::dyn_cast<fir::FieldIndexOp>(op)) |
4748 | return ty.getType(off.getFieldName()); |
4749 | if (auto off = mlir::dyn_cast<mlir::arith::ConstantOp>(op)) |
4750 | return ty.getType(fir::toInt(off)); |
4751 | } |
4752 | return mlir::Type{}; |
4753 | }) |
4754 | .Case<fir::SequenceType>([&](fir::SequenceType ty) { |
4755 | bool valid = true; |
4756 | const auto rank = ty.getDimension(); |
4757 | for (std::remove_const_t<decltype(rank)> ii = 0; |
4758 | valid && ii < rank; ++ii) |
4759 | valid = i < end && fir::isa_integer((*i++).getType()); |
4760 | return valid ? ty.getEleTy() : mlir::Type{}; |
4761 | }) |
4762 | .Case<mlir::TupleType>([&](mlir::TupleType ty) { |
4763 | if (auto *op = (*i++).getDefiningOp()) |
4764 | if (auto off = mlir::dyn_cast<mlir::arith::ConstantOp>(op)) |
4765 | return ty.getType(fir::toInt(off)); |
4766 | return mlir::Type{}; |
4767 | }) |
4768 | .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { |
4769 | if (fir::isa_integer((*i++).getType())) |
4770 | return ty.getElementType(); |
4771 | return mlir::Type{}; |
4772 | }) |
4773 | .Default([&](const auto &) { return mlir::Type{}; }); |
4774 | } |
4775 | return eleTy; |
4776 | } |
4777 | |
4778 | bool fir::reboxPreservesContinuity(fir::ReboxOp rebox, bool checkWhole) { |
4779 | // If slicing is not involved, then the rebox does not affect |
4780 | // the continuity of the array. |
4781 | auto sliceArg = rebox.getSlice(); |
4782 | if (!sliceArg) |
4783 | return true; |
4784 | |
4785 | if (auto sliceOp = |
4786 | mlir::dyn_cast_or_null<fir::SliceOp>(sliceArg.getDefiningOp())) { |
4787 | if (sliceOp.getFields().empty() && sliceOp.getSubstr().empty()) { |
4788 | // TODO: generalize code for the triples analysis with |
4789 | // hlfir::designatePreservesContinuity, especially when |
4790 | // recognition of the whole dimension slices is added. |
4791 | auto triples = sliceOp.getTriples(); |
4792 | assert((triples.size() % 3) == 0 && "invalid triples size" ); |
4793 | |
4794 | // A slice with step=1 in the innermost dimension preserves |
4795 | // the continuity of the array in the innermost dimension. |
4796 | // If checkWhole is false, then check only the innermost slice triples. |
4797 | std::size_t checkUpTo = checkWhole ? triples.size() : 3; |
4798 | checkUpTo = std::min(checkUpTo, triples.size()); |
4799 | for (std::size_t i = 0; i < checkUpTo; i += 3) { |
4800 | if (triples[i] != triples[i + 1]) { |
4801 | // This is a section of the dimension. Only allow it |
4802 | // to be the first triple. |
4803 | if (i != 0) |
4804 | return false; |
4805 | auto constantStep = fir::getIntIfConstant(triples[i + 2]); |
4806 | if (!constantStep || *constantStep != 1) |
4807 | return false; |
4808 | } |
4809 | } |
4810 | return true; |
4811 | } |
4812 | } |
4813 | return false; |
4814 | } |
4815 | |
4816 | std::optional<int64_t> fir::getAllocaByteSize(fir::AllocaOp alloca, |
4817 | const mlir::DataLayout &dl, |
4818 | const fir::KindMapping &kindMap) { |
4819 | mlir::Type type = alloca.getInType(); |
4820 | // TODO: should use the constant operands when all info is not available in |
4821 | // the type. |
4822 | if (!alloca.isDynamic()) |
4823 | if (auto sizeAndAlignment = |
4824 | getTypeSizeAndAlignment(alloca.getLoc(), type, dl, kindMap)) |
4825 | return sizeAndAlignment->first; |
4826 | return std::nullopt; |
4827 | } |
4828 | |
4829 | //===----------------------------------------------------------------------===// |
4830 | // DeclareOp |
4831 | //===----------------------------------------------------------------------===// |
4832 | |
4833 | llvm::LogicalResult fir::DeclareOp::verify() { |
4834 | auto fortranVar = |
4835 | mlir::cast<fir::FortranVariableOpInterface>(this->getOperation()); |
4836 | return fortranVar.verifyDeclareLikeOpImpl(getMemref()); |
4837 | } |
4838 | |
4839 | //===----------------------------------------------------------------------===// |
4840 | // PackArrayOp |
4841 | //===----------------------------------------------------------------------===// |
4842 | |
4843 | llvm::LogicalResult fir::PackArrayOp::verify() { |
4844 | mlir::Type arrayType = getArray().getType(); |
4845 | if (!validTypeParams(arrayType, getTypeparams(), /*allowParamsForBox=*/true)) |
4846 | return emitOpError("invalid type parameters" ); |
4847 | |
4848 | if (getInnermost() && fir::getBoxRank(arrayType) == 1) |
4849 | return emitOpError( |
4850 | "'innermost' is invalid for 1D arrays, use 'whole' instead" ); |
4851 | return mlir::success(); |
4852 | } |
4853 | |
4854 | void fir::PackArrayOp::getEffects( |
4855 | llvm::SmallVectorImpl< |
4856 | mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>> |
4857 | &effects) { |
4858 | if (getStack()) |
4859 | effects.emplace_back( |
4860 | mlir::MemoryEffects::Allocate::get(), |
4861 | mlir::SideEffects::AutomaticAllocationScopeResource::get()); |
4862 | else |
4863 | effects.emplace_back(mlir::MemoryEffects::Allocate::get(), |
4864 | mlir::SideEffects::DefaultResource::get()); |
4865 | |
4866 | if (!getNoCopy()) |
4867 | effects.emplace_back(mlir::MemoryEffects::Read::get(), |
4868 | mlir::SideEffects::DefaultResource::get()); |
4869 | } |
4870 | |
4871 | static mlir::ParseResult |
4872 | parsePackArrayConstraints(mlir::OpAsmParser &parser, mlir::IntegerAttr &maxSize, |
4873 | mlir::IntegerAttr &maxElementSize, |
4874 | mlir::IntegerAttr &minStride) { |
4875 | mlir::OperationName opName = mlir::OperationName( |
4876 | fir::PackArrayOp::getOperationName(), parser.getContext()); |
4877 | struct { |
4878 | llvm::StringRef name; |
4879 | mlir::IntegerAttr &ref; |
4880 | } attributes[] = { |
4881 | {fir::PackArrayOp::getMaxSizeAttrName(opName), maxSize}, |
4882 | {fir::PackArrayOp::getMaxElementSizeAttrName(opName), maxElementSize}, |
4883 | {fir::PackArrayOp::getMinStrideAttrName(opName), minStride}}; |
4884 | |
4885 | mlir::NamedAttrList parsedAttrs; |
4886 | if (succeeded(Result: parser.parseOptionalAttrDict(result&: parsedAttrs))) { |
4887 | for (auto parsedAttr : parsedAttrs) { |
4888 | for (auto opAttr : attributes) { |
4889 | if (parsedAttr.getName() == opAttr.name) |
4890 | opAttr.ref = mlir::cast<mlir::IntegerAttr>(parsedAttr.getValue()); |
4891 | } |
4892 | } |
4893 | return mlir::success(); |
4894 | } |
4895 | return mlir::failure(); |
4896 | } |
4897 | |
4898 | static void printPackArrayConstraints(mlir::OpAsmPrinter &p, |
4899 | fir::PackArrayOp &op, |
4900 | const mlir::IntegerAttr &maxSize, |
4901 | const mlir::IntegerAttr &maxElementSize, |
4902 | const mlir::IntegerAttr &minStride) { |
4903 | llvm::SmallVector<mlir::NamedAttribute> attributes; |
4904 | if (maxSize) |
4905 | attributes.emplace_back(op.getMaxSizeAttrName(), maxSize); |
4906 | if (maxElementSize) |
4907 | attributes.emplace_back(op.getMaxElementSizeAttrName(), maxElementSize); |
4908 | if (minStride) |
4909 | attributes.emplace_back(op.getMinStrideAttrName(), minStride); |
4910 | |
4911 | p.printOptionalAttrDict(attrs: attributes); |
4912 | } |
4913 | |
4914 | //===----------------------------------------------------------------------===// |
4915 | // UnpackArrayOp |
4916 | //===----------------------------------------------------------------------===// |
4917 | |
4918 | llvm::LogicalResult fir::UnpackArrayOp::verify() { |
4919 | if (auto packOp = getTemp().getDefiningOp<fir::PackArrayOp>()) |
4920 | if (getStack() != packOp.getStack()) |
4921 | return emitOpError() << "the pack operation uses different memory for " |
4922 | "the temporary (stack vs heap): " |
4923 | << *packOp.getOperation() << "\n" ; |
4924 | return mlir::success(); |
4925 | } |
4926 | |
4927 | void fir::UnpackArrayOp::getEffects( |
4928 | llvm::SmallVectorImpl< |
4929 | mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>> |
4930 | &effects) { |
4931 | if (getStack()) |
4932 | effects.emplace_back( |
4933 | mlir::MemoryEffects::Free::get(), |
4934 | mlir::SideEffects::AutomaticAllocationScopeResource::get()); |
4935 | else |
4936 | effects.emplace_back(mlir::MemoryEffects::Free::get(), |
4937 | mlir::SideEffects::DefaultResource::get()); |
4938 | |
4939 | if (!getNoCopy()) |
4940 | effects.emplace_back(mlir::MemoryEffects::Write::get(), |
4941 | mlir::SideEffects::DefaultResource::get()); |
4942 | } |
4943 | |
4944 | //===----------------------------------------------------------------------===// |
4945 | // IsContiguousBoxOp |
4946 | //===----------------------------------------------------------------------===// |
4947 | |
4948 | namespace { |
4949 | struct SimplifyIsContiguousBoxOp |
4950 | : public mlir::OpRewritePattern<fir::IsContiguousBoxOp> { |
4951 | using mlir::OpRewritePattern<fir::IsContiguousBoxOp>::OpRewritePattern; |
4952 | mlir::LogicalResult |
4953 | matchAndRewrite(fir::IsContiguousBoxOp op, |
4954 | mlir::PatternRewriter &rewriter) const override; |
4955 | }; |
4956 | } // namespace |
4957 | |
4958 | mlir::LogicalResult SimplifyIsContiguousBoxOp::matchAndRewrite( |
4959 | fir::IsContiguousBoxOp op, mlir::PatternRewriter &rewriter) const { |
4960 | auto boxType = mlir::cast<fir::BaseBoxType>(op.getBox().getType()); |
4961 | // Nothing to do for assumed-rank arrays and !fir.box<none>. |
4962 | if (boxType.isAssumedRank() || fir::isBoxNone(boxType)) |
4963 | return mlir::failure(); |
4964 | |
4965 | if (fir::getBoxRank(boxType) == 0) { |
4966 | // Scalars are always contiguous. |
4967 | mlir::Type i1Type = rewriter.getI1Type(); |
4968 | rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>( |
4969 | op, i1Type, rewriter.getIntegerAttr(i1Type, 1)); |
4970 | return mlir::success(); |
4971 | } |
4972 | |
4973 | // TODO: support more patterns, e.g. a result of fir.embox without |
4974 | // the slice is contiguous. We can add fir::isSimplyContiguous(box) |
4975 | // that walks def-use to figure it out. |
4976 | return mlir::failure(); |
4977 | } |
4978 | |
4979 | void fir::IsContiguousBoxOp::getCanonicalizationPatterns( |
4980 | mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { |
4981 | patterns.add<SimplifyIsContiguousBoxOp>(context); |
4982 | } |
4983 | |
4984 | //===----------------------------------------------------------------------===// |
4985 | // BoxTotalElementsOp |
4986 | //===----------------------------------------------------------------------===// |
4987 | |
4988 | namespace { |
4989 | struct SimplifyBoxTotalElementsOp |
4990 | : public mlir::OpRewritePattern<fir::BoxTotalElementsOp> { |
4991 | using mlir::OpRewritePattern<fir::BoxTotalElementsOp>::OpRewritePattern; |
4992 | mlir::LogicalResult |
4993 | matchAndRewrite(fir::BoxTotalElementsOp op, |
4994 | mlir::PatternRewriter &rewriter) const override; |
4995 | }; |
4996 | } // namespace |
4997 | |
4998 | mlir::LogicalResult SimplifyBoxTotalElementsOp::matchAndRewrite( |
4999 | fir::BoxTotalElementsOp op, mlir::PatternRewriter &rewriter) const { |
5000 | auto boxType = mlir::cast<fir::BaseBoxType>(op.getBox().getType()); |
5001 | // Nothing to do for assumed-rank arrays and !fir.box<none>. |
5002 | if (boxType.isAssumedRank() || fir::isBoxNone(boxType)) |
5003 | return mlir::failure(); |
5004 | |
5005 | if (fir::getBoxRank(boxType) == 0) { |
5006 | // Scalar: 1 element. |
5007 | rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>( |
5008 | op, op.getType(), rewriter.getIntegerAttr(op.getType(), 1)); |
5009 | return mlir::success(); |
5010 | } |
5011 | |
5012 | // TODO: support more cases, e.g. !fir.box<!fir.array<10xi32>>. |
5013 | return mlir::failure(); |
5014 | } |
5015 | |
5016 | void fir::BoxTotalElementsOp::getCanonicalizationPatterns( |
5017 | mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { |
5018 | patterns.add<SimplifyBoxTotalElementsOp>(context); |
5019 | } |
5020 | |
5021 | //===----------------------------------------------------------------------===// |
5022 | // LocalitySpecifierOp |
5023 | //===----------------------------------------------------------------------===// |
5024 | |
5025 | llvm::LogicalResult fir::LocalitySpecifierOp::verifyRegions() { |
5026 | mlir::Type argType = getArgType(); |
5027 | auto verifyTerminator = [&](mlir::Operation *terminator, |
5028 | bool yieldsValue) -> llvm::LogicalResult { |
5029 | if (!terminator->getBlock()->getSuccessors().empty()) |
5030 | return llvm::success(); |
5031 | |
5032 | if (!llvm::isa<fir::YieldOp>(terminator)) |
5033 | return mlir::emitError(terminator->getLoc()) |
5034 | << "expected exit block terminator to be an `fir.yield` op." ; |
5035 | |
5036 | YieldOp yieldOp = llvm::cast<YieldOp>(terminator); |
5037 | mlir::TypeRange yieldedTypes = yieldOp.getResults().getTypes(); |
5038 | |
5039 | if (!yieldsValue) { |
5040 | if (yieldedTypes.empty()) |
5041 | return llvm::success(); |
5042 | |
5043 | return mlir::emitError(terminator->getLoc()) |
5044 | << "Did not expect any values to be yielded." ; |
5045 | } |
5046 | |
5047 | if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType) |
5048 | return llvm::success(); |
5049 | |
5050 | auto error = mlir::emitError(yieldOp.getLoc()) |
5051 | << "Invalid yielded value. Expected type: " << argType |
5052 | << ", got: " ; |
5053 | |
5054 | if (yieldedTypes.empty()) |
5055 | error << "None" ; |
5056 | else |
5057 | error << yieldedTypes; |
5058 | |
5059 | return error; |
5060 | }; |
5061 | |
5062 | auto verifyRegion = [&](mlir::Region ®ion, unsigned expectedNumArgs, |
5063 | llvm::StringRef regionName, |
5064 | bool yieldsValue) -> llvm::LogicalResult { |
5065 | assert(!region.empty()); |
5066 | |
5067 | if (region.getNumArguments() != expectedNumArgs) |
5068 | return mlir::emitError(region.getLoc()) |
5069 | << "`" << regionName << "`: " |
5070 | << "expected " << expectedNumArgs |
5071 | << " region arguments, got: " << region.getNumArguments(); |
5072 | |
5073 | for (mlir::Block &block : region) { |
5074 | // MLIR will verify the absence of the terminator for us. |
5075 | if (!block.mightHaveTerminator()) |
5076 | continue; |
5077 | |
5078 | if (failed(verifyTerminator(block.getTerminator(), yieldsValue))) |
5079 | return llvm::failure(); |
5080 | } |
5081 | |
5082 | return llvm::success(); |
5083 | }; |
5084 | |
5085 | // Ensure all of the region arguments have the same type |
5086 | for (mlir::Region *region : getRegions()) |
5087 | for (mlir::Type ty : region->getArgumentTypes()) |
5088 | if (ty != argType) |
5089 | return emitError() << "Region argument type mismatch: got " << ty |
5090 | << " expected " << argType << "." ; |
5091 | |
5092 | mlir::Region &initRegion = getInitRegion(); |
5093 | if (!initRegion.empty() && |
5094 | failed(verifyRegion(getInitRegion(), /*expectedNumArgs=*/2, "init" , |
5095 | /*yieldsValue=*/true))) |
5096 | return llvm::failure(); |
5097 | |
5098 | LocalitySpecifierType dsType = getLocalitySpecifierType(); |
5099 | |
5100 | if (dsType == LocalitySpecifierType::Local && !getCopyRegion().empty()) |
5101 | return emitError("`local` specifiers do not require a `copy` region." ); |
5102 | |
5103 | if (dsType == LocalitySpecifierType::LocalInit && getCopyRegion().empty()) |
5104 | return emitError( |
5105 | "`local_init` specifiers require at least a `copy` region." ); |
5106 | |
5107 | if (dsType == LocalitySpecifierType::LocalInit && |
5108 | failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy" , |
5109 | /*yieldsValue=*/true))) |
5110 | return llvm::failure(); |
5111 | |
5112 | if (!getDeallocRegion().empty() && |
5113 | failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc" , |
5114 | /*yieldsValue=*/false))) |
5115 | return llvm::failure(); |
5116 | |
5117 | return llvm::success(); |
5118 | } |
5119 | |
5120 | //===----------------------------------------------------------------------===// |
5121 | // DoConcurrentOp |
5122 | //===----------------------------------------------------------------------===// |
5123 | |
5124 | llvm::LogicalResult fir::DoConcurrentOp::verify() { |
5125 | mlir::Block *body = getBody(); |
5126 | |
5127 | if (body->empty()) |
5128 | return emitOpError("body cannot be empty" ); |
5129 | |
5130 | if (!body->mightHaveTerminator() || |
5131 | !mlir::isa<fir::DoConcurrentLoopOp>(body->getTerminator())) |
5132 | return emitOpError("must be terminated by 'fir.do_concurrent.loop'" ); |
5133 | |
5134 | return mlir::success(); |
5135 | } |
5136 | |
5137 | //===----------------------------------------------------------------------===// |
5138 | // DoConcurrentLoopOp |
5139 | //===----------------------------------------------------------------------===// |
5140 | |
5141 | mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser, |
5142 | mlir::OperationState &result) { |
5143 | auto &builder = parser.getBuilder(); |
5144 | // Parse an opening `(` followed by induction variables followed by `)` |
5145 | llvm::SmallVector<mlir::OpAsmParser::Argument, 4> regionArgs; |
5146 | |
5147 | if (parser.parseArgumentList(regionArgs, mlir::OpAsmParser::Delimiter::Paren)) |
5148 | return mlir::failure(); |
5149 | |
5150 | llvm::SmallVector<mlir::Type> argTypes(regionArgs.size(), |
5151 | builder.getIndexType()); |
5152 | |
5153 | // Parse loop bounds. |
5154 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> lower; |
5155 | if (parser.parseEqual() || |
5156 | parser.parseOperandList(lower, regionArgs.size(), |
5157 | mlir::OpAsmParser::Delimiter::Paren) || |
5158 | parser.resolveOperands(lower, builder.getIndexType(), result.operands)) |
5159 | return mlir::failure(); |
5160 | |
5161 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> upper; |
5162 | if (parser.parseKeyword("to" ) || |
5163 | parser.parseOperandList(upper, regionArgs.size(), |
5164 | mlir::OpAsmParser::Delimiter::Paren) || |
5165 | parser.resolveOperands(upper, builder.getIndexType(), result.operands)) |
5166 | return mlir::failure(); |
5167 | |
5168 | // Parse step values. |
5169 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> steps; |
5170 | if (parser.parseKeyword("step" ) || |
5171 | parser.parseOperandList(steps, regionArgs.size(), |
5172 | mlir::OpAsmParser::Delimiter::Paren) || |
5173 | parser.resolveOperands(steps, builder.getIndexType(), result.operands)) |
5174 | return mlir::failure(); |
5175 | |
5176 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> reduceOperands; |
5177 | llvm::SmallVector<mlir::Type> reduceArgTypes; |
5178 | if (succeeded(parser.parseOptionalKeyword("reduce" ))) { |
5179 | // Parse reduction attributes and variables. |
5180 | llvm::SmallVector<fir::ReduceAttr> attributes; |
5181 | if (failed(parser.parseCommaSeparatedList( |
5182 | mlir::AsmParser::Delimiter::Paren, [&]() { |
5183 | if (parser.parseAttribute(attributes.emplace_back()) || |
5184 | parser.parseArrow() || |
5185 | parser.parseOperand(reduceOperands.emplace_back()) || |
5186 | parser.parseColonType(reduceArgTypes.emplace_back())) |
5187 | return mlir::failure(); |
5188 | return mlir::success(); |
5189 | }))) |
5190 | return mlir::failure(); |
5191 | // Resolve input operands. |
5192 | for (auto operand_type : llvm::zip(reduceOperands, reduceArgTypes)) |
5193 | if (parser.resolveOperand(std::get<0>(operand_type), |
5194 | std::get<1>(operand_type), result.operands)) |
5195 | return mlir::failure(); |
5196 | llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(), |
5197 | attributes.end()); |
5198 | result.addAttribute(getReduceAttrsAttrName(result.name), |
5199 | builder.getArrayAttr(arrayAttr)); |
5200 | } |
5201 | |
5202 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> localOperands; |
5203 | if (succeeded(parser.parseOptionalKeyword("local" ))) { |
5204 | std::size_t oldArgTypesSize = argTypes.size(); |
5205 | if (failed(parser.parseLParen())) |
5206 | return mlir::failure(); |
5207 | |
5208 | llvm::SmallVector<mlir::SymbolRefAttr> localSymbolVec; |
5209 | if (failed(parser.parseCommaSeparatedList([&]() { |
5210 | if (failed(parser.parseAttribute(localSymbolVec.emplace_back()))) |
5211 | return mlir::failure(); |
5212 | |
5213 | if (parser.parseOperand(localOperands.emplace_back()) || |
5214 | parser.parseArrow() || |
5215 | parser.parseArgument(regionArgs.emplace_back())) |
5216 | return mlir::failure(); |
5217 | |
5218 | return mlir::success(); |
5219 | }))) |
5220 | return mlir::failure(); |
5221 | |
5222 | if (failed(parser.parseColon())) |
5223 | return mlir::failure(); |
5224 | |
5225 | if (failed(parser.parseCommaSeparatedList([&]() { |
5226 | if (failed(parser.parseType(argTypes.emplace_back()))) |
5227 | return mlir::failure(); |
5228 | |
5229 | return mlir::success(); |
5230 | }))) |
5231 | return mlir::failure(); |
5232 | |
5233 | if (regionArgs.size() != argTypes.size()) |
5234 | return parser.emitError(parser.getNameLoc(), |
5235 | "mismatch in number of local arg and types" ); |
5236 | |
5237 | if (failed(parser.parseRParen())) |
5238 | return mlir::failure(); |
5239 | |
5240 | for (auto operandType : llvm::zip_equal( |
5241 | localOperands, llvm::drop_begin(argTypes, oldArgTypesSize))) |
5242 | if (parser.resolveOperand(std::get<0>(operandType), |
5243 | std::get<1>(operandType), result.operands)) |
5244 | return mlir::failure(); |
5245 | |
5246 | llvm::SmallVector<mlir::Attribute> symbolAttrs(localSymbolVec.begin(), |
5247 | localSymbolVec.end()); |
5248 | result.addAttribute(getLocalSymsAttrName(result.name), |
5249 | builder.getArrayAttr(symbolAttrs)); |
5250 | } |
5251 | |
5252 | // Set `operandSegmentSizes` attribute. |
5253 | result.addAttribute(DoConcurrentLoopOp::getOperandSegmentSizeAttr(), |
5254 | builder.getDenseI32ArrayAttr( |
5255 | {static_cast<int32_t>(lower.size()), |
5256 | static_cast<int32_t>(upper.size()), |
5257 | static_cast<int32_t>(steps.size()), |
5258 | static_cast<int32_t>(reduceOperands.size()), |
5259 | static_cast<int32_t>(localOperands.size())})); |
5260 | |
5261 | // Now parse the body. |
5262 | for (auto [arg, type] : llvm::zip_equal(regionArgs, argTypes)) |
5263 | arg.type = type; |
5264 | |
5265 | mlir::Region *body = result.addRegion(); |
5266 | if (parser.parseRegion(*body, regionArgs)) |
5267 | return mlir::failure(); |
5268 | |
5269 | // Parse attributes. |
5270 | if (parser.parseOptionalAttrDict(result.attributes)) |
5271 | return mlir::failure(); |
5272 | |
5273 | return mlir::success(); |
5274 | } |
5275 | |
5276 | void fir::DoConcurrentLoopOp::print(mlir::OpAsmPrinter &p) { |
5277 | p << " (" << getBody()->getArguments().slice(0, getNumInductionVars()) |
5278 | << ") = (" << getLowerBound() << ") to (" << getUpperBound() << ") step (" |
5279 | << getStep() << ")" ; |
5280 | |
5281 | if (!getReduceOperands().empty()) { |
5282 | p << " reduce(" ; |
5283 | auto attrs = getReduceAttrsAttr(); |
5284 | auto operands = getReduceOperands(); |
5285 | llvm::interleaveComma(llvm::zip(attrs, operands), p, [&](auto it) { |
5286 | p << std::get<0>(it) << " -> " << std::get<1>(it) << " : " |
5287 | << std::get<1>(it).getType(); |
5288 | }); |
5289 | p << ')'; |
5290 | } |
5291 | |
5292 | if (!getLocalVars().empty()) { |
5293 | p << " local(" ; |
5294 | llvm::interleaveComma(llvm::zip_equal(getLocalSymsAttr(), getLocalVars(), |
5295 | getRegionLocalArgs()), |
5296 | p, [&](auto it) { |
5297 | p << std::get<0>(it) << " " << std::get<1>(it) |
5298 | << " -> " << std::get<2>(it); |
5299 | }); |
5300 | p << " : " ; |
5301 | llvm::interleaveComma(getLocalVars(), p, |
5302 | [&](auto it) { p << it.getType(); }); |
5303 | p << ")" ; |
5304 | } |
5305 | |
5306 | p << ' '; |
5307 | p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); |
5308 | p.printOptionalAttrDict( |
5309 | (*this)->getAttrs(), |
5310 | /*elidedAttrs=*/{DoConcurrentLoopOp::getOperandSegmentSizeAttr(), |
5311 | DoConcurrentLoopOp::getReduceAttrsAttrName(), |
5312 | DoConcurrentLoopOp::getLocalSymsAttrName()}); |
5313 | } |
5314 | |
5315 | llvm::SmallVector<mlir::Region *> fir::DoConcurrentLoopOp::getLoopRegions() { |
5316 | return {&getRegion()}; |
5317 | } |
5318 | |
5319 | llvm::LogicalResult fir::DoConcurrentLoopOp::verify() { |
5320 | mlir::Operation::operand_range lbValues = getLowerBound(); |
5321 | mlir::Operation::operand_range ubValues = getUpperBound(); |
5322 | mlir::Operation::operand_range stepValues = getStep(); |
5323 | mlir::Operation::operand_range localVars = getLocalVars(); |
5324 | |
5325 | if (lbValues.empty()) |
5326 | return emitOpError( |
5327 | "needs at least one tuple element for lowerBound, upperBound and step" ); |
5328 | |
5329 | if (lbValues.size() != ubValues.size() || |
5330 | ubValues.size() != stepValues.size()) |
5331 | return emitOpError("different number of tuple elements for lowerBound, " |
5332 | "upperBound or step" ); |
5333 | |
5334 | // Check that the body defines the same number of block arguments as the |
5335 | // number of tuple elements in step. |
5336 | mlir::Block *body = getBody(); |
5337 | unsigned numIndVarArgs = body->getNumArguments() - localVars.size(); |
5338 | |
5339 | if (numIndVarArgs != stepValues.size()) |
5340 | return emitOpError() << "expects the same number of induction variables: " |
5341 | << body->getNumArguments() |
5342 | << " as bound and step values: " << stepValues.size(); |
5343 | for (auto arg : body->getArguments().slice(0, numIndVarArgs)) |
5344 | if (!arg.getType().isIndex()) |
5345 | return emitOpError( |
5346 | "expects arguments for the induction variable to be of index type" ); |
5347 | |
5348 | auto reduceAttrs = getReduceAttrsAttr(); |
5349 | if (getNumReduceOperands() != (reduceAttrs ? reduceAttrs.size() : 0)) |
5350 | return emitOpError( |
5351 | "mismatch in number of reduction variables and reduction attributes" ); |
5352 | |
5353 | return mlir::success(); |
5354 | } |
5355 | |
5356 | std::optional<llvm::SmallVector<mlir::Value>> |
5357 | fir::DoConcurrentLoopOp::getLoopInductionVars() { |
5358 | return llvm::SmallVector<mlir::Value>{ |
5359 | getBody()->getArguments().slice(0, getLowerBound().size())}; |
5360 | } |
5361 | |
5362 | //===----------------------------------------------------------------------===// |
5363 | // FIROpsDialect |
5364 | //===----------------------------------------------------------------------===// |
5365 | |
5366 | void fir::FIROpsDialect::registerOpExternalInterfaces() { |
5367 | // Attach default declare target interfaces to operations which can be marked |
5368 | // as declare target. |
5369 | fir::GlobalOp::attachInterface< |
5370 | mlir::omp::DeclareTargetDefaultModel<fir::GlobalOp>>(*getContext()); |
5371 | } |
5372 | |
5373 | // Tablegen operators |
5374 | |
5375 | #define GET_OP_CLASSES |
5376 | #include "flang/Optimizer/Dialect/FIROps.cpp.inc" |
5377 | |