1 | //===-- FIROps.cpp --------------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "flang/Optimizer/Dialect/FIROps.h" |
14 | #include "flang/Optimizer/Dialect/FIRAttr.h" |
15 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
16 | #include "flang/Optimizer/Dialect/FIROpsSupport.h" |
17 | #include "flang/Optimizer/Dialect/FIRType.h" |
18 | #include "flang/Optimizer/Dialect/Support/FIRContext.h" |
19 | #include "flang/Optimizer/Dialect/Support/KindMapping.h" |
20 | #include "flang/Optimizer/Support/Utils.h" |
21 | #include "mlir/Dialect/CommonFolders.h" |
22 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
23 | #include "mlir/Dialect/OpenACC/OpenACC.h" |
24 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
25 | #include "mlir/IR/Attributes.h" |
26 | #include "mlir/IR/BuiltinAttributes.h" |
27 | #include "mlir/IR/BuiltinOps.h" |
28 | #include "mlir/IR/Diagnostics.h" |
29 | #include "mlir/IR/Matchers.h" |
30 | #include "mlir/IR/OpDefinition.h" |
31 | #include "mlir/IR/PatternMatch.h" |
32 | #include "llvm/ADT/STLExtras.h" |
33 | #include "llvm/ADT/SmallVector.h" |
34 | #include "llvm/ADT/TypeSwitch.h" |
35 | |
36 | namespace { |
37 | #include "flang/Optimizer/Dialect/CanonicalizationPatterns.inc" |
38 | } // namespace |
39 | |
40 | static void propagateAttributes(mlir::Operation *fromOp, |
41 | mlir::Operation *toOp) { |
42 | if (!fromOp || !toOp) |
43 | return; |
44 | |
45 | for (mlir::NamedAttribute attr : fromOp->getAttrs()) { |
46 | if (attr.getName().getValue().starts_with( |
47 | mlir::acc::OpenACCDialect::getDialectNamespace())) |
48 | toOp->setAttr(attr.getName(), attr.getValue()); |
49 | } |
50 | } |
51 | |
52 | /// Return true if a sequence type is of some incomplete size or a record type |
53 | /// is malformed or contains an incomplete sequence type. An incomplete sequence |
54 | /// type is one with more unknown extents in the type than have been provided |
55 | /// via `dynamicExtents`. Sequence types with an unknown rank are incomplete by |
56 | /// definition. |
57 | static bool verifyInType(mlir::Type inType, |
58 | llvm::SmallVectorImpl<llvm::StringRef> &visited, |
59 | unsigned dynamicExtents = 0) { |
60 | if (auto st = inType.dyn_cast<fir::SequenceType>()) { |
61 | auto shape = st.getShape(); |
62 | if (shape.size() == 0) |
63 | return true; |
64 | for (std::size_t i = 0, end = shape.size(); i < end; ++i) { |
65 | if (shape[i] != fir::SequenceType::getUnknownExtent()) |
66 | continue; |
67 | if (dynamicExtents-- == 0) |
68 | return true; |
69 | } |
70 | } else if (auto rt = inType.dyn_cast<fir::RecordType>()) { |
71 | // don't recurse if we're already visiting this one |
72 | if (llvm::is_contained(visited, rt.getName())) |
73 | return false; |
74 | // keep track of record types currently being visited |
75 | visited.push_back(Elt: rt.getName()); |
76 | for (auto &field : rt.getTypeList()) |
77 | if (verifyInType(field.second, visited)) |
78 | return true; |
79 | visited.pop_back(); |
80 | } |
81 | return false; |
82 | } |
83 | |
84 | static bool verifyTypeParamCount(mlir::Type inType, unsigned numParams) { |
85 | auto ty = fir::unwrapSequenceType(inType); |
86 | if (numParams > 0) { |
87 | if (auto recTy = ty.dyn_cast<fir::RecordType>()) |
88 | return numParams != recTy.getNumLenParams(); |
89 | if (auto chrTy = ty.dyn_cast<fir::CharacterType>()) |
90 | return !(numParams == 1 && chrTy.hasDynamicLen()); |
91 | return true; |
92 | } |
93 | if (auto chrTy = ty.dyn_cast<fir::CharacterType>()) |
94 | return !chrTy.hasConstantLen(); |
95 | return false; |
96 | } |
97 | |
98 | /// Parser shared by Alloca and Allocmem |
99 | /// |
100 | /// operation ::= %res = (`fir.alloca` | `fir.allocmem`) $in_type |
101 | /// ( `(` $typeparams `)` )? ( `,` $shape )? |
102 | /// attr-dict-without-keyword |
103 | template <typename FN> |
104 | static mlir::ParseResult parseAllocatableOp(FN wrapResultType, |
105 | mlir::OpAsmParser &parser, |
106 | mlir::OperationState &result) { |
107 | mlir::Type intype; |
108 | if (parser.parseType(result&: intype)) |
109 | return mlir::failure(); |
110 | auto &builder = parser.getBuilder(); |
111 | result.addAttribute("in_type" , mlir::TypeAttr::get(intype)); |
112 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands; |
113 | llvm::SmallVector<mlir::Type> typeVec; |
114 | bool hasOperands = false; |
115 | std::int32_t typeparamsSize = 0; |
116 | if (!parser.parseOptionalLParen()) { |
117 | // parse the LEN params of the derived type. (<params> : <types>) |
118 | if (parser.parseOperandList(result&: operands, delimiter: mlir::OpAsmParser::Delimiter::None) || |
119 | parser.parseColonTypeList(result&: typeVec) || parser.parseRParen()) |
120 | return mlir::failure(); |
121 | typeparamsSize = operands.size(); |
122 | hasOperands = true; |
123 | } |
124 | std::int32_t shapeSize = 0; |
125 | if (!parser.parseOptionalComma()) { |
126 | // parse size to scale by, vector of n dimensions of type index |
127 | if (parser.parseOperandList(result&: operands, delimiter: mlir::OpAsmParser::Delimiter::None)) |
128 | return mlir::failure(); |
129 | shapeSize = operands.size() - typeparamsSize; |
130 | auto idxTy = builder.getIndexType(); |
131 | for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i) |
132 | typeVec.push_back(Elt: idxTy); |
133 | hasOperands = true; |
134 | } |
135 | if (hasOperands && |
136 | parser.resolveOperands(operands, types&: typeVec, loc: parser.getNameLoc(), |
137 | result&: result.operands)) |
138 | return mlir::failure(); |
139 | mlir::Type restype = wrapResultType(intype); |
140 | if (!restype) { |
141 | parser.emitError(loc: parser.getNameLoc(), message: "invalid allocate type: " ) << intype; |
142 | return mlir::failure(); |
143 | } |
144 | result.addAttribute("operandSegmentSizes" , builder.getDenseI32ArrayAttr( |
145 | {typeparamsSize, shapeSize})); |
146 | if (parser.parseOptionalAttrDict(result&: result.attributes) || |
147 | parser.addTypeToList(type: restype, result&: result.types)) |
148 | return mlir::failure(); |
149 | return mlir::success(); |
150 | } |
151 | |
152 | template <typename OP> |
153 | static void printAllocatableOp(mlir::OpAsmPrinter &p, OP &op) { |
154 | p << ' ' << op.getInType(); |
155 | if (!op.getTypeparams().empty()) { |
156 | p << '(' << op.getTypeparams() << " : " << op.getTypeparams().getTypes() |
157 | << ')'; |
158 | } |
159 | // print the shape of the allocation (if any); all must be index type |
160 | for (auto sh : op.getShape()) { |
161 | p << ", " ; |
162 | p.printOperand(sh); |
163 | } |
164 | p.printOptionalAttrDict(attrs: op->getAttrs(), elidedAttrs: {"in_type" , "operandSegmentSizes" }); |
165 | } |
166 | |
167 | //===----------------------------------------------------------------------===// |
168 | // AllocaOp |
169 | //===----------------------------------------------------------------------===// |
170 | |
171 | /// Create a legal memory reference as return type |
172 | static mlir::Type wrapAllocaResultType(mlir::Type intype) { |
173 | // FIR semantics: memory references to memory references are disallowed |
174 | if (intype.isa<fir::ReferenceType>()) |
175 | return {}; |
176 | return fir::ReferenceType::get(intype); |
177 | } |
178 | |
179 | mlir::Type fir::AllocaOp::getAllocatedType() { |
180 | return getType().cast<fir::ReferenceType>().getEleTy(); |
181 | } |
182 | |
183 | mlir::Type fir::AllocaOp::getRefTy(mlir::Type ty) { |
184 | return fir::ReferenceType::get(ty); |
185 | } |
186 | |
187 | void fir::AllocaOp::build(mlir::OpBuilder &builder, |
188 | mlir::OperationState &result, mlir::Type inType, |
189 | llvm::StringRef uniqName, mlir::ValueRange typeparams, |
190 | mlir::ValueRange shape, |
191 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
192 | auto nameAttr = builder.getStringAttr(uniqName); |
193 | build(builder, result, wrapAllocaResultType(inType), inType, nameAttr, {}, |
194 | /*pinned=*/false, typeparams, shape); |
195 | result.addAttributes(attributes); |
196 | } |
197 | |
198 | void fir::AllocaOp::build(mlir::OpBuilder &builder, |
199 | mlir::OperationState &result, mlir::Type inType, |
200 | llvm::StringRef uniqName, bool pinned, |
201 | mlir::ValueRange typeparams, mlir::ValueRange shape, |
202 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
203 | auto nameAttr = builder.getStringAttr(uniqName); |
204 | build(builder, result, wrapAllocaResultType(inType), inType, nameAttr, {}, |
205 | pinned, typeparams, shape); |
206 | result.addAttributes(attributes); |
207 | } |
208 | |
209 | void fir::AllocaOp::build(mlir::OpBuilder &builder, |
210 | mlir::OperationState &result, mlir::Type inType, |
211 | llvm::StringRef uniqName, llvm::StringRef bindcName, |
212 | mlir::ValueRange typeparams, mlir::ValueRange shape, |
213 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
214 | auto nameAttr = |
215 | uniqName.empty() ? mlir::StringAttr{} : builder.getStringAttr(uniqName); |
216 | auto bindcAttr = |
217 | bindcName.empty() ? mlir::StringAttr{} : builder.getStringAttr(bindcName); |
218 | build(builder, result, wrapAllocaResultType(inType), inType, nameAttr, |
219 | bindcAttr, /*pinned=*/false, typeparams, shape); |
220 | result.addAttributes(attributes); |
221 | } |
222 | |
223 | void fir::AllocaOp::build(mlir::OpBuilder &builder, |
224 | mlir::OperationState &result, mlir::Type inType, |
225 | llvm::StringRef uniqName, llvm::StringRef bindcName, |
226 | bool pinned, mlir::ValueRange typeparams, |
227 | mlir::ValueRange shape, |
228 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
229 | auto nameAttr = |
230 | uniqName.empty() ? mlir::StringAttr{} : builder.getStringAttr(uniqName); |
231 | auto bindcAttr = |
232 | bindcName.empty() ? mlir::StringAttr{} : builder.getStringAttr(bindcName); |
233 | build(builder, result, wrapAllocaResultType(inType), inType, nameAttr, |
234 | bindcAttr, pinned, typeparams, shape); |
235 | result.addAttributes(attributes); |
236 | } |
237 | |
238 | void fir::AllocaOp::build(mlir::OpBuilder &builder, |
239 | mlir::OperationState &result, mlir::Type inType, |
240 | mlir::ValueRange typeparams, mlir::ValueRange shape, |
241 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
242 | build(builder, result, wrapAllocaResultType(inType), inType, {}, {}, |
243 | /*pinned=*/false, typeparams, shape); |
244 | result.addAttributes(attributes); |
245 | } |
246 | |
247 | void fir::AllocaOp::build(mlir::OpBuilder &builder, |
248 | mlir::OperationState &result, mlir::Type inType, |
249 | bool pinned, mlir::ValueRange typeparams, |
250 | mlir::ValueRange shape, |
251 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
252 | build(builder, result, wrapAllocaResultType(inType), inType, {}, {}, pinned, |
253 | typeparams, shape); |
254 | result.addAttributes(attributes); |
255 | } |
256 | |
257 | mlir::ParseResult fir::AllocaOp::parse(mlir::OpAsmParser &parser, |
258 | mlir::OperationState &result) { |
259 | return parseAllocatableOp(wrapAllocaResultType, parser, result); |
260 | } |
261 | |
262 | void fir::AllocaOp::print(mlir::OpAsmPrinter &p) { |
263 | printAllocatableOp(p, *this); |
264 | } |
265 | |
266 | mlir::LogicalResult fir::AllocaOp::verify() { |
267 | llvm::SmallVector<llvm::StringRef> visited; |
268 | if (verifyInType(getInType(), visited, numShapeOperands())) |
269 | return emitOpError("invalid type for allocation" ); |
270 | if (verifyTypeParamCount(getInType(), numLenParams())) |
271 | return emitOpError("LEN params do not correspond to type" ); |
272 | mlir::Type outType = getType(); |
273 | if (!outType.isa<fir::ReferenceType>()) |
274 | return emitOpError("must be a !fir.ref type" ); |
275 | if (fir::isa_unknown_size_box(fir::dyn_cast_ptrEleTy(outType))) |
276 | return emitOpError("cannot allocate !fir.box of unknown rank or type" ); |
277 | return mlir::success(); |
278 | } |
279 | |
280 | //===----------------------------------------------------------------------===// |
281 | // AllocMemOp |
282 | //===----------------------------------------------------------------------===// |
283 | |
284 | /// Create a legal heap reference as return type |
285 | static mlir::Type wrapAllocMemResultType(mlir::Type intype) { |
286 | // Fortran semantics: C852 an entity cannot be both ALLOCATABLE and POINTER |
287 | // 8.5.3 note 1 prohibits ALLOCATABLE procedures as well |
288 | // FIR semantics: one may not allocate a memory reference value |
289 | if (intype.isa<fir::ReferenceType, fir::HeapType, fir::PointerType, |
290 | mlir::FunctionType>()) |
291 | return {}; |
292 | return fir::HeapType::get(intype); |
293 | } |
294 | |
295 | mlir::Type fir::AllocMemOp::getAllocatedType() { |
296 | return getType().cast<fir::HeapType>().getEleTy(); |
297 | } |
298 | |
299 | mlir::Type fir::AllocMemOp::getRefTy(mlir::Type ty) { |
300 | return fir::HeapType::get(ty); |
301 | } |
302 | |
303 | void fir::AllocMemOp::build(mlir::OpBuilder &builder, |
304 | mlir::OperationState &result, mlir::Type inType, |
305 | llvm::StringRef uniqName, |
306 | mlir::ValueRange typeparams, mlir::ValueRange shape, |
307 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
308 | auto nameAttr = builder.getStringAttr(uniqName); |
309 | build(builder, result, wrapAllocMemResultType(inType), inType, nameAttr, {}, |
310 | typeparams, shape); |
311 | result.addAttributes(attributes); |
312 | } |
313 | |
314 | void fir::AllocMemOp::build(mlir::OpBuilder &builder, |
315 | mlir::OperationState &result, mlir::Type inType, |
316 | llvm::StringRef uniqName, llvm::StringRef bindcName, |
317 | mlir::ValueRange typeparams, mlir::ValueRange shape, |
318 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
319 | auto nameAttr = builder.getStringAttr(uniqName); |
320 | auto bindcAttr = builder.getStringAttr(bindcName); |
321 | build(builder, result, wrapAllocMemResultType(inType), inType, nameAttr, |
322 | bindcAttr, typeparams, shape); |
323 | result.addAttributes(attributes); |
324 | } |
325 | |
326 | void fir::AllocMemOp::build(mlir::OpBuilder &builder, |
327 | mlir::OperationState &result, mlir::Type inType, |
328 | mlir::ValueRange typeparams, mlir::ValueRange shape, |
329 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
330 | build(builder, result, wrapAllocMemResultType(inType), inType, {}, {}, |
331 | typeparams, shape); |
332 | result.addAttributes(attributes); |
333 | } |
334 | |
335 | mlir::ParseResult fir::AllocMemOp::parse(mlir::OpAsmParser &parser, |
336 | mlir::OperationState &result) { |
337 | return parseAllocatableOp(wrapAllocMemResultType, parser, result); |
338 | } |
339 | |
340 | void fir::AllocMemOp::print(mlir::OpAsmPrinter &p) { |
341 | printAllocatableOp(p, *this); |
342 | } |
343 | |
344 | mlir::LogicalResult fir::AllocMemOp::verify() { |
345 | llvm::SmallVector<llvm::StringRef> visited; |
346 | if (verifyInType(getInType(), visited, numShapeOperands())) |
347 | return emitOpError("invalid type for allocation" ); |
348 | if (verifyTypeParamCount(getInType(), numLenParams())) |
349 | return emitOpError("LEN params do not correspond to type" ); |
350 | mlir::Type outType = getType(); |
351 | if (!outType.dyn_cast<fir::HeapType>()) |
352 | return emitOpError("must be a !fir.heap type" ); |
353 | if (fir::isa_unknown_size_box(fir::dyn_cast_ptrEleTy(outType))) |
354 | return emitOpError("cannot allocate !fir.box of unknown rank or type" ); |
355 | return mlir::success(); |
356 | } |
357 | |
358 | //===----------------------------------------------------------------------===// |
359 | // ArrayCoorOp |
360 | //===----------------------------------------------------------------------===// |
361 | |
362 | // CHARACTERs and derived types with LEN PARAMETERs are dependent types that |
363 | // require runtime values to fully define the type of an object. |
364 | static bool validTypeParams(mlir::Type dynTy, mlir::ValueRange typeParams) { |
365 | dynTy = fir::unwrapAllRefAndSeqType(dynTy); |
366 | // A box value will contain type parameter values itself. |
367 | if (dynTy.isa<fir::BoxType>()) |
368 | return typeParams.size() == 0; |
369 | // Derived type must have all type parameters satisfied. |
370 | if (auto recTy = dynTy.dyn_cast<fir::RecordType>()) |
371 | return typeParams.size() == recTy.getNumLenParams(); |
372 | // Characters with non-constant LEN must have a type parameter value. |
373 | if (auto charTy = dynTy.dyn_cast<fir::CharacterType>()) |
374 | if (charTy.hasDynamicLen()) |
375 | return typeParams.size() == 1; |
376 | // Otherwise, any type parameters are invalid. |
377 | return typeParams.size() == 0; |
378 | } |
379 | |
380 | mlir::LogicalResult fir::ArrayCoorOp::verify() { |
381 | auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(getMemref().getType()); |
382 | auto arrTy = eleTy.dyn_cast<fir::SequenceType>(); |
383 | if (!arrTy) |
384 | return emitOpError("must be a reference to an array" ); |
385 | auto arrDim = arrTy.getDimension(); |
386 | |
387 | if (auto shapeOp = getShape()) { |
388 | auto shapeTy = shapeOp.getType(); |
389 | unsigned shapeTyRank = 0; |
390 | if (auto s = shapeTy.dyn_cast<fir::ShapeType>()) { |
391 | shapeTyRank = s.getRank(); |
392 | } else if (auto ss = shapeTy.dyn_cast<fir::ShapeShiftType>()) { |
393 | shapeTyRank = ss.getRank(); |
394 | } else { |
395 | auto s = shapeTy.cast<fir::ShiftType>(); |
396 | shapeTyRank = s.getRank(); |
397 | if (!getMemref().getType().isa<fir::BaseBoxType>()) |
398 | return emitOpError("shift can only be provided with fir.box memref" ); |
399 | } |
400 | if (arrDim && arrDim != shapeTyRank) |
401 | return emitOpError("rank of dimension mismatched" ); |
402 | if (shapeTyRank != getIndices().size()) |
403 | return emitOpError("number of indices do not match dim rank" ); |
404 | } |
405 | |
406 | if (auto sliceOp = getSlice()) { |
407 | if (auto sl = mlir::dyn_cast_or_null<fir::SliceOp>(sliceOp.getDefiningOp())) |
408 | if (!sl.getSubstr().empty()) |
409 | return emitOpError("array_coor cannot take a slice with substring" ); |
410 | if (auto sliceTy = sliceOp.getType().dyn_cast<fir::SliceType>()) |
411 | if (sliceTy.getRank() != arrDim) |
412 | return emitOpError("rank of dimension in slice mismatched" ); |
413 | } |
414 | if (!validTypeParams(getMemref().getType(), getTypeparams())) |
415 | return emitOpError("invalid type parameters" ); |
416 | |
417 | return mlir::success(); |
418 | } |
419 | |
420 | //===----------------------------------------------------------------------===// |
421 | // ArrayLoadOp |
422 | //===----------------------------------------------------------------------===// |
423 | |
424 | static mlir::Type adjustedElementType(mlir::Type t) { |
425 | if (auto ty = t.dyn_cast<fir::ReferenceType>()) { |
426 | auto eleTy = ty.getEleTy(); |
427 | if (fir::isa_char(eleTy)) |
428 | return eleTy; |
429 | if (fir::isa_derived(eleTy)) |
430 | return eleTy; |
431 | if (eleTy.isa<fir::SequenceType>()) |
432 | return eleTy; |
433 | } |
434 | return t; |
435 | } |
436 | |
437 | std::vector<mlir::Value> fir::ArrayLoadOp::getExtents() { |
438 | if (auto sh = getShape()) |
439 | if (auto *op = sh.getDefiningOp()) { |
440 | if (auto shOp = mlir::dyn_cast<fir::ShapeOp>(op)) { |
441 | auto extents = shOp.getExtents(); |
442 | return {extents.begin(), extents.end()}; |
443 | } |
444 | return mlir::cast<fir::ShapeShiftOp>(op).getExtents(); |
445 | } |
446 | return {}; |
447 | } |
448 | |
449 | mlir::LogicalResult fir::ArrayLoadOp::verify() { |
450 | auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(getMemref().getType()); |
451 | auto arrTy = eleTy.dyn_cast<fir::SequenceType>(); |
452 | if (!arrTy) |
453 | return emitOpError("must be a reference to an array" ); |
454 | auto arrDim = arrTy.getDimension(); |
455 | |
456 | if (auto shapeOp = getShape()) { |
457 | auto shapeTy = shapeOp.getType(); |
458 | unsigned shapeTyRank = 0u; |
459 | if (auto s = shapeTy.dyn_cast<fir::ShapeType>()) { |
460 | shapeTyRank = s.getRank(); |
461 | } else if (auto ss = shapeTy.dyn_cast<fir::ShapeShiftType>()) { |
462 | shapeTyRank = ss.getRank(); |
463 | } else { |
464 | auto s = shapeTy.cast<fir::ShiftType>(); |
465 | shapeTyRank = s.getRank(); |
466 | if (!getMemref().getType().isa<fir::BaseBoxType>()) |
467 | return emitOpError("shift can only be provided with fir.box memref" ); |
468 | } |
469 | if (arrDim && arrDim != shapeTyRank) |
470 | return emitOpError("rank of dimension mismatched" ); |
471 | } |
472 | |
473 | if (auto sliceOp = getSlice()) { |
474 | if (auto sl = mlir::dyn_cast_or_null<fir::SliceOp>(sliceOp.getDefiningOp())) |
475 | if (!sl.getSubstr().empty()) |
476 | return emitOpError("array_load cannot take a slice with substring" ); |
477 | if (auto sliceTy = sliceOp.getType().dyn_cast<fir::SliceType>()) |
478 | if (sliceTy.getRank() != arrDim) |
479 | return emitOpError("rank of dimension in slice mismatched" ); |
480 | } |
481 | |
482 | if (!validTypeParams(getMemref().getType(), getTypeparams())) |
483 | return emitOpError("invalid type parameters" ); |
484 | |
485 | return mlir::success(); |
486 | } |
487 | |
488 | //===----------------------------------------------------------------------===// |
489 | // ArrayMergeStoreOp |
490 | //===----------------------------------------------------------------------===// |
491 | |
492 | mlir::LogicalResult fir::ArrayMergeStoreOp::verify() { |
493 | if (!mlir::isa<fir::ArrayLoadOp>(getOriginal().getDefiningOp())) |
494 | return emitOpError("operand #0 must be result of a fir.array_load op" ); |
495 | if (auto sl = getSlice()) { |
496 | if (auto sliceOp = |
497 | mlir::dyn_cast_or_null<fir::SliceOp>(sl.getDefiningOp())) { |
498 | if (!sliceOp.getSubstr().empty()) |
499 | return emitOpError( |
500 | "array_merge_store cannot take a slice with substring" ); |
501 | if (!sliceOp.getFields().empty()) { |
502 | // This is an intra-object merge, where the slice is projecting the |
503 | // subfields that are to be overwritten by the merge operation. |
504 | auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(getMemref().getType()); |
505 | if (auto seqTy = eleTy.dyn_cast<fir::SequenceType>()) { |
506 | auto projTy = |
507 | fir::applyPathToType(seqTy.getEleTy(), sliceOp.getFields()); |
508 | if (fir::unwrapSequenceType(getOriginal().getType()) != projTy) |
509 | return emitOpError( |
510 | "type of origin does not match sliced memref type" ); |
511 | if (fir::unwrapSequenceType(getSequence().getType()) != projTy) |
512 | return emitOpError( |
513 | "type of sequence does not match sliced memref type" ); |
514 | return mlir::success(); |
515 | } |
516 | return emitOpError("referenced type is not an array" ); |
517 | } |
518 | } |
519 | return mlir::success(); |
520 | } |
521 | auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(getMemref().getType()); |
522 | if (getOriginal().getType() != eleTy) |
523 | return emitOpError("type of origin does not match memref element type" ); |
524 | if (getSequence().getType() != eleTy) |
525 | return emitOpError("type of sequence does not match memref element type" ); |
526 | if (!validTypeParams(getMemref().getType(), getTypeparams())) |
527 | return emitOpError("invalid type parameters" ); |
528 | return mlir::success(); |
529 | } |
530 | |
531 | //===----------------------------------------------------------------------===// |
532 | // ArrayFetchOp |
533 | //===----------------------------------------------------------------------===// |
534 | |
535 | // Template function used for both array_fetch and array_update verification. |
536 | template <typename A> |
537 | mlir::Type validArraySubobject(A op) { |
538 | auto ty = op.getSequence().getType(); |
539 | return fir::applyPathToType(ty, op.getIndices()); |
540 | } |
541 | |
542 | mlir::LogicalResult fir::ArrayFetchOp::verify() { |
543 | auto arrTy = getSequence().getType().cast<fir::SequenceType>(); |
544 | auto indSize = getIndices().size(); |
545 | if (indSize < arrTy.getDimension()) |
546 | return emitOpError("number of indices != dimension of array" ); |
547 | if (indSize == arrTy.getDimension() && |
548 | ::adjustedElementType(getElement().getType()) != arrTy.getEleTy()) |
549 | return emitOpError("return type does not match array" ); |
550 | auto ty = validArraySubobject(*this); |
551 | if (!ty || ty != ::adjustedElementType(getType())) |
552 | return emitOpError("return type and/or indices do not type check" ); |
553 | if (!mlir::isa<fir::ArrayLoadOp>(getSequence().getDefiningOp())) |
554 | return emitOpError("argument #0 must be result of fir.array_load" ); |
555 | if (!validTypeParams(arrTy, getTypeparams())) |
556 | return emitOpError("invalid type parameters" ); |
557 | return mlir::success(); |
558 | } |
559 | |
560 | //===----------------------------------------------------------------------===// |
561 | // ArrayAccessOp |
562 | //===----------------------------------------------------------------------===// |
563 | |
564 | mlir::LogicalResult fir::ArrayAccessOp::verify() { |
565 | auto arrTy = getSequence().getType().cast<fir::SequenceType>(); |
566 | std::size_t indSize = getIndices().size(); |
567 | if (indSize < arrTy.getDimension()) |
568 | return emitOpError("number of indices != dimension of array" ); |
569 | if (indSize == arrTy.getDimension() && |
570 | getElement().getType() != fir::ReferenceType::get(arrTy.getEleTy())) |
571 | return emitOpError("return type does not match array" ); |
572 | mlir::Type ty = validArraySubobject(*this); |
573 | if (!ty || fir::ReferenceType::get(ty) != getType()) |
574 | return emitOpError("return type and/or indices do not type check" ); |
575 | if (!validTypeParams(arrTy, getTypeparams())) |
576 | return emitOpError("invalid type parameters" ); |
577 | return mlir::success(); |
578 | } |
579 | |
580 | //===----------------------------------------------------------------------===// |
581 | // ArrayUpdateOp |
582 | //===----------------------------------------------------------------------===// |
583 | |
584 | mlir::LogicalResult fir::ArrayUpdateOp::verify() { |
585 | if (fir::isa_ref_type(getMerge().getType())) |
586 | return emitOpError("does not support reference type for merge" ); |
587 | auto arrTy = getSequence().getType().cast<fir::SequenceType>(); |
588 | auto indSize = getIndices().size(); |
589 | if (indSize < arrTy.getDimension()) |
590 | return emitOpError("number of indices != dimension of array" ); |
591 | if (indSize == arrTy.getDimension() && |
592 | ::adjustedElementType(getMerge().getType()) != arrTy.getEleTy()) |
593 | return emitOpError("merged value does not have element type" ); |
594 | auto ty = validArraySubobject(*this); |
595 | if (!ty || ty != ::adjustedElementType(getMerge().getType())) |
596 | return emitOpError("merged value and/or indices do not type check" ); |
597 | if (!validTypeParams(arrTy, getTypeparams())) |
598 | return emitOpError("invalid type parameters" ); |
599 | return mlir::success(); |
600 | } |
601 | |
602 | //===----------------------------------------------------------------------===// |
603 | // ArrayModifyOp |
604 | //===----------------------------------------------------------------------===// |
605 | |
606 | mlir::LogicalResult fir::ArrayModifyOp::verify() { |
607 | auto arrTy = getSequence().getType().cast<fir::SequenceType>(); |
608 | auto indSize = getIndices().size(); |
609 | if (indSize < arrTy.getDimension()) |
610 | return emitOpError("number of indices must match array dimension" ); |
611 | return mlir::success(); |
612 | } |
613 | |
614 | //===----------------------------------------------------------------------===// |
615 | // BoxAddrOp |
616 | //===----------------------------------------------------------------------===// |
617 | |
618 | void fir::BoxAddrOp::build(mlir::OpBuilder &builder, |
619 | mlir::OperationState &result, mlir::Value val) { |
620 | mlir::Type type = |
621 | llvm::TypeSwitch<mlir::Type, mlir::Type>(val.getType()) |
622 | .Case<fir::BaseBoxType>([&](fir::BaseBoxType ty) -> mlir::Type { |
623 | mlir::Type eleTy = ty.getEleTy(); |
624 | if (fir::isa_ref_type(eleTy)) |
625 | return eleTy; |
626 | return fir::ReferenceType::get(eleTy); |
627 | }) |
628 | .Case<fir::BoxCharType>([&](fir::BoxCharType ty) -> mlir::Type { |
629 | return fir::ReferenceType::get(ty.getEleTy()); |
630 | }) |
631 | .Case<fir::BoxProcType>( |
632 | [&](fir::BoxProcType ty) { return ty.getEleTy(); }) |
633 | .Default([&](const auto &) { return mlir::Type{}; }); |
634 | assert(type && "bad val type" ); |
635 | build(builder, result, type, val); |
636 | } |
637 | |
638 | mlir::OpFoldResult fir::BoxAddrOp::fold(FoldAdaptor adaptor) { |
639 | if (auto *v = getVal().getDefiningOp()) { |
640 | if (auto box = mlir::dyn_cast<fir::EmboxOp>(v)) { |
641 | // Fold only if not sliced |
642 | if (!box.getSlice() && box.getMemref().getType() == getType()) { |
643 | propagateAttributes(getOperation(), box.getMemref().getDefiningOp()); |
644 | return box.getMemref(); |
645 | } |
646 | } |
647 | if (auto box = mlir::dyn_cast<fir::EmboxCharOp>(v)) |
648 | if (box.getMemref().getType() == getType()) |
649 | return box.getMemref(); |
650 | } |
651 | return {}; |
652 | } |
653 | |
654 | //===----------------------------------------------------------------------===// |
655 | // BoxCharLenOp |
656 | //===----------------------------------------------------------------------===// |
657 | |
658 | mlir::OpFoldResult fir::BoxCharLenOp::fold(FoldAdaptor adaptor) { |
659 | if (auto v = getVal().getDefiningOp()) { |
660 | if (auto box = mlir::dyn_cast<fir::EmboxCharOp>(v)) |
661 | return box.getLen(); |
662 | } |
663 | return {}; |
664 | } |
665 | |
666 | //===----------------------------------------------------------------------===// |
667 | // BoxDimsOp |
668 | //===----------------------------------------------------------------------===// |
669 | |
670 | /// Get the result types packed in a tuple tuple |
671 | mlir::Type fir::BoxDimsOp::getTupleType() { |
672 | // note: triple, but 4 is nearest power of 2 |
673 | llvm::SmallVector<mlir::Type> triple{ |
674 | getResult(0).getType(), getResult(1).getType(), getResult(2).getType()}; |
675 | return mlir::TupleType::get(getContext(), triple); |
676 | } |
677 | |
678 | //===----------------------------------------------------------------------===// |
679 | // CallOp |
680 | //===----------------------------------------------------------------------===// |
681 | |
682 | mlir::FunctionType fir::CallOp::getFunctionType() { |
683 | return mlir::FunctionType::get(getContext(), getOperandTypes(), |
684 | getResultTypes()); |
685 | } |
686 | |
687 | void fir::CallOp::print(mlir::OpAsmPrinter &p) { |
688 | bool isDirect = getCallee().has_value(); |
689 | p << ' '; |
690 | if (isDirect) |
691 | p << *getCallee(); |
692 | else |
693 | p << getOperand(0); |
694 | p << '(' << (*this)->getOperands().drop_front(isDirect ? 0 : 1) << ')'; |
695 | |
696 | // Print 'fastmath<...>' (if it has non-default value) before |
697 | // any other attributes. |
698 | mlir::arith::FastMathFlagsAttr fmfAttr = getFastmathAttr(); |
699 | if (fmfAttr.getValue() != mlir::arith::FastMathFlags::none) { |
700 | p << ' ' << mlir::arith::FastMathFlagsAttr::getMnemonic(); |
701 | p.printStrippedAttrOrType(fmfAttr); |
702 | } |
703 | |
704 | p.printOptionalAttrDict( |
705 | (*this)->getAttrs(), |
706 | {fir::CallOp::getCalleeAttrNameStr(), getFastmathAttrName()}); |
707 | auto resultTypes{getResultTypes()}; |
708 | llvm::SmallVector<mlir::Type> argTypes( |
709 | llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1)); |
710 | p << " : " << mlir::FunctionType::get(getContext(), argTypes, resultTypes); |
711 | } |
712 | |
713 | mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser, |
714 | mlir::OperationState &result) { |
715 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands; |
716 | if (parser.parseOperandList(operands)) |
717 | return mlir::failure(); |
718 | |
719 | mlir::NamedAttrList attrs; |
720 | mlir::SymbolRefAttr funcAttr; |
721 | bool isDirect = operands.empty(); |
722 | if (isDirect) |
723 | if (parser.parseAttribute(funcAttr, fir::CallOp::getCalleeAttrNameStr(), |
724 | attrs)) |
725 | return mlir::failure(); |
726 | |
727 | mlir::Type type; |
728 | if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren)) |
729 | return mlir::failure(); |
730 | |
731 | // Parse 'fastmath<...>', if present. |
732 | mlir::arith::FastMathFlagsAttr fmfAttr; |
733 | llvm::StringRef fmfAttrName = getFastmathAttrName(result.name); |
734 | if (mlir::succeeded(parser.parseOptionalKeyword(fmfAttrName))) |
735 | if (parser.parseCustomAttributeWithFallback(fmfAttr, mlir::Type{}, |
736 | fmfAttrName, attrs)) |
737 | return mlir::failure(); |
738 | |
739 | if (parser.parseOptionalAttrDict(attrs) || parser.parseColon() || |
740 | parser.parseType(type)) |
741 | return mlir::failure(); |
742 | |
743 | auto funcType = type.dyn_cast<mlir::FunctionType>(); |
744 | if (!funcType) |
745 | return parser.emitError(parser.getNameLoc(), "expected function type" ); |
746 | if (isDirect) { |
747 | if (parser.resolveOperands(operands, funcType.getInputs(), |
748 | parser.getNameLoc(), result.operands)) |
749 | return mlir::failure(); |
750 | } else { |
751 | auto funcArgs = |
752 | llvm::ArrayRef<mlir::OpAsmParser::UnresolvedOperand>(operands) |
753 | .drop_front(); |
754 | if (parser.resolveOperand(operands[0], funcType, result.operands) || |
755 | parser.resolveOperands(funcArgs, funcType.getInputs(), |
756 | parser.getNameLoc(), result.operands)) |
757 | return mlir::failure(); |
758 | } |
759 | result.addTypes(funcType.getResults()); |
760 | result.attributes = attrs; |
761 | return mlir::success(); |
762 | } |
763 | |
764 | void fir::CallOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
765 | mlir::func::FuncOp callee, mlir::ValueRange operands) { |
766 | result.addOperands(operands); |
767 | result.addAttribute(getCalleeAttrNameStr(), mlir::SymbolRefAttr::get(callee)); |
768 | result.addTypes(callee.getFunctionType().getResults()); |
769 | } |
770 | |
771 | void fir::CallOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
772 | mlir::SymbolRefAttr callee, |
773 | llvm::ArrayRef<mlir::Type> results, |
774 | mlir::ValueRange operands) { |
775 | result.addOperands(operands); |
776 | if (callee) |
777 | result.addAttribute(getCalleeAttrNameStr(), callee); |
778 | result.addTypes(results); |
779 | } |
780 | |
781 | //===----------------------------------------------------------------------===// |
782 | // CharConvertOp |
783 | //===----------------------------------------------------------------------===// |
784 | |
785 | mlir::LogicalResult fir::CharConvertOp::verify() { |
786 | auto unwrap = [&](mlir::Type t) { |
787 | t = fir::unwrapSequenceType(fir::dyn_cast_ptrEleTy(t)); |
788 | return t.dyn_cast<fir::CharacterType>(); |
789 | }; |
790 | auto inTy = unwrap(getFrom().getType()); |
791 | auto outTy = unwrap(getTo().getType()); |
792 | if (!(inTy && outTy)) |
793 | return emitOpError("not a reference to a character" ); |
794 | if (inTy.getFKind() == outTy.getFKind()) |
795 | return emitOpError("buffers must have different KIND values" ); |
796 | return mlir::success(); |
797 | } |
798 | |
799 | //===----------------------------------------------------------------------===// |
800 | // CmpOp |
801 | //===----------------------------------------------------------------------===// |
802 | |
803 | template <typename OPTY> |
804 | static void printCmpOp(mlir::OpAsmPrinter &p, OPTY op) { |
805 | p << ' '; |
806 | auto predSym = mlir::arith::symbolizeCmpFPredicate( |
807 | op->template getAttrOfType<mlir::IntegerAttr>( |
808 | OPTY::getPredicateAttrName()) |
809 | .getInt()); |
810 | assert(predSym.has_value() && "invalid symbol value for predicate" ); |
811 | p << '"' << mlir::arith::stringifyCmpFPredicate(predSym.value()) << '"' |
812 | << ", " ; |
813 | p.printOperand(op.getLhs()); |
814 | p << ", " ; |
815 | p.printOperand(op.getRhs()); |
816 | p.printOptionalAttrDict(attrs: op->getAttrs(), |
817 | /*elidedAttrs=*/{OPTY::getPredicateAttrName()}); |
818 | p << " : " << op.getLhs().getType(); |
819 | } |
820 | |
821 | template <typename OPTY> |
822 | static mlir::ParseResult parseCmpOp(mlir::OpAsmParser &parser, |
823 | mlir::OperationState &result) { |
824 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> ops; |
825 | mlir::NamedAttrList attrs; |
826 | mlir::Attribute predicateNameAttr; |
827 | mlir::Type type; |
828 | if (parser.parseAttribute(predicateNameAttr, OPTY::getPredicateAttrName(), |
829 | attrs) || |
830 | parser.parseComma() || parser.parseOperandList(result&: ops, requiredOperandCount: 2) || |
831 | parser.parseOptionalAttrDict(result&: attrs) || parser.parseColonType(result&: type) || |
832 | parser.resolveOperands(operands&: ops, type, result&: result.operands)) |
833 | return mlir::failure(); |
834 | |
835 | if (!predicateNameAttr.isa<mlir::StringAttr>()) |
836 | return parser.emitError(loc: parser.getNameLoc(), |
837 | message: "expected string comparison predicate attribute" ); |
838 | |
839 | // Rewrite string attribute to an enum value. |
840 | llvm::StringRef predicateName = |
841 | predicateNameAttr.cast<mlir::StringAttr>().getValue(); |
842 | auto predicate = fir::CmpcOp::getPredicateByName(predicateName); |
843 | auto builder = parser.getBuilder(); |
844 | mlir::Type i1Type = builder.getI1Type(); |
845 | attrs.set(OPTY::getPredicateAttrName(), |
846 | builder.getI64IntegerAttr(static_cast<std::int64_t>(predicate))); |
847 | result.attributes = attrs; |
848 | result.addTypes(newTypes: {i1Type}); |
849 | return mlir::success(); |
850 | } |
851 | |
852 | //===----------------------------------------------------------------------===// |
853 | // CmpcOp |
854 | //===----------------------------------------------------------------------===// |
855 | |
856 | void fir::buildCmpCOp(mlir::OpBuilder &builder, mlir::OperationState &result, |
857 | mlir::arith::CmpFPredicate predicate, mlir::Value lhs, |
858 | mlir::Value rhs) { |
859 | result.addOperands({lhs, rhs}); |
860 | result.types.push_back(builder.getI1Type()); |
861 | result.addAttribute( |
862 | fir::CmpcOp::getPredicateAttrName(), |
863 | builder.getI64IntegerAttr(static_cast<std::int64_t>(predicate))); |
864 | } |
865 | |
866 | mlir::arith::CmpFPredicate |
867 | fir::CmpcOp::getPredicateByName(llvm::StringRef name) { |
868 | auto pred = mlir::arith::symbolizeCmpFPredicate(name); |
869 | assert(pred.has_value() && "invalid predicate name" ); |
870 | return pred.value(); |
871 | } |
872 | |
873 | void fir::CmpcOp::print(mlir::OpAsmPrinter &p) { printCmpOp(p, *this); } |
874 | |
875 | mlir::ParseResult fir::CmpcOp::parse(mlir::OpAsmParser &parser, |
876 | mlir::OperationState &result) { |
877 | return parseCmpOp<fir::CmpcOp>(parser, result); |
878 | } |
879 | |
880 | //===----------------------------------------------------------------------===// |
881 | // ConstcOp |
882 | //===----------------------------------------------------------------------===// |
883 | |
884 | mlir::ParseResult fir::ConstcOp::parse(mlir::OpAsmParser &parser, |
885 | mlir::OperationState &result) { |
886 | fir::RealAttr realp; |
887 | fir::RealAttr imagp; |
888 | mlir::Type type; |
889 | if (parser.parseLParen() || |
890 | parser.parseAttribute(realp, fir::ConstcOp::getRealAttrName(), |
891 | result.attributes) || |
892 | parser.parseComma() || |
893 | parser.parseAttribute(imagp, fir::ConstcOp::getImagAttrName(), |
894 | result.attributes) || |
895 | parser.parseRParen() || parser.parseColonType(type) || |
896 | parser.addTypesToList(type, result.types)) |
897 | return mlir::failure(); |
898 | return mlir::success(); |
899 | } |
900 | |
901 | void fir::ConstcOp::print(mlir::OpAsmPrinter &p) { |
902 | p << '('; |
903 | p << getOperation()->getAttr(fir::ConstcOp::getRealAttrName()) << ", " ; |
904 | p << getOperation()->getAttr(fir::ConstcOp::getImagAttrName()) << ") : " ; |
905 | p.printType(getType()); |
906 | } |
907 | |
908 | mlir::LogicalResult fir::ConstcOp::verify() { |
909 | if (!getType().isa<fir::ComplexType>()) |
910 | return emitOpError("must be a !fir.complex type" ); |
911 | return mlir::success(); |
912 | } |
913 | |
914 | //===----------------------------------------------------------------------===// |
915 | // ConvertOp |
916 | //===----------------------------------------------------------------------===// |
917 | |
918 | void fir::ConvertOp::getCanonicalizationPatterns( |
919 | mlir::RewritePatternSet &results, mlir::MLIRContext *context) { |
920 | results.insert<ConvertConvertOptPattern, ConvertAscendingIndexOptPattern, |
921 | ConvertDescendingIndexOptPattern, RedundantConvertOptPattern, |
922 | CombineConvertOptPattern, CombineConvertTruncOptPattern, |
923 | ForwardConstantConvertPattern>(context); |
924 | } |
925 | |
926 | mlir::OpFoldResult fir::ConvertOp::fold(FoldAdaptor adaptor) { |
927 | if (getValue().getType() == getType()) |
928 | return getValue(); |
929 | if (matchPattern(getValue(), mlir::m_Op<fir::ConvertOp>())) { |
930 | auto inner = mlir::cast<fir::ConvertOp>(getValue().getDefiningOp()); |
931 | // (convert (convert 'a : logical -> i1) : i1 -> logical) ==> forward 'a |
932 | if (auto toTy = getType().dyn_cast<fir::LogicalType>()) |
933 | if (auto fromTy = inner.getValue().getType().dyn_cast<fir::LogicalType>()) |
934 | if (inner.getType().isa<mlir::IntegerType>() && (toTy == fromTy)) |
935 | return inner.getValue(); |
936 | // (convert (convert 'a : i1 -> logical) : logical -> i1) ==> forward 'a |
937 | if (auto toTy = getType().dyn_cast<mlir::IntegerType>()) |
938 | if (auto fromTy = |
939 | inner.getValue().getType().dyn_cast<mlir::IntegerType>()) |
940 | if (inner.getType().isa<fir::LogicalType>() && (toTy == fromTy) && |
941 | (fromTy.getWidth() == 1)) |
942 | return inner.getValue(); |
943 | } |
944 | return {}; |
945 | } |
946 | |
947 | bool fir::ConvertOp::isInteger(mlir::Type ty) { |
948 | return ty.isa<mlir::IntegerType, mlir::IndexType, fir::IntegerType>(); |
949 | } |
950 | |
951 | bool fir::ConvertOp::isIntegerCompatible(mlir::Type ty) { |
952 | return isInteger(ty) || mlir::isa<fir::LogicalType>(ty); |
953 | } |
954 | |
955 | bool fir::ConvertOp::isFloatCompatible(mlir::Type ty) { |
956 | return ty.isa<mlir::FloatType, fir::RealType>(); |
957 | } |
958 | |
959 | bool fir::ConvertOp::isPointerCompatible(mlir::Type ty) { |
960 | return ty.isa<fir::ReferenceType, fir::PointerType, fir::HeapType, |
961 | fir::LLVMPointerType, mlir::MemRefType, mlir::FunctionType, |
962 | fir::TypeDescType>(); |
963 | } |
964 | |
965 | static std::optional<mlir::Type> getVectorElementType(mlir::Type ty) { |
966 | mlir::Type elemTy; |
967 | if (mlir::isa<fir::VectorType>(ty)) |
968 | elemTy = mlir::dyn_cast<fir::VectorType>(ty).getEleTy(); |
969 | else if (mlir::isa<mlir::VectorType>(Val: ty)) |
970 | elemTy = mlir::dyn_cast<mlir::VectorType>(ty).getElementType(); |
971 | else |
972 | return std::nullopt; |
973 | |
974 | // e.g. fir.vector<4:ui32> => mlir.vector<4xi32> |
975 | // e.g. mlir.vector<4xui32> => mlir.vector<4xi32> |
976 | if (elemTy.isUnsignedInteger()) { |
977 | elemTy = mlir::IntegerType::get( |
978 | ty.getContext(), mlir::dyn_cast<mlir::IntegerType>(elemTy).getWidth()); |
979 | } |
980 | return elemTy; |
981 | } |
982 | |
983 | static std::optional<uint64_t> getVectorLen(mlir::Type ty) { |
984 | if (mlir::isa<fir::VectorType>(ty)) |
985 | return mlir::dyn_cast<fir::VectorType>(ty).getLen(); |
986 | else if (mlir::isa<mlir::VectorType>(Val: ty)) { |
987 | // fir.vector only supports 1-D vector |
988 | if (!(mlir::dyn_cast<mlir::VectorType>(ty).isScalable())) |
989 | return mlir::dyn_cast<mlir::VectorType>(ty).getShape()[0]; |
990 | } |
991 | |
992 | return std::nullopt; |
993 | } |
994 | |
995 | bool fir::ConvertOp::areVectorsCompatible(mlir::Type inTy, mlir::Type outTy) { |
996 | if (!(mlir::isa<fir::VectorType>(inTy) && |
997 | mlir::isa<mlir::VectorType>(outTy)) && |
998 | !(mlir::isa<mlir::VectorType>(inTy) && mlir::isa<fir::VectorType>(outTy))) |
999 | return false; |
1000 | |
1001 | // Only support integer, unsigned and real vector |
1002 | // Both vectors must have the same element type |
1003 | std::optional<mlir::Type> inElemTy = getVectorElementType(inTy); |
1004 | std::optional<mlir::Type> outElemTy = getVectorElementType(outTy); |
1005 | if (!inElemTy.has_value() || !outElemTy.has_value() || |
1006 | inElemTy.value() != outElemTy.value()) |
1007 | return false; |
1008 | |
1009 | // Both vectors must have the same number of elements |
1010 | std::optional<uint64_t> inLen = getVectorLen(inTy); |
1011 | std::optional<uint64_t> outLen = getVectorLen(outTy); |
1012 | if (!inLen.has_value() || !outLen.has_value() || |
1013 | inLen.value() != outLen.value()) |
1014 | return false; |
1015 | |
1016 | return true; |
1017 | } |
1018 | |
1019 | bool fir::ConvertOp::canBeConverted(mlir::Type inType, mlir::Type outType) { |
1020 | if (inType == outType) |
1021 | return true; |
1022 | return (isPointerCompatible(inType) && isPointerCompatible(outType)) || |
1023 | (isIntegerCompatible(inType) && isIntegerCompatible(outType)) || |
1024 | (isInteger(inType) && isFloatCompatible(outType)) || |
1025 | (isFloatCompatible(inType) && isInteger(outType)) || |
1026 | (isFloatCompatible(inType) && isFloatCompatible(outType)) || |
1027 | (isIntegerCompatible(inType) && isPointerCompatible(outType)) || |
1028 | (isPointerCompatible(inType) && isIntegerCompatible(outType)) || |
1029 | (inType.isa<fir::BoxType>() && outType.isa<fir::BoxType>()) || |
1030 | (inType.isa<fir::BoxProcType>() && outType.isa<fir::BoxProcType>()) || |
1031 | (fir::isa_complex(inType) && fir::isa_complex(outType)) || |
1032 | (fir::isBoxedRecordType(inType) && fir::isPolymorphicType(outType)) || |
1033 | (fir::isPolymorphicType(inType) && fir::isPolymorphicType(outType)) || |
1034 | (fir::isPolymorphicType(inType) && outType.isa<BoxType>()) || |
1035 | areVectorsCompatible(inType, outType); |
1036 | } |
1037 | |
1038 | mlir::LogicalResult fir::ConvertOp::verify() { |
1039 | if (canBeConverted(getValue().getType(), getType())) |
1040 | return mlir::success(); |
1041 | return emitOpError("invalid type conversion" ); |
1042 | } |
1043 | |
1044 | //===----------------------------------------------------------------------===// |
1045 | // CoordinateOp |
1046 | //===----------------------------------------------------------------------===// |
1047 | |
1048 | void fir::CoordinateOp::print(mlir::OpAsmPrinter &p) { |
1049 | p << ' ' << getRef() << ", " << getCoor(); |
1050 | p.printOptionalAttrDict((*this)->getAttrs(), /*elideAttrs=*/{"baseType" }); |
1051 | p << " : " ; |
1052 | p.printFunctionalType(getOperandTypes(), (*this)->getResultTypes()); |
1053 | } |
1054 | |
1055 | mlir::ParseResult fir::CoordinateOp::parse(mlir::OpAsmParser &parser, |
1056 | mlir::OperationState &result) { |
1057 | mlir::OpAsmParser::UnresolvedOperand memref; |
1058 | if (parser.parseOperand(memref) || parser.parseComma()) |
1059 | return mlir::failure(); |
1060 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> coorOperands; |
1061 | if (parser.parseOperandList(coorOperands)) |
1062 | return mlir::failure(); |
1063 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> allOperands; |
1064 | allOperands.push_back(memref); |
1065 | allOperands.append(coorOperands.begin(), coorOperands.end()); |
1066 | mlir::FunctionType funcTy; |
1067 | auto loc = parser.getCurrentLocation(); |
1068 | if (parser.parseOptionalAttrDict(result.attributes) || |
1069 | parser.parseColonType(funcTy) || |
1070 | parser.resolveOperands(allOperands, funcTy.getInputs(), loc, |
1071 | result.operands) || |
1072 | parser.addTypesToList(funcTy.getResults(), result.types)) |
1073 | return mlir::failure(); |
1074 | result.addAttribute("baseType" , mlir::TypeAttr::get(funcTy.getInput(0))); |
1075 | return mlir::success(); |
1076 | } |
1077 | |
1078 | mlir::LogicalResult fir::CoordinateOp::verify() { |
1079 | const mlir::Type refTy = getRef().getType(); |
1080 | if (fir::isa_ref_type(refTy)) { |
1081 | auto eleTy = fir::dyn_cast_ptrEleTy(refTy); |
1082 | if (auto arrTy = eleTy.dyn_cast<fir::SequenceType>()) { |
1083 | if (arrTy.hasUnknownShape()) |
1084 | return emitOpError("cannot find coordinate in unknown shape" ); |
1085 | if (arrTy.getConstantRows() < arrTy.getDimension() - 1) |
1086 | return emitOpError("cannot find coordinate with unknown extents" ); |
1087 | } |
1088 | if (!(fir::isa_aggregate(eleTy) || fir::isa_complex(eleTy) || |
1089 | fir::isa_char_string(eleTy))) |
1090 | return emitOpError("cannot apply to this element type" ); |
1091 | } |
1092 | auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(refTy); |
1093 | unsigned dimension = 0; |
1094 | const unsigned numCoors = getCoor().size(); |
1095 | for (auto coorOperand : llvm::enumerate(getCoor())) { |
1096 | auto co = coorOperand.value(); |
1097 | if (dimension == 0 && eleTy.isa<fir::SequenceType>()) { |
1098 | dimension = eleTy.cast<fir::SequenceType>().getDimension(); |
1099 | if (dimension == 0) |
1100 | return emitOpError("cannot apply to array of unknown rank" ); |
1101 | } |
1102 | if (auto *defOp = co.getDefiningOp()) { |
1103 | if (auto index = mlir::dyn_cast<fir::LenParamIndexOp>(defOp)) { |
1104 | // Recovering a LEN type parameter only makes sense from a boxed |
1105 | // value. For a bare reference, the LEN type parameters must be |
1106 | // passed as additional arguments to `index`. |
1107 | if (refTy.isa<fir::BoxType>()) { |
1108 | if (coorOperand.index() != numCoors - 1) |
1109 | return emitOpError("len_param_index must be last argument" ); |
1110 | if (getNumOperands() != 2) |
1111 | return emitOpError("too many operands for len_param_index case" ); |
1112 | } |
1113 | if (eleTy != index.getOnType()) |
1114 | emitOpError( |
1115 | "len_param_index type not compatible with reference type" ); |
1116 | return mlir::success(); |
1117 | } else if (auto index = mlir::dyn_cast<fir::FieldIndexOp>(defOp)) { |
1118 | if (eleTy != index.getOnType()) |
1119 | emitOpError("field_index type not compatible with reference type" ); |
1120 | if (auto recTy = eleTy.dyn_cast<fir::RecordType>()) { |
1121 | eleTy = recTy.getType(index.getFieldName()); |
1122 | continue; |
1123 | } |
1124 | return emitOpError("field_index not applied to !fir.type" ); |
1125 | } |
1126 | } |
1127 | if (dimension) { |
1128 | if (--dimension == 0) |
1129 | eleTy = eleTy.cast<fir::SequenceType>().getEleTy(); |
1130 | } else { |
1131 | if (auto t = eleTy.dyn_cast<mlir::TupleType>()) { |
1132 | // FIXME: Generally, we don't know which field of the tuple is being |
1133 | // referred to unless the operand is a constant. Just assume everything |
1134 | // is good in the tuple case for now. |
1135 | return mlir::success(); |
1136 | } else if (auto t = eleTy.dyn_cast<fir::RecordType>()) { |
1137 | // FIXME: This is the same as the tuple case. |
1138 | return mlir::success(); |
1139 | } else if (auto t = eleTy.dyn_cast<fir::ComplexType>()) { |
1140 | eleTy = t.getElementType(); |
1141 | } else if (auto t = eleTy.dyn_cast<mlir::ComplexType>()) { |
1142 | eleTy = t.getElementType(); |
1143 | } else if (auto t = eleTy.dyn_cast<fir::CharacterType>()) { |
1144 | if (t.getLen() == fir::CharacterType::singleton()) |
1145 | return emitOpError("cannot apply to character singleton" ); |
1146 | eleTy = fir::CharacterType::getSingleton(t.getContext(), t.getFKind()); |
1147 | if (fir::unwrapRefType(getType()) != eleTy) |
1148 | return emitOpError("character type mismatch" ); |
1149 | } else { |
1150 | return emitOpError("invalid parameters (too many)" ); |
1151 | } |
1152 | } |
1153 | } |
1154 | return mlir::success(); |
1155 | } |
1156 | |
1157 | //===----------------------------------------------------------------------===// |
1158 | // DispatchOp |
1159 | //===----------------------------------------------------------------------===// |
1160 | |
1161 | mlir::LogicalResult fir::DispatchOp::verify() { |
1162 | // Check that pass_arg_pos is in range of actual operands. pass_arg_pos is |
1163 | // unsigned so check for less than zero is not needed. |
1164 | if (getPassArgPos() && *getPassArgPos() > (getArgOperands().size() - 1)) |
1165 | return emitOpError( |
1166 | "pass_arg_pos must be smaller than the number of operands" ); |
1167 | |
1168 | // Operand pointed by pass_arg_pos must have polymorphic type. |
1169 | if (getPassArgPos() && |
1170 | !fir::isPolymorphicType(getArgOperands()[*getPassArgPos()].getType())) |
1171 | return emitOpError("pass_arg_pos must be a polymorphic operand" ); |
1172 | return mlir::success(); |
1173 | } |
1174 | |
1175 | mlir::FunctionType fir::DispatchOp::getFunctionType() { |
1176 | return mlir::FunctionType::get(getContext(), getOperandTypes(), |
1177 | getResultTypes()); |
1178 | } |
1179 | |
1180 | //===----------------------------------------------------------------------===// |
1181 | // TypeInfoOp |
1182 | //===----------------------------------------------------------------------===// |
1183 | |
1184 | void fir::TypeInfoOp::build(mlir::OpBuilder &builder, |
1185 | mlir::OperationState &result, fir::RecordType type, |
1186 | fir::RecordType parentType, |
1187 | llvm::ArrayRef<mlir::NamedAttribute> attrs) { |
1188 | result.addRegion(); |
1189 | result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), |
1190 | builder.getStringAttr(type.getName())); |
1191 | result.addAttribute(getTypeAttrName(result.name), mlir::TypeAttr::get(type)); |
1192 | if (parentType) |
1193 | result.addAttribute(getParentTypeAttrName(result.name), |
1194 | mlir::TypeAttr::get(parentType)); |
1195 | result.addAttributes(attrs); |
1196 | } |
1197 | |
1198 | mlir::LogicalResult fir::TypeInfoOp::verify() { |
1199 | if (!getDispatchTable().empty()) |
1200 | for (auto &op : getDispatchTable().front().without_terminator()) |
1201 | if (!mlir::isa<fir::DTEntryOp>(op)) |
1202 | return op.emitOpError("dispatch table must contain dt_entry" ); |
1203 | |
1204 | if (!mlir::isa<fir::RecordType>(getType())) |
1205 | return emitOpError("type must be a fir.type" ); |
1206 | |
1207 | if (getParentType() && !mlir::isa<fir::RecordType>(*getParentType())) |
1208 | return emitOpError("parent_type must be a fir.type" ); |
1209 | return mlir::success(); |
1210 | } |
1211 | |
1212 | //===----------------------------------------------------------------------===// |
1213 | // EmboxOp |
1214 | //===----------------------------------------------------------------------===// |
1215 | |
1216 | mlir::LogicalResult fir::EmboxOp::verify() { |
1217 | auto eleTy = fir::dyn_cast_ptrEleTy(getMemref().getType()); |
1218 | bool isArray = false; |
1219 | if (auto seqTy = eleTy.dyn_cast<fir::SequenceType>()) { |
1220 | eleTy = seqTy.getEleTy(); |
1221 | isArray = true; |
1222 | } |
1223 | if (hasLenParams()) { |
1224 | auto lenPs = numLenParams(); |
1225 | if (auto rt = eleTy.dyn_cast<fir::RecordType>()) { |
1226 | if (lenPs != rt.getNumLenParams()) |
1227 | return emitOpError("number of LEN params does not correspond" |
1228 | " to the !fir.type type" ); |
1229 | } else if (auto strTy = eleTy.dyn_cast<fir::CharacterType>()) { |
1230 | if (strTy.getLen() != fir::CharacterType::unknownLen()) |
1231 | return emitOpError("CHARACTER already has static LEN" ); |
1232 | } else { |
1233 | return emitOpError("LEN parameters require CHARACTER or derived type" ); |
1234 | } |
1235 | for (auto lp : getTypeparams()) |
1236 | if (!fir::isa_integer(lp.getType())) |
1237 | return emitOpError("LEN parameters must be integral type" ); |
1238 | } |
1239 | if (getShape() && !isArray) |
1240 | return emitOpError("shape must not be provided for a scalar" ); |
1241 | if (getSlice() && !isArray) |
1242 | return emitOpError("slice must not be provided for a scalar" ); |
1243 | if (getSourceBox() && !getResult().getType().isa<fir::ClassType>()) |
1244 | return emitOpError("source_box must be used with fir.class result type" ); |
1245 | return mlir::success(); |
1246 | } |
1247 | |
1248 | //===----------------------------------------------------------------------===// |
1249 | // EmboxCharOp |
1250 | //===----------------------------------------------------------------------===// |
1251 | |
1252 | mlir::LogicalResult fir::EmboxCharOp::verify() { |
1253 | auto eleTy = fir::dyn_cast_ptrEleTy(getMemref().getType()); |
1254 | if (!eleTy.dyn_cast_or_null<fir::CharacterType>()) |
1255 | return mlir::failure(); |
1256 | return mlir::success(); |
1257 | } |
1258 | |
1259 | //===----------------------------------------------------------------------===// |
1260 | // EmboxProcOp |
1261 | //===----------------------------------------------------------------------===// |
1262 | |
1263 | mlir::LogicalResult fir::EmboxProcOp::verify() { |
1264 | // host bindings (optional) must be a reference to a tuple |
1265 | if (auto h = getHost()) { |
1266 | if (auto r = h.getType().dyn_cast<fir::ReferenceType>()) |
1267 | if (r.getEleTy().isa<mlir::TupleType>()) |
1268 | return mlir::success(); |
1269 | return mlir::failure(); |
1270 | } |
1271 | return mlir::success(); |
1272 | } |
1273 | |
1274 | //===----------------------------------------------------------------------===// |
1275 | // TypeDescOp |
1276 | //===----------------------------------------------------------------------===// |
1277 | |
1278 | void fir::TypeDescOp::build(mlir::OpBuilder &, mlir::OperationState &result, |
1279 | mlir::TypeAttr inty) { |
1280 | result.addAttribute("in_type" , inty); |
1281 | result.addTypes(TypeDescType::get(inty.getValue())); |
1282 | } |
1283 | |
1284 | mlir::ParseResult fir::TypeDescOp::parse(mlir::OpAsmParser &parser, |
1285 | mlir::OperationState &result) { |
1286 | mlir::Type intype; |
1287 | if (parser.parseType(intype)) |
1288 | return mlir::failure(); |
1289 | result.addAttribute("in_type" , mlir::TypeAttr::get(intype)); |
1290 | mlir::Type restype = fir::TypeDescType::get(intype); |
1291 | if (parser.addTypeToList(restype, result.types)) |
1292 | return mlir::failure(); |
1293 | return mlir::success(); |
1294 | } |
1295 | |
1296 | void fir::TypeDescOp::print(mlir::OpAsmPrinter &p) { |
1297 | p << ' ' << getOperation()->getAttr("in_type" ); |
1298 | p.printOptionalAttrDict(getOperation()->getAttrs(), {"in_type" }); |
1299 | } |
1300 | |
1301 | mlir::LogicalResult fir::TypeDescOp::verify() { |
1302 | mlir::Type resultTy = getType(); |
1303 | if (auto tdesc = resultTy.dyn_cast<fir::TypeDescType>()) { |
1304 | if (tdesc.getOfTy() != getInType()) |
1305 | return emitOpError("wrapped type mismatched" ); |
1306 | return mlir::success(); |
1307 | } |
1308 | return emitOpError("must be !fir.tdesc type" ); |
1309 | } |
1310 | |
1311 | //===----------------------------------------------------------------------===// |
1312 | // GlobalOp |
1313 | //===----------------------------------------------------------------------===// |
1314 | |
1315 | mlir::Type fir::GlobalOp::resultType() { |
1316 | return wrapAllocaResultType(getType()); |
1317 | } |
1318 | |
1319 | mlir::ParseResult fir::GlobalOp::parse(mlir::OpAsmParser &parser, |
1320 | mlir::OperationState &result) { |
1321 | // Parse the optional linkage |
1322 | llvm::StringRef linkage; |
1323 | auto &builder = parser.getBuilder(); |
1324 | if (mlir::succeeded(parser.parseOptionalKeyword(&linkage))) { |
1325 | if (fir::GlobalOp::verifyValidLinkage(linkage)) |
1326 | return mlir::failure(); |
1327 | mlir::StringAttr linkAttr = builder.getStringAttr(linkage); |
1328 | result.addAttribute(fir::GlobalOp::getLinkNameAttrName(result.name), |
1329 | linkAttr); |
1330 | } |
1331 | |
1332 | // Parse the name as a symbol reference attribute. |
1333 | mlir::SymbolRefAttr nameAttr; |
1334 | if (parser.parseAttribute(nameAttr, |
1335 | fir::GlobalOp::getSymrefAttrName(result.name), |
1336 | result.attributes)) |
1337 | return mlir::failure(); |
1338 | result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), |
1339 | nameAttr.getRootReference()); |
1340 | |
1341 | bool simpleInitializer = false; |
1342 | if (mlir::succeeded(parser.parseOptionalLParen())) { |
1343 | mlir::Attribute attr; |
1344 | if (parser.parseAttribute(attr, getInitValAttrName(result.name), |
1345 | result.attributes) || |
1346 | parser.parseRParen()) |
1347 | return mlir::failure(); |
1348 | simpleInitializer = true; |
1349 | } |
1350 | |
1351 | if (parser.parseOptionalAttrDict(result.attributes)) |
1352 | return mlir::failure(); |
1353 | |
1354 | if (succeeded( |
1355 | parser.parseOptionalKeyword(getConstantAttrName(result.name)))) { |
1356 | // if "constant" keyword then mark this as a constant, not a variable |
1357 | result.addAttribute(getConstantAttrName(result.name), |
1358 | builder.getUnitAttr()); |
1359 | } |
1360 | |
1361 | if (succeeded(parser.parseOptionalKeyword(getTargetAttrName(result.name)))) |
1362 | result.addAttribute(getTargetAttrName(result.name), builder.getUnitAttr()); |
1363 | |
1364 | mlir::Type globalType; |
1365 | if (parser.parseColonType(globalType)) |
1366 | return mlir::failure(); |
1367 | |
1368 | result.addAttribute(fir::GlobalOp::getTypeAttrName(result.name), |
1369 | mlir::TypeAttr::get(globalType)); |
1370 | |
1371 | if (simpleInitializer) { |
1372 | result.addRegion(); |
1373 | } else { |
1374 | // Parse the optional initializer body. |
1375 | auto parseResult = |
1376 | parser.parseOptionalRegion(*result.addRegion(), /*arguments=*/{}); |
1377 | if (parseResult.has_value() && mlir::failed(*parseResult)) |
1378 | return mlir::failure(); |
1379 | } |
1380 | return mlir::success(); |
1381 | } |
1382 | |
1383 | void fir::GlobalOp::print(mlir::OpAsmPrinter &p) { |
1384 | if (getLinkName()) |
1385 | p << ' ' << *getLinkName(); |
1386 | p << ' '; |
1387 | p.printAttributeWithoutType(getSymrefAttr()); |
1388 | if (auto val = getValueOrNull()) |
1389 | p << '(' << val << ')'; |
1390 | // Print all other attributes that are not pretty printed here. |
1391 | p.printOptionalAttrDict((*this)->getAttrs(), /*elideAttrs=*/{ |
1392 | getSymNameAttrName(), getSymrefAttrName(), |
1393 | getTypeAttrName(), getConstantAttrName(), |
1394 | getTargetAttrName(), getLinkNameAttrName(), |
1395 | getInitValAttrName()}); |
1396 | if (getOperation()->getAttr(getConstantAttrName())) |
1397 | p << " " << getConstantAttrName().strref(); |
1398 | if (getOperation()->getAttr(getTargetAttrName())) |
1399 | p << " " << getTargetAttrName().strref(); |
1400 | p << " : " ; |
1401 | p.printType(getType()); |
1402 | if (hasInitializationBody()) { |
1403 | p << ' '; |
1404 | p.printRegion(getOperation()->getRegion(0), |
1405 | /*printEntryBlockArgs=*/false, |
1406 | /*printBlockTerminators=*/true); |
1407 | } |
1408 | } |
1409 | |
1410 | void fir::GlobalOp::appendInitialValue(mlir::Operation *op) { |
1411 | getBlock().getOperations().push_back(op); |
1412 | } |
1413 | |
1414 | void fir::GlobalOp::build(mlir::OpBuilder &builder, |
1415 | mlir::OperationState &result, llvm::StringRef name, |
1416 | bool isConstant, bool isTarget, mlir::Type type, |
1417 | mlir::Attribute initialVal, mlir::StringAttr linkage, |
1418 | llvm::ArrayRef<mlir::NamedAttribute> attrs) { |
1419 | result.addRegion(); |
1420 | result.addAttribute(getTypeAttrName(result.name), mlir::TypeAttr::get(type)); |
1421 | result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), |
1422 | builder.getStringAttr(name)); |
1423 | result.addAttribute(getSymrefAttrName(result.name), |
1424 | mlir::SymbolRefAttr::get(builder.getContext(), name)); |
1425 | if (isConstant) |
1426 | result.addAttribute(getConstantAttrName(result.name), |
1427 | builder.getUnitAttr()); |
1428 | if (isTarget) |
1429 | result.addAttribute(getTargetAttrName(result.name), builder.getUnitAttr()); |
1430 | if (initialVal) |
1431 | result.addAttribute(getInitValAttrName(result.name), initialVal); |
1432 | if (linkage) |
1433 | result.addAttribute(getLinkNameAttrName(result.name), linkage); |
1434 | result.attributes.append(attrs.begin(), attrs.end()); |
1435 | } |
1436 | |
1437 | void fir::GlobalOp::build(mlir::OpBuilder &builder, |
1438 | mlir::OperationState &result, llvm::StringRef name, |
1439 | mlir::Type type, mlir::Attribute initialVal, |
1440 | mlir::StringAttr linkage, |
1441 | llvm::ArrayRef<mlir::NamedAttribute> attrs) { |
1442 | build(builder, result, name, /*isConstant=*/false, /*isTarget=*/false, type, |
1443 | {}, linkage, attrs); |
1444 | } |
1445 | |
1446 | void fir::GlobalOp::build(mlir::OpBuilder &builder, |
1447 | mlir::OperationState &result, llvm::StringRef name, |
1448 | bool isConstant, bool isTarget, mlir::Type type, |
1449 | mlir::StringAttr linkage, |
1450 | llvm::ArrayRef<mlir::NamedAttribute> attrs) { |
1451 | build(builder, result, name, isConstant, isTarget, type, {}, linkage, attrs); |
1452 | } |
1453 | |
1454 | void fir::GlobalOp::build(mlir::OpBuilder &builder, |
1455 | mlir::OperationState &result, llvm::StringRef name, |
1456 | mlir::Type type, mlir::StringAttr linkage, |
1457 | llvm::ArrayRef<mlir::NamedAttribute> attrs) { |
1458 | build(builder, result, name, /*isConstant=*/false, /*isTarget=*/false, type, |
1459 | {}, linkage, attrs); |
1460 | } |
1461 | |
1462 | void fir::GlobalOp::build(mlir::OpBuilder &builder, |
1463 | mlir::OperationState &result, llvm::StringRef name, |
1464 | bool isConstant, bool isTarget, mlir::Type type, |
1465 | llvm::ArrayRef<mlir::NamedAttribute> attrs) { |
1466 | build(builder, result, name, isConstant, isTarget, type, mlir::StringAttr{}, |
1467 | attrs); |
1468 | } |
1469 | |
1470 | void fir::GlobalOp::build(mlir::OpBuilder &builder, |
1471 | mlir::OperationState &result, llvm::StringRef name, |
1472 | mlir::Type type, |
1473 | llvm::ArrayRef<mlir::NamedAttribute> attrs) { |
1474 | build(builder, result, name, /*isConstant=*/false, /*isTarget=*/false, type, |
1475 | attrs); |
1476 | } |
1477 | |
1478 | mlir::ParseResult fir::GlobalOp::verifyValidLinkage(llvm::StringRef linkage) { |
1479 | // Supporting only a subset of the LLVM linkage types for now |
1480 | static const char *validNames[] = {"common" , "internal" , "linkonce" , |
1481 | "linkonce_odr" , "weak" }; |
1482 | return mlir::success(llvm::is_contained(validNames, linkage)); |
1483 | } |
1484 | |
1485 | //===----------------------------------------------------------------------===// |
1486 | // GlobalLenOp |
1487 | //===----------------------------------------------------------------------===// |
1488 | |
1489 | mlir::ParseResult fir::GlobalLenOp::parse(mlir::OpAsmParser &parser, |
1490 | mlir::OperationState &result) { |
1491 | llvm::StringRef fieldName; |
1492 | if (failed(parser.parseOptionalKeyword(&fieldName))) { |
1493 | mlir::StringAttr fieldAttr; |
1494 | if (parser.parseAttribute(fieldAttr, |
1495 | fir::GlobalLenOp::getLenParamAttrName(), |
1496 | result.attributes)) |
1497 | return mlir::failure(); |
1498 | } else { |
1499 | result.addAttribute(fir::GlobalLenOp::getLenParamAttrName(), |
1500 | parser.getBuilder().getStringAttr(fieldName)); |
1501 | } |
1502 | mlir::IntegerAttr constant; |
1503 | if (parser.parseComma() || |
1504 | parser.parseAttribute(constant, fir::GlobalLenOp::getIntAttrName(), |
1505 | result.attributes)) |
1506 | return mlir::failure(); |
1507 | return mlir::success(); |
1508 | } |
1509 | |
1510 | void fir::GlobalLenOp::print(mlir::OpAsmPrinter &p) { |
1511 | p << ' ' << getOperation()->getAttr(fir::GlobalLenOp::getLenParamAttrName()) |
1512 | << ", " << getOperation()->getAttr(fir::GlobalLenOp::getIntAttrName()); |
1513 | } |
1514 | |
1515 | //===----------------------------------------------------------------------===// |
1516 | // FieldIndexOp |
1517 | //===----------------------------------------------------------------------===// |
1518 | |
1519 | template <typename TY> |
1520 | mlir::ParseResult parseFieldLikeOp(mlir::OpAsmParser &parser, |
1521 | mlir::OperationState &result) { |
1522 | llvm::StringRef fieldName; |
1523 | auto &builder = parser.getBuilder(); |
1524 | mlir::Type recty; |
1525 | if (parser.parseOptionalKeyword(keyword: &fieldName) || parser.parseComma() || |
1526 | parser.parseType(result&: recty)) |
1527 | return mlir::failure(); |
1528 | result.addAttribute(fir::FieldIndexOp::getFieldAttrName(), |
1529 | builder.getStringAttr(fieldName)); |
1530 | if (!recty.dyn_cast<fir::RecordType>()) |
1531 | return mlir::failure(); |
1532 | result.addAttribute(fir::FieldIndexOp::getTypeAttrName(), |
1533 | mlir::TypeAttr::get(recty)); |
1534 | if (!parser.parseOptionalLParen()) { |
1535 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands; |
1536 | llvm::SmallVector<mlir::Type> types; |
1537 | auto loc = parser.getNameLoc(); |
1538 | if (parser.parseOperandList(result&: operands, delimiter: mlir::OpAsmParser::Delimiter::None) || |
1539 | parser.parseColonTypeList(result&: types) || parser.parseRParen() || |
1540 | parser.resolveOperands(operands, types, loc, result&: result.operands)) |
1541 | return mlir::failure(); |
1542 | } |
1543 | mlir::Type fieldType = TY::get(builder.getContext()); |
1544 | if (parser.addTypeToList(type: fieldType, result&: result.types)) |
1545 | return mlir::failure(); |
1546 | return mlir::success(); |
1547 | } |
1548 | |
1549 | mlir::ParseResult fir::FieldIndexOp::parse(mlir::OpAsmParser &parser, |
1550 | mlir::OperationState &result) { |
1551 | return parseFieldLikeOp<fir::FieldType>(parser, result); |
1552 | } |
1553 | |
1554 | template <typename OP> |
1555 | void printFieldLikeOp(mlir::OpAsmPrinter &p, OP &op) { |
1556 | p << ' ' |
1557 | << op.getOperation() |
1558 | ->template getAttrOfType<mlir::StringAttr>( |
1559 | fir::FieldIndexOp::getFieldAttrName()) |
1560 | .getValue() |
1561 | << ", " << op.getOperation()->getAttr(fir::FieldIndexOp::getTypeAttrName()); |
1562 | if (op.getNumOperands()) { |
1563 | p << '('; |
1564 | p.printOperands(op.getTypeparams()); |
1565 | auto sep = ") : " ; |
1566 | for (auto op : op.getTypeparams()) { |
1567 | p << sep; |
1568 | if (op) |
1569 | p.printType(type: op.getType()); |
1570 | else |
1571 | p << "()" ; |
1572 | sep = ", " ; |
1573 | } |
1574 | } |
1575 | } |
1576 | |
1577 | void fir::FieldIndexOp::print(mlir::OpAsmPrinter &p) { |
1578 | printFieldLikeOp(p, *this); |
1579 | } |
1580 | |
1581 | void fir::FieldIndexOp::build(mlir::OpBuilder &builder, |
1582 | mlir::OperationState &result, |
1583 | llvm::StringRef fieldName, mlir::Type recTy, |
1584 | mlir::ValueRange operands) { |
1585 | result.addAttribute(getFieldAttrName(), builder.getStringAttr(fieldName)); |
1586 | result.addAttribute(getTypeAttrName(), mlir::TypeAttr::get(recTy)); |
1587 | result.addOperands(operands); |
1588 | } |
1589 | |
1590 | llvm::SmallVector<mlir::Attribute> fir::FieldIndexOp::getAttributes() { |
1591 | llvm::SmallVector<mlir::Attribute> attrs; |
1592 | attrs.push_back(getFieldIdAttr()); |
1593 | attrs.push_back(getOnTypeAttr()); |
1594 | return attrs; |
1595 | } |
1596 | |
1597 | //===----------------------------------------------------------------------===// |
1598 | // InsertOnRangeOp |
1599 | //===----------------------------------------------------------------------===// |
1600 | |
1601 | static mlir::ParseResult |
1602 | parseCustomRangeSubscript(mlir::OpAsmParser &parser, |
1603 | mlir::DenseIntElementsAttr &coord) { |
1604 | llvm::SmallVector<std::int64_t> lbounds; |
1605 | llvm::SmallVector<std::int64_t> ubounds; |
1606 | if (parser.parseKeyword(keyword: "from" ) || |
1607 | parser.parseCommaSeparatedList( |
1608 | delimiter: mlir::AsmParser::Delimiter::Paren, |
1609 | parseElementFn: [&] { return parser.parseInteger(result&: lbounds.emplace_back(Args: 0)); }) || |
1610 | parser.parseKeyword(keyword: "to" ) || |
1611 | parser.parseCommaSeparatedList(delimiter: mlir::AsmParser::Delimiter::Paren, parseElementFn: [&] { |
1612 | return parser.parseInteger(result&: ubounds.emplace_back(Args: 0)); |
1613 | })) |
1614 | return mlir::failure(); |
1615 | llvm::SmallVector<std::int64_t> zippedBounds; |
1616 | for (auto zip : llvm::zip(t&: lbounds, u&: ubounds)) { |
1617 | zippedBounds.push_back(Elt: std::get<0>(t&: zip)); |
1618 | zippedBounds.push_back(Elt: std::get<1>(t&: zip)); |
1619 | } |
1620 | coord = mlir::Builder(parser.getContext()).getIndexTensorAttr(values: zippedBounds); |
1621 | return mlir::success(); |
1622 | } |
1623 | |
1624 | static void printCustomRangeSubscript(mlir::OpAsmPrinter &printer, |
1625 | fir::InsertOnRangeOp op, |
1626 | mlir::DenseIntElementsAttr coord) { |
1627 | printer << "from (" ; |
1628 | auto enumerate = llvm::enumerate(coord.getValues<std::int64_t>()); |
1629 | // Even entries are the lower bounds. |
1630 | llvm::interleaveComma( |
1631 | make_filter_range( |
1632 | enumerate, |
1633 | [](auto indexed_value) { return indexed_value.index() % 2 == 0; }), |
1634 | printer, [&](auto indexed_value) { printer << indexed_value.value(); }); |
1635 | printer << ") to (" ; |
1636 | // Odd entries are the upper bounds. |
1637 | llvm::interleaveComma( |
1638 | make_filter_range( |
1639 | enumerate, |
1640 | [](auto indexed_value) { return indexed_value.index() % 2 != 0; }), |
1641 | printer, [&](auto indexed_value) { printer << indexed_value.value(); }); |
1642 | printer << ")" ; |
1643 | } |
1644 | |
1645 | /// Range bounds must be nonnegative, and the range must not be empty. |
1646 | mlir::LogicalResult fir::InsertOnRangeOp::verify() { |
1647 | if (fir::hasDynamicSize(getSeq().getType())) |
1648 | return emitOpError("must have constant shape and size" ); |
1649 | mlir::DenseIntElementsAttr coorAttr = getCoor(); |
1650 | if (coorAttr.size() < 2 || coorAttr.size() % 2 != 0) |
1651 | return emitOpError("has uneven number of values in ranges" ); |
1652 | bool rangeIsKnownToBeNonempty = false; |
1653 | for (auto i = coorAttr.getValues<std::int64_t>().end(), |
1654 | b = coorAttr.getValues<std::int64_t>().begin(); |
1655 | i != b;) { |
1656 | int64_t ub = (*--i); |
1657 | int64_t lb = (*--i); |
1658 | if (lb < 0 || ub < 0) |
1659 | return emitOpError("negative range bound" ); |
1660 | if (rangeIsKnownToBeNonempty) |
1661 | continue; |
1662 | if (lb > ub) |
1663 | return emitOpError("empty range" ); |
1664 | rangeIsKnownToBeNonempty = lb < ub; |
1665 | } |
1666 | return mlir::success(); |
1667 | } |
1668 | |
1669 | //===----------------------------------------------------------------------===// |
1670 | // InsertValueOp |
1671 | //===----------------------------------------------------------------------===// |
1672 | |
1673 | static bool checkIsIntegerConstant(mlir::Attribute attr, std::int64_t conVal) { |
1674 | if (auto iattr = attr.dyn_cast<mlir::IntegerAttr>()) |
1675 | return iattr.getInt() == conVal; |
1676 | return false; |
1677 | } |
1678 | |
1679 | static bool isZero(mlir::Attribute a) { return checkIsIntegerConstant(attr: a, conVal: 0); } |
1680 | static bool isOne(mlir::Attribute a) { return checkIsIntegerConstant(attr: a, conVal: 1); } |
1681 | |
1682 | // Undo some complex patterns created in the front-end and turn them back into |
1683 | // complex ops. |
1684 | template <typename FltOp, typename CpxOp> |
1685 | struct UndoComplexPattern : public mlir::RewritePattern { |
1686 | UndoComplexPattern(mlir::MLIRContext *ctx) |
1687 | : mlir::RewritePattern("fir.insert_value" , 2, ctx) {} |
1688 | |
1689 | mlir::LogicalResult |
1690 | matchAndRewrite(mlir::Operation *op, |
1691 | mlir::PatternRewriter &rewriter) const override { |
1692 | auto insval = mlir::dyn_cast_or_null<fir::InsertValueOp>(op); |
1693 | if (!insval || !insval.getType().isa<fir::ComplexType>()) |
1694 | return mlir::failure(); |
1695 | auto insval2 = mlir::dyn_cast_or_null<fir::InsertValueOp>( |
1696 | insval.getAdt().getDefiningOp()); |
1697 | if (!insval2) |
1698 | return mlir::failure(); |
1699 | auto binf = mlir::dyn_cast_or_null<FltOp>(insval.getVal().getDefiningOp()); |
1700 | auto binf2 = |
1701 | mlir::dyn_cast_or_null<FltOp>(insval2.getVal().getDefiningOp()); |
1702 | if (!binf || !binf2 || insval.getCoor().size() != 1 || |
1703 | !isOne(insval.getCoor()[0]) || insval2.getCoor().size() != 1 || |
1704 | !isZero(insval2.getCoor()[0])) |
1705 | return mlir::failure(); |
1706 | auto eai = mlir::dyn_cast_or_null<fir::ExtractValueOp>( |
1707 | binf.getLhs().getDefiningOp()); |
1708 | auto ebi = mlir::dyn_cast_or_null<fir::ExtractValueOp>( |
1709 | binf.getRhs().getDefiningOp()); |
1710 | auto ear = mlir::dyn_cast_or_null<fir::ExtractValueOp>( |
1711 | binf2.getLhs().getDefiningOp()); |
1712 | auto ebr = mlir::dyn_cast_or_null<fir::ExtractValueOp>( |
1713 | binf2.getRhs().getDefiningOp()); |
1714 | if (!eai || !ebi || !ear || !ebr || ear.getAdt() != eai.getAdt() || |
1715 | ebr.getAdt() != ebi.getAdt() || eai.getCoor().size() != 1 || |
1716 | !isOne(eai.getCoor()[0]) || ebi.getCoor().size() != 1 || |
1717 | !isOne(ebi.getCoor()[0]) || ear.getCoor().size() != 1 || |
1718 | !isZero(ear.getCoor()[0]) || ebr.getCoor().size() != 1 || |
1719 | !isZero(ebr.getCoor()[0])) |
1720 | return mlir::failure(); |
1721 | rewriter.replaceOpWithNewOp<CpxOp>(op, ear.getAdt(), ebr.getAdt()); |
1722 | return mlir::success(); |
1723 | } |
1724 | }; |
1725 | |
1726 | void fir::InsertValueOp::getCanonicalizationPatterns( |
1727 | mlir::RewritePatternSet &results, mlir::MLIRContext *context) { |
1728 | results.insert<UndoComplexPattern<mlir::arith::AddFOp, fir::AddcOp>, |
1729 | UndoComplexPattern<mlir::arith::SubFOp, fir::SubcOp>>(context); |
1730 | } |
1731 | |
1732 | //===----------------------------------------------------------------------===// |
1733 | // IterWhileOp |
1734 | //===----------------------------------------------------------------------===// |
1735 | |
1736 | void fir::IterWhileOp::build(mlir::OpBuilder &builder, |
1737 | mlir::OperationState &result, mlir::Value lb, |
1738 | mlir::Value ub, mlir::Value step, |
1739 | mlir::Value iterate, bool finalCountValue, |
1740 | mlir::ValueRange iterArgs, |
1741 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
1742 | result.addOperands({lb, ub, step, iterate}); |
1743 | if (finalCountValue) { |
1744 | result.addTypes(builder.getIndexType()); |
1745 | result.addAttribute(getFinalValueAttrNameStr(), builder.getUnitAttr()); |
1746 | } |
1747 | result.addTypes(iterate.getType()); |
1748 | result.addOperands(iterArgs); |
1749 | for (auto v : iterArgs) |
1750 | result.addTypes(v.getType()); |
1751 | mlir::Region *bodyRegion = result.addRegion(); |
1752 | bodyRegion->push_back(new mlir::Block{}); |
1753 | bodyRegion->front().addArgument(builder.getIndexType(), result.location); |
1754 | bodyRegion->front().addArgument(iterate.getType(), result.location); |
1755 | bodyRegion->front().addArguments( |
1756 | iterArgs.getTypes(), |
1757 | llvm::SmallVector<mlir::Location>(iterArgs.size(), result.location)); |
1758 | result.addAttributes(attributes); |
1759 | } |
1760 | |
1761 | mlir::ParseResult fir::IterWhileOp::parse(mlir::OpAsmParser &parser, |
1762 | mlir::OperationState &result) { |
1763 | auto &builder = parser.getBuilder(); |
1764 | mlir::OpAsmParser::Argument inductionVariable, iterateVar; |
1765 | mlir::OpAsmParser::UnresolvedOperand lb, ub, step, iterateInput; |
1766 | if (parser.parseLParen() || parser.parseArgument(inductionVariable) || |
1767 | parser.parseEqual()) |
1768 | return mlir::failure(); |
1769 | |
1770 | // Parse loop bounds. |
1771 | auto indexType = builder.getIndexType(); |
1772 | auto i1Type = builder.getIntegerType(1); |
1773 | if (parser.parseOperand(lb) || |
1774 | parser.resolveOperand(lb, indexType, result.operands) || |
1775 | parser.parseKeyword("to" ) || parser.parseOperand(ub) || |
1776 | parser.resolveOperand(ub, indexType, result.operands) || |
1777 | parser.parseKeyword("step" ) || parser.parseOperand(step) || |
1778 | parser.parseRParen() || |
1779 | parser.resolveOperand(step, indexType, result.operands) || |
1780 | parser.parseKeyword("and" ) || parser.parseLParen() || |
1781 | parser.parseArgument(iterateVar) || parser.parseEqual() || |
1782 | parser.parseOperand(iterateInput) || parser.parseRParen() || |
1783 | parser.resolveOperand(iterateInput, i1Type, result.operands)) |
1784 | return mlir::failure(); |
1785 | |
1786 | // Parse the initial iteration arguments. |
1787 | auto prependCount = false; |
1788 | |
1789 | // Induction variable. |
1790 | llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs; |
1791 | regionArgs.push_back(inductionVariable); |
1792 | regionArgs.push_back(iterateVar); |
1793 | |
1794 | if (succeeded(parser.parseOptionalKeyword("iter_args" ))) { |
1795 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands; |
1796 | llvm::SmallVector<mlir::Type> regionTypes; |
1797 | // Parse assignment list and results type list. |
1798 | if (parser.parseAssignmentList(regionArgs, operands) || |
1799 | parser.parseArrowTypeList(regionTypes)) |
1800 | return mlir::failure(); |
1801 | if (regionTypes.size() == operands.size() + 2) |
1802 | prependCount = true; |
1803 | llvm::ArrayRef<mlir::Type> resTypes = regionTypes; |
1804 | resTypes = prependCount ? resTypes.drop_front(2) : resTypes; |
1805 | // Resolve input operands. |
1806 | for (auto operandType : llvm::zip(operands, resTypes)) |
1807 | if (parser.resolveOperand(std::get<0>(operandType), |
1808 | std::get<1>(operandType), result.operands)) |
1809 | return mlir::failure(); |
1810 | if (prependCount) { |
1811 | result.addTypes(regionTypes); |
1812 | } else { |
1813 | result.addTypes(i1Type); |
1814 | result.addTypes(resTypes); |
1815 | } |
1816 | } else if (succeeded(parser.parseOptionalArrow())) { |
1817 | llvm::SmallVector<mlir::Type> typeList; |
1818 | if (parser.parseLParen() || parser.parseTypeList(typeList) || |
1819 | parser.parseRParen()) |
1820 | return mlir::failure(); |
1821 | // Type list must be "(index, i1)". |
1822 | if (typeList.size() != 2 || !typeList[0].isa<mlir::IndexType>() || |
1823 | !typeList[1].isSignlessInteger(1)) |
1824 | return mlir::failure(); |
1825 | result.addTypes(typeList); |
1826 | prependCount = true; |
1827 | } else { |
1828 | result.addTypes(i1Type); |
1829 | } |
1830 | |
1831 | if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) |
1832 | return mlir::failure(); |
1833 | |
1834 | llvm::SmallVector<mlir::Type> argTypes; |
1835 | // Induction variable (hidden) |
1836 | if (prependCount) |
1837 | result.addAttribute(IterWhileOp::getFinalValueAttrNameStr(), |
1838 | builder.getUnitAttr()); |
1839 | else |
1840 | argTypes.push_back(indexType); |
1841 | // Loop carried variables (including iterate) |
1842 | argTypes.append(result.types.begin(), result.types.end()); |
1843 | // Parse the body region. |
1844 | auto *body = result.addRegion(); |
1845 | if (regionArgs.size() != argTypes.size()) |
1846 | return parser.emitError( |
1847 | parser.getNameLoc(), |
1848 | "mismatch in number of loop-carried values and defined values" ); |
1849 | |
1850 | for (size_t i = 0, e = regionArgs.size(); i != e; ++i) |
1851 | regionArgs[i].type = argTypes[i]; |
1852 | |
1853 | if (parser.parseRegion(*body, regionArgs)) |
1854 | return mlir::failure(); |
1855 | |
1856 | fir::IterWhileOp::ensureTerminator(*body, builder, result.location); |
1857 | return mlir::success(); |
1858 | } |
1859 | |
1860 | mlir::LogicalResult fir::IterWhileOp::verify() { |
1861 | // Check that the body defines as single block argument for the induction |
1862 | // variable. |
1863 | auto *body = getBody(); |
1864 | if (!body->getArgument(1).getType().isInteger(1)) |
1865 | return emitOpError( |
1866 | "expected body second argument to be an index argument for " |
1867 | "the induction variable" ); |
1868 | if (!body->getArgument(0).getType().isIndex()) |
1869 | return emitOpError( |
1870 | "expected body first argument to be an index argument for " |
1871 | "the induction variable" ); |
1872 | |
1873 | auto opNumResults = getNumResults(); |
1874 | if (getFinalValue()) { |
1875 | // Result type must be "(index, i1, ...)". |
1876 | if (!getResult(0).getType().isa<mlir::IndexType>()) |
1877 | return emitOpError("result #0 expected to be index" ); |
1878 | if (!getResult(1).getType().isSignlessInteger(1)) |
1879 | return emitOpError("result #1 expected to be i1" ); |
1880 | opNumResults--; |
1881 | } else { |
1882 | // iterate_while always returns the early exit induction value. |
1883 | // Result type must be "(i1, ...)" |
1884 | if (!getResult(0).getType().isSignlessInteger(1)) |
1885 | return emitOpError("result #0 expected to be i1" ); |
1886 | } |
1887 | if (opNumResults == 0) |
1888 | return mlir::failure(); |
1889 | if (getNumIterOperands() != opNumResults) |
1890 | return emitOpError( |
1891 | "mismatch in number of loop-carried values and defined values" ); |
1892 | if (getNumRegionIterArgs() != opNumResults) |
1893 | return emitOpError( |
1894 | "mismatch in number of basic block args and defined values" ); |
1895 | auto iterOperands = getIterOperands(); |
1896 | auto iterArgs = getRegionIterArgs(); |
1897 | auto opResults = getFinalValue() ? getResults().drop_front() : getResults(); |
1898 | unsigned i = 0u; |
1899 | for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { |
1900 | if (std::get<0>(e).getType() != std::get<2>(e).getType()) |
1901 | return emitOpError() << "types mismatch between " << i |
1902 | << "th iter operand and defined value" ; |
1903 | if (std::get<1>(e).getType() != std::get<2>(e).getType()) |
1904 | return emitOpError() << "types mismatch between " << i |
1905 | << "th iter region arg and defined value" ; |
1906 | |
1907 | i++; |
1908 | } |
1909 | return mlir::success(); |
1910 | } |
1911 | |
1912 | void fir::IterWhileOp::print(mlir::OpAsmPrinter &p) { |
1913 | p << " (" << getInductionVar() << " = " << getLowerBound() << " to " |
1914 | << getUpperBound() << " step " << getStep() << ") and (" ; |
1915 | assert(hasIterOperands()); |
1916 | auto regionArgs = getRegionIterArgs(); |
1917 | auto operands = getIterOperands(); |
1918 | p << regionArgs.front() << " = " << *operands.begin() << ")" ; |
1919 | if (regionArgs.size() > 1) { |
1920 | p << " iter_args(" ; |
1921 | llvm::interleaveComma( |
1922 | llvm::zip(regionArgs.drop_front(), operands.drop_front()), p, |
1923 | [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); |
1924 | p << ") -> (" ; |
1925 | llvm::interleaveComma( |
1926 | llvm::drop_begin(getResultTypes(), getFinalValue() ? 0 : 1), p); |
1927 | p << ")" ; |
1928 | } else if (getFinalValue()) { |
1929 | p << " -> (" << getResultTypes() << ')'; |
1930 | } |
1931 | p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), |
1932 | {getFinalValueAttrNameStr()}); |
1933 | p << ' '; |
1934 | p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, |
1935 | /*printBlockTerminators=*/true); |
1936 | } |
1937 | |
1938 | llvm::SmallVector<mlir::Region *> fir::IterWhileOp::getLoopRegions() { |
1939 | return {&getRegion()}; |
1940 | } |
1941 | |
1942 | mlir::BlockArgument fir::IterWhileOp::iterArgToBlockArg(mlir::Value iterArg) { |
1943 | for (auto i : llvm::enumerate(getInitArgs())) |
1944 | if (iterArg == i.value()) |
1945 | return getRegion().front().getArgument(i.index() + 1); |
1946 | return {}; |
1947 | } |
1948 | |
1949 | void fir::IterWhileOp::resultToSourceOps( |
1950 | llvm::SmallVectorImpl<mlir::Value> &results, unsigned resultNum) { |
1951 | auto oper = getFinalValue() ? resultNum + 1 : resultNum; |
1952 | auto *term = getRegion().front().getTerminator(); |
1953 | if (oper < term->getNumOperands()) |
1954 | results.push_back(term->getOperand(oper)); |
1955 | } |
1956 | |
1957 | mlir::Value fir::IterWhileOp::blockArgToSourceOp(unsigned blockArgNum) { |
1958 | if (blockArgNum > 0 && blockArgNum <= getInitArgs().size()) |
1959 | return getInitArgs()[blockArgNum - 1]; |
1960 | return {}; |
1961 | } |
1962 | |
1963 | std::optional<llvm::MutableArrayRef<mlir::OpOperand>> |
1964 | fir::IterWhileOp::getYieldedValuesMutable() { |
1965 | auto *term = getRegion().front().getTerminator(); |
1966 | return getFinalValue() ? term->getOpOperands().drop_front() |
1967 | : term->getOpOperands(); |
1968 | } |
1969 | |
1970 | //===----------------------------------------------------------------------===// |
1971 | // LenParamIndexOp |
1972 | //===----------------------------------------------------------------------===// |
1973 | |
1974 | mlir::ParseResult fir::LenParamIndexOp::parse(mlir::OpAsmParser &parser, |
1975 | mlir::OperationState &result) { |
1976 | return parseFieldLikeOp<fir::LenType>(parser, result); |
1977 | } |
1978 | |
1979 | void fir::LenParamIndexOp::print(mlir::OpAsmPrinter &p) { |
1980 | printFieldLikeOp(p, *this); |
1981 | } |
1982 | |
1983 | void fir::LenParamIndexOp::build(mlir::OpBuilder &builder, |
1984 | mlir::OperationState &result, |
1985 | llvm::StringRef fieldName, mlir::Type recTy, |
1986 | mlir::ValueRange operands) { |
1987 | result.addAttribute(getFieldAttrName(), builder.getStringAttr(fieldName)); |
1988 | result.addAttribute(getTypeAttrName(), mlir::TypeAttr::get(recTy)); |
1989 | result.addOperands(operands); |
1990 | } |
1991 | |
1992 | llvm::SmallVector<mlir::Attribute> fir::LenParamIndexOp::getAttributes() { |
1993 | llvm::SmallVector<mlir::Attribute> attrs; |
1994 | attrs.push_back(getFieldIdAttr()); |
1995 | attrs.push_back(getOnTypeAttr()); |
1996 | return attrs; |
1997 | } |
1998 | |
1999 | //===----------------------------------------------------------------------===// |
2000 | // LoadOp |
2001 | //===----------------------------------------------------------------------===// |
2002 | |
2003 | void fir::LoadOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
2004 | mlir::Value refVal) { |
2005 | if (!refVal) { |
2006 | mlir::emitError(result.location, "LoadOp has null argument" ); |
2007 | return; |
2008 | } |
2009 | auto eleTy = fir::dyn_cast_ptrEleTy(refVal.getType()); |
2010 | if (!eleTy) { |
2011 | mlir::emitError(result.location, "not a memory reference type" ); |
2012 | return; |
2013 | } |
2014 | build(builder, result, eleTy, refVal); |
2015 | } |
2016 | |
2017 | void fir::LoadOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
2018 | mlir::Type resTy, mlir::Value refVal) { |
2019 | |
2020 | if (!refVal) { |
2021 | mlir::emitError(result.location, "LoadOp has null argument" ); |
2022 | return; |
2023 | } |
2024 | result.addOperands(refVal); |
2025 | result.addTypes(resTy); |
2026 | } |
2027 | |
2028 | mlir::ParseResult fir::LoadOp::getElementOf(mlir::Type &ele, mlir::Type ref) { |
2029 | if ((ele = fir::dyn_cast_ptrEleTy(ref))) |
2030 | return mlir::success(); |
2031 | return mlir::failure(); |
2032 | } |
2033 | |
2034 | mlir::ParseResult fir::LoadOp::parse(mlir::OpAsmParser &parser, |
2035 | mlir::OperationState &result) { |
2036 | mlir::Type type; |
2037 | mlir::OpAsmParser::UnresolvedOperand oper; |
2038 | if (parser.parseOperand(oper) || |
2039 | parser.parseOptionalAttrDict(result.attributes) || |
2040 | parser.parseColonType(type) || |
2041 | parser.resolveOperand(oper, type, result.operands)) |
2042 | return mlir::failure(); |
2043 | mlir::Type eleTy; |
2044 | if (fir::LoadOp::getElementOf(eleTy, type) || |
2045 | parser.addTypeToList(eleTy, result.types)) |
2046 | return mlir::failure(); |
2047 | return mlir::success(); |
2048 | } |
2049 | |
2050 | void fir::LoadOp::print(mlir::OpAsmPrinter &p) { |
2051 | p << ' '; |
2052 | p.printOperand(getMemref()); |
2053 | p.printOptionalAttrDict(getOperation()->getAttrs(), {}); |
2054 | p << " : " << getMemref().getType(); |
2055 | } |
2056 | |
2057 | //===----------------------------------------------------------------------===// |
2058 | // DoLoopOp |
2059 | //===----------------------------------------------------------------------===// |
2060 | |
2061 | void fir::DoLoopOp::build(mlir::OpBuilder &builder, |
2062 | mlir::OperationState &result, mlir::Value lb, |
2063 | mlir::Value ub, mlir::Value step, bool unordered, |
2064 | bool finalCountValue, mlir::ValueRange iterArgs, |
2065 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
2066 | result.addOperands({lb, ub, step}); |
2067 | result.addOperands(iterArgs); |
2068 | if (finalCountValue) { |
2069 | result.addTypes(builder.getIndexType()); |
2070 | result.addAttribute(getFinalValueAttrName(result.name), |
2071 | builder.getUnitAttr()); |
2072 | } |
2073 | for (auto v : iterArgs) |
2074 | result.addTypes(v.getType()); |
2075 | mlir::Region *bodyRegion = result.addRegion(); |
2076 | bodyRegion->push_back(new mlir::Block{}); |
2077 | if (iterArgs.empty() && !finalCountValue) |
2078 | fir::DoLoopOp::ensureTerminator(*bodyRegion, builder, result.location); |
2079 | bodyRegion->front().addArgument(builder.getIndexType(), result.location); |
2080 | bodyRegion->front().addArguments( |
2081 | iterArgs.getTypes(), |
2082 | llvm::SmallVector<mlir::Location>(iterArgs.size(), result.location)); |
2083 | if (unordered) |
2084 | result.addAttribute(getUnorderedAttrName(result.name), |
2085 | builder.getUnitAttr()); |
2086 | result.addAttributes(attributes); |
2087 | } |
2088 | |
2089 | mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser, |
2090 | mlir::OperationState &result) { |
2091 | auto &builder = parser.getBuilder(); |
2092 | mlir::OpAsmParser::Argument inductionVariable; |
2093 | mlir::OpAsmParser::UnresolvedOperand lb, ub, step; |
2094 | // Parse the induction variable followed by '='. |
2095 | if (parser.parseArgument(inductionVariable) || parser.parseEqual()) |
2096 | return mlir::failure(); |
2097 | |
2098 | // Parse loop bounds. |
2099 | auto indexType = builder.getIndexType(); |
2100 | if (parser.parseOperand(lb) || |
2101 | parser.resolveOperand(lb, indexType, result.operands) || |
2102 | parser.parseKeyword("to" ) || parser.parseOperand(ub) || |
2103 | parser.resolveOperand(ub, indexType, result.operands) || |
2104 | parser.parseKeyword("step" ) || parser.parseOperand(step) || |
2105 | parser.resolveOperand(step, indexType, result.operands)) |
2106 | return mlir::failure(); |
2107 | |
2108 | if (mlir::succeeded(parser.parseOptionalKeyword("unordered" ))) |
2109 | result.addAttribute("unordered" , builder.getUnitAttr()); |
2110 | |
2111 | // Parse the optional initial iteration arguments. |
2112 | llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs; |
2113 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands; |
2114 | llvm::SmallVector<mlir::Type> argTypes; |
2115 | bool prependCount = false; |
2116 | regionArgs.push_back(inductionVariable); |
2117 | |
2118 | if (succeeded(parser.parseOptionalKeyword("iter_args" ))) { |
2119 | // Parse assignment list and results type list. |
2120 | if (parser.parseAssignmentList(regionArgs, operands) || |
2121 | parser.parseArrowTypeList(result.types)) |
2122 | return mlir::failure(); |
2123 | if (result.types.size() == operands.size() + 1) |
2124 | prependCount = true; |
2125 | // Resolve input operands. |
2126 | llvm::ArrayRef<mlir::Type> resTypes = result.types; |
2127 | for (auto operand_type : |
2128 | llvm::zip(operands, prependCount ? resTypes.drop_front() : resTypes)) |
2129 | if (parser.resolveOperand(std::get<0>(operand_type), |
2130 | std::get<1>(operand_type), result.operands)) |
2131 | return mlir::failure(); |
2132 | } else if (succeeded(parser.parseOptionalArrow())) { |
2133 | if (parser.parseKeyword("index" )) |
2134 | return mlir::failure(); |
2135 | result.types.push_back(indexType); |
2136 | prependCount = true; |
2137 | } |
2138 | |
2139 | if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) |
2140 | return mlir::failure(); |
2141 | |
2142 | // Induction variable. |
2143 | if (prependCount) |
2144 | result.addAttribute(DoLoopOp::getFinalValueAttrName(result.name), |
2145 | builder.getUnitAttr()); |
2146 | else |
2147 | argTypes.push_back(indexType); |
2148 | // Loop carried variables |
2149 | argTypes.append(result.types.begin(), result.types.end()); |
2150 | // Parse the body region. |
2151 | auto *body = result.addRegion(); |
2152 | if (regionArgs.size() != argTypes.size()) |
2153 | return parser.emitError( |
2154 | parser.getNameLoc(), |
2155 | "mismatch in number of loop-carried values and defined values" ); |
2156 | for (size_t i = 0, e = regionArgs.size(); i != e; ++i) |
2157 | regionArgs[i].type = argTypes[i]; |
2158 | |
2159 | if (parser.parseRegion(*body, regionArgs)) |
2160 | return mlir::failure(); |
2161 | |
2162 | DoLoopOp::ensureTerminator(*body, builder, result.location); |
2163 | |
2164 | return mlir::success(); |
2165 | } |
2166 | |
2167 | fir::DoLoopOp fir::getForInductionVarOwner(mlir::Value val) { |
2168 | auto ivArg = mlir::dyn_cast<mlir::BlockArgument>(val); |
2169 | if (!ivArg) |
2170 | return {}; |
2171 | assert(ivArg.getOwner() && "unlinked block argument" ); |
2172 | auto *containingInst = ivArg.getOwner()->getParentOp(); |
2173 | return mlir::dyn_cast_or_null<fir::DoLoopOp>(containingInst); |
2174 | } |
2175 | |
2176 | // Lifted from loop.loop |
2177 | mlir::LogicalResult fir::DoLoopOp::verify() { |
2178 | // Check that the body defines as single block argument for the induction |
2179 | // variable. |
2180 | auto *body = getBody(); |
2181 | if (!body->getArgument(0).getType().isIndex()) |
2182 | return emitOpError( |
2183 | "expected body first argument to be an index argument for " |
2184 | "the induction variable" ); |
2185 | |
2186 | auto opNumResults = getNumResults(); |
2187 | if (opNumResults == 0) |
2188 | return mlir::success(); |
2189 | |
2190 | if (getFinalValue()) { |
2191 | if (getUnordered()) |
2192 | return emitOpError("unordered loop has no final value" ); |
2193 | opNumResults--; |
2194 | } |
2195 | if (getNumIterOperands() != opNumResults) |
2196 | return emitOpError( |
2197 | "mismatch in number of loop-carried values and defined values" ); |
2198 | if (getNumRegionIterArgs() != opNumResults) |
2199 | return emitOpError( |
2200 | "mismatch in number of basic block args and defined values" ); |
2201 | auto iterOperands = getIterOperands(); |
2202 | auto iterArgs = getRegionIterArgs(); |
2203 | auto opResults = getFinalValue() ? getResults().drop_front() : getResults(); |
2204 | unsigned i = 0u; |
2205 | for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { |
2206 | if (std::get<0>(e).getType() != std::get<2>(e).getType()) |
2207 | return emitOpError() << "types mismatch between " << i |
2208 | << "th iter operand and defined value" ; |
2209 | if (std::get<1>(e).getType() != std::get<2>(e).getType()) |
2210 | return emitOpError() << "types mismatch between " << i |
2211 | << "th iter region arg and defined value" ; |
2212 | |
2213 | i++; |
2214 | } |
2215 | return mlir::success(); |
2216 | } |
2217 | |
2218 | void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) { |
2219 | bool printBlockTerminators = false; |
2220 | p << ' ' << getInductionVar() << " = " << getLowerBound() << " to " |
2221 | << getUpperBound() << " step " << getStep(); |
2222 | if (getUnordered()) |
2223 | p << " unordered" ; |
2224 | if (hasIterOperands()) { |
2225 | p << " iter_args(" ; |
2226 | auto regionArgs = getRegionIterArgs(); |
2227 | auto operands = getIterOperands(); |
2228 | llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) { |
2229 | p << std::get<0>(it) << " = " << std::get<1>(it); |
2230 | }); |
2231 | p << ") -> (" << getResultTypes() << ')'; |
2232 | printBlockTerminators = true; |
2233 | } else if (getFinalValue()) { |
2234 | p << " -> " << getResultTypes(); |
2235 | printBlockTerminators = true; |
2236 | } |
2237 | p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), |
2238 | {"unordered" , "finalValue" }); |
2239 | p << ' '; |
2240 | p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, |
2241 | printBlockTerminators); |
2242 | } |
2243 | |
2244 | llvm::SmallVector<mlir::Region *> fir::DoLoopOp::getLoopRegions() { |
2245 | return {&getRegion()}; |
2246 | } |
2247 | |
2248 | /// Translate a value passed as an iter_arg to the corresponding block |
2249 | /// argument in the body of the loop. |
2250 | mlir::BlockArgument fir::DoLoopOp::iterArgToBlockArg(mlir::Value iterArg) { |
2251 | for (auto i : llvm::enumerate(getInitArgs())) |
2252 | if (iterArg == i.value()) |
2253 | return getRegion().front().getArgument(i.index() + 1); |
2254 | return {}; |
2255 | } |
2256 | |
2257 | /// Translate the result vector (by index number) to the corresponding value |
2258 | /// to the `fir.result` Op. |
2259 | void fir::DoLoopOp::resultToSourceOps( |
2260 | llvm::SmallVectorImpl<mlir::Value> &results, unsigned resultNum) { |
2261 | auto oper = getFinalValue() ? resultNum + 1 : resultNum; |
2262 | auto *term = getRegion().front().getTerminator(); |
2263 | if (oper < term->getNumOperands()) |
2264 | results.push_back(term->getOperand(oper)); |
2265 | } |
2266 | |
2267 | /// Translate the block argument (by index number) to the corresponding value |
2268 | /// passed as an iter_arg to the parent DoLoopOp. |
2269 | mlir::Value fir::DoLoopOp::blockArgToSourceOp(unsigned blockArgNum) { |
2270 | if (blockArgNum > 0 && blockArgNum <= getInitArgs().size()) |
2271 | return getInitArgs()[blockArgNum - 1]; |
2272 | return {}; |
2273 | } |
2274 | |
2275 | std::optional<llvm::MutableArrayRef<mlir::OpOperand>> |
2276 | fir::DoLoopOp::getYieldedValuesMutable() { |
2277 | auto *term = getRegion().front().getTerminator(); |
2278 | return getFinalValue() ? term->getOpOperands().drop_front() |
2279 | : term->getOpOperands(); |
2280 | } |
2281 | |
2282 | //===----------------------------------------------------------------------===// |
2283 | // DTEntryOp |
2284 | //===----------------------------------------------------------------------===// |
2285 | |
2286 | mlir::ParseResult fir::DTEntryOp::parse(mlir::OpAsmParser &parser, |
2287 | mlir::OperationState &result) { |
2288 | llvm::StringRef methodName; |
2289 | // allow `methodName` or `"methodName"` |
2290 | if (failed(parser.parseOptionalKeyword(&methodName))) { |
2291 | mlir::StringAttr methodAttr; |
2292 | if (parser.parseAttribute(methodAttr, getMethodAttrName(result.name), |
2293 | result.attributes)) |
2294 | return mlir::failure(); |
2295 | } else { |
2296 | result.addAttribute(getMethodAttrName(result.name), |
2297 | parser.getBuilder().getStringAttr(methodName)); |
2298 | } |
2299 | mlir::SymbolRefAttr calleeAttr; |
2300 | if (parser.parseComma() || |
2301 | parser.parseAttribute(calleeAttr, fir::DTEntryOp::getProcAttrNameStr(), |
2302 | result.attributes)) |
2303 | return mlir::failure(); |
2304 | return mlir::success(); |
2305 | } |
2306 | |
2307 | void fir::DTEntryOp::print(mlir::OpAsmPrinter &p) { |
2308 | p << ' ' << getMethodAttr() << ", " << getProcAttr(); |
2309 | } |
2310 | |
2311 | //===----------------------------------------------------------------------===// |
2312 | // ReboxOp |
2313 | //===----------------------------------------------------------------------===// |
2314 | |
2315 | /// Get the scalar type related to a fir.box type. |
2316 | /// Example: return f32 for !fir.box<!fir.heap<!fir.array<?x?xf32>>. |
2317 | static mlir::Type getBoxScalarEleTy(mlir::Type boxTy) { |
2318 | auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(boxTy); |
2319 | if (auto seqTy = eleTy.dyn_cast<fir::SequenceType>()) |
2320 | return seqTy.getEleTy(); |
2321 | return eleTy; |
2322 | } |
2323 | |
2324 | /// Test if \p t1 and \p t2 are compatible character types (if they can |
2325 | /// represent the same type at runtime). |
2326 | static bool areCompatibleCharacterTypes(mlir::Type t1, mlir::Type t2) { |
2327 | auto c1 = t1.dyn_cast<fir::CharacterType>(); |
2328 | auto c2 = t2.dyn_cast<fir::CharacterType>(); |
2329 | if (!c1 || !c2) |
2330 | return false; |
2331 | if (c1.hasDynamicLen() || c2.hasDynamicLen()) |
2332 | return true; |
2333 | return c1.getLen() == c2.getLen(); |
2334 | } |
2335 | |
2336 | mlir::LogicalResult fir::ReboxOp::verify() { |
2337 | auto inputBoxTy = getBox().getType(); |
2338 | if (fir::isa_unknown_size_box(inputBoxTy)) |
2339 | return emitOpError("box operand must not have unknown rank or type" ); |
2340 | auto outBoxTy = getType(); |
2341 | if (fir::isa_unknown_size_box(outBoxTy)) |
2342 | return emitOpError("result type must not have unknown rank or type" ); |
2343 | auto inputRank = fir::getBoxRank(inputBoxTy); |
2344 | auto inputEleTy = getBoxScalarEleTy(inputBoxTy); |
2345 | auto outRank = fir::getBoxRank(outBoxTy); |
2346 | auto outEleTy = getBoxScalarEleTy(outBoxTy); |
2347 | |
2348 | if (auto sliceVal = getSlice()) { |
2349 | // Slicing case |
2350 | if (sliceVal.getType().cast<fir::SliceType>().getRank() != inputRank) |
2351 | return emitOpError("slice operand rank must match box operand rank" ); |
2352 | if (auto shapeVal = getShape()) { |
2353 | if (auto shiftTy = shapeVal.getType().dyn_cast<fir::ShiftType>()) { |
2354 | if (shiftTy.getRank() != inputRank) |
2355 | return emitOpError("shape operand and input box ranks must match " |
2356 | "when there is a slice" ); |
2357 | } else { |
2358 | return emitOpError("shape operand must absent or be a fir.shift " |
2359 | "when there is a slice" ); |
2360 | } |
2361 | } |
2362 | if (auto sliceOp = sliceVal.getDefiningOp()) { |
2363 | auto slicedRank = mlir::cast<fir::SliceOp>(sliceOp).getOutRank(); |
2364 | if (slicedRank != outRank) |
2365 | return emitOpError("result type rank and rank after applying slice " |
2366 | "operand must match" ); |
2367 | } |
2368 | } else { |
2369 | // Reshaping case |
2370 | unsigned shapeRank = inputRank; |
2371 | if (auto shapeVal = getShape()) { |
2372 | auto ty = shapeVal.getType(); |
2373 | if (auto shapeTy = ty.dyn_cast<fir::ShapeType>()) { |
2374 | shapeRank = shapeTy.getRank(); |
2375 | } else if (auto shapeShiftTy = ty.dyn_cast<fir::ShapeShiftType>()) { |
2376 | shapeRank = shapeShiftTy.getRank(); |
2377 | } else { |
2378 | auto shiftTy = ty.cast<fir::ShiftType>(); |
2379 | shapeRank = shiftTy.getRank(); |
2380 | if (shapeRank != inputRank) |
2381 | return emitOpError("shape operand and input box ranks must match " |
2382 | "when the shape is a fir.shift" ); |
2383 | } |
2384 | } |
2385 | if (shapeRank != outRank) |
2386 | return emitOpError("result type and shape operand ranks must match" ); |
2387 | } |
2388 | |
2389 | if (inputEleTy != outEleTy) { |
2390 | // TODO: check that outBoxTy is a parent type of inputBoxTy for derived |
2391 | // types. |
2392 | // Character input and output types with constant length may be different if |
2393 | // there is a substring in the slice, otherwise, they must match. If any of |
2394 | // the types is a character with dynamic length, the other type can be any |
2395 | // character type. |
2396 | const bool typeCanMismatch = |
2397 | inputEleTy.isa<fir::RecordType>() || outEleTy.isa<mlir::NoneType>() || |
2398 | (inputEleTy.isa<mlir::NoneType>() && outEleTy.isa<fir::RecordType>()) || |
2399 | (getSlice() && inputEleTy.isa<fir::CharacterType>()) || |
2400 | (getSlice() && fir::isa_complex(inputEleTy) && |
2401 | outEleTy.isa<mlir::FloatType>()) || |
2402 | areCompatibleCharacterTypes(inputEleTy, outEleTy); |
2403 | if (!typeCanMismatch) |
2404 | return emitOpError( |
2405 | "op input and output element types must match for intrinsic types" ); |
2406 | } |
2407 | return mlir::success(); |
2408 | } |
2409 | |
2410 | //===----------------------------------------------------------------------===// |
2411 | // ResultOp |
2412 | //===----------------------------------------------------------------------===// |
2413 | |
2414 | mlir::LogicalResult fir::ResultOp::verify() { |
2415 | auto *parentOp = (*this)->getParentOp(); |
2416 | auto results = parentOp->getResults(); |
2417 | auto operands = (*this)->getOperands(); |
2418 | |
2419 | if (parentOp->getNumResults() != getNumOperands()) |
2420 | return emitOpError() << "parent of result must have same arity" ; |
2421 | for (auto e : llvm::zip(results, operands)) |
2422 | if (std::get<0>(e).getType() != std::get<1>(e).getType()) |
2423 | return emitOpError() << "types mismatch between result op and its parent" ; |
2424 | return mlir::success(); |
2425 | } |
2426 | |
2427 | //===----------------------------------------------------------------------===// |
2428 | // SaveResultOp |
2429 | //===----------------------------------------------------------------------===// |
2430 | |
2431 | mlir::LogicalResult fir::SaveResultOp::verify() { |
2432 | auto resultType = getValue().getType(); |
2433 | if (resultType != fir::dyn_cast_ptrEleTy(getMemref().getType())) |
2434 | return emitOpError("value type must match memory reference type" ); |
2435 | if (fir::isa_unknown_size_box(resultType)) |
2436 | return emitOpError("cannot save !fir.box of unknown rank or type" ); |
2437 | |
2438 | if (resultType.isa<fir::BoxType>()) { |
2439 | if (getShape() || !getTypeparams().empty()) |
2440 | return emitOpError( |
2441 | "must not have shape or length operands if the value is a fir.box" ); |
2442 | return mlir::success(); |
2443 | } |
2444 | |
2445 | // fir.record or fir.array case. |
2446 | unsigned shapeTyRank = 0; |
2447 | if (auto shapeVal = getShape()) { |
2448 | auto shapeTy = shapeVal.getType(); |
2449 | if (auto s = shapeTy.dyn_cast<fir::ShapeType>()) |
2450 | shapeTyRank = s.getRank(); |
2451 | else |
2452 | shapeTyRank = shapeTy.cast<fir::ShapeShiftType>().getRank(); |
2453 | } |
2454 | |
2455 | auto eleTy = resultType; |
2456 | if (auto seqTy = resultType.dyn_cast<fir::SequenceType>()) { |
2457 | if (seqTy.getDimension() != shapeTyRank) |
2458 | emitOpError("shape operand must be provided and have the value rank " |
2459 | "when the value is a fir.array" ); |
2460 | eleTy = seqTy.getEleTy(); |
2461 | } else { |
2462 | if (shapeTyRank != 0) |
2463 | emitOpError( |
2464 | "shape operand should only be provided if the value is a fir.array" ); |
2465 | } |
2466 | |
2467 | if (auto recTy = eleTy.dyn_cast<fir::RecordType>()) { |
2468 | if (recTy.getNumLenParams() != getTypeparams().size()) |
2469 | emitOpError("length parameters number must match with the value type " |
2470 | "length parameters" ); |
2471 | } else if (auto charTy = eleTy.dyn_cast<fir::CharacterType>()) { |
2472 | if (getTypeparams().size() > 1) |
2473 | emitOpError("no more than one length parameter must be provided for " |
2474 | "character value" ); |
2475 | } else { |
2476 | if (!getTypeparams().empty()) |
2477 | emitOpError("length parameters must not be provided for this value type" ); |
2478 | } |
2479 | |
2480 | return mlir::success(); |
2481 | } |
2482 | |
2483 | //===----------------------------------------------------------------------===// |
2484 | // IntegralSwitchTerminator |
2485 | //===----------------------------------------------------------------------===// |
2486 | static constexpr llvm::StringRef getCompareOffsetAttr() { |
2487 | return "compare_operand_offsets" ; |
2488 | } |
2489 | |
2490 | static constexpr llvm::StringRef getTargetOffsetAttr() { |
2491 | return "target_operand_offsets" ; |
2492 | } |
2493 | |
2494 | template <typename OpT> |
2495 | static mlir::LogicalResult verifyIntegralSwitchTerminator(OpT op) { |
2496 | if (!op.getSelector() |
2497 | .getType() |
2498 | .template isa<mlir::IntegerType, mlir::IndexType, |
2499 | fir::IntegerType>()) |
2500 | return op.emitOpError("must be an integer" ); |
2501 | auto cases = |
2502 | op->template getAttrOfType<mlir::ArrayAttr>(op.getCasesAttr()).getValue(); |
2503 | auto count = op.getNumDest(); |
2504 | if (count == 0) |
2505 | return op.emitOpError("must have at least one successor" ); |
2506 | if (op.getNumConditions() != count) |
2507 | return op.emitOpError("number of cases and targets don't match" ); |
2508 | if (op.targetOffsetSize() != count) |
2509 | return op.emitOpError("incorrect number of successor operand groups" ); |
2510 | for (decltype(count) i = 0; i != count; ++i) { |
2511 | if (!cases[i].template isa<mlir::IntegerAttr, mlir::UnitAttr>()) |
2512 | return op.emitOpError("invalid case alternative" ); |
2513 | } |
2514 | return mlir::success(); |
2515 | } |
2516 | |
2517 | static mlir::ParseResult parseIntegralSwitchTerminator( |
2518 | mlir::OpAsmParser &parser, mlir::OperationState &result, |
2519 | llvm::StringRef casesAttr, llvm::StringRef operandSegmentAttr) { |
2520 | mlir::OpAsmParser::UnresolvedOperand selector; |
2521 | mlir::Type type; |
2522 | if (fir::parseSelector(parser, result, selector, type)) |
2523 | return mlir::failure(); |
2524 | |
2525 | llvm::SmallVector<mlir::Attribute> ivalues; |
2526 | llvm::SmallVector<mlir::Block *> dests; |
2527 | llvm::SmallVector<llvm::SmallVector<mlir::Value>> destArgs; |
2528 | while (true) { |
2529 | mlir::Attribute ivalue; // Integer or Unit |
2530 | mlir::Block *dest; |
2531 | llvm::SmallVector<mlir::Value> destArg; |
2532 | mlir::NamedAttrList temp; |
2533 | if (parser.parseAttribute(result&: ivalue, attrName: "i" , attrs&: temp) || parser.parseComma() || |
2534 | parser.parseSuccessorAndUseList(dest, operands&: destArg)) |
2535 | return mlir::failure(); |
2536 | ivalues.push_back(Elt: ivalue); |
2537 | dests.push_back(Elt: dest); |
2538 | destArgs.push_back(Elt: destArg); |
2539 | if (!parser.parseOptionalRSquare()) |
2540 | break; |
2541 | if (parser.parseComma()) |
2542 | return mlir::failure(); |
2543 | } |
2544 | auto &bld = parser.getBuilder(); |
2545 | result.addAttribute(casesAttr, bld.getArrayAttr(ivalues)); |
2546 | llvm::SmallVector<int32_t> argOffs; |
2547 | int32_t sumArgs = 0; |
2548 | const auto count = dests.size(); |
2549 | for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { |
2550 | result.addSuccessors(successor: dests[i]); |
2551 | result.addOperands(newOperands: destArgs[i]); |
2552 | auto argSize = destArgs[i].size(); |
2553 | argOffs.push_back(Elt: argSize); |
2554 | sumArgs += argSize; |
2555 | } |
2556 | result.addAttribute(operandSegmentAttr, |
2557 | bld.getDenseI32ArrayAttr({1, 0, sumArgs})); |
2558 | result.addAttribute(getTargetOffsetAttr(), bld.getDenseI32ArrayAttr(argOffs)); |
2559 | return mlir::success(); |
2560 | } |
2561 | |
2562 | template <typename OpT> |
2563 | static void printIntegralSwitchTerminator(OpT op, mlir::OpAsmPrinter &p) { |
2564 | p << ' '; |
2565 | p.printOperand(op.getSelector()); |
2566 | p << " : " << op.getSelector().getType() << " [" ; |
2567 | auto cases = |
2568 | op->template getAttrOfType<mlir::ArrayAttr>(op.getCasesAttr()).getValue(); |
2569 | auto count = op.getNumConditions(); |
2570 | for (decltype(count) i = 0; i != count; ++i) { |
2571 | if (i) |
2572 | p << ", " ; |
2573 | auto &attr = cases[i]; |
2574 | if (auto intAttr = attr.template dyn_cast_or_null<mlir::IntegerAttr>()) |
2575 | p << intAttr.getValue(); |
2576 | else |
2577 | p.printAttribute(attr); |
2578 | p << ", " ; |
2579 | op.printSuccessorAtIndex(p, i); |
2580 | } |
2581 | p << ']'; |
2582 | p.printOptionalAttrDict( |
2583 | attrs: op->getAttrs(), elidedAttrs: {op.getCasesAttr(), getCompareOffsetAttr(), |
2584 | getTargetOffsetAttr(), op.getOperandSegmentSizeAttr()}); |
2585 | } |
2586 | |
2587 | //===----------------------------------------------------------------------===// |
2588 | // SelectOp |
2589 | //===----------------------------------------------------------------------===// |
2590 | |
2591 | mlir::LogicalResult fir::SelectOp::verify() { |
2592 | return verifyIntegralSwitchTerminator(*this); |
2593 | } |
2594 | |
2595 | mlir::ParseResult fir::SelectOp::parse(mlir::OpAsmParser &parser, |
2596 | mlir::OperationState &result) { |
2597 | return parseIntegralSwitchTerminator(parser, result, getCasesAttr(), |
2598 | getOperandSegmentSizeAttr()); |
2599 | } |
2600 | |
2601 | void fir::SelectOp::print(mlir::OpAsmPrinter &p) { |
2602 | printIntegralSwitchTerminator(*this, p); |
2603 | } |
2604 | |
2605 | template <typename A, typename... AdditionalArgs> |
2606 | static A getSubOperands(unsigned pos, A allArgs, mlir::DenseI32ArrayAttr ranges, |
2607 | AdditionalArgs &&...additionalArgs) { |
2608 | unsigned start = 0; |
2609 | for (unsigned i = 0; i < pos; ++i) |
2610 | start += ranges[i]; |
2611 | return allArgs.slice(start, ranges[pos], |
2612 | std::forward<AdditionalArgs>(additionalArgs)...); |
2613 | } |
2614 | |
2615 | static mlir::MutableOperandRange |
2616 | getMutableSuccessorOperands(unsigned pos, mlir::MutableOperandRange operands, |
2617 | llvm::StringRef offsetAttr) { |
2618 | mlir::Operation *owner = operands.getOwner(); |
2619 | mlir::NamedAttribute targetOffsetAttr = |
2620 | *owner->getAttrDictionary().getNamed(offsetAttr); |
2621 | return getSubOperands( |
2622 | pos, operands, |
2623 | targetOffsetAttr.getValue().cast<mlir::DenseI32ArrayAttr>(), |
2624 | mlir::MutableOperandRange::OperandSegment(pos, targetOffsetAttr)); |
2625 | } |
2626 | |
2627 | std::optional<mlir::OperandRange> fir::SelectOp::getCompareOperands(unsigned) { |
2628 | return {}; |
2629 | } |
2630 | |
2631 | std::optional<llvm::ArrayRef<mlir::Value>> |
2632 | fir::SelectOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { |
2633 | return {}; |
2634 | } |
2635 | |
2636 | mlir::SuccessorOperands fir::SelectOp::getSuccessorOperands(unsigned oper) { |
2637 | return mlir::SuccessorOperands(::getMutableSuccessorOperands( |
2638 | oper, getTargetArgsMutable(), getTargetOffsetAttr())); |
2639 | } |
2640 | |
2641 | std::optional<llvm::ArrayRef<mlir::Value>> |
2642 | fir::SelectOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, |
2643 | unsigned oper) { |
2644 | auto a = |
2645 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); |
2646 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
2647 | getOperandSegmentSizeAttr()); |
2648 | return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; |
2649 | } |
2650 | |
2651 | std::optional<mlir::ValueRange> |
2652 | fir::SelectOp::getSuccessorOperands(mlir::ValueRange operands, unsigned oper) { |
2653 | auto a = |
2654 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); |
2655 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
2656 | getOperandSegmentSizeAttr()); |
2657 | return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; |
2658 | } |
2659 | |
2660 | unsigned fir::SelectOp::targetOffsetSize() { |
2661 | return (*this) |
2662 | ->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()) |
2663 | .size(); |
2664 | } |
2665 | |
2666 | //===----------------------------------------------------------------------===// |
2667 | // SelectCaseOp |
2668 | //===----------------------------------------------------------------------===// |
2669 | |
2670 | std::optional<mlir::OperandRange> |
2671 | fir::SelectCaseOp::getCompareOperands(unsigned cond) { |
2672 | auto a = |
2673 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getCompareOffsetAttr()); |
2674 | return {getSubOperands(cond, getCompareArgs(), a)}; |
2675 | } |
2676 | |
2677 | std::optional<llvm::ArrayRef<mlir::Value>> |
2678 | fir::SelectCaseOp::getCompareOperands(llvm::ArrayRef<mlir::Value> operands, |
2679 | unsigned cond) { |
2680 | auto a = |
2681 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getCompareOffsetAttr()); |
2682 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
2683 | getOperandSegmentSizeAttr()); |
2684 | return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; |
2685 | } |
2686 | |
2687 | std::optional<mlir::ValueRange> |
2688 | fir::SelectCaseOp::getCompareOperands(mlir::ValueRange operands, |
2689 | unsigned cond) { |
2690 | auto a = |
2691 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getCompareOffsetAttr()); |
2692 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
2693 | getOperandSegmentSizeAttr()); |
2694 | return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; |
2695 | } |
2696 | |
2697 | mlir::SuccessorOperands fir::SelectCaseOp::getSuccessorOperands(unsigned oper) { |
2698 | return mlir::SuccessorOperands(::getMutableSuccessorOperands( |
2699 | oper, getTargetArgsMutable(), getTargetOffsetAttr())); |
2700 | } |
2701 | |
2702 | std::optional<llvm::ArrayRef<mlir::Value>> |
2703 | fir::SelectCaseOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, |
2704 | unsigned oper) { |
2705 | auto a = |
2706 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); |
2707 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
2708 | getOperandSegmentSizeAttr()); |
2709 | return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; |
2710 | } |
2711 | |
2712 | std::optional<mlir::ValueRange> |
2713 | fir::SelectCaseOp::getSuccessorOperands(mlir::ValueRange operands, |
2714 | unsigned oper) { |
2715 | auto a = |
2716 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); |
2717 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
2718 | getOperandSegmentSizeAttr()); |
2719 | return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; |
2720 | } |
2721 | |
2722 | // parser for fir.select_case Op |
2723 | mlir::ParseResult fir::SelectCaseOp::parse(mlir::OpAsmParser &parser, |
2724 | mlir::OperationState &result) { |
2725 | mlir::OpAsmParser::UnresolvedOperand selector; |
2726 | mlir::Type type; |
2727 | if (fir::parseSelector(parser, result, selector, type)) |
2728 | return mlir::failure(); |
2729 | |
2730 | llvm::SmallVector<mlir::Attribute> attrs; |
2731 | llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> opers; |
2732 | llvm::SmallVector<mlir::Block *> dests; |
2733 | llvm::SmallVector<llvm::SmallVector<mlir::Value>> destArgs; |
2734 | llvm::SmallVector<std::int32_t> argOffs; |
2735 | std::int32_t offSize = 0; |
2736 | while (true) { |
2737 | mlir::Attribute attr; |
2738 | mlir::Block *dest; |
2739 | llvm::SmallVector<mlir::Value> destArg; |
2740 | mlir::NamedAttrList temp; |
2741 | if (parser.parseAttribute(attr, "a" , temp) || isValidCaseAttr(attr) || |
2742 | parser.parseComma()) |
2743 | return mlir::failure(); |
2744 | attrs.push_back(attr); |
2745 | if (attr.dyn_cast_or_null<mlir::UnitAttr>()) { |
2746 | argOffs.push_back(0); |
2747 | } else if (attr.dyn_cast_or_null<fir::ClosedIntervalAttr>()) { |
2748 | mlir::OpAsmParser::UnresolvedOperand oper1; |
2749 | mlir::OpAsmParser::UnresolvedOperand oper2; |
2750 | if (parser.parseOperand(oper1) || parser.parseComma() || |
2751 | parser.parseOperand(oper2) || parser.parseComma()) |
2752 | return mlir::failure(); |
2753 | opers.push_back(oper1); |
2754 | opers.push_back(oper2); |
2755 | argOffs.push_back(2); |
2756 | offSize += 2; |
2757 | } else { |
2758 | mlir::OpAsmParser::UnresolvedOperand oper; |
2759 | if (parser.parseOperand(oper) || parser.parseComma()) |
2760 | return mlir::failure(); |
2761 | opers.push_back(oper); |
2762 | argOffs.push_back(1); |
2763 | ++offSize; |
2764 | } |
2765 | if (parser.parseSuccessorAndUseList(dest, destArg)) |
2766 | return mlir::failure(); |
2767 | dests.push_back(dest); |
2768 | destArgs.push_back(destArg); |
2769 | if (mlir::succeeded(parser.parseOptionalRSquare())) |
2770 | break; |
2771 | if (parser.parseComma()) |
2772 | return mlir::failure(); |
2773 | } |
2774 | result.addAttribute(fir::SelectCaseOp::getCasesAttr(), |
2775 | parser.getBuilder().getArrayAttr(attrs)); |
2776 | if (parser.resolveOperands(opers, type, result.operands)) |
2777 | return mlir::failure(); |
2778 | llvm::SmallVector<int32_t> targOffs; |
2779 | int32_t toffSize = 0; |
2780 | const auto count = dests.size(); |
2781 | for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { |
2782 | result.addSuccessors(dests[i]); |
2783 | result.addOperands(destArgs[i]); |
2784 | auto argSize = destArgs[i].size(); |
2785 | targOffs.push_back(argSize); |
2786 | toffSize += argSize; |
2787 | } |
2788 | auto &bld = parser.getBuilder(); |
2789 | result.addAttribute(fir::SelectCaseOp::getOperandSegmentSizeAttr(), |
2790 | bld.getDenseI32ArrayAttr({1, offSize, toffSize})); |
2791 | result.addAttribute(getCompareOffsetAttr(), |
2792 | bld.getDenseI32ArrayAttr(argOffs)); |
2793 | result.addAttribute(getTargetOffsetAttr(), |
2794 | bld.getDenseI32ArrayAttr(targOffs)); |
2795 | return mlir::success(); |
2796 | } |
2797 | |
2798 | void fir::SelectCaseOp::print(mlir::OpAsmPrinter &p) { |
2799 | p << ' '; |
2800 | p.printOperand(getSelector()); |
2801 | p << " : " << getSelector().getType() << " [" ; |
2802 | auto cases = |
2803 | getOperation()->getAttrOfType<mlir::ArrayAttr>(getCasesAttr()).getValue(); |
2804 | auto count = getNumConditions(); |
2805 | for (decltype(count) i = 0; i != count; ++i) { |
2806 | if (i) |
2807 | p << ", " ; |
2808 | p << cases[i] << ", " ; |
2809 | if (!cases[i].isa<mlir::UnitAttr>()) { |
2810 | auto caseArgs = *getCompareOperands(i); |
2811 | p.printOperand(*caseArgs.begin()); |
2812 | p << ", " ; |
2813 | if (cases[i].isa<fir::ClosedIntervalAttr>()) { |
2814 | p.printOperand(*(++caseArgs.begin())); |
2815 | p << ", " ; |
2816 | } |
2817 | } |
2818 | printSuccessorAtIndex(p, i); |
2819 | } |
2820 | p << ']'; |
2821 | p.printOptionalAttrDict(getOperation()->getAttrs(), |
2822 | {getCasesAttr(), getCompareOffsetAttr(), |
2823 | getTargetOffsetAttr(), getOperandSegmentSizeAttr()}); |
2824 | } |
2825 | |
2826 | unsigned fir::SelectCaseOp::compareOffsetSize() { |
2827 | return (*this) |
2828 | ->getAttrOfType<mlir::DenseI32ArrayAttr>(getCompareOffsetAttr()) |
2829 | .size(); |
2830 | } |
2831 | |
2832 | unsigned fir::SelectCaseOp::targetOffsetSize() { |
2833 | return (*this) |
2834 | ->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()) |
2835 | .size(); |
2836 | } |
2837 | |
2838 | void fir::SelectCaseOp::build(mlir::OpBuilder &builder, |
2839 | mlir::OperationState &result, |
2840 | mlir::Value selector, |
2841 | llvm::ArrayRef<mlir::Attribute> compareAttrs, |
2842 | llvm::ArrayRef<mlir::ValueRange> cmpOperands, |
2843 | llvm::ArrayRef<mlir::Block *> destinations, |
2844 | llvm::ArrayRef<mlir::ValueRange> destOperands, |
2845 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
2846 | result.addOperands(selector); |
2847 | result.addAttribute(getCasesAttr(), builder.getArrayAttr(compareAttrs)); |
2848 | llvm::SmallVector<int32_t> operOffs; |
2849 | int32_t operSize = 0; |
2850 | for (auto attr : compareAttrs) { |
2851 | if (attr.isa<fir::ClosedIntervalAttr>()) { |
2852 | operOffs.push_back(2); |
2853 | operSize += 2; |
2854 | } else if (attr.isa<mlir::UnitAttr>()) { |
2855 | operOffs.push_back(0); |
2856 | } else { |
2857 | operOffs.push_back(1); |
2858 | ++operSize; |
2859 | } |
2860 | } |
2861 | for (auto ops : cmpOperands) |
2862 | result.addOperands(ops); |
2863 | result.addAttribute(getCompareOffsetAttr(), |
2864 | builder.getDenseI32ArrayAttr(operOffs)); |
2865 | const auto count = destinations.size(); |
2866 | for (auto d : destinations) |
2867 | result.addSuccessors(d); |
2868 | const auto opCount = destOperands.size(); |
2869 | llvm::SmallVector<std::int32_t> argOffs; |
2870 | std::int32_t sumArgs = 0; |
2871 | for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { |
2872 | if (i < opCount) { |
2873 | result.addOperands(destOperands[i]); |
2874 | const auto argSz = destOperands[i].size(); |
2875 | argOffs.push_back(argSz); |
2876 | sumArgs += argSz; |
2877 | } else { |
2878 | argOffs.push_back(0); |
2879 | } |
2880 | } |
2881 | result.addAttribute(getOperandSegmentSizeAttr(), |
2882 | builder.getDenseI32ArrayAttr({1, operSize, sumArgs})); |
2883 | result.addAttribute(getTargetOffsetAttr(), |
2884 | builder.getDenseI32ArrayAttr(argOffs)); |
2885 | result.addAttributes(attributes); |
2886 | } |
2887 | |
2888 | /// This builder has a slightly simplified interface in that the list of |
2889 | /// operands need not be partitioned by the builder. Instead the operands are |
2890 | /// partitioned here, before being passed to the default builder. This |
2891 | /// partitioning is unchecked, so can go awry on bad input. |
2892 | void fir::SelectCaseOp::build(mlir::OpBuilder &builder, |
2893 | mlir::OperationState &result, |
2894 | mlir::Value selector, |
2895 | llvm::ArrayRef<mlir::Attribute> compareAttrs, |
2896 | llvm::ArrayRef<mlir::Value> cmpOpList, |
2897 | llvm::ArrayRef<mlir::Block *> destinations, |
2898 | llvm::ArrayRef<mlir::ValueRange> destOperands, |
2899 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
2900 | llvm::SmallVector<mlir::ValueRange> cmpOpers; |
2901 | auto iter = cmpOpList.begin(); |
2902 | for (auto &attr : compareAttrs) { |
2903 | if (attr.isa<fir::ClosedIntervalAttr>()) { |
2904 | cmpOpers.push_back(mlir::ValueRange({iter, iter + 2})); |
2905 | iter += 2; |
2906 | } else if (attr.isa<mlir::UnitAttr>()) { |
2907 | cmpOpers.push_back(mlir::ValueRange{}); |
2908 | } else { |
2909 | cmpOpers.push_back(mlir::ValueRange({iter, iter + 1})); |
2910 | ++iter; |
2911 | } |
2912 | } |
2913 | build(builder, result, selector, compareAttrs, cmpOpers, destinations, |
2914 | destOperands, attributes); |
2915 | } |
2916 | |
2917 | mlir::LogicalResult fir::SelectCaseOp::verify() { |
2918 | if (!getSelector() |
2919 | .getType() |
2920 | .isa<mlir::IntegerType, mlir::IndexType, fir::IntegerType, |
2921 | fir::LogicalType, fir::CharacterType>()) |
2922 | return emitOpError("must be an integer, character, or logical" ); |
2923 | auto cases = |
2924 | getOperation()->getAttrOfType<mlir::ArrayAttr>(getCasesAttr()).getValue(); |
2925 | auto count = getNumDest(); |
2926 | if (count == 0) |
2927 | return emitOpError("must have at least one successor" ); |
2928 | if (getNumConditions() != count) |
2929 | return emitOpError("number of conditions and successors don't match" ); |
2930 | if (compareOffsetSize() != count) |
2931 | return emitOpError("incorrect number of compare operand groups" ); |
2932 | if (targetOffsetSize() != count) |
2933 | return emitOpError("incorrect number of successor operand groups" ); |
2934 | for (decltype(count) i = 0; i != count; ++i) { |
2935 | auto &attr = cases[i]; |
2936 | if (!(attr.isa<fir::PointIntervalAttr>() || |
2937 | attr.isa<fir::LowerBoundAttr>() || attr.isa<fir::UpperBoundAttr>() || |
2938 | attr.isa<fir::ClosedIntervalAttr>() || attr.isa<mlir::UnitAttr>())) |
2939 | return emitOpError("incorrect select case attribute type" ); |
2940 | } |
2941 | return mlir::success(); |
2942 | } |
2943 | |
2944 | //===----------------------------------------------------------------------===// |
2945 | // SelectRankOp |
2946 | //===----------------------------------------------------------------------===// |
2947 | |
2948 | mlir::LogicalResult fir::SelectRankOp::verify() { |
2949 | return verifyIntegralSwitchTerminator(*this); |
2950 | } |
2951 | |
2952 | mlir::ParseResult fir::SelectRankOp::parse(mlir::OpAsmParser &parser, |
2953 | mlir::OperationState &result) { |
2954 | return parseIntegralSwitchTerminator(parser, result, getCasesAttr(), |
2955 | getOperandSegmentSizeAttr()); |
2956 | } |
2957 | |
2958 | void fir::SelectRankOp::print(mlir::OpAsmPrinter &p) { |
2959 | printIntegralSwitchTerminator(*this, p); |
2960 | } |
2961 | |
2962 | std::optional<mlir::OperandRange> |
2963 | fir::SelectRankOp::getCompareOperands(unsigned) { |
2964 | return {}; |
2965 | } |
2966 | |
2967 | std::optional<llvm::ArrayRef<mlir::Value>> |
2968 | fir::SelectRankOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { |
2969 | return {}; |
2970 | } |
2971 | |
2972 | mlir::SuccessorOperands fir::SelectRankOp::getSuccessorOperands(unsigned oper) { |
2973 | return mlir::SuccessorOperands(::getMutableSuccessorOperands( |
2974 | oper, getTargetArgsMutable(), getTargetOffsetAttr())); |
2975 | } |
2976 | |
2977 | std::optional<llvm::ArrayRef<mlir::Value>> |
2978 | fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, |
2979 | unsigned oper) { |
2980 | auto a = |
2981 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); |
2982 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
2983 | getOperandSegmentSizeAttr()); |
2984 | return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; |
2985 | } |
2986 | |
2987 | std::optional<mlir::ValueRange> |
2988 | fir::SelectRankOp::getSuccessorOperands(mlir::ValueRange operands, |
2989 | unsigned oper) { |
2990 | auto a = |
2991 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); |
2992 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
2993 | getOperandSegmentSizeAttr()); |
2994 | return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; |
2995 | } |
2996 | |
2997 | unsigned fir::SelectRankOp::targetOffsetSize() { |
2998 | return (*this) |
2999 | ->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()) |
3000 | .size(); |
3001 | } |
3002 | |
3003 | //===----------------------------------------------------------------------===// |
3004 | // SelectTypeOp |
3005 | //===----------------------------------------------------------------------===// |
3006 | |
3007 | std::optional<mlir::OperandRange> |
3008 | fir::SelectTypeOp::getCompareOperands(unsigned) { |
3009 | return {}; |
3010 | } |
3011 | |
3012 | std::optional<llvm::ArrayRef<mlir::Value>> |
3013 | fir::SelectTypeOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { |
3014 | return {}; |
3015 | } |
3016 | |
3017 | mlir::SuccessorOperands fir::SelectTypeOp::getSuccessorOperands(unsigned oper) { |
3018 | return mlir::SuccessorOperands(::getMutableSuccessorOperands( |
3019 | oper, getTargetArgsMutable(), getTargetOffsetAttr())); |
3020 | } |
3021 | |
3022 | std::optional<llvm::ArrayRef<mlir::Value>> |
3023 | fir::SelectTypeOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, |
3024 | unsigned oper) { |
3025 | auto a = |
3026 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); |
3027 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
3028 | getOperandSegmentSizeAttr()); |
3029 | return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; |
3030 | } |
3031 | |
3032 | std::optional<mlir::ValueRange> |
3033 | fir::SelectTypeOp::getSuccessorOperands(mlir::ValueRange operands, |
3034 | unsigned oper) { |
3035 | auto a = |
3036 | (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); |
3037 | auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( |
3038 | getOperandSegmentSizeAttr()); |
3039 | return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; |
3040 | } |
3041 | |
3042 | mlir::ParseResult fir::SelectTypeOp::parse(mlir::OpAsmParser &parser, |
3043 | mlir::OperationState &result) { |
3044 | mlir::OpAsmParser::UnresolvedOperand selector; |
3045 | mlir::Type type; |
3046 | if (fir::parseSelector(parser, result, selector, type)) |
3047 | return mlir::failure(); |
3048 | |
3049 | llvm::SmallVector<mlir::Attribute> attrs; |
3050 | llvm::SmallVector<mlir::Block *> dests; |
3051 | llvm::SmallVector<llvm::SmallVector<mlir::Value>> destArgs; |
3052 | while (true) { |
3053 | mlir::Attribute attr; |
3054 | mlir::Block *dest; |
3055 | llvm::SmallVector<mlir::Value> destArg; |
3056 | mlir::NamedAttrList temp; |
3057 | if (parser.parseAttribute(attr, "a" , temp) || parser.parseComma() || |
3058 | parser.parseSuccessorAndUseList(dest, destArg)) |
3059 | return mlir::failure(); |
3060 | attrs.push_back(attr); |
3061 | dests.push_back(dest); |
3062 | destArgs.push_back(destArg); |
3063 | if (mlir::succeeded(parser.parseOptionalRSquare())) |
3064 | break; |
3065 | if (parser.parseComma()) |
3066 | return mlir::failure(); |
3067 | } |
3068 | auto &bld = parser.getBuilder(); |
3069 | result.addAttribute(fir::SelectTypeOp::getCasesAttr(), |
3070 | bld.getArrayAttr(attrs)); |
3071 | llvm::SmallVector<int32_t> argOffs; |
3072 | int32_t offSize = 0; |
3073 | const auto count = dests.size(); |
3074 | for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { |
3075 | result.addSuccessors(dests[i]); |
3076 | result.addOperands(destArgs[i]); |
3077 | auto argSize = destArgs[i].size(); |
3078 | argOffs.push_back(argSize); |
3079 | offSize += argSize; |
3080 | } |
3081 | result.addAttribute(fir::SelectTypeOp::getOperandSegmentSizeAttr(), |
3082 | bld.getDenseI32ArrayAttr({1, 0, offSize})); |
3083 | result.addAttribute(getTargetOffsetAttr(), bld.getDenseI32ArrayAttr(argOffs)); |
3084 | return mlir::success(); |
3085 | } |
3086 | |
3087 | unsigned fir::SelectTypeOp::targetOffsetSize() { |
3088 | return (*this) |
3089 | ->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()) |
3090 | .size(); |
3091 | } |
3092 | |
3093 | void fir::SelectTypeOp::print(mlir::OpAsmPrinter &p) { |
3094 | p << ' '; |
3095 | p.printOperand(getSelector()); |
3096 | p << " : " << getSelector().getType() << " [" ; |
3097 | auto cases = |
3098 | getOperation()->getAttrOfType<mlir::ArrayAttr>(getCasesAttr()).getValue(); |
3099 | auto count = getNumConditions(); |
3100 | for (decltype(count) i = 0; i != count; ++i) { |
3101 | if (i) |
3102 | p << ", " ; |
3103 | p << cases[i] << ", " ; |
3104 | printSuccessorAtIndex(p, i); |
3105 | } |
3106 | p << ']'; |
3107 | p.printOptionalAttrDict(getOperation()->getAttrs(), |
3108 | {getCasesAttr(), getCompareOffsetAttr(), |
3109 | getTargetOffsetAttr(), |
3110 | fir::SelectTypeOp::getOperandSegmentSizeAttr()}); |
3111 | } |
3112 | |
3113 | mlir::LogicalResult fir::SelectTypeOp::verify() { |
3114 | if (!(getSelector().getType().isa<fir::BaseBoxType>())) |
3115 | return emitOpError("must be a fir.class or fir.box type" ); |
3116 | if (auto boxType = getSelector().getType().dyn_cast<fir::BoxType>()) |
3117 | if (!boxType.getEleTy().isa<mlir::NoneType>()) |
3118 | return emitOpError("selector must be polymorphic" ); |
3119 | auto typeGuardAttr = getCases(); |
3120 | for (unsigned idx = 0; idx < typeGuardAttr.size(); ++idx) |
3121 | if (typeGuardAttr[idx].isa<mlir::UnitAttr>() && |
3122 | idx != typeGuardAttr.size() - 1) |
3123 | return emitOpError("default must be the last attribute" ); |
3124 | auto count = getNumDest(); |
3125 | if (count == 0) |
3126 | return emitOpError("must have at least one successor" ); |
3127 | if (getNumConditions() != count) |
3128 | return emitOpError("number of conditions and successors don't match" ); |
3129 | if (targetOffsetSize() != count) |
3130 | return emitOpError("incorrect number of successor operand groups" ); |
3131 | for (unsigned i = 0; i != count; ++i) { |
3132 | if (!(typeGuardAttr[i].isa<fir::ExactTypeAttr>() || |
3133 | typeGuardAttr[i].isa<fir::SubclassAttr>() || |
3134 | typeGuardAttr[i].isa<mlir::UnitAttr>())) |
3135 | return emitOpError("invalid type-case alternative" ); |
3136 | } |
3137 | return mlir::success(); |
3138 | } |
3139 | |
3140 | void fir::SelectTypeOp::build(mlir::OpBuilder &builder, |
3141 | mlir::OperationState &result, |
3142 | mlir::Value selector, |
3143 | llvm::ArrayRef<mlir::Attribute> typeOperands, |
3144 | llvm::ArrayRef<mlir::Block *> destinations, |
3145 | llvm::ArrayRef<mlir::ValueRange> destOperands, |
3146 | llvm::ArrayRef<mlir::NamedAttribute> attributes) { |
3147 | result.addOperands(selector); |
3148 | result.addAttribute(getCasesAttr(), builder.getArrayAttr(typeOperands)); |
3149 | const auto count = destinations.size(); |
3150 | for (mlir::Block *dest : destinations) |
3151 | result.addSuccessors(dest); |
3152 | const auto opCount = destOperands.size(); |
3153 | llvm::SmallVector<int32_t> argOffs; |
3154 | int32_t sumArgs = 0; |
3155 | for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { |
3156 | if (i < opCount) { |
3157 | result.addOperands(destOperands[i]); |
3158 | const auto argSz = destOperands[i].size(); |
3159 | argOffs.push_back(argSz); |
3160 | sumArgs += argSz; |
3161 | } else { |
3162 | argOffs.push_back(0); |
3163 | } |
3164 | } |
3165 | result.addAttribute(getOperandSegmentSizeAttr(), |
3166 | builder.getDenseI32ArrayAttr({1, 0, sumArgs})); |
3167 | result.addAttribute(getTargetOffsetAttr(), |
3168 | builder.getDenseI32ArrayAttr(argOffs)); |
3169 | result.addAttributes(attributes); |
3170 | } |
3171 | |
3172 | //===----------------------------------------------------------------------===// |
3173 | // ShapeOp |
3174 | //===----------------------------------------------------------------------===// |
3175 | |
3176 | mlir::LogicalResult fir::ShapeOp::verify() { |
3177 | auto size = getExtents().size(); |
3178 | auto shapeTy = getType().dyn_cast<fir::ShapeType>(); |
3179 | assert(shapeTy && "must be a shape type" ); |
3180 | if (shapeTy.getRank() != size) |
3181 | return emitOpError("shape type rank mismatch" ); |
3182 | return mlir::success(); |
3183 | } |
3184 | |
3185 | void fir::ShapeOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
3186 | mlir::ValueRange extents) { |
3187 | auto type = fir::ShapeType::get(builder.getContext(), extents.size()); |
3188 | build(builder, result, type, extents); |
3189 | } |
3190 | |
3191 | //===----------------------------------------------------------------------===// |
3192 | // ShapeShiftOp |
3193 | //===----------------------------------------------------------------------===// |
3194 | |
3195 | mlir::LogicalResult fir::ShapeShiftOp::verify() { |
3196 | auto size = getPairs().size(); |
3197 | if (size < 2 || size > 16 * 2) |
3198 | return emitOpError("incorrect number of args" ); |
3199 | if (size % 2 != 0) |
3200 | return emitOpError("requires a multiple of 2 args" ); |
3201 | auto shapeTy = getType().dyn_cast<fir::ShapeShiftType>(); |
3202 | assert(shapeTy && "must be a shape shift type" ); |
3203 | if (shapeTy.getRank() * 2 != size) |
3204 | return emitOpError("shape type rank mismatch" ); |
3205 | return mlir::success(); |
3206 | } |
3207 | |
3208 | //===----------------------------------------------------------------------===// |
3209 | // ShiftOp |
3210 | //===----------------------------------------------------------------------===// |
3211 | |
3212 | mlir::LogicalResult fir::ShiftOp::verify() { |
3213 | auto size = getOrigins().size(); |
3214 | auto shiftTy = getType().dyn_cast<fir::ShiftType>(); |
3215 | assert(shiftTy && "must be a shift type" ); |
3216 | if (shiftTy.getRank() != size) |
3217 | return emitOpError("shift type rank mismatch" ); |
3218 | return mlir::success(); |
3219 | } |
3220 | |
3221 | //===----------------------------------------------------------------------===// |
3222 | // SliceOp |
3223 | //===----------------------------------------------------------------------===// |
3224 | |
3225 | void fir::SliceOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
3226 | mlir::ValueRange trips, mlir::ValueRange path, |
3227 | mlir::ValueRange substr) { |
3228 | const auto rank = trips.size() / 3; |
3229 | auto sliceTy = fir::SliceType::get(builder.getContext(), rank); |
3230 | build(builder, result, sliceTy, trips, path, substr); |
3231 | } |
3232 | |
3233 | /// Return the output rank of a slice op. The output rank must be between 1 and |
3234 | /// the rank of the array being sliced (inclusive). |
3235 | unsigned fir::SliceOp::getOutputRank(mlir::ValueRange triples) { |
3236 | unsigned rank = 0; |
3237 | if (!triples.empty()) { |
3238 | for (unsigned i = 1, end = triples.size(); i < end; i += 3) { |
3239 | auto *op = triples[i].getDefiningOp(); |
3240 | if (!mlir::isa_and_nonnull<fir::UndefOp>(op)) |
3241 | ++rank; |
3242 | } |
3243 | assert(rank > 0); |
3244 | } |
3245 | return rank; |
3246 | } |
3247 | |
3248 | mlir::LogicalResult fir::SliceOp::verify() { |
3249 | auto size = getTriples().size(); |
3250 | if (size < 3 || size > 16 * 3) |
3251 | return emitOpError("incorrect number of args for triple" ); |
3252 | if (size % 3 != 0) |
3253 | return emitOpError("requires a multiple of 3 args" ); |
3254 | auto sliceTy = getType().dyn_cast<fir::SliceType>(); |
3255 | assert(sliceTy && "must be a slice type" ); |
3256 | if (sliceTy.getRank() * 3 != size) |
3257 | return emitOpError("slice type rank mismatch" ); |
3258 | return mlir::success(); |
3259 | } |
3260 | |
3261 | //===----------------------------------------------------------------------===// |
3262 | // StoreOp |
3263 | //===----------------------------------------------------------------------===// |
3264 | |
3265 | mlir::Type fir::StoreOp::elementType(mlir::Type refType) { |
3266 | return fir::dyn_cast_ptrEleTy(refType); |
3267 | } |
3268 | |
3269 | mlir::ParseResult fir::StoreOp::parse(mlir::OpAsmParser &parser, |
3270 | mlir::OperationState &result) { |
3271 | mlir::Type type; |
3272 | mlir::OpAsmParser::UnresolvedOperand oper; |
3273 | mlir::OpAsmParser::UnresolvedOperand store; |
3274 | if (parser.parseOperand(oper) || parser.parseKeyword("to" ) || |
3275 | parser.parseOperand(store) || |
3276 | parser.parseOptionalAttrDict(result.attributes) || |
3277 | parser.parseColonType(type) || |
3278 | parser.resolveOperand(oper, fir::StoreOp::elementType(type), |
3279 | result.operands) || |
3280 | parser.resolveOperand(store, type, result.operands)) |
3281 | return mlir::failure(); |
3282 | return mlir::success(); |
3283 | } |
3284 | |
3285 | void fir::StoreOp::print(mlir::OpAsmPrinter &p) { |
3286 | p << ' '; |
3287 | p.printOperand(getValue()); |
3288 | p << " to " ; |
3289 | p.printOperand(getMemref()); |
3290 | p.printOptionalAttrDict(getOperation()->getAttrs(), {}); |
3291 | p << " : " << getMemref().getType(); |
3292 | } |
3293 | |
3294 | mlir::LogicalResult fir::StoreOp::verify() { |
3295 | if (getValue().getType() != fir::dyn_cast_ptrEleTy(getMemref().getType())) |
3296 | return emitOpError("store value type must match memory reference type" ); |
3297 | if (fir::isa_unknown_size_box(getValue().getType())) |
3298 | return emitOpError("cannot store !fir.box of unknown rank or type" ); |
3299 | return mlir::success(); |
3300 | } |
3301 | |
3302 | void fir::StoreOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
3303 | mlir::Value value, mlir::Value memref) { |
3304 | build(builder, result, value, memref, {}); |
3305 | } |
3306 | |
3307 | //===----------------------------------------------------------------------===// |
3308 | // StringLitOp |
3309 | //===----------------------------------------------------------------------===// |
3310 | |
3311 | inline fir::CharacterType::KindTy stringLitOpGetKind(fir::StringLitOp op) { |
3312 | auto eleTy = op.getType().cast<fir::SequenceType>().getEleTy(); |
3313 | return eleTy.cast<fir::CharacterType>().getFKind(); |
3314 | } |
3315 | |
3316 | bool fir::StringLitOp::isWideValue() { return stringLitOpGetKind(*this) != 1; } |
3317 | |
3318 | static mlir::NamedAttribute |
3319 | mkNamedIntegerAttr(mlir::OpBuilder &builder, llvm::StringRef name, int64_t v) { |
3320 | assert(v > 0); |
3321 | return builder.getNamedAttr( |
3322 | name, builder.getIntegerAttr(builder.getIntegerType(64), v)); |
3323 | } |
3324 | |
3325 | void fir::StringLitOp::build(mlir::OpBuilder &builder, |
3326 | mlir::OperationState &result, |
3327 | fir::CharacterType inType, llvm::StringRef val, |
3328 | std::optional<int64_t> len) { |
3329 | auto valAttr = builder.getNamedAttr(value(), builder.getStringAttr(val)); |
3330 | int64_t length = len ? *len : inType.getLen(); |
3331 | auto lenAttr = mkNamedIntegerAttr(builder, size(), length); |
3332 | result.addAttributes({valAttr, lenAttr}); |
3333 | result.addTypes(inType); |
3334 | } |
3335 | |
3336 | template <typename C> |
3337 | static mlir::ArrayAttr convertToArrayAttr(mlir::OpBuilder &builder, |
3338 | llvm::ArrayRef<C> xlist) { |
3339 | llvm::SmallVector<mlir::Attribute> attrs; |
3340 | auto ty = builder.getIntegerType(8 * sizeof(C)); |
3341 | for (auto ch : xlist) |
3342 | attrs.push_back(Elt: builder.getIntegerAttr(ty, ch)); |
3343 | return builder.getArrayAttr(attrs); |
3344 | } |
3345 | |
3346 | void fir::StringLitOp::build(mlir::OpBuilder &builder, |
3347 | mlir::OperationState &result, |
3348 | fir::CharacterType inType, |
3349 | llvm::ArrayRef<char> vlist, |
3350 | std::optional<std::int64_t> len) { |
3351 | auto valAttr = |
3352 | builder.getNamedAttr(xlist(), convertToArrayAttr(builder, vlist)); |
3353 | std::int64_t length = len ? *len : inType.getLen(); |
3354 | auto lenAttr = mkNamedIntegerAttr(builder, size(), length); |
3355 | result.addAttributes({valAttr, lenAttr}); |
3356 | result.addTypes(inType); |
3357 | } |
3358 | |
3359 | void fir::StringLitOp::build(mlir::OpBuilder &builder, |
3360 | mlir::OperationState &result, |
3361 | fir::CharacterType inType, |
3362 | llvm::ArrayRef<char16_t> vlist, |
3363 | std::optional<std::int64_t> len) { |
3364 | auto valAttr = |
3365 | builder.getNamedAttr(xlist(), convertToArrayAttr(builder, vlist)); |
3366 | std::int64_t length = len ? *len : inType.getLen(); |
3367 | auto lenAttr = mkNamedIntegerAttr(builder, size(), length); |
3368 | result.addAttributes({valAttr, lenAttr}); |
3369 | result.addTypes(inType); |
3370 | } |
3371 | |
3372 | void fir::StringLitOp::build(mlir::OpBuilder &builder, |
3373 | mlir::OperationState &result, |
3374 | fir::CharacterType inType, |
3375 | llvm::ArrayRef<char32_t> vlist, |
3376 | std::optional<std::int64_t> len) { |
3377 | auto valAttr = |
3378 | builder.getNamedAttr(xlist(), convertToArrayAttr(builder, vlist)); |
3379 | std::int64_t length = len ? *len : inType.getLen(); |
3380 | auto lenAttr = mkNamedIntegerAttr(builder, size(), length); |
3381 | result.addAttributes({valAttr, lenAttr}); |
3382 | result.addTypes(inType); |
3383 | } |
3384 | |
3385 | mlir::ParseResult fir::StringLitOp::parse(mlir::OpAsmParser &parser, |
3386 | mlir::OperationState &result) { |
3387 | auto &builder = parser.getBuilder(); |
3388 | mlir::Attribute val; |
3389 | mlir::NamedAttrList attrs; |
3390 | llvm::SMLoc trailingTypeLoc; |
3391 | if (parser.parseAttribute(val, "fake" , attrs)) |
3392 | return mlir::failure(); |
3393 | if (auto v = val.dyn_cast<mlir::StringAttr>()) |
3394 | result.attributes.push_back( |
3395 | builder.getNamedAttr(fir::StringLitOp::value(), v)); |
3396 | else if (auto v = val.dyn_cast<mlir::DenseElementsAttr>()) |
3397 | result.attributes.push_back( |
3398 | builder.getNamedAttr(fir::StringLitOp::xlist(), v)); |
3399 | else if (auto v = val.dyn_cast<mlir::ArrayAttr>()) |
3400 | result.attributes.push_back( |
3401 | builder.getNamedAttr(fir::StringLitOp::xlist(), v)); |
3402 | else |
3403 | return parser.emitError(parser.getCurrentLocation(), |
3404 | "found an invalid constant" ); |
3405 | mlir::IntegerAttr sz; |
3406 | mlir::Type type; |
3407 | if (parser.parseLParen() || |
3408 | parser.parseAttribute(sz, fir::StringLitOp::size(), result.attributes) || |
3409 | parser.parseRParen() || parser.getCurrentLocation(&trailingTypeLoc) || |
3410 | parser.parseColonType(type)) |
3411 | return mlir::failure(); |
3412 | auto charTy = type.dyn_cast<fir::CharacterType>(); |
3413 | if (!charTy) |
3414 | return parser.emitError(trailingTypeLoc, "must have character type" ); |
3415 | type = fir::CharacterType::get(builder.getContext(), charTy.getFKind(), |
3416 | sz.getInt()); |
3417 | if (!type || parser.addTypesToList(type, result.types)) |
3418 | return mlir::failure(); |
3419 | return mlir::success(); |
3420 | } |
3421 | |
3422 | void fir::StringLitOp::print(mlir::OpAsmPrinter &p) { |
3423 | p << ' ' << getValue() << '('; |
3424 | p << getSize().cast<mlir::IntegerAttr>().getValue() << ") : " ; |
3425 | p.printType(getType()); |
3426 | } |
3427 | |
3428 | mlir::LogicalResult fir::StringLitOp::verify() { |
3429 | if (getSize().cast<mlir::IntegerAttr>().getValue().isNegative()) |
3430 | return emitOpError("size must be non-negative" ); |
3431 | if (auto xl = getOperation()->getAttr(fir::StringLitOp::xlist())) { |
3432 | if (auto xList = xl.dyn_cast<mlir::ArrayAttr>()) { |
3433 | for (auto a : xList) |
3434 | if (!a.isa<mlir::IntegerAttr>()) |
3435 | return emitOpError("values in initializer must be integers" ); |
3436 | } else if (xl.isa<mlir::DenseElementsAttr>()) { |
3437 | // do nothing |
3438 | } else { |
3439 | return emitOpError("has unexpected attribute" ); |
3440 | } |
3441 | } |
3442 | return mlir::success(); |
3443 | } |
3444 | |
3445 | //===----------------------------------------------------------------------===// |
3446 | // UnboxProcOp |
3447 | //===----------------------------------------------------------------------===// |
3448 | |
3449 | mlir::LogicalResult fir::UnboxProcOp::verify() { |
3450 | if (auto eleTy = fir::dyn_cast_ptrEleTy(getRefTuple().getType())) |
3451 | if (eleTy.isa<mlir::TupleType>()) |
3452 | return mlir::success(); |
3453 | return emitOpError("second output argument has bad type" ); |
3454 | } |
3455 | |
3456 | //===----------------------------------------------------------------------===// |
3457 | // IfOp |
3458 | //===----------------------------------------------------------------------===// |
3459 | |
3460 | void fir::IfOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
3461 | mlir::Value cond, bool withElseRegion) { |
3462 | build(builder, result, std::nullopt, cond, withElseRegion); |
3463 | } |
3464 | |
3465 | void fir::IfOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, |
3466 | mlir::TypeRange resultTypes, mlir::Value cond, |
3467 | bool withElseRegion) { |
3468 | result.addOperands(cond); |
3469 | result.addTypes(resultTypes); |
3470 | |
3471 | mlir::Region *thenRegion = result.addRegion(); |
3472 | thenRegion->push_back(new mlir::Block()); |
3473 | if (resultTypes.empty()) |
3474 | IfOp::ensureTerminator(*thenRegion, builder, result.location); |
3475 | |
3476 | mlir::Region *elseRegion = result.addRegion(); |
3477 | if (withElseRegion) { |
3478 | elseRegion->push_back(new mlir::Block()); |
3479 | if (resultTypes.empty()) |
3480 | IfOp::ensureTerminator(*elseRegion, builder, result.location); |
3481 | } |
3482 | } |
3483 | |
3484 | // These 3 functions copied from scf.if implementation. |
3485 | |
3486 | /// Given the region at `index`, or the parent operation if `index` is None, |
3487 | /// return the successor regions. These are the regions that may be selected |
3488 | /// during the flow of control. |
3489 | void fir::IfOp::getSuccessorRegions( |
3490 | mlir::RegionBranchPoint point, |
3491 | llvm::SmallVectorImpl<mlir::RegionSuccessor> ®ions) { |
3492 | // The `then` and the `else` region branch back to the parent operation. |
3493 | if (!point.isParent()) { |
3494 | regions.push_back(mlir::RegionSuccessor(getResults())); |
3495 | return; |
3496 | } |
3497 | |
3498 | // Don't consider the else region if it is empty. |
3499 | regions.push_back(mlir::RegionSuccessor(&getThenRegion())); |
3500 | |
3501 | // Don't consider the else region if it is empty. |
3502 | mlir::Region *elseRegion = &this->getElseRegion(); |
3503 | if (elseRegion->empty()) |
3504 | regions.push_back(mlir::RegionSuccessor()); |
3505 | else |
3506 | regions.push_back(mlir::RegionSuccessor(elseRegion)); |
3507 | } |
3508 | |
3509 | void fir::IfOp::getEntrySuccessorRegions( |
3510 | llvm::ArrayRef<mlir::Attribute> operands, |
3511 | llvm::SmallVectorImpl<mlir::RegionSuccessor> ®ions) { |
3512 | FoldAdaptor adaptor(operands); |
3513 | auto boolAttr = |
3514 | mlir::dyn_cast_or_null<mlir::BoolAttr>(adaptor.getCondition()); |
3515 | if (!boolAttr || boolAttr.getValue()) |
3516 | regions.emplace_back(&getThenRegion()); |
3517 | |
3518 | // If the else region is empty, execution continues after the parent op. |
3519 | if (!boolAttr || !boolAttr.getValue()) { |
3520 | if (!getElseRegion().empty()) |
3521 | regions.emplace_back(&getElseRegion()); |
3522 | else |
3523 | regions.emplace_back(getResults()); |
3524 | } |
3525 | } |
3526 | |
3527 | void fir::IfOp::getRegionInvocationBounds( |
3528 | llvm::ArrayRef<mlir::Attribute> operands, |
3529 | llvm::SmallVectorImpl<mlir::InvocationBounds> &invocationBounds) { |
3530 | if (auto cond = operands[0].dyn_cast_or_null<mlir::BoolAttr>()) { |
3531 | // If the condition is known, then one region is known to be executed once |
3532 | // and the other zero times. |
3533 | invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0); |
3534 | invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1); |
3535 | } else { |
3536 | // Non-constant condition. Each region may be executed 0 or 1 times. |
3537 | invocationBounds.assign(2, {0, 1}); |
3538 | } |
3539 | } |
3540 | |
3541 | mlir::ParseResult fir::IfOp::parse(mlir::OpAsmParser &parser, |
3542 | mlir::OperationState &result) { |
3543 | result.regions.reserve(2); |
3544 | mlir::Region *thenRegion = result.addRegion(); |
3545 | mlir::Region *elseRegion = result.addRegion(); |
3546 | |
3547 | auto &builder = parser.getBuilder(); |
3548 | mlir::OpAsmParser::UnresolvedOperand cond; |
3549 | mlir::Type i1Type = builder.getIntegerType(1); |
3550 | if (parser.parseOperand(cond) || |
3551 | parser.resolveOperand(cond, i1Type, result.operands)) |
3552 | return mlir::failure(); |
3553 | |
3554 | if (parser.parseOptionalArrowTypeList(result.types)) |
3555 | return mlir::failure(); |
3556 | |
3557 | if (parser.parseRegion(*thenRegion, {}, {})) |
3558 | return mlir::failure(); |
3559 | fir::IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), |
3560 | result.location); |
3561 | |
3562 | if (mlir::succeeded(parser.parseOptionalKeyword("else" ))) { |
3563 | if (parser.parseRegion(*elseRegion, {}, {})) |
3564 | return mlir::failure(); |
3565 | fir::IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), |
3566 | result.location); |
3567 | } |
3568 | |
3569 | // Parse the optional attribute list. |
3570 | if (parser.parseOptionalAttrDict(result.attributes)) |
3571 | return mlir::failure(); |
3572 | return mlir::success(); |
3573 | } |
3574 | |
3575 | mlir::LogicalResult fir::IfOp::verify() { |
3576 | if (getNumResults() != 0 && getElseRegion().empty()) |
3577 | return emitOpError("must have an else block if defining values" ); |
3578 | |
3579 | return mlir::success(); |
3580 | } |
3581 | |
3582 | void fir::IfOp::print(mlir::OpAsmPrinter &p) { |
3583 | bool printBlockTerminators = false; |
3584 | p << ' ' << getCondition(); |
3585 | if (!getResults().empty()) { |
3586 | p << " -> (" << getResultTypes() << ')'; |
3587 | printBlockTerminators = true; |
3588 | } |
3589 | p << ' '; |
3590 | p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false, |
3591 | printBlockTerminators); |
3592 | |
3593 | // Print the 'else' regions if it exists and has a block. |
3594 | auto &otherReg = getElseRegion(); |
3595 | if (!otherReg.empty()) { |
3596 | p << " else " ; |
3597 | p.printRegion(otherReg, /*printEntryBlockArgs=*/false, |
3598 | printBlockTerminators); |
3599 | } |
3600 | p.printOptionalAttrDict((*this)->getAttrs()); |
3601 | } |
3602 | |
3603 | void fir::IfOp::resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> &results, |
3604 | unsigned resultNum) { |
3605 | auto *term = getThenRegion().front().getTerminator(); |
3606 | if (resultNum < term->getNumOperands()) |
3607 | results.push_back(term->getOperand(resultNum)); |
3608 | term = getElseRegion().front().getTerminator(); |
3609 | if (resultNum < term->getNumOperands()) |
3610 | results.push_back(term->getOperand(resultNum)); |
3611 | } |
3612 | |
3613 | //===----------------------------------------------------------------------===// |
3614 | // BoxOffsetOp |
3615 | //===----------------------------------------------------------------------===// |
3616 | |
3617 | mlir::LogicalResult fir::BoxOffsetOp::verify() { |
3618 | auto boxType = mlir::dyn_cast_or_null<fir::BaseBoxType>( |
3619 | fir::dyn_cast_ptrEleTy(getBoxRef().getType())); |
3620 | if (!boxType) |
3621 | return emitOpError("box_ref operand must have !fir.ref<!fir.box<T>> type" ); |
3622 | if (getField() != fir::BoxFieldAttr::base_addr && |
3623 | getField() != fir::BoxFieldAttr::derived_type) |
3624 | return emitOpError("cannot address provided field" ); |
3625 | if (getField() == fir::BoxFieldAttr::derived_type) |
3626 | if (!fir::boxHasAddendum(boxType)) |
3627 | return emitOpError("can only address derived_type field of derived type " |
3628 | "or unlimited polymorphic fir.box" ); |
3629 | return mlir::success(); |
3630 | } |
3631 | |
3632 | void fir::BoxOffsetOp::build(mlir::OpBuilder &builder, |
3633 | mlir::OperationState &result, mlir::Value boxRef, |
3634 | fir::BoxFieldAttr field) { |
3635 | mlir::Type valueType = |
3636 | fir::unwrapPassByRefType(fir::unwrapRefType(boxRef.getType())); |
3637 | mlir::Type resultType = valueType; |
3638 | if (field == fir::BoxFieldAttr::base_addr) |
3639 | resultType = fir::LLVMPointerType::get(fir::ReferenceType::get(valueType)); |
3640 | else if (field == fir::BoxFieldAttr::derived_type) |
3641 | resultType = fir::LLVMPointerType::get( |
3642 | fir::TypeDescType::get(fir::unwrapSequenceType(valueType))); |
3643 | build(builder, result, {resultType}, boxRef, field); |
3644 | } |
3645 | |
3646 | //===----------------------------------------------------------------------===// |
3647 | |
3648 | mlir::ParseResult fir::isValidCaseAttr(mlir::Attribute attr) { |
3649 | if (attr.isa<mlir::UnitAttr, fir::ClosedIntervalAttr, fir::PointIntervalAttr, |
3650 | fir::LowerBoundAttr, fir::UpperBoundAttr>()) |
3651 | return mlir::success(); |
3652 | return mlir::failure(); |
3653 | } |
3654 | |
3655 | unsigned fir::getCaseArgumentOffset(llvm::ArrayRef<mlir::Attribute> cases, |
3656 | unsigned dest) { |
3657 | unsigned o = 0; |
3658 | for (unsigned i = 0; i < dest; ++i) { |
3659 | auto &attr = cases[i]; |
3660 | if (!attr.dyn_cast_or_null<mlir::UnitAttr>()) { |
3661 | ++o; |
3662 | if (attr.dyn_cast_or_null<fir::ClosedIntervalAttr>()) |
3663 | ++o; |
3664 | } |
3665 | } |
3666 | return o; |
3667 | } |
3668 | |
3669 | mlir::ParseResult |
3670 | fir::parseSelector(mlir::OpAsmParser &parser, mlir::OperationState &result, |
3671 | mlir::OpAsmParser::UnresolvedOperand &selector, |
3672 | mlir::Type &type) { |
3673 | if (parser.parseOperand(selector) || parser.parseColonType(type) || |
3674 | parser.resolveOperand(selector, type, result.operands) || |
3675 | parser.parseLSquare()) |
3676 | return mlir::failure(); |
3677 | return mlir::success(); |
3678 | } |
3679 | |
3680 | mlir::func::FuncOp fir::createFuncOp(mlir::Location loc, mlir::ModuleOp module, |
3681 | llvm::StringRef name, |
3682 | mlir::FunctionType type, |
3683 | llvm::ArrayRef<mlir::NamedAttribute> attrs, |
3684 | const mlir::SymbolTable *symbolTable) { |
3685 | if (symbolTable) |
3686 | if (auto f = symbolTable->lookup<mlir::func::FuncOp>(name)) { |
3687 | #ifdef EXPENSIVE_CHECKS |
3688 | assert(f == module.lookupSymbol<mlir::func::FuncOp>(name) && |
3689 | "symbolTable and module out of sync" ); |
3690 | #endif |
3691 | return f; |
3692 | } |
3693 | if (auto f = module.lookupSymbol<mlir::func::FuncOp>(name)) |
3694 | return f; |
3695 | mlir::OpBuilder modBuilder(module.getBodyRegion()); |
3696 | modBuilder.setInsertionPointToEnd(module.getBody()); |
3697 | auto result = modBuilder.create<mlir::func::FuncOp>(loc, name, type, attrs); |
3698 | result.setVisibility(mlir::SymbolTable::Visibility::Private); |
3699 | return result; |
3700 | } |
3701 | |
3702 | fir::GlobalOp fir::createGlobalOp(mlir::Location loc, mlir::ModuleOp module, |
3703 | llvm::StringRef name, mlir::Type type, |
3704 | llvm::ArrayRef<mlir::NamedAttribute> attrs, |
3705 | const mlir::SymbolTable *symbolTable) { |
3706 | if (symbolTable) |
3707 | if (auto g = symbolTable->lookup<fir::GlobalOp>(name)) { |
3708 | #ifdef EXPENSIVE_CHECKS |
3709 | assert(g == module.lookupSymbol<fir::GlobalOp>(name) && |
3710 | "symbolTable and module out of sync" ); |
3711 | #endif |
3712 | return g; |
3713 | } |
3714 | if (auto g = module.lookupSymbol<fir::GlobalOp>(name)) |
3715 | return g; |
3716 | mlir::OpBuilder modBuilder(module.getBodyRegion()); |
3717 | auto result = modBuilder.create<fir::GlobalOp>(loc, name, type, attrs); |
3718 | result.setVisibility(mlir::SymbolTable::Visibility::Private); |
3719 | return result; |
3720 | } |
3721 | |
3722 | bool fir::hasHostAssociationArgument(mlir::func::FuncOp func) { |
3723 | if (auto allArgAttrs = func.getAllArgAttrs()) |
3724 | for (auto attr : allArgAttrs) |
3725 | if (auto dict = attr.template dyn_cast_or_null<mlir::DictionaryAttr>()) |
3726 | if (dict.get(fir::getHostAssocAttrName())) |
3727 | return true; |
3728 | return false; |
3729 | } |
3730 | |
3731 | // Test if value's definition has the specified set of |
3732 | // attributeNames. The value's definition is one of the operations |
3733 | // that are able to carry the Fortran variable attributes, e.g. |
3734 | // fir.alloca or fir.allocmem. Function arguments may also represent |
3735 | // value definitions and carry relevant attributes. |
3736 | // |
3737 | // If it is not possible to reach the limited set of definition |
3738 | // entities from the given value, then the function will return |
3739 | // std::nullopt. Otherwise, the definition is known and the return |
3740 | // value is computed as: |
3741 | // * if checkAny is true, then the function will return true |
3742 | // iff any of the attributeNames attributes is set on the definition. |
3743 | // * if checkAny is false, then the function will return true |
3744 | // iff all of the attributeNames attributes are set on the definition. |
3745 | static std::optional<bool> |
3746 | valueCheckFirAttributes(mlir::Value value, |
3747 | llvm::ArrayRef<llvm::StringRef> attributeNames, |
3748 | bool checkAny) { |
3749 | auto testAttributeSets = [&](llvm::ArrayRef<mlir::NamedAttribute> setAttrs, |
3750 | llvm::ArrayRef<llvm::StringRef> checkAttrs) { |
3751 | if (checkAny) { |
3752 | // Return true iff any of checkAttrs attributes is present |
3753 | // in setAttrs set. |
3754 | for (llvm::StringRef checkAttrName : checkAttrs) |
3755 | if (llvm::any_of(Range&: setAttrs, P: [&](mlir::NamedAttribute setAttr) { |
3756 | return setAttr.getName() == checkAttrName; |
3757 | })) |
3758 | return true; |
3759 | |
3760 | return false; |
3761 | } |
3762 | |
3763 | // Return true iff all attributes from checkAttrs are present |
3764 | // in setAttrs set. |
3765 | for (mlir::StringRef checkAttrName : checkAttrs) |
3766 | if (llvm::none_of(Range&: setAttrs, P: [&](mlir::NamedAttribute setAttr) { |
3767 | return setAttr.getName() == checkAttrName; |
3768 | })) |
3769 | return false; |
3770 | |
3771 | return true; |
3772 | }; |
3773 | // If this is a fir.box that was loaded, the fir attributes will be on the |
3774 | // related fir.ref<fir.box> creation. |
3775 | if (value.getType().isa<fir::BoxType>()) |
3776 | if (auto definingOp = value.getDefiningOp()) |
3777 | if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(definingOp)) |
3778 | value = loadOp.getMemref(); |
3779 | // If this is a function argument, look in the argument attributes. |
3780 | if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(Val&: value)) { |
3781 | if (blockArg.getOwner() && blockArg.getOwner()->isEntryBlock()) |
3782 | if (auto funcOp = mlir::dyn_cast<mlir::func::FuncOp>( |
3783 | blockArg.getOwner()->getParentOp())) |
3784 | return testAttributeSets( |
3785 | mlir::cast<mlir::FunctionOpInterface>(*funcOp).getArgAttrs( |
3786 | blockArg.getArgNumber()), |
3787 | attributeNames); |
3788 | |
3789 | // If it is not a function argument, the attributes are unknown. |
3790 | return std::nullopt; |
3791 | } |
3792 | |
3793 | if (auto definingOp = value.getDefiningOp()) { |
3794 | // If this is an allocated value, look at the allocation attributes. |
3795 | if (mlir::isa<fir::AllocMemOp>(definingOp) || |
3796 | mlir::isa<fir::AllocaOp>(definingOp)) |
3797 | return testAttributeSets(definingOp->getAttrs(), attributeNames); |
3798 | // If this is an imported global, look at AddrOfOp and GlobalOp attributes. |
3799 | // Both operations are looked at because use/host associated variable (the |
3800 | // AddrOfOp) can have ASYNCHRONOUS/VOLATILE attributes even if the ultimate |
3801 | // entity (the globalOp) does not have them. |
3802 | if (auto addressOfOp = mlir::dyn_cast<fir::AddrOfOp>(definingOp)) { |
3803 | if (testAttributeSets(addressOfOp->getAttrs(), attributeNames)) |
3804 | return true; |
3805 | if (auto module = definingOp->getParentOfType<mlir::ModuleOp>()) |
3806 | if (auto globalOp = |
3807 | module.lookupSymbol<fir::GlobalOp>(addressOfOp.getSymbol())) |
3808 | return testAttributeSets(globalOp->getAttrs(), attributeNames); |
3809 | } |
3810 | } |
3811 | // TODO: Construct associated entities attributes. Decide where the fir |
3812 | // attributes must be placed/looked for in this case. |
3813 | return std::nullopt; |
3814 | } |
3815 | |
3816 | bool fir::valueMayHaveFirAttributes( |
3817 | mlir::Value value, llvm::ArrayRef<llvm::StringRef> attributeNames) { |
3818 | std::optional<bool> mayHaveAttr = |
3819 | valueCheckFirAttributes(value, attributeNames, /*checkAny=*/true); |
3820 | return mayHaveAttr.value_or(true); |
3821 | } |
3822 | |
3823 | bool fir::valueHasFirAttribute(mlir::Value value, |
3824 | llvm::StringRef attributeName) { |
3825 | std::optional<bool> mayHaveAttr = |
3826 | valueCheckFirAttributes(value, {attributeName}, /*checkAny=*/false); |
3827 | return mayHaveAttr.value_or(false); |
3828 | } |
3829 | |
3830 | bool fir::anyFuncArgsHaveAttr(mlir::func::FuncOp func, llvm::StringRef attr) { |
3831 | for (unsigned i = 0, end = func.getNumArguments(); i < end; ++i) |
3832 | if (func.getArgAttr(i, attr)) |
3833 | return true; |
3834 | return false; |
3835 | } |
3836 | |
3837 | std::optional<std::int64_t> fir::getIntIfConstant(mlir::Value value) { |
3838 | if (auto *definingOp = value.getDefiningOp()) { |
3839 | if (auto cst = mlir::dyn_cast<mlir::arith::ConstantOp>(definingOp)) |
3840 | if (auto intAttr = cst.getValue().dyn_cast<mlir::IntegerAttr>()) |
3841 | return intAttr.getInt(); |
3842 | if (auto llConstOp = mlir::dyn_cast<mlir::LLVM::ConstantOp>(definingOp)) |
3843 | if (auto attr = llConstOp.getValue().dyn_cast<mlir::IntegerAttr>()) |
3844 | return attr.getValue().getSExtValue(); |
3845 | } |
3846 | return {}; |
3847 | } |
3848 | |
3849 | mlir::Type fir::applyPathToType(mlir::Type eleTy, mlir::ValueRange path) { |
3850 | for (auto i = path.begin(), end = path.end(); eleTy && i < end;) { |
3851 | eleTy = llvm::TypeSwitch<mlir::Type, mlir::Type>(eleTy) |
3852 | .Case<fir::RecordType>([&](fir::RecordType ty) { |
3853 | if (auto *op = (*i++).getDefiningOp()) { |
3854 | if (auto off = mlir::dyn_cast<fir::FieldIndexOp>(op)) |
3855 | return ty.getType(off.getFieldName()); |
3856 | if (auto off = mlir::dyn_cast<mlir::arith::ConstantOp>(op)) |
3857 | return ty.getType(fir::toInt(off)); |
3858 | } |
3859 | return mlir::Type{}; |
3860 | }) |
3861 | .Case<fir::SequenceType>([&](fir::SequenceType ty) { |
3862 | bool valid = true; |
3863 | const auto rank = ty.getDimension(); |
3864 | for (std::remove_const_t<decltype(rank)> ii = 0; |
3865 | valid && ii < rank; ++ii) |
3866 | valid = i < end && fir::isa_integer((*i++).getType()); |
3867 | return valid ? ty.getEleTy() : mlir::Type{}; |
3868 | }) |
3869 | .Case<mlir::TupleType>([&](mlir::TupleType ty) { |
3870 | if (auto *op = (*i++).getDefiningOp()) |
3871 | if (auto off = mlir::dyn_cast<mlir::arith::ConstantOp>(op)) |
3872 | return ty.getType(fir::toInt(off)); |
3873 | return mlir::Type{}; |
3874 | }) |
3875 | .Case<fir::ComplexType>([&](fir::ComplexType ty) { |
3876 | auto x = *i; |
3877 | if (auto *op = (*i++).getDefiningOp()) |
3878 | if (fir::isa_integer(x.getType())) |
3879 | return ty.getEleType(fir::getKindMapping( |
3880 | op->getParentOfType<mlir::ModuleOp>())); |
3881 | return mlir::Type{}; |
3882 | }) |
3883 | .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { |
3884 | if (fir::isa_integer((*i++).getType())) |
3885 | return ty.getElementType(); |
3886 | return mlir::Type{}; |
3887 | }) |
3888 | .Default([&](const auto &) { return mlir::Type{}; }); |
3889 | } |
3890 | return eleTy; |
3891 | } |
3892 | |
3893 | mlir::LogicalResult fir::DeclareOp::verify() { |
3894 | auto fortranVar = |
3895 | mlir::cast<fir::FortranVariableOpInterface>(this->getOperation()); |
3896 | return fortranVar.verifyDeclareLikeOpImpl(getMemref()); |
3897 | } |
3898 | |
3899 | llvm::SmallVector<mlir::Region *> fir::CUDAKernelOp::getLoopRegions() { |
3900 | return {&getRegion()}; |
3901 | } |
3902 | |
3903 | mlir::ParseResult parseCUFKernelValues( |
3904 | mlir::OpAsmParser &parser, |
3905 | llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &values, |
3906 | llvm::SmallVectorImpl<mlir::Type> &types) { |
3907 | if (mlir::succeeded(result: parser.parseOptionalStar())) |
3908 | return mlir::success(); |
3909 | |
3910 | if (mlir::succeeded(result: parser.parseOptionalLParen())) { |
3911 | if (mlir::failed(result: parser.parseCommaSeparatedList( |
3912 | delimiter: mlir::AsmParser::Delimiter::None, parseElementFn: [&]() { |
3913 | if (parser.parseOperand(result&: values.emplace_back())) |
3914 | return mlir::failure(); |
3915 | return mlir::success(); |
3916 | }))) |
3917 | return mlir::failure(); |
3918 | auto builder = parser.getBuilder(); |
3919 | for (size_t i = 0; i < values.size(); i++) { |
3920 | types.emplace_back(Args: builder.getI32Type()); |
3921 | } |
3922 | if (parser.parseRParen()) |
3923 | return mlir::failure(); |
3924 | } else { |
3925 | if (parser.parseOperand(result&: values.emplace_back())) |
3926 | return mlir::failure(); |
3927 | auto builder = parser.getBuilder(); |
3928 | types.emplace_back(Args: builder.getI32Type()); |
3929 | return mlir::success(); |
3930 | } |
3931 | return mlir::success(); |
3932 | } |
3933 | |
3934 | void printCUFKernelValues(mlir::OpAsmPrinter &p, mlir::Operation *op, |
3935 | mlir::ValueRange values, mlir::TypeRange types) { |
3936 | if (values.empty()) |
3937 | p << "*" ; |
3938 | |
3939 | if (values.size() > 1) |
3940 | p << "(" ; |
3941 | llvm::interleaveComma(c: values, os&: p, each_fn: [&p](mlir::Value v) { p << v; }); |
3942 | if (values.size() > 1) |
3943 | p << ")" ; |
3944 | } |
3945 | |
3946 | mlir::ParseResult parseCUFKernelLoopControl( |
3947 | mlir::OpAsmParser &parser, mlir::Region ®ion, |
3948 | llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &lowerbound, |
3949 | llvm::SmallVectorImpl<mlir::Type> &lowerboundType, |
3950 | llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &upperbound, |
3951 | llvm::SmallVectorImpl<mlir::Type> &upperboundType, |
3952 | llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &step, |
3953 | llvm::SmallVectorImpl<mlir::Type> &stepType) { |
3954 | |
3955 | llvm::SmallVector<mlir::OpAsmParser::Argument> inductionVars; |
3956 | if (parser.parseLParen() || |
3957 | parser.parseArgumentList(result&: inductionVars, |
3958 | delimiter: mlir::OpAsmParser::Delimiter::None, |
3959 | /*allowType=*/true) || |
3960 | parser.parseRParen() || parser.parseEqual() || parser.parseLParen() || |
3961 | parser.parseOperandList(result&: lowerbound, requiredOperandCount: inductionVars.size(), |
3962 | delimiter: mlir::OpAsmParser::Delimiter::None) || |
3963 | parser.parseColonTypeList(result&: lowerboundType) || parser.parseRParen() || |
3964 | parser.parseKeyword(keyword: "to" ) || parser.parseLParen() || |
3965 | parser.parseOperandList(result&: upperbound, requiredOperandCount: inductionVars.size(), |
3966 | delimiter: mlir::OpAsmParser::Delimiter::None) || |
3967 | parser.parseColonTypeList(result&: upperboundType) || parser.parseRParen() || |
3968 | parser.parseKeyword(keyword: "step" ) || parser.parseLParen() || |
3969 | parser.parseOperandList(result&: step, requiredOperandCount: inductionVars.size(), |
3970 | delimiter: mlir::OpAsmParser::Delimiter::None) || |
3971 | parser.parseColonTypeList(result&: stepType) || parser.parseRParen()) |
3972 | return mlir::failure(); |
3973 | return parser.parseRegion(region, arguments: inductionVars); |
3974 | } |
3975 | |
3976 | void printCUFKernelLoopControl( |
3977 | mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Region ®ion, |
3978 | mlir::ValueRange lowerbound, mlir::TypeRange lowerboundType, |
3979 | mlir::ValueRange upperbound, mlir::TypeRange upperboundType, |
3980 | mlir::ValueRange steps, mlir::TypeRange stepType) { |
3981 | mlir::ValueRange regionArgs = region.front().getArguments(); |
3982 | if (!regionArgs.empty()) { |
3983 | p << "(" ; |
3984 | llvm::interleaveComma( |
3985 | c: regionArgs, os&: p, each_fn: [&p](mlir::Value v) { p << v << " : " << v.getType(); }); |
3986 | p << ") = (" << lowerbound << " : " << lowerboundType << ") to (" |
3987 | << upperbound << " : " << upperboundType << ") " |
3988 | << " step (" << steps << " : " << stepType << ") " ; |
3989 | } |
3990 | p.printRegion(blocks&: region, /*printEntryBlockArgs=*/false); |
3991 | } |
3992 | |
3993 | mlir::LogicalResult fir::CUDAKernelOp::verify() { |
3994 | if (getLowerbound().size() != getUpperbound().size() || |
3995 | getLowerbound().size() != getStep().size()) |
3996 | return emitOpError( |
3997 | "expect same number of values in lowerbound, upperbound and step" ); |
3998 | |
3999 | return mlir::success(); |
4000 | } |
4001 | |
4002 | mlir::LogicalResult fir::CUDAAllocateOp::verify() { |
4003 | if (getPinned() && getStream()) |
4004 | return emitOpError("pinned and stream cannot appears at the same time" ); |
4005 | if (!fir::unwrapRefType(getBox().getType()).isa<fir::BaseBoxType>()) |
4006 | return emitOpError( |
4007 | "expect box to be a reference to a class or box type value" ); |
4008 | if (getSource() && |
4009 | !fir::unwrapRefType(getSource().getType()).isa<fir::BaseBoxType>()) |
4010 | return emitOpError( |
4011 | "expect source to be a reference to/or a class or box type value" ); |
4012 | if (getErrmsg() && |
4013 | !fir::unwrapRefType(getErrmsg().getType()).isa<fir::BoxType>()) |
4014 | return emitOpError( |
4015 | "expect errmsg to be a reference to/or a box type value" ); |
4016 | if (getErrmsg() && !getHasStat()) |
4017 | return emitOpError("expect stat attribute when errmsg is provided" ); |
4018 | return mlir::success(); |
4019 | } |
4020 | |
4021 | mlir::LogicalResult fir::CUDADeallocateOp::verify() { |
4022 | if (!fir::unwrapRefType(getBox().getType()).isa<fir::BaseBoxType>()) |
4023 | return emitOpError( |
4024 | "expect box to be a reference to class or box type value" ); |
4025 | if (getErrmsg() && |
4026 | !fir::unwrapRefType(getErrmsg().getType()).isa<fir::BoxType>()) |
4027 | return emitOpError( |
4028 | "expect errmsg to be a reference to/or a box type value" ); |
4029 | if (getErrmsg() && !getHasStat()) |
4030 | return emitOpError("expect stat attribute when errmsg is provided" ); |
4031 | return mlir::success(); |
4032 | } |
4033 | |
4034 | //===----------------------------------------------------------------------===// |
4035 | // FIROpsDialect |
4036 | //===----------------------------------------------------------------------===// |
4037 | |
4038 | void fir::FIROpsDialect::registerOpExternalInterfaces() { |
4039 | // Attach default declare target interfaces to operations which can be marked |
4040 | // as declare target. |
4041 | fir::GlobalOp::attachInterface< |
4042 | mlir::omp::DeclareTargetDefaultModel<fir::GlobalOp>>(*getContext()); |
4043 | } |
4044 | |
4045 | // Tablegen operators |
4046 | |
4047 | #define GET_OP_CLASSES |
4048 | #include "flang/Optimizer/Dialect/FIROps.cpp.inc" |
4049 | |