1//===-- HLFIROps.cpp ------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#include "flang/Optimizer/HLFIR/HLFIROps.h"
14
15#include "flang/Optimizer/Dialect/FIROpsSupport.h"
16#include "flang/Optimizer/Dialect/FIRType.h"
17#include "flang/Optimizer/Dialect/Support/FIRContext.h"
18#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/BuiltinAttributes.h"
21#include "mlir/IR/BuiltinTypes.h"
22#include "mlir/IR/DialectImplementation.h"
23#include "mlir/IR/Matchers.h"
24#include "mlir/IR/OpImplementation.h"
25#include "llvm/ADT/APInt.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/CommandLine.h"
28#include <iterator>
29#include <mlir/Interfaces/SideEffectInterfaces.h>
30#include <optional>
31#include <tuple>
32
33static llvm::cl::opt<bool> useStrictIntrinsicVerifier(
34 "strict-intrinsic-verifier", llvm::cl::init(Val: false),
35 llvm::cl::desc("use stricter verifier for HLFIR intrinsic operations"));
36
37/// generic implementation of the memory side effects interface for hlfir
38/// transformational intrinsic operations
39static void
40getIntrinsicEffects(mlir::Operation *self,
41 llvm::SmallVectorImpl<mlir::SideEffects::EffectInstance<
42 mlir::MemoryEffects::Effect>> &effects) {
43 // allocation effect if we return an expr
44 assert(self->getNumResults() == 1 &&
45 "hlfir intrinsic ops only produce 1 result");
46 if (mlir::isa<hlfir::ExprType>(self->getResult(0).getType()))
47 effects.emplace_back(mlir::MemoryEffects::Allocate::get(),
48 self->getResult(0),
49 mlir::SideEffects::DefaultResource::get());
50
51 // read effect if we read from a pointer or refference type
52 // or a box who'se pointer is read from inside of the intrinsic so that
53 // loop conflicts can be detected in code like
54 // hlfir.region_assign {
55 // %2 = hlfir.transpose %0#0 : (!fir.box<!fir.array<?x?xf32>>) ->
56 // !hlfir.expr<?x?xf32> hlfir.yield %2 : !hlfir.expr<?x?xf32> cleanup {
57 // hlfir.destroy %2 : !hlfir.expr<?x?xf32>
58 // }
59 // } to {
60 // hlfir.yield %0#0 : !fir.box<!fir.array<?x?xf32>>
61 // }
62 for (mlir::Value operand : self->getOperands()) {
63 mlir::Type opTy = operand.getType();
64 if (fir::isa_ref_type(opTy) || fir::isa_box_type(opTy))
65 effects.emplace_back(mlir::MemoryEffects::Read::get(), operand,
66 mlir::SideEffects::DefaultResource::get());
67 }
68}
69
70//===----------------------------------------------------------------------===//
71// DeclareOp
72//===----------------------------------------------------------------------===//
73
74/// Is this a fir.[ref/ptr/heap]<fir.[box/class]<fir.heap<T>>> type?
75static bool isAllocatableBoxRef(mlir::Type type) {
76 fir::BaseBoxType boxType =
77 fir::dyn_cast_ptrEleTy(type).dyn_cast_or_null<fir::BaseBoxType>();
78 return boxType && boxType.getEleTy().isa<fir::HeapType>();
79}
80
81mlir::LogicalResult hlfir::AssignOp::verify() {
82 mlir::Type lhsType = getLhs().getType();
83 if (isAllocatableAssignment() && !isAllocatableBoxRef(lhsType))
84 return emitOpError("lhs must be an allocatable when `realloc` is set");
85 if (mustKeepLhsLengthInAllocatableAssignment() &&
86 !(isAllocatableAssignment() &&
87 hlfir::getFortranElementType(lhsType).isa<fir::CharacterType>()))
88 return emitOpError("`realloc` must be set and lhs must be a character "
89 "allocatable when `keep_lhs_length_if_realloc` is set");
90 return mlir::success();
91}
92
93//===----------------------------------------------------------------------===//
94// DeclareOp
95//===----------------------------------------------------------------------===//
96
97/// Given a FIR memory type, and information about non default lower bounds, get
98/// the related HLFIR variable type.
99mlir::Type hlfir::DeclareOp::getHLFIRVariableType(mlir::Type inputType,
100 bool hasExplicitLowerBounds) {
101 mlir::Type type = fir::unwrapRefType(inputType);
102 if (type.isa<fir::BaseBoxType>())
103 return inputType;
104 if (auto charType = type.dyn_cast<fir::CharacterType>())
105 if (charType.hasDynamicLen())
106 return fir::BoxCharType::get(charType.getContext(), charType.getFKind());
107
108 auto seqType = type.dyn_cast<fir::SequenceType>();
109 bool hasDynamicExtents =
110 seqType && fir::sequenceWithNonConstantShape(seqType);
111 mlir::Type eleType = seqType ? seqType.getEleTy() : type;
112 bool hasDynamicLengthParams = fir::characterWithDynamicLen(eleType) ||
113 fir::isRecordWithTypeParameters(eleType);
114 if (hasExplicitLowerBounds || hasDynamicExtents || hasDynamicLengthParams)
115 return fir::BoxType::get(type);
116 return inputType;
117}
118
119static bool hasExplicitLowerBounds(mlir::Value shape) {
120 return shape && shape.getType().isa<fir::ShapeShiftType, fir::ShiftType>();
121}
122
123void hlfir::DeclareOp::build(mlir::OpBuilder &builder,
124 mlir::OperationState &result, mlir::Value memref,
125 llvm::StringRef uniq_name, mlir::Value shape,
126 mlir::ValueRange typeparams,
127 fir::FortranVariableFlagsAttr fortran_attrs,
128 fir::CUDADataAttributeAttr cuda_attr) {
129 auto nameAttr = builder.getStringAttr(uniq_name);
130 mlir::Type inputType = memref.getType();
131 bool hasExplicitLbs = hasExplicitLowerBounds(shape);
132 mlir::Type hlfirVariableType =
133 getHLFIRVariableType(inputType, hasExplicitLbs);
134 build(builder, result, {hlfirVariableType, inputType}, memref, shape,
135 typeparams, nameAttr, fortran_attrs, cuda_attr);
136}
137
138mlir::LogicalResult hlfir::DeclareOp::verify() {
139 if (getMemref().getType() != getResult(1).getType())
140 return emitOpError("second result type must match input memref type");
141 mlir::Type hlfirVariableType = getHLFIRVariableType(
142 getMemref().getType(), hasExplicitLowerBounds(getShape()));
143 if (hlfirVariableType != getResult(0).getType())
144 return emitOpError("first result type is inconsistent with variable "
145 "properties: expected ")
146 << hlfirVariableType;
147 // The rest of the argument verification is done by the
148 // FortranVariableInterface verifier.
149 auto fortranVar =
150 mlir::cast<fir::FortranVariableOpInterface>(this->getOperation());
151 return fortranVar.verifyDeclareLikeOpImpl(getMemref());
152}
153
154//===----------------------------------------------------------------------===//
155// DesignateOp
156//===----------------------------------------------------------------------===//
157
158void hlfir::DesignateOp::build(
159 mlir::OpBuilder &builder, mlir::OperationState &result,
160 mlir::Type result_type, mlir::Value memref, llvm::StringRef component,
161 mlir::Value component_shape, llvm::ArrayRef<Subscript> subscripts,
162 mlir::ValueRange substring, std::optional<bool> complex_part,
163 mlir::Value shape, mlir::ValueRange typeparams,
164 fir::FortranVariableFlagsAttr fortran_attrs) {
165 auto componentAttr =
166 component.empty() ? mlir::StringAttr{} : builder.getStringAttr(component);
167 llvm::SmallVector<mlir::Value> indices;
168 llvm::SmallVector<bool> isTriplet;
169 for (auto subscript : subscripts) {
170 if (auto *triplet = std::get_if<Triplet>(&subscript)) {
171 isTriplet.push_back(true);
172 indices.push_back(std::get<0>(*triplet));
173 indices.push_back(std::get<1>(*triplet));
174 indices.push_back(std::get<2>(*triplet));
175 } else {
176 isTriplet.push_back(false);
177 indices.push_back(std::get<mlir::Value>(subscript));
178 }
179 }
180 auto isTripletAttr =
181 mlir::DenseBoolArrayAttr::get(builder.getContext(), isTriplet);
182 auto complexPartAttr =
183 complex_part.has_value()
184 ? mlir::BoolAttr::get(builder.getContext(), *complex_part)
185 : mlir::BoolAttr{};
186 build(builder, result, result_type, memref, componentAttr, component_shape,
187 indices, isTripletAttr, substring, complexPartAttr, shape, typeparams,
188 fortran_attrs);
189}
190
191void hlfir::DesignateOp::build(mlir::OpBuilder &builder,
192 mlir::OperationState &result,
193 mlir::Type result_type, mlir::Value memref,
194 mlir::ValueRange indices,
195 mlir::ValueRange typeparams,
196 fir::FortranVariableFlagsAttr fortran_attrs) {
197 llvm::SmallVector<bool> isTriplet(indices.size(), false);
198 auto isTripletAttr =
199 mlir::DenseBoolArrayAttr::get(builder.getContext(), isTriplet);
200 build(builder, result, result_type, memref,
201 /*componentAttr=*/mlir::StringAttr{}, /*component_shape=*/mlir::Value{},
202 indices, isTripletAttr, /*substring*/ mlir::ValueRange{},
203 /*complexPartAttr=*/mlir::BoolAttr{}, /*shape=*/mlir::Value{},
204 typeparams, fortran_attrs);
205}
206
207static mlir::ParseResult parseDesignatorIndices(
208 mlir::OpAsmParser &parser,
209 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &indices,
210 mlir::DenseBoolArrayAttr &isTripletAttr) {
211 llvm::SmallVector<bool> isTriplet;
212 if (mlir::succeeded(parser.parseOptionalLParen())) {
213 do {
214 mlir::OpAsmParser::UnresolvedOperand i1, i2, i3;
215 if (parser.parseOperand(i1))
216 return mlir::failure();
217 indices.push_back(i1);
218 if (mlir::succeeded(parser.parseOptionalColon())) {
219 if (parser.parseOperand(i2) || parser.parseColon() ||
220 parser.parseOperand(i3))
221 return mlir::failure();
222 indices.push_back(i2);
223 indices.push_back(i3);
224 isTriplet.push_back(Elt: true);
225 } else {
226 isTriplet.push_back(Elt: false);
227 }
228 } while (mlir::succeeded(parser.parseOptionalComma()));
229 if (parser.parseRParen())
230 return mlir::failure();
231 }
232 isTripletAttr = mlir::DenseBoolArrayAttr::get(parser.getContext(), isTriplet);
233 return mlir::success();
234}
235
236static void
237printDesignatorIndices(mlir::OpAsmPrinter &p, hlfir::DesignateOp designateOp,
238 mlir::OperandRange indices,
239 const mlir::DenseBoolArrayAttr &isTripletAttr) {
240 if (!indices.empty()) {
241 p << '(';
242 unsigned i = 0;
243 for (auto isTriplet : isTripletAttr.asArrayRef()) {
244 if (isTriplet) {
245 assert(i + 2 < indices.size() && "ill-formed indices");
246 p << indices[i] << ":" << indices[i + 1] << ":" << indices[i + 2];
247 i += 3;
248 } else {
249 p << indices[i++];
250 }
251 if (i != indices.size())
252 p << ", ";
253 }
254 p << ')';
255 }
256}
257
258static mlir::ParseResult
259parseDesignatorComplexPart(mlir::OpAsmParser &parser,
260 mlir::BoolAttr &complexPart) {
261 if (mlir::succeeded(parser.parseOptionalKeyword("imag")))
262 complexPart = mlir::BoolAttr::get(parser.getContext(), true);
263 else if (mlir::succeeded(parser.parseOptionalKeyword("real")))
264 complexPart = mlir::BoolAttr::get(parser.getContext(), false);
265 return mlir::success();
266}
267
268static void printDesignatorComplexPart(mlir::OpAsmPrinter &p,
269 hlfir::DesignateOp designateOp,
270 mlir::BoolAttr complexPartAttr) {
271 if (complexPartAttr) {
272 if (complexPartAttr.getValue())
273 p << "imag";
274 else
275 p << "real";
276 }
277}
278
279mlir::LogicalResult hlfir::DesignateOp::verify() {
280 mlir::Type memrefType = getMemref().getType();
281 mlir::Type baseType = getFortranElementOrSequenceType(memrefType);
282 mlir::Type baseElementType = fir::unwrapSequenceType(baseType);
283 unsigned numSubscripts = getIsTriplet().size();
284 unsigned subscriptsRank =
285 llvm::count_if(getIsTriplet(), [](bool isTriplet) { return isTriplet; });
286 unsigned outputRank = 0;
287 mlir::Type outputElementType;
288 bool hasBoxComponent;
289 if (getComponent()) {
290 auto component = getComponent().value();
291 auto recType = baseElementType.dyn_cast<fir::RecordType>();
292 if (!recType)
293 return emitOpError(
294 "component must be provided only when the memref is a derived type");
295 unsigned fieldIdx = recType.getFieldIndex(component);
296 if (fieldIdx > recType.getNumFields()) {
297 return emitOpError("component ")
298 << component << " is not a component of memref element type "
299 << recType;
300 }
301 mlir::Type fieldType = recType.getType(fieldIdx);
302 mlir::Type componentBaseType = getFortranElementOrSequenceType(fieldType);
303 hasBoxComponent = fieldType.isa<fir::BaseBoxType>();
304 if (componentBaseType.isa<fir::SequenceType>() &&
305 baseType.isa<fir::SequenceType>() &&
306 (numSubscripts == 0 || subscriptsRank > 0))
307 return emitOpError("indices must be provided and must not contain "
308 "triplets when both memref and component are arrays");
309 if (numSubscripts != 0) {
310 if (!componentBaseType.isa<fir::SequenceType>())
311 return emitOpError("indices must not be provided if component appears "
312 "and is not an array component");
313 if (!getComponentShape())
314 return emitOpError(
315 "component_shape must be provided when indexing a component");
316 mlir::Type compShapeType = getComponentShape().getType();
317 unsigned componentRank =
318 componentBaseType.cast<fir::SequenceType>().getDimension();
319 auto shapeType = compShapeType.dyn_cast<fir::ShapeType>();
320 auto shapeShiftType = compShapeType.dyn_cast<fir::ShapeShiftType>();
321 if (!((shapeType && shapeType.getRank() == componentRank) ||
322 (shapeShiftType && shapeShiftType.getRank() == componentRank)))
323 return emitOpError("component_shape must be a fir.shape or "
324 "fir.shapeshift with the rank of the component");
325 if (numSubscripts > componentRank)
326 return emitOpError("indices number must match array component rank");
327 }
328 if (auto baseSeqType = baseType.dyn_cast<fir::SequenceType>())
329 // This case must come first to cover "array%array_comp(i, j)" that has
330 // subscripts for the component but whose rank come from the base.
331 outputRank = baseSeqType.getDimension();
332 else if (numSubscripts != 0)
333 outputRank = subscriptsRank;
334 else if (auto componentSeqType =
335 componentBaseType.dyn_cast<fir::SequenceType>())
336 outputRank = componentSeqType.getDimension();
337 outputElementType = fir::unwrapSequenceType(componentBaseType);
338 } else {
339 outputElementType = baseElementType;
340 unsigned baseTypeRank =
341 baseType.isa<fir::SequenceType>()
342 ? baseType.cast<fir::SequenceType>().getDimension()
343 : 0;
344 if (numSubscripts != 0) {
345 if (baseTypeRank != numSubscripts)
346 return emitOpError("indices number must match memref rank");
347 outputRank = subscriptsRank;
348 } else if (auto baseSeqType = baseType.dyn_cast<fir::SequenceType>()) {
349 outputRank = baseSeqType.getDimension();
350 }
351 }
352
353 if (!getSubstring().empty()) {
354 if (!outputElementType.isa<fir::CharacterType>())
355 return emitOpError("memref or component must have character type if "
356 "substring indices are provided");
357 if (getSubstring().size() != 2)
358 return emitOpError("substring must contain 2 indices when provided");
359 }
360 if (getComplexPart()) {
361 if (!fir::isa_complex(outputElementType))
362 return emitOpError("memref or component must have complex type if "
363 "complex_part is provided");
364 if (auto firCplx = outputElementType.dyn_cast<fir::ComplexType>())
365 outputElementType = firCplx.getElementType();
366 else
367 outputElementType =
368 outputElementType.cast<mlir::ComplexType>().getElementType();
369 }
370 mlir::Type resultBaseType =
371 getFortranElementOrSequenceType(getResult().getType());
372 unsigned resultRank = 0;
373 if (auto resultSeqType = resultBaseType.dyn_cast<fir::SequenceType>())
374 resultRank = resultSeqType.getDimension();
375 if (resultRank != outputRank)
376 return emitOpError("result type rank is not consistent with operands, "
377 "expected rank ")
378 << outputRank;
379 mlir::Type resultElementType = fir::unwrapSequenceType(resultBaseType);
380 // result type must match the one that was inferred here, except the character
381 // length may differ because of substrings.
382 if (resultElementType != outputElementType &&
383 !(resultElementType.isa<fir::CharacterType>() &&
384 outputElementType.isa<fir::CharacterType>()) &&
385 !(resultElementType.isa<mlir::FloatType>() &&
386 outputElementType.isa<fir::RealType>()))
387 return emitOpError(
388 "result element type is not consistent with operands, expected ")
389 << outputElementType;
390
391 if (isBoxAddressType(getResult().getType())) {
392 if (!hasBoxComponent || numSubscripts != 0 || !getSubstring().empty() ||
393 getComplexPart())
394 return emitOpError(
395 "result type must only be a box address type if it designates a "
396 "component that is a fir.box or fir.class and if there are no "
397 "indices, substrings, and complex part");
398
399 } else {
400 if ((resultRank == 0) != !getShape())
401 return emitOpError("shape must be provided if and only if the result is "
402 "an array that is not a box address");
403 if (resultRank != 0) {
404 auto shapeType = getShape().getType().dyn_cast<fir::ShapeType>();
405 auto shapeShiftType =
406 getShape().getType().dyn_cast<fir::ShapeShiftType>();
407 if (!((shapeType && shapeType.getRank() == resultRank) ||
408 (shapeShiftType && shapeShiftType.getRank() == resultRank)))
409 return emitOpError("shape must be a fir.shape or fir.shapeshift with "
410 "the rank of the result");
411 }
412 auto numLenParam = getTypeparams().size();
413 if (outputElementType.isa<fir::CharacterType>()) {
414 if (numLenParam != 1)
415 return emitOpError("must be provided one length parameter when the "
416 "result is a character");
417 } else if (fir::isRecordWithTypeParameters(outputElementType)) {
418 if (numLenParam !=
419 outputElementType.cast<fir::RecordType>().getNumLenParams())
420 return emitOpError("must be provided the same number of length "
421 "parameters as in the result derived type");
422 } else if (numLenParam != 0) {
423 return emitOpError("must not be provided length parameters if the result "
424 "type does not have length parameters");
425 }
426 }
427 return mlir::success();
428}
429
430//===----------------------------------------------------------------------===//
431// ParentComponentOp
432//===----------------------------------------------------------------------===//
433
434mlir::LogicalResult hlfir::ParentComponentOp::verify() {
435 mlir::Type baseType =
436 hlfir::getFortranElementOrSequenceType(getMemref().getType());
437 auto maybeInputSeqType = baseType.dyn_cast<fir::SequenceType>();
438 unsigned inputTypeRank =
439 maybeInputSeqType ? maybeInputSeqType.getDimension() : 0;
440 unsigned shapeRank = 0;
441 if (mlir::Value shape = getShape())
442 if (auto shapeType = shape.getType().dyn_cast<fir::ShapeType>())
443 shapeRank = shapeType.getRank();
444 if (inputTypeRank != shapeRank)
445 return emitOpError(
446 "must be provided a shape if and only if the base is an array");
447 mlir::Type outputBaseType = hlfir::getFortranElementOrSequenceType(getType());
448 auto maybeOutputSeqType = outputBaseType.dyn_cast<fir::SequenceType>();
449 unsigned outputTypeRank =
450 maybeOutputSeqType ? maybeOutputSeqType.getDimension() : 0;
451 if (inputTypeRank != outputTypeRank)
452 return emitOpError("result type rank must match input type rank");
453 if (maybeOutputSeqType && maybeInputSeqType)
454 for (auto [inputDim, outputDim] :
455 llvm::zip(maybeInputSeqType.getShape(), maybeOutputSeqType.getShape()))
456 if (inputDim != fir::SequenceType::getUnknownExtent() &&
457 outputDim != fir::SequenceType::getUnknownExtent())
458 if (inputDim != outputDim)
459 return emitOpError(
460 "result type extents are inconsistent with memref type");
461 fir::RecordType baseRecType =
462 hlfir::getFortranElementType(baseType).dyn_cast<fir::RecordType>();
463 fir::RecordType outRecType =
464 hlfir::getFortranElementType(outputBaseType).dyn_cast<fir::RecordType>();
465 if (!baseRecType || !outRecType)
466 return emitOpError("result type and input type must be derived types");
467
468 // Note: result should not be a fir.class: its dynamic type is being set to
469 // the parent type and allowing fir.class would break the operation codegen:
470 // it would keep the input dynamic type.
471 if (getType().isa<fir::ClassType>())
472 return emitOpError("result type must not be polymorphic");
473
474 // The array results are known to not be dis-contiguous in most cases (the
475 // exception being if the parent type was extended by a type without any
476 // components): require a fir.box to be used for the result to carry the
477 // strides.
478 if (!getType().isa<fir::BoxType>() &&
479 (outputTypeRank != 0 || fir::isRecordWithTypeParameters(outRecType)))
480 return emitOpError("result type must be a fir.box if the result is an "
481 "array or has length parameters");
482 return mlir::success();
483}
484
485//===----------------------------------------------------------------------===//
486// LogicalReductionOp
487//===----------------------------------------------------------------------===//
488template <typename LogicalReductionOp>
489static mlir::LogicalResult
490verifyLogicalReductionOp(LogicalReductionOp reductionOp) {
491 mlir::Operation *op = reductionOp->getOperation();
492
493 auto results = op->getResultTypes();
494 assert(results.size() == 1);
495
496 mlir::Value mask = reductionOp->getMask();
497 mlir::Value dim = reductionOp->getDim();
498
499 fir::SequenceType maskTy =
500 hlfir::getFortranElementOrSequenceType(mask.getType())
501 .cast<fir::SequenceType>();
502 mlir::Type logicalTy = maskTy.getEleTy();
503 llvm::ArrayRef<int64_t> maskShape = maskTy.getShape();
504
505 mlir::Type resultType = results[0];
506 if (mlir::isa<fir::LogicalType>(resultType)) {
507 // Result is of the same type as MASK
508 if ((resultType != logicalTy) && useStrictIntrinsicVerifier)
509 return reductionOp->emitOpError(
510 "result must have the same element type as MASK argument");
511
512 } else if (auto resultExpr =
513 mlir::dyn_cast_or_null<hlfir::ExprType>(resultType)) {
514 // Result should only be in hlfir.expr form if it is an array
515 if (maskShape.size() > 1 && dim != nullptr) {
516 if (!resultExpr.isArray())
517 return reductionOp->emitOpError("result must be an array");
518
519 if ((resultExpr.getEleTy() != logicalTy) && useStrictIntrinsicVerifier)
520 return reductionOp->emitOpError(
521 "result must have the same element type as MASK argument");
522
523 llvm::ArrayRef<int64_t> resultShape = resultExpr.getShape();
524 // Result has rank n-1
525 if (resultShape.size() != (maskShape.size() - 1))
526 return reductionOp->emitOpError(
527 "result rank must be one less than MASK");
528 } else {
529 return reductionOp->emitOpError("result must be of logical type");
530 }
531 } else {
532 return reductionOp->emitOpError("result must be of logical type");
533 }
534 return mlir::success();
535}
536
537//===----------------------------------------------------------------------===//
538// AllOp
539//===----------------------------------------------------------------------===//
540
541mlir::LogicalResult hlfir::AllOp::verify() {
542 return verifyLogicalReductionOp<hlfir::AllOp *>(this);
543}
544
545void hlfir::AllOp::getEffects(
546 llvm::SmallVectorImpl<
547 mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
548 &effects) {
549 getIntrinsicEffects(getOperation(), effects);
550}
551
552//===----------------------------------------------------------------------===//
553// AnyOp
554//===----------------------------------------------------------------------===//
555
556mlir::LogicalResult hlfir::AnyOp::verify() {
557 return verifyLogicalReductionOp<hlfir::AnyOp *>(this);
558}
559
560void hlfir::AnyOp::getEffects(
561 llvm::SmallVectorImpl<
562 mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
563 &effects) {
564 getIntrinsicEffects(getOperation(), effects);
565}
566
567//===----------------------------------------------------------------------===//
568// CountOp
569//===----------------------------------------------------------------------===//
570
571mlir::LogicalResult hlfir::CountOp::verify() {
572 mlir::Operation *op = getOperation();
573
574 auto results = op->getResultTypes();
575 assert(results.size() == 1);
576 mlir::Value mask = getMask();
577 mlir::Value dim = getDim();
578
579 fir::SequenceType maskTy =
580 hlfir::getFortranElementOrSequenceType(mask.getType())
581 .cast<fir::SequenceType>();
582 llvm::ArrayRef<int64_t> maskShape = maskTy.getShape();
583
584 mlir::Type resultType = results[0];
585 if (auto resultExpr = mlir::dyn_cast_or_null<hlfir::ExprType>(resultType)) {
586 if (maskShape.size() > 1 && dim != nullptr) {
587 if (!resultExpr.isArray())
588 return emitOpError("result must be an array");
589
590 llvm::ArrayRef<int64_t> resultShape = resultExpr.getShape();
591 // Result has rank n-1
592 if (resultShape.size() != (maskShape.size() - 1))
593 return emitOpError("result rank must be one less than MASK");
594 } else {
595 return emitOpError("result must be of numerical array type");
596 }
597 } else if (!hlfir::isFortranScalarNumericalType(resultType)) {
598 return emitOpError("result must be of numerical scalar type");
599 }
600
601 return mlir::success();
602}
603
604void hlfir::CountOp::getEffects(
605 llvm::SmallVectorImpl<
606 mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
607 &effects) {
608 getIntrinsicEffects(getOperation(), effects);
609}
610
611//===----------------------------------------------------------------------===//
612// ConcatOp
613//===----------------------------------------------------------------------===//
614
615static unsigned getCharacterKind(mlir::Type t) {
616 return hlfir::getFortranElementType(t).cast<fir::CharacterType>().getFKind();
617}
618
619static std::optional<fir::CharacterType::LenType>
620getCharacterLengthIfStatic(mlir::Type t) {
621 if (auto charType =
622 hlfir::getFortranElementType(t).dyn_cast<fir::CharacterType>())
623 if (charType.hasConstantLen())
624 return charType.getLen();
625 return std::nullopt;
626}
627
628mlir::LogicalResult hlfir::ConcatOp::verify() {
629 if (getStrings().size() < 2)
630 return emitOpError("must be provided at least two string operands");
631 unsigned kind = getCharacterKind(getResult().getType());
632 for (auto string : getStrings())
633 if (kind != getCharacterKind(string.getType()))
634 return emitOpError("strings must have the same KIND as the result type");
635 return mlir::success();
636}
637
638void hlfir::ConcatOp::build(mlir::OpBuilder &builder,
639 mlir::OperationState &result,
640 mlir::ValueRange strings, mlir::Value len) {
641 fir::CharacterType::LenType resultTypeLen = 0;
642 assert(!strings.empty() && "must contain operands");
643 unsigned kind = getCharacterKind(strings[0].getType());
644 for (auto string : strings)
645 if (auto cstLen = getCharacterLengthIfStatic(string.getType())) {
646 resultTypeLen += *cstLen;
647 } else {
648 resultTypeLen = fir::CharacterType::unknownLen();
649 break;
650 }
651 auto resultType = hlfir::ExprType::get(
652 builder.getContext(), hlfir::ExprType::Shape{},
653 fir::CharacterType::get(builder.getContext(), kind, resultTypeLen),
654 false);
655 build(builder, result, resultType, strings, len);
656}
657
658void hlfir::ConcatOp::getEffects(
659 llvm::SmallVectorImpl<
660 mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
661 &effects) {
662 getIntrinsicEffects(getOperation(), effects);
663}
664
665//===----------------------------------------------------------------------===//
666// NumericalReductionOp
667//===----------------------------------------------------------------------===//
668
669template <typename NumericalReductionOp>
670static mlir::LogicalResult
671verifyArrayAndMaskForReductionOp(NumericalReductionOp reductionOp) {
672 mlir::Value array = reductionOp->getArray();
673 mlir::Value mask = reductionOp->getMask();
674
675 fir::SequenceType arrayTy =
676 hlfir::getFortranElementOrSequenceType(array.getType())
677 .cast<fir::SequenceType>();
678 llvm::ArrayRef<int64_t> arrayShape = arrayTy.getShape();
679
680 if (mask) {
681 fir::SequenceType maskSeq =
682 hlfir::getFortranElementOrSequenceType(mask.getType())
683 .dyn_cast<fir::SequenceType>();
684 llvm::ArrayRef<int64_t> maskShape;
685
686 if (maskSeq)
687 maskShape = maskSeq.getShape();
688
689 if (!maskShape.empty()) {
690 if (maskShape.size() != arrayShape.size())
691 return reductionOp->emitWarning("MASK must be conformable to ARRAY");
692 if (useStrictIntrinsicVerifier) {
693 static_assert(fir::SequenceType::getUnknownExtent() ==
694 hlfir::ExprType::getUnknownExtent());
695 constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
696 for (std::size_t i = 0; i < arrayShape.size(); ++i) {
697 int64_t arrayExtent = arrayShape[i];
698 int64_t maskExtent = maskShape[i];
699 if ((arrayExtent != maskExtent) && (arrayExtent != unknownExtent) &&
700 (maskExtent != unknownExtent))
701 return reductionOp->emitWarning(
702 "MASK must be conformable to ARRAY");
703 }
704 }
705 }
706 }
707 return mlir::success();
708}
709
710template <typename NumericalReductionOp>
711static mlir::LogicalResult
712verifyNumericalReductionOp(NumericalReductionOp reductionOp) {
713 mlir::Operation *op = reductionOp->getOperation();
714 auto results = op->getResultTypes();
715 assert(results.size() == 1);
716
717 auto res = verifyArrayAndMaskForReductionOp(reductionOp);
718 if (failed(res))
719 return res;
720
721 mlir::Value array = reductionOp->getArray();
722 mlir::Value dim = reductionOp->getDim();
723 fir::SequenceType arrayTy =
724 hlfir::getFortranElementOrSequenceType(array.getType())
725 .cast<fir::SequenceType>();
726 mlir::Type numTy = arrayTy.getEleTy();
727 llvm::ArrayRef<int64_t> arrayShape = arrayTy.getShape();
728
729 mlir::Type resultType = results[0];
730 if (hlfir::isFortranScalarNumericalType(resultType)) {
731 // Result is of the same type as ARRAY
732 if ((resultType != numTy) && useStrictIntrinsicVerifier)
733 return reductionOp->emitOpError(
734 "result must have the same element type as ARRAY argument");
735
736 } else if (auto resultExpr =
737 mlir::dyn_cast_or_null<hlfir::ExprType>(resultType)) {
738 if (arrayShape.size() > 1 && dim != nullptr) {
739 if (!resultExpr.isArray())
740 return reductionOp->emitOpError("result must be an array");
741
742 if ((resultExpr.getEleTy() != numTy) && useStrictIntrinsicVerifier)
743 return reductionOp->emitOpError(
744 "result must have the same element type as ARRAY argument");
745
746 llvm::ArrayRef<int64_t> resultShape = resultExpr.getShape();
747 // Result has rank n-1
748 if (resultShape.size() != (arrayShape.size() - 1))
749 return reductionOp->emitOpError(
750 "result rank must be one less than ARRAY");
751 } else {
752 return reductionOp->emitOpError(
753 "result must be of numerical scalar type");
754 }
755 } else {
756 return reductionOp->emitOpError("result must be of numerical scalar type");
757 }
758 return mlir::success();
759}
760
761//===----------------------------------------------------------------------===//
762// ProductOp
763//===----------------------------------------------------------------------===//
764
765mlir::LogicalResult hlfir::ProductOp::verify() {
766 return verifyNumericalReductionOp<hlfir::ProductOp *>(this);
767}
768
769void hlfir::ProductOp::getEffects(
770 llvm::SmallVectorImpl<
771 mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
772 &effects) {
773 getIntrinsicEffects(getOperation(), effects);
774}
775
776//===----------------------------------------------------------------------===//
777// CharacterReductionOp
778//===----------------------------------------------------------------------===//
779
780template <typename CharacterReductionOp>
781static mlir::LogicalResult
782verifyCharacterReductionOp(CharacterReductionOp reductionOp) {
783 mlir::Operation *op = reductionOp->getOperation();
784 auto results = op->getResultTypes();
785 assert(results.size() == 1);
786
787 auto res = verifyArrayAndMaskForReductionOp(reductionOp);
788 if (failed(res))
789 return res;
790
791 mlir::Value array = reductionOp->getArray();
792 mlir::Value dim = reductionOp->getDim();
793 fir::SequenceType arrayTy =
794 hlfir::getFortranElementOrSequenceType(array.getType())
795 .cast<fir::SequenceType>();
796 mlir::Type numTy = arrayTy.getEleTy();
797 llvm::ArrayRef<int64_t> arrayShape = arrayTy.getShape();
798
799 auto resultExpr = results[0].cast<hlfir::ExprType>();
800 mlir::Type resultType = resultExpr.getEleTy();
801 assert(mlir::isa<fir::CharacterType>(resultType) &&
802 "result must be character");
803
804 // Result is of the same type as ARRAY
805 if ((resultType != numTy) && useStrictIntrinsicVerifier)
806 return reductionOp->emitOpError(
807 "result must have the same element type as ARRAY argument");
808
809 if (arrayShape.size() > 1 && dim != nullptr) {
810 if (!resultExpr.isArray())
811 return reductionOp->emitOpError("result must be an array");
812 llvm::ArrayRef<int64_t> resultShape = resultExpr.getShape();
813 // Result has rank n-1
814 if (resultShape.size() != (arrayShape.size() - 1))
815 return reductionOp->emitOpError(
816 "result rank must be one less than ARRAY");
817 } else if (!resultExpr.isScalar()) {
818 return reductionOp->emitOpError("result must be scalar character");
819 }
820 return mlir::success();
821}
822
823//===----------------------------------------------------------------------===//
824// MaxvalOp
825//===----------------------------------------------------------------------===//
826
827mlir::LogicalResult hlfir::MaxvalOp::verify() {
828 mlir::Operation *op = getOperation();
829
830 auto results = op->getResultTypes();
831 assert(results.size() == 1);
832
833 auto resultExpr = mlir::dyn_cast<hlfir::ExprType>(results[0]);
834 if (resultExpr && mlir::isa<fir::CharacterType>(resultExpr.getEleTy())) {
835 return verifyCharacterReductionOp<hlfir::MaxvalOp *>(this);
836 }
837 return verifyNumericalReductionOp<hlfir::MaxvalOp *>(this);
838}
839
840void hlfir::MaxvalOp::getEffects(
841 llvm::SmallVectorImpl<
842 mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
843 &effects) {
844 getIntrinsicEffects(getOperation(), effects);
845}
846
847//===----------------------------------------------------------------------===//
848// MinvalOp
849//===----------------------------------------------------------------------===//
850
851mlir::LogicalResult hlfir::MinvalOp::verify() {
852 mlir::Operation *op = getOperation();
853
854 auto results = op->getResultTypes();
855 assert(results.size() == 1);
856
857 auto resultExpr = mlir::dyn_cast<hlfir::ExprType>(results[0]);
858 if (resultExpr && mlir::isa<fir::CharacterType>(resultExpr.getEleTy())) {
859 return verifyCharacterReductionOp<hlfir::MinvalOp *>(this);
860 }
861 return verifyNumericalReductionOp<hlfir::MinvalOp *>(this);
862}
863
864void hlfir::MinvalOp::getEffects(
865 llvm::SmallVectorImpl<
866 mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
867 &effects) {
868 getIntrinsicEffects(getOperation(), effects);
869}
870
871//===----------------------------------------------------------------------===//
872// MinlocOp
873//===----------------------------------------------------------------------===//
874
875template <typename NumericalReductionOp>
876static mlir::LogicalResult
877verifyResultForMinMaxLoc(NumericalReductionOp reductionOp) {
878 mlir::Operation *op = reductionOp->getOperation();
879 auto results = op->getResultTypes();
880 assert(results.size() == 1);
881
882 mlir::Value array = reductionOp->getArray();
883 mlir::Value dim = reductionOp->getDim();
884 fir::SequenceType arrayTy =
885 hlfir::getFortranElementOrSequenceType(array.getType())
886 .cast<fir::SequenceType>();
887 llvm::ArrayRef<int64_t> arrayShape = arrayTy.getShape();
888
889 mlir::Type resultType = results[0];
890 if (dim && arrayShape.size() == 1) {
891 if (!fir::isa_integer(resultType))
892 return reductionOp->emitOpError("result must be scalar integer");
893 } else if (auto resultExpr =
894 mlir::dyn_cast_or_null<hlfir::ExprType>(resultType)) {
895 if (!resultExpr.isArray())
896 return reductionOp->emitOpError("result must be an array");
897
898 if (!fir::isa_integer(resultExpr.getEleTy()))
899 return reductionOp->emitOpError("result must have integer elements");
900
901 llvm::ArrayRef<int64_t> resultShape = resultExpr.getShape();
902 // With dim the result has rank n-1
903 if (dim && resultShape.size() != (arrayShape.size() - 1))
904 return reductionOp->emitOpError(
905 "result rank must be one less than ARRAY");
906 // With dim the result has rank n
907 if (!dim && resultShape.size() != 1)
908 return reductionOp->emitOpError("result rank must be 1");
909 } else {
910 return reductionOp->emitOpError("result must be of numerical expr type");
911 }
912 return mlir::success();
913}
914
915mlir::LogicalResult hlfir::MinlocOp::verify() {
916 auto res = verifyArrayAndMaskForReductionOp(this);
917 if (failed(res))
918 return res;
919
920 return verifyResultForMinMaxLoc(this);
921}
922
923void hlfir::MinlocOp::getEffects(
924 llvm::SmallVectorImpl<
925 mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
926 &effects) {
927 getIntrinsicEffects(getOperation(), effects);
928}
929
930//===----------------------------------------------------------------------===//
931// MaxlocOp
932//===----------------------------------------------------------------------===//
933
934mlir::LogicalResult hlfir::MaxlocOp::verify() {
935 auto res = verifyArrayAndMaskForReductionOp(this);
936 if (failed(res))
937 return res;
938
939 return verifyResultForMinMaxLoc(this);
940}
941
942void hlfir::MaxlocOp::getEffects(
943 llvm::SmallVectorImpl<
944 mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
945 &effects) {
946 getIntrinsicEffects(getOperation(), effects);
947}
948
949//===----------------------------------------------------------------------===//
950// SetLengthOp
951//===----------------------------------------------------------------------===//
952
953void hlfir::SetLengthOp::build(mlir::OpBuilder &builder,
954 mlir::OperationState &result, mlir::Value string,
955 mlir::Value len) {
956 fir::CharacterType::LenType resultTypeLen = fir::CharacterType::unknownLen();
957 if (auto cstLen = fir::getIntIfConstant(len))
958 resultTypeLen = *cstLen;
959 unsigned kind = getCharacterKind(string.getType());
960 auto resultType = hlfir::ExprType::get(
961 builder.getContext(), hlfir::ExprType::Shape{},
962 fir::CharacterType::get(builder.getContext(), kind, resultTypeLen),
963 false);
964 build(builder, result, resultType, string, len);
965}
966
967void hlfir::SetLengthOp::getEffects(
968 llvm::SmallVectorImpl<
969 mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
970 &effects) {
971 getIntrinsicEffects(getOperation(), effects);
972}
973
974//===----------------------------------------------------------------------===//
975// SumOp
976//===----------------------------------------------------------------------===//
977
978mlir::LogicalResult hlfir::SumOp::verify() {
979 return verifyNumericalReductionOp<hlfir::SumOp *>(this);
980}
981
982void hlfir::SumOp::getEffects(
983 llvm::SmallVectorImpl<
984 mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
985 &effects) {
986 getIntrinsicEffects(getOperation(), effects);
987}
988
989//===----------------------------------------------------------------------===//
990// DotProductOp
991//===----------------------------------------------------------------------===//
992
993mlir::LogicalResult hlfir::DotProductOp::verify() {
994 mlir::Value lhs = getLhs();
995 mlir::Value rhs = getRhs();
996 fir::SequenceType lhsTy =
997 hlfir::getFortranElementOrSequenceType(lhs.getType())
998 .cast<fir::SequenceType>();
999 fir::SequenceType rhsTy =
1000 hlfir::getFortranElementOrSequenceType(rhs.getType())
1001 .cast<fir::SequenceType>();
1002 llvm::ArrayRef<int64_t> lhsShape = lhsTy.getShape();
1003 llvm::ArrayRef<int64_t> rhsShape = rhsTy.getShape();
1004 std::size_t lhsRank = lhsShape.size();
1005 std::size_t rhsRank = rhsShape.size();
1006 mlir::Type lhsEleTy = lhsTy.getEleTy();
1007 mlir::Type rhsEleTy = rhsTy.getEleTy();
1008 mlir::Type resultTy = getResult().getType();
1009
1010 if ((lhsRank != 1) || (rhsRank != 1))
1011 return emitOpError("both arrays must have rank 1");
1012
1013 int64_t lhsSize = lhsShape[0];
1014 int64_t rhsSize = rhsShape[0];
1015
1016 constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
1017 if ((lhsSize != unknownExtent) && (rhsSize != unknownExtent) &&
1018 (lhsSize != rhsSize) && useStrictIntrinsicVerifier)
1019 return emitOpError("both arrays must have the same size");
1020
1021 if (useStrictIntrinsicVerifier) {
1022 if (mlir::isa<fir::LogicalType>(lhsEleTy) !=
1023 mlir::isa<fir::LogicalType>(rhsEleTy))
1024 return emitOpError("if one array is logical, so should the other be");
1025
1026 if (mlir::isa<fir::LogicalType>(lhsEleTy) !=
1027 mlir::isa<fir::LogicalType>(resultTy))
1028 return emitOpError("the result type should be a logical only if the "
1029 "argument types are logical");
1030 }
1031
1032 if (!hlfir::isFortranScalarNumericalType(resultTy) &&
1033 !mlir::isa<fir::LogicalType>(resultTy))
1034 return emitOpError(
1035 "the result must be of scalar numerical or logical type");
1036
1037 return mlir::success();
1038}
1039
1040void hlfir::DotProductOp::getEffects(
1041 llvm::SmallVectorImpl<
1042 mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
1043 &effects) {
1044 getIntrinsicEffects(getOperation(), effects);
1045}
1046
1047//===----------------------------------------------------------------------===//
1048// MatmulOp
1049//===----------------------------------------------------------------------===//
1050
1051mlir::LogicalResult hlfir::MatmulOp::verify() {
1052 mlir::Value lhs = getLhs();
1053 mlir::Value rhs = getRhs();
1054 fir::SequenceType lhsTy =
1055 hlfir::getFortranElementOrSequenceType(lhs.getType())
1056 .cast<fir::SequenceType>();
1057 fir::SequenceType rhsTy =
1058 hlfir::getFortranElementOrSequenceType(rhs.getType())
1059 .cast<fir::SequenceType>();
1060 llvm::ArrayRef<int64_t> lhsShape = lhsTy.getShape();
1061 llvm::ArrayRef<int64_t> rhsShape = rhsTy.getShape();
1062 std::size_t lhsRank = lhsShape.size();
1063 std::size_t rhsRank = rhsShape.size();
1064 mlir::Type lhsEleTy = lhsTy.getEleTy();
1065 mlir::Type rhsEleTy = rhsTy.getEleTy();
1066 hlfir::ExprType resultTy = getResult().getType().cast<hlfir::ExprType>();
1067 llvm::ArrayRef<int64_t> resultShape = resultTy.getShape();
1068 mlir::Type resultEleTy = resultTy.getEleTy();
1069
1070 if (((lhsRank != 1) && (lhsRank != 2)) || ((rhsRank != 1) && (rhsRank != 2)))
1071 return emitOpError("array must have either rank 1 or rank 2");
1072
1073 if ((lhsRank == 1) && (rhsRank == 1))
1074 return emitOpError("at least one array must have rank 2");
1075
1076 if (mlir::isa<fir::LogicalType>(lhsEleTy) !=
1077 mlir::isa<fir::LogicalType>(rhsEleTy))
1078 return emitOpError("if one array is logical, so should the other be");
1079
1080 if (!useStrictIntrinsicVerifier)
1081 return mlir::success();
1082
1083 int64_t lastLhsDim = lhsShape[lhsRank - 1];
1084 int64_t firstRhsDim = rhsShape[0];
1085 constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
1086 if (lastLhsDim != firstRhsDim)
1087 if ((lastLhsDim != unknownExtent) && (firstRhsDim != unknownExtent))
1088 return emitOpError(
1089 "the last dimension of LHS should match the first dimension of RHS");
1090
1091 if (mlir::isa<fir::LogicalType>(lhsEleTy) !=
1092 mlir::isa<fir::LogicalType>(resultEleTy))
1093 return emitOpError("the result type should be a logical only if the "
1094 "argument types are logical");
1095
1096 llvm::SmallVector<int64_t, 2> expectedResultShape;
1097 if (lhsRank == 2) {
1098 if (rhsRank == 2) {
1099 expectedResultShape.push_back(lhsShape[0]);
1100 expectedResultShape.push_back(rhsShape[1]);
1101 } else {
1102 // rhsRank == 1
1103 expectedResultShape.push_back(lhsShape[0]);
1104 }
1105 } else {
1106 // lhsRank == 1
1107 // rhsRank == 2
1108 expectedResultShape.push_back(rhsShape[1]);
1109 }
1110 if (resultShape.size() != expectedResultShape.size())
1111 return emitOpError("incorrect result shape");
1112 if (resultShape[0] != expectedResultShape[0] &&
1113 expectedResultShape[0] != unknownExtent)
1114 return emitOpError("incorrect result shape");
1115 if (resultShape.size() == 2 && resultShape[1] != expectedResultShape[1] &&
1116 expectedResultShape[1] != unknownExtent)
1117 return emitOpError("incorrect result shape");
1118
1119 return mlir::success();
1120}
1121
1122mlir::LogicalResult
1123hlfir::MatmulOp::canonicalize(MatmulOp matmulOp,
1124 mlir::PatternRewriter &rewriter) {
1125 // the only two uses of the transposed matrix should be for the hlfir.matmul
1126 // and hlfir.destory
1127 auto isOtherwiseUnused = [&](hlfir::TransposeOp transposeOp) -> bool {
1128 std::size_t numUses = 0;
1129 for (mlir::Operation *user : transposeOp.getResult().getUsers()) {
1130 ++numUses;
1131 if (user == matmulOp)
1132 continue;
1133 if (mlir::dyn_cast_or_null<hlfir::DestroyOp>(user))
1134 continue;
1135 // some other use!
1136 return false;
1137 }
1138 return numUses <= 2;
1139 };
1140
1141 mlir::Value lhs = matmulOp.getLhs();
1142 // Rewrite MATMUL(TRANSPOSE(lhs), rhs) => hlfir.matmul_transpose lhs, rhs
1143 if (auto transposeOp = lhs.getDefiningOp<hlfir::TransposeOp>()) {
1144 if (isOtherwiseUnused(transposeOp)) {
1145 mlir::Location loc = matmulOp.getLoc();
1146 mlir::Type resultTy = matmulOp.getResult().getType();
1147 auto matmulTransposeOp = rewriter.create<hlfir::MatmulTransposeOp>(
1148 loc, resultTy, transposeOp.getArray(), matmulOp.getRhs());
1149
1150 // we don't need to remove any hlfir.destroy because it will be needed for
1151 // the new intrinsic result anyway
1152 rewriter.replaceOp(matmulOp, matmulTransposeOp.getResult());
1153
1154 // but we do need to get rid of the hlfir.destroy for the hlfir.transpose
1155 // result (which is entirely removed)
1156 llvm::SmallVector<mlir::Operation *> users(
1157 transposeOp->getResult(0).getUsers());
1158 for (mlir::Operation *user : users)
1159 if (auto destroyOp = mlir::dyn_cast_or_null<hlfir::DestroyOp>(user))
1160 rewriter.eraseOp(destroyOp);
1161 rewriter.eraseOp(transposeOp);
1162
1163 return mlir::success();
1164 }
1165 }
1166
1167 return mlir::failure();
1168}
1169
1170void hlfir::MatmulOp::getEffects(
1171 llvm::SmallVectorImpl<
1172 mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
1173 &effects) {
1174 getIntrinsicEffects(getOperation(), effects);
1175}
1176
1177//===----------------------------------------------------------------------===//
1178// TransposeOp
1179//===----------------------------------------------------------------------===//
1180
1181mlir::LogicalResult hlfir::TransposeOp::verify() {
1182 mlir::Value array = getArray();
1183 fir::SequenceType arrayTy =
1184 hlfir::getFortranElementOrSequenceType(array.getType())
1185 .cast<fir::SequenceType>();
1186 llvm::ArrayRef<int64_t> inShape = arrayTy.getShape();
1187 std::size_t rank = inShape.size();
1188 mlir::Type eleTy = arrayTy.getEleTy();
1189 hlfir::ExprType resultTy = getResult().getType().cast<hlfir::ExprType>();
1190 llvm::ArrayRef<int64_t> resultShape = resultTy.getShape();
1191 std::size_t resultRank = resultShape.size();
1192 mlir::Type resultEleTy = resultTy.getEleTy();
1193
1194 if (rank != 2 || resultRank != 2)
1195 return emitOpError("input and output arrays should have rank 2");
1196
1197 if (!useStrictIntrinsicVerifier)
1198 return mlir::success();
1199
1200 constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
1201 if ((inShape[0] != resultShape[1]) && (inShape[0] != unknownExtent))
1202 return emitOpError("output shape does not match input array");
1203 if ((inShape[1] != resultShape[0]) && (inShape[1] != unknownExtent))
1204 return emitOpError("output shape does not match input array");
1205
1206 if (eleTy != resultEleTy)
1207 return emitOpError(
1208 "input and output arrays should have the same element type");
1209
1210 return mlir::success();
1211}
1212
1213void hlfir::TransposeOp::getEffects(
1214 llvm::SmallVectorImpl<
1215 mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
1216 &effects) {
1217 getIntrinsicEffects(getOperation(), effects);
1218}
1219
1220//===----------------------------------------------------------------------===//
1221// MatmulTransposeOp
1222//===----------------------------------------------------------------------===//
1223
1224mlir::LogicalResult hlfir::MatmulTransposeOp::verify() {
1225 mlir::Value lhs = getLhs();
1226 mlir::Value rhs = getRhs();
1227 fir::SequenceType lhsTy =
1228 hlfir::getFortranElementOrSequenceType(lhs.getType())
1229 .cast<fir::SequenceType>();
1230 fir::SequenceType rhsTy =
1231 hlfir::getFortranElementOrSequenceType(rhs.getType())
1232 .cast<fir::SequenceType>();
1233 llvm::ArrayRef<int64_t> lhsShape = lhsTy.getShape();
1234 llvm::ArrayRef<int64_t> rhsShape = rhsTy.getShape();
1235 std::size_t lhsRank = lhsShape.size();
1236 std::size_t rhsRank = rhsShape.size();
1237 mlir::Type lhsEleTy = lhsTy.getEleTy();
1238 mlir::Type rhsEleTy = rhsTy.getEleTy();
1239 hlfir::ExprType resultTy = getResult().getType().cast<hlfir::ExprType>();
1240 llvm::ArrayRef<int64_t> resultShape = resultTy.getShape();
1241 mlir::Type resultEleTy = resultTy.getEleTy();
1242
1243 // lhs must have rank 2 for the transpose to be valid
1244 if ((lhsRank != 2) || ((rhsRank != 1) && (rhsRank != 2)))
1245 return emitOpError("array must have either rank 1 or rank 2");
1246
1247 if (!useStrictIntrinsicVerifier)
1248 return mlir::success();
1249
1250 if (mlir::isa<fir::LogicalType>(lhsEleTy) !=
1251 mlir::isa<fir::LogicalType>(rhsEleTy))
1252 return emitOpError("if one array is logical, so should the other be");
1253
1254 // for matmul we compare the last dimension of lhs with the first dimension of
1255 // rhs, but for MatmulTranspose, dimensions of lhs are inverted by the
1256 // transpose
1257 int64_t firstLhsDim = lhsShape[0];
1258 int64_t firstRhsDim = rhsShape[0];
1259 constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
1260 if (firstLhsDim != firstRhsDim)
1261 if ((firstLhsDim != unknownExtent) && (firstRhsDim != unknownExtent))
1262 return emitOpError(
1263 "the first dimension of LHS should match the first dimension of RHS");
1264
1265 if (mlir::isa<fir::LogicalType>(lhsEleTy) !=
1266 mlir::isa<fir::LogicalType>(resultEleTy))
1267 return emitOpError("the result type should be a logical only if the "
1268 "argument types are logical");
1269
1270 llvm::SmallVector<int64_t, 2> expectedResultShape;
1271 if (rhsRank == 2) {
1272 expectedResultShape.push_back(lhsShape[1]);
1273 expectedResultShape.push_back(rhsShape[1]);
1274 } else {
1275 // rhsRank == 1
1276 expectedResultShape.push_back(lhsShape[1]);
1277 }
1278 if (resultShape.size() != expectedResultShape.size())
1279 return emitOpError("incorrect result shape");
1280 if (resultShape[0] != expectedResultShape[0])
1281 return emitOpError("incorrect result shape");
1282 if (resultShape.size() == 2 && resultShape[1] != expectedResultShape[1])
1283 return emitOpError("incorrect result shape");
1284
1285 return mlir::success();
1286}
1287
1288void hlfir::MatmulTransposeOp::getEffects(
1289 llvm::SmallVectorImpl<
1290 mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
1291 &effects) {
1292 getIntrinsicEffects(getOperation(), effects);
1293}
1294
1295//===----------------------------------------------------------------------===//
1296// AssociateOp
1297//===----------------------------------------------------------------------===//
1298
1299void hlfir::AssociateOp::build(mlir::OpBuilder &builder,
1300 mlir::OperationState &result, mlir::Value source,
1301 llvm::StringRef uniq_name, mlir::Value shape,
1302 mlir::ValueRange typeparams,
1303 fir::FortranVariableFlagsAttr fortran_attrs) {
1304 auto nameAttr = builder.getStringAttr(uniq_name);
1305 mlir::Type dataType = getFortranElementOrSequenceType(source.getType());
1306
1307 // Preserve polymorphism of polymorphic expr.
1308 mlir::Type firVarType;
1309 auto sourceExprType = mlir::dyn_cast<hlfir::ExprType>(source.getType());
1310 if (sourceExprType && sourceExprType.isPolymorphic())
1311 firVarType = fir::ClassType::get(fir::HeapType::get(dataType));
1312 else
1313 firVarType = fir::ReferenceType::get(dataType);
1314
1315 mlir::Type hlfirVariableType =
1316 DeclareOp::getHLFIRVariableType(firVarType, /*hasExplicitLbs=*/false);
1317 mlir::Type i1Type = builder.getI1Type();
1318 build(builder, result, {hlfirVariableType, firVarType, i1Type}, source, shape,
1319 typeparams, nameAttr, fortran_attrs);
1320}
1321
1322void hlfir::AssociateOp::build(
1323 mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Value source,
1324 mlir::Value shape, mlir::ValueRange typeparams,
1325 fir::FortranVariableFlagsAttr fortran_attrs,
1326 llvm::ArrayRef<mlir::NamedAttribute> attributes) {
1327 mlir::Type dataType = getFortranElementOrSequenceType(source.getType());
1328
1329 // Preserve polymorphism of polymorphic expr.
1330 mlir::Type firVarType;
1331 auto sourceExprType = mlir::dyn_cast<hlfir::ExprType>(source.getType());
1332 if (sourceExprType && sourceExprType.isPolymorphic())
1333 firVarType = fir::ClassType::get(fir::HeapType::get(dataType));
1334 else
1335 firVarType = fir::ReferenceType::get(dataType);
1336
1337 mlir::Type hlfirVariableType =
1338 DeclareOp::getHLFIRVariableType(firVarType, /*hasExplicitLbs=*/false);
1339 mlir::Type i1Type = builder.getI1Type();
1340 build(builder, result, {hlfirVariableType, firVarType, i1Type}, source, shape,
1341 typeparams, {}, fortran_attrs);
1342 result.addAttributes(attributes);
1343}
1344
1345//===----------------------------------------------------------------------===//
1346// EndAssociateOp
1347//===----------------------------------------------------------------------===//
1348
1349void hlfir::EndAssociateOp::build(mlir::OpBuilder &builder,
1350 mlir::OperationState &result,
1351 hlfir::AssociateOp associate) {
1352 mlir::Value hlfirBase = associate.getBase();
1353 mlir::Value firBase = associate.getFirBase();
1354 // If EndAssociateOp may need to initiate the deallocation
1355 // of allocatable components, it has to have access to the variable
1356 // definition, so we cannot use the FIR base as the operand.
1357 return build(builder, result,
1358 hlfir::mayHaveAllocatableComponent(hlfirBase.getType())
1359 ? hlfirBase
1360 : firBase,
1361 associate.getMustFreeStrorageFlag());
1362}
1363
1364mlir::LogicalResult hlfir::EndAssociateOp::verify() {
1365 mlir::Value var = getVar();
1366 if (hlfir::mayHaveAllocatableComponent(var.getType()) &&
1367 !hlfir::isFortranEntity(var))
1368 return emitOpError("that requires components deallocation must have var "
1369 "operand that is a Fortran entity");
1370
1371 return mlir::success();
1372}
1373
1374//===----------------------------------------------------------------------===//
1375// AsExprOp
1376//===----------------------------------------------------------------------===//
1377
1378void hlfir::AsExprOp::build(mlir::OpBuilder &builder,
1379 mlir::OperationState &result, mlir::Value var,
1380 mlir::Value mustFree) {
1381 hlfir::ExprType::Shape typeShape;
1382 bool isPolymorphic = fir::isPolymorphicType(var.getType());
1383 mlir::Type type = getFortranElementOrSequenceType(var.getType());
1384 if (auto seqType = type.dyn_cast<fir::SequenceType>()) {
1385 typeShape.append(seqType.getShape().begin(), seqType.getShape().end());
1386 type = seqType.getEleTy();
1387 }
1388
1389 auto resultType = hlfir::ExprType::get(builder.getContext(), typeShape, type,
1390 isPolymorphic);
1391 return build(builder, result, resultType, var, mustFree);
1392}
1393
1394void hlfir::AsExprOp::getEffects(
1395 llvm::SmallVectorImpl<
1396 mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
1397 &effects) {
1398 // this isn't a transformational intrinsic but follows the same pattern: it
1399 // creates a hlfir.expr and so needs to have an allocation effect, plus it
1400 // might have a pointer-like argument, in which case it has a read effect
1401 // upon those
1402 getIntrinsicEffects(getOperation(), effects);
1403}
1404
1405//===----------------------------------------------------------------------===//
1406// ElementalOp
1407//===----------------------------------------------------------------------===//
1408
1409/// Common builder for ElementalOp and ElementalAddrOp to add the arguments and
1410/// create the elemental body. Result and clean-up body must be handled in
1411/// specific builders.
1412template <typename Op>
1413static void buildElemental(mlir::OpBuilder &builder,
1414 mlir::OperationState &odsState, mlir::Value shape,
1415 mlir::Value mold, mlir::ValueRange typeparams,
1416 bool isUnordered) {
1417 odsState.addOperands(shape);
1418 if (mold)
1419 odsState.addOperands(mold);
1420 odsState.addOperands(typeparams);
1421 odsState.addAttribute(
1422 Op::getOperandSegmentSizesAttrName(odsState.name),
1423 builder.getDenseI32ArrayAttr({/*shape=*/1, (mold ? 1 : 0),
1424 static_cast<int32_t>(typeparams.size())}));
1425 if (isUnordered)
1426 odsState.addAttribute(Op::getUnorderedAttrName(odsState.name),
1427 isUnordered ? builder.getUnitAttr() : nullptr);
1428 mlir::Region *bodyRegion = odsState.addRegion();
1429 bodyRegion->push_back(new mlir::Block{});
1430 if (auto shapeType = shape.getType().dyn_cast<fir::ShapeType>()) {
1431 unsigned dim = shapeType.getRank();
1432 mlir::Type indexType = builder.getIndexType();
1433 for (unsigned d = 0; d < dim; ++d)
1434 bodyRegion->front().addArgument(indexType, odsState.location);
1435 }
1436}
1437
1438void hlfir::ElementalOp::build(mlir::OpBuilder &builder,
1439 mlir::OperationState &odsState,
1440 mlir::Type resultType, mlir::Value shape,
1441 mlir::Value mold, mlir::ValueRange typeparams,
1442 bool isUnordered) {
1443 odsState.addTypes(resultType);
1444 buildElemental<hlfir::ElementalOp>(builder, odsState, shape, mold, typeparams,
1445 isUnordered);
1446}
1447
1448mlir::Value hlfir::ElementalOp::getElementEntity() {
1449 return mlir::cast<hlfir::YieldElementOp>(getBody()->back()).getElementValue();
1450}
1451
1452mlir::LogicalResult hlfir::ElementalOp::verify() {
1453 mlir::Value mold = getMold();
1454 hlfir::ExprType resultType = mlir::cast<hlfir::ExprType>(getType());
1455 if (!!mold != resultType.isPolymorphic())
1456 return emitOpError("result must be polymorphic when mold is present "
1457 "and vice versa");
1458
1459 return mlir::success();
1460}
1461
1462//===----------------------------------------------------------------------===//
1463// ApplyOp
1464//===----------------------------------------------------------------------===//
1465
1466void hlfir::ApplyOp::build(mlir::OpBuilder &builder,
1467 mlir::OperationState &odsState, mlir::Value expr,
1468 mlir::ValueRange indices,
1469 mlir::ValueRange typeparams) {
1470 mlir::Type resultType = expr.getType();
1471 if (auto exprType = resultType.dyn_cast<hlfir::ExprType>())
1472 resultType = exprType.getElementExprType();
1473 build(builder, odsState, resultType, expr, indices, typeparams);
1474}
1475
1476//===----------------------------------------------------------------------===//
1477// NullOp
1478//===----------------------------------------------------------------------===//
1479
1480void hlfir::NullOp::build(mlir::OpBuilder &builder,
1481 mlir::OperationState &odsState) {
1482 return build(builder, odsState,
1483 fir::ReferenceType::get(builder.getNoneType()));
1484}
1485
1486//===----------------------------------------------------------------------===//
1487// DestroyOp
1488//===----------------------------------------------------------------------===//
1489
1490mlir::LogicalResult hlfir::DestroyOp::verify() {
1491 if (mustFinalizeExpr()) {
1492 mlir::Value expr = getExpr();
1493 hlfir::ExprType exprTy = mlir::cast<hlfir::ExprType>(expr.getType());
1494 mlir::Type elemTy = hlfir::getFortranElementType(exprTy);
1495 if (!mlir::isa<fir::RecordType>(elemTy))
1496 return emitOpError(
1497 "the element type must be finalizable, when 'finalize' is set");
1498 }
1499
1500 return mlir::success();
1501}
1502
1503//===----------------------------------------------------------------------===//
1504// CopyInOp
1505//===----------------------------------------------------------------------===//
1506
1507void hlfir::CopyInOp::build(mlir::OpBuilder &builder,
1508 mlir::OperationState &odsState, mlir::Value var,
1509 mlir::Value var_is_present) {
1510 return build(builder, odsState, {var.getType(), builder.getI1Type()}, var,
1511 var_is_present);
1512}
1513
1514//===----------------------------------------------------------------------===//
1515// ShapeOfOp
1516//===----------------------------------------------------------------------===//
1517
1518void hlfir::ShapeOfOp::build(mlir::OpBuilder &builder,
1519 mlir::OperationState &result, mlir::Value expr) {
1520 hlfir::ExprType exprTy = expr.getType().cast<hlfir::ExprType>();
1521 mlir::Type type = fir::ShapeType::get(builder.getContext(), exprTy.getRank());
1522 build(builder, result, type, expr);
1523}
1524
1525std::size_t hlfir::ShapeOfOp::getRank() {
1526 mlir::Type resTy = getResult().getType();
1527 fir::ShapeType shape = resTy.cast<fir::ShapeType>();
1528 return shape.getRank();
1529}
1530
1531mlir::LogicalResult hlfir::ShapeOfOp::verify() {
1532 mlir::Value expr = getExpr();
1533 hlfir::ExprType exprTy = expr.getType().cast<hlfir::ExprType>();
1534 std::size_t exprRank = exprTy.getShape().size();
1535
1536 if (exprRank == 0)
1537 return emitOpError("cannot get the shape of a shape-less expression");
1538
1539 std::size_t shapeRank = getRank();
1540 if (shapeRank != exprRank)
1541 return emitOpError("result rank and expr rank do not match");
1542
1543 return mlir::success();
1544}
1545
1546mlir::LogicalResult
1547hlfir::ShapeOfOp::canonicalize(ShapeOfOp shapeOf,
1548 mlir::PatternRewriter &rewriter) {
1549 // if extent information is available at compile time, immediately fold the
1550 // hlfir.shape_of into a fir.shape
1551 mlir::Location loc = shapeOf.getLoc();
1552 hlfir::ExprType expr = shapeOf.getExpr().getType().cast<hlfir::ExprType>();
1553
1554 mlir::Value shape = hlfir::genExprShape(rewriter, loc, expr);
1555 if (!shape)
1556 // shape information is not available at compile time
1557 return mlir::LogicalResult::failure();
1558
1559 rewriter.replaceAllUsesWith(shapeOf.getResult(), shape);
1560 rewriter.eraseOp(shapeOf);
1561 return mlir::LogicalResult::success();
1562}
1563
1564//===----------------------------------------------------------------------===//
1565// GetExtent
1566//===----------------------------------------------------------------------===//
1567
1568void hlfir::GetExtentOp::build(mlir::OpBuilder &builder,
1569 mlir::OperationState &result, mlir::Value shape,
1570 unsigned dim) {
1571 mlir::Type indexTy = builder.getIndexType();
1572 mlir::IntegerAttr dimAttr = mlir::IntegerAttr::get(indexTy, dim);
1573 build(builder, result, indexTy, shape, dimAttr);
1574}
1575
1576mlir::LogicalResult hlfir::GetExtentOp::verify() {
1577 fir::ShapeType shapeTy = getShape().getType().cast<fir::ShapeType>();
1578 std::uint64_t rank = shapeTy.getRank();
1579 llvm::APInt dim = getDim();
1580 if (dim.sge(rank))
1581 return emitOpError("dimension index out of bounds");
1582 return mlir::success();
1583}
1584
1585//===----------------------------------------------------------------------===//
1586// RegionAssignOp
1587//===----------------------------------------------------------------------===//
1588
1589/// Add a fir.end terminator to a parsed region if it does not already has a
1590/// terminator.
1591static void ensureTerminator(mlir::Region &region, mlir::Builder &builder,
1592 mlir::Location loc) {
1593 // Borrow YielOp::ensureTerminator MLIR generated implementation to add a
1594 // fir.end if there is no terminator. This has nothing to do with YielOp,
1595 // other than the fact that yieldOp has the
1596 // SingleBlocklicitTerminator<"fir::FirEndOp"> interface that
1597 // cannot be added on other HLFIR operations with several regions which are
1598 // not all terminated the same way.
1599 hlfir::YieldOp::ensureTerminator(region, builder, loc);
1600}
1601
1602mlir::ParseResult hlfir::RegionAssignOp::parse(mlir::OpAsmParser &parser,
1603 mlir::OperationState &result) {
1604 mlir::Region &rhsRegion = *result.addRegion();
1605 if (parser.parseRegion(rhsRegion))
1606 return mlir::failure();
1607 mlir::Region &lhsRegion = *result.addRegion();
1608 if (parser.parseKeyword("to") || parser.parseRegion(lhsRegion))
1609 return mlir::failure();
1610 mlir::Region &userDefinedAssignmentRegion = *result.addRegion();
1611 if (succeeded(parser.parseOptionalKeyword("user_defined_assign"))) {
1612 mlir::OpAsmParser::Argument rhsArg, lhsArg;
1613 if (parser.parseLParen() || parser.parseArgument(rhsArg) ||
1614 parser.parseColon() || parser.parseType(rhsArg.type) ||
1615 parser.parseRParen() || parser.parseKeyword("to") ||
1616 parser.parseLParen() || parser.parseArgument(lhsArg) ||
1617 parser.parseColon() || parser.parseType(lhsArg.type) ||
1618 parser.parseRParen())
1619 return mlir::failure();
1620 if (parser.parseRegion(userDefinedAssignmentRegion, {rhsArg, lhsArg}))
1621 return mlir::failure();
1622 ensureTerminator(userDefinedAssignmentRegion, parser.getBuilder(),
1623 result.location);
1624 }
1625 return mlir::success();
1626}
1627
1628void hlfir::RegionAssignOp::print(mlir::OpAsmPrinter &p) {
1629 p << " ";
1630 p.printRegion(getRhsRegion(), /*printEntryBlockArgs=*/false,
1631 /*printBlockTerminators=*/true);
1632 p << " to ";
1633 p.printRegion(getLhsRegion(), /*printEntryBlockArgs=*/false,
1634 /*printBlockTerminators=*/true);
1635 if (!getUserDefinedAssignment().empty()) {
1636 p << " user_defined_assign ";
1637 mlir::Value userAssignmentRhs = getUserAssignmentRhs();
1638 mlir::Value userAssignmentLhs = getUserAssignmentLhs();
1639 p << " (" << userAssignmentRhs << ": " << userAssignmentRhs.getType()
1640 << ") to (";
1641 p << userAssignmentLhs << ": " << userAssignmentLhs.getType() << ") ";
1642 p.printRegion(getUserDefinedAssignment(), /*printEntryBlockArgs=*/false,
1643 /*printBlockTerminators=*/false);
1644 }
1645}
1646
1647static mlir::Operation *getTerminator(mlir::Region &region) {
1648 if (region.empty() || region.back().empty())
1649 return nullptr;
1650 return &region.back().back();
1651}
1652
1653mlir::LogicalResult hlfir::RegionAssignOp::verify() {
1654 if (!mlir::isa_and_nonnull<hlfir::YieldOp>(getTerminator(getRhsRegion())))
1655 return emitOpError(
1656 "right-hand side region must be terminated by an hlfir.yield");
1657 if (!mlir::isa_and_nonnull<hlfir::YieldOp, hlfir::ElementalAddrOp>(
1658 getTerminator(getLhsRegion())))
1659 return emitOpError("left-hand side region must be terminated by an "
1660 "hlfir.yield or hlfir.elemental_addr");
1661 return mlir::success();
1662}
1663
1664//===----------------------------------------------------------------------===//
1665// YieldOp
1666//===----------------------------------------------------------------------===//
1667
1668static mlir::ParseResult parseYieldOpCleanup(mlir::OpAsmParser &parser,
1669 mlir::Region &cleanup) {
1670 if (succeeded(parser.parseOptionalKeyword("cleanup"))) {
1671 if (parser.parseRegion(cleanup, /*arguments=*/{},
1672 /*argTypes=*/{}))
1673 return mlir::failure();
1674 hlfir::YieldOp::ensureTerminator(cleanup, parser.getBuilder(),
1675 parser.getBuilder().getUnknownLoc());
1676 }
1677 return mlir::success();
1678}
1679
1680template <typename YieldOp>
1681static void printYieldOpCleanup(mlir::OpAsmPrinter &p, YieldOp yieldOp,
1682 mlir::Region &cleanup) {
1683 if (!cleanup.empty()) {
1684 p << "cleanup ";
1685 p.printRegion(cleanup, /*printEntryBlockArgs=*/false,
1686 /*printBlockTerminators=*/false);
1687 }
1688}
1689
1690//===----------------------------------------------------------------------===//
1691// ElementalAddrOp
1692//===----------------------------------------------------------------------===//
1693
1694void hlfir::ElementalAddrOp::build(mlir::OpBuilder &builder,
1695 mlir::OperationState &odsState,
1696 mlir::Value shape, mlir::Value mold,
1697 mlir::ValueRange typeparams,
1698 bool isUnordered) {
1699 buildElemental<hlfir::ElementalAddrOp>(builder, odsState, shape, mold,
1700 typeparams, isUnordered);
1701 // Push cleanUp region.
1702 odsState.addRegion();
1703}
1704
1705mlir::LogicalResult hlfir::ElementalAddrOp::verify() {
1706 hlfir::YieldOp yieldOp =
1707 mlir::dyn_cast_or_null<hlfir::YieldOp>(getTerminator(getBody()));
1708 if (!yieldOp)
1709 return emitOpError("body region must be terminated by an hlfir.yield");
1710 mlir::Type elementAddrType = yieldOp.getEntity().getType();
1711 if (!hlfir::isFortranVariableType(elementAddrType) ||
1712 hlfir::getFortranElementOrSequenceType(elementAddrType)
1713 .isa<fir::SequenceType>())
1714 return emitOpError("body must compute the address of a scalar entity");
1715 unsigned shapeRank = getShape().getType().cast<fir::ShapeType>().getRank();
1716 if (shapeRank != getIndices().size())
1717 return emitOpError("body number of indices must match shape rank");
1718 return mlir::success();
1719}
1720
1721hlfir::YieldOp hlfir::ElementalAddrOp::getYieldOp() {
1722 hlfir::YieldOp yieldOp =
1723 mlir::dyn_cast_or_null<hlfir::YieldOp>(getTerminator(getBody()));
1724 assert(yieldOp && "element_addr is ill-formed");
1725 return yieldOp;
1726}
1727
1728mlir::Value hlfir::ElementalAddrOp::getElementEntity() {
1729 return getYieldOp().getEntity();
1730}
1731
1732mlir::Region *hlfir::ElementalAddrOp::getElementCleanup() {
1733 mlir::Region *cleanup = &getYieldOp().getCleanup();
1734 return cleanup->empty() ? nullptr : cleanup;
1735}
1736
1737//===----------------------------------------------------------------------===//
1738// OrderedAssignmentTreeOpInterface
1739//===----------------------------------------------------------------------===//
1740
1741mlir::LogicalResult hlfir::OrderedAssignmentTreeOpInterface::verifyImpl() {
1742 if (mlir::Region *body = getSubTreeRegion())
1743 if (!body->empty())
1744 for (mlir::Operation &op : body->front())
1745 if (!mlir::isa<hlfir::OrderedAssignmentTreeOpInterface, fir::FirEndOp>(
1746 op))
1747 return emitOpError(
1748 "body region must only contain OrderedAssignmentTreeOpInterface "
1749 "operations or fir.end");
1750 return mlir::success();
1751}
1752
1753//===----------------------------------------------------------------------===//
1754// ForallOp
1755//===----------------------------------------------------------------------===//
1756
1757static mlir::ParseResult parseForallOpBody(mlir::OpAsmParser &parser,
1758 mlir::Region &body) {
1759 mlir::OpAsmParser::Argument bodyArg;
1760 if (parser.parseLParen() || parser.parseArgument(bodyArg) ||
1761 parser.parseColon() || parser.parseType(bodyArg.type) ||
1762 parser.parseRParen())
1763 return mlir::failure();
1764 if (parser.parseRegion(body, {bodyArg}))
1765 return mlir::failure();
1766 ensureTerminator(body, parser.getBuilder(),
1767 parser.getBuilder().getUnknownLoc());
1768 return mlir::success();
1769}
1770
1771static void printForallOpBody(mlir::OpAsmPrinter &p, hlfir::ForallOp forall,
1772 mlir::Region &body) {
1773 mlir::Value forallIndex = forall.getForallIndexValue();
1774 p << " (" << forallIndex << ": " << forallIndex.getType() << ") ";
1775 p.printRegion(body, /*printEntryBlockArgs=*/false,
1776 /*printBlockTerminators=*/false);
1777}
1778
1779/// Predicate implementation of YieldIntegerOrEmpty.
1780static bool yieldsIntegerOrEmpty(mlir::Region &region) {
1781 if (region.empty())
1782 return true;
1783 auto yield = mlir::dyn_cast_or_null<hlfir::YieldOp>(getTerminator(region));
1784 return yield && fir::isa_integer(yield.getEntity().getType());
1785}
1786
1787//===----------------------------------------------------------------------===//
1788// ForallMaskOp
1789//===----------------------------------------------------------------------===//
1790
1791static mlir::ParseResult parseAssignmentMaskOpBody(mlir::OpAsmParser &parser,
1792 mlir::Region &body) {
1793 if (parser.parseRegion(body))
1794 return mlir::failure();
1795 ensureTerminator(body, parser.getBuilder(),
1796 parser.getBuilder().getUnknownLoc());
1797 return mlir::success();
1798}
1799
1800template <typename ConcreteOp>
1801static void printAssignmentMaskOpBody(mlir::OpAsmPrinter &p, ConcreteOp,
1802 mlir::Region &body) {
1803 // ElseWhereOp is a WhereOp/ElseWhereOp terminator that should be printed.
1804 bool printBlockTerminators =
1805 !body.empty() &&
1806 mlir::isa_and_nonnull<hlfir::ElseWhereOp>(body.back().getTerminator());
1807 p.printRegion(body, /*printEntryBlockArgs=*/false, printBlockTerminators);
1808}
1809
1810static bool yieldsLogical(mlir::Region &region, bool mustBeScalarI1) {
1811 if (region.empty())
1812 return false;
1813 auto yield = mlir::dyn_cast_or_null<hlfir::YieldOp>(getTerminator(region));
1814 if (!yield)
1815 return false;
1816 mlir::Type yieldType = yield.getEntity().getType();
1817 if (mustBeScalarI1)
1818 return hlfir::isI1Type(yieldType);
1819 return hlfir::isMaskArgument(yieldType) &&
1820 hlfir::getFortranElementOrSequenceType(yieldType)
1821 .isa<fir::SequenceType>();
1822}
1823
1824mlir::LogicalResult hlfir::ForallMaskOp::verify() {
1825 if (!yieldsLogical(getMaskRegion(), /*mustBeScalarI1=*/true))
1826 return emitOpError("mask region must yield a scalar i1");
1827 mlir::Operation *op = getOperation();
1828 hlfir::ForallOp forallOp =
1829 mlir::dyn_cast_or_null<hlfir::ForallOp>(op->getParentOp());
1830 if (!forallOp || op->getParentRegion() != &forallOp.getBody())
1831 return emitOpError("must be inside the body region of an hlfir.forall");
1832 return mlir::success();
1833}
1834
1835//===----------------------------------------------------------------------===//
1836// WhereOp and ElseWhereOp
1837//===----------------------------------------------------------------------===//
1838
1839template <typename ConcreteOp>
1840static mlir::LogicalResult verifyWhereAndElseWhereBody(ConcreteOp &concreteOp) {
1841 for (mlir::Operation &op : concreteOp.getBody().front())
1842 if (mlir::isa<hlfir::ForallOp>(op))
1843 return concreteOp.emitOpError(
1844 "body region must not contain hlfir.forall");
1845 return mlir::success();
1846}
1847
1848mlir::LogicalResult hlfir::WhereOp::verify() {
1849 if (!yieldsLogical(getMaskRegion(), /*mustBeScalarI1=*/false))
1850 return emitOpError("mask region must yield a logical array");
1851 return verifyWhereAndElseWhereBody(*this);
1852}
1853
1854mlir::LogicalResult hlfir::ElseWhereOp::verify() {
1855 if (!getMaskRegion().empty())
1856 if (!yieldsLogical(getMaskRegion(), /*mustBeScalarI1=*/false))
1857 return emitOpError(
1858 "mask region must yield a logical array when provided");
1859 return verifyWhereAndElseWhereBody(*this);
1860}
1861
1862//===----------------------------------------------------------------------===//
1863// ForallIndexOp
1864//===----------------------------------------------------------------------===//
1865
1866mlir::LogicalResult
1867hlfir::ForallIndexOp::canonicalize(hlfir::ForallIndexOp indexOp,
1868 mlir::PatternRewriter &rewriter) {
1869 for (mlir::Operation *user : indexOp->getResult(0).getUsers())
1870 if (!mlir::isa<fir::LoadOp>(user))
1871 return mlir::failure();
1872
1873 auto insertPt = rewriter.saveInsertionPoint();
1874 llvm::SmallVector<mlir::Operation *> users(indexOp->getResult(0).getUsers());
1875 for (mlir::Operation *user : users)
1876 if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(user)) {
1877 rewriter.setInsertionPoint(loadOp);
1878 rewriter.replaceOpWithNewOp<fir::ConvertOp>(
1879 user, loadOp.getResult().getType(), indexOp.getIndex());
1880 }
1881 rewriter.restoreInsertionPoint(insertPt);
1882 rewriter.eraseOp(indexOp);
1883 return mlir::success();
1884}
1885
1886//===----------------------------------------------------------------------===//
1887// CharExtremumOp
1888//===----------------------------------------------------------------------===//
1889
1890mlir::LogicalResult hlfir::CharExtremumOp::verify() {
1891 if (getStrings().size() < 2)
1892 return emitOpError("must be provided at least two string operands");
1893 unsigned kind = getCharacterKind(getResult().getType());
1894 for (auto string : getStrings())
1895 if (kind != getCharacterKind(string.getType()))
1896 return emitOpError("strings must have the same KIND as the result type");
1897 return mlir::success();
1898}
1899
1900void hlfir::CharExtremumOp::build(mlir::OpBuilder &builder,
1901 mlir::OperationState &result,
1902 hlfir::CharExtremumPredicate predicate,
1903 mlir::ValueRange strings) {
1904
1905 fir::CharacterType::LenType resultTypeLen = 0;
1906 assert(!strings.empty() && "must contain operands");
1907 unsigned kind = getCharacterKind(strings[0].getType());
1908 for (auto string : strings)
1909 if (auto cstLen = getCharacterLengthIfStatic(string.getType())) {
1910 resultTypeLen = std::max(resultTypeLen, *cstLen);
1911 } else {
1912 resultTypeLen = fir::CharacterType::unknownLen();
1913 break;
1914 }
1915 auto resultType = hlfir::ExprType::get(
1916 builder.getContext(), hlfir::ExprType::Shape{},
1917 fir::CharacterType::get(builder.getContext(), kind, resultTypeLen),
1918 false);
1919
1920 build(builder, result, resultType, predicate, strings);
1921}
1922
1923void hlfir::CharExtremumOp::getEffects(
1924 llvm::SmallVectorImpl<
1925 mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
1926 &effects) {
1927 getIntrinsicEffects(getOperation(), effects);
1928}
1929
1930//===----------------------------------------------------------------------===//
1931// GetLength
1932//===----------------------------------------------------------------------===//
1933
1934mlir::LogicalResult
1935hlfir::GetLengthOp::canonicalize(GetLengthOp getLength,
1936 mlir::PatternRewriter &rewriter) {
1937 mlir::Location loc = getLength.getLoc();
1938 auto exprTy = mlir::cast<hlfir::ExprType>(getLength.getExpr().getType());
1939 auto charTy = mlir::cast<fir::CharacterType>(exprTy.getElementType());
1940 if (!charTy.hasConstantLen())
1941 return mlir::failure();
1942
1943 mlir::Type indexTy = rewriter.getIndexType();
1944 auto cstLen = rewriter.create<mlir::arith::ConstantOp>(
1945 loc, indexTy, mlir::IntegerAttr::get(indexTy, charTy.getLen()));
1946 rewriter.replaceOp(getLength, cstLen);
1947 return mlir::success();
1948}
1949
1950#include "flang/Optimizer/HLFIR/HLFIROpInterfaces.cpp.inc"
1951#define GET_OP_CLASSES
1952#include "flang/Optimizer/HLFIR/HLFIREnums.cpp.inc"
1953#include "flang/Optimizer/HLFIR/HLFIROps.cpp.inc"
1954

source code of flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp